├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── images
├── CIFAR-10.png
├── Fashion-MNIST.png
├── MNIST.png
├── inception_graph_generator_iters.png
├── inception_graph_time.png
├── latent-mnist.png
└── latent_fashion.png
├── main.py
├── models
├── __init__.py
├── dcgan.py
├── gan.py
├── wgan_clipping.py
└── wgan_gradient_penalty.py
├── requirements.txt
└── utils
├── __init__.py
├── config.py
├── data_loader.py
├── fashion_mnist.py
├── feature_extraction_test.py
├── inception_score.py
└── tensorboard_logger.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .python-version
2 | **__pycache__**
3 | *.pkl
4 | logs/*
5 | datasets/*
6 | training_result_images/
7 | inception_score_graph.txt
8 | .vscode/
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at filip.zelic@protonmail.com. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Green9
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Pytorch code for GAN models
2 | This is the pytorch implementation of 3 different GAN models using same convolutional architecture.
3 |
4 |
5 | - DCGAN (Deep convolutional GAN)
6 | - WGAN-CP (Wasserstein GAN using weight clipping)
7 | - WGAN-GP (Wasserstein GAN using gradient penalty)
8 |
9 |
10 |
11 | ## Dependecies
12 | The prominent packages are:
13 |
14 | * numpy
15 | * scikit-learn
16 | * tensorflow 2.5.0
17 | * pytorch 1.8.1
18 | * torchvision 0.9.1
19 |
20 | To install all the dependencies quickly and easily you should use __pip__
21 |
22 | ```python
23 | pip install -r requirements.txt
24 | ```
25 |
26 |
27 |
28 | *Training*
29 | ---
30 | Running training of DCGAN model on Fashion-MNIST dataset:
31 |
32 |
33 | ```
34 | python main.py --model DCGAN \
35 | --is_train True \
36 | --download True \
37 | --dataroot datasets/fashion-mnist \
38 | --dataset fashion-mnist \
39 | --epochs 30 \
40 | --cuda True \
41 | --batch_size 64
42 | ```
43 |
44 | Running training of WGAN-GP model on CIFAR-10 dataset:
45 |
46 | ```
47 | python main.py --model WGAN-GP \
48 | --is_train True \
49 | --download True \
50 | --dataroot datasets/cifar \
51 | --dataset cifar \
52 | --generator_iters 40000 \
53 | --cuda True \
54 | --batch_size 64
55 | ```
56 |
57 | Start tensorboard:
58 |
59 | ```
60 | tensorboard --logdir ./logs/
61 | ```
62 |
63 | *Walk in latent space*
64 | ---
65 | *Interpolation between a two random latent vector z over 10 random points, shows that generated samples have smooth transitions.*
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | *Generated examples MNIST, Fashion-MNIST, CIFAR-10*
75 | ---
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 | *Inception score*
87 | ---
88 | [About Inception score](https://arxiv.org/pdf/1801.01973.pdf)
89 |
90 |
91 |
92 |
93 |
94 | *Useful Resources*
95 | ---
96 |
97 |
98 | - [WGAN reddit thread](https://www.reddit.com/r/MachineLearning/comments/5qxoaz/r_170107875_wasserstein_gan/)
99 | - [Blogpost](https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html)
100 | - [Deconvolution and checkboard Artifacts](https://distill.pub/2016/deconv-checkerboard/)
101 | - [WGAN-CP paper](https://arxiv.org/pdf/1701.07875.pdf)
102 | - [WGAN-GP paper](https://arxiv.org/pdf/1704.00028.pdf)
103 | - [DCGAN paper](https://arxiv.org/pdf/1511.06434.pdf)
104 | - [Working remotely with PyCharm and SSH](https://medium.com/@erikhallstrm/work-remotely-with-pycharm-tensorflow-and-ssh-c60564be862d)
105 |
--------------------------------------------------------------------------------
/images/CIFAR-10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/CIFAR-10.png
--------------------------------------------------------------------------------
/images/Fashion-MNIST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/Fashion-MNIST.png
--------------------------------------------------------------------------------
/images/MNIST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/MNIST.png
--------------------------------------------------------------------------------
/images/inception_graph_generator_iters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/inception_graph_generator_iters.png
--------------------------------------------------------------------------------
/images/inception_graph_time.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/inception_graph_time.png
--------------------------------------------------------------------------------
/images/latent-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/latent-mnist.png
--------------------------------------------------------------------------------
/images/latent_fashion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/latent_fashion.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from utils.config import parse_args
2 | from utils.data_loader import get_data_loader
3 |
4 | from models.gan import GAN
5 | from models.dcgan import DCGAN_MODEL
6 | from models.wgan_clipping import WGAN_CP
7 | from models.wgan_gradient_penalty import WGAN_GP
8 |
9 |
10 | def main(args):
11 | model = None
12 | if args.model == 'GAN':
13 | model = GAN(args)
14 | elif args.model == 'DCGAN':
15 | model = DCGAN_MODEL(args)
16 | elif args.model == 'WGAN-CP':
17 | model = WGAN_CP(args)
18 | elif args.model == 'WGAN-GP':
19 | model = WGAN_GP(args)
20 | else:
21 | print("Model type non-existing. Try again.")
22 | exit(-1)
23 |
24 | # Load datasets to train and test loaders
25 | train_loader, test_loader = get_data_loader(args)
26 | #feature_extraction = FeatureExtractionTest(train_loader, test_loader, args.cuda, args.batch_size)
27 |
28 | # Start model training
29 | if args.is_train == 'True':
30 | model.train(train_loader)
31 |
32 | # start evaluating on test data
33 | else:
34 | model.evaluate(test_loader, args.load_D, args.load_G)
35 | for i in range(50):
36 | model.generate_latent_walk(i)
37 |
38 |
39 | if __name__ == '__main__':
40 | args = parse_args()
41 | print(args.cuda)
42 | main(args)
43 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/models/__init__.py
--------------------------------------------------------------------------------
/models/dcgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import time as t
5 | import os
6 | from utils.tensorboard_logger import Logger
7 | from utils.inception_score import get_inception_score
8 | from itertools import chain
9 | from torchvision import utils
10 |
11 | class Generator(torch.nn.Module):
12 | def __init__(self, channels):
13 | super().__init__()
14 | # Filters [1024, 512, 256]
15 | # Input_dim = 100
16 | # Output_dim = C (number of channels)
17 | self.main_module = nn.Sequential(
18 | # Z latent vector 100
19 | nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0),
20 | nn.BatchNorm2d(num_features=1024),
21 | nn.ReLU(True),
22 |
23 | # State (1024x4x4)
24 | nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
25 | nn.BatchNorm2d(num_features=512),
26 | nn.ReLU(True),
27 |
28 | # State (512x8x8)
29 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
30 | nn.BatchNorm2d(num_features=256),
31 | nn.ReLU(True),
32 |
33 | # State (256x16x16)
34 | nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1))
35 | # output of main module --> Image (Cx32x32)
36 |
37 | self.output = nn.Tanh()
38 |
39 | def forward(self, x):
40 | x = self.main_module(x)
41 | return self.output(x)
42 |
43 |
44 | class Discriminator(torch.nn.Module):
45 | def __init__(self, channels):
46 | super().__init__()
47 | # Filters [256, 512, 1024]
48 | # Input_dim = channels (Cx64x64)
49 | # Output_dim = 1
50 | self.main_module = nn.Sequential(
51 | # Image (Cx32x32)
52 | nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1),
53 | nn.LeakyReLU(0.2, inplace=True),
54 |
55 | # State (256x16x16)
56 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
57 | nn.BatchNorm2d(512),
58 | nn.LeakyReLU(0.2, inplace=True),
59 |
60 | # State (512x8x8)
61 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
62 | nn.BatchNorm2d(1024),
63 | nn.LeakyReLU(0.2, inplace=True))
64 | # outptut of main module --> State (1024x4x4)
65 |
66 | self.output = nn.Sequential(
67 | nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0),
68 | # Output 1
69 | nn.Sigmoid())
70 |
71 | def forward(self, x):
72 | x = self.main_module(x)
73 | return self.output(x)
74 |
75 | def feature_extraction(self, x):
76 | # Use discriminator for feature extraction then flatten to vector of 16384 features
77 | x = self.main_module(x)
78 | return x.view(-1, 1024*4*4)
79 |
80 | class DCGAN_MODEL(object):
81 | def __init__(self, args):
82 | print("DCGAN model initalization.")
83 | self.G = Generator(args.channels)
84 | self.D = Discriminator(args.channels)
85 | self.C = args.channels
86 |
87 | # binary cross entropy loss and optimizer
88 | self.loss = nn.BCELoss()
89 |
90 | self.cuda = False
91 | self.cuda_index = 0
92 | # check if cuda is available
93 | self.check_cuda(args.cuda)
94 |
95 | # Using lower learning rate than suggested by (ADAM authors) lr=0.0002 and Beta_1 = 0.5 instead od 0.9 works better [Radford2015]
96 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
97 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
98 |
99 | self.epochs = args.epochs
100 | self.batch_size = args.batch_size
101 |
102 | # Set the logger
103 | self.logger = Logger('./logs')
104 | self.number_of_images = 10
105 |
106 | # cuda support
107 | def check_cuda(self, cuda_flag=False):
108 | if cuda_flag:
109 | self.cuda = True
110 | self.D.cuda(self.cuda_index)
111 | self.G.cuda(self.cuda_index)
112 | self.loss = nn.BCELoss().cuda(self.cuda_index)
113 | print("Cuda enabled flag: ")
114 | print(self.cuda)
115 |
116 |
117 | def train(self, train_loader):
118 | self.t_begin = t.time()
119 | generator_iter = 0
120 | #self.file = open("inception_score_graph.txt", "w")
121 |
122 | for epoch in range(self.epochs):
123 | self.epoch_start_time = t.time()
124 |
125 | for i, (images, _) in enumerate(train_loader):
126 | # Check if round number of batches
127 | if i == train_loader.dataset.__len__() // self.batch_size:
128 | break
129 |
130 | z = torch.rand((self.batch_size, 100, 1, 1))
131 | real_labels = torch.ones(self.batch_size)
132 | fake_labels = torch.zeros(self.batch_size)
133 |
134 | if self.cuda:
135 | images, z = Variable(images).cuda(self.cuda_index), Variable(z).cuda(self.cuda_index)
136 | real_labels, fake_labels = Variable(real_labels).cuda(self.cuda_index), Variable(fake_labels).cuda(self.cuda_index)
137 | else:
138 | images, z = Variable(images), Variable(z)
139 | real_labels, fake_labels = Variable(real_labels), Variable(fake_labels)
140 |
141 |
142 | # Train discriminator
143 | # Compute BCE_Loss using real images
144 | outputs = self.D(images)
145 | d_loss_real = self.loss(outputs.flatten(), real_labels)
146 | real_score = outputs
147 |
148 | # Compute BCE Loss using fake images
149 | if self.cuda:
150 | z = Variable(torch.randn(self.batch_size, 100, 1, 1)).cuda(self.cuda_index)
151 | else:
152 | z = Variable(torch.randn(self.batch_size, 100, 1, 1))
153 | fake_images = self.G(z)
154 | outputs = self.D(fake_images)
155 | d_loss_fake = self.loss(outputs.flatten(), fake_labels)
156 | fake_score = outputs
157 |
158 | # Optimize discriminator
159 | d_loss = d_loss_real + d_loss_fake
160 | self.D.zero_grad()
161 | d_loss.backward()
162 | self.d_optimizer.step()
163 |
164 | # Train generator
165 | # Compute loss with fake images
166 | if self.cuda:
167 | z = Variable(torch.randn(self.batch_size, 100, 1, 1)).cuda(self.cuda_index)
168 | else:
169 | z = Variable(torch.randn(self.batch_size, 100, 1, 1))
170 | fake_images = self.G(z)
171 | outputs = self.D(fake_images)
172 | g_loss = self.loss(outputs.flatten(), real_labels)
173 |
174 | # Optimize generator
175 | self.D.zero_grad()
176 | self.G.zero_grad()
177 | g_loss.backward()
178 | self.g_optimizer.step()
179 | generator_iter += 1
180 |
181 |
182 | if generator_iter % 1000 == 0:
183 | # Workaround because graphic card memory can't store more than 800+ examples in memory for generating image
184 | # Therefore doing loop and generating 800 examples and stacking into list of samples to get 8000 generated images
185 | # This way Inception score is more correct since there are different generated examples from every class of Inception model
186 | # sample_list = []
187 | # for i in range(10):
188 | # z = Variable(torch.randn(800, 100, 1, 1)).cuda(self.cuda_index)
189 | # samples = self.G(z)
190 | # sample_list.append(samples.data.cpu().numpy())
191 | #
192 | # # Flattening list of lists into one list of numpy arrays
193 | # new_sample_list = list(chain.from_iterable(sample_list))
194 | # print("Calculating Inception Score over 8k generated images")
195 | # # Feeding list of numpy arrays
196 | # inception_score = get_inception_score(new_sample_list, cuda=True, batch_size=32,
197 | # resize=True, splits=10)
198 | print('Epoch-{}'.format(epoch + 1))
199 | self.save_model()
200 |
201 | if not os.path.exists('training_result_images/'):
202 | os.makedirs('training_result_images/')
203 |
204 | # Denormalize images and save them in grid 8x8
205 | z = Variable(torch.randn(800, 100, 1, 1)).cuda(self.cuda_index)
206 | samples = self.G(z)
207 | samples = samples.mul(0.5).add(0.5)
208 | samples = samples.data.cpu()[:64]
209 | grid = utils.make_grid(samples)
210 | utils.save_image(grid, 'training_result_images/img_generatori_iter_{}.png'.format(str(generator_iter).zfill(3)))
211 |
212 | time = t.time() - self.t_begin
213 | #print("Inception score: {}".format(inception_score))
214 | print("Generator iter: {}".format(generator_iter))
215 | print("Time {}".format(time))
216 |
217 | # Write to file inception_score, gen_iters, time
218 | #output = str(generator_iter) + " " + str(time) + " " + str(inception_score[0]) + "\n"
219 | #self.file.write(output)
220 |
221 |
222 | if ((i + 1) % 100) == 0:
223 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
224 | ((epoch + 1), (i + 1), train_loader.dataset.__len__() // self.batch_size, d_loss.data, g_loss.data))
225 |
226 | z = Variable(torch.randn(self.batch_size, 100, 1, 1).cuda(self.cuda_index))
227 |
228 | # TensorBoard logging
229 | # Log the scalar values
230 | info = {
231 | 'd_loss': d_loss.data,
232 | 'g_loss': g_loss.data
233 | }
234 |
235 | for tag, value in info.items():
236 | self.logger.scalar_summary(tag, value, generator_iter)
237 |
238 | # Log values and gradients of the parameters
239 | for tag, value in self.D.named_parameters():
240 | tag = tag.replace('.', '/')
241 | self.logger.histo_summary(tag, self.to_np(value), generator_iter)
242 | self.logger.histo_summary(tag + '/grad', self.to_np(value.grad), generator_iter)
243 |
244 | # Log the images while training
245 | info = {
246 | 'real_images': self.real_images(images, self.number_of_images),
247 | 'generated_images': self.generate_img(z, self.number_of_images)
248 | }
249 |
250 | for tag, images in info.items():
251 | self.logger.image_summary(tag, images, generator_iter)
252 |
253 |
254 | self.t_end = t.time()
255 | print('Time of training-{}'.format((self.t_end - self.t_begin)))
256 | #self.file.close()
257 |
258 | # Save the trained parameters
259 | self.save_model()
260 |
261 | def evaluate(self, test_loader, D_model_path, G_model_path):
262 | self.load_model(D_model_path, G_model_path)
263 | z = Variable(torch.randn(self.batch_size, 100, 1, 1)).cuda(self.cuda_index)
264 | samples = self.G(z)
265 | samples = samples.mul(0.5).add(0.5)
266 | samples = samples.data.cpu()
267 | grid = utils.make_grid(samples)
268 | print("Grid of 8x8 images saved to 'dgan_model_image.png'.")
269 | utils.save_image(grid, 'dgan_model_image.png')
270 |
271 | def real_images(self, images, number_of_images):
272 | if (self.C == 3):
273 | return self.to_np(images.view(-1, self.C, 32, 32)[:self.number_of_images])
274 | else:
275 | return self.to_np(images.view(-1, 32, 32)[:self.number_of_images])
276 |
277 | def generate_img(self, z, number_of_images):
278 | samples = self.G(z).data.cpu().numpy()[:number_of_images]
279 | generated_images = []
280 | for sample in samples:
281 | if self.C == 3:
282 | generated_images.append(sample.reshape(self.C, 32, 32))
283 | else:
284 | generated_images.append(sample.reshape(32, 32))
285 | return generated_images
286 |
287 | def to_np(self, x):
288 | return x.data.cpu().numpy()
289 |
290 | def save_model(self):
291 | torch.save(self.G.state_dict(), './generator.pkl')
292 | torch.save(self.D.state_dict(), './discriminator.pkl')
293 | print('Models save to ./generator.pkl & ./discriminator.pkl ')
294 |
295 | def load_model(self, D_model_filename, G_model_filename):
296 | D_model_path = os.path.join(os.getcwd(), D_model_filename)
297 | G_model_path = os.path.join(os.getcwd(), G_model_filename)
298 | self.D.load_state_dict(torch.load(D_model_path))
299 | self.G.load_state_dict(torch.load(G_model_path))
300 | print('Generator model loaded from {}.'.format(G_model_path))
301 | print('Discriminator model loaded from {}-'.format(D_model_path))
302 |
303 | def generate_latent_walk(self, number):
304 | if not os.path.exists('interpolated_images/'):
305 | os.makedirs('interpolated_images/')
306 |
307 | # Interpolate between twe noise(z1, z2) with number_int steps between
308 | number_int = 10
309 | z_intp = torch.FloatTensor(1, 100, 1, 1)
310 | z1 = torch.randn(1, 100, 1, 1)
311 | z2 = torch.randn(1, 100, 1, 1)
312 | if self.cuda:
313 | z_intp = z_intp.cuda()
314 | z1 = z1.cuda()
315 | z2 = z2.cuda()
316 |
317 | z_intp = Variable(z_intp)
318 | images = []
319 | alpha = 1.0 / float(number_int + 1)
320 | print(alpha)
321 | for i in range(1, number_int + 1):
322 | z_intp.data = z1*alpha + z2*(1.0 - alpha)
323 | alpha += alpha
324 | fake_im = self.G(z_intp)
325 | fake_im = fake_im.mul(0.5).add(0.5) #denormalize
326 | images.append(fake_im.view(self.C,32,32).data.cpu())
327 |
328 | grid = utils.make_grid(images, nrow=number_int )
329 | utils.save_image(grid, 'interpolated_images/interpolated_{}.png'.format(str(number).zfill(3)))
330 | print("Saved interpolated images to interpolated_images/interpolated_{}.".format(str(number).zfill(3)))
331 |
--------------------------------------------------------------------------------
/models/gan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import utils
6 | from torch.autograd import Variable
7 | from utils.tensorboard_logger import Logger
8 |
9 |
10 | class GAN(object):
11 | def __init__(self, args):
12 | # Generator architecture
13 | self.G = nn.Sequential(
14 | nn.Linear(100, 256),
15 | nn.LeakyReLU(0.2),
16 | nn.Linear(256, 512),
17 | nn.LeakyReLU(0.2),
18 | nn.Linear(512, 1024),
19 | nn.LeakyReLU(0.2),
20 | nn.Tanh())
21 |
22 | # Discriminator architecture
23 | self.D = nn.Sequential(
24 | nn.Linear(1024, 512),
25 | nn.LeakyReLU(0.2),
26 | nn.Linear(512, 256),
27 | nn.LeakyReLU(0.2),
28 | nn.Linear(256, 1),
29 | nn.Sigmoid())
30 |
31 | self.cuda = False
32 | self.cuda_index = 0
33 | # check if cuda is available
34 | self.check_cuda(args.cuda)
35 |
36 | # Binary cross entropy loss and optimizer
37 | self.loss = nn.BCELoss()
38 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=0.0002, weight_decay=0.00001)
39 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=0.0002, weight_decay=0.00001)
40 |
41 | # Set the logger
42 | self.logger = Logger('./logs')
43 | self.number_of_images = 10
44 | self.epochs = args.epochs
45 | self.batch_size = args.batch_size
46 |
47 | # Cuda support
48 | def check_cuda(self, cuda_flag=False):
49 | if cuda_flag:
50 | self.cuda_index = 0
51 | self.cuda = True
52 | self.D.cuda(self.cuda_index)
53 | self.G.cuda(self.cuda_index)
54 | self.loss = nn.BCELoss().cuda(self.cuda_index)
55 | print("Cuda enabled flag: ")
56 | print(self.cuda)
57 |
58 | def train(self, train_loader):
59 | self.t_begin = time.time()
60 | generator_iter = 0
61 |
62 | for epoch in range(self.epochs+1):
63 | for i, (images, _) in enumerate(train_loader):
64 | # Check if round number of batches
65 | if i == train_loader.dataset.__len__() // self.batch_size:
66 | break
67 |
68 | # Flatten image 1,32x32 to 1024
69 | images = images.view(self.batch_size, -1)
70 | z = torch.rand((self.batch_size, 100))
71 |
72 | if self.cuda:
73 | real_labels = Variable(torch.ones(self.batch_size)).cuda(self.cuda_index)
74 | fake_labels = Variable(torch.zeros(self.batch_size)).cuda(self.cuda_index)
75 | images, z = Variable(images.cuda(self.cuda_index)), Variable(z.cuda(self.cuda_index))
76 | else:
77 | real_labels = Variable(torch.ones(self.batch_size))
78 | fake_labels = Variable(torch.zeros(self.batch_size))
79 | images, z = Variable(images), Variable(z)
80 |
81 | # Train discriminator
82 | # compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
83 | # [Training discriminator = Maximizing discriminator being correct]
84 | outputs = self.D(images)
85 | d_loss_real = self.loss(outputs.flatten(), real_labels)
86 | real_score = outputs
87 |
88 | # Compute BCELoss using fake images
89 | fake_images = self.G(z)
90 | outputs = self.D(fake_images)
91 | d_loss_fake = self.loss(outputs.flatten(), fake_labels)
92 | fake_score = outputs
93 |
94 | # Optimizie discriminator
95 | d_loss = d_loss_real + d_loss_fake
96 | self.D.zero_grad()
97 | d_loss.backward()
98 | self.d_optimizer.step()
99 |
100 | # Train generator
101 | if self.cuda:
102 | z = Variable(torch.randn(self.batch_size, 100).cuda(self.cuda_index))
103 | else:
104 | z = Variable(torch.randn(self.batch_size, 100))
105 | fake_images = self.G(z)
106 | outputs = self.D(fake_images)
107 |
108 | # We train G to maximize log(D(G(z))[maximize likelihood of discriminator being wrong] instead of
109 | # minimizing log(1-D(G(z)))[minizing likelihood of discriminator being correct]
110 | # From paper [https://arxiv.org/pdf/1406.2661.pdf]
111 | g_loss = self.loss(outputs.flatten(), real_labels)
112 |
113 | # Optimize generator
114 | self.D.zero_grad()
115 | self.G.zero_grad()
116 | g_loss.backward()
117 | self.g_optimizer.step()
118 | generator_iter += 1
119 |
120 |
121 | if ((i + 1) % 100) == 0:
122 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
123 | ((epoch + 1), (i + 1), train_loader.dataset.__len__() // self.batch_size, d_loss.data, g_loss.data))
124 |
125 | if self.cuda:
126 | z = Variable(torch.randn(self.batch_size, 100).cuda(self.cuda_index))
127 | else:
128 | z = Variable(torch.randn(self.batch_size, 100))
129 |
130 | # ============ TensorBoard logging ============#
131 | # (1) Log the scalar values
132 | info = {
133 | 'd_loss': d_loss.data,
134 | 'g_loss': g_loss.data
135 | }
136 |
137 | for tag, value in info.items():
138 | self.logger.scalar_summary(tag, value, i + 1)
139 |
140 | # (2) Log values and gradients of the parameters (histogram)
141 | for tag, value in self.D.named_parameters():
142 | tag = tag.replace('.', '/')
143 | self.logger.histo_summary(tag, self.to_np(value), i + 1)
144 | self.logger.histo_summary(tag + '/grad', self.to_np(value.grad), i + 1)
145 |
146 | # (3) Log the images
147 | info = {
148 | 'real_images': self.to_np(images.view(-1, 32, 32)[:self.number_of_images]),
149 | 'generated_images': self.generate_img(z, self.number_of_images)
150 | }
151 |
152 | for tag, images in info.items():
153 | self.logger.image_summary(tag, images, i + 1)
154 |
155 |
156 | if generator_iter % 1000 == 0:
157 | print('Generator iter-{}'.format(generator_iter))
158 | self.save_model()
159 |
160 | if not os.path.exists('training_result_images/'):
161 | os.makedirs('training_result_images/')
162 |
163 | # Denormalize images and save them in grid 8x8
164 | if self.cuda:
165 | z = Variable(torch.randn(self.batch_size, 100).cuda(self.cuda_index))
166 | else:
167 | z = Variable(torch.randn(self.batch_size, 100))
168 | samples = self.G(z)
169 | samples = samples.mul(0.5).add(0.5)
170 | samples = samples.data.cpu()
171 | grid = utils.make_grid(samples)
172 | utils.save_image(grid, 'training_result_images/gan_image_iter_{}.png'.format(
173 | str(generator_iter).zfill(3)))
174 |
175 | self.t_end = time.time()
176 | print('Time of training-{}'.format((self.t_end - self.t_begin)))
177 | # Save the trained parameters
178 | self.save_model()
179 |
180 | def evaluate(self, test_loader, D_model_path, G_model_path):
181 | self.load_model(D_model_path, G_model_path)
182 | if self.cuda:
183 | z = Variable(torch.randn(self.batch_size, 100).cuda(self.cuda_index))
184 | else:
185 | z = Variable(torch.randn(self.batch_size, 100))
186 | samples = self.G(z)
187 | samples = samples.mul(0.5).add(0.5)
188 | samples = samples.data.cpu()
189 | grid = utils.make_grid(samples)
190 | print("Grid of 8x8 images saved to 'gan_model_image.png'.")
191 | utils.save_image(grid, 'gan_model_image.png')
192 |
193 | def generate_img(self, z, number_of_images):
194 | samples = self.G(z).data.cpu().numpy()[:number_of_images]
195 | generated_images = []
196 | for sample in samples:
197 | generated_images.append(sample.reshape(32,32))
198 | return generated_images
199 |
200 | def to_np(self, x):
201 | return x.data.cpu().numpy()
202 |
203 | def save_model(self):
204 | torch.save(self.G.state_dict(), './generator.pkl')
205 | torch.save(self.D.state_dict(), './discriminator.pkl')
206 | print('Models save to ./generator.pkl & ./discriminator.pkl ')
207 |
208 | def load_model(self, D_model_filename, G_model_filename):
209 | D_model_path = os.path.join(os.getcwd(), D_model_filename)
210 | G_model_path = os.path.join(os.getcwd(), G_model_filename)
211 | self.D.load_state_dict(torch.load(D_model_path))
212 | self.G.load_state_dict(torch.load(G_model_path))
213 | print('Generator model loaded from {}.'.format(G_model_path))
214 | print('Discriminator model loaded from {}-'.format(D_model_path))
215 |
--------------------------------------------------------------------------------
/models/wgan_clipping.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import time as t
5 | import matplotlib.pyplot as plt
6 | plt.switch_backend('agg')
7 | import os
8 | from utils.tensorboard_logger import Logger
9 | from torchvision import utils
10 |
11 |
12 | SAVE_PER_TIMES = 1000
13 |
14 | class Generator(torch.nn.Module):
15 | def __init__(self, channels):
16 | super().__init__()
17 | # Filters [1024, 512, 256]
18 | # Input_dim = 100
19 | # Output_dim = C (number of channels)
20 | self.main_module = nn.Sequential(
21 | # Z latent vector 100
22 | nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0),
23 | nn.BatchNorm2d(num_features=1024),
24 | nn.ReLU(True),
25 |
26 | # State (1024x4x4)
27 | nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
28 | nn.BatchNorm2d(num_features=512),
29 | nn.ReLU(True),
30 |
31 | # State (512x8x8)
32 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
33 | nn.BatchNorm2d(num_features=256),
34 | nn.ReLU(True),
35 |
36 | # State (256x16x16)
37 | nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1))
38 | # output of main module --> Image (Cx32x32)
39 |
40 | self.output = nn.Tanh()
41 |
42 | def forward(self, x):
43 | x = self.main_module(x)
44 | return self.output(x)
45 |
46 | class Discriminator(torch.nn.Module):
47 | def __init__(self, channels):
48 | super().__init__()
49 | # Filters [256, 512, 1024]
50 | # Input_dim = channels (Cx64x64)
51 | # Output_dim = 1
52 | self.main_module = nn.Sequential(
53 | # Image (Cx32x32)
54 | nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1),
55 | nn.BatchNorm2d(num_features=256),
56 | nn.LeakyReLU(0.2, inplace=True),
57 |
58 | # State (256x16x16)
59 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
60 | nn.BatchNorm2d(num_features=512),
61 | nn.LeakyReLU(0.2, inplace=True),
62 |
63 | # State (512x8x8)
64 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
65 | nn.BatchNorm2d(num_features=1024),
66 | nn.LeakyReLU(0.2, inplace=True))
67 | # output of main module --> State (1024x4x4)
68 |
69 | self.output = nn.Sequential(
70 | # The output of D is no longer a probability, we do not apply sigmoid at the output of D.
71 | nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0))
72 |
73 |
74 | def forward(self, x):
75 | x = self.main_module(x)
76 | return self.output(x)
77 |
78 | def feature_extraction(self, x):
79 | # Use discriminator for feature extraction then flatten to vector of 16384
80 | x = self.main_module(x)
81 | return x.view(-1, 1024*4*4)
82 |
83 |
84 | class WGAN_CP(object):
85 | def __init__(self, args):
86 | print("WGAN_CP init model.")
87 | self.G = Generator(args.channels)
88 | self.D = Discriminator(args.channels)
89 | self.C = args.channels
90 |
91 | # check if cuda is available
92 | self.check_cuda(args.cuda)
93 |
94 | # WGAN values from paper
95 | self.learning_rate = 0.00005
96 |
97 | self.batch_size = 64
98 | self.weight_cliping_limit = 0.01
99 |
100 | # WGAN with gradient clipping uses RMSprop instead of ADAM
101 | self.d_optimizer = torch.optim.RMSprop(self.D.parameters(), lr=self.learning_rate)
102 | self.g_optimizer = torch.optim.RMSprop(self.G.parameters(), lr=self.learning_rate)
103 |
104 | # Set the logger
105 | self.logger = Logger('./logs')
106 | self.logger.writer.flush()
107 | self.number_of_images = 10
108 |
109 | self.generator_iters = args.generator_iters
110 | self.critic_iter = 5
111 |
112 | def get_torch_variable(self, arg):
113 | if self.cuda:
114 | return Variable(arg).cuda(self.cuda_index)
115 | else:
116 | return Variable(arg)
117 |
118 | def check_cuda(self, cuda_flag=False):
119 | if cuda_flag:
120 | self.cuda_index = 0
121 | self.cuda = True
122 | self.D.cuda(self.cuda_index)
123 | self.G.cuda(self.cuda_index)
124 | print("Cuda enabled flag: {}".format(self.cuda))
125 | else:
126 | self.cuda = False
127 |
128 |
129 | def train(self, train_loader):
130 | self.t_begin = t.time()
131 | #self.file = open("inception_score_graph.txt", "w")
132 |
133 | # Now batches are callable self.data.next()
134 | self.data = self.get_infinite_batches(train_loader)
135 |
136 | one = torch.FloatTensor([1])
137 | mone = one * -1
138 | if self.cuda:
139 | one = one.cuda(self.cuda_index)
140 | mone = mone.cuda(self.cuda_index)
141 |
142 | for g_iter in range(self.generator_iters):
143 |
144 | # Requires grad, Generator requires_grad = False
145 | for p in self.D.parameters():
146 | p.requires_grad = True
147 |
148 | # Train Dicriminator forward-loss-backward-update self.critic_iter times while 1 Generator forward-loss-backward-update
149 | for d_iter in range(self.critic_iter):
150 | self.D.zero_grad()
151 |
152 | # Clamp parameters to a range [-c, c], c=self.weight_cliping_limit
153 | for p in self.D.parameters():
154 | p.data.clamp_(-self.weight_cliping_limit, self.weight_cliping_limit)
155 |
156 | images = self.data.__next__()
157 | # Check for batch to have full batch_size
158 | if (images.size()[0] != self.batch_size):
159 | continue
160 |
161 | z = torch.rand((self.batch_size, 100, 1, 1))
162 |
163 | images, z = self.get_torch_variable(images), self.get_torch_variable(z)
164 |
165 |
166 | # Train discriminator
167 | # WGAN - Training discriminator more iterations than generator
168 | # Train with real images
169 | d_loss_real = self.D(images)
170 | d_loss_real = d_loss_real.mean(0).view(1)
171 | d_loss_real.backward(one)
172 |
173 | # Train with fake images
174 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1))
175 | fake_images = self.G(z)
176 | d_loss_fake = self.D(fake_images)
177 | d_loss_fake = d_loss_fake.mean(0).view(1)
178 | d_loss_fake.backward(mone)
179 |
180 | d_loss = d_loss_fake - d_loss_real
181 | Wasserstein_D = d_loss_real - d_loss_fake
182 | self.d_optimizer.step()
183 | print(f' Discriminator iteration: {d_iter}/{self.critic_iter}, loss_fake: {d_loss_fake.data}, loss_real: {d_loss_real.data}')
184 |
185 |
186 |
187 | # Generator update
188 | for p in self.D.parameters():
189 | p.requires_grad = False # to avoid computation
190 |
191 | self.G.zero_grad()
192 |
193 | # Train generator
194 | # Compute loss with fake images
195 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1))
196 | fake_images = self.G(z)
197 | g_loss = self.D(fake_images)
198 | g_loss = g_loss.mean().mean(0).view(1)
199 | g_loss.backward(one)
200 | g_cost = -g_loss
201 | self.g_optimizer.step()
202 | print(f'Generator iteration: {g_iter}/{self.generator_iters}, g_loss: {g_loss.data}')
203 |
204 | # Saving model and sampling images every 1000th generator iterations
205 | if (g_iter) % SAVE_PER_TIMES == 0:
206 | self.save_model()
207 | # Workaround because graphic card memory can't store more than 830 examples in memory for generating image
208 | # Therefore doing loop and generating 800 examples and stacking into list of samples to get 8000 generated images
209 | # This way Inception score is more correct since there are different generated examples from every class of Inception model
210 | # sample_list = []
211 | # for i in range(10):
212 | # z = Variable(torch.randn(800, 100, 1, 1)).cuda(self.cuda_index)
213 | # samples = self.G(z)
214 | # sample_list.append(samples.data.cpu().numpy())
215 | #
216 | # # Flattening list of list into one list
217 | # new_sample_list = list(chain.from_iterable(sample_list))
218 | # print("Calculating Inception Score over 8k generated images")
219 | # # Feeding list of numpy arrays
220 | # inception_score = get_inception_score(new_sample_list, cuda=True, batch_size=32,
221 | # resize=True, splits=10)
222 |
223 | if not os.path.exists('training_result_images/'):
224 | os.makedirs('training_result_images/')
225 |
226 | # Denormalize images and save them in grid 8x8
227 | z = self.get_torch_variable(torch.randn(800, 100, 1, 1))
228 | samples = self.G(z)
229 | samples = samples.mul(0.5).add(0.5)
230 | samples = samples.data.cpu()[:64]
231 | grid = utils.make_grid(samples)
232 | utils.save_image(grid, 'training_result_images/img_generatori_iter_{}.png'.format(str(g_iter).zfill(3)))
233 |
234 | # Testing
235 | time = t.time() - self.t_begin
236 | #print("Inception score: {}".format(inception_score))
237 | print("Generator iter: {}".format(g_iter))
238 | print("Time {}".format(time))
239 |
240 | # Write to file inception_score, gen_iters, time
241 | #output = str(g_iter) + " " + str(time) + " " + str(inception_score[0]) + "\n"
242 | #self.file.write(output)
243 |
244 | # ============ TensorBoard logging ============#
245 | # (1) Log the scalar values
246 | info = {
247 | 'Wasserstein distance': Wasserstein_D.data,
248 | 'Loss D': d_loss.data,
249 | 'Loss G': g_cost.data,
250 | 'Loss D Real': d_loss_real.data,
251 | 'Loss D Fake': d_loss_fake.data
252 | }
253 |
254 | for tag, value in info.items():
255 | self.logger.scalar_summary(tag, value.mean().cpu(), g_iter + 1)
256 |
257 | # (3) Log the images
258 | info = {
259 | 'real_images': self.real_images(images, self.number_of_images),
260 | 'generated_images': self.generate_img(z, self.number_of_images)
261 | }
262 |
263 | for tag, images in info.items():
264 | self.logger.image_summary(tag, images, g_iter + 1)
265 |
266 | self.t_end = t.time()
267 | print('Time of training-{}'.format((self.t_end - self.t_begin)))
268 | #self.file.close()
269 |
270 | # Save the trained parameters
271 | self.save_model()
272 |
273 | def evaluate(self, test_loader, D_model_path, G_model_path):
274 | self.load_model(D_model_path, G_model_path)
275 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1))
276 | samples = self.G(z)
277 | samples = samples.mul(0.5).add(0.5)
278 | samples = samples.data.cpu()
279 | grid = utils.make_grid(samples)
280 | print("Grid of 8x8 images saved to 'dgan_model_image.png'.")
281 | utils.save_image(grid, 'dgan_model_image.png')
282 |
283 | def real_images(self, images, number_of_images):
284 | if (self.C == 3):
285 | return self.to_np(images.view(-1, self.C, 32, 32)[:self.number_of_images])
286 | else:
287 | return self.to_np(images.view(-1, 32, 32)[:self.number_of_images])
288 |
289 | def generate_img(self, z, number_of_images):
290 | samples = self.G(z).data.cpu().numpy()[:number_of_images]
291 | generated_images = []
292 | for sample in samples:
293 | if self.C == 3:
294 | generated_images.append(sample.reshape(self.C, 32, 32))
295 | else:
296 | generated_images.append(sample.reshape(32, 32))
297 | return generated_images
298 |
299 | def to_np(self, x):
300 | return x.data.cpu().numpy()
301 |
302 | def save_model(self):
303 | torch.save(self.G.state_dict(), './generator.pkl')
304 | torch.save(self.D.state_dict(), './discriminator.pkl')
305 | print('Models save to ./generator.pkl & ./discriminator.pkl ')
306 |
307 | def load_model(self, D_model_filename, G_model_filename):
308 | D_model_path = os.path.join(os.getcwd(), D_model_filename)
309 | G_model_path = os.path.join(os.getcwd(), G_model_filename)
310 | self.D.load_state_dict(torch.load(D_model_path))
311 | self.G.load_state_dict(torch.load(G_model_path))
312 | print('Generator model loaded from {}.'.format(G_model_path))
313 | print('Discriminator model loaded from {}-'.format(D_model_path))
314 |
315 | def get_infinite_batches(self, data_loader):
316 | while True:
317 | for i, (images, _) in enumerate(data_loader):
318 | yield images
319 |
320 |
321 | def generate_latent_walk(self, number):
322 | if not os.path.exists('interpolated_images/'):
323 | os.makedirs('interpolated_images/')
324 |
325 | number_int = 10
326 | # interpolate between two noise (z1, z2).
327 | z_intp = torch.FloatTensor(1, 100, 1, 1)
328 | z1 = torch.randn(1, 100, 1, 1)
329 | z2 = torch.randn(1, 100, 1, 1)
330 | if self.cuda:
331 | z_intp = z_intp.cuda()
332 | z1 = z1.cuda()
333 | z2 = z2.cuda()
334 |
335 | z_intp = Variable(z_intp)
336 | images = []
337 | alpha = 1.0 / float(number_int + 1)
338 | print(alpha)
339 | for i in range(1, number_int + 1):
340 | z_intp.data = z1*alpha + z2*(1.0 - alpha)
341 | alpha += alpha
342 | fake_im = self.G(z_intp)
343 | fake_im = fake_im.mul(0.5).add(0.5) #denormalize
344 | images.append(fake_im.view(self.C,32,32).data.cpu())
345 |
346 | grid = utils.make_grid(images, nrow=number_int )
347 | utils.save_image(grid, 'interpolated_images/interpolated_{}.png'.format(str(number).zfill(3)))
348 | print("Saved interpolated images.")
349 |
--------------------------------------------------------------------------------
/models/wgan_gradient_penalty.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.autograd import Variable
5 | from torch import autograd
6 | import time as t
7 | import matplotlib.pyplot as plt
8 | plt.switch_backend('agg')
9 | import os
10 | from utils.tensorboard_logger import Logger
11 | from itertools import chain
12 | from torchvision import utils
13 |
14 | SAVE_PER_TIMES = 100
15 |
16 | class Generator(torch.nn.Module):
17 | def __init__(self, channels):
18 | super().__init__()
19 | # Filters [1024, 512, 256]
20 | # Input_dim = 100
21 | # Output_dim = C (number of channels)
22 | self.main_module = nn.Sequential(
23 | # Z latent vector 100
24 | nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0),
25 | nn.BatchNorm2d(num_features=1024),
26 | nn.ReLU(True),
27 |
28 | # State (1024x4x4)
29 | nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
30 | nn.BatchNorm2d(num_features=512),
31 | nn.ReLU(True),
32 |
33 | # State (512x8x8)
34 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
35 | nn.BatchNorm2d(num_features=256),
36 | nn.ReLU(True),
37 |
38 | # State (256x16x16)
39 | nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1))
40 | # output of main module --> Image (Cx32x32)
41 |
42 | self.output = nn.Tanh()
43 |
44 | def forward(self, x):
45 | x = self.main_module(x)
46 | return self.output(x)
47 |
48 |
49 | class Discriminator(torch.nn.Module):
50 | def __init__(self, channels):
51 | super().__init__()
52 | # Filters [256, 512, 1024]
53 | # Input_dim = channels (Cx64x64)
54 | # Output_dim = 1
55 | self.main_module = nn.Sequential(
56 | # Omitting batch normalization in critic because our new penalized training objective (WGAN with gradient penalty) is no longer valid
57 | # in this setting, since we penalize the norm of the critic's gradient with respect to each input independently and not the enitre batch.
58 | # There is not good & fast implementation of layer normalization --> using per instance normalization nn.InstanceNorm2d()
59 | # Image (Cx32x32)
60 | nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1),
61 | nn.InstanceNorm2d(256, affine=True),
62 | nn.LeakyReLU(0.2, inplace=True),
63 |
64 | # State (256x16x16)
65 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
66 | nn.InstanceNorm2d(512, affine=True),
67 | nn.LeakyReLU(0.2, inplace=True),
68 |
69 | # State (512x8x8)
70 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
71 | nn.InstanceNorm2d(1024, affine=True),
72 | nn.LeakyReLU(0.2, inplace=True))
73 | # output of main module --> State (1024x4x4)
74 |
75 | self.output = nn.Sequential(
76 | # The output of D is no longer a probability, we do not apply sigmoid at the output of D.
77 | nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0))
78 |
79 |
80 | def forward(self, x):
81 | x = self.main_module(x)
82 | return self.output(x)
83 |
84 | def feature_extraction(self, x):
85 | # Use discriminator for feature extraction then flatten to vector of 16384
86 | x = self.main_module(x)
87 | return x.view(-1, 1024*4*4)
88 |
89 |
90 | class WGAN_GP(object):
91 | def __init__(self, args):
92 | print("WGAN_GradientPenalty init model.")
93 | self.G = Generator(args.channels)
94 | self.D = Discriminator(args.channels)
95 | self.C = args.channels
96 |
97 | # Check if cuda is available
98 | self.check_cuda(args.cuda)
99 |
100 | # WGAN values from paper
101 | self.learning_rate = 1e-4
102 | self.b1 = 0.5
103 | self.b2 = 0.999
104 | self.batch_size = 64
105 |
106 | # WGAN_gradient penalty uses ADAM
107 | self.d_optimizer = optim.Adam(self.D.parameters(), lr=self.learning_rate, betas=(self.b1, self.b2))
108 | self.g_optimizer = optim.Adam(self.G.parameters(), lr=self.learning_rate, betas=(self.b1, self.b2))
109 |
110 | # Set the logger
111 | self.logger = Logger('./logs')
112 | self.logger.writer.flush()
113 | self.number_of_images = 10
114 |
115 | self.generator_iters = args.generator_iters
116 | self.critic_iter = 5
117 | self.lambda_term = 10
118 |
119 | def get_torch_variable(self, arg):
120 | if self.cuda:
121 | return Variable(arg).cuda(self.cuda_index)
122 | else:
123 | return Variable(arg)
124 |
125 | def check_cuda(self, cuda_flag=False):
126 | print(cuda_flag)
127 | if cuda_flag:
128 | self.cuda_index = 0
129 | self.cuda = True
130 | self.D.cuda(self.cuda_index)
131 | self.G.cuda(self.cuda_index)
132 | print("Cuda enabled flag: {}".format(self.cuda))
133 | else:
134 | self.cuda = False
135 |
136 |
137 | def train(self, train_loader):
138 | self.t_begin = t.time()
139 | self.file = open("inception_score_graph.txt", "w")
140 |
141 | # Now batches are callable self.data.next()
142 | self.data = self.get_infinite_batches(train_loader)
143 |
144 | one = torch.tensor(1, dtype=torch.float)
145 | mone = one * -1
146 | if self.cuda:
147 | one = one.cuda(self.cuda_index)
148 | mone = mone.cuda(self.cuda_index)
149 |
150 | for g_iter in range(self.generator_iters):
151 | # Requires grad, Generator requires_grad = False
152 | for p in self.D.parameters():
153 | p.requires_grad = True
154 |
155 | d_loss_real = 0
156 | d_loss_fake = 0
157 | Wasserstein_D = 0
158 | # Train Dicriminator forward-loss-backward-update self.critic_iter times while 1 Generator forward-loss-backward-update
159 | for d_iter in range(self.critic_iter):
160 | self.D.zero_grad()
161 |
162 | images = self.data.__next__()
163 | # Check for batch to have full batch_size
164 | if (images.size()[0] != self.batch_size):
165 | continue
166 |
167 | z = torch.rand((self.batch_size, 100, 1, 1))
168 |
169 | images, z = self.get_torch_variable(images), self.get_torch_variable(z)
170 |
171 | # Train discriminator
172 | # WGAN - Training discriminator more iterations than generator
173 | # Train with real images
174 | d_loss_real = self.D(images)
175 | d_loss_real = d_loss_real.mean()
176 | d_loss_real.backward(mone)
177 |
178 | # Train with fake images
179 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1))
180 |
181 | fake_images = self.G(z)
182 | d_loss_fake = self.D(fake_images)
183 | d_loss_fake = d_loss_fake.mean()
184 | d_loss_fake.backward(one)
185 |
186 | # Train with gradient penalty
187 | gradient_penalty = self.calculate_gradient_penalty(images.data, fake_images.data)
188 | gradient_penalty.backward()
189 |
190 |
191 | d_loss = d_loss_fake - d_loss_real + gradient_penalty
192 | Wasserstein_D = d_loss_real - d_loss_fake
193 | self.d_optimizer.step()
194 | print(f' Discriminator iteration: {d_iter}/{self.critic_iter}, loss_fake: {d_loss_fake}, loss_real: {d_loss_real}')
195 |
196 | # Generator update
197 | for p in self.D.parameters():
198 | p.requires_grad = False # to avoid computation
199 |
200 | self.G.zero_grad()
201 | # train generator
202 | # compute loss with fake images
203 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1))
204 | fake_images = self.G(z)
205 | g_loss = self.D(fake_images)
206 | g_loss = g_loss.mean()
207 | g_loss.backward(mone)
208 | g_cost = -g_loss
209 | self.g_optimizer.step()
210 | print(f'Generator iteration: {g_iter}/{self.generator_iters}, g_loss: {g_loss}')
211 | # Saving model and sampling images every 1000th generator iterations
212 | if (g_iter) % SAVE_PER_TIMES == 0:
213 | self.save_model()
214 | # # Workaround because graphic card memory can't store more than 830 examples in memory for generating image
215 | # # Therefore doing loop and generating 800 examples and stacking into list of samples to get 8000 generated images
216 | # # This way Inception score is more correct since there are different generated examples from every class of Inception model
217 | # sample_list = []
218 | # for i in range(125):
219 | # samples = self.data.__next__()
220 | # # z = Variable(torch.randn(800, 100, 1, 1)).cuda(self.cuda_index)
221 | # # samples = self.G(z)
222 | # sample_list.append(samples.data.cpu().numpy())
223 | # #
224 | # # # Flattening list of list into one list
225 | # new_sample_list = list(chain.from_iterable(sample_list))
226 | # print("Calculating Inception Score over 8k generated images")
227 | # # # Feeding list of numpy arrays
228 | # inception_score = get_inception_score(new_sample_list, cuda=True, batch_size=32,
229 | # resize=True, splits=10)
230 |
231 | if not os.path.exists('training_result_images/'):
232 | os.makedirs('training_result_images/')
233 |
234 | # Denormalize images and save them in grid 8x8
235 | z = self.get_torch_variable(torch.randn(800, 100, 1, 1))
236 | samples = self.G(z)
237 | samples = samples.mul(0.5).add(0.5)
238 | samples = samples.data.cpu()[:64]
239 | grid = utils.make_grid(samples)
240 | utils.save_image(grid, 'training_result_images/img_generatori_iter_{}.png'.format(str(g_iter).zfill(3)))
241 |
242 | # Testing
243 | time = t.time() - self.t_begin
244 | #print("Real Inception score: {}".format(inception_score))
245 | print("Generator iter: {}".format(g_iter))
246 | print("Time {}".format(time))
247 |
248 | # Write to file inception_score, gen_iters, time
249 | #output = str(g_iter) + " " + str(time) + " " + str(inception_score[0]) + "\n"
250 | #self.file.write(output)
251 |
252 |
253 | # ============ TensorBoard logging ============#
254 | # (1) Log the scalar values
255 | info = {
256 | 'Wasserstein distance': Wasserstein_D.data,
257 | 'Loss D': d_loss.data,
258 | 'Loss G': g_cost.data,
259 | 'Loss D Real': d_loss_real.data,
260 | 'Loss D Fake': d_loss_fake.data
261 |
262 | }
263 |
264 | for tag, value in info.items():
265 | self.logger.scalar_summary(tag, value.cpu(), g_iter + 1)
266 |
267 | # (3) Log the images
268 | info = {
269 | 'real_images': self.real_images(images, self.number_of_images),
270 | 'generated_images': self.generate_img(z, self.number_of_images)
271 | }
272 |
273 | for tag, images in info.items():
274 | self.logger.image_summary(tag, images, g_iter + 1)
275 |
276 |
277 |
278 | self.t_end = t.time()
279 | print('Time of training-{}'.format((self.t_end - self.t_begin)))
280 | #self.file.close()
281 |
282 | # Save the trained parameters
283 | self.save_model()
284 |
285 | def evaluate(self, test_loader, D_model_path, G_model_path):
286 | self.load_model(D_model_path, G_model_path)
287 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1))
288 | samples = self.G(z)
289 | samples = samples.mul(0.5).add(0.5)
290 | samples = samples.data.cpu()
291 | grid = utils.make_grid(samples)
292 | print("Grid of 8x8 images saved to 'dgan_model_image.png'.")
293 | utils.save_image(grid, 'dgan_model_image.png')
294 |
295 |
296 | def calculate_gradient_penalty(self, real_images, fake_images):
297 | eta = torch.FloatTensor(self.batch_size,1,1,1).uniform_(0,1)
298 | eta = eta.expand(self.batch_size, real_images.size(1), real_images.size(2), real_images.size(3))
299 | if self.cuda:
300 | eta = eta.cuda(self.cuda_index)
301 | else:
302 | eta = eta
303 |
304 | interpolated = eta * real_images + ((1 - eta) * fake_images)
305 |
306 | if self.cuda:
307 | interpolated = interpolated.cuda(self.cuda_index)
308 | else:
309 | interpolated = interpolated
310 |
311 | # define it to calculate gradient
312 | interpolated = Variable(interpolated, requires_grad=True)
313 |
314 | # calculate probability of interpolated examples
315 | prob_interpolated = self.D(interpolated)
316 |
317 | # calculate gradients of probabilities with respect to examples
318 | gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated,
319 | grad_outputs=torch.ones(
320 | prob_interpolated.size()).cuda(self.cuda_index) if self.cuda else torch.ones(
321 | prob_interpolated.size()),
322 | create_graph=True, retain_graph=True)[0]
323 |
324 | # flatten the gradients to it calculates norm batchwise
325 | gradients = gradients.view(gradients.size(0), -1)
326 |
327 | grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_term
328 | return grad_penalty
329 |
330 | def real_images(self, images, number_of_images):
331 | if (self.C == 3):
332 | return self.to_np(images.view(-1, self.C, 32, 32)[:self.number_of_images])
333 | else:
334 | return self.to_np(images.view(-1, 32, 32)[:self.number_of_images])
335 |
336 | def generate_img(self, z, number_of_images):
337 | samples = self.G(z).data.cpu().numpy()[:number_of_images]
338 | generated_images = []
339 | for sample in samples:
340 | if self.C == 3:
341 | generated_images.append(sample.reshape(self.C, 32, 32))
342 | else:
343 | generated_images.append(sample.reshape(32, 32))
344 | return generated_images
345 |
346 | def to_np(self, x):
347 | return x.data.cpu().numpy()
348 |
349 | def save_model(self):
350 | torch.save(self.G.state_dict(), './generator.pkl')
351 | torch.save(self.D.state_dict(), './discriminator.pkl')
352 | print('Models save to ./generator.pkl & ./discriminator.pkl ')
353 |
354 | def load_model(self, D_model_filename, G_model_filename):
355 | D_model_path = os.path.join(os.getcwd(), D_model_filename)
356 | G_model_path = os.path.join(os.getcwd(), G_model_filename)
357 | self.D.load_state_dict(torch.load(D_model_path))
358 | self.G.load_state_dict(torch.load(G_model_path))
359 | print('Generator model loaded from {}.'.format(G_model_path))
360 | print('Discriminator model loaded from {}-'.format(D_model_path))
361 |
362 | def get_infinite_batches(self, data_loader):
363 | while True:
364 | for i, (images, _) in enumerate(data_loader):
365 | yield images
366 |
367 | def generate_latent_walk(self, number):
368 | if not os.path.exists('interpolated_images/'):
369 | os.makedirs('interpolated_images/')
370 |
371 | number_int = 10
372 | # interpolate between twe noise(z1, z2).
373 | z_intp = torch.FloatTensor(1, 100, 1, 1)
374 | z1 = torch.randn(1, 100, 1, 1)
375 | z2 = torch.randn(1, 100, 1, 1)
376 | if self.cuda:
377 | z_intp = z_intp.cuda()
378 | z1 = z1.cuda()
379 | z2 = z2.cuda()
380 |
381 | z_intp = Variable(z_intp)
382 | images = []
383 | alpha = 1.0 / float(number_int + 1)
384 | print(alpha)
385 | for i in range(1, number_int + 1):
386 | z_intp.data = z1*alpha + z2*(1.0 - alpha)
387 | alpha += alpha
388 | fake_im = self.G(z_intp)
389 | fake_im = fake_im.mul(0.5).add(0.5) #denormalize
390 | images.append(fake_im.view(self.C,32,32).data.cpu())
391 |
392 | grid = utils.make_grid(images, nrow=number_int )
393 | utils.save_image(grid, 'interpolated_images/interpolated_{}.png'.format(str(number).zfill(3)))
394 | print("Saved interpolated images.")
395 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.5.1
2 | numpy==1.22.0
3 | Pillow==9.0.0
4 | scikit_learn==1.0.2
5 | scipy==1.7.3
6 | six==1.16.0
7 | tensorflow==2.7.0
8 | torch==1.10.1
9 | torchvision==0.11.2
10 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/utils/__init__.py
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 |
5 | def parse_args():
6 | parser = argparse.ArgumentParser(description="Pytorch implementation of GAN models.")
7 |
8 | parser.add_argument('--model', type=str, default='DCGAN', choices=['GAN', 'DCGAN', 'WGAN-CP', 'WGAN-GP'])
9 | parser.add_argument('--is_train', type=str, default='True')
10 | parser.add_argument('--dataroot', required=True, help='path to dataset')
11 | parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'cifar', 'stl10'],
12 | help='The name of dataset')
13 | parser.add_argument('--download', type=str, default='False')
14 | parser.add_argument('--epochs', type=int, default=50, help='The number of epochs to run')
15 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
16 | parser.add_argument('--cuda', type=str, default='False', help='Availability of cuda')
17 |
18 | parser.add_argument('--load_D', type=str, default='False', help='Path for loading Discriminator network')
19 | parser.add_argument('--load_G', type=str, default='False', help='Path for loading Generator network')
20 | parser.add_argument('--generator_iters', type=int, default=10000, help='The number of iterations for generator in WGAN model.')
21 | return check_args(parser.parse_args())
22 |
23 |
24 | # Checking arguments
25 | def check_args(args):
26 | # --epoch
27 | try:
28 | assert args.epochs >= 1
29 | except:
30 | print('Number of epochs must be larger than or equal to one')
31 |
32 | # --batch_size
33 | try:
34 | assert args.batch_size >= 1
35 | except:
36 | print('Batch size must be larger than or equal to one')
37 |
38 | if args.dataset == 'cifar' or args.dataset == 'stl10':
39 | args.channels = 3
40 | else:
41 | args.channels = 1
42 | args.cuda = True if args.cuda == 'True' else False
43 | return args
44 |
--------------------------------------------------------------------------------
/utils/data_loader.py:
--------------------------------------------------------------------------------
1 | import torchvision.datasets as dset
2 | import torchvision.transforms as transforms
3 | import torch.utils.data as data_utils
4 | from utils.fashion_mnist import MNIST, FashionMNIST
5 |
6 |
7 | def get_data_loader(args):
8 |
9 | if args.dataset == 'mnist':
10 | trans = transforms.Compose([
11 | transforms.Resize(32),
12 | transforms.ToTensor(),
13 | transforms.Normalize((0.5, ), (0.5, )),
14 | ])
15 | train_dataset = MNIST(root=args.dataroot, train=True, download=args.download, transform=trans)
16 | test_dataset = MNIST(root=args.dataroot, train=False, download=args.download, transform=trans)
17 |
18 | elif args.dataset == 'fashion-mnist':
19 | trans = transforms.Compose([
20 | transforms.Resize(32),
21 | transforms.ToTensor(),
22 | transforms.Normalize((0.5, ), (0.5, )),
23 | ])
24 | train_dataset = FashionMNIST(root=args.dataroot, train=True, download=args.download, transform=trans)
25 | test_dataset = FashionMNIST(root=args.dataroot, train=False, download=args.download, transform=trans)
26 |
27 | elif args.dataset == 'cifar':
28 | trans = transforms.Compose([
29 | transforms.Resize(32),
30 | transforms.ToTensor(),
31 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
32 | ])
33 |
34 | train_dataset = dset.CIFAR10(root=args.dataroot, train=True, download=args.download, transform=trans)
35 | test_dataset = dset.CIFAR10(root=args.dataroot, train=False, download=args.download, transform=trans)
36 |
37 | elif args.dataset == 'stl10':
38 | trans = transforms.Compose([
39 | transforms.Resize(32),
40 | transforms.ToTensor(),
41 | ])
42 | train_dataset = dset.STL10(root=args.dataroot, split='train', download=args.download, transform=trans)
43 | test_dataset = dset.STL10(root=args.dataroot, split='test', download=args.download, transform=trans)
44 |
45 | # Check if everything is ok with loading datasets
46 | assert train_dataset
47 | assert test_dataset
48 |
49 | train_dataloader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
50 | test_dataloader = data_utils.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True)
51 |
52 | return train_dataloader, test_dataloader
53 |
--------------------------------------------------------------------------------
/utils/fashion_mnist.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch.utils.data as data
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import errno
7 | import torch
8 | import codecs
9 |
10 | # Code referenced from torch source code to add Fashion-MNSIT dataset to dataloder
11 | # Url: http://pytorch.org/docs/0.3.0/_modules/torchvision/datasets/mnist.html#FashionMNIST
12 | class MNIST(data.Dataset):
13 | """`MNIST `_ Dataset.
14 | Args:
15 | root (string): Root directory of dataset where ``processed/training.pt``
16 | and ``processed/test.pt`` exist.
17 | train (bool, optional): If True, creates dataset from ``training.pt``,
18 | otherwise from ``test.pt``.
19 | download (bool, optional): If true, downloads the dataset from the internet and
20 | puts it in root directory. If dataset is already downloaded, it is not
21 | downloaded again.
22 | transform (callable, optional): A function/transform that takes in an PIL image
23 | and returns a transformed version. E.g, ``transforms.RandomCrop``
24 | target_transform (callable, optional): A function/transform that takes in the
25 | target and transforms it.
26 | """
27 | urls = [
28 | 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
29 | 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
30 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
31 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
32 | ]
33 | raw_folder = 'raw'
34 | processed_folder = 'processed'
35 | training_file = 'training.pt'
36 | test_file = 'test.pt'
37 |
38 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
39 | self.root = os.path.expanduser(root)
40 | self.transform = transform
41 | self.target_transform = target_transform
42 | self.train = train # training set or test set
43 |
44 | if download:
45 | self.download()
46 |
47 | if not self._check_exists():
48 | raise RuntimeError('Dataset not found.' +
49 | ' You can use download=True to download it')
50 |
51 | if self.train:
52 | self.train_data, self.train_labels = torch.load(
53 | os.path.join(self.root, self.processed_folder, self.training_file))
54 | else:
55 | self.test_data, self.test_labels = torch.load(
56 | os.path.join(self.root, self.processed_folder, self.test_file))
57 |
58 | def __getitem__(self, index):
59 | """
60 | Args:
61 | index (int): Index
62 | Returns:
63 | tuple: (image, target) where target is index of the target class.
64 | """
65 | if self.train:
66 | img, target = self.train_data[index], self.train_labels[index]
67 | else:
68 | img, target = self.test_data[index], self.test_labels[index]
69 |
70 | # doing this so that it is consistent with all other datasets
71 | # to return a PIL Image
72 | img = Image.fromarray(img.numpy(), mode='L')
73 |
74 | if self.transform is not None:
75 | img = self.transform(img)
76 |
77 | if self.target_transform is not None:
78 | target = self.target_transform(target)
79 |
80 | return img, target
81 |
82 | def __len__(self):
83 | if self.train:
84 | return len(self.train_data)
85 | else:
86 | return len(self.test_data)
87 |
88 | def _check_exists(self):
89 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
90 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
91 |
92 | def download(self):
93 | """Download the MNIST data if it doesn't exist in processed_folder already."""
94 | from six.moves import urllib
95 | import gzip
96 |
97 | if self._check_exists():
98 | return
99 |
100 | # download files
101 | try:
102 | os.makedirs(os.path.join(self.root, self.raw_folder))
103 | os.makedirs(os.path.join(self.root, self.processed_folder))
104 | except OSError as e:
105 | if e.errno == errno.EEXIST:
106 | pass
107 | else:
108 | raise
109 |
110 | for url in self.urls:
111 | print('Downloading ' + url)
112 | data = urllib.request.urlopen(url)
113 | filename = url.rpartition('/')[2]
114 | file_path = os.path.join(self.root, self.raw_folder, filename)
115 | with open(file_path, 'wb') as f:
116 | f.write(data.read())
117 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \
118 | gzip.GzipFile(file_path) as zip_f:
119 | out_f.write(zip_f.read())
120 | os.unlink(file_path)
121 |
122 | # process and save as torch files
123 | print('Processing...')
124 |
125 | training_set = (
126 | read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
127 | read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
128 | )
129 | test_set = (
130 | read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
131 | read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
132 | )
133 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
134 | torch.save(training_set, f)
135 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
136 | torch.save(test_set, f)
137 |
138 | print('Done!')
139 |
140 |
141 | class FashionMNIST(MNIST):
142 | """`Fashion-MNIST `_ Dataset.
143 | Args:
144 | root (string): Root directory of dataset where ``processed/training.pt``
145 | and ``processed/test.pt`` exist.
146 | train (bool, optional): If True, creates dataset from ``training.pt``,
147 | otherwise from ``test.pt``.
148 | download (bool, optional): If true, downloads the dataset from the internet and
149 | puts it in root directory. If dataset is already downloaded, it is not
150 | downloaded again.
151 | transform (callable, optional): A function/transform that takes in an PIL image
152 | and returns a transformed version. E.g, ``transforms.RandomCrop``
153 | target_transform (callable, optional): A function/transform that takes in the
154 | target and transforms it.
155 | """
156 | urls = [
157 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
158 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
159 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
160 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
161 | ]
162 |
163 |
164 | def get_int(b):
165 | return int(codecs.encode(b, 'hex'), 16)
166 |
167 |
168 | def parse_byte(b):
169 | if isinstance(b, str):
170 | return ord(b)
171 | return b
172 |
173 |
174 | def read_label_file(path):
175 | with open(path, 'rb') as f:
176 | data = f.read()
177 | assert get_int(data[:4]) == 2049
178 | length = get_int(data[4:8])
179 | labels = [parse_byte(b) for b in data[8:]]
180 | assert len(labels) == length
181 | return torch.LongTensor(labels)
182 |
183 |
184 | def read_image_file(path):
185 | with open(path, 'rb') as f:
186 | data = f.read()
187 | assert get_int(data[:4]) == 2051
188 | length = get_int(data[4:8])
189 | num_rows = get_int(data[8:12])
190 | num_cols = get_int(data[12:16])
191 | images = []
192 | idx = 16
193 | for l in range(length):
194 | img = []
195 | images.append(img)
196 | for r in range(num_rows):
197 | row = []
198 | img.append(row)
199 | for c in range(num_cols):
200 | row.append(parse_byte(data[idx]))
201 | idx += 1
202 | assert len(images) == length
203 | return torch.ByteTensor(images).view(-1, 28, 28)
204 |
--------------------------------------------------------------------------------
/utils/feature_extraction_test.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 | import torch
3 | from torch.autograd import Variable
4 | from utils.data_loader import get_data_loader
5 | from sklearn.metrics import accuracy_score
6 | from sklearn.linear_model import LogisticRegression
7 |
8 | '''
9 | Running feature extraction part for GAN model extraction
10 | cifar-10 $ python main.py --dataroot datasets/cifar --dataset cifar --load_D trained_models/dcgan/cifar/discriminator.pkl --load_G trained_models/dcgan/cifar/generator.pkl
11 | '''
12 |
13 | class FeatureExtractionTest():
14 |
15 | def __init__(self, train_loader, test_loader, cuda_flag, batch_size):
16 | self.train_loader = train_loader
17 | self.test_loader = test_loader
18 | print("Train length: {}".format(len(self.train_loader)))
19 | print("Test length: {}".format(len(self.test_loader)))
20 | self.batch_size = batch_size
21 |
22 | # Remove fully connected layer and extract 2048 vector as feautre representation of image
23 | self.model = models.resnet152(pretrained=True).cuda()
24 | self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
25 |
26 |
27 | # Feature extraction test #1 flattening image
28 | def flatten_images(self):
29 | """
30 | Flattening image as image representation.
31 | Input is image and output is flattened self.channels*32*32 dimensional numpy array
32 | """
33 | x_train, y_train = [], []
34 | x_test, y_test = [], []
35 |
36 | # flatten pixels of train images
37 | for i, (images, labels) in enumerate(self.train_loader):
38 | if i == len(self.train_loader) // self.batch_size:
39 | break
40 | images = images.numpy()
41 | labels = labels.numpy()
42 |
43 | # Iterate over batch and save as numpy array features of images and label
44 | for j in range(self.batch_size):
45 | x_train.append(images[j].flatten())
46 | y_train.append(labels[j])
47 |
48 | for i, (images, labels) in enumerate(self.test_loader):
49 | if i == len(self.test_loader) // self.batch_size:
50 | break
51 |
52 | images = images.numpy()
53 | labels = labels.numpy()
54 |
55 | # Iterate over batch and save as numpy array features of images and label
56 | for j in range(self.batch_size):
57 | x_test.append(images[j].flatten())
58 | y_test.append(labels[j])
59 |
60 | return x_train, y_train, x_test, y_test
61 |
62 | # Feature extraction test #4 transfer learning Inception v3 model pretrained
63 | # Resize imaged to 224x224 for pretrained models
64 | def inception_feature_extraction(self):
65 | """
66 | Extract features from images with pretrained ResNet152 on ImageNet, with removed fully-connected layer.
67 | Input is image and output is flattened 2048 dimensional numpy array
68 | """
69 | x_train, y_train = [], []
70 | x_test, y_test = [], []
71 |
72 | for i, (images, labels) in enumerate(self.train_loader):
73 | if i == len(self.train_loader) // self.batch_size:
74 | break
75 |
76 | images = Variable(images).cuda()
77 |
78 | # Feature extraction with Resnet152 resulting with feature vector of 2048 dimension
79 | outputs = self.model(images)
80 |
81 | # Convert FloatTensors to numpy array
82 | features = outputs.data.cpu().numpy()
83 | labels = labels.numpy()
84 |
85 | # Iterate over batch and save as numpy array features of images and label
86 | for j in range(self.batch_size):
87 | x_train.append(features[j].flatten())
88 | y_train.append(labels[j])
89 |
90 |
91 | for i, (images, labels) in enumerate(self.test_loader):
92 | if i == len(self.test_loader) // self.batch_size:
93 | break
94 |
95 | images = Variable(images).cuda()
96 |
97 | # Feature extraction with Resnet152 resulting with feature vector of 2048 dimension
98 | outputs = self.model(images)
99 |
100 | # Convert FloatTensors to numpy array
101 | features = outputs.data.cpu().numpy()
102 | labels = labels.numpy()
103 |
104 | # Iterate over batch and save as numpy array features of images and label
105 | for j in range(self.batch_size):
106 | x_test.append(features[j].flatten())
107 | y_test.append(labels[j])
108 |
109 | return x_train, y_train, x_test, y_test
110 |
111 | # Feature extraction GAN model discriminator output 1024x4x4
112 | def GAN_feature_extraction(self, discriminator):
113 | """
114 | Extract features from images with trained discriminator of GAN model.
115 | Input is image and output is flattened 16348 dimensional numpy array (1024x4x4)
116 | discriminator -- Trained discriminator of GAN model
117 | """
118 | x_train, y_train = [], []
119 | x_test, y_test = [], []
120 | for i, (images, labels) in enumerate(self.train_loader):
121 | if i == len(self.train_loader) // self.batch_size:
122 | break
123 |
124 | images = Variable(images).cuda()
125 | # Feature extraction DCGAN discriminator output 1024x4x4
126 | outputs = discriminator.feature_extraction(images)
127 |
128 | # Convert FloatTensors to numpy array
129 | features = outputs.data.cpu().numpy()
130 | labels = labels.numpy()
131 |
132 | # Iterate over batch and save as numpy array features of images and label
133 | for j in range(self.batch_size):
134 | x_train.append(features[j].flatten())
135 | y_train.append(labels[j])
136 |
137 | for i, (images, labels) in enumerate(self.test_loader):
138 | if i == len(self.test_loader) // self.batch_size:
139 | break
140 |
141 | images = Variable(images).cuda()
142 | outputs = discriminator.feature_extraction(images)
143 |
144 | # Convert FloatTensors to numpy array
145 | features = outputs.data.cpu().numpy()
146 | labels = labels.numpy()
147 |
148 | # Iterate over batch and save as numpy array features of images and label
149 | for j in range(self.batch_size):
150 | x_test.append(features[j].flatten())
151 | y_test.append(labels[j])
152 |
153 | return x_train, y_train, x_test, y_test
154 |
155 |
156 | def calculate_score(self):
157 | """
158 | Calculate accuracy score by fitting feature representation on to a linear classificato LinearSVM or LogisticRegression
159 | """
160 | mean_score = 0
161 | for i in range(10):
162 | # This way data is shuffling every iteration
163 | train_loader, test_loader = get_data_loader(args)
164 |
165 | x_train, y_train, x_test, y_test = feature_extraction.inception_feature_extraction()
166 | # x_train, y_train, x_test, y_test = feature_extraction.GAN_feature_extraction(model.D)
167 | # x_train, y_train, x_test, y_test = feature_extraction.flatten_images()
168 |
169 | # clf = LinearSVC()
170 | clf = LogisticRegression()
171 | clf.fit(x_train, y_train)
172 |
173 | predicted = clf.predict(x_test)
174 | score = accuracy_score(y_test, predicted)
175 | print("Accuaracy score: {}".format(score))
176 | mean_score += score
177 | print("Mean score: {}".format(float(mean_score) / float(10)))
178 | return float(mean_score) / float(10)
179 |
--------------------------------------------------------------------------------
/utils/inception_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 | from torch.nn import functional as F
5 | import torch.utils.data
6 | from torchvision.models.inception import inception_v3
7 | import numpy as np
8 | from scipy.stats import entropy
9 |
10 |
11 | def get_inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
12 | """
13 | Computes the inception score of the generated images imgs
14 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
15 | cuda -- whether or not to run on GPU
16 | batch_size -- batch size for feeding into Inception v3
17 | splits -- number of splits
18 | """
19 | N = len(imgs)
20 |
21 | assert batch_size > 0
22 | assert N > batch_size
23 |
24 | # Set up dtype
25 | if cuda:
26 | dtype = torch.cuda.FloatTensor
27 | else:
28 | if torch.cuda.is_available():
29 | print("WARNING: You have a CUDA device, so you should probably set cuda=True")
30 | dtype = torch.FloatTensor
31 |
32 | # Set up dataloader
33 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
34 |
35 | # Load inception model
36 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
37 | inception_model.eval();
38 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
39 | def get_pred(x):
40 | if resize:
41 | x = up(x)
42 | x = inception_model(x)
43 | return F.softmax(x).data.cpu().numpy()
44 |
45 | # Get predictions
46 | preds = np.zeros((N, 1000))
47 |
48 | for i, batch in enumerate(dataloader, 0):
49 | batch = batch.type(dtype)
50 | batchv = Variable(batch)
51 | batch_size_i = batch.size()[0]
52 |
53 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)
54 |
55 | # Now compute the mean kl-div
56 | split_scores = []
57 |
58 | for k in range(splits):
59 | part = preds[k * (N // splits): (k+1) * (N // splits), :]
60 | py = np.mean(part, axis=0)
61 | scores = []
62 | for i in range(part.shape[0]):
63 | pyx = part[i, :]
64 | scores.append(entropy(pyx, py))
65 | split_scores.append(np.exp(np.mean(scores)))
66 |
67 | return np.mean(split_scores), np.std(split_scores)
68 |
--------------------------------------------------------------------------------
/utils/tensorboard_logger.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | class Logger(object):
6 | def __init__(self, log_dir):
7 | """Create a summary writer logging to log_dir."""
8 | self.writer = tf.summary.create_file_writer(log_dir)
9 |
10 | def scalar_summary(self, tag, value, step):
11 | """Log a scalar variable."""
12 | with self.writer.as_default():
13 | tf.summary.scalar(tag, data=value, step=step)
14 |
15 | def image_summary(self, tag, images, step):
16 | """Log a list of images.
17 | Args::images: numpy of shape (Batch x C x H x W) in the range [-1.0, 1.0]
18 | """
19 | with self.writer.as_default():
20 | imgs = None
21 | for i, j in enumerate(images):
22 | img = ((j*0.5+0.5)*255).round().astype('uint8')
23 | if len(img.shape) == 3:
24 | img = img.transpose(1, 2, 0)
25 | else:
26 | img = img[:, :, np.newaxis]
27 | img = img[np.newaxis, :]
28 | if not imgs is None:
29 | imgs = np.append(imgs, img, axis=0)
30 | else:
31 | imgs = img
32 | tf.summary.image('{}'.format(tag), imgs, max_outputs=len(imgs), step=step)
33 |
34 | def histo_summary(self, tag, values, step, bins=1000):
35 | """Log a histogram of the tensor of values."""
36 | with self.writer.as_default():
37 | tf.summary.histogram('{}'.format(tag), values, buckets=bins, step=step)
38 |
--------------------------------------------------------------------------------