├── Figs ├── framework.png └── teaserfig.png ├── LICENSE ├── README.md ├── checkpoint └── epoch_20.model ├── content ├── 6.jpg └── 9.jpg ├── loss ├── __pycache__ │ └── vgg.cpython-38.pyc └── vgg.py ├── networks ├── __pycache__ │ └── transfer_net.cpython-38.pyc └── transfer_net.py ├── test_model ├── __pycache__ │ └── utils.cpython-38.pyc ├── test │ ├── test.py │ └── test.yml └── utils.py ├── train_dataset ├── content │ └── readme.md └── style │ ├── 00.jpg │ ├── 01.jpg │ ├── 02.jpg │ ├── 04.jpg │ ├── 05.jpg │ ├── 06.jpg │ ├── 07.jpg │ ├── 08.jpg │ ├── 09.jpg │ └── 10.jpg └── train_model ├── __pycache__ └── utils.cpython-38.pyc ├── train1 ├── train.py └── train.yml ├── train2 ├── train.py └── train.yml └── utils.py /Figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/Figs/framework.png -------------------------------------------------------------------------------- /Figs/teaserfig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/Figs/teaserfig.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Chernobyllight 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ACCV 2024] Pluggable Style Representation Learning for Multi-Style Transfer 2 | 3 | Pytorch implementation of our ACCV 2024 paper ***Pluggable Style Representation Learning for Multi-Style Transfer*** [[paper](https://openaccess.thecvf.com/content/ACCV2024/papers/Liu_Pluggable_Style_Representation_Learning_for_Multi-Style_Transfer_ACCV_2024_paper.pdf)][[supple](https://openaccess.thecvf.com/content/ACCV2024/supplemental/Liu_Pluggable_Style_Representation_ACCV_2024_supplemental.pdf)]. 4 | 5 | 6 | ## :newspaper: $\mathrm{I}$ - Introduction 7 | 8 | **TL;DR:** We introduce a novel style representation learning scheme for multi-style transfer, which achieves superior inference speed and high generation quality. 9 | 10 | Due to the high diversity of image styles, the scalability to various styles plays a critical role in real-world applications. To accommodate a large amount of styles, previous multi-style transfer approaches rely on enlarging the model size while arbitrary-style transfer methods utilize heavy backbones. However, the additional computational cost introduced by more model parameters hinders these methods to be deployed on resource-limited devices. To address this challenge, in this paper, we develop a style transfer framework by decoupling the style modeling and transferring. Specifically, for style modeling, we propose a style representation learning scheme to encode the style information into a compact representation. Then, for style transferring, we develop a style-aware multi-style transfer network (SaMST) to adapt to diverse styles using pluggable style representations. In this way, our framework is able to accommodate diverse image styles in the learned style representations without introducing additional overhead during inference, thereby maintaining efficiency. Experiments show that our style representation can extract accurate style information. Moreover, qualitative and quantitative results demonstrate that our method achieves state-of-the-art performance in terms of both accuracy and efficiency. 11 | 12 | ![](Figs/framework.png) 13 | 14 | *An overview of our multi-style transfer framework.* 15 | 16 | ![](Figs/teaserfig.png) 17 | 18 | *Visual examples.* 19 | 20 | ## :wrench: $\mathrm{II}$ - Installation 21 | 22 | - Install python 3.8.0, torch 2.0.0, CUDA 11.7 and other essential packages (Note that using other versions of packages may affect performance). 23 | - Clone this repo 24 | 25 | ``` 26 | git clone https://github.com/Chernobyllight/SaMST 27 | cd SaMST 28 | ``` 29 | 30 | ## :red_car: $\mathrm{III}$ - Test 31 | 32 | We provide pretrained models trained on 10 styles in ***'./checkpoints/'***. The test content images are provided in ***'./content'***. 33 | 34 | - Get into evaluation codes folder ***'./test_model/test/'***: 35 | 36 | ``` 37 | cd ./test_model/test/ 38 | ``` 39 | 40 | - Specify the number of styles to train the whole model in config file ***'./test_model/test/test.yml'***. 41 | 42 | ``` 43 | style_num: 10 44 | ``` 45 | 46 | - Run 'test.py' 47 | 48 | ``` 49 | python test.py 50 | ``` 51 | 52 | The stylized results are listed in ***'./outputs/'***. 53 | 54 | ## :bullettrain_side: $\mathrm{IV}$ - Train 55 | 56 | ### :bank:Dataset Preparation 57 | 58 | We select styles from [wikiart](https://www.kaggle.com/competitions/painter-by-numbers/data) and [pixabay](https://pixabay.com/), and use [MS_COCO](https://cocodataset.org/#download) as our content dataset. Furthermore, the training dataset folder is ***'./train_dataset/'***, and folder structure should be like: 59 | 60 | ``` 61 | train_dataset 62 | ├── style 63 | │ ├── 000001.jpg 64 | │ ├── 000002.jpg 65 | │ ├── 000003.jpg 66 | │ ├── ... 67 | ├── content 68 | │ ├── MS_COCO 69 | │ │ ├── 000001.jpg 70 | │ │ ├── 000002.jpg 71 | │ │ ├── 000003.jpg 72 | │ │ ├── ... 73 | ``` 74 | 75 | ### :running:Training 76 | 77 | 78 | We provide two training pipelines. If you have a large number of styles, please get into ***train1***. In contrast, if you just train the model on several styles, you can get into ***train2*** to strike a fast convergence. There is **no difference** between the two training pipelines' option setting. Here is a example of pipeline ***train1***. 79 | 80 | - Get into training codes folder ***'./train_model/train1/'***: 81 | 82 | ``` 83 | cd ./train_model/train1/ 84 | ``` 85 | 86 | - Run 'train.py' 87 | 88 | ``` 89 | python train.py 90 | ``` 91 | 92 | 93 | ## :star: $\mathrm{V}$ - Citation 94 | 95 | If you find our work useful in your research, please cite our [paper](https://openaccess.thecvf.com/content/ACCV2024/papers/Liu_Pluggable_Style_Representation_Learning_for_Multi-Style_Transfer_ACCV_2024_paper.pdf)~ Thank you! 96 | 97 | ``` 98 | @inproceedings{liu2024pluggable, 99 | title={Pluggable Style Representation Learning for Multi-style Transfer}, 100 | author={Liu, Hongda and Wang, Longguang and Guan, Weijun and Zhang, Ye and Guo, Yulan}, 101 | booktitle={Proceedings of the Asian Conference on Computer Vision}, 102 | pages={2087--2104}, 103 | year={2024} 104 | } 105 | ``` 106 | 107 | ## :yum: $\mathrm{VI}$ - Acknowledgement 108 | 109 | This repository is heavily built upon the amazing works [Stylebank](https://github.com/jxcodetw/stylebank) and [CIN](https://github.com/kewellcjj/pytorch-multiple-style-transfer). Thanks for their great effort to community. 110 | 111 | ## :e-mail: $\mathrm{VII}$ - Contact 112 | 113 | [Hongda Liu](mailto:2946428816@qq.com) 114 | -------------------------------------------------------------------------------- /checkpoint/epoch_20.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/checkpoint/epoch_20.model -------------------------------------------------------------------------------- /content/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/content/6.jpg -------------------------------------------------------------------------------- /content/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/content/9.jpg -------------------------------------------------------------------------------- /loss/__pycache__/vgg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/loss/__pycache__/vgg.cpython-38.pyc -------------------------------------------------------------------------------- /loss/vgg.py: -------------------------------------------------------------------------------- 1 | # borrowed from https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/vgg.py 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torchvision import models 7 | 8 | 9 | class Vgg16(torch.nn.Module): 10 | def __init__(self, requires_grad=False): 11 | super(Vgg16, self).__init__() 12 | vgg_pretrained_features = models.vgg16(pretrained=True).features 13 | self.slice1 = torch.nn.Sequential() 14 | self.slice2 = torch.nn.Sequential() 15 | self.slice3 = torch.nn.Sequential() 16 | self.slice4 = torch.nn.Sequential() 17 | for x in range(4): 18 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 19 | for x in range(4, 9): 20 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(9, 16): 22 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(16, 23): 24 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 25 | if not requires_grad: 26 | for param in self.parameters(): 27 | param.requires_grad = False 28 | 29 | def forward(self, X): 30 | h = self.slice1(X) 31 | h_relu1_2 = h 32 | h = self.slice2(h) 33 | h_relu2_2 = h 34 | h = self.slice3(h) 35 | h_relu3_3 = h 36 | h = self.slice4(h) 37 | h_relu4_3 = h 38 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 39 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 40 | return out 41 | -------------------------------------------------------------------------------- /networks/__pycache__/transfer_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/networks/__pycache__/transfer_net.cpython-38.pyc -------------------------------------------------------------------------------- /networks/transfer_net.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | class TransformerNet(torch.nn.Module): 9 | def __init__(self, style_num): 10 | super(TransformerNet, self).__init__() 11 | 12 | self.style_bank = Style_bank(style_num) 13 | 14 | self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1) 15 | self.in1 = InstanceNorm2d(32) 16 | self.cm1 = condition_modulate(32) 17 | 18 | self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2) 19 | self.in2 = InstanceNorm2d(64) 20 | self.cm2 = condition_modulate(64) 21 | 22 | self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2) 23 | self.in3 = InstanceNorm2d(128) 24 | self.cm3 = condition_modulate(128) 25 | 26 | self.res1 = ResidualBlock(channels=128,dynamic_channels=128,groups=128) 27 | self.res2 = ResidualBlock(channels=128,dynamic_channels=128,groups=128) 28 | self.res3 = ResidualBlock(channels=128,dynamic_channels=128,groups=128) 29 | self.res4 = ResidualBlock(channels=128,dynamic_channels=128,groups=128) 30 | self.res5 = ResidualBlock(channels=128,dynamic_channels=128,groups=128) 31 | 32 | 33 | self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2) 34 | self.in4 = InstanceNorm2d(64) 35 | self.cm4 = condition_modulate(64) 36 | 37 | self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2) 38 | self.in5 = InstanceNorm2d(32) 39 | self.cm5 = condition_modulate(32) 40 | 41 | self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1) 42 | self.relu = torch.nn.ReLU() 43 | 44 | 45 | def forward(self, X, style_id): 46 | 47 | representation = self.style_bank(style_id) 48 | 49 | y = self.conv1(X) 50 | y = self.in1(y) 51 | y = self.cm1(y,representation) # conditional modulated 52 | y = self.relu(y) 53 | 54 | 55 | y = self.conv2(y) 56 | y = self.in2(y) 57 | y = self.cm2(y, representation) # conditional modulated 58 | y = self.relu(y) 59 | 60 | y = self.conv3(y) 61 | y = self.in3(y) 62 | y = self.cm3(y, representation) # conditional modulated 63 | y = self.relu(y) 64 | 65 | y = self.res1(y, representation) 66 | y = self.res2(y, representation) 67 | y = self.res3(y, representation) 68 | y = self.res4(y, representation) 69 | y = self.res5(y, representation) 70 | 71 | y = self.deconv1(y) 72 | y = self.in4(y) 73 | y = self.cm4(y, representation) # conditional modulated 74 | y = self.relu(y) 75 | 76 | y = self.deconv2(y) 77 | y = self.in5(y) 78 | y = self.cm5(y, representation) # conditional modulated 79 | y = self.relu(y) 80 | 81 | y = self.deconv3(y) 82 | 83 | 84 | 85 | return y,representation 86 | 87 | 88 | class condition_modulate(torch.nn.Module): 89 | """ 90 | Conditional Instance Normalization 91 | introduced in https://arxiv.org/abs/1610.07629 92 | created and applied based on my limited understanding, could be improved 93 | """ 94 | 95 | def __init__(self, in_channels): 96 | super(condition_modulate, self).__init__() 97 | self.compress_gamma = torch.nn.Sequential( 98 | torch.nn.Linear(32, in_channels,bias=False), 99 | torch.nn.LeakyReLU(0.1, True) 100 | ) 101 | self.compress_beta = torch.nn.Sequential( 102 | torch.nn.Linear(32, in_channels, bias=False), 103 | torch.nn.LeakyReLU(0.1, True) 104 | ) 105 | 106 | def forward(self, x,representation): 107 | gamma = self.compress_gamma(representation) 108 | beta = self.compress_beta(representation) 109 | 110 | b,c = gamma.size() 111 | 112 | gamma = gamma.view(b,c,1,1) 113 | beta = beta.view(b,c,1,1) 114 | 115 | out = x * gamma + beta 116 | return out 117 | 118 | 119 | 120 | 121 | class InstanceNorm2d(torch.nn.Module): 122 | """ 123 | Conditional Instance Normalization 124 | introduced in https://arxiv.org/abs/1610.07629 125 | created and applied based on my limited understanding, could be improved 126 | """ 127 | 128 | def __init__(self, in_channels): 129 | super(InstanceNorm2d, self).__init__() 130 | self.inns = torch.nn.InstanceNorm2d(in_channels, affine=False) 131 | 132 | def forward(self, x): 133 | out = self.inns(x) 134 | return out 135 | 136 | class Dynamic_ConvLayer2(torch.nn.Module): 137 | ''' 138 | in_channels: 输入的图像特征通道数 139 | out_channels: 输出的图像特征通道数 140 | groups: 分组数(in_channels,out_channels的公因数) 141 | ''' 142 | def __init__(self, in_channels, out_channels, kernel_size, groups): 143 | super(Dynamic_ConvLayer2, self).__init__() 144 | reflection_padding = kernel_size // 2 # same dimension after padding 145 | self.reflection_padding = reflection_padding 146 | self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 147 | 148 | self.kernel_size = kernel_size 149 | 150 | self.compress_key = torch.nn.Sequential( 151 | torch.nn.Linear(32, out_channels * kernel_size * kernel_size, bias=False), 152 | torch.nn.LeakyReLU(0.1, True) 153 | ) 154 | self.in_channels = in_channels 155 | self.out_channels = out_channels 156 | self.groups = groups 157 | 158 | def forward(self, x,representation): 159 | out = self.reflection_pad(x) 160 | 161 | b, c, h, w = out.size() 162 | 163 | kernel = self.compress_key(representation).view(b,self.out_channels, -1, self.kernel_size, self.kernel_size) 164 | 165 | # 1,64,1,kh,kw -> 1,64,4,kh,kw 166 | features_per_group = int(self.in_channels/self.groups) 167 | kernel = kernel.repeat_interleave(features_per_group, dim=2) 168 | 169 | # 1,64,4,kh,kw 170 | k_batch,k_outputchannel,k_feature_pergroup,kh,kw = kernel.size() 171 | 172 | out = F.conv2d(out.view(1, -1, h, w), kernel.view(-1,k_feature_pergroup,kh,kw), groups=b * self.groups, padding=0) 173 | 174 | b,c,h,w = x.size() 175 | out = out.view(b, -1, h, w) 176 | 177 | return out 178 | 179 | 180 | 181 | class ConvLayer(torch.nn.Module): 182 | def __init__(self, in_channels, out_channels, kernel_size, stride): 183 | super(ConvLayer, self).__init__() 184 | reflection_padding = kernel_size // 2 # same dimension after padding 185 | self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 186 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) # remember this dimension 187 | 188 | def forward(self, x): 189 | out = self.reflection_pad(x) 190 | out = self.conv2d(out) 191 | return out 192 | 193 | 194 | class ResidualBlock(torch.nn.Module): 195 | """ResidualBlock 196 | introduced in: https://arxiv.org/abs/1512.03385 197 | recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html 198 | """ 199 | 200 | def __init__(self, channels,dynamic_channels,groups): 201 | super(ResidualBlock, self).__init__() 202 | self.conv1 = Dynamic_ConvLayer2(channels, dynamic_channels, kernel_size=3, groups=groups) 203 | self.in1 = InstanceNorm2d(dynamic_channels) 204 | self.cm1 = condition_modulate(dynamic_channels) 205 | 206 | self.conv2 = ConvLayer(dynamic_channels, channels, kernel_size=1, stride=1) 207 | self.in2 = InstanceNorm2d(channels) 208 | self.cm2 = condition_modulate(channels) 209 | 210 | self.relu = torch.nn.ReLU() 211 | 212 | representation_channels = 32 213 | feature_channels = channels 214 | self.ca = CA_layer(channels_in=representation_channels, channels_out=feature_channels, reduction=4) 215 | 216 | def forward(self, x,representation): 217 | residual = x 218 | 219 | out = self.conv1(x,representation) 220 | out = self.in1(out) 221 | out = self.cm1(out, representation) # conditional modulated 222 | 223 | out = self.relu(out) 224 | 225 | out = self.conv2(out) 226 | out = self.in2(out) 227 | out = self.cm2(out, representation) # conditional modulated 228 | 229 | # out = out + residual 230 | out = out + self.ca([residual, representation]) 231 | 232 | return out 233 | 234 | 235 | class UpsampleConvLayer(torch.nn.Module): 236 | """UpsampleConvLayer 237 | Upsamples the input and then does a convolution. This method gives better results 238 | compared to ConvTranspose2d. 239 | ref: http://distill.pub/2016/deconv-checkerboard/ 240 | """ 241 | 242 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): 243 | super(UpsampleConvLayer, self).__init__() 244 | self.upsample = upsample 245 | if upsample: 246 | self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample) 247 | reflection_padding = kernel_size // 2 248 | self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 249 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) 250 | 251 | def forward(self, x): 252 | x_in = x 253 | if self.upsample: 254 | x_in = self.upsample_layer(x_in) 255 | out = self.reflection_pad(x_in) 256 | out = self.conv2d(out) 257 | return out 258 | 259 | 260 | 261 | 262 | 263 | class CA_layer(nn.Module): 264 | def __init__(self, channels_in, channels_out, reduction): 265 | super(CA_layer, self).__init__() 266 | self.conv_du = nn.Sequential( 267 | nn.Conv2d(channels_in, channels_in//reduction, 1, 1, 0, bias=False), 268 | nn.PReLU(), 269 | nn.Conv2d(channels_in // reduction, channels_out, 1, 1, 0, bias=False), 270 | nn.Sigmoid() 271 | ) 272 | 273 | def forward(self, x): 274 | ''' 275 | :param x[0]: feature map: B * C * H * W 276 | :param x[1]: degradation representation: B * C 277 | ''' 278 | att = self.conv_du(x[1][:, :, None, None]) 279 | 280 | return x[0] * att 281 | 282 | 283 | 284 | 285 | 286 | class style_representation(nn.Module): 287 | def __init__(self): 288 | super(style_representation, self).__init__() 289 | params = torch.ones(32, requires_grad=True).cuda() 290 | self.params = nn.Parameter(params) 291 | 292 | 293 | def forward(self): 294 | 295 | z = torch.normal(mean=0., std=0.1, size=(32,),requires_grad=False).cuda() # todo:加噪修改 296 | y = self.params + z # todo:加噪修改 297 | return y 298 | 299 | 300 | 301 | class Style_bank(nn.Module): 302 | def __init__(self, total_style): 303 | super(Style_bank, self).__init__() 304 | 305 | self.total_style = total_style 306 | 307 | self.style_para_list = nn.ModuleList() 308 | for i in range(total_style + 1): 309 | params_layer = style_representation() 310 | self.style_para_list.append(params_layer) 311 | 312 | 313 | 314 | def forward(self, style_id=None): 315 | new_z = [] 316 | if style_id is not None: 317 | for idx, i in enumerate(style_id): 318 | zs = self.style_para_list[i]() 319 | new_z.append(zs) 320 | # z = torch.cat(new_z, dim=0) 321 | new_z = torch.stack(new_z,dim=0) 322 | else: 323 | print('where is your style_id?') 324 | exit(111) 325 | 326 | all_z = [self.style_para_list[i]() for i in range(len(self.style_para_list))] 327 | all_z = torch.stack(all_z,dim=0) 328 | return new_z 329 | 330 | def add_style(self,add_num): 331 | 332 | origin_style_num = len(self.style_para_list) # 包含ae representation 333 | for i in range(origin_style_num): 334 | self.style_para_list[i].params.requires_grad_(False) 335 | 336 | for i in range(add_num): 337 | print('add a style in bank, style id:',len(self.style_para_list)) 338 | params_layer = style_representation() 339 | self.style_para_list.append(params_layer) 340 | 341 | 342 | 343 | 344 | 345 | -------------------------------------------------------------------------------- /test_model/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/test_model/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /test_model/test/test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | project_root = os.path.abspath('../..') 4 | import sys 5 | sys.path.append(project_root) 6 | 7 | import yaml 8 | 9 | from test_model import utils 10 | from networks.transfer_net import TransformerNet 11 | 12 | import torch 13 | from torchvision import transforms 14 | 15 | def stylize(opt): 16 | device = torch.device("cuda" if opt['cuda'] else "cpu") 17 | 18 | content_images = os.listdir(opt['content_image_dir']) 19 | style_model = TransformerNet(style_num=opt['style_num']) 20 | state_dict = torch.load(opt['model']) 21 | style_model.load_state_dict(state_dict) 22 | style_model.to(device) 23 | 24 | if not os.path.exists(opt['output_image_dir']): 25 | os.makedirs(opt['output_image_dir']) 26 | 27 | with torch.no_grad(): 28 | for filename in content_images: 29 | 30 | print(filename) 31 | 32 | file_path = os.path.join(opt['content_image_dir'],filename) 33 | content_image = utils.load_image(filename=file_path, scale=opt['content_scale']) 34 | content_transform = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Lambda(lambda x: x.mul(255)) 37 | ]) 38 | content_image = content_transform(content_image) 39 | content_image = content_image.unsqueeze(0).to(device) 40 | 41 | print(content_image.shape) 42 | 43 | for i in range(0, opt['style_num'] + 1): 44 | output, embedding = style_model(content_image, style_id=[i]) 45 | output = output.cpu() 46 | utils.save_image(opt['output_image_dir'] + 'style' + str(i) + '_' + filename, output[0]) 47 | 48 | 49 | 50 | 51 | 52 | 53 | def main(): 54 | 55 | with open('test.yml', 'r') as stream: 56 | opt = yaml.load(stream, Loader=yaml.FullLoader) 57 | 58 | stylize(opt) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | 64 | 65 | -------------------------------------------------------------------------------- /test_model/test/test.yml: -------------------------------------------------------------------------------- 1 | content_image_dir: '../../content' 2 | content_scale: ~ 3 | output_image_dir: '../../outputs/' 4 | 5 | model: '../../checkpoint/epoch_20.model' 6 | 7 | style_num: 10 8 | 9 | cuda: 1 10 | 11 | -------------------------------------------------------------------------------- /test_model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import cv2 4 | 5 | def load_image(filename, size=None, scale=None): 6 | img = Image.open(filename) 7 | # print(type(img)) 8 | # img.show() 9 | # cv2.waitKey() 10 | # exit(11) 11 | if size is not None: 12 | img = img.resize((size, size), Image.LANCZOS) 13 | elif scale is not None: 14 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.LANCZOS) 15 | return img 16 | 17 | 18 | 19 | def load_image_cv(filename, size=None, scale=None): 20 | img = cv2.imread(filename) 21 | img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 22 | if size is not None: 23 | img = img.resize((size, size), Image.LANCZOS) 24 | elif scale is not None: 25 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.LANCZOS) 26 | return img 27 | 28 | 29 | 30 | def save_image(filename, data): 31 | img = data.clone().clamp(0, 255).numpy() 32 | img = img.transpose(1, 2, 0).astype("uint8") 33 | img = Image.fromarray(img) 34 | img.save(filename) 35 | 36 | 37 | def gram_matrix(y): 38 | (b, ch, h, w) = y.size() 39 | features = y.view(b, ch, w * h) 40 | features_t = features.transpose(1, 2) # swapped ch and w*h, transpose share storage with original 41 | gram = features.bmm(features_t) / (ch * h * w) 42 | return gram 43 | 44 | 45 | def normalize_batch(batch): 46 | # normalize using imagenet mean and std 47 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) # new_tensor for same dimension of tensor 48 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 49 | batch = batch.div_(255.0) # back to tensor within 0, 1 50 | return (batch - mean) / std -------------------------------------------------------------------------------- /train_dataset/content/readme.md: -------------------------------------------------------------------------------- 1 | download MS_COCO dataset from: 2 | 3 | https://cocodataset.org/#download -------------------------------------------------------------------------------- /train_dataset/style/00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/00.jpg -------------------------------------------------------------------------------- /train_dataset/style/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/01.jpg -------------------------------------------------------------------------------- /train_dataset/style/02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/02.jpg -------------------------------------------------------------------------------- /train_dataset/style/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/04.jpg -------------------------------------------------------------------------------- /train_dataset/style/05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/05.jpg -------------------------------------------------------------------------------- /train_dataset/style/06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/06.jpg -------------------------------------------------------------------------------- /train_dataset/style/07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/07.jpg -------------------------------------------------------------------------------- /train_dataset/style/08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/08.jpg -------------------------------------------------------------------------------- /train_dataset/style/09.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/09.jpg -------------------------------------------------------------------------------- /train_dataset/style/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_dataset/style/10.jpg -------------------------------------------------------------------------------- /train_model/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chernobyllight/SaMST/2af3bf1a6ab4e8fbc86e4e3c66f14ee5e46d7a6c/train_model/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /train_model/train1/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | project_root = os.path.abspath('../..') 3 | import sys 4 | sys.path.append(project_root) 5 | 6 | 7 | 8 | import random 9 | 10 | import yaml 11 | import os 12 | import sys 13 | import numpy as np 14 | import time 15 | 16 | import torch 17 | from torchvision import transforms 18 | from torchvision import datasets 19 | from torch.utils.data import DataLoader 20 | from torch.optim import Adam 21 | 22 | from networks.transfer_net import TransformerNet 23 | from loss.vgg import Vgg16 24 | from train_model import utils 25 | 26 | 27 | def check_paths(opt): 28 | try: 29 | if not os.path.exists(opt['save_model_dir']): 30 | os.makedirs(opt['save_model_dir']) 31 | 32 | if opt['checkpoint_model_dir'] is not None and not (os.path.exists(opt['checkpoint_model_dir'])): 33 | os.makedirs(opt['checkpoint_model_dir']) 34 | except OSError as e: 35 | print(e) 36 | sys.exit(1) 37 | 38 | 39 | def train(opt): 40 | device = torch.device("cuda" if opt['cuda'] else "cpu") 41 | 42 | np.random.seed(opt['seed']) 43 | torch.manual_seed(opt['seed']) 44 | 45 | transform = transforms.Compose([ 46 | transforms.Resize(opt['image_size']), # the shorter side is resize to match image_size 47 | transforms.CenterCrop(opt['image_size']), 48 | transforms.ToTensor(), # to tensor [0,1] 49 | transforms.Lambda(lambda x: x.mul(255)) # convert back to [0, 255] 50 | ]) 51 | train_dataset = datasets.ImageFolder(opt['dataset'], transform) 52 | train_loader = DataLoader(train_dataset, batch_size=opt['batch_size'], shuffle=True) # to provide a batch loader 53 | 54 | 55 | style_image = [f for f in os.listdir(opt['style_image'])] 56 | style_num = len(style_image) 57 | print('total style number:',style_num) 58 | print(style_image) 59 | 60 | labels = [i for i in range(0,style_num+1)] 61 | labels = torch.Tensor(labels).cuda() 62 | # print(labels) 63 | # exit(123213) 64 | 65 | 66 | transformer = TransformerNet(style_num=style_num) 67 | print('# MODEL parameters:', sum(param.numel() for param in transformer.parameters()), '\n') 68 | begin_epoch = 0 69 | if opt['begin_checkpoint'] is not None: 70 | state_dict = torch.load(opt['begin_checkpoint']) 71 | transformer.load_state_dict(state_dict) 72 | print("load checkpoint model to train") 73 | begin_epoch = opt['begin_epoch'] 74 | transformer = transformer.to(device) 75 | 76 | 77 | optimizer = Adam(transformer.parameters(), opt['lr']) 78 | 79 | mse_loss = torch.nn.MSELoss() 80 | 81 | vgg = Vgg16(requires_grad=False).to(device) 82 | style_transform = transforms.Compose([ 83 | transforms.Resize(opt['style_size']), 84 | transforms.CenterCrop(opt['style_size']), 85 | transforms.ToTensor(), 86 | transforms.Lambda(lambda x: x.mul(255)) 87 | ]) 88 | 89 | 90 | 91 | content_weight = float(opt['content_weight']) 92 | style_weight = float(opt['style_weight']) 93 | ae_weight = float(opt['ae_weight']) 94 | 95 | total_epochs = opt['epochs'] 96 | for e in range(begin_epoch + 1,total_epochs+1): 97 | 98 | transformer.train() 99 | agg_content_loss = 0. 100 | agg_style_loss = 0. 101 | agg_ae_loss = 0. 102 | 103 | count = 0 104 | for batch_id, (x, _) in enumerate(train_loader): 105 | n_batch = len(x) 106 | 107 | if n_batch < opt['batch_size']: 108 | break # skip to next epoch when no enough images left in the last batch of current epoch 109 | 110 | count += n_batch 111 | optimizer.zero_grad() # initialize with zero gradients 112 | 113 | batch_style_id = [random.randint(1, style_num) for i in range(count - n_batch, count)] 114 | style_batch = [] 115 | for i in batch_style_id: 116 | style = utils.load_image(opt['style_image'] + style_image[i-1], size=opt['style_size']) 117 | style = style_transform(style) 118 | style_batch.append(style) 119 | 120 | style = torch.stack(style_batch).to(device) 121 | features_style = vgg(utils.normalize_batch(style)) 122 | gram_style = [utils.gram_matrix(y) for y in features_style] 123 | 124 | for i in range(n_batch): 125 | batch_style_id.append(0) 126 | 127 | x = x.repeat(2,1,1,1) 128 | y,embedding = transformer(x.to(device), style_id=batch_style_id) 129 | 130 | 131 | y = utils.normalize_batch(y) 132 | x = utils.normalize_batch(x) 133 | 134 | 135 | y = torch.split(y, n_batch , dim=0) # 按照4这个维度去分,每大块包含2个小块 136 | y1 = y[0] 137 | y2 = y[1] 138 | 139 | x = torch.split(x, n_batch , dim=0) 140 | x1 = x[0] 141 | x2 = x[1] 142 | 143 | 144 | features_y = vgg(y1.to(device)) 145 | features_x = vgg(x1.to(device)) 146 | 147 | content_loss = content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) 148 | 149 | style_loss = 0. 150 | for ft_y, gm_s in zip(features_y, gram_style): 151 | gm_y = utils.gram_matrix(ft_y) 152 | style_loss += mse_loss(gm_y, gm_s) 153 | style_loss *= style_weight 154 | 155 | ae_loss = ae_weight * mse_loss(y2.to(device),x2.to(device)) 156 | 157 | total_loss = content_loss + style_loss + ae_loss 158 | total_loss.backward() 159 | optimizer.step() 160 | 161 | agg_content_loss += content_loss.item() 162 | agg_style_loss += style_loss.item() 163 | agg_ae_loss += ae_loss.item() 164 | 165 | if (batch_id + 1) % opt['log_interval'] == 0: 166 | mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\tae: {:.6f}\ttotal: {:.6f}".format( 167 | time.ctime(), e, count, len(train_dataset), 168 | agg_content_loss / (batch_id + 1), 169 | agg_style_loss / (batch_id + 1), 170 | agg_ae_loss / (batch_id + 1), 171 | (agg_content_loss + agg_style_loss) / (batch_id + 1) 172 | ) 173 | print(mesg) 174 | 175 | if opt['checkpoint_model_dir'] is not None and (batch_id + 1) % opt['checkpoint_interval'] == 0: 176 | transformer.eval().cpu() 177 | ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" 178 | ckpt_model_path = os.path.join(opt['checkpoint_model_dir'], ckpt_model_filename) 179 | torch.save(transformer.state_dict(), ckpt_model_path) 180 | transformer.to(device).train() 181 | 182 | if e % opt['save_interval'] == 0: 183 | # save model 184 | transformer.eval().cpu() 185 | save_model_filename = "epoch_" + str(e) + ".model" 186 | save_model_path = os.path.join(opt['save_model_dir'], save_model_filename) 187 | torch.save(transformer.state_dict(), save_model_path) 188 | print("\ntrained model saved at", save_model_path) 189 | transformer.to(device).train() 190 | 191 | 192 | if e % opt['step_size'] == 0: 193 | lr = opt['lr'] * (opt['weight_decay'] ** (e // opt['step_size'])) 194 | for param_group in optimizer.param_groups: 195 | param_group['lr'] = lr 196 | print('now learning rate: ',optimizer.state_dict()['param_groups'][0]['lr']) 197 | 198 | 199 | 200 | 201 | 202 | def main(): 203 | 204 | with open('train.yml', 'r') as stream: 205 | opt = yaml.load(stream, Loader=yaml.FullLoader) 206 | 207 | random.seed(7) 208 | check_paths(opt) 209 | train(opt) 210 | 211 | 212 | if __name__ == "__main__": 213 | main() 214 | 215 | 216 | 217 | 218 | -------------------------------------------------------------------------------- /train_model/train1/train.yml: -------------------------------------------------------------------------------- 1 | 2 | epochs: 100 3 | batch_size: 8 4 | dataset: '../../train_dataset/content/' 5 | style_image: '../../train_dataset/style/' 6 | 7 | save_model_dir: '../../checkpoint/' 8 | image_size: 256 9 | style_size: 512 10 | cuda: 1 11 | seed: 7 12 | content_weight: 1e5 13 | style_weight: 1e10 14 | ae_weight: 1e3 15 | 16 | lr: 0.001 17 | weight_decay: 0.5 18 | step_size: 25 19 | save_interval: 10 20 | 21 | log_interval: 50 22 | checkpoint_interval: 100 23 | checkpoint_model_dir: ~ 24 | 25 | 26 | begin_checkpoint: ~ 27 | begin_epoch: ~ -------------------------------------------------------------------------------- /train_model/train2/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | project_root = os.path.abspath('../..') 3 | import sys 4 | sys.path.append(project_root) 5 | 6 | 7 | 8 | import random 9 | 10 | import yaml 11 | import os 12 | import sys 13 | import numpy as np 14 | import time 15 | 16 | import torch 17 | from torchvision import transforms 18 | from torchvision import datasets 19 | from torch.utils.data import DataLoader 20 | from torch.optim import Adam 21 | 22 | from networks.transfer_net import TransformerNet 23 | from loss.vgg import Vgg16 24 | from train_model import utils 25 | 26 | def check_paths(opt): 27 | try: 28 | if not os.path.exists(opt['save_model_dir']): 29 | os.makedirs(opt['save_model_dir']) 30 | 31 | if opt['checkpoint_model_dir'] is not None and not (os.path.exists(opt['checkpoint_model_dir'])): 32 | os.makedirs(opt['checkpoint_model_dir']) 33 | except OSError as e: 34 | print(e) 35 | sys.exit(1) 36 | 37 | 38 | def train(opt): 39 | device = torch.device("cuda" if opt['cuda'] else "cpu") 40 | 41 | np.random.seed(opt['seed']) 42 | torch.manual_seed(opt['seed']) 43 | 44 | transform = transforms.Compose([ 45 | transforms.Resize(opt['image_size']), # the shorter side is resize to match image_size 46 | transforms.CenterCrop(opt['image_size']), 47 | transforms.ToTensor(), # to tensor [0,1] 48 | transforms.Lambda(lambda x: x.mul(255)) # convert back to [0, 255] 49 | ]) 50 | train_dataset = datasets.ImageFolder(opt['dataset'], transform) 51 | train_loader = DataLoader(train_dataset, batch_size=opt['batch_size'], shuffle=True) # to provide a batch loader 52 | 53 | style_image = [f for f in os.listdir(opt['style_image'])] 54 | style_num = len(style_image) 55 | print('total style number:',style_num) 56 | 57 | labels = [i for i in range(0,style_num+1)] 58 | labels = torch.Tensor(labels).cuda() 59 | 60 | 61 | transformer = TransformerNet(style_num=style_num) 62 | print('# MODEL parameters:', sum(param.numel() for param in transformer.parameters()), '\n') 63 | begin_epoch = 0 64 | if opt['begin_checkpoint'] is not None: 65 | state_dict = torch.load(opt['begin_checkpoint']) 66 | transformer.load_state_dict(state_dict) 67 | print("load checkpoint model to train") 68 | begin_epoch = opt['begin_epoch'] 69 | transformer = transformer.to(device) 70 | 71 | 72 | optimizer = Adam(transformer.parameters(), opt['lr']) 73 | 74 | mse_loss = torch.nn.MSELoss() 75 | 76 | 77 | vgg = Vgg16(requires_grad=False).to(device) 78 | style_transform = transforms.Compose([ 79 | transforms.Resize(opt['style_size']), 80 | transforms.CenterCrop(opt['style_size']), 81 | transforms.ToTensor(), 82 | transforms.Lambda(lambda x: x.mul(255)) 83 | ]) 84 | 85 | style_batch = [] 86 | 87 | for i in range(style_num): 88 | style = utils.load_image(opt['style_image'] + style_image[i], size=opt['style_size']) 89 | style = style_transform(style) 90 | style_batch.append(style) 91 | 92 | style = torch.stack(style_batch).to(device) 93 | 94 | features_style = vgg(utils.normalize_batch(style)) 95 | gram_style = [utils.gram_matrix(y) for y in features_style] 96 | 97 | 98 | content_weight = float(opt['content_weight']) 99 | style_weight = float(opt['style_weight']) 100 | ae_weight = float(opt['ae_weight']) 101 | 102 | total_epochs = opt['epochs'] 103 | for e in range(begin_epoch + 1,total_epochs+1): 104 | 105 | transformer.train() 106 | agg_content_loss = 0. 107 | agg_style_loss = 0. 108 | agg_ae_loss = 0. 109 | 110 | count = 0 111 | for batch_id, (x, _) in enumerate(train_loader): 112 | n_batch = len(x) 113 | 114 | if n_batch < opt['batch_size']: 115 | break # skip to next epoch when no enough images left in the last batch of current epoch 116 | 117 | count += n_batch 118 | optimizer.zero_grad() # initialize with zero gradients 119 | 120 | batch_style_id = [random.randint(1,style_num) for i in range(count - n_batch, count)] 121 | 122 | 123 | for i in range(n_batch): 124 | batch_style_id.append(0) 125 | 126 | x = x.repeat(2,1,1,1) 127 | # print(batch_style_id) 128 | y,embedding = transformer(x.to(device), style_id=batch_style_id) 129 | 130 | 131 | y = utils.normalize_batch(y) 132 | x = utils.normalize_batch(x) 133 | 134 | 135 | y = torch.split(y, n_batch , dim=0) # 按照4这个维度去分,每大块包含2个小块 136 | y1 = y[0] 137 | y2 = y[1] 138 | 139 | x = torch.split(x, n_batch , dim=0) 140 | x1 = x[0] 141 | x2 = x[1] 142 | 143 | 144 | features_y = vgg(y1.to(device)) 145 | features_x = vgg(x1.to(device)) 146 | 147 | content_loss = content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) 148 | 149 | style_loss = 0. 150 | batch_style_id = batch_style_id[0:n_batch] 151 | for i in range(len(batch_style_id)): 152 | batch_style_id[i] -= 1 153 | for ft_y, gm_s in zip(features_y, gram_style): 154 | gm_y = utils.gram_matrix(ft_y) 155 | style_loss += mse_loss(gm_y, gm_s[batch_style_id, :, :]) 156 | style_loss *= style_weight 157 | 158 | ae_loss = ae_weight * mse_loss(y2.to(device),x2.to(device)) 159 | 160 | 161 | total_loss = content_loss + style_loss + ae_loss 162 | total_loss.backward() 163 | optimizer.step() 164 | 165 | agg_content_loss += content_loss.item() 166 | agg_style_loss += style_loss.item() 167 | agg_ae_loss += ae_loss.item() 168 | 169 | if (batch_id + 1) % opt['log_interval'] == 0: 170 | mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\tae: {:.6f}\ttotal: {:.6f}".format( 171 | time.ctime(), e, count, len(train_dataset), 172 | agg_content_loss / (batch_id + 1), 173 | agg_style_loss / (batch_id + 1), 174 | agg_ae_loss / (batch_id + 1), 175 | (agg_content_loss + agg_style_loss) / (batch_id + 1) 176 | ) 177 | print(mesg) 178 | 179 | if opt['checkpoint_model_dir'] is not None and (batch_id + 1) % opt['checkpoint_interval'] == 0: 180 | transformer.eval().cpu() 181 | ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" 182 | ckpt_model_path = os.path.join(opt['checkpoint_model_dir'], ckpt_model_filename) 183 | torch.save(transformer.state_dict(), ckpt_model_path) 184 | transformer.to(device).train() 185 | 186 | if e % opt['save_interval'] == 0: 187 | # save model 188 | transformer.eval().cpu() 189 | save_model_filename = "epoch_" + str(e) + ".model" 190 | save_model_path = os.path.join(opt['save_model_dir'], save_model_filename) 191 | torch.save(transformer.state_dict(), save_model_path) 192 | print("\ntrained model saved at", save_model_path) 193 | transformer.to(device).train() 194 | 195 | 196 | if e % opt['step_size'] == 0: 197 | lr = opt['lr'] * (opt['weight_decay'] ** (e // opt['step_size'])) 198 | for param_group in optimizer.param_groups: 199 | param_group['lr'] = lr 200 | print('now learning rate: ',optimizer.state_dict()['param_groups'][0]['lr']) 201 | 202 | 203 | 204 | 205 | 206 | def main(): 207 | 208 | with open('train.yml', 'r') as stream: 209 | opt = yaml.load(stream, Loader=yaml.FullLoader) 210 | 211 | random.seed(7) 212 | check_paths(opt) 213 | train(opt) 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /train_model/train2/train.yml: -------------------------------------------------------------------------------- 1 | 2 | epochs: 100 3 | batch_size: 8 4 | dataset: '../../train_dataset/content/' 5 | style_image: '../../train_dataset/style/' 6 | 7 | save_model_dir: '../../checkpoint/' 8 | image_size: 256 9 | style_size: 512 10 | cuda: 1 11 | seed: 7 12 | content_weight: 1e5 13 | style_weight: 1e10 14 | ae_weight: 1e3 15 | 16 | 17 | lr: 0.001 18 | weight_decay: 0.5 19 | step_size: 25 20 | save_interval: 10 21 | 22 | log_interval: 50 23 | checkpoint_interval: 100 24 | checkpoint_model_dir: ~ 25 | 26 | 27 | begin_checkpoint: ~ 28 | begin_epoch: ~ -------------------------------------------------------------------------------- /train_model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | 4 | 5 | def load_image(filename, size=None, scale=None): 6 | img = Image.open(filename) 7 | if size is not None: 8 | img = img.resize((size, size), Image.LANCZOS) 9 | elif scale is not None: 10 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.LANCZOS) 11 | return img 12 | 13 | 14 | def save_image(filename, data): 15 | img = data.clone().clamp(0, 255).numpy() 16 | img = img.transpose(1, 2, 0).astype("uint8") 17 | img = Image.fromarray(img) 18 | img.save(filename) 19 | 20 | 21 | def gram_matrix(y): 22 | (b, ch, h, w) = y.size() 23 | features = y.view(b, ch, w * h) 24 | features_t = features.transpose(1, 2) # swapped ch and w*h, transpose share storage with original 25 | gram = features.bmm(features_t) / (ch * h * w) 26 | return gram 27 | 28 | 29 | def normalize_batch(batch): 30 | # normalize using imagenet mean and std 31 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) # new_tensor for same dimension of tensor 32 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 33 | batch = batch.div_(255.0) # back to tensor within 0, 1 34 | return (batch - mean) / std --------------------------------------------------------------------------------