├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── images ├── CIFAR-10.png ├── Fashion-MNIST.png ├── MNIST.png ├── inception_graph_generator_iters.png ├── inception_graph_time.png ├── latent-mnist.png └── latent_fashion.png ├── main.py ├── models ├── __init__.py ├── dcgan.py ├── gan.py ├── wgan_clipping.py └── wgan_gradient_penalty.py ├── requirements.txt └── utils ├── __init__.py ├── config.py ├── data_loader.py ├── fashion_mnist.py ├── feature_extraction_test.py ├── inception_score.py └── tensorboard_logger.py /.gitignore: -------------------------------------------------------------------------------- 1 | .python-version 2 | **__pycache__** 3 | *.pkl 4 | logs/* 5 | datasets/* 6 | training_result_images/ 7 | inception_score_graph.txt 8 | .vscode/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at filip.zelic@protonmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Green9 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch code for GAN models 2 | This is the pytorch implementation of 3 different GAN models using same convolutional architecture. 3 | 4 | 5 | - DCGAN (Deep convolutional GAN) 6 | - WGAN-CP (Wasserstein GAN using weight clipping) 7 | - WGAN-GP (Wasserstein GAN using gradient penalty) 8 | 9 | 10 | 11 | ## Dependecies 12 | The prominent packages are: 13 | 14 | * numpy 15 | * scikit-learn 16 | * tensorflow 2.5.0 17 | * pytorch 1.8.1 18 | * torchvision 0.9.1 19 | 20 | To install all the dependencies quickly and easily you should use __pip__ 21 | 22 | ```python 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | 27 | 28 | *Training* 29 | --- 30 | Running training of DCGAN model on Fashion-MNIST dataset: 31 | 32 | 33 | ``` 34 | python main.py --model DCGAN \ 35 | --is_train True \ 36 | --download True \ 37 | --dataroot datasets/fashion-mnist \ 38 | --dataset fashion-mnist \ 39 | --epochs 30 \ 40 | --cuda True \ 41 | --batch_size 64 42 | ``` 43 | 44 | Running training of WGAN-GP model on CIFAR-10 dataset: 45 | 46 | ``` 47 | python main.py --model WGAN-GP \ 48 | --is_train True \ 49 | --download True \ 50 | --dataroot datasets/cifar \ 51 | --dataset cifar \ 52 | --generator_iters 40000 \ 53 | --cuda True \ 54 | --batch_size 64 55 | ``` 56 | 57 | Start tensorboard: 58 | 59 | ``` 60 | tensorboard --logdir ./logs/ 61 | ``` 62 | 63 | *Walk in latent space* 64 | --- 65 | *Interpolation between a two random latent vector z over 10 random points, shows that generated samples have smooth transitions.* 66 | 67 | 68 |           69 | 70 | 71 | 72 | 73 | 74 | *Generated examples MNIST, Fashion-MNIST, CIFAR-10* 75 | --- 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | *Inception score* 87 | --- 88 | [About Inception score](https://arxiv.org/pdf/1801.01973.pdf) 89 | 90 | 91 |           92 | 93 | 94 | *Useful Resources* 95 | --- 96 | 97 | 98 | - [WGAN reddit thread](https://www.reddit.com/r/MachineLearning/comments/5qxoaz/r_170107875_wasserstein_gan/) 99 | - [Blogpost](https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html) 100 | - [Deconvolution and checkboard Artifacts](https://distill.pub/2016/deconv-checkerboard/) 101 | - [WGAN-CP paper](https://arxiv.org/pdf/1701.07875.pdf) 102 | - [WGAN-GP paper](https://arxiv.org/pdf/1704.00028.pdf) 103 | - [DCGAN paper](https://arxiv.org/pdf/1511.06434.pdf) 104 | - [Working remotely with PyCharm and SSH](https://medium.com/@erikhallstrm/work-remotely-with-pycharm-tensorflow-and-ssh-c60564be862d) 105 | -------------------------------------------------------------------------------- /images/CIFAR-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/CIFAR-10.png -------------------------------------------------------------------------------- /images/Fashion-MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/Fashion-MNIST.png -------------------------------------------------------------------------------- /images/MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/MNIST.png -------------------------------------------------------------------------------- /images/inception_graph_generator_iters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/inception_graph_generator_iters.png -------------------------------------------------------------------------------- /images/inception_graph_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/inception_graph_time.png -------------------------------------------------------------------------------- /images/latent-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/latent-mnist.png -------------------------------------------------------------------------------- /images/latent_fashion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/images/latent_fashion.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils.config import parse_args 2 | from utils.data_loader import get_data_loader 3 | 4 | from models.gan import GAN 5 | from models.dcgan import DCGAN_MODEL 6 | from models.wgan_clipping import WGAN_CP 7 | from models.wgan_gradient_penalty import WGAN_GP 8 | 9 | 10 | def main(args): 11 | model = None 12 | if args.model == 'GAN': 13 | model = GAN(args) 14 | elif args.model == 'DCGAN': 15 | model = DCGAN_MODEL(args) 16 | elif args.model == 'WGAN-CP': 17 | model = WGAN_CP(args) 18 | elif args.model == 'WGAN-GP': 19 | model = WGAN_GP(args) 20 | else: 21 | print("Model type non-existing. Try again.") 22 | exit(-1) 23 | 24 | # Load datasets to train and test loaders 25 | train_loader, test_loader = get_data_loader(args) 26 | #feature_extraction = FeatureExtractionTest(train_loader, test_loader, args.cuda, args.batch_size) 27 | 28 | # Start model training 29 | if args.is_train == 'True': 30 | model.train(train_loader) 31 | 32 | # start evaluating on test data 33 | else: 34 | model.evaluate(test_loader, args.load_D, args.load_G) 35 | for i in range(50): 36 | model.generate_latent_walk(i) 37 | 38 | 39 | if __name__ == '__main__': 40 | args = parse_args() 41 | print(args.cuda) 42 | main(args) 43 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/models/__init__.py -------------------------------------------------------------------------------- /models/dcgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import time as t 5 | import os 6 | from utils.tensorboard_logger import Logger 7 | from utils.inception_score import get_inception_score 8 | from itertools import chain 9 | from torchvision import utils 10 | 11 | class Generator(torch.nn.Module): 12 | def __init__(self, channels): 13 | super().__init__() 14 | # Filters [1024, 512, 256] 15 | # Input_dim = 100 16 | # Output_dim = C (number of channels) 17 | self.main_module = nn.Sequential( 18 | # Z latent vector 100 19 | nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0), 20 | nn.BatchNorm2d(num_features=1024), 21 | nn.ReLU(True), 22 | 23 | # State (1024x4x4) 24 | nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1), 25 | nn.BatchNorm2d(num_features=512), 26 | nn.ReLU(True), 27 | 28 | # State (512x8x8) 29 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1), 30 | nn.BatchNorm2d(num_features=256), 31 | nn.ReLU(True), 32 | 33 | # State (256x16x16) 34 | nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1)) 35 | # output of main module --> Image (Cx32x32) 36 | 37 | self.output = nn.Tanh() 38 | 39 | def forward(self, x): 40 | x = self.main_module(x) 41 | return self.output(x) 42 | 43 | 44 | class Discriminator(torch.nn.Module): 45 | def __init__(self, channels): 46 | super().__init__() 47 | # Filters [256, 512, 1024] 48 | # Input_dim = channels (Cx64x64) 49 | # Output_dim = 1 50 | self.main_module = nn.Sequential( 51 | # Image (Cx32x32) 52 | nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1), 53 | nn.LeakyReLU(0.2, inplace=True), 54 | 55 | # State (256x16x16) 56 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1), 57 | nn.BatchNorm2d(512), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | 60 | # State (512x8x8) 61 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1), 62 | nn.BatchNorm2d(1024), 63 | nn.LeakyReLU(0.2, inplace=True)) 64 | # outptut of main module --> State (1024x4x4) 65 | 66 | self.output = nn.Sequential( 67 | nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0), 68 | # Output 1 69 | nn.Sigmoid()) 70 | 71 | def forward(self, x): 72 | x = self.main_module(x) 73 | return self.output(x) 74 | 75 | def feature_extraction(self, x): 76 | # Use discriminator for feature extraction then flatten to vector of 16384 features 77 | x = self.main_module(x) 78 | return x.view(-1, 1024*4*4) 79 | 80 | class DCGAN_MODEL(object): 81 | def __init__(self, args): 82 | print("DCGAN model initalization.") 83 | self.G = Generator(args.channels) 84 | self.D = Discriminator(args.channels) 85 | self.C = args.channels 86 | 87 | # binary cross entropy loss and optimizer 88 | self.loss = nn.BCELoss() 89 | 90 | self.cuda = False 91 | self.cuda_index = 0 92 | # check if cuda is available 93 | self.check_cuda(args.cuda) 94 | 95 | # Using lower learning rate than suggested by (ADAM authors) lr=0.0002 and Beta_1 = 0.5 instead od 0.9 works better [Radford2015] 96 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999)) 97 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999)) 98 | 99 | self.epochs = args.epochs 100 | self.batch_size = args.batch_size 101 | 102 | # Set the logger 103 | self.logger = Logger('./logs') 104 | self.number_of_images = 10 105 | 106 | # cuda support 107 | def check_cuda(self, cuda_flag=False): 108 | if cuda_flag: 109 | self.cuda = True 110 | self.D.cuda(self.cuda_index) 111 | self.G.cuda(self.cuda_index) 112 | self.loss = nn.BCELoss().cuda(self.cuda_index) 113 | print("Cuda enabled flag: ") 114 | print(self.cuda) 115 | 116 | 117 | def train(self, train_loader): 118 | self.t_begin = t.time() 119 | generator_iter = 0 120 | #self.file = open("inception_score_graph.txt", "w") 121 | 122 | for epoch in range(self.epochs): 123 | self.epoch_start_time = t.time() 124 | 125 | for i, (images, _) in enumerate(train_loader): 126 | # Check if round number of batches 127 | if i == train_loader.dataset.__len__() // self.batch_size: 128 | break 129 | 130 | z = torch.rand((self.batch_size, 100, 1, 1)) 131 | real_labels = torch.ones(self.batch_size) 132 | fake_labels = torch.zeros(self.batch_size) 133 | 134 | if self.cuda: 135 | images, z = Variable(images).cuda(self.cuda_index), Variable(z).cuda(self.cuda_index) 136 | real_labels, fake_labels = Variable(real_labels).cuda(self.cuda_index), Variable(fake_labels).cuda(self.cuda_index) 137 | else: 138 | images, z = Variable(images), Variable(z) 139 | real_labels, fake_labels = Variable(real_labels), Variable(fake_labels) 140 | 141 | 142 | # Train discriminator 143 | # Compute BCE_Loss using real images 144 | outputs = self.D(images) 145 | d_loss_real = self.loss(outputs.flatten(), real_labels) 146 | real_score = outputs 147 | 148 | # Compute BCE Loss using fake images 149 | if self.cuda: 150 | z = Variable(torch.randn(self.batch_size, 100, 1, 1)).cuda(self.cuda_index) 151 | else: 152 | z = Variable(torch.randn(self.batch_size, 100, 1, 1)) 153 | fake_images = self.G(z) 154 | outputs = self.D(fake_images) 155 | d_loss_fake = self.loss(outputs.flatten(), fake_labels) 156 | fake_score = outputs 157 | 158 | # Optimize discriminator 159 | d_loss = d_loss_real + d_loss_fake 160 | self.D.zero_grad() 161 | d_loss.backward() 162 | self.d_optimizer.step() 163 | 164 | # Train generator 165 | # Compute loss with fake images 166 | if self.cuda: 167 | z = Variable(torch.randn(self.batch_size, 100, 1, 1)).cuda(self.cuda_index) 168 | else: 169 | z = Variable(torch.randn(self.batch_size, 100, 1, 1)) 170 | fake_images = self.G(z) 171 | outputs = self.D(fake_images) 172 | g_loss = self.loss(outputs.flatten(), real_labels) 173 | 174 | # Optimize generator 175 | self.D.zero_grad() 176 | self.G.zero_grad() 177 | g_loss.backward() 178 | self.g_optimizer.step() 179 | generator_iter += 1 180 | 181 | 182 | if generator_iter % 1000 == 0: 183 | # Workaround because graphic card memory can't store more than 800+ examples in memory for generating image 184 | # Therefore doing loop and generating 800 examples and stacking into list of samples to get 8000 generated images 185 | # This way Inception score is more correct since there are different generated examples from every class of Inception model 186 | # sample_list = [] 187 | # for i in range(10): 188 | # z = Variable(torch.randn(800, 100, 1, 1)).cuda(self.cuda_index) 189 | # samples = self.G(z) 190 | # sample_list.append(samples.data.cpu().numpy()) 191 | # 192 | # # Flattening list of lists into one list of numpy arrays 193 | # new_sample_list = list(chain.from_iterable(sample_list)) 194 | # print("Calculating Inception Score over 8k generated images") 195 | # # Feeding list of numpy arrays 196 | # inception_score = get_inception_score(new_sample_list, cuda=True, batch_size=32, 197 | # resize=True, splits=10) 198 | print('Epoch-{}'.format(epoch + 1)) 199 | self.save_model() 200 | 201 | if not os.path.exists('training_result_images/'): 202 | os.makedirs('training_result_images/') 203 | 204 | # Denormalize images and save them in grid 8x8 205 | z = Variable(torch.randn(800, 100, 1, 1)).cuda(self.cuda_index) 206 | samples = self.G(z) 207 | samples = samples.mul(0.5).add(0.5) 208 | samples = samples.data.cpu()[:64] 209 | grid = utils.make_grid(samples) 210 | utils.save_image(grid, 'training_result_images/img_generatori_iter_{}.png'.format(str(generator_iter).zfill(3))) 211 | 212 | time = t.time() - self.t_begin 213 | #print("Inception score: {}".format(inception_score)) 214 | print("Generator iter: {}".format(generator_iter)) 215 | print("Time {}".format(time)) 216 | 217 | # Write to file inception_score, gen_iters, time 218 | #output = str(generator_iter) + " " + str(time) + " " + str(inception_score[0]) + "\n" 219 | #self.file.write(output) 220 | 221 | 222 | if ((i + 1) % 100) == 0: 223 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 224 | ((epoch + 1), (i + 1), train_loader.dataset.__len__() // self.batch_size, d_loss.data, g_loss.data)) 225 | 226 | z = Variable(torch.randn(self.batch_size, 100, 1, 1).cuda(self.cuda_index)) 227 | 228 | # TensorBoard logging 229 | # Log the scalar values 230 | info = { 231 | 'd_loss': d_loss.data, 232 | 'g_loss': g_loss.data 233 | } 234 | 235 | for tag, value in info.items(): 236 | self.logger.scalar_summary(tag, value, generator_iter) 237 | 238 | # Log values and gradients of the parameters 239 | for tag, value in self.D.named_parameters(): 240 | tag = tag.replace('.', '/') 241 | self.logger.histo_summary(tag, self.to_np(value), generator_iter) 242 | self.logger.histo_summary(tag + '/grad', self.to_np(value.grad), generator_iter) 243 | 244 | # Log the images while training 245 | info = { 246 | 'real_images': self.real_images(images, self.number_of_images), 247 | 'generated_images': self.generate_img(z, self.number_of_images) 248 | } 249 | 250 | for tag, images in info.items(): 251 | self.logger.image_summary(tag, images, generator_iter) 252 | 253 | 254 | self.t_end = t.time() 255 | print('Time of training-{}'.format((self.t_end - self.t_begin))) 256 | #self.file.close() 257 | 258 | # Save the trained parameters 259 | self.save_model() 260 | 261 | def evaluate(self, test_loader, D_model_path, G_model_path): 262 | self.load_model(D_model_path, G_model_path) 263 | z = Variable(torch.randn(self.batch_size, 100, 1, 1)).cuda(self.cuda_index) 264 | samples = self.G(z) 265 | samples = samples.mul(0.5).add(0.5) 266 | samples = samples.data.cpu() 267 | grid = utils.make_grid(samples) 268 | print("Grid of 8x8 images saved to 'dgan_model_image.png'.") 269 | utils.save_image(grid, 'dgan_model_image.png') 270 | 271 | def real_images(self, images, number_of_images): 272 | if (self.C == 3): 273 | return self.to_np(images.view(-1, self.C, 32, 32)[:self.number_of_images]) 274 | else: 275 | return self.to_np(images.view(-1, 32, 32)[:self.number_of_images]) 276 | 277 | def generate_img(self, z, number_of_images): 278 | samples = self.G(z).data.cpu().numpy()[:number_of_images] 279 | generated_images = [] 280 | for sample in samples: 281 | if self.C == 3: 282 | generated_images.append(sample.reshape(self.C, 32, 32)) 283 | else: 284 | generated_images.append(sample.reshape(32, 32)) 285 | return generated_images 286 | 287 | def to_np(self, x): 288 | return x.data.cpu().numpy() 289 | 290 | def save_model(self): 291 | torch.save(self.G.state_dict(), './generator.pkl') 292 | torch.save(self.D.state_dict(), './discriminator.pkl') 293 | print('Models save to ./generator.pkl & ./discriminator.pkl ') 294 | 295 | def load_model(self, D_model_filename, G_model_filename): 296 | D_model_path = os.path.join(os.getcwd(), D_model_filename) 297 | G_model_path = os.path.join(os.getcwd(), G_model_filename) 298 | self.D.load_state_dict(torch.load(D_model_path)) 299 | self.G.load_state_dict(torch.load(G_model_path)) 300 | print('Generator model loaded from {}.'.format(G_model_path)) 301 | print('Discriminator model loaded from {}-'.format(D_model_path)) 302 | 303 | def generate_latent_walk(self, number): 304 | if not os.path.exists('interpolated_images/'): 305 | os.makedirs('interpolated_images/') 306 | 307 | # Interpolate between twe noise(z1, z2) with number_int steps between 308 | number_int = 10 309 | z_intp = torch.FloatTensor(1, 100, 1, 1) 310 | z1 = torch.randn(1, 100, 1, 1) 311 | z2 = torch.randn(1, 100, 1, 1) 312 | if self.cuda: 313 | z_intp = z_intp.cuda() 314 | z1 = z1.cuda() 315 | z2 = z2.cuda() 316 | 317 | z_intp = Variable(z_intp) 318 | images = [] 319 | alpha = 1.0 / float(number_int + 1) 320 | print(alpha) 321 | for i in range(1, number_int + 1): 322 | z_intp.data = z1*alpha + z2*(1.0 - alpha) 323 | alpha += alpha 324 | fake_im = self.G(z_intp) 325 | fake_im = fake_im.mul(0.5).add(0.5) #denormalize 326 | images.append(fake_im.view(self.C,32,32).data.cpu()) 327 | 328 | grid = utils.make_grid(images, nrow=number_int ) 329 | utils.save_image(grid, 'interpolated_images/interpolated_{}.png'.format(str(number).zfill(3))) 330 | print("Saved interpolated images to interpolated_images/interpolated_{}.".format(str(number).zfill(3))) 331 | -------------------------------------------------------------------------------- /models/gan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import utils 6 | from torch.autograd import Variable 7 | from utils.tensorboard_logger import Logger 8 | 9 | 10 | class GAN(object): 11 | def __init__(self, args): 12 | # Generator architecture 13 | self.G = nn.Sequential( 14 | nn.Linear(100, 256), 15 | nn.LeakyReLU(0.2), 16 | nn.Linear(256, 512), 17 | nn.LeakyReLU(0.2), 18 | nn.Linear(512, 1024), 19 | nn.LeakyReLU(0.2), 20 | nn.Tanh()) 21 | 22 | # Discriminator architecture 23 | self.D = nn.Sequential( 24 | nn.Linear(1024, 512), 25 | nn.LeakyReLU(0.2), 26 | nn.Linear(512, 256), 27 | nn.LeakyReLU(0.2), 28 | nn.Linear(256, 1), 29 | nn.Sigmoid()) 30 | 31 | self.cuda = False 32 | self.cuda_index = 0 33 | # check if cuda is available 34 | self.check_cuda(args.cuda) 35 | 36 | # Binary cross entropy loss and optimizer 37 | self.loss = nn.BCELoss() 38 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=0.0002, weight_decay=0.00001) 39 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=0.0002, weight_decay=0.00001) 40 | 41 | # Set the logger 42 | self.logger = Logger('./logs') 43 | self.number_of_images = 10 44 | self.epochs = args.epochs 45 | self.batch_size = args.batch_size 46 | 47 | # Cuda support 48 | def check_cuda(self, cuda_flag=False): 49 | if cuda_flag: 50 | self.cuda_index = 0 51 | self.cuda = True 52 | self.D.cuda(self.cuda_index) 53 | self.G.cuda(self.cuda_index) 54 | self.loss = nn.BCELoss().cuda(self.cuda_index) 55 | print("Cuda enabled flag: ") 56 | print(self.cuda) 57 | 58 | def train(self, train_loader): 59 | self.t_begin = time.time() 60 | generator_iter = 0 61 | 62 | for epoch in range(self.epochs+1): 63 | for i, (images, _) in enumerate(train_loader): 64 | # Check if round number of batches 65 | if i == train_loader.dataset.__len__() // self.batch_size: 66 | break 67 | 68 | # Flatten image 1,32x32 to 1024 69 | images = images.view(self.batch_size, -1) 70 | z = torch.rand((self.batch_size, 100)) 71 | 72 | if self.cuda: 73 | real_labels = Variable(torch.ones(self.batch_size)).cuda(self.cuda_index) 74 | fake_labels = Variable(torch.zeros(self.batch_size)).cuda(self.cuda_index) 75 | images, z = Variable(images.cuda(self.cuda_index)), Variable(z.cuda(self.cuda_index)) 76 | else: 77 | real_labels = Variable(torch.ones(self.batch_size)) 78 | fake_labels = Variable(torch.zeros(self.batch_size)) 79 | images, z = Variable(images), Variable(z) 80 | 81 | # Train discriminator 82 | # compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x)) 83 | # [Training discriminator = Maximizing discriminator being correct] 84 | outputs = self.D(images) 85 | d_loss_real = self.loss(outputs.flatten(), real_labels) 86 | real_score = outputs 87 | 88 | # Compute BCELoss using fake images 89 | fake_images = self.G(z) 90 | outputs = self.D(fake_images) 91 | d_loss_fake = self.loss(outputs.flatten(), fake_labels) 92 | fake_score = outputs 93 | 94 | # Optimizie discriminator 95 | d_loss = d_loss_real + d_loss_fake 96 | self.D.zero_grad() 97 | d_loss.backward() 98 | self.d_optimizer.step() 99 | 100 | # Train generator 101 | if self.cuda: 102 | z = Variable(torch.randn(self.batch_size, 100).cuda(self.cuda_index)) 103 | else: 104 | z = Variable(torch.randn(self.batch_size, 100)) 105 | fake_images = self.G(z) 106 | outputs = self.D(fake_images) 107 | 108 | # We train G to maximize log(D(G(z))[maximize likelihood of discriminator being wrong] instead of 109 | # minimizing log(1-D(G(z)))[minizing likelihood of discriminator being correct] 110 | # From paper [https://arxiv.org/pdf/1406.2661.pdf] 111 | g_loss = self.loss(outputs.flatten(), real_labels) 112 | 113 | # Optimize generator 114 | self.D.zero_grad() 115 | self.G.zero_grad() 116 | g_loss.backward() 117 | self.g_optimizer.step() 118 | generator_iter += 1 119 | 120 | 121 | if ((i + 1) % 100) == 0: 122 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 123 | ((epoch + 1), (i + 1), train_loader.dataset.__len__() // self.batch_size, d_loss.data, g_loss.data)) 124 | 125 | if self.cuda: 126 | z = Variable(torch.randn(self.batch_size, 100).cuda(self.cuda_index)) 127 | else: 128 | z = Variable(torch.randn(self.batch_size, 100)) 129 | 130 | # ============ TensorBoard logging ============# 131 | # (1) Log the scalar values 132 | info = { 133 | 'd_loss': d_loss.data, 134 | 'g_loss': g_loss.data 135 | } 136 | 137 | for tag, value in info.items(): 138 | self.logger.scalar_summary(tag, value, i + 1) 139 | 140 | # (2) Log values and gradients of the parameters (histogram) 141 | for tag, value in self.D.named_parameters(): 142 | tag = tag.replace('.', '/') 143 | self.logger.histo_summary(tag, self.to_np(value), i + 1) 144 | self.logger.histo_summary(tag + '/grad', self.to_np(value.grad), i + 1) 145 | 146 | # (3) Log the images 147 | info = { 148 | 'real_images': self.to_np(images.view(-1, 32, 32)[:self.number_of_images]), 149 | 'generated_images': self.generate_img(z, self.number_of_images) 150 | } 151 | 152 | for tag, images in info.items(): 153 | self.logger.image_summary(tag, images, i + 1) 154 | 155 | 156 | if generator_iter % 1000 == 0: 157 | print('Generator iter-{}'.format(generator_iter)) 158 | self.save_model() 159 | 160 | if not os.path.exists('training_result_images/'): 161 | os.makedirs('training_result_images/') 162 | 163 | # Denormalize images and save them in grid 8x8 164 | if self.cuda: 165 | z = Variable(torch.randn(self.batch_size, 100).cuda(self.cuda_index)) 166 | else: 167 | z = Variable(torch.randn(self.batch_size, 100)) 168 | samples = self.G(z) 169 | samples = samples.mul(0.5).add(0.5) 170 | samples = samples.data.cpu() 171 | grid = utils.make_grid(samples) 172 | utils.save_image(grid, 'training_result_images/gan_image_iter_{}.png'.format( 173 | str(generator_iter).zfill(3))) 174 | 175 | self.t_end = time.time() 176 | print('Time of training-{}'.format((self.t_end - self.t_begin))) 177 | # Save the trained parameters 178 | self.save_model() 179 | 180 | def evaluate(self, test_loader, D_model_path, G_model_path): 181 | self.load_model(D_model_path, G_model_path) 182 | if self.cuda: 183 | z = Variable(torch.randn(self.batch_size, 100).cuda(self.cuda_index)) 184 | else: 185 | z = Variable(torch.randn(self.batch_size, 100)) 186 | samples = self.G(z) 187 | samples = samples.mul(0.5).add(0.5) 188 | samples = samples.data.cpu() 189 | grid = utils.make_grid(samples) 190 | print("Grid of 8x8 images saved to 'gan_model_image.png'.") 191 | utils.save_image(grid, 'gan_model_image.png') 192 | 193 | def generate_img(self, z, number_of_images): 194 | samples = self.G(z).data.cpu().numpy()[:number_of_images] 195 | generated_images = [] 196 | for sample in samples: 197 | generated_images.append(sample.reshape(32,32)) 198 | return generated_images 199 | 200 | def to_np(self, x): 201 | return x.data.cpu().numpy() 202 | 203 | def save_model(self): 204 | torch.save(self.G.state_dict(), './generator.pkl') 205 | torch.save(self.D.state_dict(), './discriminator.pkl') 206 | print('Models save to ./generator.pkl & ./discriminator.pkl ') 207 | 208 | def load_model(self, D_model_filename, G_model_filename): 209 | D_model_path = os.path.join(os.getcwd(), D_model_filename) 210 | G_model_path = os.path.join(os.getcwd(), G_model_filename) 211 | self.D.load_state_dict(torch.load(D_model_path)) 212 | self.G.load_state_dict(torch.load(G_model_path)) 213 | print('Generator model loaded from {}.'.format(G_model_path)) 214 | print('Discriminator model loaded from {}-'.format(D_model_path)) 215 | -------------------------------------------------------------------------------- /models/wgan_clipping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import time as t 5 | import matplotlib.pyplot as plt 6 | plt.switch_backend('agg') 7 | import os 8 | from utils.tensorboard_logger import Logger 9 | from torchvision import utils 10 | 11 | 12 | SAVE_PER_TIMES = 1000 13 | 14 | class Generator(torch.nn.Module): 15 | def __init__(self, channels): 16 | super().__init__() 17 | # Filters [1024, 512, 256] 18 | # Input_dim = 100 19 | # Output_dim = C (number of channels) 20 | self.main_module = nn.Sequential( 21 | # Z latent vector 100 22 | nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0), 23 | nn.BatchNorm2d(num_features=1024), 24 | nn.ReLU(True), 25 | 26 | # State (1024x4x4) 27 | nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1), 28 | nn.BatchNorm2d(num_features=512), 29 | nn.ReLU(True), 30 | 31 | # State (512x8x8) 32 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1), 33 | nn.BatchNorm2d(num_features=256), 34 | nn.ReLU(True), 35 | 36 | # State (256x16x16) 37 | nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1)) 38 | # output of main module --> Image (Cx32x32) 39 | 40 | self.output = nn.Tanh() 41 | 42 | def forward(self, x): 43 | x = self.main_module(x) 44 | return self.output(x) 45 | 46 | class Discriminator(torch.nn.Module): 47 | def __init__(self, channels): 48 | super().__init__() 49 | # Filters [256, 512, 1024] 50 | # Input_dim = channels (Cx64x64) 51 | # Output_dim = 1 52 | self.main_module = nn.Sequential( 53 | # Image (Cx32x32) 54 | nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1), 55 | nn.BatchNorm2d(num_features=256), 56 | nn.LeakyReLU(0.2, inplace=True), 57 | 58 | # State (256x16x16) 59 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1), 60 | nn.BatchNorm2d(num_features=512), 61 | nn.LeakyReLU(0.2, inplace=True), 62 | 63 | # State (512x8x8) 64 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1), 65 | nn.BatchNorm2d(num_features=1024), 66 | nn.LeakyReLU(0.2, inplace=True)) 67 | # output of main module --> State (1024x4x4) 68 | 69 | self.output = nn.Sequential( 70 | # The output of D is no longer a probability, we do not apply sigmoid at the output of D. 71 | nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0)) 72 | 73 | 74 | def forward(self, x): 75 | x = self.main_module(x) 76 | return self.output(x) 77 | 78 | def feature_extraction(self, x): 79 | # Use discriminator for feature extraction then flatten to vector of 16384 80 | x = self.main_module(x) 81 | return x.view(-1, 1024*4*4) 82 | 83 | 84 | class WGAN_CP(object): 85 | def __init__(self, args): 86 | print("WGAN_CP init model.") 87 | self.G = Generator(args.channels) 88 | self.D = Discriminator(args.channels) 89 | self.C = args.channels 90 | 91 | # check if cuda is available 92 | self.check_cuda(args.cuda) 93 | 94 | # WGAN values from paper 95 | self.learning_rate = 0.00005 96 | 97 | self.batch_size = 64 98 | self.weight_cliping_limit = 0.01 99 | 100 | # WGAN with gradient clipping uses RMSprop instead of ADAM 101 | self.d_optimizer = torch.optim.RMSprop(self.D.parameters(), lr=self.learning_rate) 102 | self.g_optimizer = torch.optim.RMSprop(self.G.parameters(), lr=self.learning_rate) 103 | 104 | # Set the logger 105 | self.logger = Logger('./logs') 106 | self.logger.writer.flush() 107 | self.number_of_images = 10 108 | 109 | self.generator_iters = args.generator_iters 110 | self.critic_iter = 5 111 | 112 | def get_torch_variable(self, arg): 113 | if self.cuda: 114 | return Variable(arg).cuda(self.cuda_index) 115 | else: 116 | return Variable(arg) 117 | 118 | def check_cuda(self, cuda_flag=False): 119 | if cuda_flag: 120 | self.cuda_index = 0 121 | self.cuda = True 122 | self.D.cuda(self.cuda_index) 123 | self.G.cuda(self.cuda_index) 124 | print("Cuda enabled flag: {}".format(self.cuda)) 125 | else: 126 | self.cuda = False 127 | 128 | 129 | def train(self, train_loader): 130 | self.t_begin = t.time() 131 | #self.file = open("inception_score_graph.txt", "w") 132 | 133 | # Now batches are callable self.data.next() 134 | self.data = self.get_infinite_batches(train_loader) 135 | 136 | one = torch.FloatTensor([1]) 137 | mone = one * -1 138 | if self.cuda: 139 | one = one.cuda(self.cuda_index) 140 | mone = mone.cuda(self.cuda_index) 141 | 142 | for g_iter in range(self.generator_iters): 143 | 144 | # Requires grad, Generator requires_grad = False 145 | for p in self.D.parameters(): 146 | p.requires_grad = True 147 | 148 | # Train Dicriminator forward-loss-backward-update self.critic_iter times while 1 Generator forward-loss-backward-update 149 | for d_iter in range(self.critic_iter): 150 | self.D.zero_grad() 151 | 152 | # Clamp parameters to a range [-c, c], c=self.weight_cliping_limit 153 | for p in self.D.parameters(): 154 | p.data.clamp_(-self.weight_cliping_limit, self.weight_cliping_limit) 155 | 156 | images = self.data.__next__() 157 | # Check for batch to have full batch_size 158 | if (images.size()[0] != self.batch_size): 159 | continue 160 | 161 | z = torch.rand((self.batch_size, 100, 1, 1)) 162 | 163 | images, z = self.get_torch_variable(images), self.get_torch_variable(z) 164 | 165 | 166 | # Train discriminator 167 | # WGAN - Training discriminator more iterations than generator 168 | # Train with real images 169 | d_loss_real = self.D(images) 170 | d_loss_real = d_loss_real.mean(0).view(1) 171 | d_loss_real.backward(one) 172 | 173 | # Train with fake images 174 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1)) 175 | fake_images = self.G(z) 176 | d_loss_fake = self.D(fake_images) 177 | d_loss_fake = d_loss_fake.mean(0).view(1) 178 | d_loss_fake.backward(mone) 179 | 180 | d_loss = d_loss_fake - d_loss_real 181 | Wasserstein_D = d_loss_real - d_loss_fake 182 | self.d_optimizer.step() 183 | print(f' Discriminator iteration: {d_iter}/{self.critic_iter}, loss_fake: {d_loss_fake.data}, loss_real: {d_loss_real.data}') 184 | 185 | 186 | 187 | # Generator update 188 | for p in self.D.parameters(): 189 | p.requires_grad = False # to avoid computation 190 | 191 | self.G.zero_grad() 192 | 193 | # Train generator 194 | # Compute loss with fake images 195 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1)) 196 | fake_images = self.G(z) 197 | g_loss = self.D(fake_images) 198 | g_loss = g_loss.mean().mean(0).view(1) 199 | g_loss.backward(one) 200 | g_cost = -g_loss 201 | self.g_optimizer.step() 202 | print(f'Generator iteration: {g_iter}/{self.generator_iters}, g_loss: {g_loss.data}') 203 | 204 | # Saving model and sampling images every 1000th generator iterations 205 | if (g_iter) % SAVE_PER_TIMES == 0: 206 | self.save_model() 207 | # Workaround because graphic card memory can't store more than 830 examples in memory for generating image 208 | # Therefore doing loop and generating 800 examples and stacking into list of samples to get 8000 generated images 209 | # This way Inception score is more correct since there are different generated examples from every class of Inception model 210 | # sample_list = [] 211 | # for i in range(10): 212 | # z = Variable(torch.randn(800, 100, 1, 1)).cuda(self.cuda_index) 213 | # samples = self.G(z) 214 | # sample_list.append(samples.data.cpu().numpy()) 215 | # 216 | # # Flattening list of list into one list 217 | # new_sample_list = list(chain.from_iterable(sample_list)) 218 | # print("Calculating Inception Score over 8k generated images") 219 | # # Feeding list of numpy arrays 220 | # inception_score = get_inception_score(new_sample_list, cuda=True, batch_size=32, 221 | # resize=True, splits=10) 222 | 223 | if not os.path.exists('training_result_images/'): 224 | os.makedirs('training_result_images/') 225 | 226 | # Denormalize images and save them in grid 8x8 227 | z = self.get_torch_variable(torch.randn(800, 100, 1, 1)) 228 | samples = self.G(z) 229 | samples = samples.mul(0.5).add(0.5) 230 | samples = samples.data.cpu()[:64] 231 | grid = utils.make_grid(samples) 232 | utils.save_image(grid, 'training_result_images/img_generatori_iter_{}.png'.format(str(g_iter).zfill(3))) 233 | 234 | # Testing 235 | time = t.time() - self.t_begin 236 | #print("Inception score: {}".format(inception_score)) 237 | print("Generator iter: {}".format(g_iter)) 238 | print("Time {}".format(time)) 239 | 240 | # Write to file inception_score, gen_iters, time 241 | #output = str(g_iter) + " " + str(time) + " " + str(inception_score[0]) + "\n" 242 | #self.file.write(output) 243 | 244 | # ============ TensorBoard logging ============# 245 | # (1) Log the scalar values 246 | info = { 247 | 'Wasserstein distance': Wasserstein_D.data, 248 | 'Loss D': d_loss.data, 249 | 'Loss G': g_cost.data, 250 | 'Loss D Real': d_loss_real.data, 251 | 'Loss D Fake': d_loss_fake.data 252 | } 253 | 254 | for tag, value in info.items(): 255 | self.logger.scalar_summary(tag, value.mean().cpu(), g_iter + 1) 256 | 257 | # (3) Log the images 258 | info = { 259 | 'real_images': self.real_images(images, self.number_of_images), 260 | 'generated_images': self.generate_img(z, self.number_of_images) 261 | } 262 | 263 | for tag, images in info.items(): 264 | self.logger.image_summary(tag, images, g_iter + 1) 265 | 266 | self.t_end = t.time() 267 | print('Time of training-{}'.format((self.t_end - self.t_begin))) 268 | #self.file.close() 269 | 270 | # Save the trained parameters 271 | self.save_model() 272 | 273 | def evaluate(self, test_loader, D_model_path, G_model_path): 274 | self.load_model(D_model_path, G_model_path) 275 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1)) 276 | samples = self.G(z) 277 | samples = samples.mul(0.5).add(0.5) 278 | samples = samples.data.cpu() 279 | grid = utils.make_grid(samples) 280 | print("Grid of 8x8 images saved to 'dgan_model_image.png'.") 281 | utils.save_image(grid, 'dgan_model_image.png') 282 | 283 | def real_images(self, images, number_of_images): 284 | if (self.C == 3): 285 | return self.to_np(images.view(-1, self.C, 32, 32)[:self.number_of_images]) 286 | else: 287 | return self.to_np(images.view(-1, 32, 32)[:self.number_of_images]) 288 | 289 | def generate_img(self, z, number_of_images): 290 | samples = self.G(z).data.cpu().numpy()[:number_of_images] 291 | generated_images = [] 292 | for sample in samples: 293 | if self.C == 3: 294 | generated_images.append(sample.reshape(self.C, 32, 32)) 295 | else: 296 | generated_images.append(sample.reshape(32, 32)) 297 | return generated_images 298 | 299 | def to_np(self, x): 300 | return x.data.cpu().numpy() 301 | 302 | def save_model(self): 303 | torch.save(self.G.state_dict(), './generator.pkl') 304 | torch.save(self.D.state_dict(), './discriminator.pkl') 305 | print('Models save to ./generator.pkl & ./discriminator.pkl ') 306 | 307 | def load_model(self, D_model_filename, G_model_filename): 308 | D_model_path = os.path.join(os.getcwd(), D_model_filename) 309 | G_model_path = os.path.join(os.getcwd(), G_model_filename) 310 | self.D.load_state_dict(torch.load(D_model_path)) 311 | self.G.load_state_dict(torch.load(G_model_path)) 312 | print('Generator model loaded from {}.'.format(G_model_path)) 313 | print('Discriminator model loaded from {}-'.format(D_model_path)) 314 | 315 | def get_infinite_batches(self, data_loader): 316 | while True: 317 | for i, (images, _) in enumerate(data_loader): 318 | yield images 319 | 320 | 321 | def generate_latent_walk(self, number): 322 | if not os.path.exists('interpolated_images/'): 323 | os.makedirs('interpolated_images/') 324 | 325 | number_int = 10 326 | # interpolate between two noise (z1, z2). 327 | z_intp = torch.FloatTensor(1, 100, 1, 1) 328 | z1 = torch.randn(1, 100, 1, 1) 329 | z2 = torch.randn(1, 100, 1, 1) 330 | if self.cuda: 331 | z_intp = z_intp.cuda() 332 | z1 = z1.cuda() 333 | z2 = z2.cuda() 334 | 335 | z_intp = Variable(z_intp) 336 | images = [] 337 | alpha = 1.0 / float(number_int + 1) 338 | print(alpha) 339 | for i in range(1, number_int + 1): 340 | z_intp.data = z1*alpha + z2*(1.0 - alpha) 341 | alpha += alpha 342 | fake_im = self.G(z_intp) 343 | fake_im = fake_im.mul(0.5).add(0.5) #denormalize 344 | images.append(fake_im.view(self.C,32,32).data.cpu()) 345 | 346 | grid = utils.make_grid(images, nrow=number_int ) 347 | utils.save_image(grid, 'interpolated_images/interpolated_{}.png'.format(str(number).zfill(3))) 348 | print("Saved interpolated images.") 349 | -------------------------------------------------------------------------------- /models/wgan_gradient_penalty.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from torch import autograd 6 | import time as t 7 | import matplotlib.pyplot as plt 8 | plt.switch_backend('agg') 9 | import os 10 | from utils.tensorboard_logger import Logger 11 | from itertools import chain 12 | from torchvision import utils 13 | 14 | SAVE_PER_TIMES = 100 15 | 16 | class Generator(torch.nn.Module): 17 | def __init__(self, channels): 18 | super().__init__() 19 | # Filters [1024, 512, 256] 20 | # Input_dim = 100 21 | # Output_dim = C (number of channels) 22 | self.main_module = nn.Sequential( 23 | # Z latent vector 100 24 | nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0), 25 | nn.BatchNorm2d(num_features=1024), 26 | nn.ReLU(True), 27 | 28 | # State (1024x4x4) 29 | nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1), 30 | nn.BatchNorm2d(num_features=512), 31 | nn.ReLU(True), 32 | 33 | # State (512x8x8) 34 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1), 35 | nn.BatchNorm2d(num_features=256), 36 | nn.ReLU(True), 37 | 38 | # State (256x16x16) 39 | nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1)) 40 | # output of main module --> Image (Cx32x32) 41 | 42 | self.output = nn.Tanh() 43 | 44 | def forward(self, x): 45 | x = self.main_module(x) 46 | return self.output(x) 47 | 48 | 49 | class Discriminator(torch.nn.Module): 50 | def __init__(self, channels): 51 | super().__init__() 52 | # Filters [256, 512, 1024] 53 | # Input_dim = channels (Cx64x64) 54 | # Output_dim = 1 55 | self.main_module = nn.Sequential( 56 | # Omitting batch normalization in critic because our new penalized training objective (WGAN with gradient penalty) is no longer valid 57 | # in this setting, since we penalize the norm of the critic's gradient with respect to each input independently and not the enitre batch. 58 | # There is not good & fast implementation of layer normalization --> using per instance normalization nn.InstanceNorm2d() 59 | # Image (Cx32x32) 60 | nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1), 61 | nn.InstanceNorm2d(256, affine=True), 62 | nn.LeakyReLU(0.2, inplace=True), 63 | 64 | # State (256x16x16) 65 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1), 66 | nn.InstanceNorm2d(512, affine=True), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | 69 | # State (512x8x8) 70 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1), 71 | nn.InstanceNorm2d(1024, affine=True), 72 | nn.LeakyReLU(0.2, inplace=True)) 73 | # output of main module --> State (1024x4x4) 74 | 75 | self.output = nn.Sequential( 76 | # The output of D is no longer a probability, we do not apply sigmoid at the output of D. 77 | nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0)) 78 | 79 | 80 | def forward(self, x): 81 | x = self.main_module(x) 82 | return self.output(x) 83 | 84 | def feature_extraction(self, x): 85 | # Use discriminator for feature extraction then flatten to vector of 16384 86 | x = self.main_module(x) 87 | return x.view(-1, 1024*4*4) 88 | 89 | 90 | class WGAN_GP(object): 91 | def __init__(self, args): 92 | print("WGAN_GradientPenalty init model.") 93 | self.G = Generator(args.channels) 94 | self.D = Discriminator(args.channels) 95 | self.C = args.channels 96 | 97 | # Check if cuda is available 98 | self.check_cuda(args.cuda) 99 | 100 | # WGAN values from paper 101 | self.learning_rate = 1e-4 102 | self.b1 = 0.5 103 | self.b2 = 0.999 104 | self.batch_size = 64 105 | 106 | # WGAN_gradient penalty uses ADAM 107 | self.d_optimizer = optim.Adam(self.D.parameters(), lr=self.learning_rate, betas=(self.b1, self.b2)) 108 | self.g_optimizer = optim.Adam(self.G.parameters(), lr=self.learning_rate, betas=(self.b1, self.b2)) 109 | 110 | # Set the logger 111 | self.logger = Logger('./logs') 112 | self.logger.writer.flush() 113 | self.number_of_images = 10 114 | 115 | self.generator_iters = args.generator_iters 116 | self.critic_iter = 5 117 | self.lambda_term = 10 118 | 119 | def get_torch_variable(self, arg): 120 | if self.cuda: 121 | return Variable(arg).cuda(self.cuda_index) 122 | else: 123 | return Variable(arg) 124 | 125 | def check_cuda(self, cuda_flag=False): 126 | print(cuda_flag) 127 | if cuda_flag: 128 | self.cuda_index = 0 129 | self.cuda = True 130 | self.D.cuda(self.cuda_index) 131 | self.G.cuda(self.cuda_index) 132 | print("Cuda enabled flag: {}".format(self.cuda)) 133 | else: 134 | self.cuda = False 135 | 136 | 137 | def train(self, train_loader): 138 | self.t_begin = t.time() 139 | self.file = open("inception_score_graph.txt", "w") 140 | 141 | # Now batches are callable self.data.next() 142 | self.data = self.get_infinite_batches(train_loader) 143 | 144 | one = torch.tensor(1, dtype=torch.float) 145 | mone = one * -1 146 | if self.cuda: 147 | one = one.cuda(self.cuda_index) 148 | mone = mone.cuda(self.cuda_index) 149 | 150 | for g_iter in range(self.generator_iters): 151 | # Requires grad, Generator requires_grad = False 152 | for p in self.D.parameters(): 153 | p.requires_grad = True 154 | 155 | d_loss_real = 0 156 | d_loss_fake = 0 157 | Wasserstein_D = 0 158 | # Train Dicriminator forward-loss-backward-update self.critic_iter times while 1 Generator forward-loss-backward-update 159 | for d_iter in range(self.critic_iter): 160 | self.D.zero_grad() 161 | 162 | images = self.data.__next__() 163 | # Check for batch to have full batch_size 164 | if (images.size()[0] != self.batch_size): 165 | continue 166 | 167 | z = torch.rand((self.batch_size, 100, 1, 1)) 168 | 169 | images, z = self.get_torch_variable(images), self.get_torch_variable(z) 170 | 171 | # Train discriminator 172 | # WGAN - Training discriminator more iterations than generator 173 | # Train with real images 174 | d_loss_real = self.D(images) 175 | d_loss_real = d_loss_real.mean() 176 | d_loss_real.backward(mone) 177 | 178 | # Train with fake images 179 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1)) 180 | 181 | fake_images = self.G(z) 182 | d_loss_fake = self.D(fake_images) 183 | d_loss_fake = d_loss_fake.mean() 184 | d_loss_fake.backward(one) 185 | 186 | # Train with gradient penalty 187 | gradient_penalty = self.calculate_gradient_penalty(images.data, fake_images.data) 188 | gradient_penalty.backward() 189 | 190 | 191 | d_loss = d_loss_fake - d_loss_real + gradient_penalty 192 | Wasserstein_D = d_loss_real - d_loss_fake 193 | self.d_optimizer.step() 194 | print(f' Discriminator iteration: {d_iter}/{self.critic_iter}, loss_fake: {d_loss_fake}, loss_real: {d_loss_real}') 195 | 196 | # Generator update 197 | for p in self.D.parameters(): 198 | p.requires_grad = False # to avoid computation 199 | 200 | self.G.zero_grad() 201 | # train generator 202 | # compute loss with fake images 203 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1)) 204 | fake_images = self.G(z) 205 | g_loss = self.D(fake_images) 206 | g_loss = g_loss.mean() 207 | g_loss.backward(mone) 208 | g_cost = -g_loss 209 | self.g_optimizer.step() 210 | print(f'Generator iteration: {g_iter}/{self.generator_iters}, g_loss: {g_loss}') 211 | # Saving model and sampling images every 1000th generator iterations 212 | if (g_iter) % SAVE_PER_TIMES == 0: 213 | self.save_model() 214 | # # Workaround because graphic card memory can't store more than 830 examples in memory for generating image 215 | # # Therefore doing loop and generating 800 examples and stacking into list of samples to get 8000 generated images 216 | # # This way Inception score is more correct since there are different generated examples from every class of Inception model 217 | # sample_list = [] 218 | # for i in range(125): 219 | # samples = self.data.__next__() 220 | # # z = Variable(torch.randn(800, 100, 1, 1)).cuda(self.cuda_index) 221 | # # samples = self.G(z) 222 | # sample_list.append(samples.data.cpu().numpy()) 223 | # # 224 | # # # Flattening list of list into one list 225 | # new_sample_list = list(chain.from_iterable(sample_list)) 226 | # print("Calculating Inception Score over 8k generated images") 227 | # # # Feeding list of numpy arrays 228 | # inception_score = get_inception_score(new_sample_list, cuda=True, batch_size=32, 229 | # resize=True, splits=10) 230 | 231 | if not os.path.exists('training_result_images/'): 232 | os.makedirs('training_result_images/') 233 | 234 | # Denormalize images and save them in grid 8x8 235 | z = self.get_torch_variable(torch.randn(800, 100, 1, 1)) 236 | samples = self.G(z) 237 | samples = samples.mul(0.5).add(0.5) 238 | samples = samples.data.cpu()[:64] 239 | grid = utils.make_grid(samples) 240 | utils.save_image(grid, 'training_result_images/img_generatori_iter_{}.png'.format(str(g_iter).zfill(3))) 241 | 242 | # Testing 243 | time = t.time() - self.t_begin 244 | #print("Real Inception score: {}".format(inception_score)) 245 | print("Generator iter: {}".format(g_iter)) 246 | print("Time {}".format(time)) 247 | 248 | # Write to file inception_score, gen_iters, time 249 | #output = str(g_iter) + " " + str(time) + " " + str(inception_score[0]) + "\n" 250 | #self.file.write(output) 251 | 252 | 253 | # ============ TensorBoard logging ============# 254 | # (1) Log the scalar values 255 | info = { 256 | 'Wasserstein distance': Wasserstein_D.data, 257 | 'Loss D': d_loss.data, 258 | 'Loss G': g_cost.data, 259 | 'Loss D Real': d_loss_real.data, 260 | 'Loss D Fake': d_loss_fake.data 261 | 262 | } 263 | 264 | for tag, value in info.items(): 265 | self.logger.scalar_summary(tag, value.cpu(), g_iter + 1) 266 | 267 | # (3) Log the images 268 | info = { 269 | 'real_images': self.real_images(images, self.number_of_images), 270 | 'generated_images': self.generate_img(z, self.number_of_images) 271 | } 272 | 273 | for tag, images in info.items(): 274 | self.logger.image_summary(tag, images, g_iter + 1) 275 | 276 | 277 | 278 | self.t_end = t.time() 279 | print('Time of training-{}'.format((self.t_end - self.t_begin))) 280 | #self.file.close() 281 | 282 | # Save the trained parameters 283 | self.save_model() 284 | 285 | def evaluate(self, test_loader, D_model_path, G_model_path): 286 | self.load_model(D_model_path, G_model_path) 287 | z = self.get_torch_variable(torch.randn(self.batch_size, 100, 1, 1)) 288 | samples = self.G(z) 289 | samples = samples.mul(0.5).add(0.5) 290 | samples = samples.data.cpu() 291 | grid = utils.make_grid(samples) 292 | print("Grid of 8x8 images saved to 'dgan_model_image.png'.") 293 | utils.save_image(grid, 'dgan_model_image.png') 294 | 295 | 296 | def calculate_gradient_penalty(self, real_images, fake_images): 297 | eta = torch.FloatTensor(self.batch_size,1,1,1).uniform_(0,1) 298 | eta = eta.expand(self.batch_size, real_images.size(1), real_images.size(2), real_images.size(3)) 299 | if self.cuda: 300 | eta = eta.cuda(self.cuda_index) 301 | else: 302 | eta = eta 303 | 304 | interpolated = eta * real_images + ((1 - eta) * fake_images) 305 | 306 | if self.cuda: 307 | interpolated = interpolated.cuda(self.cuda_index) 308 | else: 309 | interpolated = interpolated 310 | 311 | # define it to calculate gradient 312 | interpolated = Variable(interpolated, requires_grad=True) 313 | 314 | # calculate probability of interpolated examples 315 | prob_interpolated = self.D(interpolated) 316 | 317 | # calculate gradients of probabilities with respect to examples 318 | gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated, 319 | grad_outputs=torch.ones( 320 | prob_interpolated.size()).cuda(self.cuda_index) if self.cuda else torch.ones( 321 | prob_interpolated.size()), 322 | create_graph=True, retain_graph=True)[0] 323 | 324 | # flatten the gradients to it calculates norm batchwise 325 | gradients = gradients.view(gradients.size(0), -1) 326 | 327 | grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_term 328 | return grad_penalty 329 | 330 | def real_images(self, images, number_of_images): 331 | if (self.C == 3): 332 | return self.to_np(images.view(-1, self.C, 32, 32)[:self.number_of_images]) 333 | else: 334 | return self.to_np(images.view(-1, 32, 32)[:self.number_of_images]) 335 | 336 | def generate_img(self, z, number_of_images): 337 | samples = self.G(z).data.cpu().numpy()[:number_of_images] 338 | generated_images = [] 339 | for sample in samples: 340 | if self.C == 3: 341 | generated_images.append(sample.reshape(self.C, 32, 32)) 342 | else: 343 | generated_images.append(sample.reshape(32, 32)) 344 | return generated_images 345 | 346 | def to_np(self, x): 347 | return x.data.cpu().numpy() 348 | 349 | def save_model(self): 350 | torch.save(self.G.state_dict(), './generator.pkl') 351 | torch.save(self.D.state_dict(), './discriminator.pkl') 352 | print('Models save to ./generator.pkl & ./discriminator.pkl ') 353 | 354 | def load_model(self, D_model_filename, G_model_filename): 355 | D_model_path = os.path.join(os.getcwd(), D_model_filename) 356 | G_model_path = os.path.join(os.getcwd(), G_model_filename) 357 | self.D.load_state_dict(torch.load(D_model_path)) 358 | self.G.load_state_dict(torch.load(G_model_path)) 359 | print('Generator model loaded from {}.'.format(G_model_path)) 360 | print('Discriminator model loaded from {}-'.format(D_model_path)) 361 | 362 | def get_infinite_batches(self, data_loader): 363 | while True: 364 | for i, (images, _) in enumerate(data_loader): 365 | yield images 366 | 367 | def generate_latent_walk(self, number): 368 | if not os.path.exists('interpolated_images/'): 369 | os.makedirs('interpolated_images/') 370 | 371 | number_int = 10 372 | # interpolate between twe noise(z1, z2). 373 | z_intp = torch.FloatTensor(1, 100, 1, 1) 374 | z1 = torch.randn(1, 100, 1, 1) 375 | z2 = torch.randn(1, 100, 1, 1) 376 | if self.cuda: 377 | z_intp = z_intp.cuda() 378 | z1 = z1.cuda() 379 | z2 = z2.cuda() 380 | 381 | z_intp = Variable(z_intp) 382 | images = [] 383 | alpha = 1.0 / float(number_int + 1) 384 | print(alpha) 385 | for i in range(1, number_int + 1): 386 | z_intp.data = z1*alpha + z2*(1.0 - alpha) 387 | alpha += alpha 388 | fake_im = self.G(z_intp) 389 | fake_im = fake_im.mul(0.5).add(0.5) #denormalize 390 | images.append(fake_im.view(self.C,32,32).data.cpu()) 391 | 392 | grid = utils.make_grid(images, nrow=number_int ) 393 | utils.save_image(grid, 'interpolated_images/interpolated_{}.png'.format(str(number).zfill(3))) 394 | print("Saved interpolated images.") 395 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.1 2 | numpy==1.22.0 3 | Pillow==9.0.0 4 | scikit_learn==1.0.2 5 | scipy==1.7.3 6 | six==1.16.0 7 | tensorflow==2.7.0 8 | torch==1.10.1 9 | torchvision==0.11.2 10 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zeleni9/pytorch-wgan/e594e2eef7dbd82d6ad23e9442006f6aee08db6e/utils/__init__.py -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description="Pytorch implementation of GAN models.") 7 | 8 | parser.add_argument('--model', type=str, default='DCGAN', choices=['GAN', 'DCGAN', 'WGAN-CP', 'WGAN-GP']) 9 | parser.add_argument('--is_train', type=str, default='True') 10 | parser.add_argument('--dataroot', required=True, help='path to dataset') 11 | parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'cifar', 'stl10'], 12 | help='The name of dataset') 13 | parser.add_argument('--download', type=str, default='False') 14 | parser.add_argument('--epochs', type=int, default=50, help='The number of epochs to run') 15 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 16 | parser.add_argument('--cuda', type=str, default='False', help='Availability of cuda') 17 | 18 | parser.add_argument('--load_D', type=str, default='False', help='Path for loading Discriminator network') 19 | parser.add_argument('--load_G', type=str, default='False', help='Path for loading Generator network') 20 | parser.add_argument('--generator_iters', type=int, default=10000, help='The number of iterations for generator in WGAN model.') 21 | return check_args(parser.parse_args()) 22 | 23 | 24 | # Checking arguments 25 | def check_args(args): 26 | # --epoch 27 | try: 28 | assert args.epochs >= 1 29 | except: 30 | print('Number of epochs must be larger than or equal to one') 31 | 32 | # --batch_size 33 | try: 34 | assert args.batch_size >= 1 35 | except: 36 | print('Batch size must be larger than or equal to one') 37 | 38 | if args.dataset == 'cifar' or args.dataset == 'stl10': 39 | args.channels = 3 40 | else: 41 | args.channels = 1 42 | args.cuda = True if args.cuda == 'True' else False 43 | return args 44 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import torchvision.datasets as dset 2 | import torchvision.transforms as transforms 3 | import torch.utils.data as data_utils 4 | from utils.fashion_mnist import MNIST, FashionMNIST 5 | 6 | 7 | def get_data_loader(args): 8 | 9 | if args.dataset == 'mnist': 10 | trans = transforms.Compose([ 11 | transforms.Resize(32), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.5, ), (0.5, )), 14 | ]) 15 | train_dataset = MNIST(root=args.dataroot, train=True, download=args.download, transform=trans) 16 | test_dataset = MNIST(root=args.dataroot, train=False, download=args.download, transform=trans) 17 | 18 | elif args.dataset == 'fashion-mnist': 19 | trans = transforms.Compose([ 20 | transforms.Resize(32), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, ), (0.5, )), 23 | ]) 24 | train_dataset = FashionMNIST(root=args.dataroot, train=True, download=args.download, transform=trans) 25 | test_dataset = FashionMNIST(root=args.dataroot, train=False, download=args.download, transform=trans) 26 | 27 | elif args.dataset == 'cifar': 28 | trans = transforms.Compose([ 29 | transforms.Resize(32), 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 32 | ]) 33 | 34 | train_dataset = dset.CIFAR10(root=args.dataroot, train=True, download=args.download, transform=trans) 35 | test_dataset = dset.CIFAR10(root=args.dataroot, train=False, download=args.download, transform=trans) 36 | 37 | elif args.dataset == 'stl10': 38 | trans = transforms.Compose([ 39 | transforms.Resize(32), 40 | transforms.ToTensor(), 41 | ]) 42 | train_dataset = dset.STL10(root=args.dataroot, split='train', download=args.download, transform=trans) 43 | test_dataset = dset.STL10(root=args.dataroot, split='test', download=args.download, transform=trans) 44 | 45 | # Check if everything is ok with loading datasets 46 | assert train_dataset 47 | assert test_dataset 48 | 49 | train_dataloader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 50 | test_dataloader = data_utils.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True) 51 | 52 | return train_dataloader, test_dataloader 53 | -------------------------------------------------------------------------------- /utils/fashion_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import errno 7 | import torch 8 | import codecs 9 | 10 | # Code referenced from torch source code to add Fashion-MNSIT dataset to dataloder 11 | # Url: http://pytorch.org/docs/0.3.0/_modules/torchvision/datasets/mnist.html#FashionMNIST 12 | class MNIST(data.Dataset): 13 | """`MNIST `_ Dataset. 14 | Args: 15 | root (string): Root directory of dataset where ``processed/training.pt`` 16 | and ``processed/test.pt`` exist. 17 | train (bool, optional): If True, creates dataset from ``training.pt``, 18 | otherwise from ``test.pt``. 19 | download (bool, optional): If true, downloads the dataset from the internet and 20 | puts it in root directory. If dataset is already downloaded, it is not 21 | downloaded again. 22 | transform (callable, optional): A function/transform that takes in an PIL image 23 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 24 | target_transform (callable, optional): A function/transform that takes in the 25 | target and transforms it. 26 | """ 27 | urls = [ 28 | 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 29 | 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 30 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 31 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 32 | ] 33 | raw_folder = 'raw' 34 | processed_folder = 'processed' 35 | training_file = 'training.pt' 36 | test_file = 'test.pt' 37 | 38 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 39 | self.root = os.path.expanduser(root) 40 | self.transform = transform 41 | self.target_transform = target_transform 42 | self.train = train # training set or test set 43 | 44 | if download: 45 | self.download() 46 | 47 | if not self._check_exists(): 48 | raise RuntimeError('Dataset not found.' + 49 | ' You can use download=True to download it') 50 | 51 | if self.train: 52 | self.train_data, self.train_labels = torch.load( 53 | os.path.join(self.root, self.processed_folder, self.training_file)) 54 | else: 55 | self.test_data, self.test_labels = torch.load( 56 | os.path.join(self.root, self.processed_folder, self.test_file)) 57 | 58 | def __getitem__(self, index): 59 | """ 60 | Args: 61 | index (int): Index 62 | Returns: 63 | tuple: (image, target) where target is index of the target class. 64 | """ 65 | if self.train: 66 | img, target = self.train_data[index], self.train_labels[index] 67 | else: 68 | img, target = self.test_data[index], self.test_labels[index] 69 | 70 | # doing this so that it is consistent with all other datasets 71 | # to return a PIL Image 72 | img = Image.fromarray(img.numpy(), mode='L') 73 | 74 | if self.transform is not None: 75 | img = self.transform(img) 76 | 77 | if self.target_transform is not None: 78 | target = self.target_transform(target) 79 | 80 | return img, target 81 | 82 | def __len__(self): 83 | if self.train: 84 | return len(self.train_data) 85 | else: 86 | return len(self.test_data) 87 | 88 | def _check_exists(self): 89 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ 90 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) 91 | 92 | def download(self): 93 | """Download the MNIST data if it doesn't exist in processed_folder already.""" 94 | from six.moves import urllib 95 | import gzip 96 | 97 | if self._check_exists(): 98 | return 99 | 100 | # download files 101 | try: 102 | os.makedirs(os.path.join(self.root, self.raw_folder)) 103 | os.makedirs(os.path.join(self.root, self.processed_folder)) 104 | except OSError as e: 105 | if e.errno == errno.EEXIST: 106 | pass 107 | else: 108 | raise 109 | 110 | for url in self.urls: 111 | print('Downloading ' + url) 112 | data = urllib.request.urlopen(url) 113 | filename = url.rpartition('/')[2] 114 | file_path = os.path.join(self.root, self.raw_folder, filename) 115 | with open(file_path, 'wb') as f: 116 | f.write(data.read()) 117 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 118 | gzip.GzipFile(file_path) as zip_f: 119 | out_f.write(zip_f.read()) 120 | os.unlink(file_path) 121 | 122 | # process and save as torch files 123 | print('Processing...') 124 | 125 | training_set = ( 126 | read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), 127 | read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) 128 | ) 129 | test_set = ( 130 | read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), 131 | read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) 132 | ) 133 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: 134 | torch.save(training_set, f) 135 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: 136 | torch.save(test_set, f) 137 | 138 | print('Done!') 139 | 140 | 141 | class FashionMNIST(MNIST): 142 | """`Fashion-MNIST `_ Dataset. 143 | Args: 144 | root (string): Root directory of dataset where ``processed/training.pt`` 145 | and ``processed/test.pt`` exist. 146 | train (bool, optional): If True, creates dataset from ``training.pt``, 147 | otherwise from ``test.pt``. 148 | download (bool, optional): If true, downloads the dataset from the internet and 149 | puts it in root directory. If dataset is already downloaded, it is not 150 | downloaded again. 151 | transform (callable, optional): A function/transform that takes in an PIL image 152 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 153 | target_transform (callable, optional): A function/transform that takes in the 154 | target and transforms it. 155 | """ 156 | urls = [ 157 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 158 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 159 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 160 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 161 | ] 162 | 163 | 164 | def get_int(b): 165 | return int(codecs.encode(b, 'hex'), 16) 166 | 167 | 168 | def parse_byte(b): 169 | if isinstance(b, str): 170 | return ord(b) 171 | return b 172 | 173 | 174 | def read_label_file(path): 175 | with open(path, 'rb') as f: 176 | data = f.read() 177 | assert get_int(data[:4]) == 2049 178 | length = get_int(data[4:8]) 179 | labels = [parse_byte(b) for b in data[8:]] 180 | assert len(labels) == length 181 | return torch.LongTensor(labels) 182 | 183 | 184 | def read_image_file(path): 185 | with open(path, 'rb') as f: 186 | data = f.read() 187 | assert get_int(data[:4]) == 2051 188 | length = get_int(data[4:8]) 189 | num_rows = get_int(data[8:12]) 190 | num_cols = get_int(data[12:16]) 191 | images = [] 192 | idx = 16 193 | for l in range(length): 194 | img = [] 195 | images.append(img) 196 | for r in range(num_rows): 197 | row = [] 198 | img.append(row) 199 | for c in range(num_cols): 200 | row.append(parse_byte(data[idx])) 201 | idx += 1 202 | assert len(images) == length 203 | return torch.ByteTensor(images).view(-1, 28, 28) 204 | -------------------------------------------------------------------------------- /utils/feature_extraction_test.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | import torch 3 | from torch.autograd import Variable 4 | from utils.data_loader import get_data_loader 5 | from sklearn.metrics import accuracy_score 6 | from sklearn.linear_model import LogisticRegression 7 | 8 | ''' 9 | Running feature extraction part for GAN model extraction 10 | cifar-10 $ python main.py --dataroot datasets/cifar --dataset cifar --load_D trained_models/dcgan/cifar/discriminator.pkl --load_G trained_models/dcgan/cifar/generator.pkl 11 | ''' 12 | 13 | class FeatureExtractionTest(): 14 | 15 | def __init__(self, train_loader, test_loader, cuda_flag, batch_size): 16 | self.train_loader = train_loader 17 | self.test_loader = test_loader 18 | print("Train length: {}".format(len(self.train_loader))) 19 | print("Test length: {}".format(len(self.test_loader))) 20 | self.batch_size = batch_size 21 | 22 | # Remove fully connected layer and extract 2048 vector as feautre representation of image 23 | self.model = models.resnet152(pretrained=True).cuda() 24 | self.model = torch.nn.Sequential(*list(self.model.children())[:-1]) 25 | 26 | 27 | # Feature extraction test #1 flattening image 28 | def flatten_images(self): 29 | """ 30 | Flattening image as image representation. 31 | Input is image and output is flattened self.channels*32*32 dimensional numpy array 32 | """ 33 | x_train, y_train = [], [] 34 | x_test, y_test = [], [] 35 | 36 | # flatten pixels of train images 37 | for i, (images, labels) in enumerate(self.train_loader): 38 | if i == len(self.train_loader) // self.batch_size: 39 | break 40 | images = images.numpy() 41 | labels = labels.numpy() 42 | 43 | # Iterate over batch and save as numpy array features of images and label 44 | for j in range(self.batch_size): 45 | x_train.append(images[j].flatten()) 46 | y_train.append(labels[j]) 47 | 48 | for i, (images, labels) in enumerate(self.test_loader): 49 | if i == len(self.test_loader) // self.batch_size: 50 | break 51 | 52 | images = images.numpy() 53 | labels = labels.numpy() 54 | 55 | # Iterate over batch and save as numpy array features of images and label 56 | for j in range(self.batch_size): 57 | x_test.append(images[j].flatten()) 58 | y_test.append(labels[j]) 59 | 60 | return x_train, y_train, x_test, y_test 61 | 62 | # Feature extraction test #4 transfer learning Inception v3 model pretrained 63 | # Resize imaged to 224x224 for pretrained models 64 | def inception_feature_extraction(self): 65 | """ 66 | Extract features from images with pretrained ResNet152 on ImageNet, with removed fully-connected layer. 67 | Input is image and output is flattened 2048 dimensional numpy array 68 | """ 69 | x_train, y_train = [], [] 70 | x_test, y_test = [], [] 71 | 72 | for i, (images, labels) in enumerate(self.train_loader): 73 | if i == len(self.train_loader) // self.batch_size: 74 | break 75 | 76 | images = Variable(images).cuda() 77 | 78 | # Feature extraction with Resnet152 resulting with feature vector of 2048 dimension 79 | outputs = self.model(images) 80 | 81 | # Convert FloatTensors to numpy array 82 | features = outputs.data.cpu().numpy() 83 | labels = labels.numpy() 84 | 85 | # Iterate over batch and save as numpy array features of images and label 86 | for j in range(self.batch_size): 87 | x_train.append(features[j].flatten()) 88 | y_train.append(labels[j]) 89 | 90 | 91 | for i, (images, labels) in enumerate(self.test_loader): 92 | if i == len(self.test_loader) // self.batch_size: 93 | break 94 | 95 | images = Variable(images).cuda() 96 | 97 | # Feature extraction with Resnet152 resulting with feature vector of 2048 dimension 98 | outputs = self.model(images) 99 | 100 | # Convert FloatTensors to numpy array 101 | features = outputs.data.cpu().numpy() 102 | labels = labels.numpy() 103 | 104 | # Iterate over batch and save as numpy array features of images and label 105 | for j in range(self.batch_size): 106 | x_test.append(features[j].flatten()) 107 | y_test.append(labels[j]) 108 | 109 | return x_train, y_train, x_test, y_test 110 | 111 | # Feature extraction GAN model discriminator output 1024x4x4 112 | def GAN_feature_extraction(self, discriminator): 113 | """ 114 | Extract features from images with trained discriminator of GAN model. 115 | Input is image and output is flattened 16348 dimensional numpy array (1024x4x4) 116 | discriminator -- Trained discriminator of GAN model 117 | """ 118 | x_train, y_train = [], [] 119 | x_test, y_test = [], [] 120 | for i, (images, labels) in enumerate(self.train_loader): 121 | if i == len(self.train_loader) // self.batch_size: 122 | break 123 | 124 | images = Variable(images).cuda() 125 | # Feature extraction DCGAN discriminator output 1024x4x4 126 | outputs = discriminator.feature_extraction(images) 127 | 128 | # Convert FloatTensors to numpy array 129 | features = outputs.data.cpu().numpy() 130 | labels = labels.numpy() 131 | 132 | # Iterate over batch and save as numpy array features of images and label 133 | for j in range(self.batch_size): 134 | x_train.append(features[j].flatten()) 135 | y_train.append(labels[j]) 136 | 137 | for i, (images, labels) in enumerate(self.test_loader): 138 | if i == len(self.test_loader) // self.batch_size: 139 | break 140 | 141 | images = Variable(images).cuda() 142 | outputs = discriminator.feature_extraction(images) 143 | 144 | # Convert FloatTensors to numpy array 145 | features = outputs.data.cpu().numpy() 146 | labels = labels.numpy() 147 | 148 | # Iterate over batch and save as numpy array features of images and label 149 | for j in range(self.batch_size): 150 | x_test.append(features[j].flatten()) 151 | y_test.append(labels[j]) 152 | 153 | return x_train, y_train, x_test, y_test 154 | 155 | 156 | def calculate_score(self): 157 | """ 158 | Calculate accuracy score by fitting feature representation on to a linear classificato LinearSVM or LogisticRegression 159 | """ 160 | mean_score = 0 161 | for i in range(10): 162 | # This way data is shuffling every iteration 163 | train_loader, test_loader = get_data_loader(args) 164 | 165 | x_train, y_train, x_test, y_test = feature_extraction.inception_feature_extraction() 166 | # x_train, y_train, x_test, y_test = feature_extraction.GAN_feature_extraction(model.D) 167 | # x_train, y_train, x_test, y_test = feature_extraction.flatten_images() 168 | 169 | # clf = LinearSVC() 170 | clf = LogisticRegression() 171 | clf.fit(x_train, y_train) 172 | 173 | predicted = clf.predict(x_test) 174 | score = accuracy_score(y_test, predicted) 175 | print("Accuaracy score: {}".format(score)) 176 | mean_score += score 177 | print("Mean score: {}".format(float(mean_score) / float(10))) 178 | return float(mean_score) / float(10) 179 | -------------------------------------------------------------------------------- /utils/inception_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | import torch.utils.data 6 | from torchvision.models.inception import inception_v3 7 | import numpy as np 8 | from scipy.stats import entropy 9 | 10 | 11 | def get_inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): 12 | """ 13 | Computes the inception score of the generated images imgs 14 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 15 | cuda -- whether or not to run on GPU 16 | batch_size -- batch size for feeding into Inception v3 17 | splits -- number of splits 18 | """ 19 | N = len(imgs) 20 | 21 | assert batch_size > 0 22 | assert N > batch_size 23 | 24 | # Set up dtype 25 | if cuda: 26 | dtype = torch.cuda.FloatTensor 27 | else: 28 | if torch.cuda.is_available(): 29 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 30 | dtype = torch.FloatTensor 31 | 32 | # Set up dataloader 33 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 34 | 35 | # Load inception model 36 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) 37 | inception_model.eval(); 38 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) 39 | def get_pred(x): 40 | if resize: 41 | x = up(x) 42 | x = inception_model(x) 43 | return F.softmax(x).data.cpu().numpy() 44 | 45 | # Get predictions 46 | preds = np.zeros((N, 1000)) 47 | 48 | for i, batch in enumerate(dataloader, 0): 49 | batch = batch.type(dtype) 50 | batchv = Variable(batch) 51 | batch_size_i = batch.size()[0] 52 | 53 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 54 | 55 | # Now compute the mean kl-div 56 | split_scores = [] 57 | 58 | for k in range(splits): 59 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 60 | py = np.mean(part, axis=0) 61 | scores = [] 62 | for i in range(part.shape[0]): 63 | pyx = part[i, :] 64 | scores.append(entropy(pyx, py)) 65 | split_scores.append(np.exp(np.mean(scores))) 66 | 67 | return np.mean(split_scores), np.std(split_scores) 68 | -------------------------------------------------------------------------------- /utils/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, log_dir): 7 | """Create a summary writer logging to log_dir.""" 8 | self.writer = tf.summary.create_file_writer(log_dir) 9 | 10 | def scalar_summary(self, tag, value, step): 11 | """Log a scalar variable.""" 12 | with self.writer.as_default(): 13 | tf.summary.scalar(tag, data=value, step=step) 14 | 15 | def image_summary(self, tag, images, step): 16 | """Log a list of images. 17 | Args::images: numpy of shape (Batch x C x H x W) in the range [-1.0, 1.0] 18 | """ 19 | with self.writer.as_default(): 20 | imgs = None 21 | for i, j in enumerate(images): 22 | img = ((j*0.5+0.5)*255).round().astype('uint8') 23 | if len(img.shape) == 3: 24 | img = img.transpose(1, 2, 0) 25 | else: 26 | img = img[:, :, np.newaxis] 27 | img = img[np.newaxis, :] 28 | if not imgs is None: 29 | imgs = np.append(imgs, img, axis=0) 30 | else: 31 | imgs = img 32 | tf.summary.image('{}'.format(tag), imgs, max_outputs=len(imgs), step=step) 33 | 34 | def histo_summary(self, tag, values, step, bins=1000): 35 | """Log a histogram of the tensor of values.""" 36 | with self.writer.as_default(): 37 | tf.summary.histogram('{}'.format(tag), values, buckets=bins, step=step) 38 | --------------------------------------------------------------------------------