├── README.md ├── data.py ├── figs ├── 0_input.jpeg ├── 0_real.jpeg ├── 1_input.jpeg ├── 1_real.jpeg ├── 22epoch.jpeg ├── 5_epochs_in.jpeg ├── 5_epochs_out.jpeg ├── GAN_0epoch.jpeg ├── GAN_1epoch.jpeg ├── L2_0epoch.jpeg ├── L2_1epoch.jpeg ├── c_0.jpg ├── c_1.jpg ├── c_2.jpg ├── c_3.jpg ├── c_4.jpg ├── c_5.jpg ├── c_6.jpg ├── c_7.jpg ├── comments.jpeg ├── dataset.jpeg ├── encoder.jpeg ├── facenn.jpg ├── generator.jpg ├── test.jpg ├── test.txt └── training.jpeg ├── main.py ├── network.py ├── pretrained ├── GoT_test │ ├── 001.jpg │ ├── 001 │ │ └── varys.jpg │ ├── 002.jpg │ ├── 002 │ │ └── shae.jpg │ ├── 003.jpg │ ├── 003 │ │ └── bran.jpg │ ├── 004.jpg │ └── 004 │ │ └── sansa.jpg ├── Readme.md ├── faces.gif ├── generator_v0.pt ├── test-GoT.jpg └── test-Pie.jpg └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch implementation of a face frontalization GAN 2 | 3 | ## Introduction 4 | 5 | Screenwriters never cease to amuse us with bizarre portrayals of the tech industry, ranging [from cringeworthy to hilarious](https://www.reddit.com/r/ZoomEnhance/). With the current advances in artificial intelligence, however, some of the most unrealistic technologies from the TV screens are coming to life. For example, the Enhance software from *CSI: NY* (or *Les Experts : Manhattan* for the francophone readers) has already been outshone by the state-of-the-art [Super Resolution neural networks](https://medium.com/@jonathan_hui/gan-super-resolution-gan-srgan-b471da7270ec). On a more extreme side of the imagination, there is *Enemy of the state*: 6 | 7 | [![](https://img.youtube.com/vi/3EwZQddc3kY/0.jpg)](https://www.youtube.com/watch?v=3EwZQddc3kY?) 8 | 9 | "Rotating [a video surveillance footage] 75 degrees around the vertical" must have seemed completely nonsensical long after 1998 when the movie came out, evinced by the youtube comments below this particular excerpt: 10 | ![](figs/comments.jpeg) 11 | 12 | Despite the apparent pessimism of the audience, thanks to machine learning today anyone with a little bit of Python knowledge and a large enough dataset can take a stab at writing a sci-fi drama worthy program. 13 | 14 | ## The face frontalization problem 15 | 16 | Forget MNIST, forget the boring cat vs. dog classifiers, today we are going to learn how to do something far more exciting! This project was inspired by the impressive work by R. Huang et al. [(Beyond Face Rotation: Global and Local Perception GAN for Photorealistic and Identity Preserving Frontal View Synthesis)](https://arxiv.org/abs/1704.04086), in which the authors synthesise frontal views of people's faces given their images at various angles. Below is Figure 3 from that paper, in which they compare their results [\[1\]](https://arxiv.org/abs/1704.04086) to previous work [2-6]: 17 | 18 | | ![](figs/c_0.jpg) |![](figs/c_1.jpg) | ![](figs/c_2.jpg) | ![](figs/c_3.jpg) | ![](figs/c_4.jpg) | ![](figs/c_5.jpg) |![](figs/c_6.jpg) |![](figs/c_7.jpg) | 19 | |---|---|---|---|---|---|---|---| 20 | | Input | [\[1\]](https://arxiv.org/abs/1704.04086) | [\[2\]](http://openaccess.thecvf.com/content_cvpr_2017/papers/Tran_Disentangled_Representation_Learning_CVPR_2017_paper.pdf) | [\[3\]](https://ieeexplore.ieee.org/document/7298667) | [\[4\]](https://arxiv.org/abs/1511.08446) | [\[5\]](https://ieeexplore.ieee.org/document/7298679) | [\[6\]](https://arxiv.org/abs/1411.7964) | Actual frontal | 21 | |Comparison of multiple face frontalization methods [\[1\]](https://arxiv.org/abs/1704.04086)| 22 | 23 | We are not going to try to reproduce the state-of-the-art model by R. Huang et al. Instead, we will construct and train a face frontalization model, producing reasonable results in a single afternoon: 24 | 25 | | input |![5epochs](figs/5_epochs_in.jpeg) | 26 | |---|---| 27 | | generated output | ![5epochs](figs/5_epochs_out.jpeg) | 28 | 29 | Additionally, we will go over: 30 | 31 | 1. How to use NVIDIA's `DALI` library for highly optimized pre-processing of images on the GPU and feeding them into a deep learning model. 32 | 33 | 2. How to code a Generative Adversarial Network, praised as “the most interesting idea in the last ten years in Machine Learning” by Yann LeCun, the director of Facebook AI, in `PyTorch` 34 | 35 | You will also have your very own Generative Adversarial Network set up to be trained on a dataset of your choice. Without further ado, let's dig in! 36 | 37 | ## Setting Up Your Data 38 | 39 | At the heart of any machine learning project, lies the data. Unfortunately, Scaleway cannot provide the [CMU Multi-PIE Face Database](http://www.cs.cmu.edu/afs/cs/project/PIE/MultiPie/Multi-Pie/Home.html) that we used for training due to copyright, so we shall proceed assuming you already have a dataset that you would like to train your model on. In order to make use of [NVIDIA Data Loading Library (DALI)](https://github.com/NVIDIA/DALI), the images should be in JPEG format. The dimensions of the images do not matter, since we have DALI to resize all the inputs to the input size required by our network (128 x 128 pixels), but a 1:1 ratio is desirable to obtain the most realistic synthesised images. 40 | The advantage of using DALI over, e.g., a standard PyTorch Dataset, is that whatever pre-processing (resizing, cropping, etc) is necessary, is performed on the GPU rather than the CPU, after which pre-processed images on the GPU are fed straight into the neural network. 41 | 42 | ### Managing our dataset: 43 | 44 | For the face frontalization project, we set up our dataset in the following way: the dataset folder contains a subfolder and a target frontal image for each person (aka subject). In principle, the names of the subfolders and the target images do not have to be identical (as they are in the Figure below), but if we are to separately sort all the subfolders and all the targets alphanumerically, the ones corresponding to the same subject must appear at the same position on the two lists of names. 45 | 46 | As you can see, subfolder `001/` corresponding to subject 001 contains images of the person pictured in `001.jpg` - these are closely cropped images of the face under different poses, lighting conditions, and varying face expressions. For the purposes of face frontalization, it is crucial to have the frontal images aligned as close to one another as possible, whereas the other (profile) images have a little bit more leeway. 47 | 48 | For instance, our target frontal images are all squares and cropped in such a way that the bottom of the person's chin is located at the bottom of the image, and the centred point between the inner corners of the eyes is situated at *0.8h* above and *0.5h* to the right of the lower left corner (*h* being the image's height). This way, once the images are resized to 128 x 128, the face features all appear at more or less the same locations on the images in the training set, and the network can learn to generate the said features and combine them together into realistic synthetic faces. 49 | 50 | ![](figs/dataset.jpeg) 51 | 52 | ### Building a DALI `Pipeline`: 53 | 54 | We are now going to build a pipeline for our dataset that is going to inherit from `nvidia.dali.pipeline.Pipeline`. At the time of writing, DALI does not directly support reading (image, image) pairs from a directory, so we will be making use of `nvidia.dali.ops.ExternalSource()` to pass the inputs and the targets to the pipeline. 55 | 56 | ##### `data.py` 57 | ``` 58 | import collections 59 | from random import shuffle 60 | import os 61 | from os import listdir 62 | from os.path import join 63 | 64 | import numpy as np 65 | from nvidia.dali.pipeline import Pipeline 66 | import nvidia.dali.ops as ops 67 | import nvidia.dali.types as types 68 | 69 | 70 | def is_jpeg(filename): 71 | return any(filename.endswith(extension) for extension in [".jpg", ".jpeg"]) 72 | 73 | 74 | def get_subdirs(directory): 75 | subdirs = sorted([join(directory,name) for name in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, name))]) 76 | return subdirs 77 | 78 | 79 | flatten = lambda l: [item for sublist in l for item in sublist] 80 | 81 | 82 | class ExternalInputIterator(object): 83 | 84 | def __init__(self, imageset_dir, batch_size, random_shuffle=False): 85 | self.images_dir = imageset_dir 86 | self.batch_size = batch_size 87 | 88 | # First, figure out what are the inputs and what are the targets in your directory structure: 89 | # Get a list of filenames for the target (frontal) images 90 | self.frontals = np.array([join(imageset_dir, frontal_file) for frontal_file in sorted(os.listdir(imageset_dir)) if is_jpeg(frontal_file)]) 91 | 92 | # Get a list of lists of filenames for the input (profile) images for each person 93 | profile_files = [[join(person_dir, profile_file) for profile_file in sorted(os.listdir(person_dir)) if is_jpeg(profile_file)] for person_dir in get_subdirs(imageset_dir)] 94 | 95 | # Build a flat list of frontal indices, corresponding to the *flattened* profile_files 96 | # The reason we are doing it this way is that we need to keep track of the multiple inputs corresponding to each target 97 | frontal_ind = [] 98 | for ind, profiles in enumerate(profile_files): 99 | frontal_ind += [ind]*len(profiles) 100 | self.frontal_indices = np.array(frontal_ind) 101 | 102 | # Now that we have built frontal_indices, we can flatten profile_files 103 | self.profiles = np.array(flatten(profile_files)) 104 | 105 | # Shuffle the (input, target) pairs if necessary: in practice, it is profiles and frontal_indices that get shuffled 106 | if random_shuffle: 107 | ind = np.array(range(len(self.frontal_indices))) 108 | shuffle(ind) 109 | self.profiles = self.profiles[ind] 110 | self.frontal_indices = self.frontal_indices[ind] 111 | 112 | 113 | def __iter__(self): 114 | self.i = 0 115 | self.n = len(self.frontal_indices) 116 | return self 117 | 118 | 119 | # Return a batch of (input, target) pairs 120 | def __next__(self): 121 | profiles = [] 122 | frontals = [] 123 | 124 | for _ in range(self.batch_size): 125 | profile_filename = self.profiles[self.i] 126 | frontal_filename = self.frontals[self.frontal_indices[self.i]] 127 | 128 | profile = open(profile_filename, 'rb') 129 | frontal = open(frontal_filename, 'rb') 130 | 131 | profiles.append(np.frombuffer(profile.read(), dtype = np.uint8)) 132 | frontals.append(np.frombuffer(frontal.read(), dtype = np.uint8)) 133 | 134 | profile.close() 135 | frontal.close() 136 | 137 | self.i = (self.i + 1) % self.n 138 | return (profiles, frontals) 139 | 140 | next = __next__ 141 | 142 | 143 | class ImagePipeline(Pipeline): 144 | ''' 145 | Constructor arguments: 146 | - imageset_dir: directory containing the dataset 147 | - image_size = 128: length of the square that the images will be resized to 148 | - random_shuffle = False 149 | - batch_size = 64 150 | - num_threads = 2 151 | - device_id = 0 152 | ''' 153 | 154 | def __init__(self, imageset_dir, image_size=128, random_shuffle=False, batch_size=64, num_threads=2, device_id=0): 155 | super(ImagePipeline, self).__init__(batch_size, num_threads, device_id, seed=12) 156 | eii = ExternalInputIterator(imageset_dir, batch_size, random_shuffle) 157 | self.iterator = iter(eii) 158 | self.num_inputs = len(eii.frontal_indices) 159 | 160 | # The source for the inputs and targets 161 | self.input = ops.ExternalSource() 162 | self.target = ops.ExternalSource() 163 | 164 | # nvJPEGDecoder below accepts CPU inputs, but returns GPU outputs (hence device = "mixed") 165 | self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB) 166 | 167 | # The rest of pre-processing is done on the GPU 168 | self.res = ops.Resize(device="gpu", resize_x=image_size, resize_y=image_size) 169 | self.norm = ops.NormalizePermute(device="gpu", output_dtype=types.FLOAT, 170 | mean=[128., 128., 128.], std=[128., 128., 128.], 171 | height=image_size, width=image_size) 172 | 173 | 174 | # epoch_size = number of (profile, frontal) image pairs in the dataset 175 | def epoch_size(self, name = None): 176 | return self.num_inputs 177 | 178 | 179 | # Define the flow of the data loading and pre-processing 180 | def define_graph(self): 181 | self.profiles = self.input(name="inputs") 182 | self.frontals = self.target(name="targets") 183 | profile_images = self.decode(self.profiles) 184 | profile_images = self.res(profile_images) 185 | profile_output = self.norm(profile_images) 186 | frontal_images = self.decode(self.frontals) 187 | frontal_images = self.res(frontal_images) 188 | frontal_output = self.norm(frontal_images) 189 | return (profile_output, frontal_output) 190 | 191 | 192 | def iter_setup(self): 193 | (images, targets) = self.iterator.next() 194 | self.feed_input(self.profiles, images) 195 | self.feed_input(self.frontals, targets) 196 | ``` 197 | 198 | You can now use the `ImagePipeline` class that you wrote above to load images from your dataset directory, one batch at a time. 199 | 200 | If you are using the code from this tutorial inside a Jupyter notebook, here is how you can use an `ImagePipeline` to display the images: 201 | ``` 202 | from __future__ import division 203 | import matplotlib.gridspec as gridspec 204 | import matplotlib.pyplot as plt 205 | %matplotlib inline 206 | 207 | 208 | def show_images(image_batch, batch_size): 209 | columns = 4 210 | rows = (batch_size + 1) // (columns) 211 | fig = plt.figure(figsize = (32,(32 // columns) * rows)) 212 | gs = gridspec.GridSpec(rows, columns) 213 | 214 | for j in range(rows*columns): 215 | plt.subplot(gs[j]) 216 | plt.axis("off") 217 | plt.imshow(np.transpose(image_batch.at(j), (1,2,0))) 218 | 219 | 220 | batch_size = 8 221 | pipe = ImagePipeline('my_dataset_directory', image_size=128, batch_size=batch_size) 222 | pipe.build() 223 | profiles, frontals = pipe.run() 224 | 225 | 226 | # The images returned by ImagePipeline are currently on the GPU 227 | # We need to copy them to the CPU via the asCPU() method in order to display them 228 | show_images(profiles.asCPU(), batch_size=batch_size) 229 | show_images(frontals.asCPU(), batch_size=batch_size) 230 | ``` 231 | 232 | ## Setting Up Your Neural Network 233 | 234 | Here comes the fun part, building the network's architecture! We assume that you are already somewhat familiar with the idea behind convolutional neural networks, the architecture of choice for many computer vision applications today. 235 | Beyond that, there are two main concepts that we will need for the face Frontalization project, that we shall touch upon in this section: 236 | 237 | * the Encoder / Decoder Network(s) and 238 | * the Generative Adversarial Network. 239 | 240 | ### Encoders and Decoders 241 | 242 | #### The Encoder 243 | 244 | As mentioned above, our network takes images that are sized 128 by 128 as input. Since the images are in colour (meaning 3 colour channels for each pixel), this results in the input being 3x128x128=49152 dimensional. Perhaps we do not need all 49152 values to describe a person's face? This turns out to be correct: we can get away with a mere 512 dimensional vector (which is simply another way of saying "512 numbers") to encode all the information that we care about. This is an example of *dimensionality reduction*: the Encoder network (paired with the network responsible for the inverse process, *decoding*) learns a lower dimensional representation of the input. The architecture of the Encoder may look something like this: 245 | ![encoder](figs/encoder.jpeg) 246 | 247 | Here we start with input that is 128x128 and has 3 channels. As we pass it through convolutional layers, the size of the input gets smaller and smaller (from 128x128 to 64x64 to 16x16 etc on the Figure above) whereas the number of channels grows (from 3 to 8 to 16 and so on). This reflects the fact that the deeper the convolutional layer, the more abstract are the features that it learns. In the end we get to a layer whose output is sized 1x1, yet has a very high number of channels: 256 in the example depicted above (or 512 in our own network). 256x1 and 1x256 are really the same thing, if you think about it, so another way to put it is that the output of the Encoder is 256 dimensional (with a single channel), so we have reduced the dimensionality of the original input from 49152 to 256! Why would we want to do that? Having this lower dimensional representation helps us prevent overfitting our final model to the training set. 248 | 249 | In the end, what we want is a representation (and hence, a model) that is precise enough to fit the training data well, yet does not overfit - meaning, that it can be generalised to the data it has not seen before as well. 250 | 251 | #### The Decoder 252 | 253 | As the name suggests, the Decoder's job is the inverse of that of the Encoder. In other words, it takes the low-dimensional representation output of the Encoder and has it go through *deconvolutional* layers (also known as the [transposed convolutional layers](https://datascience.stackexchange.com/questions/6107/what-are-deconvolutional-layers)). The architecture of the Decoder network is often symmetric to that of the Encoder, although this does not have to be the case. The Encoder and the Decoder are often combined into a single network, whose inputs and outputs are both images: 254 | 255 | ![generator](figs/generator.jpg) 256 | 257 | In our project this Encoder/Decoder network is called the Generator. The Generator takes in a profile image, and (if we do our job right) outputs a frontal one: 258 | 259 | ![facenn](figs/facenn.jpg) 260 | 261 | It is now time to write it using PyTorch. A [two dimensional convolutional layer](https://pytorch.org/docs/stable/nn.html#conv2d) can be created via `torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)`. You can now read off the architecture of our Generator network from the code snippet below: 262 | 263 | ##### `network.py` 264 | ``` 265 | import torch 266 | import torch.nn as nn 267 | import torch.nn.parallel 268 | import torch.optim as optim 269 | from torch.autograd import Variable 270 | 271 | 272 | def weights_init(m): 273 | classname = m.__class__.__name__ 274 | 275 | if classname.find('Conv') != -1: 276 | m.weight.data.normal_(0.0, 0.02) 277 | 278 | elif classname.find('BatchNorm') != -1: 279 | m.weight.data.normal_(1.0, 0.02) 280 | m.bias.data.fill_(0) 281 | 282 | 283 | ''' Generator network for 128x128 RGB images ''' 284 | class G(nn.Module): 285 | 286 | def __init__(self): 287 | super(G, self).__init__() 288 | 289 | self.main = nn.Sequential( 290 | # Input HxW = 128x128 291 | nn.Conv2d(3, 16, 4, 2, 1), # Output HxW = 64x64 292 | nn.BatchNorm2d(16), 293 | nn.ReLU(True), 294 | nn.Conv2d(16, 32, 4, 2, 1), # Output HxW = 32x32 295 | nn.BatchNorm2d(32), 296 | nn.ReLU(True), 297 | nn.Conv2d(32, 64, 4, 2, 1), # Output HxW = 16x16 298 | nn.BatchNorm2d(64), 299 | nn.ReLU(True), 300 | nn.Conv2d(64, 128, 4, 2, 1), # Output HxW = 8x8 301 | nn.BatchNorm2d(128), 302 | nn.ReLU(True), 303 | nn.Conv2d(128, 256, 4, 2, 1), # Output HxW = 4x4 304 | nn.BatchNorm2d(256), 305 | nn.ReLU(True), 306 | nn.Conv2d(256, 512, 4, 2, 1), # Output HxW = 2x2 307 | nn.MaxPool2d((2,2)), 308 | # At this point, we arrive at our low D representation vector, which is 512 dimensional. 309 | 310 | nn.ConvTranspose2d(512, 256, 4, 1, 0, bias = False), # Output HxW = 4x4 311 | nn.BatchNorm2d(256), 312 | nn.ReLU(True), 313 | nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), # Output HxW = 8x8 314 | nn.BatchNorm2d(128), 315 | nn.ReLU(True), 316 | nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), # Output HxW = 16x16 317 | nn.BatchNorm2d(64), 318 | nn.ReLU(True), 319 | nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False), # Output HxW = 32x32 320 | nn.BatchNorm2d(32), 321 | nn.ReLU(True), 322 | nn.ConvTranspose2d(32, 16, 4, 2, 1, bias = False), # Output HxW = 64x64 323 | nn.BatchNorm2d(16), 324 | nn.ReLU(True), 325 | nn.ConvTranspose2d(16, 3, 4, 2, 1, bias = False), # Output HxW = 128x128 326 | nn.Tanh() 327 | ) 328 | 329 | 330 | def forward(self, input): 331 | output = self.main(input) 332 | return output 333 | ``` 334 | 335 | 336 | ### Generative Adversarial Networks (GANs) 337 | 338 | Generative Adversarial Networks (GANs) are a very exciting deep learning development, which was introduced in a [2014 paper](https://arxiv.org/pdf/1406.2661.pdf) by Ian Goodfellow and collaborators. Without getting into too much detail, here is the idea behind GANs: there are two networks, a *generator* (perhaps our name choice for the Encoder/Decoder net above makes more sense now) and a *discriminator*. The Generator's job is to generate synthetic images, but what is the Discriminator to do? The Discriminator is supposed to tell the difference between the *real* images and the *fake* ones that were synthesised by the Generator. 339 | 340 | Usually, GAN training is carried out in an unsupervised manner. There is an unlabelled dataset of, say, images in a specific domain. The generator will generate some image given random noise as input. The discriminator is then trained to recognise the images from the dataset as *real* and the output of the generator as *fake*. As far as the discriminator is concerned, the two categories comprise a labelled dataset. If this sounds like a binary classification problem to you, you won't be surprised to hear that the loss function is the [binary cross entropy](https://ml-cheatsheet.readthedocs.io/en/latest/loss_functions.html). The task of the generator is to fool the discriminator. Here is how that is done: first, the generator gives its output to the discriminator. Naturally, that output depends on what the generator's trainable parameters are. The discriminator is not being trained at this point, rather it is used for inference. Instead, it is the *generator*'s weights that are updated in a way that gets the discriminator to accept (as in, label as "real") the synthesised outputs. The updating of the generator's and the discriminator's weights is done alternatively - once each for every batch, as you will see later when we discuss training our model. 341 | 342 | Since we are not trying to simply generate faces, the architecture of our Generator is a little different from the one described above (for one thing, it takes real images as inputs, not some random noise, and tries to incorporate certain features of those inputs in its outputs). Our loss function won't be just the cross-entropy either: we have to add an additional component that compares the generator's outputs to the target ones. This could be, for instance, a pixelwise mean square error, or a mean absolute error. These matters are going to be addressed in the Training section of this tutorial. 343 | 344 | Before we move on, let us complete the `network.py` file by providing the code for the Discriminator: 345 | 346 | ##### `network.py` \[continued\] 347 | ``` 348 | ''' Discriminator network for 128x128 RGB images ''' 349 | class D(nn.Module): 350 | 351 | def __init__(self): 352 | super(D, self).__init__() 353 | 354 | self.main = nn.Sequential( 355 | nn.Conv2d(3, 16, 4, 2, 1), 356 | nn.LeakyReLU(0.2, inplace = True), 357 | nn.Conv2d(16, 32, 4, 2, 1), 358 | nn.BatchNorm2d(32), 359 | nn.LeakyReLU(0.2, inplace = True), 360 | nn.Conv2d(32, 64, 4, 2, 1), 361 | nn.BatchNorm2d(64), 362 | nn.LeakyReLU(0.2, inplace = True), 363 | nn.Conv2d(64, 128, 4, 2, 1), 364 | nn.BatchNorm2d(128), 365 | nn.LeakyReLU(0.2, inplace = True), 366 | nn.Conv2d(128, 256, 4, 2, 1), 367 | nn.BatchNorm2d(256), 368 | nn.LeakyReLU(0.2, inplace = True), 369 | nn.Conv2d(256, 512, 4, 2, 1), 370 | nn.BatchNorm2d(512), 371 | nn.LeakyReLU(0.2, inplace = True), 372 | nn.Conv2d(512, 1, 4, 2, 1, bias = False), 373 | nn.Sigmoid() 374 | ) 375 | 376 | 377 | def forward(self, input): 378 | output = self.main(input) 379 | return output.view(-1) 380 | ``` 381 | 382 | As you can see, the architecture of the Discriminator is rather similar to that of the Generator, except that it seems to contain only the Encoder part of the latter. Indeed, the goal of the Discriminator is not to output an image, so there is no need for something like a Decoder. Instead, the Discriminator contains layers that process an input image (much like an Encoder would), with the goal of distinguishing real images from the synthetic ones. 383 | 384 | ## From DALI to PyTorch 385 | 386 | DALI is a wonderful tool that not only pre-processes images on the fly, but also provides plugins for several popular machine learning frameworks, including PyTorch. 387 | 388 | If you used PyTorch before, you may be familiar with its `torch.utils.data.Dataset` and `torch.utils.data.DataLoader` classes meant to ease the [pre-processing and loading of the data](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html). When using DALI, we combine the aforementioned `nvidia.dali.pipeline.Pipeline` with `nvidia.dali.plugin.pytorch.DALIGenericIterator` in order to accomplish the task. 389 | 390 | At this point, we are starting to get into the third, and last, `Python` file that is a part of the face frontalization project. First, let us get the imports out of the way. We'll also set the seeds for the randomised parts of our model in order to have better control over reproducibility of the results: 391 | 392 | ##### `main.py` 393 | ``` 394 | from __future__ import print_function 395 | import time 396 | import math 397 | import random 398 | import os 399 | from os import listdir 400 | from os.path import join 401 | from PIL import Image 402 | 403 | import numpy as np 404 | import torch 405 | import torch.nn as nn 406 | import torch.nn.parallel 407 | import torch.optim as optim 408 | import torchvision.utils as vutils 409 | from torch.autograd import Variable 410 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 411 | 412 | from data import ImagePipeline 413 | import network 414 | 415 | np.random.seed(42) 416 | random.seed(10) 417 | torch.backends.cudnn.deterministic = True 418 | torch.backends.cudnn.benchmark = False 419 | torch.manual_seed(999) 420 | 421 | # Where is your training dataset at? 422 | datapath = 'training_set' 423 | 424 | # You can also choose which GPU you want your model to be trained on below: 425 | gpu_id = 0 426 | device = torch.device("cuda", gpu_id) 427 | ``` 428 | In order to integrate the `ImagePipeline` class from `data.py` into your PyTorch model, you will need to make use of `DALIGenericIterator`. Constructing one is very straightforward: you only need to pass it a pipeline object, a list of labels for what your pipeline spits out, and the `epoch_size` of the pipeline. Here is what that looks like: 429 | 430 | ##### `main.py` \[continued\] 431 | ``` 432 | train_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=True, batch_size=30, device_id=gpu_id) 433 | train_pipe.build() 434 | m_train = train_pipe.epoch_size() 435 | print("Size of the training set: ", m_train) 436 | train_pipe_loader = DALIGenericIterator(train_pipe, ["profiles", "frontals"], m_train) 437 | ``` 438 | 439 | Now you are ready to train. 440 | 441 | ## Training The Model 442 | 443 | First, lets get ourselves some neural networks: 444 | 445 | ##### `main.py` \[continued\] 446 | ``` 447 | # Generator: 448 | netG = network.G().to(device) 449 | netG.apply(network.weights_init) 450 | 451 | # Discriminator: 452 | netD = network.D().to(device) 453 | netD.apply(network.weights_init) 454 | ``` 455 | 456 | ### The Loss Function 457 | 458 | Mathematically, training a neural network refers to updating its weights in a way that minimises the loss function. There are multiple choices to be made here, the most crucial, perhaps, being the form of the loss function. We have already touched upon it in our discussion of GANs in Step 3, so we know that we need the binary cross entropy loss for the discriminator, whose job is to classify the images as either *real* or *fake*. 459 | 460 | However, we also want a pixelwise loss function that will get the generated outputs to not only look like frontal images of people in general, but the **right** people - same ones that we see in the input profile images. The common ones to use are the so-called L1 loss and L2 loss: you might know them under the names of *Mean Absolute Error* and *Mean Squared Error* respectively. In the code below, we'll give you both (in addition to the cross entropy), together with a way to vary the relative importance you place on each of the three. 461 | 462 | ##### `main.py` \[continued\] 463 | ``` 464 | # Here is where you set how important each component of the loss function is: 465 | L1_factor = 0 466 | L2_factor = 1 467 | GAN_factor = 0.0005 468 | 469 | criterion = nn.BCELoss() # Binary cross entropy loss 470 | 471 | # Optimizers for the generator and the discriminator (Adam is a fancier version of gradient descent with a few more bells and whistles that is used very often): 472 | optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999)) 473 | optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999), eps = 1e-8) 474 | 475 | # Create a directory for the output files 476 | try: 477 | os.mkdir('output') 478 | except OSError: 479 | pass 480 | 481 | start_time = time.time() 482 | 483 | # Let's train for 30 epochs (meaning, we go through the entire training set 30 times): 484 | for epoch in range(30): 485 | 486 | # Lets keep track of the loss values for each epoch: 487 | loss_L1 = 0 488 | loss_L2 = 0 489 | loss_gan = 0 490 | 491 | # Your train_pipe_loader will load the images one batch at a time 492 | # The inner loop iterates over those batches: 493 | 494 | for i, data in enumerate(train_pipe_loader, 0): 495 | 496 | # These are your images from the current batch: 497 | profile = data[0]['profiles'] 498 | frontal = data[0]['frontals'] 499 | 500 | # TRAINING THE DISCRIMINATOR 501 | netD.zero_grad() 502 | real = Variable(frontal).type('torch.FloatTensor').to(device) 503 | target = Variable(torch.ones(real.size()[0])).to(device) 504 | output = netD(real) 505 | # D should accept the GT images 506 | errD_real = criterion(output, target) 507 | 508 | profile = Variable(profile).type('torch.FloatTensor').to(device) 509 | generated = netG(profile) 510 | target = Variable(torch.zeros(real.size()[0])).to(device) 511 | output = netD(generated.detach()) # detach() because we are not training G here 512 | 513 | # D should reject the synthetic images 514 | errD_fake = criterion(output, target) 515 | 516 | errD = errD_real + errD_fake 517 | errD.backward() 518 | # Update D 519 | optimizerD.step() 520 | 521 | # TRAINING THE GENERATOR 522 | netG.zero_grad() 523 | target = Variable(torch.ones(real.size()[0])).to(device) 524 | output = netD(generated) 525 | 526 | # G wants to : 527 | # (a) have the synthetic images be accepted by D (= look like frontal images of people) 528 | errG_GAN = criterion(output, target) 529 | 530 | # (b) have the synthetic images resemble the ground truth frontal image 531 | errG_L1 = torch.mean(torch.abs(real - generated)) 532 | errG_L2 = torch.mean(torch.pow((real - generated), 2)) 533 | 534 | errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2 535 | 536 | loss_L1 += errG_L1.item() 537 | loss_L2 += errG_L2.item() 538 | loss_gan += errG_GAN.item() 539 | 540 | errG.backward() 541 | # Update G 542 | optimizerG.step() 543 | 544 | if epoch == 0: 545 | print('First training epoch completed in ',(time.time() - start_time),' seconds') 546 | 547 | # reset the DALI iterator 548 | train_pipe_loader.reset() 549 | 550 | # Print the absolute values of three losses to screen: 551 | print('[%d/30] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train,)) 552 | 553 | # Save the inputs, outputs, and ground truth frontals to files: 554 | vutils.save_image(profile.data, 'output/%03d_input.jpg' % epoch, normalize=True) 555 | vutils.save_image(real.data, 'output/%03d_real.jpg' % epoch, normalize=True) 556 | vutils.save_image(generated.data, 'output/%03d_generated.jpg' % epoch, normalize=True) 557 | 558 | # Save the pre-trained Generator as well 559 | torch.save(netG,'output/netG_%d.pt' % epoch) 560 | ``` 561 | 562 | [Training GANs is notoriously difficult](https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b), but let us focus on the L2 loss, equal to the sum of squared differences between each pixel of the output and target images. Its value decreases with every epoch, and if we compare the generated images to the real ones of frontal faces, we see that our model is indeed learning to fit the training data: 563 | 564 | ![training](figs/training.jpeg) 565 | 566 | ![out](figs/22epoch.jpeg) 567 | 568 | In the figure above, the upper row are some of the inputs fed into our model during the 22nd training epoch, below are the frontal images generated by our GAN, and at the bottom is the row of the corresponding ground truth images. 569 | 570 | Next, lets see how the model performs on data that it has never seen before. 571 | 572 | ## Testing The Model 573 | 574 | We are going to test the model we trained in the previous section on the three subjects that appear in the comparison table in the paper [Beyond Face Rotation: Global and Local Perception GAN for Photorealistic and Identity Preserving Frontal View Synthesis](https://arxiv.org/abs/1704.04086). These subjects do not appear in our training set; instead, we put the corresponding images in a directory called `test_set` that has the same structure as the `training_set` above. The `test.py` code is going to load a pre-trained generator network that we saved during training, and put the input test images through it, generating the outputs: 575 | 576 | ### `test.py` 577 | ``` 578 | import torch 579 | import torchvision.utils as vutils 580 | from torch.autograd import Variable 581 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 582 | 583 | from data import ImagePipeline 584 | 585 | device = 'cuda' 586 | 587 | datapath = 'test_set' 588 | 589 | # Generate frontal images from the test set 590 | def frontalize(model, datapath, mtest): 591 | 592 | test_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=False, batch_size=mtest) 593 | test_pipe.build() 594 | test_pipe_loader = DALIGenericIterator(test_pipe, ["profiles", "frontals"], mtest) 595 | 596 | with torch.no_grad(): 597 | for data in test_pipe_loader: 598 | profile = data[0]['profiles'] 599 | frontal = data[0]['frontals'] 600 | generated = model(Variable(profile.type('torch.FloatTensor').to(device))) 601 | vutils.save_image(torch.cat((profile, generated.data, frontal)), 'output/test.jpg', nrow=mtest, padding=2, normalize=True) 602 | return 603 | 604 | # Load a pre-trained Pytorch model 605 | saved_model = torch.load("./output/netG_15.pt") 606 | 607 | frontalize(saved_model, datapath, 3) 608 | ``` 609 | Here are the results of the model above that has been trained for 15 epochs: 610 | 611 | ![out](figs/test.jpg) 612 | 613 | Again, here we see input images on top, followed by generated images in the middle and the ground truth ones at the bottom. Naturally, the agreement between the latter two is not as close as that for the images in the training set, yet we see that the network did in fact learn to pick up various facial features such as glasses, thickness of the eyebrows, eye and nose shape etc. In the end, one has to experiment with the hyperparameters of the model to see what works best. We managed to produce the following results in just five training epochs (which take about an hour and a half on the [Scaleway RENDER-S GPU instances](https://www.scaleway.com/gpu-instances/)): 614 | 615 | | input |![5epochs](figs/5_epochs_in.jpeg) | 616 | |---|---| 617 | | generated output | ![5epochs](figs/5_epochs_out.jpeg) | 618 | 619 | Here we have trained the GAN model with parameters as in the code above for the first three epochs, then set the `GAN_factor` to zero and continued to train only the generator, optimizing the L1 and L2 losses, for two more epochs. 620 | 621 | # Why use GANs for supervised ML problems? 622 | 623 | Generative Adversarial Networks were initially meant for [*unsupervised machine learning*](https://machinelearningmastery.com/supervised-and-unsupervised-machine-learning-algorithms/): a typical training set consists of the so-called "real" examples of some data (which are not labeled, unlike in *supervised learning*), and the goal is to train the network to generate more examples of the same sort of data. However, GANs are increasingly being used for tasks where training inputs have corresponding outputs (i.e., supervised machine learning). Examples include [face frontalization](https://arxiv.org/abs/1704.04086), [super resolution](https://medium.com/@jonathan_hui/gan-super-resolution-gan-srgan-b471da7270ec), etc. What is the benefit of introducing a GAN architecture into such problems? 624 | 625 | Let us compare the results of training the `Generator` network above in a supervised way using only the L2 pixelwise loss, with the combined `Generator`/`Discriminator` architecture. In the latter case, the `Discriminator` is trained to accept the *real* images, reject the *synethic* ones, and the `Generator` learns to fool the `Discriminator` in addition to optimizing its own L2 loss. 626 | 627 | In the first case, the generated images start to resemble human faces early on during training, but the fine features remain blurry for not only for the test set, but also for the training samples for a relatively long time: 628 | 629 | Images generated after 20 000 mini batch evaluations (inputs from the training set): 630 | 631 | ![](figs/L2_0epoch.jpeg) 632 | 633 | Images generated after 40 000 mini batch evaluations (inputs from the training set): 634 | 635 | ![](figs/L2_1epoch.jpeg) 636 | 637 | Mathematically, this can be attributed to the small contribution of these fine features (eyelids, shadows around the mouth, etc) to pixelwise loss. Since longer training times are required to achieve desired accuracy for the training set, this makes such models prone to over-fitting. Introducing GAN into the picture changes things considerably: 638 | 639 | Images generated after 20 000 mini batch evaluations (inputs from the training set): 640 | 641 | ![](figs/GAN_0epoch.jpeg) 642 | 643 | Images generated after 40 000 mini batch evaluations (inputs from the training set): 644 | 645 | ![](figs/GAN_1epoch.jpeg) 646 | 647 | While our training set does not have enough unique GT frontal images to train an unsupervised GAN with the likes of [NVIDIA's Style GAN](https://thenextweb.com/artificial-intelligence/2019/02/13/thispersondoesnotexist-com-is-face-generating-ai-at-its-creepiest), the GAN architecture turns out to be very good at generating the above-mentioned fine features, even though it introduces unwanted noise elsewhere. Evidently, these are the details that the discriminator uses to distinguish the *real* images from the *fake* ones, facilitating the generation of [photorealistic](https://arxiv.org/abs/1704.04086) and [super-resolved](https://medium.com/@jonathan_hui/gan-super-resolution-gan-srgan-b471da7270ec) synethic outputs. 648 | 649 | For your reference, below are the profile inputs and the ground truth frontals for the images generated above: 650 | 651 | ![](figs/0_input.jpeg) 652 | 653 | ![](figs/0_real.jpeg) 654 | 655 | ![](figs/1_input.jpeg) 656 | 657 | ![](figs/1_real.jpeg) 658 | 659 | 660 | *An extended version of this article was first published on the [Scaleway blog](https://blog.scaleway.com/2019/gpu-instances-using-deep-learning-to-obtain-frontal-rendering-of-facial-images/)* 661 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from random import shuffle 3 | import os 4 | from os import listdir 5 | from os.path import join 6 | 7 | import numpy as np 8 | from nvidia.dali.pipeline import Pipeline 9 | import nvidia.dali.ops as ops 10 | import nvidia.dali.types as types 11 | 12 | 13 | def is_jpeg(filename): 14 | return any(filename.endswith(extension) for extension in [".jpg", ".jpeg"]) 15 | 16 | 17 | def get_subdirs(directory): 18 | subdirs = sorted([join(directory,name) for name in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, name))]) 19 | return subdirs 20 | 21 | 22 | flatten = lambda l: [item for sublist in l for item in sublist] 23 | 24 | 25 | class ExternalInputIterator(object): 26 | 27 | def __init__(self, imageset_dir, batch_size, random_shuffle=False): 28 | self.images_dir = imageset_dir 29 | self.batch_size = batch_size 30 | 31 | # First, figure out what are the inputs and what are the targets in your directory structure: 32 | # Get a list of filenames for the target (frontal) images 33 | self.frontals = np.array([join(imageset_dir, frontal_file) for frontal_file in sorted(os.listdir(imageset_dir)) if is_jpeg(frontal_file)]) 34 | 35 | # Get a list of lists of filenames for the input (profile) images for each person 36 | profile_files = [[join(person_dir, profile_file) for profile_file in sorted(os.listdir(person_dir)) if is_jpeg(profile_file)] for person_dir in get_subdirs(imageset_dir)] 37 | 38 | # Build a flat list of frontal indices, corresponding to the *flattened* profile_files 39 | # The reason we are doing it this way is that we need to keep track of the multiple inputs corresponding to each target 40 | frontal_ind = [] 41 | for ind, profiles in enumerate(profile_files): 42 | frontal_ind += [ind]*len(profiles) 43 | self.frontal_indices = np.array(frontal_ind) 44 | 45 | # Now that we have built frontal_indices, we can flatten profile_files 46 | self.profiles = np.array(flatten(profile_files)) 47 | 48 | # Shuffle the (input, target) pairs if necessary: in practice, it is profiles and frontal_indices that get shuffled 49 | if random_shuffle: 50 | ind = np.array(range(len(self.frontal_indices))) 51 | shuffle(ind) 52 | self.profiles = self.profiles[ind] 53 | self.frontal_indices = self.frontal_indices[ind] 54 | 55 | 56 | def __iter__(self): 57 | self.i = 0 58 | self.n = len(self.frontal_indices) 59 | return self 60 | 61 | 62 | # Return a batch of (input, target) pairs 63 | def __next__(self): 64 | profiles = [] 65 | frontals = [] 66 | 67 | for _ in range(self.batch_size): 68 | profile_filename = self.profiles[self.i] 69 | frontal_filename = self.frontals[self.frontal_indices[self.i]] 70 | 71 | profile = open(profile_filename, 'rb') 72 | frontal = open(frontal_filename, 'rb') 73 | 74 | profiles.append(np.frombuffer(profile.read(), dtype = np.uint8)) 75 | frontals.append(np.frombuffer(frontal.read(), dtype = np.uint8)) 76 | 77 | profile.close() 78 | frontal.close() 79 | 80 | self.i = (self.i + 1) % self.n 81 | return (profiles, frontals) 82 | 83 | next = __next__ 84 | 85 | 86 | class ImagePipeline(Pipeline): 87 | ''' 88 | Constructor arguments: 89 | - imageset_dir: directory containing the dataset 90 | - image_size = 128: length of the square that the images will be resized to 91 | - random_shuffle = False 92 | - batch_size = 64 93 | - num_threads = 2 94 | - device_id = 0 95 | ''' 96 | 97 | def __init__(self, imageset_dir, image_size=128, random_shuffle=False, batch_size=64, num_threads=2, device_id=0): 98 | super(ImagePipeline, self).__init__(batch_size, num_threads, device_id, seed=12) 99 | eii = ExternalInputIterator(imageset_dir, batch_size, random_shuffle) 100 | self.iterator = iter(eii) 101 | self.num_inputs = len(eii.frontal_indices) 102 | 103 | # The source for the inputs and targets 104 | self.input = ops.ExternalSource() 105 | self.target = ops.ExternalSource() 106 | 107 | # nvJPEGDecoder below accepts CPU inputs, but returns GPU outputs (hence device = "mixed") 108 | self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB) 109 | 110 | # The rest of pre-processing is done on the GPU 111 | self.res = ops.Resize(device="gpu", resize_x=image_size, resize_y=image_size) 112 | self.norm = ops.NormalizePermute(device="gpu", output_dtype=types.FLOAT, 113 | mean=[128., 128., 128.], std=[128., 128., 128.], 114 | height=image_size, width=image_size) 115 | 116 | 117 | # epoch_size = number of (profile, frontal) image pairs in the dataset 118 | def epoch_size(self, name = None): 119 | return self.num_inputs 120 | 121 | 122 | # Define the flow of the data loading and pre-processing 123 | def define_graph(self): 124 | self.profiles = self.input(name="inputs") 125 | self.frontals = self.target(name="targets") 126 | profile_images = self.decode(self.profiles) 127 | profile_images = self.res(profile_images) 128 | profile_output = self.norm(profile_images) 129 | frontal_images = self.decode(self.frontals) 130 | frontal_images = self.res(frontal_images) 131 | frontal_output = self.norm(frontal_images) 132 | return (profile_output, frontal_output) 133 | 134 | 135 | def iter_setup(self): 136 | (images, targets) = self.iterator.next() 137 | self.feed_input(self.profiles, images) 138 | self.feed_input(self.frontals, targets) 139 | 140 | -------------------------------------------------------------------------------- /figs/0_input.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/0_input.jpeg -------------------------------------------------------------------------------- /figs/0_real.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/0_real.jpeg -------------------------------------------------------------------------------- /figs/1_input.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/1_input.jpeg -------------------------------------------------------------------------------- /figs/1_real.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/1_real.jpeg -------------------------------------------------------------------------------- /figs/22epoch.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/22epoch.jpeg -------------------------------------------------------------------------------- /figs/5_epochs_in.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/5_epochs_in.jpeg -------------------------------------------------------------------------------- /figs/5_epochs_out.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/5_epochs_out.jpeg -------------------------------------------------------------------------------- /figs/GAN_0epoch.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/GAN_0epoch.jpeg -------------------------------------------------------------------------------- /figs/GAN_1epoch.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/GAN_1epoch.jpeg -------------------------------------------------------------------------------- /figs/L2_0epoch.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/L2_0epoch.jpeg -------------------------------------------------------------------------------- /figs/L2_1epoch.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/L2_1epoch.jpeg -------------------------------------------------------------------------------- /figs/c_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/c_0.jpg -------------------------------------------------------------------------------- /figs/c_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/c_1.jpg -------------------------------------------------------------------------------- /figs/c_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/c_2.jpg -------------------------------------------------------------------------------- /figs/c_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/c_3.jpg -------------------------------------------------------------------------------- /figs/c_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/c_4.jpg -------------------------------------------------------------------------------- /figs/c_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/c_5.jpg -------------------------------------------------------------------------------- /figs/c_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/c_6.jpg -------------------------------------------------------------------------------- /figs/c_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/c_7.jpg -------------------------------------------------------------------------------- /figs/comments.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/comments.jpeg -------------------------------------------------------------------------------- /figs/dataset.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/dataset.jpeg -------------------------------------------------------------------------------- /figs/encoder.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/encoder.jpeg -------------------------------------------------------------------------------- /figs/facenn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/facenn.jpg -------------------------------------------------------------------------------- /figs/generator.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/generator.jpg -------------------------------------------------------------------------------- /figs/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/test.jpg -------------------------------------------------------------------------------- /figs/test.txt: -------------------------------------------------------------------------------- 1 | test 2 | -------------------------------------------------------------------------------- /figs/training.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/figs/training.jpeg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | import math 4 | import random 5 | import os 6 | from os import listdir 7 | from os.path import join 8 | from PIL import Image 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.optim as optim 15 | import torchvision.utils as vutils 16 | from torch.autograd import Variable 17 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 18 | 19 | from data import ImagePipeline 20 | import network 21 | 22 | np.random.seed(42) 23 | random.seed(10) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | torch.manual_seed(999) 27 | 28 | # Where is your training dataset at? 29 | datapath = 'training_set' 30 | 31 | # You can also choose which GPU you want your model to be trained on below: 32 | gpu_id = 0 33 | device = torch.device("cuda", gpu_id) 34 | 35 | train_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=True, batch_size=30, device_id=gpu_id) 36 | train_pipe.build() 37 | m_train = train_pipe.epoch_size() 38 | print("Size of the training set: ", m_train) 39 | train_pipe_loader = DALIGenericIterator(train_pipe, ["profiles", "frontals"], m_train) 40 | 41 | # Generator: 42 | netG = network.G().to(device) 43 | netG.apply(network.weights_init) 44 | 45 | # Discriminator: 46 | netD = network.D().to(device) 47 | netD.apply(network.weights_init) 48 | 49 | # Here is where you set how important each component of the loss function is: 50 | L1_factor = 0 51 | L2_factor = 1 52 | GAN_factor = 0.0005 53 | 54 | criterion = nn.BCELoss() # Binary cross entropy loss 55 | 56 | # Optimizers for the generator and the discriminator (Adam is a fancier version of gradient descent with a few more bells and whistles that is used very often): 57 | optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999)) 58 | optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999), eps = 1e-8) 59 | 60 | # Create a directory for the output files 61 | try: 62 | os.mkdir('output') 63 | except OSError: 64 | pass 65 | 66 | start_time = time.time() 67 | 68 | # Let's train for 30 epochs (meaning, we go through the entire training set 30 times): 69 | for epoch in range(30): 70 | 71 | # Lets keep track of the loss values for each epoch: 72 | loss_L1 = 0 73 | loss_L2 = 0 74 | loss_gan = 0 75 | 76 | # Your train_pipe_loader will load the images one batch at a time 77 | # The inner loop iterates over those batches: 78 | 79 | for i, data in enumerate(train_pipe_loader, 0): 80 | 81 | # These are your images from the current batch: 82 | profile = data[0]['profiles'] 83 | frontal = data[0]['frontals'] 84 | 85 | # TRAINING THE DISCRIMINATOR 86 | netD.zero_grad() 87 | real = Variable(frontal).type('torch.FloatTensor').to(device) 88 | target = Variable(torch.ones(real.size()[0])).to(device) 89 | output = netD(real) 90 | # D should accept the GT images 91 | errD_real = criterion(output, target) 92 | 93 | profile = Variable(profile).type('torch.FloatTensor').to(device) 94 | generated = netG(profile) 95 | target = Variable(torch.zeros(real.size()[0])).to(device) 96 | output = netD(generated.detach()) # detach() because we are not training G here 97 | 98 | # D should reject the synthetic images 99 | errD_fake = criterion(output, target) 100 | 101 | errD = errD_real + errD_fake 102 | errD.backward() 103 | # Update D 104 | optimizerD.step() 105 | 106 | # TRAINING THE GENERATOR 107 | netG.zero_grad() 108 | target = Variable(torch.ones(real.size()[0])).to(device) 109 | output = netD(generated) 110 | 111 | # G wants to : 112 | # (a) have the synthetic images be accepted by D (= look like frontal images of people) 113 | errG_GAN = criterion(output, target) 114 | 115 | # (b) have the synthetic images resemble the ground truth frontal image 116 | errG_L1 = torch.mean(torch.abs(real - generated)) 117 | errG_L2 = torch.mean(torch.pow((real - generated), 2)) 118 | 119 | errG = GAN_factor * errG_GAN + L1_factor * errG_L1 + L2_factor * errG_L2 120 | 121 | loss_L1 += errG_L1.item() 122 | loss_L2 += errG_L2.item() 123 | loss_gan += errG_GAN.item() 124 | 125 | errG.backward() 126 | # Update G 127 | optimizerG.step() 128 | 129 | if epoch == 0: 130 | print('First training epoch completed in ',(time.time() - start_time),' seconds') 131 | 132 | # reset the DALI iterator 133 | train_pipe_loader.reset() 134 | 135 | # Print the absolute values of three losses to screen: 136 | print('[%d/30] Training absolute losses: L1 %.7f ; L2 %.7f BCE %.7f' % ((epoch + 1), loss_L1/m_train, loss_L2/m_train, loss_gan/m_train,)) 137 | 138 | # Save the inputs, outputs, and ground truth frontals to files: 139 | vutils.save_image(profile.data, 'output/%03d_input.jpg' % epoch, normalize=True) 140 | vutils.save_image(real.data, 'output/%03d_real.jpg' % epoch, normalize=True) 141 | vutils.save_image(generated.data, 'output/%03d_generated.jpg' % epoch, normalize=True) 142 | 143 | # Save the pre-trained Generator as well 144 | torch.save(netG,'output/netG_%d.pt' % epoch) 145 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | 11 | if classname.find('Conv') != -1: 12 | m.weight.data.normal_(0.0, 0.02) 13 | 14 | elif classname.find('BatchNorm') != -1: 15 | m.weight.data.normal_(1.0, 0.02) 16 | m.bias.data.fill_(0) 17 | 18 | 19 | ''' Generator network for 128x128 RGB images ''' 20 | class G(nn.Module): 21 | 22 | def __init__(self): 23 | super(G, self).__init__() 24 | 25 | self.main = nn.Sequential( 26 | # Input HxW = 128x128 27 | nn.Conv2d(3, 16, 4, 2, 1), # Output HxW = 64x64 28 | nn.BatchNorm2d(16), 29 | nn.ReLU(True), 30 | nn.Conv2d(16, 32, 4, 2, 1), # Output HxW = 32x32 31 | nn.BatchNorm2d(32), 32 | nn.ReLU(True), 33 | nn.Conv2d(32, 64, 4, 2, 1), # Output HxW = 16x16 34 | nn.BatchNorm2d(64), 35 | nn.ReLU(True), 36 | nn.Conv2d(64, 128, 4, 2, 1), # Output HxW = 8x8 37 | nn.BatchNorm2d(128), 38 | nn.ReLU(True), 39 | nn.Conv2d(128, 256, 4, 2, 1), # Output HxW = 4x4 40 | nn.BatchNorm2d(256), 41 | nn.ReLU(True), 42 | nn.Conv2d(256, 512, 4, 2, 1), # Output HxW = 2x2 43 | nn.MaxPool2d((2,2)), 44 | # At this point, we arrive at our low D representation vector, which is 512 dimensional. 45 | 46 | nn.ConvTranspose2d(512, 256, 4, 1, 0, bias = False), # Output HxW = 4x4 47 | nn.BatchNorm2d(256), 48 | nn.ReLU(True), 49 | nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), # Output HxW = 8x8 50 | nn.BatchNorm2d(128), 51 | nn.ReLU(True), 52 | nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), # Output HxW = 16x16 53 | nn.BatchNorm2d(64), 54 | nn.ReLU(True), 55 | nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False), # Output HxW = 32x32 56 | nn.BatchNorm2d(32), 57 | nn.ReLU(True), 58 | nn.ConvTranspose2d(32, 16, 4, 2, 1, bias = False), # Output HxW = 64x64 59 | nn.BatchNorm2d(16), 60 | nn.ReLU(True), 61 | nn.ConvTranspose2d(16, 3, 4, 2, 1, bias = False), # Output HxW = 128x128 62 | nn.Tanh() 63 | ) 64 | 65 | 66 | def forward(self, input): 67 | output = self.main(input) 68 | return output 69 | 70 | 71 | ''' Discriminator network for 128x128 RGB images ''' 72 | class D(nn.Module): 73 | 74 | def __init__(self): 75 | super(D, self).__init__() 76 | self.main = nn.Sequential( 77 | nn.Conv2d(3, 16, 4, 2, 1), 78 | nn.LeakyReLU(0.2, inplace = True), 79 | nn.Conv2d(16, 32, 4, 2, 1), 80 | nn.BatchNorm2d(32), 81 | nn.LeakyReLU(0.2, inplace = True), 82 | nn.Conv2d(32, 64, 4, 2, 1), 83 | nn.BatchNorm2d(64), 84 | nn.LeakyReLU(0.2, inplace = True), 85 | nn.Conv2d(64, 128, 4, 2, 1), 86 | nn.BatchNorm2d(128), 87 | nn.LeakyReLU(0.2, inplace = True), 88 | nn.Conv2d(128, 256, 4, 2, 1), 89 | nn.BatchNorm2d(256), 90 | nn.LeakyReLU(0.2, inplace = True), 91 | nn.Conv2d(256, 512, 4, 2, 1), 92 | nn.BatchNorm2d(512), 93 | nn.LeakyReLU(0.2, inplace = True), 94 | nn.Conv2d(512, 1, 4, 2, 1, bias = False), 95 | nn.Sigmoid() 96 | ) 97 | 98 | 99 | def forward(self, input): 100 | output = self.main(input) 101 | return output.view(-1) 102 | -------------------------------------------------------------------------------- /pretrained/GoT_test/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/GoT_test/001.jpg -------------------------------------------------------------------------------- /pretrained/GoT_test/001/varys.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/GoT_test/001/varys.jpg -------------------------------------------------------------------------------- /pretrained/GoT_test/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/GoT_test/002.jpg -------------------------------------------------------------------------------- /pretrained/GoT_test/002/shae.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/GoT_test/002/shae.jpg -------------------------------------------------------------------------------- /pretrained/GoT_test/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/GoT_test/003.jpg -------------------------------------------------------------------------------- /pretrained/GoT_test/003/bran.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/GoT_test/003/bran.jpg -------------------------------------------------------------------------------- /pretrained/GoT_test/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/GoT_test/004.jpg -------------------------------------------------------------------------------- /pretrained/GoT_test/004/sansa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/GoT_test/004/sansa.jpg -------------------------------------------------------------------------------- /pretrained/Readme.md: -------------------------------------------------------------------------------- 1 | The enclosed `generator_v0.pt` is the Generator network of the [Frontalization GAN](https://github.com/scaleway/frontalization/blob/master/network.py) that has been trained on the [CMU Multi-PIE Face Database](http://www.cs.cmu.edu/afs/cs/project/PIE/MultiPie/Multi-Pie/Home.html) for 18 epochs. All images of the following subjects were excluded from training and used for validation: those numbered `214` through `236`, and `346`. 2 | 3 | Hyperparameters used: as those listed in [`main.py`](https://github.com/scaleway/frontalization/blob/master/main.py), except for the following: 4 | ``` 5 | L1_factor = 1 6 | L2_factor = 1 7 | GAN_factor = 0.001 8 | ``` 9 | 10 | Additionally, `GAN_factor` was set to zero for all odd-numbered epochs. Such alternating GAN on/off manner of training served as a regularization mechanism and resulted in more realistic-looking images at the end. 11 | 12 | The images in the Multi-PIE set are taken at set angles and lighting conditions. Although the training set contained around 650 000 image profile-frontal pairs, the number of unique subjects used for training was only around 300. This, and the lack of diversity in the set are the likely reasons why the final model does not generalize particularly well beyond the test set taken from the Multi-PIE Database (naturally, excluded during training). Within that test set it does, however, demonstrate reasonable performance even for large angle inputs: 13 | 14 | ![](test-Pie.jpg) 15 | 16 | (Top row: inputs; middle row: model outputs; bottom row: ground truth images) 17 | 18 | Outside of Multi-Pie, the model performs best on images of dark haired caucasian and asian individuals without facial hair (sadly, the latter places a severe limitation on running inference on Game of Thrones characters): 19 | 20 | ![](test-GoT.jpg) 21 | 22 | The figures above were generated using the code found in [`test.py`](https://github.com/scaleway/frontalization/blob/master/test.py). If using your own inference code, take care to use the same normalization procedures as in [`data.py`](https://github.com/scaleway/frontalization/blob/master/data.py) and resize images to 128x128. Enjoy! 23 | 24 | ![](faces.gif) 25 | -------------------------------------------------------------------------------- /pretrained/faces.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/faces.gif -------------------------------------------------------------------------------- /pretrained/generator_v0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/generator_v0.pt -------------------------------------------------------------------------------- /pretrained/test-GoT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/test-GoT.jpg -------------------------------------------------------------------------------- /pretrained/test-Pie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scaleway/frontalization/d7bb9151fd5e4ef617f261fb4ceae6e7c89c185c/pretrained/test-Pie.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.utils as vutils 3 | from torch.autograd import Variable 4 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 5 | 6 | from data import ImagePipeline 7 | 8 | device = 'cuda' 9 | 10 | datapath = 'test_set' 11 | 12 | # Generate frontal images from the test set 13 | def frontalize(model, datapath, mtest): 14 | 15 | test_pipe = ImagePipeline(datapath, image_size=128, random_shuffle=False, batch_size=mtest) 16 | test_pipe.build() 17 | test_pipe_loader = DALIGenericIterator(test_pipe, ["profiles", "frontals"], mtest) 18 | 19 | with torch.no_grad(): 20 | for data in test_pipe_loader: 21 | profile = data[0]['profiles'] 22 | frontal = data[0]['frontals'] 23 | generated = model(Variable(profile.type('torch.FloatTensor').to(device))) 24 | vutils.save_image(torch.cat((profile, generated.data, frontal)), 'output/test.jpg', nrow=mtest, padding=2, normalize=True) 25 | return 26 | 27 | # Load a pre-trained Pytorch model 28 | saved_model = torch.load("./output/netG_1.pt") 29 | 30 | frontalize(saved_model, datapath, 3) 31 | 32 | --------------------------------------------------------------------------------