├── .gitignore ├── LICENSE ├── README.md ├── function.py ├── net.py ├── requirements.txt ├── setup.py └── stylize.py /.gitignore: -------------------------------------------------------------------------------- 1 | models/*.t7 2 | models/*.pth 3 | models/*.py 4 | __pycache__* 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Most files in this directory (code/) are either directly copied from the 2 | pytorch-AdaIN repository (https://github.com/naoto0804/pytorch-AdaIN) 3 | or adapted slightly. The following license applies to these files: 4 | 5 | MIT License 6 | 7 | Copyright (c) 2018 Naoto Inoue 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stylize-datasets 2 | This repository contains code for stylizing arbitrary image datasets using [AdaIN](https://arxiv.org/abs/1703.06868). The code is a generalization of Robert Geirhos' [Stylized-ImageNet](https://github.com/rgeirhos/Stylized-ImageNet) code, which is tailored to stylizing ImageNet. Everything in this repository is based on naoto0804's [pytorch-AdaIN](https://github.com/naoto0804/pytorch-AdaIN) implementation. 3 | 4 | Given an image dataset, the script creates the specified number of stylized versions of every image while keeping the directory structure and naming scheme intact (usefull for existing data loaders or if directory names include class annotations). 5 | 6 | Feel free to open an issue in case there is any question. 7 | 8 | ## Usage 9 | - Dependencies: 10 | - python >= 3.6 11 | - Pillow 12 | - torch 13 | - torchvision 14 | - tqdm 15 | - Download the models: 16 | - download the models (vgg/decoder) manually from [pytorch-AdaIN](https://github.com/naoto0804/pytorch-AdaIN) and move both files to the `models/` directory 17 | - Get style images: Download train.zip from [Kaggle's painter-by-numbers dataset](https://www.kaggle.com/c/painter-by-numbers/data) 18 | - To stylize a dataset, run `python stylize.py`. 19 | 20 | Arguments: 21 | - `--content-dir ` the top-level directory of the content image dataset (mandatory) 22 | - `--style-dir ` the top-level directory of the style images (mandatory) 23 | - `--output-dir ` the directory where the stylized dataset will be stored (optional, default: `output/`) 24 | - `--num-styles ` number of stylizations to create for each content image (optional, default: `1`) 25 | - `--alpha ` Weight that controls the strength of stylization, should be between 0 and 1 (optional, default: `1`) 26 | - `--extensions ...` list of image extensions to scan style and content directory for (optional, default: `png, jpeg, jpg`). Note: this is case sensitive, `--extensions jpg` will not scan for files ending on `.JPG`. Image types must be compatible with PIL's `Image.open()` ([Documentation](https://pillow.readthedocs.io/en/5.1.x/handbook/image-file-formats.html)) 27 | - `--content-size ` Minimum size for content images, resulting in scaling of the shorter side of the content image to `N` (optional, default: `0`). Set this to 0 to keep the original image dimensions. 28 | - `--style-size ` Minimum size for style images, resulting in scaling of the shorter side of the style image to `N` (optional, default: `512`). Set this to 0 to keep the original image dimensions (for large style images, this will result in high (GPU) memory consumption). 29 | - `--crop ` Size for the center crop applied to the content image in order to create a squared image (optional, default 0). Setting this to 0 will disable the cropping. 30 | 31 | Here is an example call: 32 | 33 | ``` 34 | python3 stylize.py --content-dir '/home/username/stylize-datasets/images/' --style-dir '/home/username/stylize-datasets/train/' --num-styles 10 --content_size 0 --style_size 256 35 | ``` 36 | 37 | ## Citation 38 | 39 | If you use this code, please consider citing: 40 | ``` 41 | @article{michaelis2019dragon, 42 | title={Benchmarking Robustness in Object Detection: 43 | Autonomous Driving when Winter is Coming}, 44 | author={Michaelis, Claudio and Mitzkus, Benjamin and 45 | Geirhos, Robert and Rusak, Evgenia and 46 | Bringmann, Oliver and Ecker, Alexander S. and 47 | Bethge, Matthias and Brendel, Wieland}, 48 | journal={arXiv preprint arXiv:1907.07484}, 49 | year={2019} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_mean_std(feat, eps=1e-5): 5 | # eps is a small value added to the variance to avoid divide-by-zero. 6 | size = feat.data.size() 7 | assert (len(size) == 4) 8 | N, C = size[:2] 9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 10 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 12 | return feat_mean, feat_std 13 | 14 | 15 | def adaptive_instance_normalization(content_feat, style_feat): 16 | assert (content_feat.data.size()[:2] == style_feat.data.size()[:2]) 17 | size = content_feat.data.size() 18 | style_mean, style_std = calc_mean_std(style_feat) 19 | content_mean, content_std = calc_mean_std(content_feat) 20 | 21 | normalized_feat = (content_feat - content_mean.expand( 22 | size)) / content_std.expand(size) 23 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 24 | 25 | 26 | def _calc_feat_flatten_mean_std(feat): 27 | # takes 3D feat (C, H, W), return mean and std of array within channels 28 | assert (feat.size()[0] == 3) 29 | assert (isinstance(feat, torch.FloatTensor)) 30 | feat_flatten = feat.view(3, -1) 31 | mean = feat_flatten.mean(dim=-1, keepdim=True) 32 | std = feat_flatten.std(dim=-1, keepdim=True) 33 | return feat_flatten, mean, std 34 | 35 | 36 | def _mat_sqrt(x): 37 | U, D, V = torch.svd(x) 38 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 39 | 40 | 41 | def coral(source, target): 42 | # assume both source and target are 3D array (C, H, W) 43 | # Note: flatten -> f 44 | 45 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 46 | source_f_norm = (source_f - source_f_mean.expand_as( 47 | source_f)) / source_f_std.expand_as(source_f) 48 | source_f_cov_eye = \ 49 | torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 50 | 51 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 52 | target_f_norm = (target_f - target_f_mean.expand_as( 53 | target_f)) / target_f_std.expand_as(target_f) 54 | target_f_cov_eye = \ 55 | torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 56 | 57 | source_f_norm_transfer = torch.mm( 58 | _mat_sqrt(target_f_cov_eye), 59 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), 60 | source_f_norm) 61 | ) 62 | 63 | source_f_transfer = source_f_norm_transfer * \ 64 | target_f_std.expand_as(source_f_norm) + \ 65 | target_f_mean.expand_as(source_f_norm) 66 | 67 | return source_f_transfer.view(source.size()) 68 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | 4 | from function import adaptive_instance_normalization as adain 5 | from function import calc_mean_std 6 | 7 | decoder = nn.Sequential( 8 | nn.ReflectionPad2d((1, 1, 1, 1)), 9 | nn.Conv2d(512, 256, (3, 3)), 10 | nn.ReLU(), 11 | nn.Upsample(scale_factor=2), 12 | nn.ReflectionPad2d((1, 1, 1, 1)), 13 | nn.Conv2d(256, 256, (3, 3)), 14 | nn.ReLU(), 15 | nn.ReflectionPad2d((1, 1, 1, 1)), 16 | nn.Conv2d(256, 256, (3, 3)), 17 | nn.ReLU(), 18 | nn.ReflectionPad2d((1, 1, 1, 1)), 19 | nn.Conv2d(256, 256, (3, 3)), 20 | nn.ReLU(), 21 | nn.ReflectionPad2d((1, 1, 1, 1)), 22 | nn.Conv2d(256, 128, (3, 3)), 23 | nn.ReLU(), 24 | nn.Upsample(scale_factor=2), 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(128, 128, (3, 3)), 27 | nn.ReLU(), 28 | nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(128, 64, (3, 3)), 30 | nn.ReLU(), 31 | nn.Upsample(scale_factor=2), 32 | nn.ReflectionPad2d((1, 1, 1, 1)), 33 | nn.Conv2d(64, 64, (3, 3)), 34 | nn.ReLU(), 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(64, 3, (3, 3)), 37 | ) 38 | 39 | vgg = nn.Sequential( 40 | nn.Conv2d(3, 3, (1, 1)), 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(3, 64, (3, 3)), 43 | nn.ReLU(), # relu1-1 44 | nn.ReflectionPad2d((1, 1, 1, 1)), 45 | nn.Conv2d(64, 64, (3, 3)), 46 | nn.ReLU(), # relu1-2 47 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(64, 128, (3, 3)), 50 | nn.ReLU(), # relu2-1 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(128, 128, (3, 3)), 53 | nn.ReLU(), # relu2-2 54 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 55 | nn.ReflectionPad2d((1, 1, 1, 1)), 56 | nn.Conv2d(128, 256, (3, 3)), 57 | nn.ReLU(), # relu3-1 58 | nn.ReflectionPad2d((1, 1, 1, 1)), 59 | nn.Conv2d(256, 256, (3, 3)), 60 | nn.ReLU(), # relu3-2 61 | nn.ReflectionPad2d((1, 1, 1, 1)), 62 | nn.Conv2d(256, 256, (3, 3)), 63 | nn.ReLU(), # relu3-3 64 | nn.ReflectionPad2d((1, 1, 1, 1)), 65 | nn.Conv2d(256, 256, (3, 3)), 66 | nn.ReLU(), # relu3-4 67 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 68 | nn.ReflectionPad2d((1, 1, 1, 1)), 69 | nn.Conv2d(256, 512, (3, 3)), 70 | nn.ReLU(), # relu4-1, this is the last layer used 71 | nn.ReflectionPad2d((1, 1, 1, 1)), 72 | nn.Conv2d(512, 512, (3, 3)), 73 | nn.ReLU(), # relu4-2 74 | nn.ReflectionPad2d((1, 1, 1, 1)), 75 | nn.Conv2d(512, 512, (3, 3)), 76 | nn.ReLU(), # relu4-3 77 | nn.ReflectionPad2d((1, 1, 1, 1)), 78 | nn.Conv2d(512, 512, (3, 3)), 79 | nn.ReLU(), # relu4-4 80 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 81 | nn.ReflectionPad2d((1, 1, 1, 1)), 82 | nn.Conv2d(512, 512, (3, 3)), 83 | nn.ReLU(), # relu5-1 84 | nn.ReflectionPad2d((1, 1, 1, 1)), 85 | nn.Conv2d(512, 512, (3, 3)), 86 | nn.ReLU(), # relu5-2 87 | nn.ReflectionPad2d((1, 1, 1, 1)), 88 | nn.Conv2d(512, 512, (3, 3)), 89 | nn.ReLU(), # relu5-3 90 | nn.ReflectionPad2d((1, 1, 1, 1)), 91 | nn.Conv2d(512, 512, (3, 3)), 92 | nn.ReLU() # relu5-4 93 | ) 94 | 95 | 96 | class Net(nn.Module): 97 | def __init__(self, encoder, decoder): 98 | super(Net, self).__init__() 99 | enc_layers = list(encoder.children()) 100 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 101 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 102 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 103 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 104 | self.decoder = decoder 105 | self.mse_loss = nn.MSELoss() 106 | 107 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 108 | def encode_with_intermediate(self, input): 109 | results = [input] 110 | for i in range(4): 111 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 112 | results.append(func(results[-1])) 113 | return results[1:] 114 | 115 | # extract relu4_1 from input image 116 | def encode(self, input): 117 | for i in range(4): 118 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 119 | return input 120 | 121 | def calc_content_loss(self, input, target): 122 | assert (input.data.size() == target.data.size()) 123 | assert (target.requires_grad is False) 124 | return self.mse_loss(input, target) 125 | 126 | def calc_style_loss(self, input, target): 127 | assert (input.data.size() == target.data.size()) 128 | assert (target.requires_grad is False) 129 | input_mean, input_std = calc_mean_std(input) 130 | target_mean, target_std = calc_mean_std(target) 131 | return self.mse_loss(input_mean, target_mean) + \ 132 | self.mse_loss(input_std, target_std) 133 | 134 | def forward(self, content, style): 135 | style_feats = self.encode_with_intermediate(style) 136 | t = adain(self.encode(content), style_feats[-1]) 137 | 138 | g_t = self.decoder(Variable(t.data, requires_grad=True)) 139 | g_t_feats = self.encode_with_intermediate(g_t) 140 | 141 | loss_c = self.calc_content_loss(g_t_feats[-1], t) 142 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) 143 | for i in range(1, 4): 144 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) 145 | return loss_c, loss_s 146 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python >= 3.6 2 | Pillow 3 | torch 4 | torchvision 5 | tqdm -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='stylize-datasets', 5 | version='0.1.0', 6 | packages=[''], 7 | url='https://github.com/bethgelab/stylize-datasets', 8 | license='MIT License', 9 | author='', 10 | author_email='', 11 | description='This repository contains code for stylizing arbitrary image datasets using AdaIN. The code is a generalization of Robert Geirhos\' Stylized-ImageNet code, which is tailored to stylizing ImageNet. Everything in this repository is based on naoto0804\'s pytorch-AdaIN implementation. Given an image dataset, the script creates the specified number of stylized versions of every image while keeping the directory structure and naming scheme intact (usefull for existing data loaders or if directory names include class annotations).' 12 | ) 13 | -------------------------------------------------------------------------------- /stylize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | from function import adaptive_instance_normalization 4 | import net 5 | from pathlib import Path 6 | from PIL import Image 7 | import random 8 | import torch 9 | import torch.nn as nn 10 | import torchvision.transforms 11 | from torchvision.utils import save_image 12 | from tqdm import tqdm 13 | 14 | parser = argparse.ArgumentParser(description='This script applies the AdaIN style transfer method to arbitrary datasets.') 15 | parser.add_argument('--content-dir', type=str, 16 | help='Directory path to a batch of content images') 17 | parser.add_argument('--style-dir', type=str, 18 | help='Directory path to a batch of style images') 19 | parser.add_argument('--output-dir', type=str, default='output', 20 | help='Directory to save the output images') 21 | parser.add_argument('--num-styles', type=int, default=1, help='Number of styles to \ 22 | create for each image (default: 1)') 23 | parser.add_argument('--alpha', type=float, default=1.0, 24 | help='The weight that controls the degree of \ 25 | stylization. Should be between 0 and 1') 26 | parser.add_argument('--extensions', nargs='+', type=str, default=['png', 'jpeg', 'jpg'], help='List of image extensions to scan style and content directory for (case sensitive), default: png, jpeg, jpg') 27 | 28 | # Advanced options 29 | parser.add_argument('--content-size', type=int, default=0, 30 | help='New (minimum) size for the content image, \ 31 | keeping the original size if set to 0') 32 | parser.add_argument('--style-size', type=int, default=512, 33 | help='New (minimum) size for the style image, \ 34 | keeping the original size if set to 0') 35 | parser.add_argument('--crop', type=int, default=0, 36 | help='If set to anything else than 0, center crop of this size will be applied to the content image \ 37 | after resizing in order to create a squared image (default: 0)') 38 | 39 | # random.seed(131213) 40 | 41 | def input_transform(size, crop): 42 | transform_list = [] 43 | if size != 0: 44 | transform_list.append(torchvision.transforms.Resize(size)) 45 | if crop != 0: 46 | transform_list.append(torchvision.transforms.CenterCrop(crop)) 47 | transform_list.append(torchvision.transforms.ToTensor()) 48 | transform = torchvision.transforms.Compose(transform_list) 49 | return transform 50 | 51 | def style_transfer(vgg, decoder, content, style, alpha=1.0): 52 | assert (0.0 <= alpha <= 1.0) 53 | content_f = vgg(content) 54 | style_f = vgg(style) 55 | feat = adaptive_instance_normalization(content_f, style_f) 56 | feat = feat * alpha + content_f * (1 - alpha) 57 | return decoder(feat) 58 | 59 | def main(): 60 | args = parser.parse_args() 61 | 62 | # set content and style directories 63 | content_dir = Path(args.content_dir) 64 | style_dir = Path(args.style_dir) 65 | style_dir = style_dir.resolve() 66 | output_dir = Path(args.output_dir) 67 | output_dir = output_dir.resolve() 68 | assert style_dir.is_dir(), 'Style directory not found' 69 | 70 | # collect content files 71 | extensions = args.extensions 72 | assert len(extensions) > 0, 'No file extensions specified' 73 | content_dir = Path(content_dir) 74 | content_dir = content_dir.resolve() 75 | assert content_dir.is_dir(), 'Content directory not found' 76 | dataset = [] 77 | for ext in extensions: 78 | dataset += list(content_dir.rglob('*.' + ext)) 79 | 80 | assert len(dataset) > 0, 'No images with specified extensions found in content directory' + content_dir 81 | content_paths = sorted(dataset) 82 | print('Found %d content images in %s' % (len(content_paths), content_dir)) 83 | 84 | # collect style files 85 | styles = [] 86 | for ext in extensions: 87 | styles += list(style_dir.rglob('*.' + ext)) 88 | 89 | assert len(styles) > 0, 'No images with specified extensions found in style directory' + style_dir 90 | styles = sorted(styles) 91 | print('Found %d style images in %s' % (len(styles), style_dir)) 92 | 93 | decoder = net.decoder 94 | vgg = net.vgg 95 | 96 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 97 | 98 | decoder.eval() 99 | vgg.eval() 100 | 101 | decoder.load_state_dict(torch.load('models/decoder.pth')) 102 | vgg.load_state_dict(torch.load('models/vgg_normalised.pth')) 103 | vgg = nn.Sequential(*list(vgg.children())[:31]) 104 | 105 | vgg.to(device) 106 | decoder.to(device) 107 | 108 | content_tf = input_transform(args.content_size, args.crop) 109 | style_tf = input_transform(args.style_size, 0) 110 | 111 | 112 | # disable decompression bomb errors 113 | Image.MAX_IMAGE_PIXELS = None 114 | skipped_imgs = [] 115 | 116 | # actual style transfer as in AdaIN 117 | with tqdm(total=len(content_paths)) as pbar: 118 | for content_path in content_paths: 119 | try: 120 | content_img = Image.open(content_path).convert('RGB') 121 | for style_path in random.sample(styles, args.num_styles): 122 | style_img = Image.open(style_path).convert('RGB') 123 | 124 | content = content_tf(content_img) 125 | style = style_tf(style_img) 126 | style = style.to(device).unsqueeze(0) 127 | content = content.to(device).unsqueeze(0) 128 | with torch.no_grad(): 129 | output = style_transfer(vgg, decoder, content, style, 130 | args.alpha) 131 | output = output.cpu() 132 | 133 | rel_path = content_path.relative_to(content_dir) 134 | out_dir = output_dir.joinpath(rel_path.parent) 135 | 136 | # create directory structure if it does not exist 137 | if not out_dir.is_dir(): 138 | out_dir.mkdir(parents=True) 139 | 140 | content_name = content_path.stem 141 | style_name = style_path.stem 142 | out_filename = content_name + '-stylized-' + style_name + content_path.suffix 143 | output_name = out_dir.joinpath(out_filename) 144 | 145 | save_image(output, output_name, padding=0) #default image padding is 2. 146 | style_img.close() 147 | content_img.close() 148 | except OSError as e: 149 | print('Skipping stylization of %s due to an error' %(content_path)) 150 | skipped_imgs.append(content_path) 151 | continue 152 | except RuntimeError as e: 153 | print('Skipping stylization of %s due to an error' %(content_path)) 154 | skipped_imgs.append(content_path) 155 | continue 156 | finally: 157 | pbar.update(1) 158 | 159 | if(len(skipped_imgs) > 0): 160 | with open(output_dir.joinpath('skipped_imgs.txt'), 'w') as f: 161 | for item in skipped_imgs: 162 | f.write("%s\n" % item) 163 | 164 | if __name__ == '__main__': 165 | main() 166 | --------------------------------------------------------------------------------