├── .idea
└── vcs.xml
├── README.md
├── cgan
├── cgan.py
└── out
│ ├── 366b924f53bba0.svg
│ └── GAN output [Epoch 80].jpg
├── ct-gan
├── ct_gan.py
├── models.py
└── out
│ ├── 367444be15895e.svg
│ └── GAN output [Epoch 34].jpg
├── dcgan
├── dcgan.py
└── out
│ ├── 366cebae21d91c.svg
│ ├── GAN output [Epoch 358].jpg
│ └── real_samples.png
├── gan
├── gan.py
└── out
│ ├── 366b962ff767a0.svg
│ └── GAN output [Epoch 95].jpg
├── infogan
├── infogan.py
└── out
│ ├── (Nothing is varied) GAN output [Epoch 14].jpg
│ ├── (Only vary c1) GAN output [Epoch 14].jpg
│ ├── (Only vary c2) GAN output [Epoch 14].jpg
│ └── 367d344b5db086.svg
├── lsgan
├── lsgan.py
└── out
│ ├── 3677c4a9714c1c.svg
│ └── GAN output [Epoch 11].jpg
├── pix2pix
├── datasets.py
├── download_dataset.sh
├── models.py
├── out
│ ├── 3685e66f775102.svg
│ ├── 3685e7054e8534.svg
│ ├── GAN output [Epoch 216].jpg
│ └── GAN output [Epoch 240].jpg
└── pix2pix.py
├── ralsgan
├── out
│ ├── 36844a331a7460.svg
│ ├── GAN output [Iteration 31616].png
│ └── GAN output [Iteration 31680].png
├── preprocess_cat_dataset.py
├── ralsgan.py
└── setting_up_script.sh
├── srgan
├── datasets.py
├── models.py
├── out
│ ├── 36896236305c82.svg
│ ├── GAN output [Epoch 6] (1).jpg
│ └── GAN output [Epoch 6].jpg
└── srgan.py
├── wgan-gp
├── models.py
├── out
│ ├── 36737d5d4d1ebe.svg
│ ├── 3673cb5f59f630.svg
│ ├── GAN output [Epoch 105].jpg
│ └── GAN output [Epoch 52].jpg
└── wgan_gp.py
└── wgan
├── dataset_loader.py
├── out
├── bn_366d939cd31210.svg
├── bn_GAN output [Epoch 99].jpg
├── no_bn_366d9efab2beac.svg
├── no_bn_GAN output [Epoch 99] (1).jpg
└── real_samples.png
└── wgan.py
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/cgan/cgan.py:
--------------------------------------------------------------------------------
1 | '''
2 | Although I tried to implement this exactly as it was described in the paper,
3 | I couldn't get over the many problems such as mode collapse, diverging generator etc with the described network.
4 | I wasn't able to stabilize the training with SGD. I don't understand why a
5 | maxout layer (already nonlinear) has to be followed by a relu layer. I also had to use batchnorm for stability (instead of dropout).
6 | This paper brings the simple (yet effective, especially if you implemented on more
7 | complex networks such as pix2pix or any visual conditional generation system)
8 | idea of concatenating condition (e.g., labels) to the input to get conditional outputs,
9 | yet I think has fundamental flaws in designing the architecture.
10 | '''
11 | import argparse
12 |
13 | import torch
14 | from torch import nn, optim
15 | from torch.autograd.variable import Variable
16 |
17 | from torchvision import transforms, datasets
18 | from torch.utils.data import DataLoader
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs of training')
22 | parser.add_argument('--batch_size', type=int, default=100, help='size of the batches')
23 | parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
24 | parser.add_argument('--b1', type=float, default=0.5, help='adam: beta 1')
25 | parser.add_argument('--b2', type=float, default=0.999, help='adam: beta 2')
26 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
27 | parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension')
28 | parser.add_argument('--channels', type=int, default=1, help='number of image channels')
29 | parser.add_argument('--n_classes', type=int, default=10, help='number of classes (e.g., digits 0 ..9, 10 classes on mnist)')
30 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
31 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
32 | parser.add_argument('--sample_interval', type=int, default=512, help='interval betwen image samples')
33 | opt = parser.parse_args()
34 |
35 |
36 | try:
37 | import visdom
38 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
39 | except ImportError:
40 | vis = None
41 | else:
42 | vis.close(None) # Clear all figures.
43 |
44 |
45 | img_dims = (opt.channels, opt.img_size, opt.img_size)
46 | n_features = opt.channels * opt.img_size * opt.img_size
47 |
48 | def weights_init_xavier(m):
49 | classname = m.__class__.__name__
50 | if classname.find('Conv') != -1:
51 | nn.init.xavier_normal_(m.weight.data, gain=0.02)
52 | elif classname.find('Linear') != -1:
53 | nn.init.xavier_normal_(m.weight.data, gain=0.02)
54 | elif classname.find('BatchNorm1d') != -1:
55 | nn.init.normal_(m.weight.data, 1.0, 0.02)
56 | nn.init.constant_(m.bias.data, 0.0)
57 |
58 | class Generator(nn.Module):
59 | def __init__(self):
60 | super(Generator, self).__init__()
61 |
62 | # Map z & y (noise and label) into the hidden layer.
63 | # TO DO: How to run this with a function defined here?
64 | self.z_map = nn.Sequential(
65 | nn.Linear(opt.latent_dim, 200),
66 | nn.BatchNorm1d(200),
67 | nn.ReLU(inplace=True),
68 | )
69 | self.y_map = nn.Sequential(
70 | nn.Linear(opt.n_classes, 1000),
71 | nn.BatchNorm1d(1000),
72 | nn.ReLU(inplace=True),
73 | )
74 | self.zy_map = nn.Sequential(
75 | nn.Linear(1200, 1200),
76 | nn.BatchNorm1d(1200),
77 | nn.ReLU(inplace=True),
78 | )
79 |
80 | self.model = nn.Sequential(
81 | nn.Linear(1200, n_features),
82 | nn.Tanh()
83 | )
84 | # Tanh > Image values are between [-1, 1]
85 |
86 |
87 | def forward(self, z, y):
88 | zh = self.z_map(z)
89 | yh = self.y_map(y)
90 | zy = torch.cat((zh, yh), dim=1) # Combine noise and labels.
91 | zyh = self.zy_map(zy)
92 | x = self.model(zyh)
93 | x = x.view(x.size(0), *img_dims)
94 | return x
95 |
96 | class Discriminator(nn.Module):
97 | def __init__(self):
98 | super(Discriminator, self).__init__()
99 |
100 | self.model = nn.Sequential(
101 | nn.Linear(240, 1),
102 | nn.Sigmoid()
103 | )
104 |
105 | # Imitating a 3d array by combining second and third dimensions via multiplication for maxout.
106 | self.x_map = nn.Sequential(nn.Linear(n_features, 240 * 5))
107 | self.y_map = nn.Sequential(nn.Linear(opt.n_classes, 50 * 5))
108 | self.j_map = nn.Sequential(nn.Linear(240 + 50, 240 * 4))
109 |
110 | def forward(self, x, y):
111 | # maxout for x
112 | x = x.view(-1, n_features)
113 | x = self.x_map(x)
114 | x, _ = x.view(-1, 240, 5).max(dim=2) # pytorch outputs max values and indices
115 | # .. and y
116 | y = y.view(-1, opt.n_classes)
117 | y = self.y_map(y)
118 | y, _ = y.view(-1, 50, 5).max(dim=2)
119 | # joint maxout layer
120 | jmx = torch.cat((x, y), dim=1)
121 | jmx = self.j_map(jmx)
122 | jmx, _ = jmx.view(-1, 240, 4).max(dim=2)
123 |
124 | prob = self.model(jmx)
125 | return prob
126 |
127 |
128 | def mnist_data():
129 | compose = transforms.Compose(
130 | [transforms.ToTensor(),
131 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
132 | ])
133 | output_dir = './data/mnist'
134 | return datasets.MNIST(root=output_dir, train=True,
135 | transform=compose, download=True)
136 |
137 | mnist = mnist_data()
138 | batch_iterator = DataLoader(mnist, shuffle=True, batch_size=opt.batch_size) # List, NCHW format.
139 |
140 | cuda = torch.cuda.is_available()
141 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
142 | LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
143 | gan_loss = nn.BCELoss()
144 |
145 | generator = Generator()
146 | discriminator = Discriminator()
147 |
148 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
149 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
150 |
151 |
152 | # Loss record.
153 | g_losses = []
154 | d_losses = []
155 | epochs = []
156 | loss_legend = ['Discriminator', 'Generator']
157 |
158 | if cuda:
159 | generator = generator.cuda()
160 | discriminator = discriminator.cuda()
161 |
162 | # Weight initialization.
163 | generator.apply(weights_init_xavier)
164 | discriminator.apply(weights_init_xavier)
165 |
166 |
167 | for epoch in range(opt.n_epochs):
168 | print('Epoch {}'.format(epoch))
169 | for i, (batch, labels) in enumerate(batch_iterator):
170 |
171 | real = Variable(Tensor(batch.size(0), 1).fill_(1), requires_grad=False)
172 | fake = Variable(Tensor(batch.size(0), 1).fill_(0), requires_grad=False)
173 |
174 | labels_onehot = Variable(Tensor(batch.size(0), opt.n_classes).zero_())
175 | labels_ = labels.type(LongTensor) # Note: MNIST dataset is given as a longtensor, and we use .type() to feed it into cuda (if cuda available).
176 | labels_ = labels_.view(batch.size(0), 1)
177 | labels_onehot = labels_onehot.scatter_(1, labels_, 1) # As of 0.4, we don't have one-hot built-in function yet in pytorch.
178 |
179 | imgs_real = Variable(batch.type(Tensor))
180 | noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1))
181 | imgs_fake = generator(noise, labels_onehot)
182 |
183 | # == Discriminator update == #
184 | optimizer_D.zero_grad()
185 |
186 | # A small reminder: given classes c, prob. p, - c*logp - (1-c)*log(1-p) is the BCE/GAN loss.
187 | # Putting D(x) as p, x=real's class as 1, (..and same for fake with c=0) we arrive to BCE loss.
188 | # This is intuitively how well the discriminator can distinguish btw real & fake.
189 | d_loss = gan_loss(discriminator(imgs_real, labels_onehot), real) + \
190 | gan_loss(discriminator(imgs_fake, labels_onehot), fake)
191 |
192 | d_loss.backward()
193 | optimizer_D.step()
194 |
195 |
196 | # == Generator update == #
197 | noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1))
198 | imgs_fake = generator(noise, labels_onehot)
199 |
200 | optimizer_G.zero_grad()
201 |
202 | # Minimizing (1-log(d(g(noise))) is less stable than maximizing log(d(g)) [*].
203 | # Since BCE loss is defined as a negative sum, maximizing [*] is == minimizing [*]'s negative.
204 | # Intuitively, how well does the G fool the D?
205 | g_loss = gan_loss(discriminator(imgs_fake, labels_onehot), real)
206 |
207 | g_loss.backward()
208 | optimizer_G.step()
209 |
210 | if vis:
211 | batches_done = epoch * len(batch_iterator) + i
212 | if batches_done % opt.sample_interval == 0:
213 |
214 | # Generate digits from 0, ... 9, 5 samples in each column.
215 | noise = Variable(Tensor(5*10, opt.latent_dim).normal_(0, 1))
216 |
217 | labels_onehot = Variable(Tensor(5*10, opt.n_classes).zero_())
218 | labels_ = torch.range(0, 9)
219 | labels_ = labels_.view(1, -1).repeat(5, 1).transpose(0, 1).contiguous().view(1, -1) # my version of matlab's repelem
220 | labels_ = labels_.type(LongTensor)
221 | labels_ = labels_.view(-1, 1)
222 | labels_onehot = labels_onehot.scatter_(1, labels_, 1) # As of 0.4, we don't have one-hot built-in function yet in pytorch.
223 |
224 | imgs_fake = Variable(generator(noise, labels_onehot))
225 |
226 | # Keep a record of losses for plotting.
227 | epochs.append(epoch + i/len(batch_iterator))
228 | g_losses.append(g_loss.item())
229 | d_losses.append(d_loss.item())
230 |
231 | # Display results on visdom page.
232 | vis.line(
233 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
234 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
235 | opts={
236 | 'title': 'loss over time',
237 | 'legend': loss_legend,
238 | 'xlabel': 'epoch',
239 | 'ylabel': 'loss',
240 | 'width': 512,
241 | 'height': 512
242 | },
243 | win=1)
244 | vis.images(
245 | imgs_fake.data[:50],
246 | nrow=5, win=2,
247 | opts={
248 | 'title': 'GAN output [Epoch {}]'.format(epoch),
249 | 'width': 512,
250 | 'height': 512,
251 | }
252 | )
253 |
--------------------------------------------------------------------------------
/cgan/out/366b924f53bba0.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cgan/out/GAN output [Epoch 80].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/cgan/out/GAN output [Epoch 80].jpg
--------------------------------------------------------------------------------
/ct-gan/ct_gan.py:
--------------------------------------------------------------------------------
1 | ''' Improving the Improved Training of Wasserstein GANs: A Consistency Term and Its Dual Effect '''
2 | import argparse
3 |
4 | import torch
5 | from torch import nn, optim
6 | from torch.autograd.variable import Variable
7 |
8 | from torchvision import transforms, datasets
9 | from torch.utils.data import DataLoader
10 |
11 | from models import Discriminator, Generator
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--n_epochs', type=int, default=1000, help='number of epochs of training')
15 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
16 | parser.add_argument('--gamma', type=float, default=0.0003, help='adam: learning rate')
17 | parser.add_argument('--b1', type=float, default=0.9, help='adam: beta 1')
18 | parser.add_argument('--b2', type=float, default=0.999, help='adam: beta 2')
19 | parser.add_argument('--n_critic', type=int, default=5, help='number of critic iterations per generator iteration')
20 | parser.add_argument('--lambda_1', type=int, default=10, help='gradient penalty coefficient')
21 | parser.add_argument('--lambda_2', type=int, default=2, help='consistency term coefficient')
22 | parser.add_argument('--M', type=float, default=0.2, help='ct-gan hyperparameter')
23 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
24 | parser.add_argument('--latent_dim', type=int, default=128, help='number of input units to generate an image')
25 | parser.add_argument('--channels', type=int, default=1, help='number of image channels')
26 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
27 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
28 | parser.add_argument('--sample_interval', type=int, default=256, help='interval betwen image samples')
29 | opt = parser.parse_args()
30 |
31 | try:
32 | import visdom
33 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
34 | except ImportError:
35 | vis = None
36 | else:
37 | vis.close(None) # Clear all figures.
38 |
39 | def load_cifar10(img_size):
40 | compose = transforms.Compose(
41 | [transforms.Resize(img_size),
42 | transforms.ToTensor(),
43 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
44 | ])
45 | output_dir = './data/cifar10'
46 | cifar = datasets.CIFAR10(root=output_dir, download=True, train=True,
47 | transform=compose)
48 | return cifar
49 |
50 | cifar = load_cifar10(opt.img_size)
51 | batch_iterator = DataLoader(cifar, shuffle=True, batch_size=opt.batch_size) # List, NCHW format.
52 |
53 | cuda = torch.cuda.is_available()
54 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
55 | gan_loss = nn.BCELoss()
56 |
57 | generator = Generator()
58 | discriminator = Discriminator()
59 |
60 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.gamma, betas=(opt.b1, opt.b2))
61 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.gamma, betas=(opt.b1, opt.b2))
62 |
63 | # Loss record.
64 | g_losses = []
65 | d_losses = []
66 | epochs = []
67 | loss_legend = ['Discriminator', 'Generator']
68 |
69 | if cuda:
70 | generator = generator.cuda()
71 | discriminator = discriminator.cuda()
72 |
73 | noise_fixed = Variable(Tensor(25, opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN.
74 |
75 | for epoch in range(opt.n_epochs):
76 | print('Epoch {}'.format(epoch))
77 | for i, (batch, _) in enumerate(batch_iterator):
78 | # == Discriminator update == #
79 | for iter in range(opt.n_critic):
80 | # Sample real and fake images, using notation in paper.
81 | x = Variable(batch.type(Tensor))
82 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1))
83 | x_tilde = Variable(generator(noise), requires_grad=True)
84 |
85 | epsilon = Variable(Tensor(batch.size(0), 1, 1, 1).uniform_(0, 1))
86 |
87 | x_hat = epsilon*x + (1 - epsilon)*x_tilde
88 | x_hat = torch.autograd.Variable(x_hat, requires_grad=True)
89 |
90 | # Put the interpolated data through critic.
91 | dw_x = discriminator(x_hat)
92 | grad_x = torch.autograd.grad(outputs=dw_x, inputs=x_hat,
93 | grad_outputs=Variable(Tensor(batch.size(0), 1, 1, 1).fill_(1.0), requires_grad=False),
94 | create_graph=True, retain_graph=True, only_inputs=True)
95 | grad_x = grad_x[0].view(batch.size(0), -1)
96 | grad_x = grad_x.norm(p=2, dim=1)
97 |
98 | # Update critic.
99 | optimizer_D.zero_grad()
100 |
101 | # Standard WGAN loss.
102 | d_wgan_loss = torch.mean(discriminator(x_tilde)) - torch.mean(discriminator(x))
103 |
104 | # WGAN-GP loss.
105 | d_wgp_loss = torch.mean((grad_x - 1)**2)
106 |
107 | ###### Consistency term. ######
108 | dw_x1, dw_x1_i = discriminator(x, dropout=0.5, intermediate_output=True) # Perturb the input by applying dropout to hidden layers.
109 | dw_x2, dw_x2_i = discriminator(x, dropout=0.5, intermediate_output=True)
110 | # Using l2 norm as the distance metric d, referring to the official code (paper ambiguous on d).
111 | second_to_last_reg = ((dw_x1_i-dw_x2_i) ** 2).mean(dim=1).unsqueeze_(1).unsqueeze_(2).unsqueeze_(3)
112 | d_wct_loss = (dw_x1-dw_x2) ** 2 \
113 | + 0.1 * second_to_last_reg \
114 | - opt.M
115 | d_wct_loss, _ = torch.max(d_wct_loss, 0) # torch.max returns max, and the index of max
116 |
117 | # Combined loss.
118 | d_loss = d_wgan_loss + opt.lambda_1*d_wgp_loss + opt.lambda_2*d_wct_loss
119 |
120 | d_loss.backward()
121 | optimizer_D.step()
122 |
123 | # == Generator update == #
124 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1))
125 | imgs_fake = generator(noise)
126 |
127 | optimizer_G.zero_grad()
128 |
129 | g_loss = -torch.mean(discriminator(imgs_fake))
130 |
131 | g_loss.backward()
132 | optimizer_G.step()
133 |
134 | if vis:
135 | batches_done = epoch * len(batch_iterator) + i
136 | if batches_done % opt.sample_interval == 0:
137 |
138 | # Keep a record of losses for plotting.
139 | epochs.append(epoch + i/len(batch_iterator))
140 | g_losses.append(g_loss.item())
141 | d_losses.append(d_loss.item())
142 |
143 | # Generate images for a given set of fixed noise
144 | # so we can track how the GAN learns.
145 | imgs_fake_fixed = generator(noise_fixed).detach().data
146 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2)
147 |
148 | # Display results on visdom page.
149 | vis.line(
150 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
151 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
152 | opts={
153 | 'title': 'loss over time',
154 | 'legend': loss_legend,
155 | 'xlabel': 'epoch',
156 | 'ylabel': 'loss',
157 | 'width': 512,
158 | 'height': 512
159 | },
160 | win=1)
161 | vis.images(
162 | imgs_fake_fixed[:25],
163 | nrow=5, win=2,
164 | opts={
165 | 'title': 'GAN output [Epoch {}]'.format(epoch),
166 | 'width': 512,
167 | 'height': 512,
168 | }
169 | )
--------------------------------------------------------------------------------
/ct-gan/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | # from torch.legacy.nn import Identity
4 |
5 | # Residual network.
6 | # CT-GAN basically uses the same architecture WGAN-GP uses,
7 | # but with a different loss.
8 | class MeanPoolConv(nn.Module):
9 | def __init__(self, n_input, n_output, k_size, kaiming_init=True):
10 | super(MeanPoolConv, self).__init__()
11 | conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True)
12 | if kaiming_init:
13 | nn.init.kaiming_uniform_(conv1.weight, mode='fan_in', nonlinearity='relu')
14 | self.model = nn.Sequential(conv1)
15 | def forward(self, x):
16 | out = (x[:,:,::2,::2] + x[:,:,1::2,::2] + x[:,:,::2,1::2] + x[:,:,1::2,1::2]) / 4.0
17 | out = self.model(out)
18 | return out
19 |
20 | class ConvMeanPool(nn.Module):
21 | def __init__(self, n_input, n_output, k_size, kaiming_init=True):
22 | super(ConvMeanPool, self).__init__()
23 | conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True)
24 |
25 | if kaiming_init:
26 | nn.init.kaiming_uniform_(conv1.weight, mode='fan_in', nonlinearity='relu')
27 |
28 | self.model = nn.Sequential(conv1)
29 | def forward(self, x):
30 | out = self.model(x)
31 | out = (out[:,:,::2,::2] + out[:,:,1::2,::2] + out[:,:,::2,1::2] + out[:,:,1::2,1::2]) / 4.0
32 | return out
33 |
34 | class UpsampleConv(nn.Module):
35 | def __init__(self, n_input, n_output, k_size, kaiming_init=True):
36 | super(UpsampleConv, self).__init__()
37 |
38 | conv_layer = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True)
39 | if kaiming_init:
40 | nn.init.kaiming_uniform_(conv_layer.weight, mode='fan_in', nonlinearity='relu')
41 |
42 | self.model = nn.Sequential(
43 | nn.PixelShuffle(2),
44 | conv_layer,
45 | )
46 |
47 | def forward(self, x):
48 | x = x.repeat((1, 4, 1, 1)) # Weird concat of WGAN-GPs upsampling process.
49 | out = self.model(x)
50 | return out
51 |
52 | class ResidualBlock(nn.Module):
53 | def __init__(self, n_input, n_output, k_size, resample='up', bn=True, spatial_dim=None):
54 | super(ResidualBlock, self).__init__()
55 |
56 | self.resample = resample
57 |
58 | if resample == 'up':
59 | self.conv1 = UpsampleConv(n_input, n_output, k_size, kaiming_init=True)
60 | self.conv2 = nn.Conv2d(n_output, n_output, k_size, padding=(k_size-1)//2)
61 | self.conv_shortcut = UpsampleConv(n_input, n_output, k_size, kaiming_init=True)
62 | self.out_dim = n_output
63 |
64 | # Weight initialization.
65 | nn.init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
66 | elif resample == 'down':
67 | self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2)
68 | self.conv2 = ConvMeanPool(n_input, n_output, k_size, kaiming_init=True)
69 | self.conv_shortcut = ConvMeanPool(n_input, n_output, k_size, kaiming_init=True)
70 | self.out_dim = n_output
71 | self.ln_dims = [n_input, spatial_dim, spatial_dim] # Define the dimensions for layer normalization.
72 |
73 | nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
74 | else:
75 | self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2)
76 | self.conv2 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2)
77 | self.conv_shortcut = None # Identity
78 | self.out_dim = n_input
79 | self.ln_dims = [n_input, spatial_dim, spatial_dim]
80 |
81 | nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
82 | nn.init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
83 |
84 | self.model = nn.Sequential(
85 | nn.BatchNorm2d(n_input) if bn else nn.LayerNorm(self.ln_dims),
86 | nn.ReLU(inplace=True),
87 | self.conv1,
88 | nn.BatchNorm2d(self.out_dim) if bn else nn.LayerNorm(self.ln_dims),
89 | nn.ReLU(inplace=True),
90 | self.conv2,
91 | )
92 |
93 | def forward(self, x):
94 | if self.conv_shortcut is None:
95 | return x + self.model(x)
96 | else:
97 | return self.conv_shortcut(x) + self.model(x)
98 |
99 | class DiscBlock1(nn.Module):
100 | def __init__(self, n_output):
101 | super(DiscBlock1, self).__init__()
102 |
103 | self.conv1 = nn.Conv2d(3, n_output, 3, padding=(3-1)//2)
104 | self.conv2 = ConvMeanPool(n_output, n_output, 1, kaiming_init=True)
105 | self.conv_shortcut = MeanPoolConv(3, n_output, 1, kaiming_init=False)
106 |
107 | # Weight initialization:
108 | nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
109 |
110 | self.model = nn.Sequential(
111 | self.conv1,
112 | nn.ReLU(inplace=True),
113 | self.conv2
114 | )
115 |
116 | def forward(self, x):
117 | return self.conv_shortcut(x) + self.model(x)
118 |
119 | class Generator(nn.Module):
120 | def __init__(self):
121 | super(Generator, self).__init__()
122 |
123 | self.model = nn.Sequential( # 128 x 1 x 1
124 | nn.ConvTranspose2d(128, 128, 4, 1, 0), # 128 x 4 x 4
125 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 8 x 8
126 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 16 x 16
127 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 32 x 32
128 | nn.BatchNorm2d(128),
129 | nn.ReLU(inplace=True),
130 | nn.Conv2d(128, 3, 3, padding=(3-1)//2), # 3 x 32 x 32 (no kaiming init. here)
131 | nn.Tanh()
132 | )
133 |
134 | def forward(self, z):
135 | img = self.model(z)
136 | return img
137 |
138 | class Discriminator(nn.Module):
139 | def __init__(self):
140 | super(Discriminator, self).__init__()
141 | n_output = 128
142 | '''
143 | This is a parameter but since we experiment with a single size
144 | of 3 x 32 x 32 images, it is hardcoded here.
145 | '''
146 |
147 | self.DiscBlock1 = DiscBlock1(n_output) # 128 x 16 x 16
148 | self.block1 = nn.Sequential(
149 | ResidualBlock(n_output, n_output, 3, resample='down', bn=False, spatial_dim=16), # 128 x 8 x 8
150 | )
151 | self.block2 = nn.Sequential(
152 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8
153 | )
154 | self.block3 = nn.Sequential(
155 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8
156 | )
157 |
158 | self.l1 = nn.Sequential(nn.Linear(128, 1)) # 128 x 1
159 |
160 | def forward(self, x, dropout=0.0, intermediate_output=False):
161 | # x = x.view(-1, 3, 32, 32)
162 | y = self.DiscBlock1(x)
163 | y = self.block1(y)
164 | y = F.dropout(y, training=True, p=dropout)
165 | y = self.block2(y)
166 | y = F.dropout(y, training=True, p=dropout)
167 | y = self.block3(y)
168 | y = F.dropout(y, training=True, p=dropout)
169 | y = F.relu(y)
170 | y = y.view(x.size(0), 128, -1)
171 | y = y.mean(dim=2)
172 | critic_value = self.l1(y).unsqueeze_(1).unsqueeze_(2) # or *.view(x.size(0), 128, 1, 1, 1)
173 |
174 | if intermediate_output:
175 | return critic_value, y # y is the D_(.), intermediate layer given in paper.
176 |
177 | return critic_value
--------------------------------------------------------------------------------
/ct-gan/out/367444be15895e.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ct-gan/out/GAN output [Epoch 34].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/ct-gan/out/GAN output [Epoch 34].jpg
--------------------------------------------------------------------------------
/dcgan/dcgan.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | from torch import nn, optim
6 | from torch.autograd.variable import Variable
7 |
8 | from torchvision import transforms, datasets
9 | from torch.utils.data import DataLoader
10 | from torchvision.utils import save_image
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--n_epochs', type=int, default=1000, help='number of epochs of training')
14 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
15 | parser.add_argument('--lr', type=float, default=0.0001, help='adam: learning rate')
16 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
17 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
18 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
19 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
20 | parser.add_argument('--channels', type=int, default=3, help='number of image channels')
21 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
22 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
23 | parser.add_argument('--sample_interval', type=int, default=64, help='interval betwen image samples')
24 | opt = parser.parse_args()
25 |
26 | # Just trying if these different initializations
27 | # matter that much.
28 | def weights_init_xavier(m):
29 | classname = m.__class__.__name__
30 | if classname.find('Conv') != -1 or classname.find('Linear') != -1:
31 | m.weight.data.normal_(0.0, 0.02)
32 | elif classname.find('BatchNorm') != -1:
33 | m.weight.data.normal_(1.0, 0.02)
34 | m.bias.data.fill_(0)
35 |
36 | try:
37 | import visdom
38 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
39 | except ImportError:
40 | vis = None
41 | else:
42 | vis.close(None) # Clear all figures.
43 |
44 |
45 | img_dims = (opt.channels, opt.img_size, opt.img_size)
46 | n_features = opt.channels * opt.img_size * opt.img_size
47 |
48 | # Model is taken from Fig. 1 of the DCGAN paper.
49 | # BN before ReLU? https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/
50 | # I tried BN -> ReLU but got poor results.
51 | # Here I skip one layer of DCGAN because I use CIFAR10 (32x32)
52 | # instead of LSUN (64x64). LSUN is too large to load and train (40 GB for a single bedroom category)
53 |
54 | '''
55 | UPDATE: Both disc. and generator made
56 | fully convolutional, unlike the original
57 | version of the paper. The code is commented
58 | out for reference. This is the standard practice
59 | in most of the networks today.
60 | '''
61 |
62 | class Generator(nn.Module):
63 | def __init__(self):
64 | super(Generator, self).__init__()
65 |
66 | def convlayer(n_input, n_output, k_size=4, stride=2, padding=0):
67 | block = [
68 | nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),
69 | nn.ReLU(inplace=True),
70 | nn.BatchNorm2d(n_output),
71 | ]
72 | return block
73 |
74 | self.project = nn.Sequential(
75 | nn.Linear(opt.latent_dim, 256 * 4 * 4),
76 | nn.BatchNorm1d(256 * 4 * 4),
77 | nn.ReLU(inplace=True),
78 | )
79 |
80 | # pytorch doesn't have padding='same' as keras does.
81 | # hence we have to pad manually to get the desired behavior.
82 | # from the manual:
83 | # o = (i−1)∗s−2p+k, i/o = in/out-put, k = kernel size, p = padding, s = stride.
84 |
85 | self.model = nn.Sequential(
86 | *convlayer(opt.latent_dim, 256, 4, 1, 0), # Fully connected layer via convolution.
87 | *convlayer(256, 128, 4, 2, 1), # 4->8, (4-1)*2-2p+4 = 8, p = 1
88 | *convlayer(128, 64, 4, 2, 1), # 8->16, p = 1
89 | nn.ConvTranspose2d(64, opt.channels, 4, 2, 1),
90 | nn.Tanh()
91 | )
92 | # Tanh > Image values are between [-1, 1]
93 |
94 | def forward(self, z):
95 | # p = self.project(z)
96 | # p = p.view(-1, 256, 4, 4) # Project and reshape (notice that pytorch uses NCHW format)
97 | img = self.model(z)
98 | img = img.view(z.size(0), *img_dims)
99 | return img
100 |
101 | # Discriminator mirrors the generator, like and encoder-decoder in reverse.
102 | class Discriminator(nn.Module):
103 | def __init__(self):
104 | super(Discriminator, self).__init__()
105 |
106 | # Again, we must ensure the same padding.
107 | # From the manual:
108 | # o = (i+2p−d∗(kernel_size−1)−1)/s+1, d = dilation (default = 1).
109 |
110 | self.model = nn.Sequential(
111 | nn.Conv2d(opt.channels, 64, 4, 2, 1, bias=False), # 32-> 16, (32+2p-3-1)/2 + 1 = 16, p = 1
112 | nn.LeakyReLU(0.2, inplace=True),
113 | # nn.BatchNorm2d(64), # IS IT CRUCIAL TO SKIP BN AT FIRST LAYER OF DISCRIMINATOR?
114 | nn.Conv2d(64, 128, 4, 2, 1, bias=False),
115 | nn.LeakyReLU(0.2, inplace=True),
116 | nn.BatchNorm2d(128),
117 | nn.Conv2d(128, 256, 4, 2, 1, bias=False),
118 | nn.LeakyReLU(0.2, inplace=True),
119 | nn.BatchNorm2d(256),
120 | nn.Conv2d(256, 1, 4, 1, 0, bias=False), # FC with Conv.
121 | nn.Sigmoid()
122 | )
123 |
124 | # we used 3 convolutional blocks, hence 2^3 = 8 times downsampling.
125 | semantic_dim = opt.img_size // (2**3)
126 |
127 | self.l1 = nn.Sequential(
128 | nn.Linear(256 * semantic_dim**2, 1),
129 | nn.Sigmoid()
130 | )
131 |
132 | def forward(self, img):
133 | prob = self.model(img)
134 | # prob = self.l1(prob.view(prob.size(0), -1))
135 | return prob
136 |
137 |
138 | # Beware: CIFAR10 download is slow.
139 | def custom_data(CIFAR_CLASS = 5): # Only working with dogs.
140 | compose = transforms.Compose(
141 | [transforms.Resize(opt.img_size),
142 | transforms.ToTensor(),
143 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
144 | ])
145 | output_dir = './data/cifar10'
146 | cifar = datasets.CIFAR10(root=output_dir, download=True, train=True,
147 | transform=compose)
148 |
149 | # Our custom cifar with a single category.
150 | Tensor__ = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
151 |
152 | features = Tensor__(5000, 3, 32, 32) # 5k samples in each category.
153 | targets = Tensor__(5000, 1)
154 |
155 | # Go through CIFAR to pick only the desired class..
156 | it = 0
157 | for i in range(cifar.__len__()):
158 | sample = cifar.__getitem__(i)
159 | if sample[1] == CIFAR_CLASS:
160 | features[it, ...] = sample[0]
161 | targets[it] = sample[1]
162 | it += 1
163 | return features, targets
164 |
165 | features, targets = custom_data(CIFAR_CLASS=1) # Cars.
166 | batch_iterator = DataLoader(torch.utils.data.TensorDataset(features, targets), shuffle=True, batch_size=opt.batch_size) # List, NCHW format.
167 |
168 | # Save a batch of real images for reference.
169 | os.makedirs('./out', exist_ok=True)
170 | save_image(next(iter(batch_iterator))[0][:25, ...], './out/real_samples.png', nrow=5, normalize=True)
171 |
172 | cuda = torch.cuda.is_available()
173 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
174 | gan_loss = nn.BCELoss()
175 |
176 | generator = Generator()
177 | discriminator = Discriminator()
178 |
179 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
180 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
181 |
182 | # Loss record.
183 | g_losses = []
184 | d_losses = []
185 | epochs = []
186 | loss_legend = ['Discriminator', 'Generator']
187 |
188 | if cuda:
189 | generator = generator.cuda()
190 | discriminator = discriminator.cuda()
191 |
192 | generator.apply(weights_init_xavier)
193 | discriminator.apply(weights_init_xavier)
194 |
195 | noise_fixed = Variable(Tensor(25, opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN.
196 |
197 | for epoch in range(opt.n_epochs):
198 | print('Epoch {}'.format(epoch))
199 | for i, (batch, _) in enumerate(batch_iterator):
200 |
201 | real = Variable(Tensor(batch.size(0), 1, 1, 1).fill_(1), requires_grad=False)
202 | fake = Variable(Tensor(batch.size(0), 1, 1, 1).fill_(0), requires_grad=False)
203 |
204 | imgs_real = Variable(batch.type(Tensor), requires_grad=False)
205 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False)
206 | imgs_fake = Variable(generator(noise), requires_grad=False)
207 |
208 |
209 | # == Discriminator update == #
210 | optimizer_D.zero_grad()
211 |
212 | # A small reminder: given classes c, prob. p, - c*logp - (1-c)*log(1-p) is the BCE/GAN loss.
213 | # Putting D(x) as p, x=real's class as 1, (..and same for fake with c=0) we arrive to BCE loss.
214 | # This is intuitively how well the discriminator can distinguish btw real & fake.
215 | d_loss = gan_loss(discriminator(imgs_real), real) + \
216 | gan_loss(discriminator(imgs_fake), fake)
217 |
218 | d_loss.backward()
219 | optimizer_D.step()
220 |
221 |
222 | # == Generator update == #
223 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False)
224 | imgs_fake = generator(noise)
225 |
226 | optimizer_G.zero_grad()
227 |
228 | # Minimizing (1-log(d(g(noise))) is less stable than maximizing log(d(g)) [*].
229 | # Since BCE loss is defined as a negative sum, maximizing [*] is == minimizing [*]'s negative.
230 | # Intuitively, how well does the G fool the D?
231 | g_loss = gan_loss(discriminator(imgs_fake), real)
232 |
233 | g_loss.backward()
234 | optimizer_G.step()
235 |
236 | if vis:
237 | batches_done = epoch * len(batch_iterator) + i
238 | if batches_done % opt.sample_interval == 0:
239 |
240 | # Keep a record of losses for plotting.
241 | epochs.append(epoch + i/len(batch_iterator))
242 | g_losses.append(g_loss.item())
243 | d_losses.append(d_loss.item())
244 |
245 | # Generate images for a given set of fixed noise
246 | # so we can track how the GAN learns.
247 | imgs_fake_fixed = generator(noise_fixed).detach().data
248 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom.
249 |
250 | # Display results on visdom page.
251 | vis.line(
252 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
253 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
254 | opts={
255 | 'title': 'loss over time',
256 | 'legend': loss_legend,
257 | 'xlabel': 'epoch',
258 | 'ylabel': 'loss',
259 | 'width': 512,
260 | 'height': 512
261 | },
262 | win=1)
263 | vis.images(
264 | imgs_fake_fixed,
265 | nrow=5, win=2,
266 | opts={
267 | 'title': 'GAN output [Epoch {}]'.format(epoch),
268 | 'width': 512,
269 | 'height': 512,
270 | }
271 | )
272 |
--------------------------------------------------------------------------------
/dcgan/out/GAN output [Epoch 358].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/dcgan/out/GAN output [Epoch 358].jpg
--------------------------------------------------------------------------------
/dcgan/out/real_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/dcgan/out/real_samples.png
--------------------------------------------------------------------------------
/gan/gan.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from torch import nn, optim
5 | from torch.autograd.variable import Variable
6 |
7 | from torchvision import transforms, datasets
8 | from torch.utils.data import DataLoader
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
12 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
13 | parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
14 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
15 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
16 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
17 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
18 | parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension')
19 | parser.add_argument('--channels', type=int, default=1, help='number of image channels')
20 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
21 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
22 | parser.add_argument('--sample_interval', type=int, default=512, help='interval betwen image samples')
23 | opt = parser.parse_args()
24 |
25 |
26 | # Set up the visdom, if exists.
27 | # If not, visualization of the training will not be possible.
28 | # If visdom exists, it should be also on.
29 | # Use 'python -m visdom.server' command to activate the server.
30 | try:
31 | import visdom
32 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
33 | except ImportError:
34 | vis = None
35 | else:
36 | vis.close(None) # Clear all figures.
37 |
38 | img_dims = (opt.channels, opt.img_size, opt.img_size)
39 | n_features = opt.channels * opt.img_size * opt.img_size
40 |
41 | # This does not directly follow the paper. I did not fully understand the
42 | # order of the architecture they describe in there.
43 | class Generator(nn.Module):
44 | def __init__(self):
45 | super(Generator, self).__init__()
46 |
47 | def ganlayer(n_input, n_output, dropout=True):
48 | pipeline = [nn.Linear(n_input, n_output)]
49 | pipeline.append(nn.LeakyReLU(0.2, inplace=True))
50 | if dropout:
51 | pipeline.append(nn.Dropout(0.25))
52 | return pipeline
53 |
54 | self.model = nn.Sequential(
55 | *ganlayer(opt.latent_dim, 128, dropout=False),
56 | *ganlayer(128, 256),
57 | *ganlayer(256, 512),
58 | *ganlayer(512, 1024),
59 | nn.Linear(1024, n_features),
60 | nn.Tanh()
61 | )
62 | # Tanh > Image values are between [-1, 1]
63 |
64 | def forward(self, z):
65 | img = self.model(z)
66 | img = img.view(img.size(0), *img_dims)
67 | return img
68 |
69 | class Discriminator(nn.Module):
70 | def __init__(self):
71 | super(Discriminator, self).__init__()
72 |
73 | self.model = nn.Sequential(
74 | nn.Linear(n_features, 512),
75 | nn.LeakyReLU(0.2, inplace=True),
76 | nn.Linear(512, 256),
77 | nn.LeakyReLU(0.2, inplace=True),
78 | nn.Linear(256, 1),
79 | nn.Sigmoid()
80 | )
81 |
82 | def forward(self, img):
83 | flatimg = img.view(img.size(0), -1)
84 | prob = self.model( flatimg )
85 | return prob
86 |
87 |
88 | def mnist_data():
89 | compose = transforms.Compose(
90 | [transforms.ToTensor(),
91 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
92 | ])
93 | output_dir = './data/mnist'
94 | return datasets.MNIST(root=output_dir, train=True,
95 | transform=compose, download=True)
96 |
97 | mnist = mnist_data()
98 | batch_iterator = DataLoader(mnist, shuffle=True, batch_size=opt.batch_size) # List, NCHW format.
99 |
100 | cuda = torch.cuda.is_available()
101 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
102 | gan_loss = nn.BCELoss()
103 |
104 | generator = Generator()
105 | discriminator = Discriminator()
106 |
107 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
108 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
109 |
110 | # Loss record.
111 | g_losses = []
112 | d_losses = []
113 | epochs = []
114 | loss_legend = ['Discriminator', 'Generator']
115 |
116 | if cuda:
117 | generator = generator.cuda()
118 | discriminator = discriminator.cuda()
119 |
120 | for epoch in range(opt.n_epochs):
121 | print('Epoch {}'.format(epoch))
122 | for i, (batch, _) in enumerate(batch_iterator):
123 |
124 | real = Variable(Tensor(batch.size(0), 1).fill_(1), requires_grad=False)
125 | fake = Variable(Tensor(batch.size(0), 1).fill_(0), requires_grad=False)
126 |
127 | imgs_real = Variable(batch.type(Tensor))
128 | noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1))
129 | imgs_fake = Variable(generator(noise), requires_grad=False)
130 |
131 |
132 | # == Discriminator update == #
133 | optimizer_D.zero_grad()
134 |
135 | # A small reminder: given classes c, prob. p, - c*logp - (1-c)*log(1-p) is the BCE/GAN loss.
136 | # Putting D(x) as p, x=real's class as 1, (..and same for fake with c=0) we arrive to BCE loss.
137 | # This is intuitively how well the discriminator can distinguish btw real & fake.
138 | d_loss = gan_loss(discriminator(imgs_real), real) + \
139 | gan_loss(discriminator(imgs_fake), fake)
140 |
141 | d_loss.backward()
142 | optimizer_D.step()
143 |
144 |
145 | # == Generator update == #
146 | noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1))
147 | imgs_fake = generator(noise)
148 |
149 | optimizer_G.zero_grad()
150 |
151 | # Minimizing (1-log(d(g(noise))) is less stable than maximizing log(d(g)) [*].
152 | # Since BCE loss is defined as a negative sum, maximizing [*] is == minimizing [*]'s negative.
153 | # Intuitively, how well does the G fool the D?
154 | g_loss = gan_loss(discriminator(imgs_fake), real)
155 |
156 | g_loss.backward()
157 | optimizer_G.step()
158 |
159 | if vis:
160 | batches_done = epoch * len(batch_iterator) + i
161 | if batches_done % opt.sample_interval == 0:
162 |
163 | # Keep a record of losses for plotting.
164 | epochs.append(epoch + i/len(batch_iterator))
165 | g_losses.append(g_loss.item())
166 | d_losses.append(d_loss.item())
167 |
168 | # Display results on visdom page.
169 | vis.line(
170 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
171 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
172 | opts={
173 | 'title': 'loss over time',
174 | 'legend': loss_legend,
175 | 'xlabel': 'epoch',
176 | 'ylabel': 'loss',
177 | 'width': 512,
178 | 'height': 512
179 | },
180 | win=1)
181 | vis.images(
182 | imgs_fake.data[:25],
183 | nrow=5, win=2,
184 | opts={
185 | 'title': 'GAN output [Epoch {}]'.format(epoch),
186 | 'width': 512,
187 | 'height': 512,
188 | }
189 | )
190 |
--------------------------------------------------------------------------------
/gan/out/366b962ff767a0.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gan/out/GAN output [Epoch 95].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/gan/out/GAN output [Epoch 95].jpg
--------------------------------------------------------------------------------
/infogan/out/(Nothing is varied) GAN output [Epoch 14].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/infogan/out/(Nothing is varied) GAN output [Epoch 14].jpg
--------------------------------------------------------------------------------
/infogan/out/(Only vary c1) GAN output [Epoch 14].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/infogan/out/(Only vary c1) GAN output [Epoch 14].jpg
--------------------------------------------------------------------------------
/infogan/out/(Only vary c2) GAN output [Epoch 14].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/infogan/out/(Only vary c2) GAN output [Epoch 14].jpg
--------------------------------------------------------------------------------
/lsgan/lsgan.py:
--------------------------------------------------------------------------------
1 | '''
2 | Using Eqn. 9 of the lsgan paper as the loss function.
3 | Figure 2 of paper's v1 is used as the gen/discr. combo.
4 | Note that there is a v3 (as of 2018, July).
5 | A significant design difference is the use of BN before ReLU
6 | unlike the regular GANs (ReLU > BN), (BN > ReLU).
7 | '''
8 |
9 | import argparse
10 |
11 | import torch
12 | from torch import nn, optim
13 | from torch.autograd.variable import Variable
14 |
15 | from torchvision import transforms, datasets
16 | from torch.utils.data import DataLoader
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
20 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
21 | parser.add_argument('--lr', type=float, default=0.0001, help='adam: learning rate')
22 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
23 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
24 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
25 | parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension')
26 | parser.add_argument('--channels', type=int, default=1, help='number of image channels')
27 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
28 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
29 | parser.add_argument('--sample_interval', type=int, default=64, help='interval betwen image samples')
30 | opt = parser.parse_args()
31 |
32 | def weights_init_normal(m):
33 | classname = m.__class__.__name__
34 | if classname.find('Conv') != -1:
35 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
36 | elif classname.find('BatchNorm') != -1:
37 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
38 | torch.nn.init.constant_(m.bias.data, 0.0)
39 |
40 |
41 | try:
42 | import visdom
43 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
44 | except ImportError:
45 | vis = None
46 | else:
47 | vis.close(None) # Clear all figures.
48 |
49 |
50 | class Generator(nn.Module):
51 | def __init__(self):
52 | super(Generator, self).__init__()
53 |
54 | def convlayer(n_input, n_output, k_size=5, stride=2, padding=0, output_padding=0):
55 | block = [
56 | nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False, output_padding=output_padding),
57 | nn.BatchNorm2d(n_output),
58 | nn.ReLU(inplace=True),
59 | ]
60 | return block
61 |
62 | # MNIST (not constantinople.. sorry not LSUN)
63 | self.model = nn.Sequential(
64 | *convlayer(opt.latent_dim, 128, 7, 1, 0), # Fully connected, 128 x 7 x 7
65 | *convlayer(128, 64, 5, 2, 2, output_padding=1), # 64 x 14 x 14
66 | *convlayer(64, 32, 5, 2, 2, output_padding=1), # 32 x 28 x 28
67 | nn.ConvTranspose2d(32, opt.channels, 5, 1, 2), # 1 x 28 x 28
68 | nn.Tanh()
69 | )
70 |
71 | def forward(self, z):
72 | img = self.model(z)
73 | return img
74 |
75 |
76 | class Discriminator(nn.Module):
77 | def __init__(self):
78 | super(Discriminator, self).__init__()
79 |
80 | def convlayer(n_input, n_output, k_size=4, stride=2, padding=0, normalize=True, dilation=1):
81 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False, dilation=dilation),]
82 | if normalize:
83 | block.append(nn.BatchNorm2d(n_output))
84 | block.append(nn.LeakyReLU(0.2, inplace=True))
85 | return block
86 |
87 | self.conv_block = nn.Sequential( # 1 x 28 x 28
88 | *convlayer(opt.channels, 32, 5, 2, 2, normalize=False), # 32 x 14 x 14
89 | *convlayer(32, 64, 5, 2, 2), # 64 x 7 x 7
90 | )
91 |
92 | self.fc_block = nn.Sequential(
93 | nn.Linear(64 * 7 * 7, 512),
94 | nn.BatchNorm1d(512),
95 | nn.LeakyReLU(0.2, inplace=True),
96 | nn.Linear(512, 1),
97 | )
98 |
99 | def forward(self, img):
100 | conv_out = self.conv_block(img)
101 | conv_out = conv_out.view(img.size(0), 64 * 7 * 7)
102 | l2_value = self.fc_block(conv_out)
103 | l2_value = l2_value.unsqueeze_(dim=2).unsqueeze_(dim=3)
104 | return l2_value
105 |
106 |
107 | def mnist_data():
108 | compose = transforms.Compose(
109 | [transforms.Resize(opt.img_size),
110 | transforms.ToTensor(),
111 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
112 | ])
113 | output_dir = './data/mnist'
114 | return datasets.MNIST(root=output_dir, train=True,
115 | transform=compose, download=True)
116 |
117 | mnist = mnist_data()
118 | batch_iterator = DataLoader(mnist, shuffle=True, batch_size=opt.batch_size) # List, NCHW format.
119 |
120 |
121 | cuda = torch.cuda.is_available()
122 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
123 | gan_loss = nn.MSELoss() # LSGAN's loss, a shortcut to torch.mean( (x-y) ** 2 )
124 |
125 | generator = Generator()
126 | discriminator = Discriminator()
127 |
128 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
129 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
130 |
131 | # Loss record.
132 | g_losses = []
133 | d_losses = []
134 | epochs = []
135 | loss_legend = ['Discriminator', 'Generator']
136 |
137 | if cuda:
138 | generator = generator.cuda()
139 | discriminator = discriminator.cuda()
140 |
141 | generator.apply(weights_init_normal)
142 | discriminator.apply(weights_init_normal)
143 |
144 | noise_fixed = Variable(Tensor(25, opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN.
145 |
146 | for epoch in range(opt.n_epochs):
147 | print('Epoch {}'.format(epoch))
148 | for i, (batch, _) in enumerate(batch_iterator):
149 |
150 | real = Variable(Tensor(batch.size(0), 1, 1, 1).fill_(1.0), requires_grad=False)
151 | fake = Variable(Tensor(batch.size(0), 1, 1, 1).fill_(0.0), requires_grad=False)
152 |
153 | imgs_real = Variable(batch.type(Tensor), requires_grad=False)
154 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False)
155 | imgs_fake = generator(noise) # Variable(generator(noise), requires_grad=False)
156 |
157 |
158 | # == Discriminator update == #
159 | optimizer_D.zero_grad()
160 |
161 | d_loss = 0.5*gan_loss(discriminator(imgs_real), real) + \
162 | 0.5*gan_loss(discriminator(imgs_fake.data), fake)
163 |
164 | d_loss.backward()
165 | optimizer_D.step()
166 |
167 |
168 | # == Generator update == #
169 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False)
170 | imgs_fake = generator(noise)
171 |
172 | optimizer_G.zero_grad()
173 |
174 | g_loss = 0.5*gan_loss(discriminator(imgs_fake), real)
175 |
176 | g_loss.backward()
177 | optimizer_G.step()
178 |
179 | if vis:
180 | batches_done = epoch * len(batch_iterator) + i
181 | if batches_done % opt.sample_interval == 0:
182 |
183 | # Keep a record of losses for plotting.
184 | epochs.append(epoch + i/len(batch_iterator))
185 | g_losses.append(g_loss.item())
186 | d_losses.append(d_loss.item())
187 |
188 | # Generate images for a given set of fixed noise
189 | # so we can track how the GAN learns.
190 | imgs_fake_fixed = generator(noise_fixed).detach().data
191 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom.
192 |
193 | # Display results on visdom page.
194 | vis.line(
195 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
196 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
197 | opts={
198 | 'title': 'loss over time',
199 | 'legend': loss_legend,
200 | 'xlabel': 'epoch',
201 | 'ylabel': 'loss',
202 | 'width': 512,
203 | 'height': 512
204 | },
205 | win=1)
206 | vis.images(
207 | imgs_fake_fixed,
208 | nrow=5, win=2,
209 | opts={
210 | 'title': 'GAN output [Epoch {}]'.format(epoch),
211 | 'width': 512,
212 | 'height': 512,
213 | }
214 | )
215 |
--------------------------------------------------------------------------------
/lsgan/out/3677c4a9714c1c.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lsgan/out/GAN output [Epoch 11].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/lsgan/out/GAN output [Epoch 11].jpg
--------------------------------------------------------------------------------
/pix2pix/datasets.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from torch.utils import data
3 | import glob
4 | import torchvision.transforms as transforms
5 |
6 | class Dataset(data.Dataset):
7 | 'Characterizes a dataset for PyTorch'
8 | def __init__(self, data_folder, transform, stage):
9 | 'Initialization'
10 | self.images = glob.glob('{}/{}/*'.format(data_folder, stage))
11 | self.transform = transforms.Compose(transform)
12 | 'Get image size'
13 | _, self.img_size = Image.open(self.images[0]).size
14 |
15 | def __len__(self):
16 | 'Denotes the total number of samples'
17 | return len(self.images)
18 |
19 | def __getitem__(self, index):
20 | 'Generates one sample of data'
21 | # Select sample
22 |
23 | # Load data and get label
24 | X = Image.open(self.images[index])
25 | B = X.crop((0, 0, self.img_size, self.img_size)) # (left, upper, right, lower)
26 | A = X.crop((self.img_size, 0, self.img_size+self.img_size, self.img_size))
27 |
28 | return self.transform(A), self.transform(B)
29 |
--------------------------------------------------------------------------------
/pix2pix/download_dataset.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 | URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
3 | TAR_FILE=./$FILE.tar.gz
4 | TARGET_DIR=./$FILE/
5 | wget -N $URL -O $TAR_FILE
6 | mkdir $TARGET_DIR
7 | tar -zxvf $TAR_FILE
8 | rm $TAR_FILE
--------------------------------------------------------------------------------
/pix2pix/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | def weights_init_normal(m):
6 | classname = m.__class__.__name__
7 | if classname.find('Conv') != -1:
8 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
9 | elif classname.find('BatchNorm2d') != -1:
10 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
11 | torch.nn.init.constant_(m.bias.data, 0.0)
12 |
13 | '''
14 | Trying to make it compact, I'm afraid
15 | this u-net implementation became very
16 | messy. In future, simplicity before
17 | compactnesss.
18 | '''
19 |
20 | class Generator(nn.Module):
21 | def __init__(self, in_channels=3, out_channels=3):
22 | super(Generator, self).__init__()
23 |
24 | ##### Encoder #####
25 | def convblock(n_input, n_output, k_size=4, stride=2, padding=1, normalize=True):
26 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
27 | if normalize:
28 | block.append(nn.BatchNorm2d(n_output))
29 | block.append(nn.LeakyReLU(0.2, inplace=True))
30 | return block
31 |
32 | g_layers = [64, 128, 256, 512, 512, 512, 512, 512] # Generator feature maps (ref: pix2pix paper)
33 | self.encoder = nn.ModuleList(convblock(in_channels, g_layers[0], normalize=False)) # 1st layer is not normalized.
34 | for iter in range(1, len(g_layers)):
35 | self.encoder += convblock(g_layers[iter-1], g_layers[iter])
36 | # Ex.: Why use ModuleList instead of python lists?
37 |
38 | ##### Decoder #####
39 | def convdblock(n_input, n_output, k_size=4, stride=2, padding=1, dropout=0.0):
40 | block = [
41 | nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),
42 | nn.BatchNorm2d(n_output)
43 | ]
44 | if dropout: # Mixing BN and Dropout is not recommended in general.
45 | block.append(nn.Dropout(dropout))
46 | block.append(nn.ReLU(inplace=True))
47 | return block
48 |
49 | d_layers = [512, 1024, 1024, 1024, 1024, 512, 256, 128]
50 | self.decoder = nn.ModuleList(convdblock(g_layers[-1], d_layers[0]))
51 | for iter in range(1, len(d_layers)-1):
52 | self.decoder += convdblock(d_layers[iter], d_layers[iter+1]//2)
53 |
54 | self.model = nn.Sequential(
55 | nn.ConvTranspose2d(d_layers[-1], out_channels, 4, 2, 1),
56 | nn.Tanh()
57 | )
58 |
59 |
60 | def forward(self, z):
61 | # Encode.
62 | e = [self.encoder[1](self.encoder[0](z))]
63 | module_iter = 2
64 | for iter in range(1, 8): # we have 8 downsampling layers
65 | result = self.encoder[module_iter+0](e[iter-1])
66 | result = self.encoder[module_iter+1](result)
67 | result = self.encoder[module_iter+2](result)
68 | e.append(result)
69 | module_iter += 3
70 | # Decode.
71 | # First d-layer.
72 | d1 = self.decoder[2](self.decoder[1](self.decoder[0](e[-1])))
73 | d1 = torch.cat((d1, e[-2]), dim=1)
74 | d = [d1]
75 | module_iter = 3
76 | for iter in range(1, 7): # we have 7 upsampling layers
77 | result = self.decoder[module_iter+0](d[iter-1])
78 | result = self.decoder[module_iter+1](result)
79 | result = self.decoder[module_iter+2](result)
80 | result = torch.cat((result, e[-(iter+2)]), dim=1) # Concating n-i^th layer with i^th.
81 | d.append(result)
82 | module_iter += 3
83 |
84 | # Pass the decoder output through a "flattening" layer to get the image.
85 | img = self.model(d[-1])
86 | return img
87 |
88 |
89 | class Discriminator(nn.Module):
90 | def __init__(self, a_channels=3, b_channels=3):
91 | super(Discriminator, self).__init__()
92 |
93 | def convblock(n_input, n_output, k_size=4, stride=2, padding=1, normalize=True):
94 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
95 | if normalize:
96 | block.append(nn.BatchNorm2d(n_output))
97 | block.append(nn.LeakyReLU(0.2, inplace=True))
98 | return block
99 |
100 | self.model = nn.Sequential(
101 | *convblock(a_channels + b_channels, 64, normalize=False),
102 | *convblock(64, 128),
103 | *convblock(128, 256),
104 | *convblock(256, 512),
105 | )
106 |
107 | self.l1 = nn.Linear(512 * 16 * 16, 1)
108 |
109 | def forward(self, img_A, img_B):
110 | img = torch.cat((img_A, img_B), dim=1)
111 | conv_out = self.model(img)
112 | conv_out = conv_out.view(img_A.size(0), -1)
113 | prob = self.l1(conv_out)
114 | prob = F.sigmoid(prob)
115 | return prob
--------------------------------------------------------------------------------
/pix2pix/out/3685e66f775102.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pix2pix/out/3685e7054e8534.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pix2pix/out/GAN output [Epoch 216].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/pix2pix/out/GAN output [Epoch 216].jpg
--------------------------------------------------------------------------------
/pix2pix/out/GAN output [Epoch 240].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/pix2pix/out/GAN output [Epoch 240].jpg
--------------------------------------------------------------------------------
/pix2pix/pix2pix.py:
--------------------------------------------------------------------------------
1 | '''
2 | To run, use
3 | bash download_dataset.sh cityscapes
4 | to download the data into the same
5 | directory as the pix2pix.py.
6 | '''
7 |
8 | '''
9 | pix2pix paper suggests using a noise
10 | vector along with the real image (to generate
11 | translated images). However, I noticed from
12 | separate experiments that the noise is simply
13 | filtered out, i.e. it does not provide any diversity
14 | for a given A -> B conversion, hence omitted
15 | here. In fact, there is a GAN called BicycleGAN which
16 | addresses the very same issue by employing
17 | autoencoders.
18 | '''
19 |
20 | import argparse
21 |
22 | import torch
23 | from torch import nn, optim
24 | from torch.autograd.variable import Variable
25 |
26 | from torchvision import transforms, datasets
27 | from torch.utils.data import DataLoader
28 |
29 | from models import Generator, Discriminator, weights_init_normal
30 | from datasets import Dataset
31 |
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--n_epochs', type=int, default=1000, help='number of epochs of training')
34 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
35 | parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
36 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
37 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
38 | parser.add_argument('--lambda_l1', type=int, default=100, help='penalty term for deviation from original images')
39 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
40 | parser.add_argument('--data_folder', type=str, default="cityscapes", help='where the training image pairs are located')
41 | parser.add_argument('--img_size', type=int, default=256, help='size of the input images')
42 | parser.add_argument('--channels', type=int, default=3, help='number of image channels')
43 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
44 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
45 | parser.add_argument('--sample_interval', type=int, default=64, help='interval betwen image samples')
46 | opt = parser.parse_args()
47 |
48 |
49 | try:
50 | import visdom
51 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
52 | except ImportError:
53 | vis = None
54 | else:
55 | vis.close(None) # Clear all figures.
56 |
57 | # Cityscapes dataset loader.
58 | transform_image = [
59 | transforms.ToTensor(),
60 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
61 | ]
62 | data_train = Dataset(opt.data_folder, transform_image, stage='train')
63 | data_val = Dataset(opt.data_folder, transform_image, stage='val')
64 |
65 | params = {'batch_size': opt.batch_size, 'shuffle': True}
66 | dataloader_train = DataLoader(data_val, **params)
67 |
68 | params = {'batch_size': 5, 'shuffle': True}
69 | dataloader_val = DataLoader(data_val, **params)
70 |
71 | cuda = torch.cuda.is_available()
72 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
73 | gan_loss = nn.BCELoss()
74 | l1_loss = nn.L1Loss()
75 |
76 | generator = Generator()
77 | discriminator = Discriminator()
78 |
79 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
80 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
81 |
82 | # Loss record.
83 | g_losses = []
84 | d_losses = []
85 | epochs = []
86 | loss_legend = ['Discriminator', 'Generator']
87 |
88 | if cuda:
89 | generator = generator.cuda()
90 | discriminator = discriminator.cuda()
91 | gan_loss = gan_loss.cuda()
92 | l1_loss = l1_loss.cuda()
93 |
94 |
95 | generator.apply(weights_init_normal)
96 | discriminator.apply(weights_init_normal)
97 |
98 | for epoch in range(opt.n_epochs):
99 | print('Epoch {}'.format(epoch))
100 | for i, (batch_A, batch_B) in enumerate(dataloader_train):
101 |
102 | real = Variable(Tensor(batch_A.size(0), 1).fill_(1), requires_grad=False)
103 | fake = Variable(Tensor(batch_A.size(0), 1).fill_(0), requires_grad=False)
104 |
105 | imgs_real_A = Variable(batch_A.type(Tensor))
106 | imgs_real_B = Variable(batch_B.type(Tensor))
107 |
108 | # == Discriminator update == #
109 | optimizer_D.zero_grad()
110 |
111 | imgs_fake = Variable(generator(imgs_real_A.detach()))
112 |
113 | d_loss = gan_loss(discriminator(imgs_real_A, imgs_real_B), real) + gan_loss(discriminator(imgs_real_A, imgs_fake), fake)
114 |
115 | d_loss.backward()
116 | optimizer_D.step()
117 |
118 | # == Generator update == #
119 | imgs_fake = generator(imgs_real_A)
120 |
121 | optimizer_G.zero_grad()
122 |
123 | g_loss = gan_loss(discriminator(imgs_real_A, imgs_fake), real) + opt.lambda_l1 * l1_loss(imgs_fake, imgs_real_B)
124 |
125 | g_loss.backward()
126 | optimizer_G.step()
127 |
128 | if vis:
129 | batches_done = epoch * len(dataloader_train) + i
130 | if batches_done % opt.sample_interval == 0:
131 |
132 | # Keep a record of losses for plotting.
133 | epochs.append(epoch + i/len(dataloader_train))
134 | g_losses.append(g_loss.item())
135 | d_losses.append(d_loss.item())
136 |
137 |
138 |
139 | # Generate 5 images from the validation set.
140 | batch_A, batch_B = next(iter(dataloader_val))
141 |
142 | imgs_val_A = Variable(batch_A.type(Tensor))
143 | imgs_val_B = Variable(batch_B.type(Tensor))
144 | imgs_fake_B = generator(imgs_val_A).detach().data
145 |
146 | # For visualization purposes.
147 | imgs_val_A = imgs_val_A.add_(1).div_(2)
148 | imgs_fake_B = imgs_fake_B.add_(1).div_(2)
149 | fake_val = torch.cat((imgs_val_A, imgs_fake_B), dim=2)
150 |
151 |
152 | # Display results on visdom page.
153 | vis.line(
154 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
155 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
156 | opts={
157 | 'title': 'loss over time',
158 | 'legend': loss_legend,
159 | 'xlabel': 'epoch',
160 | 'ylabel': 'loss',
161 | 'width': 512,
162 | 'height': 512
163 | },
164 | win=1)
165 | vis.images(
166 | fake_val,
167 | nrow=5, win=2,
168 | opts={
169 | 'title': 'GAN output [Epoch {}]'.format(epoch),
170 | 'width': 512,
171 | 'height': 512,
172 | }
173 | )
174 |
--------------------------------------------------------------------------------
/ralsgan/out/GAN output [Iteration 31616].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/ralsgan/out/GAN output [Iteration 31616].png
--------------------------------------------------------------------------------
/ralsgan/out/GAN output [Iteration 31680].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/ralsgan/out/GAN output [Iteration 31680].png
--------------------------------------------------------------------------------
/ralsgan/preprocess_cat_dataset.py:
--------------------------------------------------------------------------------
1 | ## Modified version of https://github.com/microe/angora-blue/blob/master/cascade_training/describe.py by Erik Hovland
2 | ##### I modified it to work with Python 3, changed the paths and made it output a folder for images bigger than 64x64 and a folder for images bigger than 128x128
3 |
4 | import cv2
5 | import glob
6 | import math
7 | import sys
8 |
9 | def rotateCoords(coords, center, angleRadians):
10 | # Positive y is down so reverse the angle, too.
11 | angleRadians = -angleRadians
12 | xs, ys = coords[::2], coords[1::2]
13 | newCoords = []
14 | n = min(len(xs), len(ys))
15 | i = 0
16 | centerX = center[0]
17 | centerY = center[1]
18 | cosAngle = math.cos(angleRadians)
19 | sinAngle = math.sin(angleRadians)
20 | while i < n:
21 | xOffset = xs[i] - centerX
22 | yOffset = ys[i] - centerY
23 | newX = xOffset * cosAngle - yOffset * sinAngle + centerX
24 | newY = xOffset * sinAngle + yOffset * cosAngle + centerY
25 | newCoords += [newX, newY]
26 | i += 1
27 | return newCoords
28 |
29 | def preprocessCatFace(coords, image):
30 |
31 | leftEyeX, leftEyeY = coords[0], coords[1]
32 | rightEyeX, rightEyeY = coords[2], coords[3]
33 | mouthX = coords[4]
34 | if leftEyeX > rightEyeX and leftEyeY < rightEyeY and \
35 | mouthX > rightEyeX:
36 | # The "right eye" is in the second quadrant of the face,
37 | # while the "left eye" is in the fourth quadrant (from the
38 | # viewer's perspective.) Swap the eyes' labels in order to
39 | # simplify the rotation logic.
40 | leftEyeX, rightEyeX = rightEyeX, leftEyeX
41 | leftEyeY, rightEyeY = rightEyeY, leftEyeY
42 |
43 | eyesCenter = (0.5 * (leftEyeX + rightEyeX),
44 | 0.5 * (leftEyeY + rightEyeY))
45 |
46 | eyesDeltaX = rightEyeX - leftEyeX
47 | eyesDeltaY = rightEyeY - leftEyeY
48 | eyesAngleRadians = math.atan2(eyesDeltaY, eyesDeltaX)
49 | eyesAngleDegrees = eyesAngleRadians * 180.0 / math.pi
50 |
51 | # Straighten the image and fill in gray for blank borders.
52 | rotation = cv2.getRotationMatrix2D(
53 | eyesCenter, eyesAngleDegrees, 1.0)
54 | imageSize = image.shape[1::-1]
55 | straight = cv2.warpAffine(image, rotation, imageSize,
56 | borderValue=(128, 128, 128))
57 |
58 | # Straighten the coordinates of the features.
59 | newCoords = rotateCoords(
60 | coords, eyesCenter, eyesAngleRadians)
61 |
62 | # Make the face as wide as the space between the ear bases.
63 | w = abs(newCoords[16] - newCoords[6])
64 | # Make the face square.
65 | h = w
66 | # Put the center point between the eyes at (0.5, 0.4) in
67 | # proportion to the entire face.
68 | minX = eyesCenter[0] - w/2
69 | if minX < 0:
70 | w += minX
71 | minX = 0
72 | minY = eyesCenter[1] - h*2/5
73 | if minY < 0:
74 | h += minY
75 | minY = 0
76 |
77 | # Crop the face.
78 | crop = straight[int(minY):int(minY+h), int(minX):int(minX+w)]
79 | # Return the crop.
80 | return crop
81 |
82 | def describePositive():
83 | output = open('log.txt', 'w')
84 | for imagePath in glob.glob('cat_dataset/*.jpg'):
85 | # Open the '.cat' annotation file associated with this
86 | # image.
87 | input = open('%s.cat' % imagePath, 'r')
88 | # Read the coordinates of the cat features from the
89 | # file. Discard the first number, which is the number
90 | # of features.
91 | coords = [int(i) for i in input.readline().split()[1:]]
92 | # Read the image.
93 | image = cv2.imread(imagePath)
94 | # Straighten and crop the cat face.
95 | crop = preprocessCatFace(coords, image)
96 | if crop is None:
97 | print >> sys.stderr, \
98 | 'Failed to preprocess image at %s.' % \
99 | imagePath
100 | continue
101 | # Save the crop to folders based on size
102 | h, w, colors = crop.shape
103 | if min(h,w) >= 64:
104 | Path1 = imagePath.replace("cat_dataset","cats_bigger_than_64x64")
105 | resized_crop = cv2.resize(crop, (64, 64))
106 | cv2.imwrite(Path1, resized_crop)
107 | if min(h,w) >= 128:
108 | Path2 = imagePath.replace("cat_dataset","cats_bigger_than_128x128")
109 | resized_crop = cv2.resize(crop, (128, 128))
110 | cv2.imwrite(Path2, resized_crop)
111 | if min(h,w) >= 256:
112 | Path2 = imagePath.replace("cat_dataset","cats_bigger_than_256x256")
113 | resized_crop = cv2.resize(crop, (256, 256))
114 | cv2.imwrite(Path2, resized_crop)
115 | # Append the cropped face and its bounds to the
116 | # positive description.
117 | #h, w = crop.shape[:2]
118 | #print (cropPath, 1, 0, 0, w, h, file=output)
119 |
120 |
121 | def main():
122 | describePositive()
123 |
124 | if __name__ == '__main__':
125 | main()
--------------------------------------------------------------------------------
/ralsgan/ralsgan.py:
--------------------------------------------------------------------------------
1 | '''
2 | To run this script, you must have the CATS dataset.
3 | I put the necessary code under this
4 | folder. Steps to follow (ref: @AlexiaJM's github)
5 | 1. Download: Cat Dataset (http://academictorrents.com/details/c501571c29d16d7f41d159d699d0e7fb37092cbd)
6 | (I used a version uploaded by @simoninithomas hence I have
7 | a slightly different version of cats .)
8 | 2. Run setting_up_script.sh in same folder as preprocess_cat_dataset.py
9 | and your CAT dataset (open and run manually)
10 | 3. mv cats_bigger_than_256x256 "cats/0"
11 | (Imagefolder class requires a subfolder under the given
12 | folder (indicating class)
13 | '''
14 |
15 |
16 | '''
17 | Relativistic GAN paper (https://arxiv.org/abs/1807.00734)
18 | is very inventive and experimental, hence comes with a
19 | bunch of different versions of the same underlying
20 | idea. I used Ralsgan (relativistic average least squares)
21 | for demonstration as it seems to give the most promising
22 | results for the high-definition outputs with a single shot
23 | network.
24 | A really interesting behavior I observe with this model is
25 | that although around 5k iterations I see some "catlike"
26 | images, the images look like noise before that for thousands
27 | of iterations, which is unlike the other GAN losses, where the
28 | quality improvement is more linear. Also note that the proper
29 | cat generations don't come before 10s of thousands of iterations.
30 | '''
31 |
32 | import argparse
33 | import numpy
34 |
35 | import torch
36 | from torch import nn, optim
37 | from torch.autograd.variable import Variable
38 |
39 | from torchvision import transforms, datasets
40 |
41 | import os
42 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
43 |
44 |
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--n_iters', type=int, default=100e3, help='number of iterations of training')
47 | parser.add_argument('--batch_size', type=int, default=48, help='size of the batches')
48 | parser.add_argument('--lr', type=float, default=0.0001, help='adam: learning rate')
49 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
50 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
51 | parser.add_argument('--latent_dim', type=int, default=128, help='dimensionality of the latent space')
52 | parser.add_argument('--input_folder', type=str, default="cats", help='source images folder')
53 | parser.add_argument('--img_size', type=int, default=256, help='size of each image dimension')
54 | parser.add_argument('--channels', type=int, default=3, help='number of image channels')
55 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
56 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
57 | parser.add_argument('--sample_interval', type=int, default=64, help='interval betwen image samples')
58 | opt = parser.parse_args()
59 |
60 | def weights_init_normal(m):
61 | classname = m.__class__.__name__
62 | if classname.find('Conv') != -1:
63 | m.weight.data.normal_(0.0, 0.02)
64 | elif classname.find('BatchNorm') != -1:
65 | m.weight.data.normal_(1.0, 0.02)
66 | m.bias.data.fill_(0)
67 |
68 |
69 | try:
70 | import visdom
71 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
72 | except ImportError:
73 | vis = None
74 | else:
75 | vis.close(None) # Clear all figures.
76 |
77 |
78 | img_dims = (opt.channels, opt.img_size, opt.img_size)
79 | n_features = opt.channels * opt.img_size * opt.img_size
80 |
81 |
82 | # Appendix D.4., DCGAN for 0.
83 | class Generator(nn.Module):
84 | def __init__(self):
85 | super(Generator, self).__init__()
86 |
87 | def convlayer(n_input, n_output, k_size=4, stride=2, padding=0):
88 | block = [
89 | nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),
90 | nn.BatchNorm2d(n_output),
91 | nn.ReLU(inplace=True),
92 | ]
93 | return block
94 |
95 | self.model = nn.Sequential(
96 | *convlayer(opt.latent_dim, 1024, 4, 1, 0), # Fully connected layer via convolution.
97 | *convlayer(1024, 512, 4, 2, 1),
98 | *convlayer(512, 256, 4, 2, 1),
99 | *convlayer(256, 128, 4, 2, 1),
100 | *convlayer(128, 64, 4, 2, 1),
101 | *convlayer(64, 32, 4, 2, 1),
102 | nn.ConvTranspose2d(32, opt.channels, 4, 2, 1),
103 | nn.Tanh()
104 | )
105 | '''
106 | There is a slight error in v2 of the relativistic gan paper, where
107 | the architecture goes from 128>64>32 but then 64>3.
108 | '''
109 |
110 |
111 | def forward(self, z):
112 | z = z.view(-1, opt.latent_dim, 1, 1)
113 | img = self.model(z)
114 | return img
115 |
116 | # PACGAN2.
117 | class Discriminator(nn.Module):
118 | def __init__(self):
119 | super(Discriminator, self).__init__()
120 |
121 | def convlayer(n_input, n_output, k_size=4, stride=2, padding=0, bn=False):
122 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
123 | if bn:
124 | block.append(nn.BatchNorm2d(n_output))
125 | block.append(nn.LeakyReLU(0.2, inplace=True))
126 | return block
127 |
128 | self.model = nn.Sequential(
129 | *convlayer(opt.channels * 2, 32, 4, 2, 1),
130 | *convlayer(32, 64, 4, 2, 1),
131 | *convlayer(64, 128, 4, 2, 1, bn=True),
132 | *convlayer(128, 256, 4, 2, 1, bn=True),
133 | *convlayer(256, 512, 4, 2, 1, bn=True),
134 | *convlayer(512, 1024, 4, 2, 1, bn=True),
135 | nn.Conv2d(1024, 1, 4, 1, 0, bias=False), # FC with Conv.
136 | )
137 |
138 | def forward(self, imgs):
139 | critic_value = self.model(imgs)
140 | critic_value = critic_value.view(imgs.size(0), -1)
141 | return critic_value
142 |
143 |
144 | '''
145 | Worst part of the data science: data (prep).
146 | I shamelessly copy these from @AlexiaJM's code.
147 | (https://github.com/AlexiaJM/RelativisticGAN)
148 | '''
149 |
150 | transform = transforms.Compose([
151 | transforms.Resize((opt.img_size, opt.img_size)),
152 | transforms.ToTensor(),
153 | transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
154 | ])
155 | data = datasets.ImageFolder(root=os.path.join(os.getcwd(), opt.input_folder), transform=transform)
156 |
157 |
158 | def generate_random_sample():
159 | while True:
160 | random_indexes = numpy.random.choice(data.__len__(), size=opt.batch_size * 2, replace=False)
161 | batch = [data[i][0] for i in random_indexes]
162 | yield torch.stack(batch, 0)
163 |
164 |
165 | random_sample = generate_random_sample()
166 |
167 | def mse_loss(input, target):
168 | return torch.sum((input - target)**2) / input.data.nelement()
169 |
170 | cuda = torch.cuda.is_available()
171 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
172 | gan_loss = mse_loss
173 |
174 | generator = Generator()
175 | discriminator = Discriminator()
176 |
177 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
178 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
179 |
180 | # Loss record.
181 | g_losses = []
182 | d_losses = []
183 | epochs = []
184 | loss_legend = ['Discriminator', 'Generator']
185 |
186 | if cuda:
187 | generator = generator.cuda()
188 | discriminator = discriminator.cuda()
189 |
190 | generator.apply(weights_init_normal)
191 | discriminator.apply(weights_init_normal)
192 |
193 | noise_fixed = Variable(Tensor(25, opt.latent_dim).normal_(0, 1), requires_grad=False)
194 |
195 | for it in range(int(opt.n_iters)):
196 | print('Iter. {}'.format(it))
197 |
198 | batch = random_sample.__next__()
199 |
200 | imgs_real = Variable(batch.type(Tensor))
201 | imgs_real = torch.cat((imgs_real[0:opt.batch_size, ...], imgs_real[opt.batch_size:opt.batch_size * 2, ...]), 1)
202 | real = Variable(Tensor(batch.size(0)//2, 1).fill_(1.0), requires_grad=False)
203 |
204 | noise = Variable(Tensor(opt.batch_size * 2, opt.latent_dim).normal_(0, 1))
205 | imgs_fake = generator(noise)
206 | imgs_fake = torch.cat((imgs_fake[0:opt.batch_size, ...], imgs_fake[opt.batch_size:opt.batch_size * 2, ...]), 1)
207 |
208 | # == Discriminator update == #
209 | optimizer_D.zero_grad()
210 |
211 | c_xr = discriminator(imgs_real)
212 | c_xf = discriminator(imgs_fake.detach())
213 |
214 | d_loss = gan_loss(c_xr, torch.mean(c_xf) + real) + gan_loss(c_xf, torch.mean(c_xr) - real)
215 |
216 | d_loss.backward()
217 | optimizer_D.step()
218 |
219 | # == Generator update == #
220 | batch = random_sample.__next__()
221 |
222 | imgs_real = Variable(batch.type(Tensor))
223 | imgs_real = torch.cat((imgs_real[0:opt.batch_size, ...], imgs_real[opt.batch_size:opt.batch_size * 2, ...]), 1)
224 |
225 | noise = Variable(Tensor(opt.batch_size * 2, opt.latent_dim).normal_(0, 1))
226 | imgs_fake = generator(noise)
227 | imgs_fake = torch.cat((imgs_fake[0:opt.batch_size, ...], imgs_fake[opt.batch_size:opt.batch_size * 2, ...]), 1)
228 |
229 | c_xr = discriminator(imgs_real)
230 | c_xf = discriminator(imgs_fake)
231 | real = Variable(Tensor(batch.size(0)//2, 1).fill_(1.0), requires_grad=False)
232 |
233 | optimizer_G.zero_grad()
234 |
235 | g_loss = gan_loss(c_xf, torch.mean(c_xr) + real) + gan_loss(c_xr, torch.mean(c_xf) - real)
236 |
237 | g_loss.backward()
238 | optimizer_G.step()
239 |
240 | if vis:
241 | if it % opt.sample_interval == 0:
242 |
243 | # Keep a record of losses for plotting.
244 | epochs.append(it)
245 | g_losses.append(g_loss.item())
246 | d_losses.append(d_loss.item())
247 |
248 | # Generate images for a given set of fixed noise
249 | # so we can track how the GAN learns.
250 | imgs_fake_fixed = generator(noise_fixed).detach().data
251 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom.
252 |
253 | # Display results on visdom page.
254 | vis.line(
255 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
256 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
257 | opts={
258 | 'title': 'loss over time',
259 | 'legend': loss_legend,
260 | 'xlabel': 'iteration',
261 | 'ylabel': 'loss',
262 | 'width': 512,
263 | 'height': 512
264 | },
265 | win=1)
266 | vis.images(
267 | imgs_fake_fixed,
268 | nrow=5, win=2,
269 | opts={
270 | 'title': 'GAN output [Iteration {}]'.format(it),
271 | 'width': 512,
272 | 'height': 512,
273 | }
274 | )
--------------------------------------------------------------------------------
/ralsgan/setting_up_script.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # wget -nc http://www.simoninithomas.com/data/cats.zip
3 | unzip cats.zip -d cat_dataset
4 | # Move everything in one single directory
5 | mv cat_dataset/CAT_00/* cat_dataset
6 | rmdir cat_dataset/CAT_00
7 | mv cat_dataset/CAT_01/* cat_dataset
8 | rmdir cat_dataset/CAT_01
9 | mv cat_dataset/CAT_02/* cat_dataset
10 | rmdir cat_dataset/CAT_02
11 | mv cat_dataset/CAT_03/* cat_dataset
12 | rmdir cat_dataset/CAT_03
13 | mv cat_dataset/CAT_04/* cat_dataset
14 | rmdir cat_dataset/CAT_04
15 | mv cat_dataset/CAT_05/* cat_dataset
16 | rmdir cat_dataset/CAT_05
17 | mv cat_dataset/CAT_06/* cat_dataset
18 | rmdir cat_dataset/CAT_06
19 | ## Error correction
20 | rm cat_dataset/00000003_019.jpg.cat
21 | mv 00000003_015.jpg.cat cat_dataset/00000003_015.jpg.cat
22 | ## Removing outliers
23 | # Corrupted, drawings, badly cropped, inverted, impossible to tell it's a cat, blocked face
24 | cd cat_dataset
25 | rm 00000004_007.jpg 00000007_002.jpg 00000045_028.jpg 00000050_014.jpg 00000056_013.jpg 00000059_002.jpg 00000108_005.jpg 00000122_023.jpg 00000126_005.jpg 00000132_018.jpg 00000142_024.jpg 00000142_029.jpg 00000143_003.jpg 00000145_021.jpg 00000166_021.jpg 00000169_021.jpg 00000186_002.jpg 00000202_022.jpg 00000208_023.jpg 00000210_003.jpg 00000229_005.jpg 00000236_025.jpg 00000249_016.jpg 00000254_013.jpg 00000260_019.jpg 00000261_029.jpg 00000265_029.jpg 00000271_020.jpg 00000282_026.jpg 00000316_004.jpg 00000352_014.jpg 00000400_026.jpg 00000406_006.jpg 00000431_024.jpg 00000443_027.jpg 00000502_015.jpg 00000504_012.jpg 00000510_019.jpg 00000514_016.jpg 00000514_008.jpg 00000515_021.jpg 00000519_015.jpg 00000522_016.jpg 00000523_021.jpg 00000529_005.jpg 00000556_022.jpg 00000574_011.jpg 00000581_018.jpg 00000582_011.jpg 00000588_016.jpg 00000588_019.jpg 00000590_006.jpg 00000592_018.jpg 00000593_027.jpg 00000617_013.jpg 00000618_016.jpg 00000619_025.jpg 00000622_019.jpg 00000622_021.jpg 00000630_007.jpg 00000645_016.jpg 00000656_017.jpg 00000659_000.jpg 00000660_022.jpg 00000660_029.jpg 00000661_016.jpg 00000663_005.jpg 00000672_027.jpg 00000673_027.jpg 00000675_023.jpg 00000692_006.jpg 00000800_017.jpg 00000805_004.jpg 00000807_020.jpg 00000823_010.jpg 00000824_010.jpg 00000836_008.jpg 00000843_021.jpg 00000850_025.jpg 00000862_017.jpg 00000864_007.jpg 00000865_015.jpg 00000870_007.jpg 00000877_014.jpg 00000882_013.jpg 00000887_028.jpg 00000893_022.jpg 00000907_013.jpg 00000921_029.jpg 00000929_022.jpg 00000934_006.jpg 00000960_021.jpg 00000976_004.jpg 00000987_000.jpg 00000993_009.jpg 00001006_014.jpg 00001008_013.jpg 00001012_019.jpg 00001014_005.jpg 00001020_017.jpg 00001039_008.jpg 00001039_023.jpg 00001048_029.jpg 00001057_003.jpg 00001068_005.jpg 00001113_015.jpg 00001140_007.jpg 00001157_029.jpg 00001158_000.jpg 00001167_007.jpg 00001184_007.jpg 00001188_019.jpg 00001204_027.jpg 00001205_022.jpg 00001219_005.jpg 00001243_010.jpg 00001261_005.jpg 00001270_028.jpg 00001274_006.jpg 00001293_015.jpg 00001312_021.jpg 00001365_026.jpg 00001372_006.jpg 00001379_018.jpg 00001388_024.jpg 00001389_026.jpg 00001418_028.jpg 00001425_012.jpg 00001431_001.jpg 00001456_018.jpg 00001458_003.jpg 00001468_019.jpg 00001475_009.jpg 00001487_020.jpg
26 | rm 00000004_007.jpg.cat 00000007_002.jpg.cat 00000045_028.jpg.cat 00000050_014.jpg.cat 00000056_013.jpg.cat 00000059_002.jpg.cat 00000108_005.jpg.cat 00000122_023.jpg.cat 00000126_005.jpg.cat 00000132_018.jpg.cat 00000142_024.jpg.cat 00000142_029.jpg.cat 00000143_003.jpg.cat 00000145_021.jpg.cat 00000166_021.jpg.cat 00000169_021.jpg.cat 00000186_002.jpg.cat 00000202_022.jpg.cat 00000208_023.jpg.cat 00000210_003.jpg.cat 00000229_005.jpg.cat 00000236_025.jpg.cat 00000249_016.jpg.cat 00000254_013.jpg.cat 00000260_019.jpg.cat 00000261_029.jpg.cat 00000265_029.jpg.cat 00000271_020.jpg.cat 00000282_026.jpg.cat 00000316_004.jpg.cat 00000352_014.jpg.cat 00000400_026.jpg.cat 00000406_006.jpg.cat 00000431_024.jpg.cat 00000443_027.jpg.cat 00000502_015.jpg.cat 00000504_012.jpg.cat 00000510_019.jpg.cat 00000514_016.jpg.cat 00000514_008.jpg.cat 00000515_021.jpg.cat 00000519_015.jpg.cat 00000522_016.jpg.cat 00000523_021.jpg.cat 00000529_005.jpg.cat 00000556_022.jpg.cat 00000574_011.jpg.cat 00000581_018.jpg.cat 00000582_011.jpg.cat 00000588_016.jpg.cat 00000588_019.jpg.cat 00000590_006.jpg.cat 00000592_018.jpg.cat 00000593_027.jpg.cat 00000617_013.jpg.cat 00000618_016.jpg.cat 00000619_025.jpg.cat 00000622_019.jpg.cat 00000622_021.jpg.cat 00000630_007.jpg.cat 00000645_016.jpg.cat 00000656_017.jpg.cat 00000659_000.jpg.cat 00000660_022.jpg.cat 00000660_029.jpg.cat 00000661_016.jpg.cat 00000663_005.jpg.cat 00000672_027.jpg.cat 00000673_027.jpg.cat 00000675_023.jpg.cat 00000692_006.jpg.cat 00000800_017.jpg.cat 00000805_004.jpg.cat 00000807_020.jpg.cat 00000823_010.jpg.cat 00000824_010.jpg.cat 00000836_008.jpg.cat 00000843_021.jpg.cat 00000850_025.jpg.cat 00000862_017.jpg.cat 00000864_007.jpg.cat 00000865_015.jpg.cat 00000870_007.jpg.cat 00000877_014.jpg.cat 00000882_013.jpg.cat 00000887_028.jpg.cat 00000893_022.jpg.cat 00000907_013.jpg.cat 00000921_029.jpg.cat 00000929_022.jpg.cat 00000934_006.jpg.cat 00000960_021.jpg.cat 00000976_004.jpg.cat 00000987_000.jpg.cat 00000993_009.jpg.cat 00001006_014.jpg.cat 00001008_013.jpg.cat 00001012_019.jpg.cat 00001014_005.jpg.cat 00001020_017.jpg.cat 00001039_008.jpg.cat 00001039_023.jpg.cat 00001048_029.jpg.cat 00001057_003.jpg.cat 00001068_005.jpg.cat 00001113_015.jpg.cat 00001140_007.jpg.cat 00001157_029.jpg.cat 00001158_000.jpg.cat 00001167_007.jpg.cat 00001184_007.jpg.cat 00001188_019.jpg.cat 00001204_027.jpg.cat 00001205_022.jpg.cat 00001219_005.jpg.cat 00001243_010.jpg.cat 00001261_005.jpg.cat 00001270_028.jpg.cat 00001274_006.jpg.cat 00001293_015.jpg.cat 00001312_021.jpg.cat 00001365_026.jpg.cat 00001372_006.jpg.cat 00001379_018.jpg.cat 00001388_024.jpg.cat 00001389_026.jpg.cat 00001418_028.jpg.cat 00001425_012.jpg.cat 00001431_001.jpg.cat 00001456_018.jpg.cat 00001458_003.jpg.cat 00001468_019.jpg.cat 00001475_009.jpg.cat 00001487_020.jpg.cat
27 | cd ..
28 | ## Preprocessing and putting in folders for different image sizes
29 | mkdir cats_bigger_than_64x64
30 | mkdir cats_bigger_than_128x128
31 | mkdir cats_bigger_than_256x256
32 | python preprocess_cat_dataset.py
33 | ## Removing cat_dataset
34 | rm -r cat_dataset
35 | ## Move to your favorite place
36 | mv cats_bigger_than_256x256 cats
--------------------------------------------------------------------------------
/srgan/datasets.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from torch.utils import data
3 | import glob
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 |
7 | class Dataset(data.Dataset):
8 | 'Characterizes a dataset for PyTorch'
9 | def __init__(self, data_folder, transform_lr, transform_hr, stage):
10 | 'Initialization'
11 | file_list = glob.glob('{}/*'.format(data_folder))
12 | n = len(file_list)
13 | train_size = np.floor(n * 0.8).astype(np.int)
14 | self.images = file_list[:train_size] if stage is 'train' else file_list[train_size:]
15 |
16 | self.transform_lr = transforms.Compose(transform_lr)
17 | self.transform_hr = transforms.Compose(transform_hr)
18 |
19 | def __len__(self):
20 | 'Denotes the total number of samples'
21 | return len(self.images)
22 |
23 | def __getitem__(self, index):
24 | 'Generates one sample of data'
25 | # Select sample
26 |
27 | # Load data and get label
28 | hr = Image.open(self.images[index])
29 |
30 | return self.transform_lr(hr), self.transform_hr(hr)
--------------------------------------------------------------------------------
/srgan/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchvision.models import vgg19
4 | from torch.autograd import Variable
5 |
6 |
7 | def weights_init_normal(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Conv') != -1:
10 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
11 | elif classname.find('BatchNorm2d') != -1:
12 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
13 | torch.nn.init.constant_(m.bias.data, 0.0)
14 |
15 | class ResidualBlock(nn.Module):
16 | def __init__(self, n_output=64, k_size=3, stride=1, padding=1):
17 | super(ResidualBlock, self).__init__()
18 | self.model = nn.Sequential(
19 | nn.Conv2d(n_output, n_output, k_size, stride, padding),
20 | nn.BatchNorm2d(n_output),
21 | nn.PReLU(),
22 | nn.Conv2d(n_output, n_output, k_size, stride, padding),
23 | nn.BatchNorm2d(n_output),
24 | )
25 |
26 | def forward(self, x):
27 | return x + self.model(x)
28 |
29 | class ShuffleBlock(nn.Module):
30 | def __init__(self, n_input, n_output, k_size=3, stride=1, padding=1):
31 | super(ShuffleBlock, self).__init__()
32 | self.model = nn.Sequential(
33 | nn.Conv2d(n_input, n_output, k_size, stride, padding), # N, 256, H, W
34 | nn.PixelShuffle(2), # N, 64, 2H, 2W
35 | nn.PReLU(),
36 | )
37 | '''
38 | Input: :math:`(N, C * upscale_factor^2, H, W)`
39 | Output: :math:`(N, C, H * upscale_factor, W * upscale_factor)`
40 | '''
41 |
42 | def forward(self, x):
43 | return self.model(x)
44 |
45 |
46 | # n_fmap = number of feature maps,
47 | # B = number of cascaded residual blocks.
48 | class Generator(nn.Module):
49 | def __init__(self, n_input=3, n_output=3, n_fmap=64, B=16):
50 | super(Generator, self).__init__()
51 |
52 | self.l1 = nn.Sequential(
53 | nn.Conv2d(n_input, n_fmap, 9, 1, 4),
54 | nn.PReLU(),
55 | )
56 |
57 | # A cascaded of B residual blocks.
58 | self.R = []
59 | for _ in range(B):
60 | self.R.append(ResidualBlock(n_fmap))
61 | self.R = nn.Sequential(*self.R)
62 |
63 | self.l2 = nn.Sequential(
64 | nn.Conv2d(n_fmap, n_fmap, 3, 1, 1),
65 | nn.BatchNorm2d(n_fmap),
66 | )
67 |
68 | self.px = nn.Sequential(
69 | ShuffleBlock(64, 256),
70 | ShuffleBlock(64, 256),
71 | )
72 |
73 | self.conv_final = nn.Sequential(
74 | nn.Conv2d(64, n_output, 9, 1, 4),
75 | nn.Tanh(),
76 | )
77 |
78 |
79 | def forward(self, img_in):
80 | out_1 = self.l1(img_in)
81 | out_2 = self.R(out_1)
82 | out_3 = out_1 + self.l2(out_2)
83 | out_4 = self.px(out_3)
84 | return self.conv_final(out_4)
85 |
86 |
87 | class Discriminator(nn.Module):
88 | def __init__(self, lr_channels=3):
89 | super(Discriminator, self).__init__()
90 |
91 | def convblock(n_input, n_output, k_size=3, stride=1, padding=1, bn=True):
92 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
93 | if bn:
94 | block.append(nn.BatchNorm2d(n_output))
95 | block.append(nn.LeakyReLU(0.2, inplace=True))
96 | return block
97 |
98 | self.conv = nn.Sequential(
99 | *convblock(lr_channels, 64, 3, 1, 1, bn=False),
100 | *convblock(64, 64, 3, 2, 1),
101 | *convblock(64, 128, 3, 1, 1),
102 | *convblock(128, 128, 3, 2, 1),
103 | *convblock(128, 256, 3, 1, 1),
104 | *convblock(256, 256, 3, 2, 1),
105 | *convblock(256, 512, 3, 1, 1),
106 | *convblock(512, 512, 3, 2, 1),
107 | )
108 |
109 | self.fc = nn.Sequential(
110 | nn.Linear(512 * 16 * 16, 1024),
111 | nn.LeakyReLU(0.2, inplace=True),
112 | nn.Linear(1024, 1),
113 | nn.Sigmoid()
114 | )
115 |
116 | def forward(self, img):
117 | out_1 = self.conv(img)
118 | out_1 = out_1.view(img.size(0), -1)
119 | out_2 = self.fc(out_1)
120 | return out_2
121 |
122 |
123 | class VGGFeatures(nn.Module):
124 | def __init__(self):
125 | super(VGGFeatures, self).__init__()
126 | model = vgg19(pretrained=True)
127 |
128 | children = list(model.features.children())
129 | max_pool_indices = [index for index, m in enumerate(children) if isinstance(m, nn.MaxPool2d)]
130 | target_features = children[:max_pool_indices[4]]
131 | '''
132 | We use vgg-5,4 which is the layer output after 5th conv
133 | and right before the 4th max pool.
134 | '''
135 | self.features = nn.Sequential(*target_features)
136 | for p in self.features.parameters():
137 | p.requires_grad = False
138 |
139 | '''
140 | # VGG means and stdevs on pretrained imagenet
141 | mean = -1 + Variable(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
142 | std = 2*Variable(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
143 |
144 | # This is for cuda compatibility.
145 | self.register_buffer('mean', mean)
146 | self.register_buffer('std', std)
147 | '''
148 |
149 | def forward(self, input):
150 | # input = (input - self.mean) / self.std
151 | output = self.features(input)
152 | return output
--------------------------------------------------------------------------------
/srgan/out/GAN output [Epoch 6] (1).jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/srgan/out/GAN output [Epoch 6] (1).jpg
--------------------------------------------------------------------------------
/srgan/out/GAN output [Epoch 6].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/srgan/out/GAN output [Epoch 6].jpg
--------------------------------------------------------------------------------
/srgan/srgan.py:
--------------------------------------------------------------------------------
1 | '''
2 | Literally, the 90% of my time was spent
3 | looking for a HD dataset that I can use
4 | for this GAN. It turns out not a lot of
5 | such sets exist. If you happen to have
6 | a nice custom dataset of say size 1024x
7 | 1024, it would fit perfectly for this
8 | code. I settled for "celeba align"
9 | dataset that isn't particularly high
10 | definition, but the memory constraints
11 | forced my hand on this. It is accessible:
12 | https://drive.google.com/open?id=0B7EVK8r0v71pWEZsZE9oNnFzTm8
13 | '''
14 |
15 | import argparse
16 |
17 | import torch
18 | from torch import nn, optim
19 | from torch.autograd.variable import Variable
20 |
21 | from torchvision import transforms
22 | from torch.utils.data import DataLoader
23 |
24 | from models import Generator, Discriminator, VGGFeatures, weights_init_normal
25 | from datasets import Dataset
26 |
27 | from PIL import Image
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--n_epochs', type=int, default=10, help='number of epochs of training')
31 | parser.add_argument('--batch_size', type=int, default=16, help='size of the batches')
32 | parser.add_argument('--lr', type=float, default=0.0001, help='adam: learning rate')
33 | parser.add_argument('--b1', type=float, default=0.9, help='adam: decay of first order momentum of gradient')
34 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
35 | parser.add_argument('--img_size', type=int, default=256, help='high resolution image output size')
36 | parser.add_argument('--data_folder', type=str, default="img_align_celeba", help='where the training image pairs are located')
37 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
38 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
39 | parser.add_argument('--sample_interval', type=int, default=16, help='interval betwen image samples')
40 | opt = parser.parse_args()
41 |
42 |
43 |
44 | try:
45 | import visdom
46 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
47 | except ImportError:
48 | vis = None
49 | else:
50 | vis.close(None) # Clear all figures.
51 |
52 |
53 | # Celeba dataset loader.
54 | transform_image_hr = [
55 | transforms.Resize((opt.img_size, opt.img_size), Image.BICUBIC),
56 | transforms.ToTensor(),
57 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
58 | ]
59 | transform_image_lr = [
60 | transforms.Resize((opt.img_size//4, opt.img_size//4), Image.BICUBIC),
61 | transforms.ToTensor(),
62 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
63 | ]
64 | data_train = Dataset(opt.data_folder, transform_image_lr, transform_image_hr, stage='train')
65 | data_val = Dataset(opt.data_folder, transform_image_lr, transform_image_hr, stage='val')
66 |
67 | params = {'batch_size': opt.batch_size, 'shuffle': True}
68 | dataloader_train = DataLoader(data_val, **params)
69 |
70 | params = {'batch_size': 5, 'shuffle': True}
71 | dataloader_val = DataLoader(data_val, **params)
72 |
73 |
74 | cuda = torch.cuda.is_available()
75 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
76 |
77 | gan_loss = nn.BCELoss()
78 | content_loss = nn.MSELoss()
79 |
80 | generator = Generator()
81 | discriminator = Discriminator()
82 | vgg = VGGFeatures()
83 |
84 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
85 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
86 |
87 | # Loss record.
88 | g_losses = []
89 | d_losses = []
90 | epochs = []
91 | loss_legend = ['Discriminator', 'Generator']
92 |
93 | if cuda:
94 | generator = generator.cuda()
95 | discriminator = discriminator.cuda()
96 | vgg = vgg.cuda()
97 | gan_loss = gan_loss.cuda()
98 | content_loss = content_loss.cuda()
99 |
100 |
101 | generator.apply(weights_init_normal)
102 | discriminator.apply(weights_init_normal)
103 |
104 | for epoch in range(opt.n_epochs):
105 | print('Epoch {}'.format(epoch))
106 | for i, (batch_lr, batch_hr) in enumerate(dataloader_train):
107 |
108 | real = Variable(Tensor(batch_lr.size(0), 1).fill_(1), requires_grad=False)
109 | fake = Variable(Tensor(batch_lr.size(0), 1).fill_(0), requires_grad=False)
110 |
111 | imgs_real_lr = Variable(batch_lr.type(Tensor))
112 | imgs_real_hr = Variable(batch_hr.type(Tensor))
113 |
114 | # == Discriminator update == #
115 | optimizer_D.zero_grad()
116 |
117 | imgs_fake_hr = Variable(generator(imgs_real_lr.detach()))
118 |
119 | d_loss = gan_loss(discriminator(imgs_real_hr), real) + gan_loss(discriminator(imgs_fake_hr), fake)
120 |
121 | d_loss.backward()
122 | optimizer_D.step()
123 |
124 | # == Generator update == #
125 | imgs_fake_hr = generator(imgs_real_lr)
126 |
127 | optimizer_G.zero_grad()
128 |
129 | g_loss = (1/12.75) * content_loss(vgg(imgs_fake_hr), vgg(imgs_real_hr.detach())) + 1e-3 * gan_loss(discriminator(imgs_fake_hr), real)
130 |
131 | g_loss.backward()
132 | optimizer_G.step()
133 |
134 | if vis:
135 | batches_done = epoch * len(dataloader_train) + i
136 | if batches_done % opt.sample_interval == 0:
137 |
138 | # Keep a record of losses for plotting.
139 | epochs.append(epoch + i/len(dataloader_train))
140 | g_losses.append(g_loss.item())
141 | d_losses.append(d_loss.item())
142 |
143 |
144 |
145 | # Generate 5 images from the validation set.
146 | batch_lr, batch_hr = next(iter(dataloader_val))
147 |
148 | imgs_val_lr = Variable(batch_lr.type(Tensor))
149 | imgs_val_hr = Variable(batch_hr.type(Tensor))
150 | imgs_fake_hr = generator(imgs_val_lr).detach().data
151 |
152 | # For visualization purposes.
153 | imgs_val_lr = torch.nn.functional.upsample(imgs_val_lr, size=(imgs_fake_hr.size(2), imgs_fake_hr.size(3)), mode='bilinear')
154 |
155 | imgs_val_lr = imgs_val_lr.mul_(Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).add_(Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
156 | imgs_val_hr = imgs_val_hr.mul_(Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).add_(Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
157 | imgs_fake_hr = imgs_fake_hr.add_(torch.abs(torch.min(imgs_fake_hr))).div_(torch.max(imgs_fake_hr)-torch.min(imgs_fake_hr))
158 | fake_val = torch.cat((imgs_val_lr, imgs_val_hr, imgs_fake_hr), dim=2)
159 |
160 | # Display results on visdom page.
161 | vis.line(
162 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
163 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
164 | opts={
165 | 'title': 'loss over time',
166 | 'legend': loss_legend,
167 | 'xlabel': 'epoch',
168 | 'ylabel': 'loss',
169 | 'width': 512,
170 | 'height': 512
171 | },
172 | win=1)
173 | vis.images(
174 | fake_val,
175 | nrow=5, win=2,
176 | opts={
177 | 'title': 'GAN output [Epoch {}]'.format(epoch),
178 | 'width': 512,
179 | 'height': 512,
180 | }
181 | )
--------------------------------------------------------------------------------
/wgan-gp/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | # from torch.legacy.nn import Identity
3 |
4 | # Residual network.
5 | # WGAN-GP paper defines a residual block with up & downsampling.
6 | # See the official implementation (given in the paper).
7 | # I use architectures described in the official implementation,
8 | # since I find it hard to deduce the blocks given here from the text alone.
9 | class MeanPoolConv(nn.Module):
10 | def __init__(self, n_input, n_output, k_size):
11 | super(MeanPoolConv, self).__init__()
12 | conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True)
13 | self.model = nn.Sequential(conv1)
14 | def forward(self, x):
15 | out = (x[:,:,::2,::2] + x[:,:,1::2,::2] + x[:,:,::2,1::2] + x[:,:,1::2,1::2]) / 4.0
16 | out = self.model(out)
17 | return out
18 |
19 | class ConvMeanPool(nn.Module):
20 | def __init__(self, n_input, n_output, k_size):
21 | super(ConvMeanPool, self).__init__()
22 | conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True)
23 | self.model = nn.Sequential(conv1)
24 | def forward(self, x):
25 | out = self.model(x)
26 | out = (out[:,:,::2,::2] + out[:,:,1::2,::2] + out[:,:,::2,1::2] + out[:,:,1::2,1::2]) / 4.0
27 | return out
28 |
29 | class UpsampleConv(nn.Module):
30 | def __init__(self, n_input, n_output, k_size):
31 | super(UpsampleConv, self).__init__()
32 |
33 | self.model = nn.Sequential(
34 | nn.PixelShuffle(2),
35 | nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True)
36 | )
37 | def forward(self, x):
38 | x = x.repeat((1, 4, 1, 1)) # Weird concat of WGAN-GPs upsampling process.
39 | out = self.model(x)
40 | return out
41 |
42 | class ResidualBlock(nn.Module):
43 | def __init__(self, n_input, n_output, k_size, resample='up', bn=True, spatial_dim=None):
44 | super(ResidualBlock, self).__init__()
45 |
46 | self.resample = resample
47 |
48 | if resample == 'up':
49 | self.conv1 = UpsampleConv(n_input, n_output, k_size)
50 | self.conv2 = nn.Conv2d(n_output, n_output, k_size, padding=(k_size-1)//2)
51 | self.conv_shortcut = UpsampleConv(n_input, n_output, k_size)
52 | self.out_dim = n_output
53 | elif resample == 'down':
54 | self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2)
55 | self.conv2 = ConvMeanPool(n_input, n_output, k_size)
56 | self.conv_shortcut = ConvMeanPool(n_input, n_output, k_size)
57 | self.out_dim = n_output
58 | self.ln_dims = [n_input, spatial_dim, spatial_dim] # Define the dimensions for layer normalization.
59 | else:
60 | self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2)
61 | self.conv2 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2)
62 | self.conv_shortcut = None # Identity
63 | self.out_dim = n_input
64 | self.ln_dims = [n_input, spatial_dim, spatial_dim]
65 |
66 | self.model = nn.Sequential(
67 | nn.BatchNorm2d(n_input) if bn else nn.LayerNorm(self.ln_dims),
68 | nn.ReLU(inplace=True),
69 | self.conv1,
70 | nn.BatchNorm2d(self.out_dim) if bn else nn.LayerNorm(self.ln_dims),
71 | nn.ReLU(inplace=True),
72 | self.conv2,
73 | )
74 |
75 | def forward(self, x):
76 | if self.conv_shortcut is None:
77 | return x + self.model(x)
78 | else:
79 | return self.conv_shortcut(x) + self.model(x)
80 |
81 | class DiscBlock1(nn.Module):
82 | def __init__(self, n_output):
83 | super(DiscBlock1, self).__init__()
84 |
85 | self.conv1 = nn.Conv2d(3, n_output, 3, padding=(3-1)//2)
86 | self.conv2 = ConvMeanPool(n_output, n_output, 1)
87 | self.conv_shortcut = MeanPoolConv(3, n_output, 1)
88 |
89 | self.model = nn.Sequential(
90 | self.conv1,
91 | nn.ReLU(inplace=True),
92 | self.conv2
93 | )
94 |
95 | def forward(self, x):
96 | return self.conv_shortcut(x) + self.model(x)
97 |
98 | class Generator(nn.Module):
99 | def __init__(self):
100 | super(Generator, self).__init__()
101 |
102 | self.model = nn.Sequential( # 128 x 1 x 1
103 | nn.ConvTranspose2d(128, 128, 4, 1, 0), # 128 x 4 x 4
104 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 8 x 8
105 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 16 x 16
106 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 32 x 32
107 | nn.BatchNorm2d(128),
108 | nn.ReLU(inplace=True),
109 | nn.Conv2d(128, 3, 3, padding=(3-1)//2), # 3 x 32 x 32
110 | nn.Tanh()
111 | )
112 |
113 | def forward(self, z):
114 | img = self.model(z)
115 | return img
116 |
117 | class Discriminator(nn.Module):
118 | def __init__(self):
119 | super(Discriminator, self).__init__()
120 | n_output = 128
121 | '''
122 | This is a parameter but since we experiment with a single size
123 | of 3 x 32 x 32 images, it is hardcoded here.
124 | '''
125 |
126 | self.DiscBlock1 = DiscBlock1(n_output) # 128 x 16 x 16
127 |
128 | self.model = nn.Sequential(
129 | ResidualBlock(n_output, n_output, 3, resample='down', bn=False, spatial_dim=16), # 128 x 8 x 8
130 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8
131 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8
132 | nn.ReLU(inplace=True),
133 | )
134 | self.l1 = nn.Sequential(nn.Linear(128, 1)) # 128 x 1
135 |
136 | def forward(self, x):
137 | # x = x.view(-1, 3, 32, 32)
138 | y = self.DiscBlock1(x)
139 | y = self.model(y)
140 | y = y.view(x.size(0), 128, -1)
141 | y = y.mean(dim=2)
142 | out = self.l1(y).unsqueeze_(1).unsqueeze_(2) # or *.view(x.size(0), 128, 1, 1, 1)
143 | return out
144 |
--------------------------------------------------------------------------------
/wgan-gp/out/36737d5d4d1ebe.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wgan-gp/out/GAN output [Epoch 105].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan-gp/out/GAN output [Epoch 105].jpg
--------------------------------------------------------------------------------
/wgan-gp/out/GAN output [Epoch 52].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan-gp/out/GAN output [Epoch 52].jpg
--------------------------------------------------------------------------------
/wgan-gp/wgan_gp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from torch import nn, optim
5 | from torch.autograd.variable import Variable
6 |
7 | from torchvision import transforms, datasets
8 | from torch.utils.data import DataLoader
9 |
10 | from models import Discriminator, Generator
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
14 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
15 | parser.add_argument('--alpha', type=float, default=0.0001, help='adam: learning rate')
16 | parser.add_argument('--b1', type=float, default=0.5, help='adam: beta 1')
17 | parser.add_argument('--b2', type=float, default=0.9, help='adam: beta 2')
18 | parser.add_argument('--n_critic', type=int, default=5, help='number of critic iterations per generator iteration')
19 | parser.add_argument('--lambda_1', type=int, default=10, help='gradient penalty coefficient')
20 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
21 | parser.add_argument('--channels', type=int, default=3, help='is image rgb (3) or grayscale (1) ?')
22 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
23 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
24 | parser.add_argument('--sample_interval', type=int, default=256, help='interval betwen image samples')
25 | opt = parser.parse_args()
26 |
27 | try:
28 | import visdom
29 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
30 | except ImportError:
31 | vis = None
32 | else:
33 | vis.close(None) # Clear all figures.
34 |
35 | img_dims = (opt.channels, opt.img_size, opt.img_size)
36 | n_features = opt.channels * opt.img_size * opt.img_size
37 |
38 | # TODO: Use some initialization in the future.
39 | def init_weights(m):
40 | if type(m) == nn.ConvTranspose2d:
41 | torch.nn.init.kaiming_normal(m.weight, mode='fan_out', nonlinearity='relu')
42 | elif type(m) == nn.Conv2d:
43 | torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
44 |
45 | def load_cifar10(img_size):
46 | compose = transforms.Compose(
47 | [transforms.Resize(img_size),
48 | transforms.ToTensor(),
49 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
50 | ])
51 | output_dir = './data/cifar10'
52 | cifar = datasets.CIFAR10(root=output_dir, download=True, train=True,
53 | transform=compose)
54 | return cifar
55 |
56 | cifar = load_cifar10(opt.img_size)
57 | batch_iterator = DataLoader(cifar, shuffle=True, batch_size=opt.batch_size) # List, NCHW format.
58 |
59 | cuda = torch.cuda.is_available()
60 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
61 | gan_loss = nn.BCELoss()
62 |
63 | generator = Generator()
64 | discriminator = Discriminator()
65 |
66 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.alpha, betas=(opt.b1, opt.b2))
67 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.alpha, betas=(opt.b1, opt.b2))
68 |
69 | # Loss record.
70 | g_losses = []
71 | d_losses = []
72 | epochs = []
73 | loss_legend = ['Discriminator', 'Generator']
74 |
75 | if cuda:
76 | generator = generator.cuda()
77 | discriminator = discriminator.cuda()
78 |
79 | noise_fixed = Variable(Tensor(25, 128, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN.
80 |
81 | for epoch in range(opt.n_epochs):
82 | print('Epoch {}'.format(epoch))
83 | for i, (batch, _) in enumerate(batch_iterator):
84 | # == Discriminator update == #
85 | for iter in range(opt.n_critic):
86 | # Sample real and fake images, using notation in paper.
87 | x = Variable(batch.type(Tensor))
88 | noise = Variable(Tensor(batch.size(0), 128, 1, 1).normal_(0, 1))
89 | x_tilde = Variable(generator(noise), requires_grad=True)
90 |
91 | epsilon = Variable(Tensor(batch.size(0), 1, 1, 1).uniform_(0, 1))
92 |
93 | x_hat = epsilon*x + (1 - epsilon)*x_tilde
94 | x_hat = torch.autograd.Variable(x_hat, requires_grad=True)
95 |
96 | # Put the interpolated data through critic.
97 | dw_x = discriminator(x_hat)
98 | # A great exercise on learning how the autograd.grad works!
99 | grad_x = torch.autograd.grad(outputs=dw_x, inputs=x_hat,
100 | grad_outputs=Variable(Tensor(batch.size(0), 1, 1, 1).fill_(1.0), requires_grad=False),
101 | create_graph=True, retain_graph=True, only_inputs=True)
102 | grad_x = grad_x[0].view(batch.size(0), -1)
103 | grad_x = grad_x.norm(p=2, dim=1) # My naming is inaccurate, this is the 2-norm of grad(D_w(x_hat))
104 |
105 | # Update discriminator (or critic, since we don't output probabilities anymore).
106 | optimizer_D.zero_grad()
107 |
108 | # WGAN-GP loss, defined properly as a loss unlike the WGAN paper.
109 | d_loss = torch.mean(discriminator(x_tilde)) - torch.mean(discriminator(x)) + opt.lambda_1*torch.mean((grad_x - 1)**2)
110 | # d_loss = torch.mean(d_loss) # there's a reason for why this shouldn't be done this way :)
111 |
112 | d_loss.backward()
113 | optimizer_D.step()
114 |
115 | # == Generator update == #
116 | noise = Variable(Tensor(batch.size(0), 128, 1, 1).normal_(0, 1))
117 | imgs_fake = generator(noise)
118 |
119 | optimizer_G.zero_grad()
120 |
121 | g_loss = -torch.mean(discriminator(imgs_fake))
122 |
123 | g_loss.backward()
124 | optimizer_G.step()
125 |
126 | if vis:
127 | batches_done = epoch * len(batch_iterator) + i
128 | if batches_done % opt.sample_interval == 0:
129 |
130 | # Keep a record of losses for plotting.
131 | epochs.append(epoch + i/len(batch_iterator))
132 | g_losses.append(g_loss.item())
133 | d_losses.append(d_loss.item())
134 |
135 | # Generate images for a given set of fixed noise
136 | # so we can track how the GAN learns.
137 | imgs_fake_fixed = generator(noise_fixed).detach().data
138 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2)
139 |
140 | # Display results on visdom page.
141 | vis.line(
142 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
143 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
144 | opts={
145 | 'title': 'loss over time',
146 | 'legend': loss_legend,
147 | 'xlabel': 'epoch',
148 | 'ylabel': 'loss',
149 | 'width': 512,
150 | 'height': 512
151 | },
152 | win=1)
153 | vis.images(
154 | imgs_fake_fixed[:25],
155 | nrow=5, win=2,
156 | opts={
157 | 'title': 'GAN output [Epoch {}]'.format(epoch),
158 | 'width': 512,
159 | 'height': 512,
160 | }
161 | )
--------------------------------------------------------------------------------
/wgan/dataset_loader.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms, datasets
2 |
3 | def load_cifar10(img_size):
4 | compose = transforms.Compose(
5 | [transforms.Resize(img_size),
6 | transforms.ToTensor(),
7 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
8 | ])
9 | output_dir = './data/cifar10'
10 | cifar = datasets.CIFAR10(root=output_dir, download=True, train=True,
11 | transform=compose)
12 | return cifar
13 |
14 | def load_mnist(img_size):
15 | compose = transforms.Compose(
16 | [transforms.Resize(img_size),
17 | transforms.ToTensor(),
18 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19 | ])
20 | output_dir = './data/mnist'
21 | return datasets.MNIST(root=output_dir, train=True,
22 | transform=compose, download=True)
23 |
--------------------------------------------------------------------------------
/wgan/out/bn_GAN output [Epoch 99].jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan/out/bn_GAN output [Epoch 99].jpg
--------------------------------------------------------------------------------
/wgan/out/no_bn_366d9efab2beac.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wgan/out/no_bn_GAN output [Epoch 99] (1).jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan/out/no_bn_GAN output [Epoch 99] (1).jpg
--------------------------------------------------------------------------------
/wgan/out/real_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan/out/real_samples.png
--------------------------------------------------------------------------------
/wgan/wgan.py:
--------------------------------------------------------------------------------
1 | '''
2 | Wasserstein DCGAN.
3 |
4 | Q: Does lower loss indicate better images, i.e.,
5 | does convergence really mean anything? In regular GANs, it doesn't.
6 | A: Kind of. I would say having a steadily decreasing loss is nice, but it's not always a 1-1 mapping.
7 | (related: Many Paths to Equilibrium: GANs Do Not Need to Decrease a Divergence At Every Step)
8 |
9 | Q: Is it possible to train a stable network without batch normalization, as claimed?
10 | A: Sort of, batch norm still makes things way faster. Also, there is this practice of
11 | not using BN at the first layer of discriminator which I found to not matter.
12 | However, I noticed irregular behavior when BN is not used, such as sudden drops in error.
13 | '''
14 |
15 |
16 | import argparse
17 | import os
18 |
19 | import torch
20 | from torch import nn, optim
21 | from torch.autograd.variable import Variable
22 |
23 | from torch.utils.data import DataLoader
24 | from torchvision.utils import save_image
25 |
26 | from dataset_loader import load_mnist, load_cifar10
27 |
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('--dataset', default='cifar10', help='cifar10 | mnist')
30 | parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
31 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
32 | parser.add_argument('--lr', type=float, default=0.00005, help='adam: learning rate')
33 | parser.add_argument('--c', type=float, default=0.01, help='clipping value for the gradients')
34 | parser.add_argument('--n_critic', type=int, default=5, help='number of discriminator updates per each generator update')
35 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
36 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
37 | parser.add_argument('--channels', type=int, default=3, help='number of image channels')
38 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
39 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
40 | parser.add_argument('--sample_interval', type=int, default=256, help='interval betwen image samples')
41 | opt = parser.parse_args()
42 |
43 |
44 | try:
45 | import visdom
46 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
47 | except ImportError:
48 | vis = None
49 | else:
50 | vis.close(None) # Clear all figures.
51 |
52 | def weights_init(m):
53 | classname = m.__class__.__name__
54 | if classname.find('Conv') != -1:
55 | m.weight.data.normal_(0.0, 0.02)
56 | elif classname.find('BatchNorm') != -1:
57 | m.weight.data.normal_(1.0, 0.02)
58 | m.bias.data.fill_(0)
59 |
60 | img_dims = (opt.channels, opt.img_size, opt.img_size)
61 | n_features = opt.channels * opt.img_size * opt.img_size
62 |
63 | # DCGAN network.
64 | class Generator(nn.Module):
65 | def __init__(self):
66 | super(Generator, self).__init__()
67 |
68 | def convblock(n_input, n_output, k_size=4, stride=2, padding=0, normalize=True):
69 | block = [nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
70 | if normalize:
71 | block.append(nn.BatchNorm2d(n_output))
72 | block.append(nn.ReLU(inplace=True))
73 | return block
74 |
75 | self.project = nn.Sequential(
76 | nn.Linear(opt.latent_dim, 256 * 4 * 4),
77 | nn.BatchNorm1d(256 * 4 * 4),
78 | nn.ReLU(inplace=True),
79 | )
80 |
81 | self.model = nn.Sequential(
82 | *convblock(opt.latent_dim, 256, 4, 1, 0),
83 | *convblock(256, 128, 4, 2, 1),
84 | *convblock(128, 64, 4, 2, 1),
85 | nn.ConvTranspose2d(64, opt.channels, 4, 2, 1),
86 | nn.Tanh()
87 | )
88 | # Tanh > Image values are between [-1, 1]
89 |
90 | def forward(self, z):
91 | img = self.model(z)
92 | img = img.view(z.size(0), *img_dims)
93 | return img
94 |
95 |
96 | class Discriminator(nn.Module):
97 | def __init__(self):
98 | super(Discriminator, self).__init__()
99 |
100 | def convblock(n_input, n_output, kernel_size=4, stride=2, padding=1, normalize=True):
101 | block = [nn.Conv2d(n_input, n_output, kernel_size, stride, padding, bias=False)]
102 | if normalize:
103 | block.append(nn.BatchNorm2d(n_output))
104 | block.append(nn.LeakyReLU(0.2, inplace=True))
105 | return block
106 |
107 | self.model = nn.Sequential(
108 | *convblock(opt.channels, 64, 4, 2, 1, normalize=False), # 32-> 16, (32+2p-3-1)/2 + 1 = 16, p = 1, apparently BN at 1st layer is detrimental for WGANs.
109 | *convblock(64, 128, 4, 2, 1),
110 | *convblock(128, 256, 4, 2, 1),
111 | nn.Conv2d(256, 1, 4, 1, 0, bias=False), # FC with Conv.
112 | nn.Sigmoid()
113 | )
114 |
115 | def forward(self, img):
116 | prob = self.model(img)
117 | return prob
118 |
119 |
120 | assert (opt.dataset == 'cifar10' or opt.dataset == 'mnist'), 'Unknown dataset! Only cifar10 and mnist are supported.'
121 |
122 | if opt.dataset == 'cifar10':
123 | batch_iterator = DataLoader(load_cifar10(opt.img_size), shuffle=True, batch_size=opt.batch_size) # List, NCHW format.
124 | elif opt.dataset == 'mnist':
125 | batch_iterator = DataLoader(load_mnist(opt.img_size), shuffle=True, batch_size=opt.batch_size) # List, NCHW format.
126 |
127 | # Save a batch of real images for reference.
128 | os.makedirs('./out', exist_ok=True)
129 | save_image(next(iter(batch_iterator))[0][:25, ...], './out/real_samples.png', nrow=5, normalize=True)
130 |
131 |
132 | cuda = torch.cuda.is_available()
133 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
134 | gan_loss = nn.BCELoss()
135 |
136 | generator = Generator()
137 | discriminator = Discriminator()
138 |
139 | optimizer_D = optim.RMSprop(discriminator.parameters(), lr=opt.lr)
140 | optimizer_G = optim.RMSprop(generator.parameters(), lr=opt.lr)
141 |
142 | # Loss record.
143 | g_losses = []
144 | d_losses = []
145 | epochs = []
146 | loss_legend = ['Discriminator', 'Generator']
147 |
148 | if cuda:
149 | generator = generator.cuda()
150 | discriminator = discriminator.cuda()
151 |
152 | generator.apply(weights_init)
153 | discriminator.apply(weights_init)
154 |
155 | noise_fixed = Variable(Tensor(25, opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN.
156 |
157 | for epoch in range(opt.n_epochs):
158 | print('Epoch {}'.format(epoch))
159 | for i, (batch, _) in enumerate(batch_iterator):
160 |
161 | # == Discriminator update == #
162 | for iter in range(opt.n_critic):
163 | # Sample real and fake images.
164 | imgs_real = Variable(batch.type(Tensor))
165 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1))
166 | imgs_fake = Variable(generator(noise), requires_grad=False)
167 |
168 | # Update discriminator (or critic, since we don't output probabilities anymore).
169 | optimizer_D.zero_grad()
170 |
171 | # WGAN utility, we ascend on this hence the loss will be the negative.
172 | d_loss = -torch.mean(discriminator(imgs_real) - discriminator(imgs_fake))
173 |
174 | d_loss.backward()
175 | optimizer_D.step()
176 |
177 | # Detrimental clipping of the generator weights,
178 | # which will be fixed in a few months from this paper.
179 | for params in discriminator.parameters():
180 | params.data.clamp_(-opt.c, +opt.c)
181 |
182 | # == Generator update == #
183 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1))
184 | imgs_fake = generator(noise)
185 |
186 | optimizer_G.zero_grad()
187 |
188 | g_loss = -torch.mean(discriminator(imgs_fake))
189 |
190 | g_loss.backward()
191 | optimizer_G.step()
192 |
193 | if vis:
194 | batches_done = epoch * len(batch_iterator) + i
195 | if batches_done % opt.sample_interval == 0:
196 |
197 | # Keep a record of losses for plotting.
198 | epochs.append(epoch + i/len(batch_iterator))
199 | g_losses.append(-g_loss.item()) # Negative because the loss is actually maximized in WGAN.
200 | d_losses.append(-d_loss.item())
201 |
202 | # Generate images for a given set of fixed noise
203 | # so we can track how the GAN learns.
204 | imgs_fake_fixed = generator(noise_fixed).detach().data
205 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom.
206 |
207 | # Display results on visdom page.
208 | vis.line(
209 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1),
210 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1),
211 | opts={
212 | 'title': 'loss over time',
213 | 'legend': loss_legend,
214 | 'xlabel': 'epoch',
215 | 'ylabel': 'loss',
216 | 'width': 512,
217 | 'height': 512
218 | },
219 | win=1)
220 | vis.images(
221 | imgs_fake_fixed,
222 | nrow=5, win=2,
223 | opts={
224 | 'title': 'GAN output [Epoch {}]'.format(epoch),
225 | 'width': 512,
226 | 'height': 512,
227 | }
228 | )
--------------------------------------------------------------------------------