├── Avatar-Net.ipynb ├── README.md ├── main.py ├── network.py ├── sample_images ├── content │ ├── avril.jpg │ ├── bair.jpg │ ├── blonde_girl.jpg │ ├── brad_pitt.jpg │ ├── chicago.jpg │ ├── cornell.jpg │ ├── flowers.jpg │ ├── golden_gate.jpg │ ├── green_eye.jpg │ ├── home_alone.jpg │ ├── karya.jpg │ ├── lenna.jpg │ ├── modern.jpg │ ├── newyork.jpg │ ├── sailboat.jpg │ ├── stata.jpg │ ├── taj_mahal.jpg │ ├── tubingen.jpg │ └── venice-boat.jpg ├── mask │ ├── blonde_girl_mask1.jpg │ └── blonde_girl_mask2.jpg ├── style │ ├── abstraction.jpg │ ├── brick.jpg │ ├── candy.jpg │ ├── mondrian.jpg │ ├── starry_night.jpg │ └── yellow_sunset.jpg └── test_results │ ├── content_style_interpolation.jpg │ ├── masked_stylized_image.jpg │ ├── multiple_style_interpolation.jpg │ ├── patch_size_variation.jpg │ ├── patch_stride_variation.jpg │ ├── stylization.jpg │ └── training_loss.png ├── style_decorator.py ├── test.py ├── train.py ├── utils.py └── wct.py /README.md: -------------------------------------------------------------------------------- 1 | Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration 2 | --- 3 | 4 | **Unofficial Pytorch Implementation of Avatar-Net** 5 | 6 | **Reference**: [Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration, CVPR2018](https://arxiv.org/abs/1805.03857) 7 | 8 | 9 | ![result_image](./sample_images/test_results/stylization.jpg) 10 | 11 | 12 | Requirements 13 | -- 14 | * torch (version: 1.2.0) 15 | * torchvision (version: 0.4.0) 16 | * Pillow (version: 6.1.0) 17 | * matplotlib (version: 3.1.1) 18 | 19 | Download 20 | -- 21 | * The trained models can be downloaded throuth the [releases](https://github.com/tyui592/Avatar-Net_Pytorch/releases/download/v0.2/check_point.pth). 22 | * [MSCOCO train2014](http://cocodataset.org/#download) is needed to train the network. 23 | 24 | Usage 25 | -- 26 | 27 | ### Arguments 28 | * `--gpu-no`: GPU device number (-1: cpu, 0~N: GPU) 29 | * `--train`: Flag for the network training (default: False) 30 | * `--content-dir`: Path of the Content image dataset for training 31 | * `--imsize`: Size for resizing input images (resize shorter side of the image) 32 | * `--cropsize`: Size for crop input images (crop the image into squares) 33 | * `--cencrop`: Flag for crop the center reigion of the image (default: randomly crop) 34 | * `--check-point`: Check point path for loading trained network 35 | * `--content`: Content image path to evalute the network 36 | * `--style`: Style image path to evalute the network 37 | * `--mask`: Mask image path for masked stylization 38 | * `--style-strength`: Content vs Style interpolation weight (1.0: style, 0.0: content, default: 1.0) 39 | * `--interpolatoin-weights`: Weights for multiple style interpolation 40 | * `--patch-size`: Patch size of style decorator (default: 3) 41 | * `--patch-stride`: Patch stride of style decorator (default: 1) 42 | 43 | 44 | ### Train example script 45 | 46 | ``` 47 | python main.py --train --gpu-no 0 --imsize 512 --cropsize 256 --content-dir ./coco2014/ --save-path ./trained_models/ 48 | ``` 49 | 50 | ![training_loss](./sample_images/test_results/training_loss.png) 51 | 52 | 53 | ### Test example script and image 54 | * These figures are generated in [jupyter notebook](Avatar-Net.ipynb). You can make the figure yourself. 55 | 56 | #### Generate the stylized image with a single style (Content-style interapoltion) 57 | 58 | ``` 59 | python main.py --check-point ./trained_models/check_point.pth --imsize 512 --cropsize 512 --cencrop --content ./sample_images/content/blonde_girl.jpg --style ./sample_images/style/mondrian.jpg --style-strength 1.0 60 | ``` 61 | 62 | ![content_style_interpolation](./sample_images/test_results/content_style_interpolation.jpg) 63 | 64 | #### Generate the stylized image with multiple style 65 | 66 | ``` 67 | python main.py --check-point ./trained_models/check_point.pth --imsize 512 --cropsize 512 --content ./sample_images/content/blonde_girl.jpg --style ./sample_images/style/mondrian.jpg ./sample_images/style/abstraction.jpg --interpolation-weights 0.5 0.5 68 | ``` 69 | 70 | ![multiple_style_interpolation](./sample_images/test_results/multiple_style_interpolation.jpg) 71 | 72 | 73 | #### Generate the stylized image with multiple style and mask 74 | 75 | ``` 76 | python main.py --check-point ./trained_models/check_point.pth --imsize 512 --cropsize 512 --content ./sample_images/content/blonde_girl.jpg --style ./sample_images/style/mondrian.jpg ./sample_images/style/abstraction.jpg --mask ./sample_images/mask/blonde_girl_mask1.jpg ./sample_images/mask/blonde_girl_mask2.jpg --interpolation-weights 1.0 1.0 77 | ``` 78 | 79 | ![masked_stylization](./sample_images/test_results/masked_stylized_image.jpg) 80 | 81 | 82 | #### Generate the stylized image with varying patch size 83 | 84 | ``` 85 | python main.py --check-point ./trained_models/check_point.pth --imsize 512 --cropsize 512 --content ./sample_images/content/blonde_girl.jpg --style ./sample_images/style/mondrian.jpg --patch-size 3 86 | ``` 87 | 88 | ![patch_size_variation](./sample_images/test_results/patch_size_variation.jpg) 89 | 90 | 91 | #### Generate the stylized image with varying patch stride 92 | 93 | ``` 94 | python main.py --check-point ./trained_models/check_point.pth --imsize 512 --cropsize 512 --content ./sample_images/content/blonde_girl.jpg --style ./sample_images/style/mondrian.jpg --patch-stride 4 95 | ``` 96 | 97 | ![patch_stride_variation](./sample_images/test_results/patch_stride_variation.jpg) 98 | 99 | 100 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from train import network_train 5 | from test import network_test 6 | 7 | def build_parser(): 8 | parser = argparse.ArgumentParser() 9 | 10 | # cpu, gpu mode selection 11 | parser.add_argument('--gpu-no', type=int, 12 | help='cpu : -1, gpu : 0 ~ n ', default=0) 13 | 14 | ### arguments for network training 15 | parser.add_argument('--train', action='store_true', 16 | help='Train flag', default=False) 17 | 18 | parser.add_argument('--max-iter', type=int, 19 | help='Train iterations', default=40000) 20 | 21 | parser.add_argument('--batch-size', type=int, 22 | help='Batch size', default=16) 23 | 24 | parser.add_argument('--lr', type=float, 25 | help='Learning rate to optimize network', default=1e-3) 26 | 27 | parser.add_argument('--check-iter', type=int, 28 | help='Number of iteration to check training logs', default=100) 29 | 30 | parser.add_argument('--imsize', type=int, 31 | help='Size for resize image during training', default=512) 32 | 33 | parser.add_argument('--cropsize', type=int, 34 | help='Size for crop image durning training', default=None) 35 | 36 | parser.add_argument('--cencrop', action='store_true', 37 | help='Flag for crop the center rigion of the image, default: randomly crop', default=False) 38 | 39 | parser.add_argument('--layers', type=int, nargs='+', 40 | help='Layer indices to extract features', default=[1, 6, 11, 20]) 41 | 42 | parser.add_argument('--feature-weight', type=float, 43 | help='Feautre loss weight', default=0.1) 44 | 45 | parser.add_argument('--tv-weight', type=float, 46 | help='Total valiation loss weight', default=1.0) 47 | 48 | parser.add_argument('--content-dir', type=str, 49 | help='Content data path to train the network') 50 | 51 | parser.add_argument('--save-path', type=str, 52 | help='Save path', default='./trained_models/') 53 | 54 | parser.add_argument('--check-point', type=str, 55 | help="Trained model load path") 56 | 57 | parser.add_argument('--content', type=str, 58 | help="Test content image path") 59 | 60 | parser.add_argument('--style', type=str, nargs='+', 61 | help="Test style image path") 62 | 63 | parser.add_argument('--mask', type=str, nargs='+', 64 | help="Mask image for masked stylization", default=None) 65 | 66 | parser.add_argument('--style-strength', type=float, 67 | help='Content vs style interpolation value: 1(style), 0(content)', default=1.0) 68 | 69 | parser.add_argument('--interpolation-weights', type=float, nargs='+', 70 | help='Multi-style interpolation weights', default=None) 71 | 72 | parser.add_argument('--patch-size', type=int, 73 | help='Size of patch for swap normalized content and style features', default=3) 74 | 75 | parser.add_argument('--patch-stride', type=int, 76 | help='Size of patch stride for swap normalized content and style features', default=1) 77 | 78 | return parser 79 | 80 | if __name__ == '__main__': 81 | parser = build_parser() 82 | args= parser.parse_args() 83 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_no) 84 | 85 | if args.train: 86 | network_train(args) 87 | else: 88 | network_test(args) 89 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision 6 | import torchvision.models as models 7 | 8 | 9 | from style_decorator import StyleDecorator 10 | 11 | class AvatarNet(nn.Module): 12 | def __init__(self, layers=[1, 6, 11, 20]): 13 | super(AvatarNet, self).__init__() 14 | self.encoder = Encoder(layers) 15 | self.decoder = Decoder(layers) 16 | 17 | self.adain = AdaIN() 18 | self.decorator = StyleDecorator() 19 | 20 | def forward(self, content, styles, style_strength=1.0, patch_size=3, patch_stride=1, masks=None, interpolation_weights=None, train=False): 21 | if interpolation_weights is None: 22 | interpolation_weights = [1/len(styles)] * len(styles) 23 | if masks is None: 24 | masks = [1] * len(styles) 25 | 26 | # encode content image 27 | content_feature = self.encoder(content) 28 | style_features = [] 29 | for style in styles: 30 | style_features.append(self.encoder(style)) 31 | 32 | if not train: 33 | transformed_feature = [] 34 | for style_feature, interpolation_weight, mask in zip(style_features, interpolation_weights, masks): 35 | if isinstance(mask, torch.Tensor): 36 | b, c, h, w = content_feature[-1].size() 37 | mask = F.interpolate(mask, size=(h, w)) 38 | transformed_feature.append(self.decorator(content_feature[-1], style_feature[-1], style_strength, patch_size, patch_stride) * interpolation_weight * mask) 39 | transformed_feature = sum(transformed_feature) 40 | 41 | else: 42 | transformed_feature = content_feature[-1] 43 | 44 | # re-ordering style features for transferring feature during decoding 45 | style_features = [style_feature[:-1][::-1] for style_feature in style_features] 46 | 47 | stylized_image = self.decoder(transformed_feature, style_features, masks, interpolation_weights) 48 | 49 | return stylized_image 50 | 51 | class AdaIN(nn.Module): 52 | def __init__(self): 53 | super(AdaIN, self).__init__() 54 | 55 | def forward(self, content, style, style_strength=1.0, eps=1e-5): 56 | b, c, h, w = content.size() 57 | 58 | content_std, content_mean = torch.std_mean(content.view(b, c, -1), dim=2, keepdim=True) 59 | style_std, style_mean = torch.std_mean(style.view(b, c, -1), dim=2, keepdim=True) 60 | 61 | normalized_content = (content.view(b, c, -1) - content_mean)/(content_std+eps) 62 | 63 | stylized_content = (normalized_content * style_std) + style_mean 64 | 65 | output = (1-style_strength)*content + style_strength*stylized_content.view(b, c, h, w) 66 | return output 67 | 68 | class Encoder(nn.Module): 69 | def __init__(self, layers=[1, 6, 11, 20]): 70 | super(Encoder, self).__init__() 71 | vgg = torchvision.models.vgg19(pretrained=True).features 72 | 73 | self.encoder = nn.ModuleList() 74 | temp_seq = nn.Sequential() 75 | for i in range(max(layers)+1): 76 | temp_seq.add_module(str(i), vgg[i]) 77 | if i in layers: 78 | self.encoder.append(temp_seq) 79 | temp_seq = nn.Sequential() 80 | 81 | def forward(self, x): 82 | features = [] 83 | for layer in self.encoder: 84 | x = layer(x) 85 | features.append(x) 86 | return features 87 | 88 | class Decoder(nn.Module): 89 | def __init__(self, layers=[1, 6, 11, 20], transformers=[AdaIN(), AdaIN(), AdaIN(), None]): 90 | super(Decoder, self).__init__() 91 | vgg = torchvision.models.vgg19(pretrained=False).features 92 | self.transformers = transformers 93 | 94 | self.decoder = nn.ModuleList() 95 | temp_seq = nn.Sequential() 96 | count = 0 97 | for i in range(max(layers)-1, -1, -1): 98 | if isinstance(vgg[i], nn.Conv2d): 99 | # get number of in/out channels 100 | out_channels = vgg[i].in_channels 101 | in_channels = vgg[i].out_channels 102 | kernel_size = vgg[i].kernel_size 103 | 104 | # make a [reflection pad + convolution + relu] layer 105 | temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1))) 106 | count += 1 107 | temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size)) 108 | count += 1 109 | temp_seq.add_module(str(count), nn.ReLU()) 110 | count += 1 111 | 112 | # change down-sampling(MaxPooling) --> upsampling 113 | elif isinstance(vgg[i], nn.MaxPool2d): 114 | temp_seq.add_module(str(count), nn.Upsample(scale_factor=2)) 115 | count += 1 116 | 117 | if i in layers: 118 | self.decoder.append(temp_seq) 119 | temp_seq = nn.Sequential() 120 | 121 | # append last conv layers without ReLU activation 122 | self.decoder.append(temp_seq[:-1]) 123 | 124 | def forward(self, x, styles, masks=None, interpolation_weights=None): 125 | if interpolation_weights is None: 126 | interpolation_weights = [1/len(styles)] * len(styles) 127 | if masks is None: 128 | masks = [1] * len(styles) 129 | 130 | y = x 131 | for i, layer in enumerate(self.decoder): 132 | y = layer(y) 133 | 134 | if self.transformers[i]: 135 | transformed_feature = [] 136 | for style, interpolation_weight, mask in zip(styles, interpolation_weights, masks): 137 | if isinstance(mask, torch.Tensor): 138 | b, c, h, w = y.size() 139 | mask = F.interpolate(mask, size=(h, w)) 140 | transformed_feature.append(self.transformers[i](y, style[i]) * interpolation_weight * mask) 141 | y = sum(transformed_feature) 142 | 143 | return y 144 | -------------------------------------------------------------------------------- /sample_images/content/avril.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/avril.jpg -------------------------------------------------------------------------------- /sample_images/content/bair.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/bair.jpg -------------------------------------------------------------------------------- /sample_images/content/blonde_girl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/blonde_girl.jpg -------------------------------------------------------------------------------- /sample_images/content/brad_pitt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/brad_pitt.jpg -------------------------------------------------------------------------------- /sample_images/content/chicago.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/chicago.jpg -------------------------------------------------------------------------------- /sample_images/content/cornell.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/cornell.jpg -------------------------------------------------------------------------------- /sample_images/content/flowers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/flowers.jpg -------------------------------------------------------------------------------- /sample_images/content/golden_gate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/golden_gate.jpg -------------------------------------------------------------------------------- /sample_images/content/green_eye.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/green_eye.jpg -------------------------------------------------------------------------------- /sample_images/content/home_alone.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/home_alone.jpg -------------------------------------------------------------------------------- /sample_images/content/karya.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/karya.jpg -------------------------------------------------------------------------------- /sample_images/content/lenna.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/lenna.jpg -------------------------------------------------------------------------------- /sample_images/content/modern.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/modern.jpg -------------------------------------------------------------------------------- /sample_images/content/newyork.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/newyork.jpg -------------------------------------------------------------------------------- /sample_images/content/sailboat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/sailboat.jpg -------------------------------------------------------------------------------- /sample_images/content/stata.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/stata.jpg -------------------------------------------------------------------------------- /sample_images/content/taj_mahal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/taj_mahal.jpg -------------------------------------------------------------------------------- /sample_images/content/tubingen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/tubingen.jpg -------------------------------------------------------------------------------- /sample_images/content/venice-boat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/content/venice-boat.jpg -------------------------------------------------------------------------------- /sample_images/mask/blonde_girl_mask1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/mask/blonde_girl_mask1.jpg -------------------------------------------------------------------------------- /sample_images/mask/blonde_girl_mask2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/mask/blonde_girl_mask2.jpg -------------------------------------------------------------------------------- /sample_images/style/abstraction.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/style/abstraction.jpg -------------------------------------------------------------------------------- /sample_images/style/brick.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/style/brick.jpg -------------------------------------------------------------------------------- /sample_images/style/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/style/candy.jpg -------------------------------------------------------------------------------- /sample_images/style/mondrian.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/style/mondrian.jpg -------------------------------------------------------------------------------- /sample_images/style/starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/style/starry_night.jpg -------------------------------------------------------------------------------- /sample_images/style/yellow_sunset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/style/yellow_sunset.jpg -------------------------------------------------------------------------------- /sample_images/test_results/content_style_interpolation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/test_results/content_style_interpolation.jpg -------------------------------------------------------------------------------- /sample_images/test_results/masked_stylized_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/test_results/masked_stylized_image.jpg -------------------------------------------------------------------------------- /sample_images/test_results/multiple_style_interpolation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/test_results/multiple_style_interpolation.jpg -------------------------------------------------------------------------------- /sample_images/test_results/patch_size_variation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/test_results/patch_size_variation.jpg -------------------------------------------------------------------------------- /sample_images/test_results/patch_stride_variation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/test_results/patch_stride_variation.jpg -------------------------------------------------------------------------------- /sample_images/test_results/stylization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/test_results/stylization.jpg -------------------------------------------------------------------------------- /sample_images/test_results/training_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Avatar-Net_Pytorch/0fff5054e8107946175806e6cf413383247bf933/sample_images/test_results/training_loss.png -------------------------------------------------------------------------------- /style_decorator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from wct import whitening, coloring 5 | 6 | def extract_patches(feature, patch_size, stride): 7 | ph, pw = patch_size 8 | sh, sw = stride 9 | 10 | # padding the feature 11 | padh = (ph - 1) // 2 12 | padw = (pw - 1) // 2 13 | padding_size = (padw, padw, padh, padh) 14 | feature = F.pad(feature, padding_size, 'constant', 0) 15 | 16 | # extract patches 17 | patches = feature.unfold(2, ph, sh).unfold(3, pw, sw) 18 | patches = patches.contiguous().view(*patches.size()[:-2], -1) 19 | 20 | return patches 21 | 22 | class StyleDecorator(torch.nn.Module): 23 | 24 | def __init__(self): 25 | super(StyleDecorator, self).__init__() 26 | 27 | def kernel_normalize(self, kernel, k=3): 28 | b, ch, h, w, kk = kernel.size() 29 | 30 | # calc kernel norm 31 | kernel = kernel.view(b, ch, h*w, kk).transpose(2, 1) 32 | kernel_norm = torch.norm(kernel.contiguous().view(b, h*w, ch*kk), p=2, dim=2, keepdim=True) 33 | 34 | # kernel reshape 35 | kernel = kernel.view(b, h*w, ch, k, k) 36 | kernel_norm = kernel_norm.view(b, h*w, 1, 1, 1) 37 | 38 | return kernel, kernel_norm 39 | 40 | def conv2d_with_style_kernels(self, features, kernels, patch_size, deconv_flag=False): 41 | output = list() 42 | b, c, h, w = features.size() 43 | 44 | # padding 45 | pad = (patch_size - 1) // 2 46 | padding_size = (pad, pad, pad, pad) 47 | 48 | # batch-wise convolutions with style kernels 49 | for feature, kernel in zip(features, kernels): 50 | feature = F.pad(feature.unsqueeze(0), padding_size, 'constant', 0) 51 | 52 | if deconv_flag: 53 | padding_size = patch_size - 1 54 | output.append(F.conv_transpose2d(feature, kernel, padding=padding_size)) 55 | else: 56 | output.append(F.conv2d(feature, kernel)) 57 | 58 | return torch.cat(output, dim=0) 59 | 60 | def binarize_patch_score(self, features): 61 | outputs= list() 62 | 63 | # batch-wise operation 64 | for feature in features: 65 | matching_indices = torch.argmax(feature, dim=0) 66 | one_hot_mask = torch.zeros_like(feature) 67 | 68 | h, w = matching_indices.size() 69 | for i in range(h): 70 | for j in range(w): 71 | ind = matching_indices[i, j] 72 | one_hot_mask[ind, i, j] = 1 73 | outputs.append(one_hot_mask.unsqueeze(0)) 74 | 75 | return torch.cat(outputs, dim=0) 76 | 77 | def norm_deconvolution(self, h, w, patch_size): 78 | mask = torch.ones((h, w)) 79 | fullmask = torch.zeros((h + patch_size - 1, w + patch_size - 1)) 80 | 81 | for i in range(patch_size): 82 | for j in range(patch_size): 83 | pad = (i, patch_size - i - 1, j, patch_size - j - 1) 84 | padded_mask = F.pad(mask, pad, 'constant', 0) 85 | fullmask += padded_mask 86 | 87 | pad_width = (patch_size - 1) // 2 88 | if pad_width == 0: 89 | deconv_norm = fullmask 90 | else: 91 | deconv_norm = fullmask[pad_width:-pad_width, pad_width:-pad_width] 92 | 93 | return deconv_norm.view(1, 1, h, w) 94 | 95 | def reassemble_feature(self, normalized_content_feature, normalized_style_feature, patch_size, patch_stride): 96 | # get patches of style feature 97 | style_kernel = extract_patches(normalized_style_feature, [patch_size, patch_size], [patch_stride, patch_stride]) 98 | 99 | # kernel normalize 100 | style_kernel, kernel_norm = self.kernel_normalize(style_kernel, patch_size) 101 | 102 | # convolution with style kernel(patch wise convolution) 103 | patch_score = self.conv2d_with_style_kernels(normalized_content_feature, style_kernel/kernel_norm, patch_size) 104 | 105 | # binarization 106 | binarized = self.binarize_patch_score(patch_score) 107 | 108 | # deconv norm 109 | deconv_norm = self.norm_deconvolution(h=binarized.size(2), w=binarized.size(3), patch_size=patch_size) 110 | 111 | # deconvolution 112 | output = self.conv2d_with_style_kernels(binarized, style_kernel, patch_size, deconv_flag=True) 113 | 114 | return output/deconv_norm.type_as(output) 115 | 116 | def forward(self, content_feature, style_feature, style_strength=1.0, patch_size=3, patch_stride=1): 117 | # 1-1. content feature projection 118 | normalized_content_feature = whitening(content_feature) 119 | 120 | # 1-2. style feature projection 121 | normalized_style_feature = whitening(style_feature) 122 | 123 | # 2. swap content and style features 124 | reassembled_feature = self.reassemble_feature(normalized_content_feature, normalized_style_feature, patch_size=patch_size, patch_stride=patch_stride) 125 | 126 | # 3. reconstruction feature with style mean and covariance matrix 127 | stylized_feature = coloring(reassembled_feature, style_feature) 128 | 129 | # 4. content and style interpolation 130 | result_feature = (1 - style_strength) * content_feature + style_strength * stylized_feature 131 | 132 | return result_feature 133 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from network import AvatarNet 4 | from utils import imload, imsave, maskload 5 | 6 | 7 | def network_test(args): 8 | # set device 9 | device = torch.device('cuda' if args.gpu_no >= 0 else 'cpu') 10 | 11 | # load check point 12 | check_point = torch.load(args.check_point) 13 | 14 | # load network 15 | network = AvatarNet(args.layers) 16 | network.load_state_dict(check_point['state_dict']) 17 | network = network.to(device) 18 | 19 | # load target images 20 | content_img = imload(args.content, args.imsize, args.cropsize).to(device) 21 | style_imgs = [imload(style, args.imsize, args.cropsize, args.cencrop).to(device) for style in args.style] 22 | masks = None 23 | if args.mask: 24 | masks = [maskload(mask).to(device) for mask in args.mask] 25 | 26 | # stylize image 27 | with torch.no_grad(): 28 | stylized_img = network(content_img, style_imgs, args.style_strength, args.patch_size, args.patch_stride, 29 | masks, args.interpolation_weights, False) 30 | 31 | imsave(stylized_img, 'stylized_image.jpg') 32 | 33 | return None 34 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | from network import AvatarNet, Encoder 5 | from utils import ImageFolder, imsave, lastest_arverage_value 6 | 7 | def network_train(args): 8 | # set device 9 | device = torch.device('cuda' if args.gpu_no >= 0 else 'cpu') 10 | 11 | # get network 12 | network = AvatarNet(args.layers).to(device) 13 | 14 | # get data set 15 | data_set = ImageFolder(args.content_dir, args.imsize, args.cropsize, args.cencrop) 16 | 17 | # get loss calculator 18 | loss_network = Encoder(args.layers).to(device) 19 | mse_loss = torch.nn.MSELoss(reduction='mean').to(device) 20 | loss_seq = {'total':[], 'image':[], 'feature':[], 'tv':[]} 21 | 22 | # get optimizer 23 | for param in network.encoder.parameters(): 24 | param.requires_grad = False 25 | optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr) 26 | 27 | # training 28 | for iteration in range(args.max_iter): 29 | data_loader = torch.utils.data.DataLoader(data_set, batch_size=args.batch_size, shuffle=True) 30 | input_image = next(iter(data_loader)).to(device) 31 | 32 | output_image = network(input_image, [input_image], train=True) 33 | 34 | # calculate losses 35 | total_loss = 0 36 | ## image reconstruction loss 37 | image_loss = mse_loss(output_image, input_image) 38 | loss_seq['image'].append(image_loss.item()) 39 | total_loss += image_loss 40 | 41 | ## feature reconstruction loss 42 | input_features = loss_network(input_image) 43 | output_features = loss_network(output_image) 44 | feature_loss = 0 45 | for output_feature, input_feature in zip(output_features, input_features): 46 | feature_loss += mse_loss(output_feature, input_feature) 47 | loss_seq['feature'].append(feature_loss.item()) 48 | total_loss += feature_loss * args.feature_weight 49 | 50 | ## total variation loss 51 | tv_loss = calc_tv_loss(output_image) 52 | loss_seq['tv'].append(tv_loss.item()) 53 | total_loss += tv_loss * args.tv_weight 54 | 55 | loss_seq['total'].append(total_loss.item()) 56 | 57 | optimizer.zero_grad() 58 | total_loss.backward() 59 | optimizer.step() 60 | 61 | # print loss log and save network, loss log and output images 62 | if (iteration + 1) % args.check_iter == 0: 63 | imsave(torch.cat([input_image, output_image], dim=0), args.save_path+"training_image.png") 64 | print("%s: Iteration: [%d/%d]\tImage Loss: %2.4f\tFeature Loss: %2.4f\tTV Loss: %2.4f\tTotal: %2.4f"%(time.ctime(),iteration+1, 65 | args.max_iter, lastest_arverage_value(loss_seq['image']), lastest_arverage_value(loss_seq['feature']), 66 | lastest_arverage_value(loss_seq['tv']), lastest_arverage_value(loss_seq['total']))) 67 | torch.save({'iteration': iteration+1, 68 | 'state_dict': network.state_dict(), 69 | 'loss_seq': loss_seq}, 70 | args.save_path+'check_point.pth') 71 | 72 | return network 73 | 74 | def calc_tv_loss(x): 75 | tv_loss = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) 76 | tv_loss += torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) 77 | return tv_loss 78 | 79 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | 7 | from PIL import Image 8 | 9 | def lastest_arverage_value(values, length=100): 10 | if len(values) < length: 11 | length = len(values) 12 | return sum(values[-length:])/length 13 | 14 | class ImageFolder(torch.utils.data.Dataset): 15 | def __init__(self, root_path, imsize=None, cropsize=None, cencrop=False): 16 | super(ImageFolder, self).__init__() 17 | 18 | self.file_names = sorted(os.listdir(root_path)) 19 | self.root_path = root_path 20 | self.transform = _transformer(imsize, cropsize, cencrop) 21 | 22 | def __len__(self): 23 | return len(self.file_names) 24 | 25 | def __getitem__(self, index): 26 | image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB") 27 | return self.transform(image) 28 | 29 | def _normalizer(denormalize=False): 30 | # set Mean and Std of RGB channels of IMAGENET to use pre-trained VGG net 31 | MEAN = [0.485, 0.456, 0.406] 32 | STD = [0.229, 0.224, 0.225] 33 | 34 | if denormalize: 35 | MEAN = [-mean/std for mean, std in zip(MEAN, STD)] 36 | STD = [1/std for std in STD] 37 | 38 | return transforms.Normalize(mean=MEAN, std=STD) 39 | 40 | def _transformer(imsize=None, cropsize=None, cencrop=False): 41 | normalize = _normalizer() 42 | transformer = [] 43 | if imsize: 44 | transformer.append(transforms.Resize(imsize)) 45 | if cropsize: 46 | if cencrop: 47 | transformer.append(transforms.CenterCrop(cropsize)) 48 | else: 49 | transformer.append(transforms.RandomCrop(cropsize)) 50 | 51 | transformer.append(transforms.ToTensor()) 52 | transformer.append(normalize) 53 | return transforms.Compose(transformer) 54 | 55 | def imsave(tensor, path): 56 | denormalize = _normalizer(denormalize=True) 57 | if tensor.is_cuda: 58 | tensor = tensor.cpu() 59 | tensor = torchvision.utils.make_grid(tensor) 60 | torchvision.utils.save_image(denormalize(tensor).clamp_(0.0, 1.0), path) 61 | return None 62 | 63 | def imload(path, imsize=None, cropsize=None, cencrop=False): 64 | transformer = _transformer(imsize, cropsize, cencrop) 65 | return transformer(Image.open(path).convert("RGB")).unsqueeze(0) 66 | 67 | def imshow(tensor): 68 | denormalize = _normalizer(denormalize=True) 69 | if tensor.is_cuda: 70 | tensor = tensor.cpu() 71 | tensor = torchvision.utils.make_grid(denormalize(tensor.squeeze(0))) 72 | image = transforms.functional.to_pil_image(tensor.clamp_(0.0, 1.0)) 73 | return image 74 | 75 | def maskload(path): 76 | mask = Image.open(path).convert('L') 77 | return transforms.functional.to_tensor(mask).unsqueeze(0) 78 | -------------------------------------------------------------------------------- /wct.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def covsqrt_mean(feature, inverse=False, tolerance=1e-14): 4 | # I referenced the default svd tolerance value in matlab. 5 | 6 | b, c, h, w = feature.size() 7 | 8 | mean = torch.mean(feature.view(b, c, -1), dim=2, keepdim=True) 9 | zeromean = feature.view(b, c, -1) - mean 10 | cov = torch.bmm(zeromean, zeromean.transpose(1, 2)) 11 | 12 | evals, evects = torch.symeig(cov, eigenvectors=True) 13 | 14 | p = 0.5 15 | if inverse: 16 | p *= -1 17 | 18 | covsqrt = [] 19 | for i in range(b): 20 | k = 0 21 | for j in range(c): 22 | if evals[i][j] > tolerance: 23 | k = j 24 | break 25 | covsqrt.append(torch.mm(evects[i][:, k:], 26 | torch.mm(evals[i][k:].pow(p).diag_embed(), 27 | evects[i][:, k:].t())).unsqueeze(0)) 28 | covsqrt = torch.cat(covsqrt, dim=0) 29 | 30 | return covsqrt, mean 31 | 32 | 33 | def whitening(feature): 34 | b, c, h, w = feature.size() 35 | 36 | inv_covsqrt, mean = covsqrt_mean(feature, inverse=True) 37 | 38 | normalized_feature = torch.matmul(inv_covsqrt, feature.view(b, c, -1)-mean) 39 | 40 | return normalized_feature.view(b, c, h, w) 41 | 42 | 43 | def coloring(feature, target): 44 | b, c, h, w = feature.size() 45 | 46 | covsqrt, mean = covsqrt_mean(target) 47 | 48 | colored_feature = torch.matmul(covsqrt, feature.view(b, c, -1)) + mean 49 | 50 | return colored_feature.view(b, c, h, w) 51 | --------------------------------------------------------------------------------