├── .gitignore ├── LICENSE ├── README.md ├── function.py ├── input ├── content │ ├── avril.jpg │ ├── blonde_girl.jpg │ ├── brad_pitt.jpg │ ├── chicago.jpg │ ├── cornell.jpg │ ├── flowers.jpg │ ├── golden_gate.jpg │ ├── lenna.jpg │ ├── modern.jpg │ ├── newyork.jpg │ └── sailboat.jpg ├── mask │ └── mask.png ├── style │ ├── antimonocromatismo.jpg │ ├── asheville.jpg │ ├── brushstrokes.jpg │ ├── contrast_of_forms.jpg │ ├── en_campo_gris.jpg │ ├── flower_of_life.jpg │ ├── goeritz.jpg │ ├── impronte_d_artista.jpg │ ├── la_muse.jpg │ ├── mondrian.jpg │ ├── mondrian_cropped.jpg │ ├── picasso_seated_nude_hr.jpg │ ├── picasso_self_portrait.jpg │ ├── scene_de_rue.jpg │ ├── sketch.png │ ├── the_resevoir_at_poitiers.jpg │ ├── trial.jpg │ ├── woman_in_peasant_dress.jpg │ ├── woman_in_peasant_dress_cropped.jpg │ └── woman_with_hat_matisse.jpg ├── styleexample │ ├── mondrian.jpg │ └── woman_with_hat_matisse.jpg └── videos │ └── cutBunny.mp4 ├── net.py ├── requirements.txt ├── results.png ├── sampler.py ├── test.py ├── test_video.py ├── torch_to_pytorch.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.t7 3 | output/* 4 | .idea 5 | experiments/* 6 | logs/* 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Naoto Inoue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-AdaIN 2 | 3 | This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017]. 4 | I'm really grateful to the [original implementation](https://github.com/xunhuang1995/AdaIN-style) in Torch by the authors, which is very useful. 5 | 6 | ![Results](results.png) 7 | 8 | ## Requirements 9 | Please install requirements by `pip install -r requirements.txt` 10 | 11 | - Python 3.5+ 12 | - PyTorch 0.4+ 13 | - TorchVision 14 | - Pillow 15 | 16 | (optional, for training) 17 | - tqdm 18 | - TensorboardX 19 | 20 | ## Usage 21 | 22 | ### Download models 23 | Download decoder.pth / vgg_normalized.pth from [release](https://github.com/naoto0804/pytorch-AdaIN/releases/tag/v0.0.0) and put them under `models/`. 24 | 25 | ### Test 26 | Use `--content` and `--style` to provide the respective path to the content and style image. 27 | ``` 28 | CUDA_VISIBLE_DEVICES= python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg 29 | ``` 30 | 31 | You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory. 32 | ``` 33 | CUDA_VISIBLE_DEVICES= python test.py --content_dir input/content --style_dir input/style 34 | ``` 35 | 36 | This is an example of mixing four styles by specifying `--style` and `--style_interpolation_weights` option. 37 | ``` 38 | CUDA_VISIBLE_DEVICES= python test.py --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop 39 | ``` 40 | 41 | Some other options: 42 | * `--content_size`: New (minimum) size for the content image. Keeping the original size if set to 0. 43 | * `--style_size`: New (minimum) size for the content image. Keeping the original size if set to 0. 44 | * `--alpha`: Adjust the degree of stylization. It should be a value between 0.0 and 1.0 (default). 45 | * `--preserve_color`: Preserve the color of the content image. 46 | 47 | 48 | ### Train 49 | Use `--content_dir` and `--style_dir` to provide the respective directory to the content and style images. 50 | ``` 51 | CUDA_VISIBLE_DEVICES= python train.py --content_dir --style_dir 52 | ``` 53 | 54 | For more details and parameters, please refer to --help option. 55 | 56 | I share the model trained by this code as `iter_1000000.pth 57 | ` at [release](https://github.com/naoto0804/pytorch-AdaIN/releases/tag/v0.0.0). 58 | 59 | ## References 60 | - [1]: X. Huang and S. Belongie. "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.", in ICCV, 2017. 61 | - [2]: [Original implementation in Torch](https://github.com/xunhuang1995/AdaIN-style) 62 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_mean_std(feat, eps=1e-5): 5 | # eps is a small value added to the variance to avoid divide-by-zero. 6 | size = feat.size() 7 | assert (len(size) == 4) 8 | N, C = size[:2] 9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 10 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 12 | return feat_mean, feat_std 13 | 14 | 15 | def adaptive_instance_normalization(content_feat, style_feat): 16 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 17 | size = content_feat.size() 18 | style_mean, style_std = calc_mean_std(style_feat) 19 | content_mean, content_std = calc_mean_std(content_feat) 20 | 21 | normalized_feat = (content_feat - content_mean.expand( 22 | size)) / content_std.expand(size) 23 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 24 | 25 | 26 | def _calc_feat_flatten_mean_std(feat): 27 | # takes 3D feat (C, H, W), return mean and std of array within channels 28 | assert (feat.size()[0] == 3) 29 | assert (isinstance(feat, torch.FloatTensor)) 30 | feat_flatten = feat.view(3, -1) 31 | mean = feat_flatten.mean(dim=-1, keepdim=True) 32 | std = feat_flatten.std(dim=-1, keepdim=True) 33 | return feat_flatten, mean, std 34 | 35 | 36 | def _mat_sqrt(x): 37 | U, D, V = torch.svd(x) 38 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 39 | 40 | 41 | def coral(source, target): 42 | # assume both source and target are 3D array (C, H, W) 43 | # Note: flatten -> f 44 | 45 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 46 | source_f_norm = (source_f - source_f_mean.expand_as( 47 | source_f)) / source_f_std.expand_as(source_f) 48 | source_f_cov_eye = \ 49 | torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 50 | 51 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 52 | target_f_norm = (target_f - target_f_mean.expand_as( 53 | target_f)) / target_f_std.expand_as(target_f) 54 | target_f_cov_eye = \ 55 | torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 56 | 57 | source_f_norm_transfer = torch.mm( 58 | _mat_sqrt(target_f_cov_eye), 59 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), 60 | source_f_norm) 61 | ) 62 | 63 | source_f_transfer = source_f_norm_transfer * \ 64 | target_f_std.expand_as(source_f_norm) + \ 65 | target_f_mean.expand_as(source_f_norm) 66 | 67 | return source_f_transfer.view(source.size()) 68 | -------------------------------------------------------------------------------- /input/content/avril.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/avril.jpg -------------------------------------------------------------------------------- /input/content/blonde_girl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/blonde_girl.jpg -------------------------------------------------------------------------------- /input/content/brad_pitt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/brad_pitt.jpg -------------------------------------------------------------------------------- /input/content/chicago.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/chicago.jpg -------------------------------------------------------------------------------- /input/content/cornell.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/cornell.jpg -------------------------------------------------------------------------------- /input/content/flowers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/flowers.jpg -------------------------------------------------------------------------------- /input/content/golden_gate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/golden_gate.jpg -------------------------------------------------------------------------------- /input/content/lenna.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/lenna.jpg -------------------------------------------------------------------------------- /input/content/modern.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/modern.jpg -------------------------------------------------------------------------------- /input/content/newyork.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/newyork.jpg -------------------------------------------------------------------------------- /input/content/sailboat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/content/sailboat.jpg -------------------------------------------------------------------------------- /input/mask/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/mask/mask.png -------------------------------------------------------------------------------- /input/style/antimonocromatismo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/antimonocromatismo.jpg -------------------------------------------------------------------------------- /input/style/asheville.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/asheville.jpg -------------------------------------------------------------------------------- /input/style/brushstrokes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/brushstrokes.jpg -------------------------------------------------------------------------------- /input/style/contrast_of_forms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/contrast_of_forms.jpg -------------------------------------------------------------------------------- /input/style/en_campo_gris.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/en_campo_gris.jpg -------------------------------------------------------------------------------- /input/style/flower_of_life.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/flower_of_life.jpg -------------------------------------------------------------------------------- /input/style/goeritz.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/goeritz.jpg -------------------------------------------------------------------------------- /input/style/impronte_d_artista.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/impronte_d_artista.jpg -------------------------------------------------------------------------------- /input/style/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/la_muse.jpg -------------------------------------------------------------------------------- /input/style/mondrian.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/mondrian.jpg -------------------------------------------------------------------------------- /input/style/mondrian_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/mondrian_cropped.jpg -------------------------------------------------------------------------------- /input/style/picasso_seated_nude_hr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/picasso_seated_nude_hr.jpg -------------------------------------------------------------------------------- /input/style/picasso_self_portrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/picasso_self_portrait.jpg -------------------------------------------------------------------------------- /input/style/scene_de_rue.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/scene_de_rue.jpg -------------------------------------------------------------------------------- /input/style/sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/sketch.png -------------------------------------------------------------------------------- /input/style/the_resevoir_at_poitiers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/the_resevoir_at_poitiers.jpg -------------------------------------------------------------------------------- /input/style/trial.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/trial.jpg -------------------------------------------------------------------------------- /input/style/woman_in_peasant_dress.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/woman_in_peasant_dress.jpg -------------------------------------------------------------------------------- /input/style/woman_in_peasant_dress_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/woman_in_peasant_dress_cropped.jpg -------------------------------------------------------------------------------- /input/style/woman_with_hat_matisse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/style/woman_with_hat_matisse.jpg -------------------------------------------------------------------------------- /input/styleexample/mondrian.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/styleexample/mondrian.jpg -------------------------------------------------------------------------------- /input/styleexample/woman_with_hat_matisse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/styleexample/woman_with_hat_matisse.jpg -------------------------------------------------------------------------------- /input/videos/cutBunny.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/input/videos/cutBunny.mp4 -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from function import adaptive_instance_normalization as adain 4 | from function import calc_mean_std 5 | 6 | decoder = nn.Sequential( 7 | nn.ReflectionPad2d((1, 1, 1, 1)), 8 | nn.Conv2d(512, 256, (3, 3)), 9 | nn.ReLU(), 10 | nn.Upsample(scale_factor=2, mode='nearest'), 11 | nn.ReflectionPad2d((1, 1, 1, 1)), 12 | nn.Conv2d(256, 256, (3, 3)), 13 | nn.ReLU(), 14 | nn.ReflectionPad2d((1, 1, 1, 1)), 15 | nn.Conv2d(256, 256, (3, 3)), 16 | nn.ReLU(), 17 | nn.ReflectionPad2d((1, 1, 1, 1)), 18 | nn.Conv2d(256, 256, (3, 3)), 19 | nn.ReLU(), 20 | nn.ReflectionPad2d((1, 1, 1, 1)), 21 | nn.Conv2d(256, 128, (3, 3)), 22 | nn.ReLU(), 23 | nn.Upsample(scale_factor=2, mode='nearest'), 24 | nn.ReflectionPad2d((1, 1, 1, 1)), 25 | nn.Conv2d(128, 128, (3, 3)), 26 | nn.ReLU(), 27 | nn.ReflectionPad2d((1, 1, 1, 1)), 28 | nn.Conv2d(128, 64, (3, 3)), 29 | nn.ReLU(), 30 | nn.Upsample(scale_factor=2, mode='nearest'), 31 | nn.ReflectionPad2d((1, 1, 1, 1)), 32 | nn.Conv2d(64, 64, (3, 3)), 33 | nn.ReLU(), 34 | nn.ReflectionPad2d((1, 1, 1, 1)), 35 | nn.Conv2d(64, 3, (3, 3)), 36 | ) 37 | 38 | vgg = nn.Sequential( 39 | nn.Conv2d(3, 3, (1, 1)), 40 | nn.ReflectionPad2d((1, 1, 1, 1)), 41 | nn.Conv2d(3, 64, (3, 3)), 42 | nn.ReLU(), # relu1-1 43 | nn.ReflectionPad2d((1, 1, 1, 1)), 44 | nn.Conv2d(64, 64, (3, 3)), 45 | nn.ReLU(), # relu1-2 46 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 47 | nn.ReflectionPad2d((1, 1, 1, 1)), 48 | nn.Conv2d(64, 128, (3, 3)), 49 | nn.ReLU(), # relu2-1 50 | nn.ReflectionPad2d((1, 1, 1, 1)), 51 | nn.Conv2d(128, 128, (3, 3)), 52 | nn.ReLU(), # relu2-2 53 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 54 | nn.ReflectionPad2d((1, 1, 1, 1)), 55 | nn.Conv2d(128, 256, (3, 3)), 56 | nn.ReLU(), # relu3-1 57 | nn.ReflectionPad2d((1, 1, 1, 1)), 58 | nn.Conv2d(256, 256, (3, 3)), 59 | nn.ReLU(), # relu3-2 60 | nn.ReflectionPad2d((1, 1, 1, 1)), 61 | nn.Conv2d(256, 256, (3, 3)), 62 | nn.ReLU(), # relu3-3 63 | nn.ReflectionPad2d((1, 1, 1, 1)), 64 | nn.Conv2d(256, 256, (3, 3)), 65 | nn.ReLU(), # relu3-4 66 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 67 | nn.ReflectionPad2d((1, 1, 1, 1)), 68 | nn.Conv2d(256, 512, (3, 3)), 69 | nn.ReLU(), # relu4-1, this is the last layer used 70 | nn.ReflectionPad2d((1, 1, 1, 1)), 71 | nn.Conv2d(512, 512, (3, 3)), 72 | nn.ReLU(), # relu4-2 73 | nn.ReflectionPad2d((1, 1, 1, 1)), 74 | nn.Conv2d(512, 512, (3, 3)), 75 | nn.ReLU(), # relu4-3 76 | nn.ReflectionPad2d((1, 1, 1, 1)), 77 | nn.Conv2d(512, 512, (3, 3)), 78 | nn.ReLU(), # relu4-4 79 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 80 | nn.ReflectionPad2d((1, 1, 1, 1)), 81 | nn.Conv2d(512, 512, (3, 3)), 82 | nn.ReLU(), # relu5-1 83 | nn.ReflectionPad2d((1, 1, 1, 1)), 84 | nn.Conv2d(512, 512, (3, 3)), 85 | nn.ReLU(), # relu5-2 86 | nn.ReflectionPad2d((1, 1, 1, 1)), 87 | nn.Conv2d(512, 512, (3, 3)), 88 | nn.ReLU(), # relu5-3 89 | nn.ReflectionPad2d((1, 1, 1, 1)), 90 | nn.Conv2d(512, 512, (3, 3)), 91 | nn.ReLU() # relu5-4 92 | ) 93 | 94 | 95 | class Net(nn.Module): 96 | def __init__(self, encoder, decoder): 97 | super(Net, self).__init__() 98 | enc_layers = list(encoder.children()) 99 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 100 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 101 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 102 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 103 | self.decoder = decoder 104 | self.mse_loss = nn.MSELoss() 105 | 106 | # fix the encoder 107 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 108 | for param in getattr(self, name).parameters(): 109 | param.requires_grad = False 110 | 111 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 112 | def encode_with_intermediate(self, input): 113 | results = [input] 114 | for i in range(4): 115 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 116 | results.append(func(results[-1])) 117 | return results[1:] 118 | 119 | # extract relu4_1 from input image 120 | def encode(self, input): 121 | for i in range(4): 122 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 123 | return input 124 | 125 | def calc_content_loss(self, input, target): 126 | assert (input.size() == target.size()) 127 | assert (target.requires_grad is False) 128 | return self.mse_loss(input, target) 129 | 130 | def calc_style_loss(self, input, target): 131 | assert (input.size() == target.size()) 132 | assert (target.requires_grad is False) 133 | input_mean, input_std = calc_mean_std(input) 134 | target_mean, target_std = calc_mean_std(target) 135 | return self.mse_loss(input_mean, target_mean) + \ 136 | self.mse_loss(input_std, target_std) 137 | 138 | def forward(self, content, style, alpha=1.0): 139 | assert 0 <= alpha <= 1 140 | style_feats = self.encode_with_intermediate(style) 141 | content_feat = self.encode(content) 142 | t = adain(content_feat, style_feats[-1]) 143 | t = alpha * t + (1 - alpha) * content_feat 144 | 145 | g_t = self.decoder(t) 146 | g_t_feats = self.encode_with_intermediate(g_t) 147 | 148 | loss_c = self.calc_content_loss(g_t_feats[-1], t) 149 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) 150 | for i in range(1, 4): 151 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) 152 | return loss_c, loss_s 153 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | Pillow==10.2.0 3 | pkg-resources==0.0.0 4 | protobuf==3.18.3 5 | six==1.12.0 6 | tensorboardX==1.8 7 | torch==1.13.1 8 | torchvision==0.4.0 9 | tqdm==4.35.0 10 | opencv-python==4.4.0.46 11 | imageio==2.9.0 12 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naoto0804/pytorch-AdaIN/47950d0e6656a95a80a4b105c4c0f58d38ef785c/results.png -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils import data 3 | 4 | 5 | def InfiniteSampler(n): 6 | # i = 0 7 | i = n - 1 8 | order = np.random.permutation(n) 9 | while True: 10 | yield order[i] 11 | i += 1 12 | if i >= n: 13 | np.random.seed() 14 | order = np.random.permutation(n) 15 | i = 0 16 | 17 | 18 | class InfiniteSamplerWrapper(data.sampler.Sampler): 19 | def __init__(self, data_source): 20 | self.num_samples = len(data_source) 21 | 22 | def __iter__(self): 23 | return iter(InfiniteSampler(self.num_samples)) 24 | 25 | def __len__(self): 26 | return 2 ** 31 27 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.nn as nn 6 | from PIL import Image 7 | from torchvision import transforms 8 | from torchvision.utils import save_image 9 | 10 | import net 11 | from function import adaptive_instance_normalization, coral 12 | 13 | 14 | def test_transform(size, crop): 15 | transform_list = [] 16 | if size != 0: 17 | transform_list.append(transforms.Resize(size)) 18 | if crop: 19 | transform_list.append(transforms.CenterCrop(size)) 20 | transform_list.append(transforms.ToTensor()) 21 | transform = transforms.Compose(transform_list) 22 | return transform 23 | 24 | 25 | def style_transfer(vgg, decoder, content, style, alpha=1.0, 26 | interpolation_weights=None): 27 | assert (0.0 <= alpha <= 1.0) 28 | content_f = vgg(content) 29 | style_f = vgg(style) 30 | if interpolation_weights: 31 | _, C, H, W = content_f.size() 32 | feat = torch.FloatTensor(1, C, H, W).zero_().to(device) 33 | base_feat = adaptive_instance_normalization(content_f, style_f) 34 | for i, w in enumerate(interpolation_weights): 35 | feat = feat + w * base_feat[i:i + 1] 36 | content_f = content_f[0:1] 37 | else: 38 | feat = adaptive_instance_normalization(content_f, style_f) 39 | feat = feat * alpha + content_f * (1 - alpha) 40 | return decoder(feat) 41 | 42 | 43 | parser = argparse.ArgumentParser() 44 | # Basic options 45 | parser.add_argument('--content', type=str, 46 | help='File path to the content image') 47 | parser.add_argument('--content_dir', type=str, 48 | help='Directory path to a batch of content images') 49 | parser.add_argument('--style', type=str, 50 | help='File path to the style image, or multiple style \ 51 | images separated by commas if you want to do style \ 52 | interpolation or spatial control') 53 | parser.add_argument('--style_dir', type=str, 54 | help='Directory path to a batch of style images') 55 | parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth') 56 | parser.add_argument('--decoder', type=str, default='models/decoder.pth') 57 | 58 | # Additional options 59 | parser.add_argument('--content_size', type=int, default=512, 60 | help='New (minimum) size for the content image, \ 61 | keeping the original size if set to 0') 62 | parser.add_argument('--style_size', type=int, default=512, 63 | help='New (minimum) size for the style image, \ 64 | keeping the original size if set to 0') 65 | parser.add_argument('--crop', action='store_true', 66 | help='do center crop to create squared image') 67 | parser.add_argument('--save_ext', default='.jpg', 68 | help='The extension name of the output image') 69 | parser.add_argument('--output', type=str, default='output', 70 | help='Directory to save the output image(s)') 71 | 72 | # Advanced options 73 | parser.add_argument('--preserve_color', action='store_true', 74 | help='If specified, preserve color of the content image') 75 | parser.add_argument('--alpha', type=float, default=1.0, 76 | help='The weight that controls the degree of \ 77 | stylization. Should be between 0 and 1') 78 | parser.add_argument( 79 | '--style_interpolation_weights', type=str, default='', 80 | help='The weight for blending the style of multiple style images') 81 | 82 | args = parser.parse_args() 83 | 84 | do_interpolation = False 85 | 86 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 87 | 88 | output_dir = Path(args.output) 89 | output_dir.mkdir(exist_ok=True, parents=True) 90 | 91 | # Either --content or --contentDir should be given. 92 | assert (args.content or args.content_dir) 93 | if args.content: 94 | content_paths = [Path(args.content)] 95 | else: 96 | content_dir = Path(args.content_dir) 97 | content_paths = [f for f in content_dir.glob('*')] 98 | 99 | # Either --style or --styleDir should be given. 100 | assert (args.style or args.style_dir) 101 | if args.style: 102 | style_paths = args.style.split(',') 103 | if len(style_paths) == 1: 104 | style_paths = [Path(args.style)] 105 | else: 106 | do_interpolation = True 107 | assert (args.style_interpolation_weights != ''), \ 108 | 'Please specify interpolation weights' 109 | weights = [int(i) for i in args.style_interpolation_weights.split(',')] 110 | interpolation_weights = [w / sum(weights) for w in weights] 111 | else: 112 | style_dir = Path(args.style_dir) 113 | style_paths = [f for f in style_dir.glob('*')] 114 | 115 | decoder = net.decoder 116 | vgg = net.vgg 117 | 118 | decoder.eval() 119 | vgg.eval() 120 | 121 | decoder.load_state_dict(torch.load(args.decoder)) 122 | vgg.load_state_dict(torch.load(args.vgg)) 123 | vgg = nn.Sequential(*list(vgg.children())[:31]) 124 | 125 | vgg.to(device) 126 | decoder.to(device) 127 | 128 | content_tf = test_transform(args.content_size, args.crop) 129 | style_tf = test_transform(args.style_size, args.crop) 130 | 131 | for content_path in content_paths: 132 | if do_interpolation: # one content image, N style image 133 | style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths]) 134 | content = content_tf(Image.open(str(content_path))) \ 135 | .unsqueeze(0).expand_as(style) 136 | style = style.to(device) 137 | content = content.to(device) 138 | with torch.no_grad(): 139 | output = style_transfer(vgg, decoder, content, style, 140 | args.alpha, interpolation_weights) 141 | output = output.cpu() 142 | output_name = output_dir / '{:s}_interpolation{:s}'.format( 143 | content_path.stem, args.save_ext) 144 | save_image(output, str(output_name)) 145 | 146 | else: # process one content and one style 147 | for style_path in style_paths: 148 | content = content_tf(Image.open(str(content_path))) 149 | style = style_tf(Image.open(str(style_path))) 150 | if args.preserve_color: 151 | style = coral(style, content) 152 | style = style.to(device).unsqueeze(0) 153 | content = content.to(device).unsqueeze(0) 154 | with torch.no_grad(): 155 | output = style_transfer(vgg, decoder, content, style, 156 | args.alpha) 157 | output = output.cpu() 158 | 159 | output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format( 160 | content_path.stem, style_path.stem, args.save_ext) 161 | save_image(output, str(output_name)) 162 | -------------------------------------------------------------------------------- /test_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from PIL import Image 9 | import cv2 10 | import imageio 11 | from torchvision import transforms 12 | from torchvision.utils import save_image 13 | 14 | import net 15 | from function import adaptive_instance_normalization, coral 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | def test_transform(size, crop): 21 | transform_list = [] 22 | if size != 0: 23 | transform_list.append(transforms.Resize(size)) 24 | if crop: 25 | transform_list.append(transforms.CenterCrop(size)) 26 | transform_list.append(transforms.ToTensor()) 27 | transform = transforms.Compose(transform_list) 28 | return transform 29 | 30 | 31 | def style_transfer(vgg, decoder, content, style, alpha=1.0, 32 | interpolation_weights=None): 33 | assert (0.0 <= alpha <= 1.0) 34 | content_f = vgg(content) 35 | style_f = vgg(style) 36 | if interpolation_weights: 37 | _, C, H, W = content_f.size() 38 | feat = torch.FloatTensor(1, C, H, W).zero_().to(device) 39 | base_feat = adaptive_instance_normalization(content_f, style_f) 40 | for i, w in enumerate(interpolation_weights): 41 | feat = feat + w * base_feat[i:i + 1] 42 | content_f = content_f[0:1] 43 | else: 44 | feat = adaptive_instance_normalization(content_f, style_f) 45 | feat = feat * alpha + content_f * (1 - alpha) 46 | return decoder(feat) 47 | 48 | 49 | parser = argparse.ArgumentParser() 50 | # Basic options 51 | parser.add_argument('--content_video', type=str, 52 | help='File path to the content video') 53 | parser.add_argument('--style_path', type=str, 54 | help='File path to the style video or single image') 55 | parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth') 56 | parser.add_argument('--decoder', type=str, default='models/decoder.pth') 57 | 58 | # Additional options 59 | parser.add_argument('--content_size', type=int, default=512, 60 | help='New (minimum) size for the content image, \ 61 | keeping the original size if set to 0') 62 | parser.add_argument('--style_size', type=int, default=512, 63 | help='New (minimum) size for the style image, \ 64 | keeping the original size if set to 0') 65 | parser.add_argument('--crop', action='store_true', 66 | help='do center crop to create squared image') 67 | parser.add_argument('--save_ext', default='.mp4', 68 | help='The extension name of the output video') 69 | parser.add_argument('--output', type=str, default='output', 70 | help='Directory to save the output image(s)') 71 | 72 | # Advanced options 73 | parser.add_argument('--preserve_color', action='store_true', 74 | help='If specified, preserve color of the content image') 75 | parser.add_argument('--alpha', type=float, default=1.0, 76 | help='The weight that controls the degree of \ 77 | stylization. Should be between 0 and 1') 78 | parser.add_argument( 79 | '--style_interpolation_weights', type=str, default='', 80 | help='The weight for blending the style of multiple style images') 81 | 82 | args = parser.parse_args() 83 | 84 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 85 | 86 | output_dir = Path(args.output) 87 | output_dir.mkdir(exist_ok = True, parents = True) 88 | 89 | # --content_video should be given. 90 | assert (args.content_video) 91 | if args.content_video: 92 | content_path = Path(args.content_video) 93 | 94 | # --style_path should be given 95 | assert (args.style_path) 96 | if args.style_path: 97 | style_path = Path(args.style_path) 98 | 99 | decoder = net.decoder 100 | vgg = net.vgg 101 | 102 | decoder.eval() 103 | vgg.eval() 104 | 105 | decoder.load_state_dict(torch.load(args.decoder)) 106 | vgg.load_state_dict(torch.load(args.vgg)) 107 | vgg = nn.Sequential(*list(vgg.children())[:31]) 108 | 109 | vgg.to(device) 110 | decoder.to(device) 111 | 112 | content_tf = test_transform(args.content_size, args.crop) 113 | style_tf = test_transform(args.style_size, args.crop) 114 | 115 | #get video fps & video size 116 | content_video = cv2.VideoCapture(args.content_video) 117 | fps = int(content_video.get(cv2.CAP_PROP_FPS)) 118 | content_video_length = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT)) 119 | output_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH)) 120 | output_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 121 | 122 | assert fps != 0, 'Fps is zero, Please enter proper video path' 123 | 124 | pbar = tqdm(total = content_video_length) 125 | if style_path.suffix in [".mp4", ".mpg", ".avi"]: 126 | 127 | style_video = cv2.VideoCapture(args.style_path) 128 | style_video_length = int(style_video.get(cv2.CAP_PROP_FRAME_COUNT)) 129 | 130 | assert style_video_length==content_video_length, 'Content video and style video has different number of frames' 131 | 132 | output_video_path = output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format( 133 | content_path.stem, style_path.stem, args.save_ext) 134 | writer = imageio.get_writer(output_video_path, mode='I', fps=fps) 135 | 136 | while(True): 137 | ret, content_img = content_video.read() 138 | 139 | if not ret: 140 | break 141 | _, style_img = style_video.read() 142 | 143 | content = content_tf(Image.fromarray(content_img)) 144 | style = style_tf(Image.fromarray(style_img)) 145 | 146 | if args.preserve_color: 147 | style = coral(style, content) 148 | 149 | style = style.to(device).unsqueeze(0) 150 | content = content.to(device).unsqueeze(0) 151 | with torch.no_grad(): 152 | output = style_transfer(vgg, decoder, content, style, 153 | args.alpha) 154 | output = output.cpu() 155 | output = output.squeeze(0) 156 | output = np.array(output)*255 157 | #output = np.uint8(output) 158 | output = np.transpose(output, (1,2,0)) 159 | output = cv2.resize(output, (output_width, output_height), interpolation=cv2.INTER_CUBIC) 160 | 161 | writer.append_data(np.array(output)) 162 | pbar.update(1) 163 | 164 | style_video.release() 165 | content_video.release() 166 | 167 | if style_path.suffix in [".jpg", ".png", ".JPG", ".PNG"]: 168 | 169 | output_video_path = output_dir / '{:s}_stylized_{:s}{:s}'.format( 170 | content_path.stem, style_path.stem, args.save_ext) 171 | writer = imageio.get_writer(output_video_path, mode='I', fps=fps) 172 | 173 | style_img = Image.open(style_path) 174 | while(True): 175 | ret, content_img = content_video.read() 176 | 177 | if not ret: 178 | break 179 | content = content_tf(Image.fromarray(content_img)) 180 | style = style_tf(style_img) 181 | 182 | if args.preserve_color: 183 | style = coral(style, content) 184 | 185 | style = style.to(device).unsqueeze(0) 186 | content = content.to(device).unsqueeze(0) 187 | with torch.no_grad(): 188 | output = style_transfer(vgg, decoder, content, style, 189 | args.alpha) 190 | output = output.cpu() 191 | output = output.squeeze(0) 192 | output = np.array(output)*255 193 | #output = np.uint8(output) 194 | output = np.transpose(output, (1,2,0)) 195 | output = cv2.resize(output, (output_width, output_height), interpolation=cv2.INTER_CUBIC) 196 | 197 | writer.append_data(np.array(output)) 198 | pbar.update(1) 199 | 200 | content_video.release() -------------------------------------------------------------------------------- /torch_to_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | from functools import reduce 5 | 6 | import torch 7 | assert torch.__version__.split('.')[0] == '0', 'Only working on PyTorch 0.x.x' 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from torch.utils.serialization import load_lua 11 | 12 | 13 | class LambdaBase(nn.Sequential): 14 | def __init__(self, fn, *args): 15 | super(LambdaBase, self).__init__(*args) 16 | self.lambda_func = fn 17 | 18 | def forward_prepare(self, input): 19 | output = [] 20 | for module in self._modules.values(): 21 | output.append(module(input)) 22 | return output if output else input 23 | 24 | 25 | class Lambda(LambdaBase): 26 | def forward(self, input): 27 | return self.lambda_func(self.forward_prepare(input)) 28 | 29 | 30 | class LambdaMap(LambdaBase): 31 | def forward(self, input): 32 | # result is Variables list [Variable1, Variable2, ...] 33 | return list(map(self.lambda_func, self.forward_prepare(input))) 34 | 35 | 36 | class LambdaReduce(LambdaBase): 37 | def forward(self, input): 38 | # result is a Variable 39 | return reduce(self.lambda_func, self.forward_prepare(input)) 40 | 41 | 42 | def copy_param(m, n): 43 | if m.weight is not None: n.weight.data.copy_(m.weight) 44 | if m.bias is not None: n.bias.data.copy_(m.bias) 45 | if hasattr(n, 'running_mean'): n.running_mean.copy_(m.running_mean) 46 | if hasattr(n, 'running_var'): n.running_var.copy_(m.running_var) 47 | 48 | 49 | def add_submodule(seq, *args): 50 | for n in args: 51 | seq.add_module(str(len(seq._modules)), n) 52 | 53 | 54 | def lua_recursive_model(module, seq): 55 | for m in module.modules: 56 | name = type(m).__name__ 57 | real = m 58 | if name == 'TorchObject': 59 | name = m._typename.replace('cudnn.', '') 60 | m = m._obj 61 | 62 | if name == 'SpatialConvolution': 63 | if not hasattr(m, 'groups'): m.groups = 1 64 | n = nn.Conv2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH), 65 | (m.dW, m.dH), (m.padW, m.padH), 1, m.groups, 66 | bias=(m.bias is not None)) 67 | copy_param(m, n) 68 | add_submodule(seq, n) 69 | elif name == 'SpatialBatchNormalization': 70 | n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, 71 | m.affine) 72 | copy_param(m, n) 73 | add_submodule(seq, n) 74 | elif name == 'ReLU': 75 | n = nn.ReLU() 76 | add_submodule(seq, n) 77 | elif name == 'SpatialMaxPooling': 78 | n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 79 | ceil_mode=m.ceil_mode) 80 | add_submodule(seq, n) 81 | elif name == 'SpatialAveragePooling': 82 | n = nn.AvgPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 83 | ceil_mode=m.ceil_mode) 84 | add_submodule(seq, n) 85 | elif name == 'SpatialUpSamplingNearest': 86 | n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor) 87 | add_submodule(seq, n) 88 | elif name == 'View': 89 | n = Lambda(lambda x: x.view(x.size(0), -1)) 90 | add_submodule(seq, n) 91 | elif name == 'Linear': 92 | # Linear in pytorch only accept 2D input 93 | n1 = Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x) 94 | n2 = nn.Linear(m.weight.size(1), m.weight.size(0), 95 | bias=(m.bias is not None)) 96 | copy_param(m, n2) 97 | n = nn.Sequential(n1, n2) 98 | add_submodule(seq, n) 99 | elif name == 'Dropout': 100 | m.inplace = False 101 | n = nn.Dropout(m.p) 102 | add_submodule(seq, n) 103 | elif name == 'SoftMax': 104 | n = nn.Softmax() 105 | add_submodule(seq, n) 106 | elif name == 'Identity': 107 | n = Lambda(lambda x: x) # do nothing 108 | add_submodule(seq, n) 109 | elif name == 'SpatialFullConvolution': 110 | n = nn.ConvTranspose2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH), 111 | (m.dW, m.dH), (m.padW, m.padH)) 112 | add_submodule(seq, n) 113 | elif name == 'SpatialReplicationPadding': 114 | n = nn.ReplicationPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b)) 115 | add_submodule(seq, n) 116 | elif name == 'SpatialReflectionPadding': 117 | n = nn.ReflectionPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b)) 118 | add_submodule(seq, n) 119 | elif name == 'Copy': 120 | n = Lambda(lambda x: x) # do nothing 121 | add_submodule(seq, n) 122 | elif name == 'Narrow': 123 | n = Lambda( 124 | lambda x, a=(m.dimension, m.index, m.length): x.narrow(*a)) 125 | add_submodule(seq, n) 126 | elif name == 'SpatialCrossMapLRN': 127 | lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size, m.alpha, m.beta, 128 | m.k) 129 | n = Lambda(lambda x, lrn=lrn: lrn.forward(x)) 130 | add_submodule(seq, n) 131 | elif name == 'Sequential': 132 | n = nn.Sequential() 133 | lua_recursive_model(m, n) 134 | add_submodule(seq, n) 135 | elif name == 'ConcatTable': # output is list 136 | n = LambdaMap(lambda x: x) 137 | lua_recursive_model(m, n) 138 | add_submodule(seq, n) 139 | elif name == 'CAddTable': # input is list 140 | n = LambdaReduce(lambda x, y: x + y) 141 | add_submodule(seq, n) 142 | elif name == 'Concat': 143 | dim = m.dimension 144 | n = LambdaReduce(lambda x, y, dim=dim: torch.cat((x, y), dim)) 145 | lua_recursive_model(m, n) 146 | add_submodule(seq, n) 147 | elif name == 'TorchObject': 148 | print('Not Implement', name, real._typename) 149 | else: 150 | print('Not Implement', name) 151 | 152 | 153 | def lua_recursive_source(module): 154 | s = [] 155 | for m in module.modules: 156 | name = type(m).__name__ 157 | real = m 158 | if name == 'TorchObject': 159 | name = m._typename.replace('cudnn.', '') 160 | m = m._obj 161 | 162 | if name == 'SpatialConvolution': 163 | if not hasattr(m, 'groups'): m.groups = 1 164 | s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format( 165 | m.nInputPlane, 166 | m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 167 | 1, m.groups, m.bias is not None)] 168 | elif name == 'SpatialBatchNormalization': 169 | s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format( 170 | m.running_mean.size(0), m.eps, m.momentum, m.affine)] 171 | elif name == 'ReLU': 172 | s += ['nn.ReLU()'] 173 | elif name == 'SpatialMaxPooling': 174 | s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format( 175 | (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)] 176 | elif name == 'SpatialAveragePooling': 177 | s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format( 178 | (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)] 179 | elif name == 'SpatialUpSamplingNearest': 180 | s += ['nn.UpsamplingNearest2d(scale_factor={})'.format( 181 | m.scale_factor)] 182 | elif name == 'View': 183 | s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View'] 184 | elif name == 'Linear': 185 | s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )' 186 | s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1), 187 | m.weight.size(0), 188 | (m.bias is not None)) 189 | s += ['nn.Sequential({},{}),#Linear'.format(s1, s2)] 190 | elif name == 'Dropout': 191 | s += ['nn.Dropout({})'.format(m.p)] 192 | elif name == 'SoftMax': 193 | s += ['nn.Softmax()'] 194 | elif name == 'Identity': 195 | s += ['Lambda(lambda x: x), # Identity'] 196 | elif name == 'SpatialFullConvolution': 197 | s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane, 198 | m.nOutputPlane, 199 | (m.kW, m.kH), 200 | (m.dW, m.dH), ( 201 | m.padW, m.padH))] 202 | elif name == 'SpatialReplicationPadding': 203 | s += ['nn.ReplicationPad2d({})'.format( 204 | (m.pad_l, m.pad_r, m.pad_t, m.pad_b))] 205 | elif name == 'SpatialReflectionPadding': 206 | s += ['nn.ReflectionPad2d({})'.format( 207 | (m.pad_l, m.pad_r, m.pad_t, m.pad_b))] 208 | elif name == 'Copy': 209 | s += ['Lambda(lambda x: x), # Copy'] 210 | elif name == 'Narrow': 211 | s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format( 212 | (m.dimension, m.index, m.length))] 213 | elif name == 'SpatialCrossMapLRN': 214 | lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format( 215 | (m.size, m.alpha, m.beta, m.k)) 216 | s += [ 217 | 'Lambda(lambda x,lrn={}: Variable(lrn.forward(x)))'.format( 218 | lrn)] 219 | 220 | elif name == 'Sequential': 221 | s += ['nn.Sequential( # Sequential'] 222 | s += lua_recursive_source(m) 223 | s += [')'] 224 | elif name == 'ConcatTable': 225 | s += ['LambdaMap(lambda x: x, # ConcatTable'] 226 | s += lua_recursive_source(m) 227 | s += [')'] 228 | elif name == 'CAddTable': 229 | s += ['LambdaReduce(lambda x,y: x+y), # CAddTable'] 230 | elif name == 'Concat': 231 | dim = m.dimension 232 | s += [ 233 | 'LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format( 234 | m.dimension)] 235 | s += lua_recursive_source(m) 236 | s += [')'] 237 | else: 238 | s += '# ' + name + ' Not Implement,\n' 239 | s = map(lambda x: '\t{}'.format(x), s) 240 | return s 241 | 242 | 243 | def simplify_source(s): 244 | s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d', ')'), 245 | s) 246 | s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d', ')'), s) 247 | s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d', ')'), s) 248 | s = map(lambda x: x.replace(',bias=True),#Conv2d', ')'), s) 249 | s = map(lambda x: x.replace('),#Conv2d', ')'), s) 250 | s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d', ')'), s) 251 | s = map(lambda x: x.replace('),#BatchNorm2d', ')'), s) 252 | s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d', ')'), s) 253 | s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d', ')'), s) 254 | s = map(lambda x: x.replace('),#MaxPool2d', ')'), s) 255 | s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d', ')'), s) 256 | s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d', ')'), s) 257 | s = map(lambda x: x.replace(',bias=True)),#Linear', ')), # Linear'), s) 258 | s = map(lambda x: x.replace(')),#Linear', ')), # Linear'), s) 259 | 260 | s = map(lambda x: '{},\n'.format(x), s) 261 | s = map(lambda x: x[1:], s) 262 | s = reduce(lambda x, y: x + y, s) 263 | return s 264 | 265 | 266 | def torch_to_pytorch(t7_filename, outputname=None): 267 | model = load_lua(t7_filename, unknown_classes=True) 268 | if type(model).__name__ == 'hashable_uniq_dict': model = model.model 269 | model.gradInput = None 270 | slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model)) 271 | s = simplify_source(slist) 272 | header = ''' 273 | import torch 274 | import torch.nn as nn 275 | from torch.autograd import Variable 276 | from functools import reduce 277 | 278 | class LambdaBase(nn.Sequential): 279 | def __init__(self, fn, *args): 280 | super(LambdaBase, self).__init__(*args) 281 | self.lambda_func = fn 282 | 283 | def forward_prepare(self, input): 284 | output = [] 285 | for module in self._modules.values(): 286 | output.append(module(input)) 287 | return output if output else input 288 | 289 | class Lambda(LambdaBase): 290 | def forward(self, input): 291 | return self.lambda_func(self.forward_prepare(input)) 292 | 293 | class LambdaMap(LambdaBase): 294 | def forward(self, input): 295 | return list(map(self.lambda_func,self.forward_prepare(input))) 296 | 297 | class LambdaReduce(LambdaBase): 298 | def forward(self, input): 299 | return reduce(self.lambda_func,self.forward_prepare(input)) 300 | ''' 301 | varname = t7_filename.replace('.t7', '').replace('.', '_').replace('-', 302 | '_') 303 | s = '{}\n\n{} = {}'.format(header, varname, s[:-2]) 304 | 305 | if outputname is None: outputname = varname 306 | with open(outputname + '.py', "w") as pyfile: 307 | pyfile.write(s) 308 | 309 | n = nn.Sequential() 310 | lua_recursive_model(model, n) 311 | torch.save(n.state_dict(), outputname + '.pth') 312 | 313 | 314 | parser = argparse.ArgumentParser( 315 | description='Convert torch t7 model to pytorch') 316 | parser.add_argument('--model', '-m', type=str, required=True, 317 | help='torch model file in t7 format') 318 | parser.add_argument('--output', '-o', type=str, default=None, 319 | help='output file name prefix, xxx.py xxx.pth') 320 | args = parser.parse_args() 321 | 322 | torch_to_pytorch(args.model, args.output) 323 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | from PIL import Image, ImageFile 9 | from tensorboardX import SummaryWriter 10 | from torchvision import transforms 11 | from tqdm import tqdm 12 | 13 | import net 14 | from sampler import InfiniteSamplerWrapper 15 | 16 | cudnn.benchmark = True 17 | Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError 18 | # Disable OSError: image file is truncated 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | 21 | 22 | def train_transform(): 23 | transform_list = [ 24 | transforms.Resize(size=(512, 512)), 25 | transforms.RandomCrop(256), 26 | transforms.ToTensor() 27 | ] 28 | return transforms.Compose(transform_list) 29 | 30 | 31 | class FlatFolderDataset(data.Dataset): 32 | def __init__(self, root, transform): 33 | super(FlatFolderDataset, self).__init__() 34 | self.root = root 35 | self.paths = list(Path(self.root).glob('*')) 36 | self.transform = transform 37 | 38 | def __getitem__(self, index): 39 | path = self.paths[index] 40 | img = Image.open(str(path)).convert('RGB') 41 | img = self.transform(img) 42 | return img 43 | 44 | def __len__(self): 45 | return len(self.paths) 46 | 47 | def name(self): 48 | return 'FlatFolderDataset' 49 | 50 | 51 | def adjust_learning_rate(optimizer, iteration_count): 52 | """Imitating the original implementation""" 53 | lr = args.lr / (1.0 + args.lr_decay * iteration_count) 54 | for param_group in optimizer.param_groups: 55 | param_group['lr'] = lr 56 | 57 | 58 | parser = argparse.ArgumentParser() 59 | # Basic options 60 | parser.add_argument('--content_dir', type=str, required=True, 61 | help='Directory path to a batch of content images') 62 | parser.add_argument('--style_dir', type=str, required=True, 63 | help='Directory path to a batch of style images') 64 | parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth') 65 | 66 | # training options 67 | parser.add_argument('--save_dir', default='./experiments', 68 | help='Directory to save the model') 69 | parser.add_argument('--log_dir', default='./logs', 70 | help='Directory to save the log') 71 | parser.add_argument('--lr', type=float, default=1e-4) 72 | parser.add_argument('--lr_decay', type=float, default=5e-5) 73 | parser.add_argument('--max_iter', type=int, default=160000) 74 | parser.add_argument('--batch_size', type=int, default=8) 75 | parser.add_argument('--style_weight', type=float, default=10.0) 76 | parser.add_argument('--content_weight', type=float, default=1.0) 77 | parser.add_argument('--n_threads', type=int, default=16) 78 | parser.add_argument('--save_model_interval', type=int, default=10000) 79 | args = parser.parse_args() 80 | 81 | device = torch.device('cuda') 82 | save_dir = Path(args.save_dir) 83 | save_dir.mkdir(exist_ok=True, parents=True) 84 | log_dir = Path(args.log_dir) 85 | log_dir.mkdir(exist_ok=True, parents=True) 86 | writer = SummaryWriter(log_dir=str(log_dir)) 87 | 88 | decoder = net.decoder 89 | vgg = net.vgg 90 | 91 | vgg.load_state_dict(torch.load(args.vgg)) 92 | vgg = nn.Sequential(*list(vgg.children())[:31]) 93 | network = net.Net(vgg, decoder) 94 | network.train() 95 | network.to(device) 96 | 97 | content_tf = train_transform() 98 | style_tf = train_transform() 99 | 100 | content_dataset = FlatFolderDataset(args.content_dir, content_tf) 101 | style_dataset = FlatFolderDataset(args.style_dir, style_tf) 102 | 103 | content_iter = iter(data.DataLoader( 104 | content_dataset, batch_size=args.batch_size, 105 | sampler=InfiniteSamplerWrapper(content_dataset), 106 | num_workers=args.n_threads)) 107 | style_iter = iter(data.DataLoader( 108 | style_dataset, batch_size=args.batch_size, 109 | sampler=InfiniteSamplerWrapper(style_dataset), 110 | num_workers=args.n_threads)) 111 | 112 | optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr) 113 | 114 | for i in tqdm(range(args.max_iter)): 115 | adjust_learning_rate(optimizer, iteration_count=i) 116 | content_images = next(content_iter).to(device) 117 | style_images = next(style_iter).to(device) 118 | loss_c, loss_s = network(content_images, style_images) 119 | loss_c = args.content_weight * loss_c 120 | loss_s = args.style_weight * loss_s 121 | loss = loss_c + loss_s 122 | 123 | optimizer.zero_grad() 124 | loss.backward() 125 | optimizer.step() 126 | 127 | writer.add_scalar('loss_content', loss_c.item(), i + 1) 128 | writer.add_scalar('loss_style', loss_s.item(), i + 1) 129 | 130 | if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter: 131 | state_dict = net.decoder.state_dict() 132 | for key in state_dict.keys(): 133 | state_dict[key] = state_dict[key].to(torch.device('cpu')) 134 | torch.save(state_dict, save_dir / 135 | 'decoder_iter_{:d}.pth.tar'.format(i + 1)) 136 | writer.close() 137 | --------------------------------------------------------------------------------