├── lib ├── __init__.py ├── adain │ ├── adain.py │ └── adain_model.py ├── adaconv │ ├── kernel_predictor.py │ ├── adaconv.py │ └── adaconv_model.py ├── loss.py ├── dataset.py ├── vgg.py └── lightning │ ├── datamodule.py │ └── lightningmodel.py ├── imgs ├── arch_01.png ├── arch_02.png ├── results_table_256.jpg ├── results_table_512.jpg └── results_comparison.jpg ├── test_images ├── style │ ├── style_01.jpg │ ├── style_02.jpg │ ├── style_03.jpg │ ├── style_04.jpg │ └── style_05.jpg └── content │ ├── content_01.jpg │ ├── content_02.jpg │ ├── content_03.jpg │ ├── content_04.jpg │ └── content_05.jpg ├── .gitignore ├── LICENSE ├── stylize.py ├── test.py ├── train.py └── README.rst /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/arch_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/imgs/arch_01.png -------------------------------------------------------------------------------- /imgs/arch_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/imgs/arch_02.png -------------------------------------------------------------------------------- /imgs/results_table_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/imgs/results_table_256.jpg -------------------------------------------------------------------------------- /imgs/results_table_512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/imgs/results_table_512.jpg -------------------------------------------------------------------------------- /imgs/results_comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/imgs/results_comparison.jpg -------------------------------------------------------------------------------- /test_images/style/style_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/style/style_01.jpg -------------------------------------------------------------------------------- /test_images/style/style_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/style/style_02.jpg -------------------------------------------------------------------------------- /test_images/style/style_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/style/style_03.jpg -------------------------------------------------------------------------------- /test_images/style/style_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/style/style_04.jpg -------------------------------------------------------------------------------- /test_images/style/style_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/style/style_05.jpg -------------------------------------------------------------------------------- /test_images/content/content_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/content/content_01.jpg -------------------------------------------------------------------------------- /test_images/content/content_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/content/content_02.jpg -------------------------------------------------------------------------------- /test_images/content/content_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/content/content_03.jpg -------------------------------------------------------------------------------- /test_images/content/content_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/content/content_04.jpg -------------------------------------------------------------------------------- /test_images/content/content_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/HEAD/test_images/content/content_05.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.webp 2 | *.pyc 3 | .idea/ 4 | *.gif 5 | imgs/ 6 | */logs/ 7 | logs/ 8 | *.ckpt 9 | .logs_old/ 10 | test_images/output/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 RElbers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lib/adain/adain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class AdaInstanceNorm2d(nn.Module): 7 | def __init__(self, mlp_features=None): 8 | super().__init__() 9 | 10 | # If mlp_features is specified, the bias and scale are estimated by transforming a code vector, 11 | # as in MUNIT (https://arxiv.org/pdf/1804.04732.pdf). 12 | if mlp_features is not None: 13 | in_features = mlp_features[0] 14 | out_features = mlp_features[1] 15 | 16 | self._scale = nn.Linear(in_features, out_features) 17 | self._bias = nn.Linear(in_features, out_features) 18 | # If mlp_features is not specified, the bias and scale are the mean and std of 2d feature maps, 19 | # as in standard AdaIN (https://arxiv.org/pdf/1703.06868.pdf). 20 | else: 21 | self._scale = self._std 22 | self._bias = self._mean 23 | 24 | def forward(self, x, y): 25 | y_scale = self._scale(y).unsqueeze(-1).unsqueeze(-1) 26 | y_bias = self._bias(y).unsqueeze(-1).unsqueeze(-1) 27 | 28 | x = F.instance_norm(x) 29 | x = (x * y_scale) + y_bias 30 | return x 31 | 32 | def _std(self, x): 33 | return torch.std(x, dim=[2, 3]) 34 | 35 | def _mean(self, x): 36 | return torch.mean(x, dim=[2, 3]) 37 | -------------------------------------------------------------------------------- /stylize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | 6 | from lib import dataset 7 | from lib.lightning.lightningmodel import LightningModel 8 | 9 | 10 | def stylize_image(model, content_file, style_file, content_size=None): 11 | device = next(model.parameters()).device 12 | 13 | content = dataset.load(content_file) 14 | style = dataset.load(style_file) 15 | 16 | content = dataset.content_transforms(content_size)(content) 17 | style = dataset.style_transforms()(style) 18 | 19 | content = content.to(device).unsqueeze(0) 20 | style = style.to(device).unsqueeze(0) 21 | 22 | output = model(content, style) 23 | return output[0].detach().cpu() 24 | 25 | 26 | def parse_args(): 27 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument('--content', type=str, default='./content.png') 29 | parser.add_argument('--style', type=str, default='./style.png') 30 | parser.add_argument('--output', type=str, default='./output.png') 31 | parser.add_argument('--model', type=str, default='./model.ckpt') 32 | 33 | return vars(parser.parse_args()) 34 | 35 | 36 | if __name__ == '__main__': 37 | args = parse_args() 38 | 39 | model = LightningModel.load_from_checkpoint(checkpoint_path=args['model']) 40 | model = model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) 41 | model.eval() 42 | 43 | with torch.no_grad(): 44 | output = stylize_image(model, args['content'], args['style']) 45 | dataset.save(output, args['output']) 46 | -------------------------------------------------------------------------------- /lib/adain/adain_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from torch import nn 4 | 5 | from lib.adain.adain import AdaInstanceNorm2d 6 | from lib.vgg import VGGDecoder, VGGEncoder 7 | 8 | 9 | class AdaINModel(nn.Module): 10 | @staticmethod 11 | def add_argparse_args(parent_parser): 12 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 13 | parser.add_argument('--alpha', type=float, default=1.0) 14 | return parser 15 | 16 | def __init__(self, alpha): 17 | super().__init__() 18 | 19 | self.encoder = VGGEncoder() 20 | self.decoder = VGGDecoder() 21 | self.adain = AdaInstanceNorm2d() 22 | self.alpha = alpha 23 | 24 | def forward(self, content, style, return_embeddings=False): 25 | self.encoder.freeze() 26 | 27 | # Encode -> Decode 28 | content_embeddings, style_embeddings = self._encode(content, style) 29 | output = self.decoder(content_embeddings[-1]) 30 | 31 | # Return embeddings if training 32 | if return_embeddings: 33 | output_embeddings = self.encoder(output) 34 | embeddings = { 35 | 'content': content_embeddings, 36 | 'style': style_embeddings, 37 | 'output': output_embeddings 38 | } 39 | return output, embeddings 40 | else: 41 | return output 42 | 43 | def _encode(self, content, style): 44 | content_embeddings = self.encoder(content) 45 | style_embeddings = self.encoder(style) 46 | 47 | t = self.adain(content_embeddings[-1], style_embeddings[-1]) 48 | t = self.alpha * t + (1 - self.alpha) * content_embeddings[-1] 49 | 50 | content_embeddings[-1] = t 51 | return content_embeddings, style_embeddings 52 | -------------------------------------------------------------------------------- /lib/adaconv/kernel_predictor.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | from torch import nn 4 | 5 | 6 | class KernelPredictor(nn.Module): 7 | def __init__(self, in_channels, out_channels, n_groups, style_channels, kernel_size): 8 | super().__init__() 9 | self.in_channels = in_channels 10 | self.out_channels = out_channels 11 | self.w_channels = style_channels 12 | self.n_groups = n_groups 13 | self.kernel_size = kernel_size 14 | 15 | padding = (kernel_size - 1) / 2 16 | self.spatial = nn.Conv2d(style_channels, 17 | in_channels * out_channels // n_groups, 18 | kernel_size=kernel_size, 19 | padding=(ceil(padding), ceil(padding)), 20 | padding_mode='reflect') 21 | self.pointwise = nn.Sequential( 22 | nn.AdaptiveAvgPool2d((1, 1)), 23 | nn.Conv2d(style_channels, 24 | out_channels * out_channels // n_groups, 25 | kernel_size=1) 26 | ) 27 | self.bias = nn.Sequential( 28 | nn.AdaptiveAvgPool2d((1, 1)), 29 | nn.Conv2d(style_channels, 30 | out_channels, 31 | kernel_size=1) 32 | ) 33 | 34 | def forward(self, w): 35 | w_spatial = self.spatial(w) 36 | w_spatial = w_spatial.reshape(len(w), 37 | self.out_channels, 38 | self.in_channels // self.n_groups, 39 | self.kernel_size, self.kernel_size) 40 | 41 | w_pointwise = self.pointwise(w) 42 | w_pointwise = w_pointwise.reshape(len(w), 43 | self.out_channels, 44 | self.out_channels // self.n_groups, 45 | 1, 1) 46 | 47 | bias = self.bias(w) 48 | bias = bias.reshape(len(w), 49 | self.out_channels) 50 | 51 | return w_spatial, w_pointwise, bias -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class MSEContentLoss(nn.Module): 8 | # https://arxiv.org/pdf/1508.06576.pdf 9 | 10 | def forward(self, x, y): 11 | return F.mse_loss(x, y) 12 | 13 | 14 | class GramStyleLoss(nn.Module): 15 | # https://arxiv.org/pdf/1508.06576.pdf 16 | 17 | def forward(self, x, y): 18 | gram_diff = self.gram_matrix(x) - self.gram_matrix(y) 19 | return torch.mean(torch.sum(gram_diff ** 2, dim=[1, 2])) 20 | 21 | def gram_matrix(self, x): 22 | n, c, h, w = x.size() 23 | x = x.view(n, c, h * w) 24 | return x @ x.transpose(-2, -1) / (c * h * w) 25 | 26 | 27 | class MomentMatchingStyleLoss(nn.Module): 28 | # https://arxiv.org/pdf/1703.06868.pdf 29 | 30 | def forward(self, x, y): 31 | x_mean = torch.mean(x, dim=[2, 3]) 32 | y_mean = torch.mean(y, dim=[2, 3]) 33 | mean_loss = F.mse_loss(x_mean, y_mean) 34 | 35 | x_std = torch.std(x, dim=[2, 3]) 36 | y_std = torch.std(y, dim=[2, 3]) 37 | std_loss = F.mse_loss(x_std, y_std) 38 | 39 | return mean_loss + std_loss 40 | 41 | 42 | class CMDStyleLoss(nn.Module): 43 | # https://arxiv.org/pdf/2103.07208.pdf 44 | # CMDStyleLoss works with pre-activation outputs of VGG19 (without ReLU) 45 | 46 | def __init__(self, k=5): 47 | super().__init__() 48 | self.k = k 49 | 50 | def forward(self, x, y): 51 | x, y = torch.sigmoid(x), torch.sigmoid(y) 52 | 53 | loss = 0 54 | for x_k, y_k in zip(self.moments(x), self.moments(y)): 55 | loss += self.l2_dist(x_k, y_k).mean() 56 | return loss 57 | 58 | def moments(self, x): 59 | # First vectorize feature maps 60 | n, c, h, w = x.size() 61 | x = x.view(n, c, h * w) 62 | 63 | x_mean = torch.mean(x, dim=2, keepdim=True) 64 | x_centered = x - x_mean 65 | 66 | moments = [x_mean.squeeze(-1)] 67 | for n in range(2, self.k + 1): 68 | moments.append(torch.mean(x_centered ** n, dim=2)) 69 | return moments 70 | 71 | def l2_dist(self, x, y): 72 | return torch.norm(x - y, dim=1) 73 | -------------------------------------------------------------------------------- /lib/adaconv/adaconv.py: -------------------------------------------------------------------------------- 1 | from math import ceil, floor 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class AdaConv2d(nn.Module): 9 | """ 10 | Implementation of the Adaptive Convolution block. Performs a depthwise seperable adaptive convolution on its input X. 11 | The weights for the adaptive convolutions are generated by a KernelPredictor module based on the style embedding W. 12 | The adaptive convolution is followed by a normal convolution. 13 | 14 | References: 15 | https://openaccess.thecvf.com/content/CVPR2021/papers/Chandran_Adaptive_Convolutions_for_Structure-Aware_Style_Transfer_CVPR_2021_paper.pdf 16 | 17 | 18 | Args: 19 | in_channels: Number of channels in the input image. 20 | out_channels: Number of channels produced by final convolution. 21 | kernel_size: The kernel size of the final convolution. 22 | n_groups: The number of groups for the adaptive convolutions. 23 | Defaults to 1 group per channel if None. 24 | 25 | Input shape: 26 | x: Input tensor. 27 | w_spatial: Weights for the spatial adaptive convolution. 28 | w_pointwise: Weights for the pointwise adaptive convolution. 29 | bias: Bias for the pointwise adaptive convolution. 30 | """ 31 | 32 | def __init__(self, in_channels, out_channels, kernel_size=3, n_groups=None): 33 | super().__init__() 34 | self.n_groups = in_channels if n_groups is None else n_groups 35 | self.in_channels = in_channels 36 | self.out_channels = out_channels 37 | 38 | padding = (kernel_size - 1) / 2 39 | self.conv = nn.Conv2d(in_channels=in_channels, 40 | out_channels=out_channels, 41 | kernel_size=(kernel_size, kernel_size), 42 | padding=(ceil(padding), floor(padding)), 43 | padding_mode='reflect') 44 | 45 | def forward(self, x, w_spatial, w_pointwise, bias): 46 | assert len(x) == len(w_spatial) == len(w_pointwise) == len(bias) 47 | x = F.instance_norm(x) 48 | 49 | # F.conv2d does not work with batched filters (as far as I can tell)... 50 | # Hack for inputs with > 1 sample 51 | ys = [] 52 | for i in range(len(x)): 53 | y = self._forward_single(x[i:i + 1], w_spatial[i], w_pointwise[i], bias[i]) 54 | ys.append(y) 55 | ys = torch.cat(ys, dim=0) 56 | 57 | ys = self.conv(ys) 58 | return ys 59 | 60 | def _forward_single(self, x, w_spatial, w_pointwise, bias): 61 | # Only square kernels 62 | assert w_spatial.size(-1) == w_spatial.size(-2) 63 | padding = (w_spatial.size(-1) - 1) / 2 64 | pad = (ceil(padding), floor(padding), ceil(padding), floor(padding)) 65 | 66 | x = F.pad(x, pad=pad, mode='reflect') 67 | x = F.conv2d(x, w_spatial, groups=self.n_groups) 68 | x = F.conv2d(x, w_pointwise, groups=self.n_groups, bias=bias) 69 | return x 70 | -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from pathlib import Path 4 | 5 | from PIL import Image 6 | from torch.utils.data import IterableDataset, Dataset 7 | from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop 8 | from torchvision.utils import save_image 9 | 10 | 11 | def files_in(dir): 12 | return list(sorted(Path(dir).glob('*'))) 13 | 14 | 15 | def save(img_tensor, file): 16 | if img_tensor.ndim == 4: 17 | assert len(img_tensor) == 1 18 | 19 | save_image(img_tensor, str(file)) 20 | 21 | 22 | def load(file): 23 | img = Image.open(str(file)) 24 | img = img.convert('RGB') 25 | return img 26 | 27 | 28 | def style_transforms(size=256): 29 | # Style images must be 256x256 for AdaConv 30 | return Compose([ 31 | Resize(size=size), # Resize to keep aspect ratio 32 | CenterCrop(size=(size, size)), # Center crop to square 33 | ToTensor()]) 34 | 35 | 36 | def content_transforms(min_size=None): 37 | # min_size is optional as content images have no size restrictions 38 | transforms = [] 39 | if min_size: 40 | transforms.append(Resize(size=min_size)) 41 | transforms.append(ToTensor()) 42 | return Compose(transforms) 43 | 44 | 45 | class StylizationDataset(Dataset): 46 | def __init__(self, content_files, style_files, content_transform=None, style_transform=None): 47 | self.content_files = content_files 48 | self.style_files = style_files 49 | 50 | id = lambda x: x 51 | self.content_transform = id if content_transform is None else content_transform 52 | self.style_transform = id if style_transform is None else style_transform 53 | 54 | def __getitem__(self, idx): 55 | content_file, style_file = self.files_at_index(idx) 56 | 57 | content_img = load(content_file) 58 | style_img = load(style_file) 59 | 60 | content_img = self.content_transform(content_img) 61 | style_img = self.style_transform(style_img) 62 | 63 | return { 64 | 'content': content_img, 65 | 'style': style_img, 66 | } 67 | 68 | def __len__(self): 69 | return len(self.content_files) * len(self.style_files) 70 | 71 | def files_at_index(self, idx): 72 | content_idx = idx % len(self.content_files) 73 | style_idx = idx // len(self.content_files) 74 | 75 | assert 0 <= content_idx < len(self.content_files) 76 | assert 0 <= style_idx < len(self.style_files) 77 | return self.content_files[content_idx], self.style_files[style_idx] 78 | 79 | 80 | class EndlessDataset(IterableDataset): 81 | """ 82 | Wrapper for StylizationDataset which loops infinitely. 83 | Usefull when training based on iterations instead of epochs 84 | """ 85 | 86 | def __init__(self, *args, **kwargs): 87 | self.dataset = StylizationDataset(*args, **kwargs) 88 | 89 | def __iter__(self): 90 | while True: 91 | idx = random.randrange(len(self.dataset)) 92 | 93 | try: 94 | yield self.dataset[idx] 95 | except Exception as e: 96 | files = self.dataset.files_at_index(idx) 97 | warnings.warn(f'\n{str(e)}\n\tFiles: [{str(files[0])}, {str(files[1])}]') 98 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | 5 | import torch 6 | import torchvision.transforms.functional as TF 7 | from torchvision.utils import make_grid 8 | from tqdm import tqdm 9 | 10 | from lib import dataset 11 | from lib.lightning.lightningmodel import LightningModel 12 | from stylize import stylize_image 13 | 14 | 15 | def resize(img, size): 16 | c, h, w = img.size() 17 | if h < w: 18 | small_size = size[0] 19 | else: 20 | small_size = size[1] 21 | 22 | img = TF.resize(img, small_size) 23 | img = TF.center_crop(img, size) 24 | return img 25 | 26 | 27 | def parse_args(): 28 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 29 | parser.add_argument('--content-dir', type=str, default='./test_images/content') 30 | parser.add_argument('--style-dir', type=str, default='./test_images/style') 31 | parser.add_argument('--output-dir', type=str, default='./test_images/output') 32 | parser.add_argument('--model', type=str, default='./model.ckpt') 33 | parser.add_argument('--save-as', type=str, default='png') 34 | parser.add_argument('--content-size', type=int, default=512, 35 | help='Content images are resized such that the smaller edge has this size.') 36 | 37 | return vars(parser.parse_args()) 38 | 39 | 40 | if __name__ == '__main__': 41 | args = parse_args() 42 | ext = args['save_as'] 43 | content_transform = dataset.content_transforms(args['content_size']) 44 | style_transform = dataset.style_transforms() 45 | 46 | content_files = dataset.files_in(args['content_dir']) 47 | style_files = dataset.files_in(args['style_dir']) 48 | output_dir = Path(args['output_dir']) 49 | if not output_dir.exists(): 50 | output_dir.mkdir(parents=True, exist_ok=True) 51 | 52 | model = LightningModel.load_from_checkpoint(checkpoint_path=args['model']) 53 | model = model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) 54 | model.eval() 55 | 56 | pbar = tqdm(total=len(content_files) * len(style_files)) 57 | with torch.no_grad(): 58 | # Add style images at top row 59 | imgs = [style_transform(dataset.load(f)) for f in style_files] 60 | 61 | for i, content in enumerate(content_files): 62 | # Add content images at left column 63 | imgs.append(content_transform(dataset.load(content))) 64 | 65 | for j, style in enumerate(style_files): 66 | # Stylize content-style pair 67 | output = stylize_image(model, content, style, content_size=args['content_size']) 68 | 69 | dataset.save(output, output_dir.joinpath(f'{i:02}--{j:02}.{ext}')) 70 | imgs.append(output) 71 | pbar.update(1) 72 | 73 | # Make all same size for table 74 | avg_h = int(sum([img.size(1) for img in imgs]) / len(imgs)) 75 | avg_w = int(sum([img.size(2) for img in imgs]) / len(imgs)) 76 | imgs = [resize(img, [avg_h, avg_w]) for img in imgs] 77 | imgs = [torch.ones((3, avg_h, avg_w)), *imgs] # Add empty top left square. 78 | grid = make_grid(imgs, nrow=len(style_files) + 1, padding=16, pad_value=1) 79 | dataset.save(grid, output_dir.joinpath(f'table.{ext}')) 80 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | 5 | import torch 6 | from pytorch_lightning import Trainer 7 | from pytorch_lightning.callbacks import LearningRateMonitor 8 | from pytorch_lightning.loggers import TensorBoardLogger 9 | 10 | from lib import dataset 11 | from lib.lightning.datamodule import DataModule 12 | from lib.lightning.lightningmodel import LightningModel 13 | 14 | 15 | class TensorBoardImageLogger(TensorBoardLogger): 16 | """ 17 | Wrapper for TensorBoardLogger which logs images to disk, 18 | instead of the TensorBoard log file. 19 | """ 20 | 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | exp = self.experiment 24 | 25 | # if not hasattr(exp, 'add_image'): 26 | exp.add_image = self.add_image 27 | 28 | def add_image(self, tag, img_tensor, global_step): 29 | dir = Path(self.log_dir, 'images') 30 | dir.mkdir(parents=True, exist_ok=True) 31 | 32 | file = dir.joinpath(f'{tag}_{global_step:09}.jpg') 33 | dataset.save(img_tensor, file) 34 | 35 | 36 | def parse_args(): 37 | # Init parser 38 | parser = ArgumentParser() 39 | parser.add_argument('--iterations', type=int, default=160_000, 40 | help='The number of training iterations.') 41 | parser.add_argument('--log-dir', type=str, default='./', 42 | help='The directory where the logs are saved to.') 43 | parser.add_argument('--checkpoint', type=str, 44 | help='Resume training from a checkpoint file.') 45 | parser.add_argument('--val-interval', type=int, default=1000, 46 | help='How often a validation step is performed. ' 47 | 'Applies the model to several fixed images and calculate the loss.') 48 | 49 | parser = DataModule.add_argparse_args(parser) 50 | parser = LightningModel.add_argparse_args(parser) 51 | 52 | parser.formatter_class = argparse.ArgumentDefaultsHelpFormatter 53 | return vars(parser.parse_args()) 54 | 55 | 56 | if __name__ == '__main__': 57 | args = parse_args() 58 | 59 | if args['checkpoint'] is None: 60 | max_epochs = 1 61 | model = LightningModel(**args) 62 | else: 63 | # We need to increment the max_epoch variable, because PyTorch Lightning will 64 | # resume training from the beginning of the next epoch if resuming from a mid-epoch checkpoint. 65 | max_epochs = torch.load(args['checkpoint'])['epoch'] + 1 66 | model = LightningModel.load_from_checkpoint(checkpoint_path=args['checkpoint']) 67 | datamodule = DataModule(**args) 68 | 69 | logger = TensorBoardImageLogger(args['log_dir'], name='logs') 70 | lr_monitor = LearningRateMonitor(logging_interval='step') 71 | trainer = Trainer(gpus=1, 72 | resume_from_checkpoint=args['checkpoint'], 73 | max_epochs=max_epochs, 74 | max_steps=args['iterations'], 75 | checkpoint_callback=True, 76 | val_check_interval=args['val_interval'], 77 | logger=logger, 78 | callbacks=[lr_monitor]) 79 | 80 | trainer.fit(model, datamodule=datamodule) 81 | trainer.save_checkpoint("./model.ckpt") 82 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | AdaConv 2 | ============================== 3 | 4 | Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer from `"Adaptive Convolutions for Structure-Aware Style Transfer" `__. 5 | I tried to be as faithful as possible to the what the paper explains of the model, but not every training detail was in the paper so I had to make some choices regarding that. 6 | If something was unclear I tried to do what AdaIn does instead. Results are at the bottom of this page. 7 | 8 | 9 | `Direct link to the adaconv module. `_ 10 | 11 | `Direct link to the kernel predictor module. `_ 12 | 13 | Usage 14 | ----- 15 | 16 | The parameters in the commands below are the default parameters and can thus be omitted unless you want to use different options. 17 | Check the help option (``-h`` or ``--help``) for more information about all parameters. 18 | To train a new model: 19 | 20 | .. code:: 21 | 22 | python train.py --content ./data/MSCOCO/train2017 --style ./data/WikiArt/train 23 | 24 | 25 | To resume training from a checkpoint (.ckpt files are saved in the log directory): 26 | 27 | .. code:: 28 | 29 | python train.py --checkpoint 30 | 31 | 32 | To apply the model on a single style-content pair: 33 | 34 | .. code:: 35 | 36 | python stylize.py --content ./content.png --style ./style.png --output ./output.png --model ./model.ckpt 37 | 38 | 39 | To apply the model on every style-content combination in a folder and create a table of outputs: 40 | 41 | .. code:: 42 | 43 | python test.py --content-dir ./test_images/content --style-dir ./test_images/style --output-dir ./test_images/output --model ./model.ckpt 44 | 45 | 46 | Weights 47 | ======= 48 | `Pretrained weights can be downloaded here. `_ 49 | Move ``model.ckpt`` to the root directory of this project and run ``stylize.py`` or ``test.py``. 50 | You can finetune the model further by loading it as a checkpoint and increasing the number of iterations. 51 | To train for an additional 40k (200k - 160k) iterations: 52 | 53 | .. code:: 54 | 55 | python train.py --checkpoint ./model.ckpt --iterations 200000 56 | 57 | 58 | Data 59 | ==== 60 | 61 | The model is trained with the `MS COCO train2017 dataset `_ for content images and the `WikiArt train dataset `_ for style images. 62 | By default the content images should be placed in ``./data/MSCOCO/train2017`` and the style images in ``./data/WikiArt/train``. 63 | You can change these directories by passing arguments when running the script. 64 | The test style and content images in the ``./test_images`` folder are taken from the `official AdaIn repository `_. 65 | 66 | 67 | Results 68 | ======= 69 | Judging from the results I'm not convinced everything is as the original authors did, but without an official repository it's hard to compare implementations. 70 | Results after training 160k iterations: 71 | 72 | .. image:: https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/master/imgs/results_table_256.jpg 73 | 74 | Comparison with reported results in the paper: 75 | 76 | .. image:: https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/master/imgs/results_comparison.jpg 77 | 78 | 79 | Architecture (from the original paper): 80 | --------------------------------------- 81 | 82 | .. image:: https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/master/imgs/arch_01.png 83 | 84 | .. image:: https://raw.githubusercontent.com/RElbers/ada-conv-pytorch/master/imgs/arch_02.png 85 | 86 | -------------------------------------------------------------------------------- /lib/vgg.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from torch import nn 4 | from torchvision import models 5 | from torchvision.transforms import transforms 6 | 7 | 8 | class VGGEncoder(nn.Module): 9 | def __init__(self, normalize=True, post_activation=True): 10 | super().__init__() 11 | 12 | if normalize: 13 | mean = [0.485, 0.456, 0.406] 14 | std = [0.229, 0.224, 0.225] 15 | self.normalize = transforms.Normalize(mean=mean, std=std) 16 | else: 17 | self.normalize = nn.Identity() 18 | 19 | if post_activation: 20 | layer_names = {'relu1_1', 'relu2_1', 'relu3_1', 'relu4_1'} 21 | else: 22 | layer_names = {'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1'} 23 | blocks, block_names, scale_factor, out_channels = extract_vgg_blocks(models.vgg19(pretrained=True).features, 24 | layer_names) 25 | 26 | self.blocks = nn.ModuleList(blocks) 27 | self.block_names = block_names 28 | self.scale_factor = scale_factor 29 | self.out_channels = out_channels 30 | 31 | def forward(self, xs): 32 | xs = self.normalize(xs) 33 | 34 | features = [] 35 | for block in self.blocks: 36 | xs = block(xs) 37 | features.append(xs) 38 | 39 | return features 40 | 41 | def freeze(self): 42 | self.eval() 43 | for parameter in self.parameters(): 44 | parameter.requires_grad = False 45 | 46 | 47 | # For AdaIn, not used in AdaConv. 48 | class VGGDecoder(nn.Module): 49 | def __init__(self): 50 | super().__init__() 51 | 52 | layers = [ 53 | self._conv(512, 256), 54 | nn.ReLU(), 55 | self._upsample(), 56 | 57 | self._conv(256, 256), 58 | nn.ReLU(), 59 | self._conv(256, 256), 60 | nn.ReLU(), 61 | self._conv(256, 256), 62 | nn.ReLU(), 63 | self._conv(256, 128), 64 | nn.ReLU(), 65 | self._upsample(), 66 | 67 | self._conv(128, 128), 68 | nn.ReLU(), 69 | self._conv(128, 64), 70 | nn.ReLU(), 71 | self._upsample(), 72 | 73 | self._conv(64, 64), 74 | nn.ReLU(), 75 | self._conv(64, 3), 76 | ] 77 | self.layers = nn.Sequential(*layers) 78 | 79 | def forward(self, content): 80 | ys = self.layers(content) 81 | return ys 82 | 83 | @staticmethod 84 | def _conv(in_channels, out_channels, kernel_size=3, padding_mode='reflect'): 85 | padding = (kernel_size - 1) // 2 86 | return nn.Conv2d(in_channels=in_channels, 87 | out_channels=out_channels, 88 | kernel_size=kernel_size, 89 | padding=padding, 90 | padding_mode=padding_mode) 91 | 92 | @staticmethod 93 | def _upsample(scale_factor=2, mode='nearest'): 94 | return nn.Upsample(scale_factor=scale_factor, mode=mode) 95 | 96 | 97 | def extract_vgg_blocks(layers, layer_names): 98 | blocks, current_block, block_names = [], [], [] 99 | scale_factor, out_channels = -1, -1 100 | depth_idx, relu_idx, conv_idx = 1, 1, 1 101 | for layer in layers: 102 | name = '' 103 | if isinstance(layer, nn.Conv2d): 104 | name = f'conv{depth_idx}_{conv_idx}' 105 | current_out_channels = layer.out_channels 106 | layer.padding_mode = 'reflect' 107 | conv_idx += 1 108 | elif isinstance(layer, nn.ReLU): 109 | name = f'relu{depth_idx}_{relu_idx}' 110 | layer = nn.ReLU(inplace=False) 111 | relu_idx += 1 112 | elif isinstance(layer, nn.AvgPool2d) or isinstance(layer, nn.MaxPool2d): 113 | name = f'pool{depth_idx}' 114 | depth_idx += 1 115 | conv_idx = 1 116 | relu_idx = 1 117 | else: 118 | warnings.warn(f' Unexpected layer type: {type(layer)}') 119 | 120 | current_block.append(layer) 121 | if name in layer_names: 122 | blocks.append(nn.Sequential(*current_block)) 123 | block_names.append(name) 124 | scale_factor = 1 * 2 ** (depth_idx - 1) 125 | out_channels = current_out_channels 126 | current_block = [] 127 | 128 | return blocks, block_names, scale_factor, out_channels 129 | -------------------------------------------------------------------------------- /lib/lightning/datamodule.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | import pytorch_lightning as pl 5 | from sklearn.model_selection import train_test_split 6 | from torch import Tensor 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | 10 | from lib import dataset 11 | from lib.dataset import StylizationDataset, files_in, EndlessDataset 12 | 13 | 14 | class DataModule(pl.LightningDataModule): 15 | @staticmethod 16 | def add_argparse_args(parent_parser): 17 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 18 | parser.add_argument('--content', type=str, default='./data/MSCOCO/train2017', 19 | help='Directory with content images.') 20 | parser.add_argument('--style', type=str, default='./data/WikiArt/train', 21 | help='Directory with style images.') 22 | parser.add_argument('--test-content', type=str, default='./test_images/content', 23 | help='Directory with test content images (or path to single image). If not set, takes 5 random train content images.') 24 | parser.add_argument('--test-style', type=str, default='./test_images/style', 25 | help='Directory with test style images (or path to single image). If not set, takes 5 random train style images.') 26 | parser.add_argument('--batch-size', type=int, default=8, 27 | help='Training batch size.') 28 | 29 | return parser 30 | 31 | def __init__(self, content, style, batch_size, test_content=None, test_style=None, **_): 32 | super().__init__() 33 | if not Path(content).exists(): 34 | raise Exception(f'Path used for content images does not exist: "{Path(content)}"') 35 | if not Path(style).exists(): 36 | raise Exception(f'Path used for style images does not exist: "{Path(style)}"') 37 | 38 | content_files, test_content_files = self.get_files(content, test_content, batch_size) 39 | style_files, test_style_files = self.get_files(style, test_style, batch_size) 40 | 41 | train_transforms = self.train_transforms() 42 | self.train_dataset = EndlessDataset(content_files, style_files, 43 | style_transform=train_transforms['style'], 44 | content_transform=train_transforms['content']) 45 | 46 | test_transforms = self.test_transforms() 47 | self.test_dataset = StylizationDataset(test_content_files, test_style_files, 48 | style_transform=test_transforms['style'], 49 | content_transform=test_transforms['content']) 50 | self.batch_size = batch_size 51 | 52 | def train_transforms(self): 53 | return { 54 | 'content': transforms.Compose([ 55 | transforms.Resize(size=(512, 512)), 56 | transforms.RandomCrop(256), 57 | transforms.ToTensor(), 58 | ]), 59 | 'style': transforms.Compose([ 60 | transforms.Resize(size=(512, 512)), 61 | transforms.RandomCrop(256), 62 | transforms.ToTensor(), 63 | ]) 64 | } 65 | 66 | def test_transforms(self): 67 | return { 68 | 'content': transforms.Compose([ 69 | transforms.CenterCrop(256), 70 | dataset.content_transforms(), 71 | ]), 72 | 'style': dataset.style_transforms(), 73 | } 74 | 75 | def train_dataloader(self): 76 | return DataLoader(self.train_dataset, batch_size=self.batch_size) 77 | 78 | def val_dataloader(self): 79 | return DataLoader(self.test_dataset, batch_size=1) 80 | 81 | def test_dataloader(self): 82 | return DataLoader(self.test_dataset, batch_size=1) 83 | 84 | def transfer_batch_to_device(self, batch, device): 85 | for k, v in batch.items(): 86 | if isinstance(v, Tensor): 87 | batch[k] = v.to(device) 88 | return batch 89 | 90 | def prepare_data(self): 91 | pass 92 | 93 | def setup(self, stage=None): 94 | pass 95 | 96 | @staticmethod 97 | def get_files(train_path, test_path, test_size=5): 98 | train_files = files_in(train_path) 99 | 100 | if test_path is None: 101 | train_files, test_files = train_test_split(train_files, test_size=test_size) 102 | else: 103 | if Path(test_path).is_dir(): 104 | test_files = files_in(test_path) 105 | else: 106 | test_files = [test_path] 107 | return train_files, test_files 108 | -------------------------------------------------------------------------------- /lib/lightning/lightningmodel.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from math import sqrt 3 | from statistics import mean 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import LambdaLR 9 | from torchvision.utils import make_grid 10 | 11 | from lib.adaconv.adaconv_model import AdaConvModel 12 | from lib.adain.adain_model import AdaINModel 13 | from lib.loss import MomentMatchingStyleLoss, GramStyleLoss, CMDStyleLoss, MSEContentLoss 14 | 15 | 16 | class LightningModel(pl.LightningModule): 17 | @staticmethod 18 | def add_argparse_args(parent_parser): 19 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 20 | 21 | # Add params of other models 22 | parser = AdaConvModel.add_argparse_args(parser) 23 | parser = AdaINModel.add_argparse_args(parser) 24 | parser.add_argument('--model-type', type=str, default='adaconv', choices=['adain', 'adaconv']) 25 | 26 | # Losses 27 | # mm = Moment Matching, gram = Gram matrix based, cmd = Central Moment Discrepancy 28 | parser.add_argument('--style-loss', type=str, default='mm', choices=['mm', 'gram', 'cmd']) 29 | parser.add_argument('--style-weight', type=float, default=10.0) 30 | parser.add_argument('--content-loss', type=str, default='mse', choices=['mse']) 31 | parser.add_argument('--content-weight', type=float, default=1.0) 32 | 33 | # Optimizer 34 | parser.add_argument('--lr', type=float, default=0.0001) 35 | parser.add_argument('--lr-decay', type=float, default=0.00005) 36 | return parser 37 | 38 | def __init__(self, 39 | model_type, 40 | alpha, 41 | style_size, style_channels, kernel_size, 42 | style_loss, style_weight, 43 | content_loss, content_weight, 44 | lr, lr_decay, 45 | **_): 46 | super().__init__() 47 | self.save_hyperparameters() 48 | 49 | self.lr = lr 50 | self.lr_decay = lr_decay 51 | self.style_weight = style_weight 52 | self.content_weight = content_weight 53 | 54 | # Style loss 55 | if style_loss == 'mm': 56 | self.style_loss = MomentMatchingStyleLoss() 57 | elif style_loss == 'gram': 58 | self.style_loss = GramStyleLoss() 59 | elif style_loss == 'cmd': 60 | self.style_loss = CMDStyleLoss() 61 | else: 62 | raise ValueError('style_loss') 63 | 64 | # Content loss 65 | if content_loss == 'mse': 66 | self.content_loss = MSEContentLoss() 67 | else: 68 | raise ValueError('content_loss') 69 | 70 | # Model type 71 | if model_type == 'adain': 72 | self.model = AdaINModel(alpha) 73 | elif model_type == 'adaconv': 74 | self.model = AdaConvModel(style_size, style_channels, kernel_size) 75 | else: 76 | raise ValueError('model_type') 77 | 78 | def forward(self, content, style, return_embeddings=False): 79 | return self.model(content, style, return_embeddings) 80 | 81 | def training_step(self, batch, batch_idx): 82 | return self.shared_step(batch, 'train') 83 | 84 | def validation_step(self, batch, batch_idx): 85 | return self.shared_step(batch, 'val') 86 | 87 | def shared_step(self, batch, step): 88 | content, style = batch['content'], batch['style'] 89 | output, embeddings = self.model(content, style, return_embeddings=True) 90 | content_loss, style_loss = self.loss(embeddings) 91 | 92 | # Log metrics 93 | self.log(rf'{step}/loss_style', style_loss.item(), prog_bar=step == 'train') 94 | self.log(rf'{step}/loss_content', content_loss.item(), prog_bar=step == 'train') 95 | 96 | # Return output only for validation step 97 | if step == 'val': 98 | return { 99 | 'loss': content_loss + style_loss, 100 | 'output': output, 101 | } 102 | return content_loss + style_loss 103 | 104 | def validation_epoch_end(self, outputs): 105 | if self.global_step == 0: 106 | return 107 | 108 | with torch.no_grad(): 109 | imgs = [x['output'] for x in outputs] 110 | imgs = [img for triple in imgs for img in triple] 111 | nrow = int(sqrt(len(imgs))) 112 | grid = make_grid(imgs, nrow=nrow, padding=0) 113 | logger = self.logger.experiment 114 | logger.add_image(rf'val_img', grid, global_step=self.global_step + 1) 115 | 116 | def loss(self, embeddings): 117 | # Content 118 | content_loss = self.content_loss(embeddings['content'][-1], embeddings['output'][-1]) 119 | 120 | # Style 121 | style_loss = [] 122 | for (style_features, output_features) in zip(embeddings['style'], embeddings['output']): 123 | style_loss.append(self.style_loss(style_features, output_features)) 124 | style_loss = sum(style_loss) 125 | 126 | return self.content_weight * content_loss, self.style_weight * style_loss 127 | 128 | def configure_optimizers(self): 129 | optimizer = Adam(self.parameters(), lr=self.lr) 130 | 131 | def lr_lambda(iter): 132 | return 1 / (1 + 0.0002 * iter) 133 | 134 | lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) 135 | 136 | return { 137 | 'optimizer': optimizer, 138 | 'lr_scheduler': { 139 | "scheduler": lr_scheduler, 140 | "interval": "step", 141 | "frequency": 1, 142 | }, 143 | } 144 | -------------------------------------------------------------------------------- /lib/adaconv/adaconv_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torchinfo 4 | from torch import nn 5 | 6 | from lib.adaconv.adaconv import AdaConv2d 7 | from lib.adaconv.kernel_predictor import KernelPredictor 8 | from lib.vgg import VGGEncoder 9 | 10 | 11 | class AdaConvModel(nn.Module): 12 | @staticmethod 13 | def add_argparse_args(parent_parser): 14 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 15 | parser.add_argument('--style-size', type=int, default=256, help='Size of the input style image.') 16 | parser.add_argument('--style-channels', type=int, default=512, help='Number of channels for the style descriptor.') 17 | parser.add_argument('--kernel-size', type=int, default=3, help='The size of the predicted kernels.') 18 | return parser 19 | 20 | def __init__(self, style_size, style_channels, kernel_size): 21 | super().__init__() 22 | self.encoder = VGGEncoder() 23 | 24 | style_in_shape = (self.encoder.out_channels, style_size // self.encoder.scale_factor, style_size // self.encoder.scale_factor) 25 | style_out_shape = (style_channels, kernel_size, kernel_size) 26 | self.style_encoder = GlobalStyleEncoder(in_shape=style_in_shape, out_shape=style_out_shape) 27 | self.decoder = AdaConvDecoder(style_channels=style_channels, kernel_size=kernel_size) 28 | 29 | def forward(self, content, style, return_embeddings=False): 30 | self.encoder.freeze() 31 | 32 | # Encode -> Decode 33 | content_embeddings, style_embeddings = self._encode(content, style) 34 | output = self._decode(content_embeddings[-1], style_embeddings[-1]) 35 | 36 | # Return embeddings if training 37 | if return_embeddings: 38 | output_embeddings = self.encoder(output) 39 | embeddings = { 40 | 'content': content_embeddings, 41 | 'style': style_embeddings, 42 | 'output': output_embeddings 43 | } 44 | return output, embeddings 45 | else: 46 | return output 47 | 48 | def _encode(self, content, style): 49 | content_embeddings = self.encoder(content) 50 | style_embeddings = self.encoder(style) 51 | return content_embeddings, style_embeddings 52 | 53 | def _decode(self, content_embedding, style_embedding): 54 | style_embedding = self.style_encoder(style_embedding) 55 | output = self.decoder(content_embedding, style_embedding) 56 | return output 57 | 58 | 59 | class AdaConvDecoder(nn.Module): 60 | def __init__(self, style_channels, kernel_size): 61 | super().__init__() 62 | self.style_channels = style_channels 63 | self.kernel_size = kernel_size 64 | 65 | # Inverted VGG with first conv in each scale replaced with AdaConv 66 | group_div = [1, 2, 4, 8] 67 | n_convs = [1, 4, 2, 2] 68 | self.layers = nn.ModuleList([ 69 | *self._make_layers(512, 256, group_div=group_div[0], n_convs=n_convs[0]), 70 | *self._make_layers(256, 128, group_div=group_div[1], n_convs=n_convs[1]), 71 | *self._make_layers(128, 64, group_div=group_div[2], n_convs=n_convs[2]), 72 | *self._make_layers(64, 3, group_div=group_div[3], n_convs=n_convs[3], final_act=False, upsample=False)]) 73 | 74 | def forward(self, content, w_style): 75 | # Checking types is a bit hacky, but it works well. 76 | for module in self.layers: 77 | if isinstance(module, KernelPredictor): 78 | w_spatial, w_pointwise, bias = module(w_style) 79 | elif isinstance(module, AdaConv2d): 80 | content = module(content, w_spatial, w_pointwise, bias) 81 | else: 82 | content = module(content) 83 | return content 84 | 85 | def _make_layers(self, in_channels, out_channels, group_div, n_convs, final_act=True, upsample=True): 86 | n_groups = in_channels // group_div 87 | 88 | layers = [] 89 | for i in range(n_convs): 90 | last = i == n_convs - 1 91 | out_channels_ = out_channels if last else in_channels 92 | if i == 0: 93 | layers += [ 94 | KernelPredictor(in_channels, in_channels, 95 | n_groups=n_groups, 96 | style_channels=self.style_channels, 97 | kernel_size=self.kernel_size), 98 | AdaConv2d(in_channels, out_channels_, n_groups=n_groups)] 99 | else: 100 | layers.append(nn.Conv2d(in_channels, out_channels_, 3, 101 | padding=1, padding_mode='reflect')) 102 | 103 | if not last or final_act: 104 | layers.append(nn.ReLU()) 105 | 106 | if upsample: 107 | layers.append(nn.Upsample(scale_factor=2, mode='nearest')) 108 | return layers 109 | 110 | 111 | class GlobalStyleEncoder(nn.Module): 112 | def __init__(self, in_shape, out_shape): 113 | super().__init__() 114 | self.in_shape = in_shape 115 | self.out_shape = out_shape 116 | channels = in_shape[0] 117 | 118 | self.downscale = nn.Sequential( 119 | nn.Conv2d(channels, channels, 3, padding=1, padding_mode='reflect'), 120 | nn.LeakyReLU(), 121 | nn.AvgPool2d(2, 2), 122 | # 123 | nn.Conv2d(channels, channels, 3, padding=1, padding_mode='reflect'), 124 | nn.LeakyReLU(), 125 | nn.AvgPool2d(2, 2), 126 | # 127 | nn.Conv2d(channels, channels, 3, padding=1, padding_mode='reflect'), 128 | nn.LeakyReLU(), 129 | nn.AvgPool2d(2, 2), 130 | ) 131 | 132 | in_features = self.in_shape[0] * (self.in_shape[1] // 8) * self.in_shape[2] // 8 133 | out_features = self.out_shape[0] * self.out_shape[1] * self.out_shape[2] 134 | self.fc = nn.Linear(in_features, out_features) 135 | 136 | def forward(self, xs): 137 | ys = self.downscale(xs) 138 | ys = ys.reshape(len(xs), -1) 139 | 140 | w = self.fc(ys) 141 | w = w.reshape(len(xs), self.out_shape[0], self.out_shape[1], self.out_shape[2]) 142 | return w 143 | --------------------------------------------------------------------------------