├── Images
├── NewYork_small.jpg
├── Out_Style1_small.jpg
├── Out_Style2_small.jpg
├── Out_Style3_small.jpg
├── Out_Style4_small.jpg
├── Out_Style5_small.jpg
├── style1_small.jpg
├── style2_small.jpg
├── style3_small.jpg
├── style4_small.jpg
└── style5_small.jpg
├── LICENSE
├── README.md
├── Style Transfer Test.ipynb
├── Style Transfer Train.ipynb
├── Utils
├── __init__.py
├── networks.py
└── utils.py
└── examples
└── architecture.jpg
/Images/NewYork_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/NewYork_small.jpg
--------------------------------------------------------------------------------
/Images/Out_Style1_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style1_small.jpg
--------------------------------------------------------------------------------
/Images/Out_Style2_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style2_small.jpg
--------------------------------------------------------------------------------
/Images/Out_Style3_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style3_small.jpg
--------------------------------------------------------------------------------
/Images/Out_Style4_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style4_small.jpg
--------------------------------------------------------------------------------
/Images/Out_Style5_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/Out_Style5_small.jpg
--------------------------------------------------------------------------------
/Images/style1_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style1_small.jpg
--------------------------------------------------------------------------------
/Images/style2_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style2_small.jpg
--------------------------------------------------------------------------------
/Images/style3_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style3_small.jpg
--------------------------------------------------------------------------------
/Images/style4_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style4_small.jpg
--------------------------------------------------------------------------------
/Images/style5_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/Images/style5_small.jpg
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 MAlberts99
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch-AdaIN-StyleTransfer
2 | This project is an unofficial PyTorch implementation of the paper using Google Colab: [**Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization**](https://arxiv.org/abs/1703.06868)
3 |
4 | All credit goes to: [Xun Huang](http://www.cs.cornell.edu/~xhuang/) and
5 | [Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/)
6 |
7 |
8 | ## Description
9 | The paper presents a new style transfer algorithm, which uses a fixed pretrained vgg19 (up to ReLU 4.1) to encode a style image and a content image. Then the style of the style image is transferred to the content image. The novel approach this paper describes uses an AdaIN layer to transfer the style. This layer first normalises the the content image to unit mean and standard deviation. After that, the content image is scaled such that it's mean and standard deviation are equal to the mean and standard deviation of the style image. Then the image is decoded using a decoder that mirrors the vgg19.
10 |
11 |
12 |
13 |
14 | ## Requirements
15 | - A google drive account to run the notebooks.
16 | - A pretrained vgg19 pth file. I used the file provided by [Naoto Inoue](https://github.com/naoto0804/pytorch-AdaIN) in his implementation of the same paper. Link: [vgg_normalised.pth](https://drive.google.com/file/d/108uza-dsmwvbW2zv-G73jtVcMU_2Nb7Y/view).
17 |
18 |
19 | To train:
20 | - [2015 Coco Image Dataset, 13GB](http://images.cocodataset.org/zips/test2015.zip)
21 | - [WikiArt Dataset, 25.4GB](http://web.fsktm.um.edu.my/~cschan/source/ICIP2017/wikiart.zip)
22 |
23 | A note on the Datasets: The free version of Google Colab only has 30GB of usable storage while using the GPU. Thus you may have to reduce the size of the dataset. In this implementation I used 40k images of each dataset.
24 |
25 | ## Trained Model
26 | You can download my model from [here](https://drive.google.com/file/d/1-96gmMVd1wYbP0WNM7KMF2gCIqtQ6ZtA/view?usp=sharing). It has been trained for 120.000 iteration and provides an image quality close to the offical implementation. The style weight (gamma) used was 2.0.
27 |
28 | ## Manual
29 | - Copy the content of this repository into a Google Drive folder. Then download the pretrained vgg19 file and place it in the same folder. If you want to play around with the network add the pretrained network file as well. If you want to train the network from scratch, e.g change the style weight, download the datasets as well.
30 | #### Interference
31 | - Open the Style Transfer Test notebook. In the first cell you need to specify the directory of the folder in your Google Drive. Additionally if you changed the image folder you also need to change the `img_dir` variable accordingly.
32 | - The next cell will load the network.
33 | - Then the images are loaded. Here you can choose your images.
34 | - In the next cell you can change the `alpha` variable. This variable influences the impact of the style image on the content image.
35 | #### Training
36 | - Open the Style Transfer Train notebook. In the first cell you need to specify the directory of the folder in your Google Drive.
37 | - Then you will have to download/import the datasets into the colab instance. I have not implemented this step as depending on your storage you will need to reduce the amount of images of each dataset used.
38 | - Change the `pathStyleImages` and `pathContentImages` to the folders containing the images. Note the folder needs to only contain images. Nested folders are not supported
39 | - Then run the rest of the cells.
40 |
41 | ## Results
42 | |Style | Generated Image |
43 | | :----: | :----: |
44 | |||
45 | |||
46 | |||
47 | |||
48 | |||
49 |
50 | As can be seen above the results are not quite as good as the one presented in the paper. This can be explained by the model in the paper being trained for 160.000 iterations. I only trained mine for 120.000. Additionally, the original model was trained on 80.000 images of each type in whereas I only trained on 40.000 images.
51 |
52 |
--------------------------------------------------------------------------------
/Style Transfer Train.ipynb:
--------------------------------------------------------------------------------
1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Style Transfer Train.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"0HDAX91xcOA2","colab_type":"code","colab":{}},"source":["import torch\n","import torchvision\n","import torchvision.transforms as transforms\n","from torch.utils.data import DataLoader, Dataset\n","\n","from tqdm.notebook import tqdm\n","from google.colab import drive\n","import numpy as np\n","import os\n","import sys\n","\n","manualSeed = 999\n","torch.manual_seed(manualSeed)\n","\n","drive.mount(\"/content/gdrive\")\n","path = \"/content/gdrive/My Drive/Colab Notebooks/GANS/Style Transfer\"\n","\n","\n","sys.path.append(path)\n","from Utils import networks"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"LBxpcfaNybWC","colab_type":"code","colab":{}},"source":["# Images Dataset that returns one style and one content image. As I only trained using 40.000\n","# images each, each image is randomly sampled. The way it is implemented does not allow multi-threading. However\n","# as this network is relatively small and training times low, no improved class was implemented.\n","\n","class Images(Dataset): \n"," def __init__(self, root_dir1, root_dir2, transform=None):\n"," self.root_dir1 = root_dir1\n"," self.root_dir2 = root_dir2\n"," self.transform = transform\n","\n"," def __len__(self):\n"," return min(len(os.listdir(self.root_dir1)), len(os.listdir(self.root_dir2)))\n","\n"," def __getitem__(self, idx):\n"," all_names1, all_names2 = os.listdir(self.root_dir1), os.listdir(self.root_dir2)\n"," idx1, idx2 = np.random.randint(0, len(all_names1)), np.random.randint(0, len(all_names2))\n","\n"," img_name1, img_name2 = os.path.join(self.root_dir1, all_names1[idx1]), os.path.join(self.root_dir2, all_names2[idx2])\n"," image1 = Image.open(img_name1).convert(\"RGB\")\n"," image2 = Image.open(img_name2).convert(\"RGB\")\n","\n"," if self.transform:\n"," image1 = self.transform(image1)\n"," image2 = self.transform(image2)\n","\n"," return image1, image2 "],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MIVw4MQo6WKu","colab_type":"code","colab":{}},"source":["# To note is that the images are not normalised\n","transform = transforms.Compose([transforms.Resize(512), \n"," transforms.CenterCrop(256),\n"," transforms.ToTensor()])\n","\n","\n","# Specify the path to the style and content images\n","pathStyleImages = \"/content/Data/Wiki_40k\"\n","pathContentImages = \"/content/Data/Coco_40k\" \n","\n","\n","all_img = Images(pathStyleImages, pathContentImages, transform=transform)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"bGd2_znLxLIy","colab_type":"code","colab":{}},"source":["# Simple save \n","def save_state(decoder, optimiser, iters, run_dir):\n"," \n"," name = \"StyleTransfer Checkpoint Iter: {}.tar\".format(iters)\n"," torch.save({\"Decoder\" : decoder,\n"," \"Optimiser\" : optimiser,\n"," \"iters\": iters\n"," }, os.path.join(path, name))\n"," print(\"Saved : {} succesfully\".format(name))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"nNf9XPp2HgCO","colab_type":"code","colab":{},"executionInfo":{"status":"ok","timestamp":1600861736262,"user_tz":-120,"elapsed":825,"user":{"displayName":"Marvin Alberts","photoUrl":"","userId":"07281461137485771510"}}},"source":["def training_loop(network, # StyleTransferNetwork\n"," dataloader_comb, # DataLoader\n"," n_epochs, # Number of Epochs\n"," run_dir # Directory in which the checkpoints and tensorboard files are saved\n"," ):\n"," \n","\n"," writer = SummaryWriter(os.path.join(path, run_dir))\n"," # Fixed images to compare over time\n"," fixed_batch_style, fixed_batch_content = all_img[0]\n"," fixed_batch_style, fixed_batch_content = fixed_batch_style.unsqueeze(0).to(device), fixed_batch_content.unsqueeze(0).to(device) # Move images to device\n","\n"," writer.add_image(\"Style\", torchvision.utils.make_grid(fixed_batch_style))\n"," writer.add_image(\"Content\", torchvision.utils.make_grid(fixed_batch_content))\n","\n"," iters = network.iters\n","\n"," for epoch in range(1, n_epochs+1):\n"," tqdm_object = tqdm(dataloader_comb, total=len(dataloader_comb))\n","\n"," for style_imgs, content_imgs in tqdm_object:\n"," network.adjust_learning_rate(network.optimiser, iters)\n"," style_imgs = style_imgs.to(device)\n"," content_imgs = content_imgs.to(device)\n","\n"," loss_comb, content_loss, style_loss = network(style_imgs, content_imgs)\n","\n"," network.optimiser.zero_grad()\n"," loss_comb.backward()\n"," network.optimiser.step()\n","\n"," # Update status bar, add Loss, add Images\n"," tqdm_object.set_postfix_str(\"Combined Loss: {:.3f}, Style Loss: {:.3f}, Content Loss: {:.3f}\".format(\n"," loss_comb.item()*100, style_loss.item()*100, content_loss.item()*100))\n"," \n"," if iters % 25 == 0:\n"," writer.add_scalar(\"Combined Loss\", loss_comb*1000, iters)\n"," writer.add_scalar(\"Style Loss\", style_loss*1000, iters)\n"," writer.add_scalar(\"Content Loss\", content_loss*1000, iters)\n","\n"," if (iters+1) % 2000 == 1:\n"," with torch.no_grad():\n"," network.set_train(False)\n"," images = network(fixed_batch_style, fixed_batch_content)\n"," img_grid = torchvision.utils.make_grid(images)\n"," writer.add_image(\"Progress Iter: {}\".format(iters), img_grid)\n"," network.set_train(True)\n","\n"," if (iters+1) % 4000 == 1:\n"," save_state(network.decoder.state_dict(), network.optimiser.state_dict(), iters, run_dir)\n"," writer.close()\n"," writer = SummaryWriter(os.path.join(path, run_dir))\n","\n"," iters += 1"],"execution_count":1,"outputs":[]},{"cell_type":"code","metadata":{"id":"iKg_C0K6leFp","colab_type":"code","colab":{}},"source":["device = (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","learning_rate = 1e-4\n","learning_rate_decay = 5e-5\n","\n","dataloader_comb = DataLoader(all_img, batch_size=5, shuffle=True, num_workers=0, drop_last=True)\n","gamma = torch.tensor([2]).to(device) # Style weight\n","\n","n_epochs = 5\n","run_dir = \"runs/Run 1\" # Change if you want to save the checkpoints/tensorboard files in a different directory\n","\n","state_encoder = torch.load(os.path.join(path, \"vgg_normalised.pth\"))\n","network = networks.StyleTransferNetwork(device,\n"," state_encoder,\n"," learning_rate,\n"," learning_rate_decay,\n"," gamma,\n"," load_fromstate=False,\n"," load_path=os.path.join(path, \"StyleTransfer Checkpoint Iter: 120000.tar\"))\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5_huZgs0ClUJ","colab_type":"code","colab":{}},"source":["training_loop(network, dataloader_comb, n_epochs, run_dir)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"4m-uIZxqXeoD","colab_type":"code","colab":{}},"source":[""],"execution_count":null,"outputs":[]}]}
--------------------------------------------------------------------------------
/Utils/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ["utils", "networks"]
2 |
--------------------------------------------------------------------------------
/Utils/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | import os
6 | from .utils import *
7 |
8 | # The style transfer network
9 | class StyleTransferNetwork(nn.Module):
10 | def __init__(self,
11 | device, # "cpu" for cpu, "cuda" for gpu
12 | enc_state_dict, # The state dict of the pretrained vgg19
13 | learning_rate=1e-4,
14 | learning_rate_decay=5e-5, # Decay parameter for the learning rate
15 | gamma=2.0, # Controls importance of StyleLoss vs ContentLoss, Loss = gamma*StyleLoss + ContentLoss
16 | train=True, # Wether or not network is training
17 | load_fromstate=False, # Load from checkpoint?
18 | load_path=None # Path to load checkpoint
19 | ):
20 | super().__init__()
21 |
22 | assert device in ["cpu", "cuda"]
23 | if load_fromstate and not os.path.isfile(load_path):
24 | raise ValueError("Checkpoint file not found")
25 |
26 |
27 | self.learning_rate = learning_rate
28 | self.learning_rate_decay = learning_rate_decay
29 | self.train = train
30 | self.gamma = gamma
31 |
32 | self.encoder = Encoder(enc_state_dict, device) # A pretrained vgg19 is used as the encoder
33 | self.decoder = Decoder().to(device)
34 |
35 | self.optimiser = optim.Adam(self.decoder.parameters(), lr=self.learning_rate)
36 | self.iters = 0
37 |
38 | if load_fromstate:
39 | state = torch.load(load_path)
40 | self.decoder.load_state_dict(state["Decoder"])
41 | self.optimiser.load_state_dict(state["Optimiser"])
42 | self.iters = state["iters"]
43 |
44 |
45 | def set_train(self, boolean): # Change state of network
46 | assert type(boolean) == bool
47 | self.train = boolean
48 |
49 | def adjust_learning_rate(self, optimiser, iters): # Simple learning rate decay
50 | lr = learning_rate / (1.0 + learning_rate_decay * iters)
51 | for param_group in optimiser.param_groups:
52 | param_group['lr'] = lr
53 |
54 | def forward(self, style, content, alpha=1.0): # Alpha can be used while testing to control the importance of the transferred style
55 |
56 | # Encode style and content
57 | layers_style = self.encoder(style, self.train) # if train: returns all states
58 | layer_content = self.encoder(content, False) # for the content only the last layer is important
59 |
60 | # Transfer Style
61 | if self.train:
62 | style_applied = AdaIn(layer_content, layers_style[-1]) # Last layer is "style" layer
63 | else:
64 | style_applied = alpha*AdaIn(layer_content, layers_style) + (1-alpha)*layer_content # Alpha controls magnitude of style
65 |
66 | # Scale up
67 | style_applied_upscaled = self.decoder(style_applied)
68 | if not self.train:
69 | return style_applied_upscaled # When not training return transformed image
70 |
71 | # Compute Loss
72 | layers_style_applied = self.encoder(style_applied_upscaled, self.train)
73 |
74 | content_loss = Content_loss(layers_style_applied[-1], layer_content)
75 | style_loss = Style_loss(layers_style_applied, layers_style)
76 |
77 | loss_comb = content_loss + self.gamma*style_loss
78 |
79 | return loss_comb, content_loss, style_loss
80 |
81 | # The decoder is a reversed vgg19 up to ReLU 4.1. To note is that the last layer is not activated.
82 |
83 | class Decoder(nn.Module):
84 | def __init__(self):
85 | super().__init__()
86 |
87 | self.padding = nn.ReflectionPad2d(padding=1) # Using reflection padding as described in vgg19
88 | self.UpSample = nn.Upsample(scale_factor=2, mode="nearest")
89 |
90 | self.conv4_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=0)
91 |
92 | self.conv3_1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0)
93 | self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0)
94 | self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0)
95 | self.conv3_4 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=0)
96 |
97 | self.conv2_1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0)
98 | self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=0)
99 |
100 | self.conv1_1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0)
101 | self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=0)
102 |
103 |
104 | def forward(self, x):
105 | out = self.UpSample(F.relu(self.conv4_1(self.padding(x))))
106 |
107 | out = F.relu(self.conv3_1(self.padding(out)))
108 | out = F.relu(self.conv3_2(self.padding(out)))
109 | out = F.relu(self.conv3_3(self.padding(out)))
110 | out = self.UpSample(F.relu(self.conv3_4(self.padding(out))))
111 |
112 | out = F.relu(self.conv2_1(self.padding(out)))
113 | out = self.UpSample(F.relu(self.conv2_2(self.padding(out))))
114 |
115 | out = F.relu(self.conv1_1(self.padding(out)))
116 | out = self.conv1_2(self.padding(out))
117 | return out
118 |
119 | # A vgg19 Sequential which is used up to Relu 4.1. To note is that the
120 | # first layer is a 3,3 convolution, different from a standard vgg19
121 |
122 | class Encoder(nn.Module):
123 | def __init__(self, state_dict, device):
124 | super().__init__()
125 | self.vgg19 = nn.Sequential(
126 | nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1),
127 | nn.ReflectionPad2d(padding=1),
128 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3),
129 | nn.ReLU(inplace=True), # First layer from which Style Loss is calculated
130 | nn.ReflectionPad2d(padding=1),
131 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
132 | nn.ReLU(inplace=True),
133 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True),
134 | nn.ReflectionPad2d(padding=1),
135 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
136 | nn.ReLU(inplace=True), # Second layer from which Style Loss is calculated
137 | nn.ReflectionPad2d(padding=1),
138 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3),
139 | nn.ReLU(inplace=True),
140 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True),
141 | nn.ReflectionPad2d(padding=1), # Third layer from which Style Loss is calculated
142 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3),
143 | nn.ReLU(inplace=True),
144 | nn.ReflectionPad2d(padding=1),
145 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3),
146 | nn.ReLU(inplace=True),
147 | nn.ReflectionPad2d(padding=1),
148 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3),
149 | nn.ReLU(inplace=True),
150 | nn.ReflectionPad2d(padding=1),
151 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3),
152 | nn.ReLU(inplace=True),
153 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True),
154 | nn.ReflectionPad2d(padding=1),
155 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3),
156 | nn.ReLU(inplace=True), # This is Relu 4.1 The output layer of the encoder.
157 | nn.ReflectionPad2d(padding=1),
158 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
159 | nn.ReLU(inplace=True),
160 | nn.ReflectionPad2d(padding=1),
161 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
162 | nn.ReLU(inplace=True),
163 | nn.ReflectionPad2d(padding=1),
164 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
165 | nn.ReLU(inplace=True),
166 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True),
167 | nn.ReflectionPad2d(padding=1),
168 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
169 | nn.ReLU(inplace=True),
170 | nn.ReflectionPad2d(padding=1),
171 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
172 | nn.ReLU(inplace=True),
173 | nn.ReflectionPad2d(padding=1),
174 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
175 | nn.ReLU(inplace=True),
176 | nn.ReflectionPad2d(padding=1),
177 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
178 | nn.ReLU(inplace=True)
179 | ).to(device)
180 |
181 | self.vgg19.load_state_dict(state_dict)
182 |
183 | encoder_children = list(self.vgg19.children())
184 | self.EncoderList = nn.ModuleList([nn.Sequential(*encoder_children[:4]), # Up to Relu 1.1
185 | nn.Sequential(*encoder_children[4:11]), # Up to Relu 2.1
186 | nn.Sequential(*encoder_children[11:18]), # Up to Relu 3.1
187 | nn.Sequential(*encoder_children[18:31]), # Up to Relu 4.1, also the
188 | ]) # input for the decoder
189 |
190 | def forward(self, x, intermediates=False): # if training use intermediates = True, to get the output of
191 | states = [] # all the encoder layers to calculate the style loss
192 | for i in range(len(self.EncoderList)):
193 | x = self.EncoderList[i](x)
194 |
195 | if intermediates: # All intermediate states get saved in states
196 | states.append(x)
197 | if intermediates:
198 | return states
199 | return x
200 |
--------------------------------------------------------------------------------
/Utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | # Calculates mean and std channel-wise
5 | def calc_mean_std(input, eps=1e-5):
6 | batch_size, channels = input.shape[:2]
7 |
8 | reshaped = input.view(batch_size, channels, -1) # Reshape channel wise
9 | mean = torch.mean(reshaped, dim = 2).view(batch_size, channels, 1, 1) # Calculat mean and reshape
10 | std = torch.sqrt(torch.var(reshaped, dim=2)+eps).view(batch_size, channels, 1, 1) # Calculate variance, add epsilon (avoid 0 division),
11 | # calculate std and reshape
12 | return mean, std
13 |
14 | def AdaIn(content, style):
15 | assert content.shape[:2] == style.shape[:2] # Only first two dim, such that different image sizes is possible
16 | batch_size, n_channels = content.shape[:2]
17 | mean_content, std_content = calc_mean_std(content)
18 | mean_style, std_style = calc_mean_std(style)
19 |
20 | output = std_style*((content - mean_content) / (std_content)) + mean_style # Normalise, then modify mean and std
21 | return output
22 |
23 | def Content_loss(input, target): # Content loss is a simple MSE Loss
24 | loss = F.mse_loss(input, target)
25 | return loss
26 |
27 | def Style_loss(input, target):
28 | mean_loss, std_loss = 0, 0
29 |
30 | for input_layer, target_layer in zip(input, target):
31 | mean_input_layer, std_input_layer = calc_mean_std(input_layer)
32 | mean_target_layer, std_target_layer = calc_mean_std(target_layer)
33 |
34 | mean_loss += F.mse_loss(mean_input_layer, mean_target_layer)
35 | std_loss += F.mse_loss(std_input_layer, std_target_layer)
36 |
37 | return mean_loss+std_loss
38 |
--------------------------------------------------------------------------------
/examples/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MAlberts99/PyTorch-AdaIN-StyleTransfer/8535d692a0b7dace72048fa76413436602254cb3/examples/architecture.jpg
--------------------------------------------------------------------------------