├── README.md ├── VGGdecoder.py ├── content ├── brad_pitt.jpg ├── golden_gate.jpg ├── in1.jpg ├── lenna.jpg ├── neko.jpg └── sailboat.jpg ├── feature_transformer.py ├── model.py ├── model_state ├── decoder_relu1_1.pth ├── decoder_relu2_1.pth ├── decoder_relu3_1.pth └── decoder_relu4_1.pth ├── normalisedVGG.py ├── res ├── IMG_0565_03_demo.jpg ├── IMG_0565_04_demo.jpg ├── IMG_0565_05_demo.jpg ├── IMG_0565_1_demo.jpg ├── IMG_0565_bridge_demo.jpg ├── IMG_0565_feathers_demo.jpg ├── IMG_0565_horse_demo.jpg ├── IMG_0565_hosi_demo.jpg ├── IMG_0565_hs6_demo.jpg ├── IMG_0565_picasso_seated_nude_hr_demo.jpg ├── IMG_0565_udnie_demo.jpg ├── IMG_0565_wave_demo.jpg ├── neko_hosi.jpg ├── neko_hosi_pair.jpg ├── neko_hosi_style_transfer_demo.jpg ├── neko_hosi_with_style_image.jpg ├── res1.gif ├── res3.gif ├── res4.gif └── res5.gif ├── style ├── antimonocromatismo.jpg ├── asheville.jpg ├── brushstrokes.jpg ├── contrast_of_forms.jpg ├── en_campo_gris.jpg ├── hosi.jpg ├── in1.jpg ├── in2.jpg ├── la_muse.jpg ├── mondrian.jpg ├── picasso_seated_nude_hr.jpg ├── picasso_self_portrait.jpg ├── scene_de_rue.jpg ├── sketch.png ├── trial.jpg ├── woman_in_peasant_dress_cropped.jpg └── woman_with_hat_matisse.jpg └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch_WCT 2 | 3 | Unofficial Pytorch(1.0+) implementation of nips paper [Universal Style Transfer via Feature Transforms](https://arxiv.org/pdf/1705.08086.pdf). 4 | 5 | Original torch implementation from the author can be found [here](https://github.com/Yijunmaverick/UniversalStyleTransfer). 6 | 7 | Other implementations such as [Pytorch_implementation1](https://github.com/black-puppydog/PytorchWCT) , [Pytorch_implementation2](https://github.com/sunshineatnoon/PytorchWCT) or [Pytorch_implementation3 ](https://github.com/pietrocarbo/deep-transfer)are also available. 8 | 9 | This repository provides a pre-trained model for you to generate your own image given content image and style image. 10 | 11 | If you have any question, please feel free to contact me. (Language in English/Japanese/Chinese will be ok!) 12 | 13 | ## Notice 14 | I propose a structure-emphasized multimodal style transfer(SEMST), feel free to use it [here](https://github.com/irasin/Structure-emphasized-Multimodal-Style-Transfer). 15 | 16 | ------ 17 | 18 | ## Requirements 19 | 20 | - Python 3.7 21 | - PyTorch 1.0+ 22 | - TorchVision 23 | - Pillow 24 | 25 | Anaconda environment recommended here! 26 | 27 | (optional) 28 | 29 | - GPU environment 30 | 31 | 32 | 33 | ## Usage 34 | 35 | ------ 36 | 37 | ## test 38 | 39 | 1. Clone this repository 40 | 41 | ```bash 42 | git clone https://github.com/irasin/Pytorch_WCT 43 | cd Pytorch_WCT 44 | ``` 45 | 46 | 2. Prepare your content image and style image. I provide some in the `content` and `style` and you can try to use them easily. 47 | 48 | 3. Download the pretrained model [here](https://drive.google.com/open?id=1tsaGnC7YbruBQNCp6qMmmaSTJiGuyoPA) and put them under the directory named `model_state` 49 | 50 | 4. Generate the output image. A transferred output image and a content_output_pair image and a NST_demo_like image will be generated. 51 | 52 | ```python 53 | python test.py -c content_image_path -s style_image_path 54 | ``` 55 | 56 | ``` 57 | usage: test.py [-h] 58 | [--content CONTENT] 59 | [--style STYLE] 60 | [--output_name OUTPUT_NAME] 61 | [--alpha ALPHA] 62 | [--gpu GPU] 63 | [--model_state_path MODEL_STATE_PATH] 64 | 65 | 66 | ``` 67 | 68 | If output_name is not given, it will use the combination of content image name and style image name. 69 | 70 | ------ 71 | 72 | # Result 73 | 74 | Some results of content image and my cat (called Sora) will be shown here. 75 | 76 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_03_demo.jpg) 77 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_04_demo.jpg) 78 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_05_demo.jpg) 79 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_1_demo.jpg) 80 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_brideg_demo.jpg) 81 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_feathers_demo.jpg) 82 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_horse_demo.jpg) 83 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_hosi_demo.jpg) 84 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_hs6_demo.jpg) 85 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_picasso_seated_nude_hr_demo.jpg.jpg) 86 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_udnie_demo.jpg) 87 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/IMG_0565_wave_demo.jpg) 88 | 89 | 90 | ![image](https://github.com/irasin/Pytorch_WCT/blob/master/res/neko_hosi.jpg) 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /VGGdecoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Interpolate(nn.Module): 8 | def __init__(self, scale_factor=2): 9 | super().__init__() 10 | self.scale_factor = scale_factor 11 | 12 | def forward(self, x): 13 | x = F.interpolate(x, scale_factor=self.scale_factor) 14 | return x 15 | 16 | 17 | vgg_decoder_relu5_1 = nn.Sequential( 18 | nn.ReflectionPad2d((1, 1, 1, 1)), 19 | nn.Conv2d(512, 512, 3), 20 | nn.ReLU(), 21 | Interpolate(2), 22 | nn.ReflectionPad2d((1, 1, 1, 1)), 23 | nn.Conv2d(512, 512, 3), 24 | nn.ReLU(), 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(512, 512, 3), 27 | nn.ReLU(), 28 | nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(512, 512, 3), 30 | nn.ReLU(), 31 | nn.ReflectionPad2d((1, 1, 1, 1)), 32 | nn.Conv2d(512, 256, 3), 33 | nn.ReLU(), 34 | Interpolate(2), 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(256, 256, 3), 37 | nn.ReLU(), 38 | nn.ReflectionPad2d((1, 1, 1, 1)), 39 | nn.Conv2d(256, 256, 3), 40 | nn.ReLU(), 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(256, 256, 3), 43 | nn.ReLU(), 44 | nn.ReflectionPad2d((1, 1, 1, 1)), 45 | nn.Conv2d(256, 128, 3), 46 | nn.ReLU(), 47 | Interpolate(2), 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(128, 128, 3), 50 | nn.ReLU(), 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(128, 64, 3), 53 | nn.ReLU(), 54 | Interpolate(2), 55 | nn.ReflectionPad2d((1, 1, 1, 1)), 56 | nn.Conv2d(64, 64, 3), 57 | nn.ReLU(), 58 | nn.ReflectionPad2d((1, 1, 1, 1)), 59 | nn.Conv2d(64, 3, 3) 60 | ) 61 | 62 | 63 | class Decoder(nn.Module): 64 | def __init__(self, level, pretrained_path=None): 65 | super().__init__() 66 | if level == 1: 67 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children())[-2:])) 68 | elif level == 2: 69 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children())[-9:])) 70 | elif level == 3: 71 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children())[-16:])) 72 | elif level == 4: 73 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children())[-29:])) 74 | elif level == 5: 75 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children()))) 76 | else: 77 | raise ValueError('level should be between 1~5') 78 | 79 | if pretrained_path is not None: 80 | self.net.load_state_dict(torch.load(pretrained_path, map_location=lambda storage, loc: storage)) 81 | 82 | def forward(self, x): 83 | return self.net(x) 84 | -------------------------------------------------------------------------------- /content/brad_pitt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/content/brad_pitt.jpg -------------------------------------------------------------------------------- /content/golden_gate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/content/golden_gate.jpg -------------------------------------------------------------------------------- /content/in1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/content/in1.jpg -------------------------------------------------------------------------------- /content/lenna.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/content/lenna.jpg -------------------------------------------------------------------------------- /content/neko.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/content/neko.jpg -------------------------------------------------------------------------------- /content/sailboat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/content/sailboat.jpg -------------------------------------------------------------------------------- /feature_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def whiten_and_color(content_feature, style_feature, alpha=1): 4 | """ 5 | A WCT function can be used directly between encoder and decoder 6 | """ 7 | cf = content_feature.squeeze(0)#.double() 8 | c, ch, cw = cf.shape 9 | cf = cf.reshape(c, -1) 10 | c_mean = torch.mean(cf, 1, keepdim=True) 11 | cf = cf - c_mean 12 | c_cov = torch.mm(cf, cf.t()).div(ch*cw - 1) 13 | c_u, c_e, c_v = torch.svd(c_cov) 14 | 15 | # if necessary, use k-th largest eig-value 16 | k_c = c 17 | for i in range(c): 18 | if c_e[i] < 0.00001: 19 | k_c = i 20 | break 21 | c_d = c_e[:k_c].pow(-0.5) 22 | 23 | w_step1 = torch.mm(c_v[:, :k_c], torch.diag(c_d)) 24 | w_step2 = torch.mm(w_step1, (c_v[:, :k_c].t())) 25 | whitened = torch.mm(w_step2, cf) 26 | 27 | sf = style_feature.squeeze(0)#.double() 28 | c, sh, sw = sf.shape 29 | sf = sf.reshape(c, -1) 30 | s_mean = torch.mean(sf, 1, keepdim=True) 31 | sf = sf - s_mean 32 | s_cov = torch.mm(sf, sf.t()).div(sh*sw -1) 33 | s_u, s_e, s_v = torch.svd(s_cov) 34 | 35 | # if necessary, use k-th largest eig-value 36 | k_s = c 37 | for i in range(c): 38 | if s_e[i] < 0.00001: 39 | k_s = i 40 | break 41 | s_d = s_e[:k_s].pow(0.5) 42 | c_step1 = torch.mm(s_v[:, :k_s], torch.diag(s_d)) 43 | c_step2 = torch.mm(c_step1, s_v[:, :k_s].t()) 44 | colored = torch.mm(c_step2, whitened) + s_mean 45 | 46 | colored_feature = colored.reshape(c, ch, cw).unsqueeze(0).float() 47 | 48 | colored_feature = alpha * colored_feature + (1.0 - alpha) * content_feature 49 | return colored_feature 50 | 51 | 52 | # a = torch.randn(1, 64, 128, 128) 53 | # b = torch.randn(1, 64, 124, 122) 54 | # 55 | # c = whiten_and_color(a, b) 56 | # c.shape 57 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from normalisedVGG import NormalisedVGG 5 | from VGGdecoder import Decoder 6 | from feature_transformer import whiten_and_color 7 | 8 | 9 | class SingleLevelAE(nn.Module): 10 | def __init__(self, level, pretrained_path_dir='model_state'): 11 | super().__init__() 12 | self.level = level 13 | self.encoder = NormalisedVGG(f'{pretrained_path_dir}/vgg_normalised_conv5_1.pth') 14 | self.decoder = Decoder(level, f'{pretrained_path_dir}/decoder_relu{level}_1.pth') 15 | 16 | def forward(self, content_image, style_image, alpha): 17 | content_feature = self.encoder(content_image, f'relu{self.level}_1') 18 | style_feature = self.encoder(style_image, f'relu{self.level}_1') 19 | res = whiten_and_color(content_feature, style_feature, alpha) 20 | res = self.decoder(res) 21 | return res 22 | 23 | 24 | class MultiLevelAE(nn.Module): 25 | def __init__(self, pretrained_path_dir='model_state'): 26 | super().__init__() 27 | self.encoder = NormalisedVGG(f'{pretrained_path_dir}/vgg_normalised_conv5_1.pth') 28 | self.decoder1 = Decoder(1, f'{pretrained_path_dir}/decoder_relu1_1.pth') 29 | self.decoder2 = Decoder(2, f'{pretrained_path_dir}/decoder_relu2_1.pth') 30 | self.decoder3 = Decoder(3, f'{pretrained_path_dir}/decoder_relu3_1.pth') 31 | self.decoder4 = Decoder(4, f'{pretrained_path_dir}/decoder_relu4_1.pth') 32 | self.decoder5 = Decoder(5, f'{pretrained_path_dir}/decoder_relu5_1.pth') 33 | 34 | def transform_level(self, content_image, style_image, alpha, level): 35 | content_feature = self.encoder(content_image, f'relu{level}_1') 36 | style_feature = self.encoder(style_image, f'relu{level}_1') 37 | res = whiten_and_color(content_feature, style_feature, alpha) 38 | return getattr(self, f'decoder{level}')(res) 39 | 40 | def forward(self, content_image, style_image, alpha=1): 41 | r5 = self.transform_level(content_image, style_image, alpha, 5) 42 | r4 = self.transform_level(r5, style_image, alpha, 4) 43 | r3 = self.transform_level(r4, style_image, alpha, 3) 44 | r2 = self.transform_level(r3, style_image, alpha, 2) 45 | r1 = self.transform_level(r2, style_image, alpha, 1) 46 | 47 | return r1 48 | 49 | -------------------------------------------------------------------------------- /model_state/decoder_relu1_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/model_state/decoder_relu1_1.pth -------------------------------------------------------------------------------- /model_state/decoder_relu2_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/model_state/decoder_relu2_1.pth -------------------------------------------------------------------------------- /model_state/decoder_relu3_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/model_state/decoder_relu3_1.pth -------------------------------------------------------------------------------- /model_state/decoder_relu4_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/model_state/decoder_relu4_1.pth -------------------------------------------------------------------------------- /normalisedVGG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | normalised_vgg_relu5_1 = nn.Sequential( 5 | nn.Conv2d(3, 3, 1), 6 | nn.ReflectionPad2d((1, 1, 1, 1)), 7 | nn.Conv2d(3, 64, 3), 8 | nn.ReLU(), 9 | nn.ReflectionPad2d((1, 1, 1, 1)), 10 | nn.Conv2d(64, 64, 3), 11 | nn.ReLU(), 12 | nn.MaxPool2d(2, ceil_mode=True), 13 | nn.ReflectionPad2d((1, 1, 1, 1)), 14 | nn.Conv2d(64, 128, 3), 15 | nn.ReLU(), 16 | nn.ReflectionPad2d((1, 1, 1, 1)), 17 | nn.Conv2d(128, 128, 3), 18 | nn.ReLU(), 19 | nn.MaxPool2d(2, ceil_mode=True), 20 | nn.ReflectionPad2d((1, 1, 1, 1)), 21 | nn.Conv2d(128, 256, 3), 22 | nn.ReLU(), 23 | nn.ReflectionPad2d((1, 1, 1, 1)), 24 | nn.Conv2d(256, 256, 3), 25 | nn.ReLU(), 26 | nn.ReflectionPad2d((1, 1, 1, 1)), 27 | nn.Conv2d(256, 256, 3), 28 | nn.ReLU(), 29 | nn.ReflectionPad2d((1, 1, 1, 1)), 30 | nn.Conv2d(256, 256, 3), 31 | nn.ReLU(), 32 | nn.MaxPool2d(2, ceil_mode=True), 33 | nn.ReflectionPad2d((1, 1, 1, 1)), 34 | nn.Conv2d(256, 512, 3), 35 | nn.ReLU(), 36 | nn.ReflectionPad2d((1, 1, 1, 1)), 37 | nn.Conv2d(512, 512, 3), 38 | nn.ReLU(), 39 | nn.ReflectionPad2d((1, 1, 1, 1)), 40 | nn.Conv2d(512, 512, 3), 41 | nn.ReLU(), 42 | nn.ReflectionPad2d((1, 1, 1, 1)), 43 | nn.Conv2d(512, 512, 3), 44 | nn.ReLU(), 45 | nn.MaxPool2d(2, ceil_mode=True), 46 | nn.ReflectionPad2d((1, 1, 1, 1)), 47 | nn.Conv2d(512, 512, 3), 48 | nn.ReLU() 49 | ) 50 | 51 | 52 | class NormalisedVGG(nn.Module): 53 | """ 54 | VGG reluX_1(X = 1, 2, 3, 4, 5) can be obtained by slicing the follow vgg5_1 model. 55 | 56 | Sequential( 57 | (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)) 58 | (1): ReflectionPad2d((1, 1, 1, 1)) 59 | (2): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1)) 60 | (3): ReLU() # relu1_1 61 | (4): ReflectionPad2d((1, 1, 1, 1)) 62 | (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)) 63 | (6): ReLU() 64 | (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 65 | (8): ReflectionPad2d((1, 1, 1, 1)) 66 | (9): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1)) 67 | (10): ReLU() # relu2_1 68 | (11): ReflectionPad2d((1, 1, 1, 1)) 69 | (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1)) 70 | (13): ReLU() 71 | (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 72 | (15): ReflectionPad2d((1, 1, 1, 1)) 73 | (16): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)) 74 | (17): ReLU() # relu3_1 75 | (18): ReflectionPad2d((1, 1, 1, 1)) 76 | (19): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)) 77 | (20): ReLU() 78 | (21): ReflectionPad2d((1, 1, 1, 1)) 79 | (22): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)) 80 | (23): ReLU() 81 | (24): ReflectionPad2d((1, 1, 1, 1)) 82 | (25): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)) 83 | (26): ReLU() 84 | (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 85 | (28): ReflectionPad2d((1, 1, 1, 1)) 86 | (29): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1)) 87 | (30): ReLU()# relu4_1 88 | (31): ReflectionPad2d((1, 1, 1, 1)) 89 | (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) 90 | (33): ReLU() 91 | (34): ReflectionPad2d((1, 1, 1, 1)) 92 | (35): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) 93 | (36): ReLU() 94 | (37): ReflectionPad2d((1, 1, 1, 1)) 95 | (38): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) 96 | (39): ReLU() 97 | (40): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 98 | (41): ReflectionPad2d((1, 1, 1, 1)) 99 | (42): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) 100 | (43): ReLU() # relu5_1 101 | ) 102 | """ 103 | def __init__(self, pretrained_path='vgg_normalised_conv5_1.pth'): 104 | super().__init__() 105 | self.net = normalised_vgg_relu5_1 106 | if pretrained_path is not None: 107 | self.net.load_state_dict(torch.load(pretrained_path, map_location=lambda storage, loc: storage)) 108 | 109 | def forward(self, x, target): 110 | if target == 'relu1_1': 111 | return self.net[:4](x) 112 | elif target == 'relu2_1': 113 | return self.net[:11](x) 114 | elif target == 'relu3_1': 115 | return self.net[:18](x) 116 | elif target == 'relu4_1': 117 | return self.net[:31](x) 118 | elif target == 'relu5_1': 119 | return self.net(x) 120 | else: 121 | raise ValueError(f'target should be in ["relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1"] but not {target}') 122 | -------------------------------------------------------------------------------- /res/IMG_0565_03_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_03_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_04_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_04_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_05_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_05_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_1_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_1_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_bridge_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_bridge_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_feathers_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_feathers_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_horse_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_horse_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_hosi_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_hosi_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_hs6_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_hs6_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_picasso_seated_nude_hr_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_picasso_seated_nude_hr_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_udnie_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_udnie_demo.jpg -------------------------------------------------------------------------------- /res/IMG_0565_wave_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/IMG_0565_wave_demo.jpg -------------------------------------------------------------------------------- /res/neko_hosi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/neko_hosi.jpg -------------------------------------------------------------------------------- /res/neko_hosi_pair.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/neko_hosi_pair.jpg -------------------------------------------------------------------------------- /res/neko_hosi_style_transfer_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/neko_hosi_style_transfer_demo.jpg -------------------------------------------------------------------------------- /res/neko_hosi_with_style_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/neko_hosi_with_style_image.jpg -------------------------------------------------------------------------------- /res/res1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/res1.gif -------------------------------------------------------------------------------- /res/res3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/res3.gif -------------------------------------------------------------------------------- /res/res4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/res4.gif -------------------------------------------------------------------------------- /res/res5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/res/res5.gif -------------------------------------------------------------------------------- /style/antimonocromatismo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/antimonocromatismo.jpg -------------------------------------------------------------------------------- /style/asheville.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/asheville.jpg -------------------------------------------------------------------------------- /style/brushstrokes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/brushstrokes.jpg -------------------------------------------------------------------------------- /style/contrast_of_forms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/contrast_of_forms.jpg -------------------------------------------------------------------------------- /style/en_campo_gris.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/en_campo_gris.jpg -------------------------------------------------------------------------------- /style/hosi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/hosi.jpg -------------------------------------------------------------------------------- /style/in1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/in1.jpg -------------------------------------------------------------------------------- /style/in2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/in2.jpg -------------------------------------------------------------------------------- /style/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/la_muse.jpg -------------------------------------------------------------------------------- /style/mondrian.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/mondrian.jpg -------------------------------------------------------------------------------- /style/picasso_seated_nude_hr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/picasso_seated_nude_hr.jpg -------------------------------------------------------------------------------- /style/picasso_self_portrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/picasso_self_portrait.jpg -------------------------------------------------------------------------------- /style/scene_de_rue.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/scene_de_rue.jpg -------------------------------------------------------------------------------- /style/sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/sketch.png -------------------------------------------------------------------------------- /style/trial.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/trial.jpg -------------------------------------------------------------------------------- /style/woman_in_peasant_dress_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/woman_in_peasant_dress_cropped.jpg -------------------------------------------------------------------------------- /style/woman_with_hat_matisse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_WCT/6334ed981bab630c82460a4e0368840911500975/style/woman_with_hat_matisse.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from PIL import Image 4 | import torch 5 | from torchvision import transforms 6 | from torchvision.utils import save_image 7 | from model import MultiLevelAE 8 | 9 | 10 | trans = transforms.Compose([transforms.ToTensor()]) 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description='WCT Style Transfer by Pytorch') 15 | parser.add_argument('--content', '-c', type=str, default=None, 16 | help='Content image path e.g. content.jpg') 17 | parser.add_argument('--style', '-s', type=str, default=None, 18 | help='Style image path e.g. image.jpg') 19 | parser.add_argument('--output_name', '-o', type=str, default=None, 20 | help='Output path for generated image, no need to add ext, e.g. out') 21 | parser.add_argument('--alpha', '-a', type=float, default=1.0, 22 | help='alpha control the fusion degree in Adain') 23 | parser.add_argument('--gpu', '-g', type=int, default=0, 24 | help='GPU ID(negative value indicate CPU)') 25 | parser.add_argument('--model_state_path', type=str, default='model_state', 26 | help='save directory for result and loss') 27 | 28 | args = parser.parse_args() 29 | 30 | # set device on GPU if available, else CPU 31 | if torch.cuda.is_available() and args.gpu >= 0: 32 | device = torch.device(f'cuda:{args.gpu}') 33 | print(f'# CUDA available: {torch.cuda.get_device_name(0)}') 34 | else: 35 | device = 'cpu' 36 | 37 | # set model 38 | model = MultiLevelAE(args.model_state_path) 39 | model = model.to(device) 40 | 41 | c = Image.open(args.content).convert('RGB') 42 | s = Image.open(args.style).convert('RGB') 43 | c_tensor = trans(c).unsqueeze(0).to(device) 44 | s_tensor = trans(s).unsqueeze(0).to(device) 45 | with torch.no_grad(): 46 | out = model(c_tensor, s_tensor, args.alpha) 47 | 48 | if args.output_name is None: 49 | c_name = os.path.splitext(os.path.basename(args.content))[0] 50 | s_name = os.path.splitext(os.path.basename(args.style))[0] 51 | args.output_name = f'{c_name}_{s_name}' 52 | 53 | save_image(out, f'{args.output_name}.jpg', nrow=1) 54 | o = Image.open(f'{args.output_name}.jpg') 55 | 56 | demo = Image.new('RGB', (c.width * 2, c.height)) 57 | o = o.resize(c.size) 58 | s = s.resize((i // 4 for i in c.size)) 59 | 60 | demo.paste(c, (0, 0)) 61 | demo.paste(o, (c.width, 0)) 62 | demo.paste(s, (c.width, c.height - s.height)) 63 | demo.save(f'{args.output_name}_style_transfer_demo.jpg', quality=95) 64 | 65 | o.paste(s, (0, o.height - s.height)) 66 | o.save(f'{args.output_name}_with_style_image.jpg', quality=95) 67 | 68 | print(f'result saved into files starting with {args.output_name}') 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | --------------------------------------------------------------------------------