├── .gitignore ├── LICENSE ├── README.md ├── create_chainer_model.py ├── generate.py ├── models ├── .gitkeep ├── composition.model └── seurat.model ├── net.py ├── sample_images ├── output_0.jpg ├── output_1.jpg ├── output_2.jpg ├── output_keep_colors_1.jpg ├── output_keep_colors_2.jpg ├── style_0.jpg ├── style_1.png ├── style_2.png └── tubingen.jpg ├── setup_model.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask instance folder 57 | instance/ 58 | 59 | # Scrapy stuff: 60 | .scrapy 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyBuilder 66 | target/ 67 | 68 | # IPython Notebook 69 | .ipynb_checkpoints 70 | 71 | # pyenv 72 | .python-version 73 | 74 | # celery beat schedule file 75 | celerybeat-schedule 76 | 77 | # dotenv 78 | .env 79 | 80 | # virtualenv 81 | venv/ 82 | ENV/ 83 | 84 | # Spyder project settings 85 | .spyderproject 86 | 87 | # Rope project settings 88 | .ropeproject -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Yusuke Tomoto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chainer implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" 2 | Fast artistic style transfer by using feed forward network. 3 | 4 | **checkout [resize-conv](https://github.com/yusuketomoto/chainer-fast-neuralstyle/tree/resize-conv) branch which provides better result.** 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | - input image size: 1024x768 15 | - process time(CPU): 17.78sec (Core i7-5930K) 16 | - process time(GPU): 0.994sec (GPU TitanX) 17 | 18 | 19 | ## Requirement 20 | - [Chainer](https://github.com/pfnet/chainer) 21 | ``` 22 | $ pip install chainer 23 | ``` 24 | 25 | ## Prerequisite 26 | Download VGG16 model and convert it into smaller file so that we use only the convolutional layers which are 10% of the entire model. 27 | ``` 28 | sh setup_model.sh 29 | ``` 30 | 31 | ## Train 32 | Need to train one image transformation network model per one style target. 33 | According to the paper, the models are trained on the [Microsoft COCO dataset](http://mscoco.org/dataset/#download). 34 | ``` 35 | python train.py -s -d -g 36 | ``` 37 | 38 | ## Generate 39 | ``` 40 | python generate.py -m -o -g 41 | ``` 42 | 43 | This repo has pretrained models as an example. 44 | 45 | - example: 46 | ``` 47 | python generate.py sample_images/tubingen.jpg -m models/composition.model -o sample_images/output.jpg 48 | ``` 49 | or 50 | ``` 51 | python generate.py sample_images/tubingen.jpg -m models/seurat.model -o sample_images/output.jpg 52 | ``` 53 | 54 | #### Transfer only style but not color (**--keep_colors option**) 55 | `python generate.py -m -o -g --keep_colors` 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | ## A collection of pre-trained models 65 | Fashizzle Dizzle created pre-trained models collection repository, [chainer-fast-neuralstyle-models](https://github.com/gafr/chainer-fast-neuralstyle-models). You can find a variety of models. 66 | 67 | ## Difference from paper 68 | - Convolution kernel size 4 instead of 3. 69 | - Training with batchsize(n>=2) causes unstable result. 70 | 71 | ## No Backward Compatibility 72 | ##### Jul. 19, 2016 73 | This version is not compatible with the previous versions. You can't use models trained by the previous implementation. Sorry for the inconvenience! 74 | 75 | ## License 76 | MIT 77 | 78 | ## Reference 79 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](http://arxiv.org/abs/1603.08155) 80 | 81 | Codes written in this repository based on following nice works, thanks to the author. 82 | 83 | - [chainer-gogh](https://github.com/mattya/chainer-gogh.git) Chainer implementation of neural-style. I heavily referenced it. 84 | - [chainer-cifar10](https://github.com/mitmul/chainer-cifar10) Residual block implementation is referred. 85 | -------------------------------------------------------------------------------- /create_chainer_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from chainer import link 3 | from chainer.links.caffe import CaffeFunction 4 | from chainer import serializers 5 | from net import * 6 | 7 | # http://qiita.com/tabe2314/items/6c0c1b769e12ab1e2614 8 | def copy_model(src, dst): 9 | assert isinstance(src, link.Chain) 10 | assert isinstance(dst, link.Chain) 11 | for child in src.children(): 12 | if child.name not in dst.__dict__: continue 13 | dst_child = dst[child.name] 14 | if type(child) != type(dst_child): continue 15 | if isinstance(child, link.Chain): 16 | copy_model(child, dst_child) 17 | if isinstance(child, link.Link): 18 | match = True 19 | for a, b in zip(child.namedparams(), dst_child.namedparams()): 20 | if a[0] != b[0]: 21 | match = False 22 | break 23 | if a[1].data.shape != b[1].data.shape: 24 | match = False 25 | break 26 | if not match: 27 | print('Ignore %s because of parameter mismatch' % child.name) 28 | continue 29 | for a, b in zip(child.namedparams(), dst_child.namedparams()): 30 | b[1].data = a[1].data 31 | print('Copy %s' % child.name) 32 | 33 | print('load VGG16 caffemodel') 34 | ref = CaffeFunction('VGG_ILSVRC_16_layers.caffemodel') 35 | vgg = VGG() 36 | print('copy weights') 37 | copy_model(ref, vgg) 38 | 39 | print('save "vgg16.model"') 40 | serializers.save_npz('vgg16.model', vgg) 41 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import argparse 4 | from PIL import Image, ImageFilter 5 | import time 6 | 7 | import chainer 8 | from chainer import cuda, Variable, serializers 9 | from net import * 10 | 11 | parser = argparse.ArgumentParser(description='Real-time style transfer image generator') 12 | parser.add_argument('input') 13 | parser.add_argument('--gpu', '-g', default=-1, type=int, 14 | help='GPU ID (negative value indicates CPU)') 15 | parser.add_argument('--model', '-m', default='models/style.model', type=str) 16 | parser.add_argument('--out', '-o', default='out.jpg', type=str) 17 | parser.add_argument('--median_filter', default=3, type=int) 18 | parser.add_argument('--padding', default=50, type=int) 19 | parser.add_argument('--keep_colors', action='store_true') 20 | parser.set_defaults(keep_colors=False) 21 | args = parser.parse_args() 22 | 23 | # from 6o6o's fork. https://github.com/6o6o/chainer-fast-neuralstyle/blob/master/generate.py 24 | def original_colors(original, stylized): 25 | h, s, v = original.convert('HSV').split() 26 | hs, ss, vs = stylized.convert('HSV').split() 27 | return Image.merge('HSV', (h, s, vs)).convert('RGB') 28 | 29 | model = FastStyleNet() 30 | serializers.load_npz(args.model, model) 31 | if args.gpu >= 0: 32 | cuda.get_device(args.gpu).use() 33 | model.to_gpu() 34 | xp = np if args.gpu < 0 else cuda.cupy 35 | 36 | start = time.time() 37 | original = Image.open(args.input).convert('RGB') 38 | image = np.asarray(original, dtype=np.float32).transpose(2, 0, 1) 39 | image = image.reshape((1,) + image.shape) 40 | if args.padding > 0: 41 | image = np.pad(image, [[0, 0], [0, 0], [args.padding, args.padding], [args.padding, args.padding]], 'symmetric') 42 | image = xp.asarray(image) 43 | x = Variable(image) 44 | 45 | y = model(x) 46 | result = cuda.to_cpu(y.data) 47 | 48 | if args.padding > 0: 49 | result = result[:, :, args.padding:-args.padding, args.padding:-args.padding] 50 | result = np.uint8(result[0].transpose((1, 2, 0))) 51 | med = Image.fromarray(result) 52 | if args.median_filter > 0: 53 | med = med.filter(ImageFilter.MedianFilter(args.median_filter)) 54 | if args.keep_colors: 55 | med = original_colors(original, med) 56 | print(time.time() - start, 'sec') 57 | 58 | med.save(args.out) 59 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/models/.gitkeep -------------------------------------------------------------------------------- /models/composition.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/models/composition.model -------------------------------------------------------------------------------- /models/seurat.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/models/seurat.model -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import chainer 5 | import chainer.links as L 6 | import chainer.functions as F 7 | from chainer import Variable 8 | 9 | class ResidualBlock(chainer.Chain): 10 | def __init__(self, n_in, n_out, stride=1, ksize=3): 11 | w = math.sqrt(2) 12 | super(ResidualBlock, self).__init__( 13 | c1=L.Convolution2D(n_in, n_out, ksize, stride, 1, w), 14 | c2=L.Convolution2D(n_out, n_out, ksize, 1, 1, w), 15 | b1=L.BatchNormalization(n_out), 16 | b2=L.BatchNormalization(n_out) 17 | ) 18 | 19 | def __call__(self, x, test): 20 | h = F.relu(self.b1(self.c1(x), test=test)) 21 | h = self.b2(self.c2(h), test=test) 22 | if x.data.shape != h.data.shape: 23 | xp = chainer.cuda.get_array_module(x.data) 24 | n, c, hh, ww = x.data.shape 25 | pad_c = h.data.shape[1] - c 26 | p = xp.zeros((n, pad_c, hh, ww), dtype=xp.float32) 27 | p = chainer.Variable(p, volatile=test) 28 | x = F.concat((p, x)) 29 | if x.data.shape[2:] != h.data.shape[2:]: 30 | x = F.average_pooling_2d(x, 1, 2) 31 | return h + x 32 | 33 | class FastStyleNet(chainer.Chain): 34 | def __init__(self): 35 | super(FastStyleNet, self).__init__( 36 | c1=L.Convolution2D(3, 32, 9, stride=1, pad=4), 37 | c2=L.Convolution2D(32, 64, 4, stride=2, pad=1), 38 | c3=L.Convolution2D(64, 128, 4,stride=2, pad=1), 39 | r1=ResidualBlock(128, 128), 40 | r2=ResidualBlock(128, 128), 41 | r3=ResidualBlock(128, 128), 42 | r4=ResidualBlock(128, 128), 43 | r5=ResidualBlock(128, 128), 44 | d1=L.Deconvolution2D(128, 64, 4, stride=2, pad=1), 45 | d2=L.Deconvolution2D(64, 32, 4, stride=2, pad=1), 46 | d3=L.Deconvolution2D(32, 3, 9, stride=1, pad=4), 47 | b1=L.BatchNormalization(32), 48 | b2=L.BatchNormalization(64), 49 | b3=L.BatchNormalization(128), 50 | b4=L.BatchNormalization(64), 51 | b5=L.BatchNormalization(32), 52 | ) 53 | 54 | def __call__(self, x, test=False): 55 | h = self.b1(F.elu(self.c1(x)), test=test) 56 | h = self.b2(F.elu(self.c2(h)), test=test) 57 | h = self.b3(F.elu(self.c3(h)), test=test) 58 | h = self.r1(h, test=test) 59 | h = self.r2(h, test=test) 60 | h = self.r3(h, test=test) 61 | h = self.r4(h, test=test) 62 | h = self.r5(h, test=test) 63 | h = self.b4(F.elu(self.d1(h)), test=test) 64 | h = self.b5(F.elu(self.d2(h)), test=test) 65 | y = self.d3(h) 66 | return (F.tanh(y)+1)*127.5 67 | 68 | class VGG(chainer.Chain): 69 | def __init__(self): 70 | super(VGG, self).__init__( 71 | conv1_1=L.Convolution2D(3, 64, 3, stride=1, pad=1), 72 | conv1_2=L.Convolution2D(64, 64, 3, stride=1, pad=1), 73 | 74 | conv2_1=L.Convolution2D(64, 128, 3, stride=1, pad=1), 75 | conv2_2=L.Convolution2D(128, 128, 3, stride=1, pad=1), 76 | 77 | conv3_1=L.Convolution2D(128, 256, 3, stride=1, pad=1), 78 | conv3_2=L.Convolution2D(256, 256, 3, stride=1, pad=1), 79 | conv3_3=L.Convolution2D(256, 256, 3, stride=1, pad=1), 80 | 81 | conv4_1=L.Convolution2D(256, 512, 3, stride=1, pad=1), 82 | conv4_2=L.Convolution2D(512, 512, 3, stride=1, pad=1), 83 | conv4_3=L.Convolution2D(512, 512, 3, stride=1, pad=1), 84 | 85 | conv5_1=L.Convolution2D(512, 512, 3, stride=1, pad=1), 86 | conv5_2=L.Convolution2D(512, 512, 3, stride=1, pad=1), 87 | conv5_3=L.Convolution2D(512, 512, 3, stride=1, pad=1) 88 | ) 89 | self.train = False 90 | self.mean = np.asarray(120, dtype=np.float32) 91 | 92 | def preprocess(self, image): 93 | return np.rollaxis(image - self.mean, 2) 94 | 95 | def __call__(self, x): 96 | y1 = F.relu(self.conv1_2(F.relu(self.conv1_1(x)))) 97 | h = F.max_pooling_2d(y1, 2, stride=2) 98 | y2 = F.relu(self.conv2_2(F.relu(self.conv2_1(h)))) 99 | h = F.max_pooling_2d(y2, 2, stride=2) 100 | y3 = F.relu(self.conv3_3(F.relu(self.conv3_2(F.relu(self.conv3_1(h)))))) 101 | h = F.max_pooling_2d(y3, 2, stride=2) 102 | y4 = F.relu(self.conv4_3(F.relu(self.conv4_2(F.relu(self.conv4_1(h)))))) 103 | return [y1, y2, y3, y4] 104 | -------------------------------------------------------------------------------- /sample_images/output_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_0.jpg -------------------------------------------------------------------------------- /sample_images/output_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_1.jpg -------------------------------------------------------------------------------- /sample_images/output_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_2.jpg -------------------------------------------------------------------------------- /sample_images/output_keep_colors_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_keep_colors_1.jpg -------------------------------------------------------------------------------- /sample_images/output_keep_colors_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_keep_colors_2.jpg -------------------------------------------------------------------------------- /sample_images/style_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/style_0.jpg -------------------------------------------------------------------------------- /sample_images/style_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/style_1.png -------------------------------------------------------------------------------- /sample_images/style_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/style_2.png -------------------------------------------------------------------------------- /sample_images/tubingen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/tubingen.jpg -------------------------------------------------------------------------------- /setup_model.sh: -------------------------------------------------------------------------------- 1 | if [ ! -f VGG_ILSVRC_16_layers.caffemodel ]; then 2 | wget http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel 3 | fi 4 | 5 | python create_chainer_model.py 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import numpy as np 3 | import os, re 4 | import argparse 5 | from PIL import Image 6 | 7 | from chainer import cuda, Variable, optimizers, serializers 8 | from net import * 9 | 10 | def load_image(path, size): 11 | image = Image.open(path).convert('RGB') 12 | w,h = image.size 13 | if w < h: 14 | if w < size: 15 | image = image.resize((size, size*h//w)) 16 | w, h = image.size 17 | else: 18 | if h < size: 19 | image = image.resize((size*w//h, size)) 20 | w, h = image.size 21 | image = image.crop(((w-size)*0.5, (h-size)*0.5, (w+size)*0.5, (h+size)*0.5)) 22 | return xp.asarray(image, dtype=np.float32).transpose(2, 0, 1) 23 | 24 | def gram_matrix(y): 25 | b, ch, h, w = y.data.shape 26 | features = F.reshape(y, (b, ch, w*h)) 27 | gram = F.batch_matmul(features, features, transb=True)/np.float32(ch*w*h) 28 | return gram 29 | 30 | def total_variation(x): 31 | xp = cuda.get_array_module(x.data) 32 | b, ch, h, w = x.data.shape 33 | wh = Variable(xp.asarray([[[[1], [-1]], [[0], [0]], [[0], [0]]], [[[0], [0]], [[1], [-1]], [[0], [0]]], [[[0], [0]], [[0], [0]], [[1], [-1]]]], dtype=np.float32), volatile=x.volatile) 34 | ww = Variable(xp.asarray([[[[1, -1]], [[0, 0]], [[0, 0]]], [[[0, 0]], [[1, -1]], [[0, 0]]], [[[0, 0]], [[0, 0]], [[1, -1]]]], dtype=np.float32), volatile=x.volatile) 35 | return F.sum(F.convolution_2d(x, W=wh) ** 2) + F.sum(F.convolution_2d(x, W=ww) ** 2) 36 | 37 | parser = argparse.ArgumentParser(description='Real-time style transfer') 38 | parser.add_argument('--gpu', '-g', default=-1, type=int, 39 | help='GPU ID (negative value indicates CPU)') 40 | parser.add_argument('--dataset', '-d', default='dataset', type=str, 41 | help='dataset directory path (according to the paper, use MSCOCO 80k images)') 42 | parser.add_argument('--style_image', '-s', type=str, required=True, 43 | help='style image path') 44 | parser.add_argument('--batchsize', '-b', type=int, default=1, 45 | help='batch size (default value is 1)') 46 | parser.add_argument('--initmodel', '-i', default=None, type=str, 47 | help='initialize the model from given file') 48 | parser.add_argument('--resume', '-r', default=None, type=str, 49 | help='resume the optimization from snapshot') 50 | parser.add_argument('--output', '-o', default=None, type=str, 51 | help='output model file path without extension') 52 | parser.add_argument('--lambda_tv', default=1e-6, type=float, 53 | help='weight of total variation regularization according to the paper to be set between 10e-4 and 10e-6.') 54 | parser.add_argument('--lambda_feat', default=1.0, type=float) 55 | parser.add_argument('--lambda_style', default=5.0, type=float) 56 | parser.add_argument('--epoch', '-e', default=2, type=int) 57 | parser.add_argument('--lr', '-l', default=1e-3, type=float) 58 | parser.add_argument('--checkpoint', '-c', default=0, type=int) 59 | parser.add_argument('--image_size', default=256, type=int) 60 | args = parser.parse_args() 61 | 62 | batchsize = args.batchsize 63 | 64 | image_size = args.image_size 65 | n_epoch = args.epoch 66 | lambda_tv = args.lambda_tv 67 | lambda_f = args.lambda_feat 68 | lambda_s = args.lambda_style 69 | style_prefix, _ = os.path.splitext(os.path.basename(args.style_image)) 70 | output = style_prefix if args.output == None else args.output 71 | fs = os.listdir(args.dataset) 72 | imagepaths = [] 73 | for fn in fs: 74 | base, ext = os.path.splitext(fn) 75 | if ext == '.jpg' or ext == '.png': 76 | imagepath = os.path.join(args.dataset,fn) 77 | imagepaths.append(imagepath) 78 | n_data = len(imagepaths) 79 | print('num traning images:', n_data) 80 | n_iter = n_data // batchsize 81 | print(n_iter, 'iterations,', n_epoch, 'epochs') 82 | 83 | model = FastStyleNet() 84 | vgg = VGG() 85 | serializers.load_npz('vgg16.model', vgg) 86 | if args.initmodel: 87 | print('load model from', args.initmodel) 88 | serializers.load_npz(args.initmodel, model) 89 | if args.gpu >= 0: 90 | cuda.get_device(args.gpu).use() 91 | model.to_gpu() 92 | vgg.to_gpu() 93 | xp = np if args.gpu < 0 else cuda.cupy 94 | 95 | O = optimizers.Adam(alpha=args.lr) 96 | O.setup(model) 97 | if args.resume: 98 | print('load optimizer state from', args.resume) 99 | serializers.load_npz(args.resume, O) 100 | 101 | style = vgg.preprocess(np.asarray(Image.open(args.style_image).convert('RGB').resize((image_size,image_size)), dtype=np.float32)) 102 | style = xp.asarray(style, dtype=xp.float32) 103 | style_b = xp.zeros((batchsize,) + style.shape, dtype=xp.float32) 104 | for i in range(batchsize): 105 | style_b[i] = style 106 | feature_s = vgg(Variable(style_b, volatile=True)) 107 | gram_s = [gram_matrix(y) for y in feature_s] 108 | 109 | for epoch in range(n_epoch): 110 | print('epoch', epoch) 111 | for i in range(n_iter): 112 | model.zerograds() 113 | vgg.zerograds() 114 | 115 | indices = range(i * batchsize, (i+1) * batchsize) 116 | x = xp.zeros((batchsize, 3, image_size, image_size), dtype=xp.float32) 117 | for j in range(batchsize): 118 | x[j] = load_image(imagepaths[i*batchsize + j], image_size) 119 | 120 | xc = Variable(x.copy(), volatile=True) 121 | x = Variable(x) 122 | 123 | y = model(x) 124 | 125 | xc -= 120 126 | y -= 120 127 | 128 | feature = vgg(xc) 129 | feature_hat = vgg(y) 130 | 131 | L_feat = lambda_f * F.mean_squared_error(Variable(feature[2].data), feature_hat[2]) # compute for only the output of layer conv3_3 132 | 133 | L_style = Variable(xp.zeros((), dtype=np.float32)) 134 | for f, f_hat, g_s in zip(feature, feature_hat, gram_s): 135 | L_style += lambda_s * F.mean_squared_error(gram_matrix(f_hat), Variable(g_s.data)) 136 | 137 | L_tv = lambda_tv * total_variation(y) 138 | L = L_feat + L_style + L_tv 139 | 140 | print('(epoch {}) batch {}/{}... training loss is...{}'.format(epoch, i, n_iter, L.data)) 141 | 142 | L.backward() 143 | O.update() 144 | 145 | if args.checkpoint > 0 and i % args.checkpoint == 0: 146 | serializers.save_npz('models/{}_{}_{}.model'.format(output, epoch, i), model) 147 | serializers.save_npz('models/{}_{}_{}.state'.format(output, epoch, i), O) 148 | 149 | print('save "style.model"') 150 | serializers.save_npz('models/{}_{}.model'.format(output, epoch), model) 151 | serializers.save_npz('models/{}_{}.state'.format(output, epoch), O) 152 | 153 | serializers.save_npz('models/{}.model'.format(output), model) 154 | serializers.save_npz('models/{}.state'.format(output), O) 155 | --------------------------------------------------------------------------------