├── Images ├── NewYork_small.jpg ├── Out_Style1_small.jpg ├── Out_Style2_small.jpg ├── Out_Style3_small.jpg ├── Out_Style4_small.jpg ├── Out_Style5_small.jpg ├── style1_small.jpg ├── style2_small.jpg ├── style3_small.jpg ├── style4_small.jpg └── style5_small.jpg ├── LICENSE ├── README.md ├── Style Transfer Test.ipynb ├── Style Transfer Train.ipynb ├── Utils ├── __init__.py ├── networks.py └── utils.py └── examples └── architecture.jpg /Images/NewYork_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/NewYork_small.jpg -------------------------------------------------------------------------------- /Images/Out_Style1_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style1_small.jpg -------------------------------------------------------------------------------- /Images/Out_Style2_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style2_small.jpg -------------------------------------------------------------------------------- /Images/Out_Style3_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style3_small.jpg -------------------------------------------------------------------------------- /Images/Out_Style4_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style4_small.jpg -------------------------------------------------------------------------------- /Images/Out_Style5_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style5_small.jpg -------------------------------------------------------------------------------- /Images/style1_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style1_small.jpg -------------------------------------------------------------------------------- /Images/style2_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style2_small.jpg -------------------------------------------------------------------------------- /Images/style3_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style3_small.jpg -------------------------------------------------------------------------------- /Images/style4_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style4_small.jpg -------------------------------------------------------------------------------- /Images/style5_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style5_small.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 MAlberts99 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-AdaIN-StyleTransfer 2 | This project is an unofficial PyTorch implementation of the paper using Google Colab: [**Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization**](https://arxiv.org/abs/1703.06868) 3 | 4 | All credit goes to: [Xun Huang](http://www.cs.cornell.edu/~xhuang/) and 5 | [Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/) 6 | 7 | 8 | ## Description 9 | The paper presents a new style transfer algorithm, which uses a fixed pretrained vgg19 (up to ReLU 4.1) to encode a style image and a content image. Then the style of the style image is transferred to the content image. The novel approach this paper describes uses an AdaIN layer to transfer the style. This layer first normalises the the content image to unit mean and standard deviation. After that, the content image is scaled such that it's mean and standard deviation are equal to the mean and standard deviation of the style image. Then the image is decoded using a decoder that mirrors the vgg19. 10 |

11 | 12 |

13 | 14 | ## Requirements 15 | - A google drive account to run the notebooks. 16 | - A pretrained vgg19 pth file. I used the file provided by [Naoto Inoue](https://github.com/naoto0804/pytorch-AdaIN) in his implementation of the same paper. Link: [vgg_normalised.pth](https://drive.google.com/file/d/108uza-dsmwvbW2zv-G73jtVcMU_2Nb7Y/view). 17 | 18 | 19 | To train: 20 | - [2015 Coco Image Dataset, 13GB](http://images.cocodataset.org/zips/test2015.zip) 21 | - [WikiArt Dataset, 25.4GB](http://web.fsktm.um.edu.my/~cschan/source/ICIP2017/wikiart.zip) 22 | 23 | A note on the Datasets: The free version of Google Colab only has 30GB of usable storage while using the GPU. Thus you may have to reduce the size of the dataset. In this implementation I used 40k images of each dataset. 24 | 25 | ## Trained Model 26 | You can download my model from [here](https://drive.google.com/file/d/1-96gmMVd1wYbP0WNM7KMF2gCIqtQ6ZtA/view?usp=sharing). It has been trained for 120.000 iteration and provides an image quality close to the offical implementation. The style weight (gamma) used was 2.0. 27 | 28 | ## Manual 29 | - Copy the content of this repository into a Google Drive folder. Then download the pretrained vgg19 file and place it in the same folder. If you want to play around with the network add the pretrained network file as well. If you want to train the network from scratch, e.g change the style weight, download the datasets as well. 30 | #### Interference 31 | - Open the Style Transfer Test notebook. In the first cell you need to specify the directory of the folder in your Google Drive. Additionally if you changed the image folder you also need to change the `img_dir` variable accordingly. 32 | - The next cell will load the network. 33 | - Then the images are loaded. Here you can choose your images. 34 | - In the next cell you can change the `alpha` variable. This variable influences the impact of the style image on the content image. 35 | #### Training 36 | - Open the Style Transfer Train notebook. In the first cell you need to specify the directory of the folder in your Google Drive. 37 | - Then you will have to download/import the datasets into the colab instance. I have not implemented this step as depending on your storage you will need to reduce the amount of images of each dataset used. 38 | - Change the `pathStyleImages` and `pathContentImages` to the folders containing the images. Note the folder needs to only contain images. Nested folders are not supported 39 | - Then run the rest of the cells. 40 | 41 | ## Results 42 | |Style | Generated Image | 43 | | :----: | :----: | 44 | |![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/style1_small.jpg)|![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/Out_Style1_small.jpg)| 45 | |![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/style2_small.jpg)|![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/Out_Style2_small.jpg)| 46 | |![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/style3_small.jpg)|![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/Out_Style3_small.jpg)| 47 | |![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/style4_small.jpg)|![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/Out_Style4_small.jpg)| 48 | |![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/style5_small.jpg)|![](https://github.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/blob/master/Images/Out_Style5_small.jpg)| 49 | 50 | As can be seen above the results are not quite as good as the one presented in the paper. This can be explained by the model in the paper being trained for 160.000 iterations. I only trained mine for 120.000. Additionally, the original model was trained on 80.000 images of each type in whereas I only trained on 40.000 images. 51 | 52 | -------------------------------------------------------------------------------- /Style Transfer Train.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Style Transfer Train.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"0HDAX91xcOA2","colab_type":"code","colab":{}},"source":["import torch\n","import torchvision\n","import torchvision.transforms as transforms\n","from torch.utils.data import DataLoader, Dataset\n","\n","from tqdm.notebook import tqdm\n","from google.colab import drive\n","import numpy as np\n","import os\n","import sys\n","\n","manualSeed = 999\n","torch.manual_seed(manualSeed)\n","\n","drive.mount(\"/content/gdrive\")\n","path = \"/content/gdrive/My Drive/Colab Notebooks/GANS/Style Transfer\"\n","\n","\n","sys.path.append(path)\n","from Utils import networks"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"LBxpcfaNybWC","colab_type":"code","colab":{}},"source":["# Images Dataset that returns one style and one content image. As I only trained using 40.000\n","# images each, each image is randomly sampled. The way it is implemented does not allow multi-threading. However\n","# as this network is relatively small and training times low, no improved class was implemented.\n","\n","class Images(Dataset): \n"," def __init__(self, root_dir1, root_dir2, transform=None):\n"," self.root_dir1 = root_dir1\n"," self.root_dir2 = root_dir2\n"," self.transform = transform\n","\n"," def __len__(self):\n"," return min(len(os.listdir(self.root_dir1)), len(os.listdir(self.root_dir2)))\n","\n"," def __getitem__(self, idx):\n"," all_names1, all_names2 = os.listdir(self.root_dir1), os.listdir(self.root_dir2)\n"," idx1, idx2 = np.random.randint(0, len(all_names1)), np.random.randint(0, len(all_names2))\n","\n"," img_name1, img_name2 = os.path.join(self.root_dir1, all_names1[idx1]), os.path.join(self.root_dir2, all_names2[idx2])\n"," image1 = Image.open(img_name1).convert(\"RGB\")\n"," image2 = Image.open(img_name2).convert(\"RGB\")\n","\n"," if self.transform:\n"," image1 = self.transform(image1)\n"," image2 = self.transform(image2)\n","\n"," return image1, image2 "],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MIVw4MQo6WKu","colab_type":"code","colab":{}},"source":["# To note is that the images are not normalised\n","transform = transforms.Compose([transforms.Resize(512), \n"," transforms.CenterCrop(256),\n"," transforms.ToTensor()])\n","\n","\n","# Specify the path to the style and content images\n","pathStyleImages = \"/content/Data/Wiki_40k\"\n","pathContentImages = \"/content/Data/Coco_40k\" \n","\n","\n","all_img = Images(pathStyleImages, pathContentImages, transform=transform)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"bGd2_znLxLIy","colab_type":"code","colab":{}},"source":["# Simple save \n","def save_state(decoder, optimiser, iters, run_dir):\n"," \n"," name = \"StyleTransfer Checkpoint Iter: {}.tar\".format(iters)\n"," torch.save({\"Decoder\" : decoder,\n"," \"Optimiser\" : optimiser,\n"," \"iters\": iters\n"," }, os.path.join(path, name))\n"," print(\"Saved : {} succesfully\".format(name))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"nNf9XPp2HgCO","colab_type":"code","colab":{},"executionInfo":{"status":"ok","timestamp":1600861736262,"user_tz":-120,"elapsed":825,"user":{"displayName":"Marvin Alberts","photoUrl":"","userId":"07281461137485771510"}}},"source":["def training_loop(network, # StyleTransferNetwork\n"," dataloader_comb, # DataLoader\n"," n_epochs, # Number of Epochs\n"," run_dir # Directory in which the checkpoints and tensorboard files are saved\n"," ):\n"," \n","\n"," writer = SummaryWriter(os.path.join(path, run_dir))\n"," # Fixed images to compare over time\n"," fixed_batch_style, fixed_batch_content = all_img[0]\n"," fixed_batch_style, fixed_batch_content = fixed_batch_style.unsqueeze(0).to(device), fixed_batch_content.unsqueeze(0).to(device) # Move images to device\n","\n"," writer.add_image(\"Style\", torchvision.utils.make_grid(fixed_batch_style))\n"," writer.add_image(\"Content\", torchvision.utils.make_grid(fixed_batch_content))\n","\n"," iters = network.iters\n","\n"," for epoch in range(1, n_epochs+1):\n"," tqdm_object = tqdm(dataloader_comb, total=len(dataloader_comb))\n","\n"," for style_imgs, content_imgs in tqdm_object:\n"," network.adjust_learning_rate(network.optimiser, iters)\n"," style_imgs = style_imgs.to(device)\n"," content_imgs = content_imgs.to(device)\n","\n"," loss_comb, content_loss, style_loss = network(style_imgs, content_imgs)\n","\n"," network.optimiser.zero_grad()\n"," loss_comb.backward()\n"," network.optimiser.step()\n","\n"," # Update status bar, add Loss, add Images\n"," tqdm_object.set_postfix_str(\"Combined Loss: {:.3f}, Style Loss: {:.3f}, Content Loss: {:.3f}\".format(\n"," loss_comb.item()*100, style_loss.item()*100, content_loss.item()*100))\n"," \n"," if iters % 25 == 0:\n"," writer.add_scalar(\"Combined Loss\", loss_comb*1000, iters)\n"," writer.add_scalar(\"Style Loss\", style_loss*1000, iters)\n"," writer.add_scalar(\"Content Loss\", content_loss*1000, iters)\n","\n"," if (iters+1) % 2000 == 1:\n"," with torch.no_grad():\n"," network.set_train(False)\n"," images = network(fixed_batch_style, fixed_batch_content)\n"," img_grid = torchvision.utils.make_grid(images)\n"," writer.add_image(\"Progress Iter: {}\".format(iters), img_grid)\n"," network.set_train(True)\n","\n"," if (iters+1) % 4000 == 1:\n"," save_state(network.decoder.state_dict(), network.optimiser.state_dict(), iters, run_dir)\n"," writer.close()\n"," writer = SummaryWriter(os.path.join(path, run_dir))\n","\n"," iters += 1"],"execution_count":1,"outputs":[]},{"cell_type":"code","metadata":{"id":"iKg_C0K6leFp","colab_type":"code","colab":{}},"source":["device = (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","learning_rate = 1e-4\n","learning_rate_decay = 5e-5\n","\n","dataloader_comb = DataLoader(all_img, batch_size=5, shuffle=True, num_workers=0, drop_last=True)\n","gamma = torch.tensor([2]).to(device) # Style weight\n","\n","n_epochs = 5\n","run_dir = \"runs/Run 1\" # Change if you want to save the checkpoints/tensorboard files in a different directory\n","\n","state_encoder = torch.load(os.path.join(path, \"vgg_normalised.pth\"))\n","network = networks.StyleTransferNetwork(device,\n"," state_encoder,\n"," learning_rate,\n"," learning_rate_decay,\n"," gamma,\n"," load_fromstate=False,\n"," load_path=os.path.join(path, \"StyleTransfer Checkpoint Iter: 120000.tar\"))\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5_huZgs0ClUJ","colab_type":"code","colab":{}},"source":["training_loop(network, dataloader_comb, n_epochs, run_dir)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"4m-uIZxqXeoD","colab_type":"code","colab":{}},"source":[""],"execution_count":null,"outputs":[]}]} -------------------------------------------------------------------------------- /Utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["utils", "networks"] 2 | -------------------------------------------------------------------------------- /Utils/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import os 6 | from .utils import * 7 | 8 | # The style transfer network 9 | class StyleTransferNetwork(nn.Module): 10 | def __init__(self, 11 | device, # "cpu" for cpu, "cuda" for gpu 12 | enc_state_dict, # The state dict of the pretrained vgg19 13 | learning_rate=1e-4, 14 | learning_rate_decay=5e-5, # Decay parameter for the learning rate 15 | gamma=2.0, # Controls importance of StyleLoss vs ContentLoss, Loss = gamma*StyleLoss + ContentLoss 16 | train=True, # Wether or not network is training 17 | load_fromstate=False, # Load from checkpoint? 18 | load_path=None # Path to load checkpoint 19 | ): 20 | super().__init__() 21 | 22 | assert device in ["cpu", "cuda"] 23 | if load_fromstate and not os.path.isfile(load_path): 24 | raise ValueError("Checkpoint file not found") 25 | 26 | 27 | self.learning_rate = learning_rate 28 | self.learning_rate_decay = learning_rate_decay 29 | self.train = train 30 | self.gamma = gamma 31 | 32 | self.encoder = Encoder(enc_state_dict, device) # A pretrained vgg19 is used as the encoder 33 | self.decoder = Decoder().to(device) 34 | 35 | self.optimiser = optim.Adam(self.decoder.parameters(), lr=self.learning_rate) 36 | self.iters = 0 37 | 38 | if load_fromstate: 39 | state = torch.load(load_path) 40 | self.decoder.load_state_dict(state["Decoder"]) 41 | self.optimiser.load_state_dict(state["Optimiser"]) 42 | self.iters = state["iters"] 43 | 44 | 45 | def set_train(self, boolean): # Change state of network 46 | assert type(boolean) == bool 47 | self.train = boolean 48 | 49 | def adjust_learning_rate(self, optimiser, iters): # Simple learning rate decay 50 | lr = learning_rate / (1.0 + learning_rate_decay * iters) 51 | for param_group in optimiser.param_groups: 52 | param_group['lr'] = lr 53 | 54 | def forward(self, style, content, alpha=1.0): # Alpha can be used while testing to control the importance of the transferred style 55 | 56 | # Encode style and content 57 | layers_style = self.encoder(style, self.train) # if train: returns all states 58 | layer_content = self.encoder(content, False) # for the content only the last layer is important 59 | 60 | # Transfer Style 61 | if self.train: 62 | style_applied = AdaIn(layer_content, layers_style[-1]) # Last layer is "style" layer 63 | else: 64 | style_applied = alpha*AdaIn(layer_content, layers_style) + (1-alpha)*layer_content # Alpha controls magnitude of style 65 | 66 | # Scale up 67 | style_applied_upscaled = self.decoder(style_applied) 68 | if not self.train: 69 | return style_applied_upscaled # When not training return transformed image 70 | 71 | # Compute Loss 72 | layers_style_applied = self.encoder(style_applied_upscaled, self.train) 73 | 74 | content_loss = Content_loss(layers_style_applied[-1], layer_content) 75 | style_loss = Style_loss(layers_style_applied, layers_style) 76 | 77 | loss_comb = content_loss + self.gamma*style_loss 78 | 79 | return loss_comb, content_loss, style_loss 80 | 81 | # The decoder is a reversed vgg19 up to ReLU 4.1. To note is that the last layer is not activated. 82 | 83 | class Decoder(nn.Module): 84 | def __init__(self): 85 | super().__init__() 86 | 87 | self.padding = nn.ReflectionPad2d(padding=1) # Using reflection padding as described in vgg19 88 | self.UpSample = nn.Upsample(scale_factor=2, mode="nearest") 89 | 90 | self.conv4_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=0) 91 | 92 | self.conv3_1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0) 93 | self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0) 94 | self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0) 95 | self.conv3_4 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=0) 96 | 97 | self.conv2_1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0) 98 | self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=0) 99 | 100 | self.conv1_1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0) 101 | self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=0) 102 | 103 | 104 | def forward(self, x): 105 | out = self.UpSample(F.relu(self.conv4_1(self.padding(x)))) 106 | 107 | out = F.relu(self.conv3_1(self.padding(out))) 108 | out = F.relu(self.conv3_2(self.padding(out))) 109 | out = F.relu(self.conv3_3(self.padding(out))) 110 | out = self.UpSample(F.relu(self.conv3_4(self.padding(out)))) 111 | 112 | out = F.relu(self.conv2_1(self.padding(out))) 113 | out = self.UpSample(F.relu(self.conv2_2(self.padding(out)))) 114 | 115 | out = F.relu(self.conv1_1(self.padding(out))) 116 | out = self.conv1_2(self.padding(out)) 117 | return out 118 | 119 | # A vgg19 Sequential which is used up to Relu 4.1. To note is that the 120 | # first layer is a 3,3 convolution, different from a standard vgg19 121 | 122 | class Encoder(nn.Module): 123 | def __init__(self, state_dict, device): 124 | super().__init__() 125 | self.vgg19 = nn.Sequential( 126 | nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1), 127 | nn.ReflectionPad2d(padding=1), 128 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3), 129 | nn.ReLU(inplace=True), # First layer from which Style Loss is calculated 130 | nn.ReflectionPad2d(padding=1), 131 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3), 132 | nn.ReLU(inplace=True), 133 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True), 134 | nn.ReflectionPad2d(padding=1), 135 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3), 136 | nn.ReLU(inplace=True), # Second layer from which Style Loss is calculated 137 | nn.ReflectionPad2d(padding=1), 138 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3), 139 | nn.ReLU(inplace=True), 140 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True), 141 | nn.ReflectionPad2d(padding=1), # Third layer from which Style Loss is calculated 142 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3), 143 | nn.ReLU(inplace=True), 144 | nn.ReflectionPad2d(padding=1), 145 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3), 146 | nn.ReLU(inplace=True), 147 | nn.ReflectionPad2d(padding=1), 148 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3), 149 | nn.ReLU(inplace=True), 150 | nn.ReflectionPad2d(padding=1), 151 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3), 152 | nn.ReLU(inplace=True), 153 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True), 154 | nn.ReflectionPad2d(padding=1), 155 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3), 156 | nn.ReLU(inplace=True), # This is Relu 4.1 The output layer of the encoder. 157 | nn.ReflectionPad2d(padding=1), 158 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3), 159 | nn.ReLU(inplace=True), 160 | nn.ReflectionPad2d(padding=1), 161 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3), 162 | nn.ReLU(inplace=True), 163 | nn.ReflectionPad2d(padding=1), 164 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3), 165 | nn.ReLU(inplace=True), 166 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True), 167 | nn.ReflectionPad2d(padding=1), 168 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3), 169 | nn.ReLU(inplace=True), 170 | nn.ReflectionPad2d(padding=1), 171 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3), 172 | nn.ReLU(inplace=True), 173 | nn.ReflectionPad2d(padding=1), 174 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3), 175 | nn.ReLU(inplace=True), 176 | nn.ReflectionPad2d(padding=1), 177 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3), 178 | nn.ReLU(inplace=True) 179 | ).to(device) 180 | 181 | self.vgg19.load_state_dict(state_dict) 182 | 183 | encoder_children = list(self.vgg19.children()) 184 | self.EncoderList = nn.ModuleList([nn.Sequential(*encoder_children[:4]), # Up to Relu 1.1 185 | nn.Sequential(*encoder_children[4:11]), # Up to Relu 2.1 186 | nn.Sequential(*encoder_children[11:18]), # Up to Relu 3.1 187 | nn.Sequential(*encoder_children[18:31]), # Up to Relu 4.1, also the 188 | ]) # input for the decoder 189 | 190 | def forward(self, x, intermediates=False): # if training use intermediates = True, to get the output of 191 | states = [] # all the encoder layers to calculate the style loss 192 | for i in range(len(self.EncoderList)): 193 | x = self.EncoderList[i](x) 194 | 195 | if intermediates: # All intermediate states get saved in states 196 | states.append(x) 197 | if intermediates: 198 | return states 199 | return x 200 | -------------------------------------------------------------------------------- /Utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # Calculates mean and std channel-wise 5 | def calc_mean_std(input, eps=1e-5): 6 | batch_size, channels = input.shape[:2] 7 | 8 | reshaped = input.view(batch_size, channels, -1) # Reshape channel wise 9 | mean = torch.mean(reshaped, dim = 2).view(batch_size, channels, 1, 1) # Calculat mean and reshape 10 | std = torch.sqrt(torch.var(reshaped, dim=2)+eps).view(batch_size, channels, 1, 1) # Calculate variance, add epsilon (avoid 0 division), 11 | # calculate std and reshape 12 | return mean, std 13 | 14 | def AdaIn(content, style): 15 | assert content.shape[:2] == style.shape[:2] # Only first two dim, such that different image sizes is possible 16 | batch_size, n_channels = content.shape[:2] 17 | mean_content, std_content = calc_mean_std(content) 18 | mean_style, std_style = calc_mean_std(style) 19 | 20 | output = std_style*((content - mean_content) / (std_content)) + mean_style # Normalise, then modify mean and std 21 | return output 22 | 23 | def Content_loss(input, target): # Content loss is a simple MSE Loss 24 | loss = F.mse_loss(input, target) 25 | return loss 26 | 27 | def Style_loss(input, target): 28 | mean_loss, std_loss = 0, 0 29 | 30 | for input_layer, target_layer in zip(input, target): 31 | mean_input_layer, std_input_layer = calc_mean_std(input_layer) 32 | mean_target_layer, std_target_layer = calc_mean_std(target_layer) 33 | 34 | mean_loss += F.mse_loss(mean_input_layer, mean_target_layer) 35 | std_loss += F.mse_loss(std_input_layer, std_target_layer) 36 | 37 | return mean_loss+std_loss 38 | -------------------------------------------------------------------------------- /examples/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/examples/architecture.jpg --------------------------------------------------------------------------------