├── samples ├── training_images.jpg ├── gan_generated_tiles_2.jpg ├── gan_evolution_training_epoch.jpg └── gan_latent_space_interpolation.jpg ├── source ├── utils.py ├── interpolation.py ├── data.py ├── train.py ├── test.py ├── torchsummary.py └── gan.py ├── LICENSE ├── dataset └── prepare_tiles.pde └── README.md /samples/training_images.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floboc/tiles-gan/HEAD/samples/training_images.jpg -------------------------------------------------------------------------------- /samples/gan_generated_tiles_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floboc/tiles-gan/HEAD/samples/gan_generated_tiles_2.jpg -------------------------------------------------------------------------------- /samples/gan_evolution_training_epoch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floboc/tiles-gan/HEAD/samples/gan_evolution_training_epoch.jpg -------------------------------------------------------------------------------- /samples/gan_latent_space_interpolation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floboc/tiles-gan/HEAD/samples/gan_latent_space_interpolation.jpg -------------------------------------------------------------------------------- /source/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | #can be used as a transform to add random noise 5 | class RandomNoise(object): 6 | def __init__(self, std): 7 | self.std = std 8 | 9 | def __call__(self, tensor): 10 | return tensor + self.std * torch.randn(tensor.size()) 11 | 12 | -------------------------------------------------------------------------------- /source/interpolation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def lerp(val, low, high): 4 | """Linear interpolation""" 5 | return low + (high - low) * val 6 | 7 | def slerp(val, low, high): 8 | """Spherical interpolation. val has a range of 0 to 1.""" 9 | if val <= 0: 10 | return low 11 | elif val >= 1: 12 | return high 13 | elif np.allclose(low, high): 14 | return low 15 | omega = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high))) 16 | so = np.sin(omega) 17 | return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high 18 | 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 floboc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /source/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import glob 4 | import random 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torch.autograd import Variable 9 | 10 | class ImageDataset(torch.utils.data.Dataset): 11 | def __init__(self, path, extension, transforms_list, recursive=True): 12 | self.path = path 13 | self.transform = transforms.Compose(transforms_list) 14 | if recursive: 15 | self.files = sorted(glob.glob(path + '/**/*.' + extension, recursive=True)) 16 | else: 17 | self.files = sorted(glob.glob(path + '/*.' + extension)) 18 | 19 | def __len__(self): 20 | return len(self.files) 21 | 22 | def __getitem__(self, index): 23 | filepath = self.files[index % len(self.files)]; 24 | image = self.transform(Image.open(filepath).convert('RGB')) 25 | return image 26 | 27 | class ReplayBuffer(): 28 | def __init__(self, max_size=50, replace_probability = 0.5): 29 | assert (max_size > 0), 'Buffer size must be superior to 0' 30 | self.max_size = max_size 31 | self.replace_probability = replace_probability 32 | self.data = [] 33 | 34 | def push_and_pop(self, data): 35 | to_return = [] 36 | for element in data.data: 37 | element = torch.unsqueeze(element, 0) 38 | if len(self.data) < self.max_size: 39 | self.data.append(element) 40 | to_return.append(element) 41 | else: 42 | if random.uniform(0,1) < self.replace_probability: 43 | i = random.randint(0, self.max_size-1) 44 | to_return.append(self.data[i].clone()) 45 | self.data[i] = element 46 | else: 47 | to_return.append(element) 48 | return Variable(torch.cat(to_return)) 49 | 50 | -------------------------------------------------------------------------------- /dataset/prepare_tiles.pde: -------------------------------------------------------------------------------- 1 | ArrayList images = new ArrayList 2 | 3 | (); 4 | 5 | void setup() 6 | { 7 | String input_folder = "F:/Projets/MachineLearning/Tiles/data/raw/"; //folder containing tilesets (must not contain any other file) 8 | String output_folder = "F:/Projets/MachineLearning/Tiles/data/tiles/"; //folder that will contain individual tiles (must exist) 9 | int size = 32; //Size of each tile (square) 10 | 11 | // Load each tileset 12 | File[] files = listFiles(input_folder); 13 | for (int i = 0; i < files.length; i++) 14 | { 15 | File f = files[i]; 16 | println("Processing " + f.getName() + "..."); 17 | if (!f.isDirectory()) 18 | { 19 | PImage img = loadImage(f.getAbsolutePath()); 20 | int cols = img.width / size; 21 | int rows = img.height / size; 22 | for (int c = 0; c < cols; c++) 23 | { 24 | for (int r = 0; r < rows; r++) 25 | { 26 | PImage tile_img = createImage(size, size, ARGB); 27 | tile_img.set(0, 0, img.get(c * size, r * size, size, size)); 28 | 29 | // Check that the tile is not empty 30 | color col = tile_img.get(0, 0); 31 | boolean is_not_empty = false; 32 | for (int x = 0; x < size; x++) 33 | { 34 | for (int y = 0; y < size; y++) 35 | { 36 | if (tile_img.get(x,y) != col) 37 | { 38 | is_not_empty = true; 39 | break; 40 | } 41 | } 42 | if (is_not_empty) 43 | break; 44 | } 45 | 46 | if (is_not_empty) 47 | tile_img.save(output_folder + "/" + f.getName() + (r * cols + c) + ".png"); 48 | } 49 | } 50 | } 51 | } 52 | 53 | println("Done!"); 54 | } -------------------------------------------------------------------------------- /source/train.py: -------------------------------------------------------------------------------- 1 | #repository files 2 | from utils import RandomNoise 3 | from gan import GANOptions, GAN 4 | from torchsummary import summary 5 | 6 | #other dependencies 7 | import random 8 | import numpy 9 | import torch 10 | import torchvision.transforms as transforms 11 | from datetime import datetime 12 | from PIL import Image 13 | 14 | #Set random seem for reproducibility 15 | rnd_seed = 123456789 #datetime.now() 16 | random.seed(rnd_seed) 17 | numpy.random.seed(rnd_seed) 18 | torch.manual_seed(rnd_seed) 19 | print("Random Seed: ", rnd_seed) 20 | 21 | #Change the settings to your own dataset paths 22 | opts = GANOptions(data_path = "D:/GAN/Databases/tiles/", file_extension = "png", 23 | output_path = "D:/GAN/Results/tiles/", 24 | image_size = 32) 25 | 26 | #The settings I used for my article (https://playerone-studio.com/gan-2d-tiles-generative-adversarial-network) 27 | opts.generator_batchnorm = True 28 | opts.discriminator_batchnorm = True 29 | opts.discriminator_dropout = 0.3 30 | opts.generator_frequency = 1 31 | opts.generator_dropout = 0.3 32 | opts.label_softness = 0.2 33 | opts.batch_size = 128 34 | opts.epoch_number = 300 35 | 36 | #Note: if you are using jupyter notebook you might need to disable workers: 37 | opts.workers_nbr = 0; 38 | 39 | #List of tranformations applied to each input image 40 | opts.transforms = [ 41 | transforms.Resize(int(opts.image_size), Image.BICUBIC), 42 | transforms.ToTensor(), #do not forget to transform image into a tensor 43 | transforms.Normalize((.5, .5, .5), (.5, .5, .5)) 44 | #RandomNoise(0.01) #adding noise might prevent the discriminator from over-fitting 45 | ] 46 | 47 | #Build GAN 48 | model = GAN(opts) 49 | 50 | #Display GAN architecture (note: only work if cuda is enabled) 51 | summary(model.generator.cuda(), input_size=(opts.latent_dim, 1, 1)) 52 | summary(model.discriminator.cuda(), input_size=(opts.channels_nbr, opts.image_size, opts.image_size)) 53 | 54 | #Start training 55 | model.train() 56 | 57 | #Save model 58 | model.save(opts.output_path) 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generating 2D map tiles with Generative Adversarial Network (GAN) 2 | 3 | ## Introduction 4 | 5 | This repository contains all the source code used to generate small 2D tiles commonly used in video games for building maps. 6 | Generating tiles was achieved using a **Deep Convolutional Generative Adversarial Network** (DCGAN). 7 | This code uses PyTorch as backend. 8 | 9 | This is an example of what it can generate:\ 10 | ![](https://github.com/floboc/tiles-gan/blob/master/samples/gan_generated_tiles_2.jpg) 11 | 12 | And when interpolating in its latent space:\ 13 | ![](https://github.com/floboc/tiles-gan/blob/master/samples/gan_latent_space_interpolation.jpg) 14 | 15 | **You can find more information and details about this source code in this article:** 16 | 17 | 18 | ## Dependencies 19 | 20 | There are few dependencies: 21 | - PyTorch () 22 | - Numpy 23 | - Python Imaging Library (PIL) 24 | 25 | 26 | ## How to use it 27 | 28 | ### 1. Create your own dataset: 29 | - You can use the simple Processing () script provided in this repository to convert downloaded tilesets into individual tiles 30 | - All tilesets must be in a folder with no other file, and tiles should be of the same size in all tilesets (here 32x32) 31 | - Tiles will be saved as individual PNG files. Empty tiles will be omitted. 32 | 33 | Here are some samples of tiles I used for my dataset:\ 34 | ![](https://github.com/floboc/tiles-gan/blob/master/samples/training_images.jpg) 35 | 36 | ### 2. Train your model: 37 | - Edit train.py so that the paths match that of your dataset (images have to be power of 2) 38 | - Also adjust any settings as you want. The settings are detailed in the file gan.py 39 | - run using "python train.py" 40 | 41 | You should be patient as etting the first results can take some time. Here are some results during training:\ 42 | ![](https://github.com/floboc/tiles-gan/blob/master/samples/gan_evolution_training_epoch.jpg) 43 | 44 | ### 3. Test your model: 45 | - Edit test.py to match the path where you saved your model 46 | - Also adjust latent space dimension if required 47 | - run using "python test.py" to generate some test images in your output folder 48 | - images will be saved to the same path as your model 49 | 50 | 51 | **Enjoy!** 52 | -------------------------------------------------------------------------------- /source/test.py: -------------------------------------------------------------------------------- 1 | from gan import GANOptions, GAN 2 | from interpolation import lerp, slerp 3 | from torchsummary import summary 4 | 5 | #other dependencies 6 | import random 7 | import numpy 8 | import torch 9 | import torchvision.utils 10 | import torchvision.transforms as transforms 11 | from datetime import datetime 12 | 13 | 14 | #This is ugly but for now it does the work... 15 | #Change this so that it matches your model options 16 | opts = GANOptions() 17 | opts.latent_dim = 50 18 | opts.image_size = 32 19 | opts.channels_nbr = 3 20 | opts.output_path = "D:/GAN/Results/tiles/models" 21 | 22 | 23 | #Set random seem for reproducibility (if required) 24 | rnd_seed = 123456789 25 | random.seed(rnd_seed) 26 | numpy.random.seed(rnd_seed) 27 | torch.manual_seed(rnd_seed) 28 | print("Random Seed: ", rnd_seed) 29 | 30 | 31 | #CUDA 32 | if torch.cuda.is_available(): 33 | Tensor = torch.cuda.FloatTensor 34 | print("CUDA is available") 35 | else: 36 | Tensor = torch.Tensor 37 | print("CUDA is NOT available") 38 | 39 | #load model 40 | gan = GAN(opts) 41 | gan.load(opts.output_path) 42 | 43 | #display model info 44 | summary(gan.generator.cuda(), input_size=(opts.latent_dim, 1, 1)) 45 | summary(gan.discriminator.cuda(), input_size=(opts.channels_nbr, opts.image_size, opts.image_size)) 46 | 47 | 48 | #IMAGE GENERATION 49 | 50 | #test model 51 | nrows = 16 52 | ncols = 16 53 | latent_vectors = Tensor(numpy.random.normal(0.0, 1.0, (nrows * ncols, opts.latent_dim, 1, 1))) 54 | gen = gan.generator(latent_vectors).detach() 55 | 56 | #save image 57 | torchvision.utils.save_image(gen, opts.output_path + '/generated.png', nrow=ncols, normalize=True, padding=0) 58 | 59 | 60 | 61 | #LATENT SPACE INTERPOLATION 62 | 63 | #source and target latent representations 64 | source = numpy.random.normal(0.0, 1.0, opts.latent_dim) 65 | target = numpy.random.normal(0.0, 1.0, opts.latent_dim) 66 | steps = 16 67 | #interpolation in latent space 68 | latents = [] 69 | for i in range(steps): 70 | factor = i / (steps - 1) 71 | val = slerp(factor, source, target) 72 | z = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(Tensor(val), 0),2),3) 73 | latents.append(z) 74 | latents = torch.cat(latents) 75 | 76 | #apply model 77 | gen = gan.generator(latents).detach() 78 | 79 | #save image 80 | torchvision.utils.save_image(gen, opts.output_path + '/interpolation.png', nrow=steps, normalize=True, padding=0) 81 | -------------------------------------------------------------------------------- /source/torchsummary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | from collections import OrderedDict 6 | import numpy as np 7 | 8 | 9 | def summary(model, input_size, batch_size=-1, device="cuda"): 10 | 11 | def register_hook(module): 12 | 13 | def hook(module, input, output): 14 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 15 | module_idx = len(summary) 16 | 17 | m_key = "%s-%i" % (class_name, module_idx + 1) 18 | summary[m_key] = OrderedDict() 19 | summary[m_key]["input_shape"] = list(input[0].size()) 20 | summary[m_key]["input_shape"][0] = batch_size 21 | if isinstance(output, (list, tuple)): 22 | summary[m_key]["output_shape"] = [ 23 | [-1] + list(o.size())[1:] for o in output 24 | ] 25 | else: 26 | summary[m_key]["output_shape"] = list(output.size()) 27 | summary[m_key]["output_shape"][0] = batch_size 28 | 29 | params = 0 30 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 31 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 32 | summary[m_key]["trainable"] = module.weight.requires_grad 33 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 34 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 35 | summary[m_key]["nb_params"] = params 36 | 37 | if ( 38 | not isinstance(module, nn.Sequential) 39 | and not isinstance(module, nn.ModuleList) 40 | and not (module == model) 41 | ): 42 | hooks.append(module.register_forward_hook(hook)) 43 | 44 | device = device.lower() 45 | assert device in [ 46 | "cuda", 47 | "cpu", 48 | ], "Input device is not valid, please specify 'cuda' or 'cpu'" 49 | 50 | if device == "cuda" and torch.cuda.is_available(): 51 | dtype = torch.cuda.FloatTensor 52 | else: 53 | dtype = torch.FloatTensor 54 | 55 | # multiple inputs to the network 56 | if isinstance(input_size, tuple): 57 | input_size = [input_size] 58 | 59 | # batch_size of 2 for batchnorm 60 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 61 | # print(type(x[0])) 62 | 63 | # create properties 64 | summary = OrderedDict() 65 | hooks = [] 66 | 67 | # register hook 68 | model.apply(register_hook) 69 | 70 | # make a forward pass 71 | # print(x.shape) 72 | model(*x) 73 | 74 | # remove these hooks 75 | for h in hooks: 76 | h.remove() 77 | 78 | print("----------------------------------------------------------------") 79 | line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 80 | print(line_new) 81 | print("================================================================") 82 | total_params = 0 83 | total_output = 0 84 | trainable_params = 0 85 | for layer in summary: 86 | # input_shape, output_shape, trainable, nb_params 87 | line_new = "{:>20} {:>25} {:>15}".format( 88 | layer, 89 | str(summary[layer]["output_shape"]), 90 | "{0:,}".format(summary[layer]["nb_params"]), 91 | ) 92 | total_params += summary[layer]["nb_params"] 93 | total_output += np.prod(summary[layer]["output_shape"]) 94 | if "trainable" in summary[layer]: 95 | if summary[layer]["trainable"] == True: 96 | trainable_params += summary[layer]["nb_params"] 97 | print(line_new) 98 | 99 | # assume 4 bytes/number (float on cuda). 100 | total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) 101 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients 102 | total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) 103 | total_size = total_params_size + total_output_size + total_input_size 104 | 105 | print("================================================================") 106 | print("Total params: {0:,}".format(total_params)) 107 | print("Trainable params: {0:,}".format(trainable_params)) 108 | print("Non-trainable params: {0:,}".format(total_params - trainable_params)) 109 | print("----------------------------------------------------------------") 110 | print("Input size (MB): %0.2f" % total_input_size) 111 | print("Forward/backward pass size (MB): %0.2f" % total_output_size) 112 | print("Params size (MB): %0.2f" % total_params_size) 113 | print("Estimated Total Size (MB): %0.2f" % total_size) 114 | print("----------------------------------------------------------------") 115 | # return summary 116 | -------------------------------------------------------------------------------- /source/gan.py: -------------------------------------------------------------------------------- 1 | #repositiory 2 | from data import ImageDataset, ReplayBuffer 3 | 4 | #other dependencies 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader 7 | import os 8 | import math 9 | import numpy 10 | import random 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.utils 14 | from torchvision.utils import save_image 15 | 16 | class GANOptions: 17 | def __init__(self, image_size = 32, n_epochs = 10000, latent_dim = 50, batch_size = 128, data_path = "", file_extension = "", output_path = "", workers_nbr = 8, transform_list = []): 18 | self.epoch_number = n_epochs #number of epochs 19 | self.latent_dim = latent_dim #size of latent space 20 | self.batch_size = batch_size #size of mini-batch 21 | self.label_softness = 0.1 #amplitude of the noise added to label when training discriminator 22 | 23 | self.data_path = data_path #path of the data 24 | self.file_extension = file_extension #file extension 25 | self.workers_nbr = workers_nbr #number of workers for data loading 26 | self.image_size = image_size #size of the image 27 | self.channels_nbr = 3 #number of channels in the image 28 | self.transforms = transform_list 29 | 30 | self.output_path = output_path #path for output 31 | self.save_interval = 100 #number of batches between saves 32 | 33 | self.generator_frequency = 1 #number of training steps for the generator (not recommended, keep it to 1) 34 | self.generator_dropout = 0.2 #drop-out probability for generator (0 = none) 35 | self.generator_batchnorm = True #use batch-normalization in generator 36 | self.generator_lr = 2e-4 #generator learning rate 37 | self.generator_beta1 = 0.5 #generator beta1 parameter (adam) 38 | 39 | self.discriminator_dropout = 0.2 #drop-out probability for discriminator (0 = none) 40 | self.discriminator_batchnorm = True #use batch-normalization in discriminator 41 | self.discriminator_lr = 2e-4 #discriminator learning rate 42 | self.discriminator_beta1 = 0.5 #discriminator beta1 parameter (adam) 43 | self.discriminator_kernel = 5 #kernel size for discriminator 44 | 45 | 46 | #custom weight initialization 47 | def gan_weights_init(m): 48 | classname = m.__class__.__name__ 49 | if classname.find('Conv') != -1: 50 | nn.init.normal_(m.weight.data, 0.0, 0.02) 51 | elif classname.find('BatchNorm') != -1: 52 | nn.init.normal_(m.weight.data, 1.0, 0.02) 53 | nn.init.constant_(m.bias.data, 0) 54 | 55 | #GAN generator 56 | class Generator(torch.nn.Module): 57 | 58 | #latent_dim: size of latent vectors 59 | #image_size: size of generated images (must be a power of 2) 60 | #channels_nbr: number of channels in the image (3 for RGB) 61 | #features_nbr: number of filters at the last stage 62 | #dropout: unit drop-out probability (0 for no drop-out) 63 | #batch_norm: if true, batch normalization will be used, otherise, instance normalization will be applied 64 | def __init__(self, latent_dim, image_size, channels_nbr = 3, features_nbr = 64, dropout = 0.2, batch_norm = True): 65 | super(Generator, self).__init__() 66 | 67 | layers = [] 68 | 69 | #Compute model size 70 | in_features = latent_dim 71 | num_blocks = round(math.log2(image_size)) - 3 72 | out_features = features_nbr * (2 ** num_blocks) 73 | 74 | #First expansion layer (to 4x4 size) 75 | layers += [ 76 | torch.nn.ConvTranspose2d(in_channels=in_features, out_channels=out_features, kernel_size=4, stride=1, padding=0, bias=False), 77 | torch.nn.BatchNorm2d(out_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) if batch_norm else None, 78 | torch.nn.LeakyReLU(0.1, inplace=True) 79 | ] 80 | in_features = out_features 81 | out_features = out_features // 2 82 | 83 | #Transpose convolution blocks 84 | for i in range(num_blocks): 85 | layers += [ 86 | torch.nn.Upsample(scale_factor=2), 87 | torch.nn.ReflectionPad2d(1), 88 | torch.nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=3, stride=1, padding=0), 89 | torch.nn.Dropout2d(p=0.3, inplace=True) if dropout > 0 else None, 90 | torch.nn.BatchNorm2d(out_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) if batch_norm else None, 91 | torch.nn.InstanceNorm2d(out_features) if (not batch_norm) else None, 92 | torch.nn.LeakyReLU(0.1, inplace=True) 93 | ] 94 | in_features = out_features 95 | out_features = out_features // 2 96 | 97 | #Last conversion layer 98 | layers += [ 99 | torch.nn.Upsample(scale_factor=2), 100 | torch.nn.ReflectionPad2d(1), 101 | torch.nn.Conv2d(in_channels=in_features, out_channels=channels_nbr, kernel_size=3, stride=1, padding=0), 102 | torch.nn.Tanh() 103 | ] 104 | 105 | #create sequential nn from layers list 106 | self.net = torch.nn.Sequential(*filter(lambda x: x is not None, layers)) 107 | 108 | 109 | def forward(self, x): 110 | return self.net(x) 111 | 112 | #GAN discriminator 113 | class Discriminator(torch.nn.Module): 114 | 115 | #image_size: size of generated images (must be a power of 2) 116 | #channels_nbr: number of channels in the image (3 for RGB) 117 | #features_nbr: number of filters in the first stage 118 | #kernel_size: size of the convolution kernels 119 | #dropout: unit drop-out probability (0 for no drop-out) 120 | #batch_norm: if true, batch normalization will be used, otherwise, instance normalization will be applied 121 | def __init__(self, image_size, channels_nbr = 3, features_nbr = 64, kernel_size = 5, dropout = 0.2, batch_norm = True): 122 | super(Discriminator, self).__init__() 123 | 124 | layers = [] 125 | 126 | #Compute model size 127 | in_features = channels_nbr 128 | out_features = features_nbr 129 | num_blocks = round(math.log2(image_size)) - 3 130 | 131 | #First layer (input is C x S x S, output is F x (S/2) x (S/2)) 132 | layers += [ 133 | torch.nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2, bias=False), 134 | torch.nn.LeakyReLU(0.2, inplace=True) 135 | ] 136 | in_features = out_features 137 | out_features *= 2 138 | 139 | #Transpose convolution blocks 140 | for i in range(num_blocks): 141 | layers += [ 142 | torch.nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2, bias=False), 143 | torch.nn.Dropout2d(p=0.3, inplace=True) if dropout > 0 else None, 144 | torch.nn.BatchNorm2d(out_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) if batch_norm else None, 145 | torch.nn.InstanceNorm2d(out_features) if (not batch_norm) else None, 146 | torch.nn.LeakyReLU(0.2, inplace=True) 147 | ] 148 | in_features = out_features 149 | out_features *= 2 150 | 151 | #Last layer (input is N x 4 x 4, output is 1 x 1 x 1) 152 | layers += [ 153 | torch.nn.Conv2d(in_channels=in_features, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False), 154 | torch.nn.Sigmoid() 155 | ] 156 | 157 | #create sequential nn from layers list 158 | self.net = torch.nn.Sequential(*filter(lambda x: x is not None, layers)) 159 | 160 | def forward(self, x): 161 | return self.net(x) 162 | 163 | #GAN class used for training 164 | class GAN: 165 | #options a GANOption obect 166 | def __init__(self, options): 167 | self.discriminator = Discriminator(image_size=options.image_size, channels_nbr=options.channels_nbr, 168 | features_nbr = 64, kernel_size = options.discriminator_kernel, 169 | dropout=options.discriminator_dropout, batch_norm = options.discriminator_batchnorm) 170 | self.generator = Generator(latent_dim=options.latent_dim, image_size=options.image_size, 171 | channels_nbr=options.channels_nbr, features_nbr = 64, 172 | dropout=options.generator_dropout, batch_norm = options.generator_batchnorm) 173 | self.options = options 174 | 175 | #saves the generator and discriminator to the path given 176 | #note: for now it doesn't save any option (like latent space dimension) 177 | def save(self, path): 178 | torch.save(self.generator.state_dict(), '%s/generator.model' % path) 179 | torch.save(self.discriminator.state_dict(), '%s/discriminator.model' % path) 180 | 181 | #loads a previously saved GAN model FOR EVALUATION 182 | def load(self, path): 183 | self.generator.load_state_dict(torch.load('%s/generator.model' % path)) 184 | self.discriminator.load_state_dict(torch.load('%s/discriminator.model' % path)) 185 | 186 | #set dropout and batch normalization layers to evaluation mode 187 | self.generator.eval() 188 | self.discriminator.eval() 189 | 190 | #starts training using the options given at initialization 191 | def train(self): 192 | #create folders 193 | os.makedirs('%s/images' % self.options.output_path, exist_ok=True) 194 | os.makedirs('%s/models' % self.options.output_path, exist_ok=True) 195 | 196 | #define dataset 197 | dataset = ImageDataset(path=self.options.data_path, extension=self.options.file_extension, transforms_list=self.options.transforms) 198 | data_loader = DataLoader(dataset, batch_size=self.options.batch_size, shuffle=True, num_workers=self.options.workers_nbr) 199 | print('Dataset has %s samples (%s batches)' % (len(dataset), len(data_loader))) 200 | 201 | #define losses 202 | discriminator_loss = torch.nn.BCELoss() 203 | 204 | #CUDA conversion 205 | #Note: this must be done before defining optimizers 206 | if torch.cuda.is_available(): 207 | Tensor = torch.cuda.FloatTensor 208 | discriminator_loss = discriminator_loss.cuda() 209 | self.generator = self.generator.cuda() 210 | self.discriminator = self.discriminator.cuda() 211 | print("CUDA is available") 212 | else: 213 | Tensor = torch.Tensor 214 | print("CUDA is NOT available") 215 | 216 | #Optimizers 217 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.options.generator_lr, betas=(self.options.generator_beta1, 0.999)) 218 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.options.discriminator_lr, betas=(self.options.discriminator_beta1, 0.999)) 219 | 220 | #Initialize weights 221 | self.generator.apply(gan_weights_init) 222 | self.discriminator.apply(gan_weights_init) 223 | 224 | #Buffer for Experience Replay 225 | fake_buffer = ReplayBuffer(max_size=self.options.batch_size * 2, replace_probability=0.5) 226 | 227 | #Note: softening the labels by a random amount for the DISCRIMINATOR training avoid approaching 0 loss too rapidly 228 | label_softness = 0.1 229 | 230 | for epoch in range(self.options.epoch_number): 231 | for (i, batch) in enumerate(data_loader): 232 | 233 | # --- DISCRIMINATOR --- 234 | 235 | #reset discriminator's gradient 236 | discriminator_optimizer.zero_grad() 237 | 238 | #TODO: occasionally flip labels 239 | 240 | #backward on real data 241 | real = Variable(batch.type(Tensor)) 242 | batch_size = real.shape[0] #get batch-size 243 | real_labels = Variable(Tensor(numpy.random.uniform(0.0, self.options.label_softness, (batch_size, 1, 1, 1)))) 244 | real_loss = discriminator_loss(self.discriminator(real), real_labels) 245 | real_loss.backward() 246 | 247 | #backward on fake data 248 | #z = Variable(Tensor(numpy.random.uniform(0.0, 1.0, (batch_size, self.options.latent_dim, 1, 1)))) 249 | z = Variable(Tensor(numpy.random.normal(0.0, 1.0, (batch_size, self.options.latent_dim, 1, 1)))) 250 | fake_labels = Variable(Tensor(numpy.random.uniform(1.0 - self.options.label_softness, 1.0, (batch_size, 1, 1, 1)))) 251 | fake = self.generator(z) 252 | fake_buffered = fake_buffer.push_and_pop(fake) 253 | fake_loss = discriminator_loss(self.discriminator(fake_buffered.detach()), fake_labels) 254 | fake_loss.backward() 255 | 256 | #train discriminator 257 | discriminator_optimizer.step() 258 | 259 | 260 | # --- GENERATOR --- 261 | 262 | #train the generator multiple times to force it to get better 263 | #Note: this is not recommended 264 | for substep in range(self.options.generator_frequency): 265 | #reset generator's gradient 266 | generator_optimizer.zero_grad() 267 | 268 | #backward on generator (we want fake ones to be considered real so we use real label) 269 | #Discriminator loss converges rapidly to zero thus preventing the Generator from learning 270 | # => Use a different training noise sample for the Generator 271 | z2 = Variable(Tensor(numpy.random.normal(0.0, 1.0, (batch_size, self.options.latent_dim, 1, 1)))) 272 | fake2 = self.generator(z2) 273 | real_labels2 = Variable(Tensor(0.0 * numpy.ones((batch_size, 1, 1, 1)))) 274 | g_loss = discriminator_loss(self.discriminator(fake2), real_labels2) 275 | g_loss.backward() 276 | 277 | #train generator 278 | generator_optimizer.step() 279 | 280 | #log info 281 | print("epoch %i - batch %i: D=%f G=%f" % (epoch, i, (0.5 * (real_loss + fake_loss)), g_loss)) 282 | 283 | if i % self.options.save_interval == 0: 284 | #save some results 285 | z = Tensor(numpy.random.normal(0.0, 1.0, (64, self.options.latent_dim, 1, 1))) 286 | gen = self.generator(z) 287 | save_image(gen, '%s/images/%s-%s.png' % (self.options.output_path, epoch, i), nrow=8, normalize=True) 288 | 289 | --------------------------------------------------------------------------------