├── AdaIN.py ├── README.md ├── architecture.jpg ├── model.py ├── style.py ├── tmp └── PROGRESS IMAGES ARE OUTPUT HERE ├── train.py ├── train ├── content │ └── content.jpg └── style │ └── style.jpg └── utils.py /AdaIN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AdaIN(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def mu(self, x): 9 | """ Takes a (n,c,h,w) tensor as input and returns the average across 10 | it's spatial dimensions as (h,w) tensor [See eq. 5 of paper]""" 11 | return torch.sum(x,(2,3))/(x.shape[2]*x.shape[3]) 12 | 13 | def sigma(self, x): 14 | """ Takes a (n,c,h,w) tensor as input and returns the standard deviation 15 | across it's spatial dimensions as (h,w) tensor [See eq. 6 of paper] Note 16 | the permutations are required for broadcasting""" 17 | return torch.sqrt((torch.sum((x.permute([2,3,0,1])-self.mu(x)).permute([2,3,0,1])**2,(2,3))+0.000000023)/(x.shape[2]*x.shape[3])) 18 | 19 | def forward(self, x, y): 20 | """ Takes a content embeding x and a style embeding y and changes 21 | transforms the mean and standard deviation of the content embedding to 22 | that of the style. [See eq. 8 of paper] Note the permutations are 23 | required for broadcasting""" 24 | return (self.sigma(y)*((x.permute([2,3,0,1])-self.mu(x))/self.sigma(x)) + self.mu(y)).permute([2,3,0,1]) 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Adaptive-Instance-Normalization 2 | 3 | A Pytorch implementation of the 2017 Huang et. al. paper "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization" [https://arxiv.org/abs/1703.06868](https://arxiv.org/abs/1703.06868) 4 | Written from scratch with essentially no reference to Xun Huangs implementation in lua/torch (can be found here: [https://github.com/xunhuang1995/AdaIN-style](https://github.com/xunhuang1995/AdaIN-style)) but I'm none the less incredbily greatful to Huang et. al. for writing such an outstandingly beautiful paper and making their method so clear and easy to implement! 5 | ![Architecture](./architecture.jpg) 6 | 7 | ## Requirements 8 | 9 | To run this model please install the latest version of pytorch, torchvision and CUDA. 10 | 11 | ## Loading Pretrained Weights 12 | 13 | I have made a set of pretrained weights availabe on google drive if you don't want to train the model yourself. You can find them here [https://drive.google.com/file/d/1094pChApSOA7qJZn68kEdNxKIwPWRdHn/view?usp=sharing](https://drive.google.com/file/d/1094pChApSOA7qJZn68kEdNxKIwPWRdHn/view?usp=sharing). 14 | Once downloaded just place it into the root directory of the repo and you're good to go. 15 | 16 | ## Usage 17 | 18 | To use the model for style transfer use the command `python style.pt `. 19 | The styled image will be saves as `output.jpg` in the currect directory. 20 | 21 | ## Traning The Model 22 | 23 | To train the model from scratch first download the datasets you want to use. The paper uses this [https://www.kaggle.com/c/painter-by-numbers/data](https://www.kaggle.com/c/painter-by-numbers/data) Kaggle dataset of Wiki Art images as its soure for style images and the MS-COCO common objecs in context dataset [https://cocodataset.org/](https://cocodataset.org/) for its content images. After you've downloaded the datasets (or a subset of them as they are both pretty large, 10s of GB) place the style images in the `train/style` directory and the content images in the `train/content` directory. 24 | 25 | To actully train the model just run `python -i train.py` which will start training and output previews of it's progress into the `tmp` directory every few interations. 26 | Every epoch the model will be saved to a file called `adain_model`. 27 | 28 | ## To Do 29 | * Add automatic gpu/cpu selection 30 | * Add explanatory text to loss printout 31 | * Implement Bias correction on moving average loss 32 | * Update default hyperparameters to match that of Huang 33 | * Train the model for longer and upload better pretrained weights 34 | * Add command line options for hyperparameters 35 | * Make `requirements.txt` file 36 | * Add more advanced runtime style interpolation and masking features from the paper 37 | * Add some examples to this readme 38 | -------------------------------------------------------------------------------- /architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CellEight/Pytorch-Adaptive-Instance-Normalization/c7b3849b6715ff5568658288a960a6cf0ee82c7d/architecture.jpg -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | from AdaIN import AdaIN 5 | 6 | class AdaINStyle(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | # For encoder load pretrained VGG19 model and remove layers upto relu4_1 10 | self.vgg = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19', pretrained=True) 11 | self.vgg = nn.Sequential(*list(self.vgg.features.children())[:21]) 12 | # Create AdaIN layer 13 | self.ada = AdaIN() 14 | # Use Sequential to define decoder [Just reverse of vgg with pooling replaced by nearest neigbour upscaling] 15 | self.dec = nn.Sequential(nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect' ), 16 | nn.ReLU(), 17 | nn.Upsample(scale_factor=2,mode='nearest'), 18 | nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'), 19 | nn.ReLU(), 20 | nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'), 21 | nn.ReLU(), 22 | nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'), 23 | nn.ReLU(), 24 | nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'), 25 | nn.ReLU(), 26 | nn.Upsample(scale_factor=2,mode='nearest'), 27 | nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'), 28 | nn.ReLU(), 29 | nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'), 30 | nn.ReLU(), 31 | nn.Upsample(scale_factor=2,mode='nearest'), 32 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'), 33 | nn.ReLU(), 34 | nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'), 35 | nn.ReLU() #Maybe change to a sigmoid to get into 0,1 range? 36 | ) 37 | 38 | def forward(self, c, s): 39 | """ x is a image containing content information, y is an image 40 | containing style information""" 41 | # Compute content and style embeddings 42 | self.c_emb = self.vgg(c) 43 | self.s_emb = self.vgg(s) 44 | # Use AdaIN layer to make the mean and variance of c_emb (content) into that of s_emb (style) 45 | self.t = self.ada(self.c_emb,self.s_emb) 46 | return self.dec(self.t) 47 | -------------------------------------------------------------------------------- /style.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from torchvision.utils import save_image 4 | from PIL import Image 5 | from torchvision import transforms 6 | 7 | def loadImage(filename): 8 | input_image = Image.open(filename).convert("RGB") 9 | preprocess = transforms.Compose([ 10 | transforms.Resize(512), 11 | transforms.ToTensor(), 12 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 13 | ]) 14 | input_tensor = preprocess(input_image) 15 | return input_tensor.unsqueeze(0) 16 | 17 | if __name__ == "__main__": 18 | # declare variables to store hooks even though they aren't nessary for inference you need them to load the model from file 19 | style_layers = ['1','6','10','20'] 20 | debug_layers = [0,3,5,7] 21 | activations = [None]*4 22 | debug_activations = [None]*4 23 | debug_grads = [None]*4 24 | # declare hook function 25 | def styleHook(i, module, input, output): 26 | global activations 27 | activations[i] = output 28 | 29 | def debugHook(i, module, input, output): 30 | global activations 31 | debug_activations[i] = output 32 | 33 | content_img = loadImage(sys.argv[1]) 34 | style_img = loadImage(sys.argv[2]) 35 | 36 | model = torch.load('adain_model') 37 | 38 | output_img = model(content_img, style_img) 39 | save_image(output_img, 'output.jpg') 40 | -------------------------------------------------------------------------------- /tmp/PROGRESS IMAGES ARE OUTPUT HERE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CellEight/Pytorch-Adaptive-Instance-Normalization/c7b3849b6715ff5568658288a960a6cf0ee82c7d/tmp/PROGRESS IMAGES ARE OUTPUT HERE -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Sort out the imports! 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision.utils import save_image 5 | from functools import partial 6 | from copy import deepcopy 7 | from model import * 8 | from utils import * 9 | 10 | 11 | class AdaINLoss(nn.Module): 12 | def __init__(self, lambda_): 13 | super().__init__() 14 | self.lambda_ = lambda_ 15 | # lambda is a hyperparameter the dictates the relative importance of content vs style 16 | # the greater lambda is the more the model will try to preserve style 17 | # the smaller lambda is the more the model will try and preserve content 18 | # [See equation 11 of the Paper] 19 | 20 | def contentLoss(self, content_emb, output_emb): 21 | """ Takes 2 embedding tensors generated by vgg and finds the L2 norm 22 | (ie. euclidan distance) between them. [See equation 12 of the Paper]""" 23 | return torch.norm(content_emb-output_emb) 24 | 25 | def styleLoss(self, style_activations, output_activations): 26 | """ Takes 2 lists of activation tensors hooked from vgg layers during 27 | forward passes using our style image and our ouput image as inputs. 28 | Computes the L2 norm between each of their means and standard deviations 29 | and returns the sum. [See equation 13 of the Paper]""" 30 | mu_sum = 0 31 | sigma_sum = 0 32 | for style_act, output_act in zip(style_activations, output_activations): 33 | mu_sum = torch.norm(mu(style_act)-mu(output_act)) 34 | sigma_sum = torch.norm(sigma(style_act)-sigma(output_act)) 35 | return mu_sum + sigma_sum 36 | 37 | def totalLoss(self, content_emb, output_emb, style_activations, output_activations): 38 | """ Calculates the overall loss. [See equation 11 of the Paper]""" 39 | content_loss = self.contentLoss(content_emb, output_emb) 40 | style_loss = self.styleLoss(style_activations, output_activations) 41 | #print(content_loss.item(), style_loss.item()) 42 | return content_loss+self.lambda_*style_loss 43 | 44 | def forward(self, content_emb, output_emb, style_activations, output_activations): 45 | """ For caculating single image loss please pass arguments with a batch size of 1. """ 46 | return self.totalLoss(content_emb, output_emb, style_activations, output_activations)/content_emb.shape[0] 47 | 48 | 49 | 50 | if __name__ == "__main__": 51 | # Set hyperparameters 52 | bs = 1 53 | epochs = 5 54 | lr=6e-4 55 | wd=0.001 56 | lambda_ = 7.5 57 | alpha = 0.1 58 | style_layers = ['1','6','10','20'] 59 | debug_layers = [0,3,5,7] 60 | # Load content and style datasets 61 | content_images = DataLoader(ImageDataset('./train/content'), batch_size=bs, shuffle=True, num_workers=0) 62 | style_images = DataLoader(ImageDataset('./train/style'), batch_size=bs, shuffle=True, num_workers=0) 63 | # Set cost function 64 | criterion = AdaINLoss(lambda_) 65 | # Load Model 66 | model = AdaINStyle() 67 | for p in model.vgg.parameters(): 68 | p.requires_grad = False 69 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) 70 | # declare variables to store hooks 71 | activations = [None]*4 72 | debug_activations = [None]*4 73 | debug_grads = [None]*4 74 | # declare hook function 75 | def styleHook(i, module, input, output): 76 | global activations 77 | activations[i] = output 78 | 79 | def debugHook(i, module, input, output): 80 | global activations 81 | debug_activations[i] = output 82 | 83 | # establish hooks in vgg 84 | for i, layer in enumerate(style_layers): 85 | model.vgg._modules[layer].register_forward_hook(partial(styleHook,i)) 86 | for i, layer in enumerate(debug_layers): 87 | model.dec._modules[str(layer)].register_forward_hook(partial(debugHook,i)) 88 | 89 | 90 | for epoch in range(epochs): 91 | i=0 92 | running_loss = 0 93 | for content_batch, style_batch in zip(content_images, style_images): 94 | i += 1 95 | optimizer.zero_grad() 96 | output = model(content_batch,style_batch) 97 | content_emb = model.t 98 | style_activations = deepcopy(activations) 99 | output_emb = model.vgg(output) 100 | output_activations = activations 101 | #print(output, content_emb) 102 | loss = criterion(content_emb, output_emb, style_activations, output_activations) 103 | if torch.isnan(loss).item(): 104 | print("Got nan, here's the activations:") 105 | print(debug_activations) 106 | print(debug_grads_old) 107 | quit() 108 | loss.backward() 109 | debug_grads_old = deepcopy(debug_grads) 110 | debug_grads = [model.dec[layer].weight.grad for layer in debug_layers] 111 | optimizer.step() 112 | running_loss = alpha*loss.item() + (1-alpha)*running_loss 113 | print(running_loss) 114 | #for layer in model: 115 | # print(layer.grad) 116 | save_image(output, f"./tmp/{epoch}_{i}_o.png",2) 117 | save_image(torch.sigmoid(content_batch), f"./tmp/{epoch}_{i}_c.png",2) 118 | save_image(torch.sigmoid(style_batch), f"./tmp/{epoch}_{i}_s.png",2) 119 | #save_image(torch.sigmoid(output), f"./tmp/o_{epoch}_{i}.png") 120 | print(f'Epoch: [{epoch}/{epochs}], Loss: {loss}') 121 | torch.save(model, 'adain_model') 122 | -------------------------------------------------------------------------------- /train/content/content.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CellEight/Pytorch-Adaptive-Instance-Normalization/c7b3849b6715ff5568658288a960a6cf0ee82c7d/train/content/content.jpg -------------------------------------------------------------------------------- /train/style/style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CellEight/Pytorch-Adaptive-Instance-Normalization/c7b3849b6715ff5568658288a960a6cf0ee82c7d/train/style/style.jpg -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from PIL import Image 3 | from torchvision import transforms 4 | import glob 5 | import torch 6 | 7 | class ImageDataset(Dataset): 8 | def __init__(self, path): 9 | super().__init__() 10 | self.files = glob.glob(path+'/*.jpg') 11 | 12 | def __len__(self): 13 | return len(self.files) 14 | 15 | def __getitem__(self, idx): 16 | """ Load an input image from filename and convert to a format vgg will accept. 17 | Taken essentialy unedited from docs: https://pytorch.org/hub/pytorch_vision_vgg/""" 18 | input_image = Image.open(self.files[idx]).convert("RGB") 19 | preprocess = transforms.Compose([ 20 | transforms.Resize(256), 21 | transforms.CenterCrop(224), 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ]) 25 | input_tensor = preprocess(input_image) 26 | return input_tensor 27 | 28 | def mu(x): 29 | """ Takes a (n,c,h,w) tensor as input and returns the average across 30 | it's spatial dimensions as (h,w) tensor [See eq. 5 of paper]""" 31 | return torch.sum(x,(2,3))/(x.shape[2]*x.shape[3]) 32 | 33 | def sigma(x): 34 | """ Takes a (n,c,h,w) tensor as input and returns the standard deviation 35 | across it's spatial dimensions as (h,w) tensor [See eq. 6 of paper] Note 36 | the permutations are required for broadcasting""" 37 | return torch.sqrt(torch.sum((x.permute([2,3,0,1])-mu(x)).permute([2,3,0,1])**2,(2,3))/(x.shape[2]*x.shape[3])) 38 | --------------------------------------------------------------------------------