├── .gitignore
├── LICENSE
├── README.md
├── create_chainer_model.py
├── generate.py
├── models
├── .gitkeep
├── composition.model
└── seurat.model
├── net.py
├── sample_images
├── output_0.jpg
├── output_1.jpg
├── output_2.jpg
├── output_keep_colors_1.jpg
├── output_keep_colors_2.jpg
├── style_0.jpg
├── style_1.png
├── style_2.png
└── tubingen.jpg
├── setup_model.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask instance folder
57 | instance/
58 |
59 | # Scrapy stuff:
60 | .scrapy
61 |
62 | # Sphinx documentation
63 | docs/_build/
64 |
65 | # PyBuilder
66 | target/
67 |
68 | # IPython Notebook
69 | .ipynb_checkpoints
70 |
71 | # pyenv
72 | .python-version
73 |
74 | # celery beat schedule file
75 | celerybeat-schedule
76 |
77 | # dotenv
78 | .env
79 |
80 | # virtualenv
81 | venv/
82 | ENV/
83 |
84 | # Spyder project settings
85 | .spyderproject
86 |
87 | # Rope project settings
88 | .ropeproject
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Yusuke Tomoto
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chainer implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution"
2 | Fast artistic style transfer by using feed forward network.
3 |
4 | **checkout [resize-conv](https://github.com/yusuketomoto/chainer-fast-neuralstyle/tree/resize-conv) branch which provides better result.**
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | - input image size: 1024x768
15 | - process time(CPU): 17.78sec (Core i7-5930K)
16 | - process time(GPU): 0.994sec (GPU TitanX)
17 |
18 |
19 | ## Requirement
20 | - [Chainer](https://github.com/pfnet/chainer)
21 | ```
22 | $ pip install chainer
23 | ```
24 |
25 | ## Prerequisite
26 | Download VGG16 model and convert it into smaller file so that we use only the convolutional layers which are 10% of the entire model.
27 | ```
28 | sh setup_model.sh
29 | ```
30 |
31 | ## Train
32 | Need to train one image transformation network model per one style target.
33 | According to the paper, the models are trained on the [Microsoft COCO dataset](http://mscoco.org/dataset/#download).
34 | ```
35 | python train.py -s -d -g
36 | ```
37 |
38 | ## Generate
39 | ```
40 | python generate.py -m -o -g
41 | ```
42 |
43 | This repo has pretrained models as an example.
44 |
45 | - example:
46 | ```
47 | python generate.py sample_images/tubingen.jpg -m models/composition.model -o sample_images/output.jpg
48 | ```
49 | or
50 | ```
51 | python generate.py sample_images/tubingen.jpg -m models/seurat.model -o sample_images/output.jpg
52 | ```
53 |
54 | #### Transfer only style but not color (**--keep_colors option**)
55 | `python generate.py -m -o -g --keep_colors`
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 | ## A collection of pre-trained models
65 | Fashizzle Dizzle created pre-trained models collection repository, [chainer-fast-neuralstyle-models](https://github.com/gafr/chainer-fast-neuralstyle-models). You can find a variety of models.
66 |
67 | ## Difference from paper
68 | - Convolution kernel size 4 instead of 3.
69 | - Training with batchsize(n>=2) causes unstable result.
70 |
71 | ## No Backward Compatibility
72 | ##### Jul. 19, 2016
73 | This version is not compatible with the previous versions. You can't use models trained by the previous implementation. Sorry for the inconvenience!
74 |
75 | ## License
76 | MIT
77 |
78 | ## Reference
79 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](http://arxiv.org/abs/1603.08155)
80 |
81 | Codes written in this repository based on following nice works, thanks to the author.
82 |
83 | - [chainer-gogh](https://github.com/mattya/chainer-gogh.git) Chainer implementation of neural-style. I heavily referenced it.
84 | - [chainer-cifar10](https://github.com/mitmul/chainer-cifar10) Residual block implementation is referred.
85 |
--------------------------------------------------------------------------------
/create_chainer_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from chainer import link
3 | from chainer.links.caffe import CaffeFunction
4 | from chainer import serializers
5 | from net import *
6 |
7 | # http://qiita.com/tabe2314/items/6c0c1b769e12ab1e2614
8 | def copy_model(src, dst):
9 | assert isinstance(src, link.Chain)
10 | assert isinstance(dst, link.Chain)
11 | for child in src.children():
12 | if child.name not in dst.__dict__: continue
13 | dst_child = dst[child.name]
14 | if type(child) != type(dst_child): continue
15 | if isinstance(child, link.Chain):
16 | copy_model(child, dst_child)
17 | if isinstance(child, link.Link):
18 | match = True
19 | for a, b in zip(child.namedparams(), dst_child.namedparams()):
20 | if a[0] != b[0]:
21 | match = False
22 | break
23 | if a[1].data.shape != b[1].data.shape:
24 | match = False
25 | break
26 | if not match:
27 | print('Ignore %s because of parameter mismatch' % child.name)
28 | continue
29 | for a, b in zip(child.namedparams(), dst_child.namedparams()):
30 | b[1].data = a[1].data
31 | print('Copy %s' % child.name)
32 |
33 | print('load VGG16 caffemodel')
34 | ref = CaffeFunction('VGG_ILSVRC_16_layers.caffemodel')
35 | vgg = VGG()
36 | print('copy weights')
37 | copy_model(ref, vgg)
38 |
39 | print('save "vgg16.model"')
40 | serializers.save_npz('vgg16.model', vgg)
41 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import argparse
4 | from PIL import Image, ImageFilter
5 | import time
6 |
7 | import chainer
8 | from chainer import cuda, Variable, serializers
9 | from net import *
10 |
11 | parser = argparse.ArgumentParser(description='Real-time style transfer image generator')
12 | parser.add_argument('input')
13 | parser.add_argument('--gpu', '-g', default=-1, type=int,
14 | help='GPU ID (negative value indicates CPU)')
15 | parser.add_argument('--model', '-m', default='models/style.model', type=str)
16 | parser.add_argument('--out', '-o', default='out.jpg', type=str)
17 | parser.add_argument('--median_filter', default=3, type=int)
18 | parser.add_argument('--padding', default=50, type=int)
19 | parser.add_argument('--keep_colors', action='store_true')
20 | parser.set_defaults(keep_colors=False)
21 | args = parser.parse_args()
22 |
23 | # from 6o6o's fork. https://github.com/6o6o/chainer-fast-neuralstyle/blob/master/generate.py
24 | def original_colors(original, stylized):
25 | h, s, v = original.convert('HSV').split()
26 | hs, ss, vs = stylized.convert('HSV').split()
27 | return Image.merge('HSV', (h, s, vs)).convert('RGB')
28 |
29 | model = FastStyleNet()
30 | serializers.load_npz(args.model, model)
31 | if args.gpu >= 0:
32 | cuda.get_device(args.gpu).use()
33 | model.to_gpu()
34 | xp = np if args.gpu < 0 else cuda.cupy
35 |
36 | start = time.time()
37 | original = Image.open(args.input).convert('RGB')
38 | image = np.asarray(original, dtype=np.float32).transpose(2, 0, 1)
39 | image = image.reshape((1,) + image.shape)
40 | if args.padding > 0:
41 | image = np.pad(image, [[0, 0], [0, 0], [args.padding, args.padding], [args.padding, args.padding]], 'symmetric')
42 | image = xp.asarray(image)
43 | x = Variable(image)
44 |
45 | y = model(x)
46 | result = cuda.to_cpu(y.data)
47 |
48 | if args.padding > 0:
49 | result = result[:, :, args.padding:-args.padding, args.padding:-args.padding]
50 | result = np.uint8(result[0].transpose((1, 2, 0)))
51 | med = Image.fromarray(result)
52 | if args.median_filter > 0:
53 | med = med.filter(ImageFilter.MedianFilter(args.median_filter))
54 | if args.keep_colors:
55 | med = original_colors(original, med)
56 | print(time.time() - start, 'sec')
57 |
58 | med.save(args.out)
59 |
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/models/.gitkeep
--------------------------------------------------------------------------------
/models/composition.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/models/composition.model
--------------------------------------------------------------------------------
/models/seurat.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/models/seurat.model
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import chainer
5 | import chainer.links as L
6 | import chainer.functions as F
7 | from chainer import Variable
8 |
9 | class ResidualBlock(chainer.Chain):
10 | def __init__(self, n_in, n_out, stride=1, ksize=3):
11 | w = math.sqrt(2)
12 | super(ResidualBlock, self).__init__(
13 | c1=L.Convolution2D(n_in, n_out, ksize, stride, 1, w),
14 | c2=L.Convolution2D(n_out, n_out, ksize, 1, 1, w),
15 | b1=L.BatchNormalization(n_out),
16 | b2=L.BatchNormalization(n_out)
17 | )
18 |
19 | def __call__(self, x, test):
20 | h = F.relu(self.b1(self.c1(x), test=test))
21 | h = self.b2(self.c2(h), test=test)
22 | if x.data.shape != h.data.shape:
23 | xp = chainer.cuda.get_array_module(x.data)
24 | n, c, hh, ww = x.data.shape
25 | pad_c = h.data.shape[1] - c
26 | p = xp.zeros((n, pad_c, hh, ww), dtype=xp.float32)
27 | p = chainer.Variable(p, volatile=test)
28 | x = F.concat((p, x))
29 | if x.data.shape[2:] != h.data.shape[2:]:
30 | x = F.average_pooling_2d(x, 1, 2)
31 | return h + x
32 |
33 | class FastStyleNet(chainer.Chain):
34 | def __init__(self):
35 | super(FastStyleNet, self).__init__(
36 | c1=L.Convolution2D(3, 32, 9, stride=1, pad=4),
37 | c2=L.Convolution2D(32, 64, 4, stride=2, pad=1),
38 | c3=L.Convolution2D(64, 128, 4,stride=2, pad=1),
39 | r1=ResidualBlock(128, 128),
40 | r2=ResidualBlock(128, 128),
41 | r3=ResidualBlock(128, 128),
42 | r4=ResidualBlock(128, 128),
43 | r5=ResidualBlock(128, 128),
44 | d1=L.Deconvolution2D(128, 64, 4, stride=2, pad=1),
45 | d2=L.Deconvolution2D(64, 32, 4, stride=2, pad=1),
46 | d3=L.Deconvolution2D(32, 3, 9, stride=1, pad=4),
47 | b1=L.BatchNormalization(32),
48 | b2=L.BatchNormalization(64),
49 | b3=L.BatchNormalization(128),
50 | b4=L.BatchNormalization(64),
51 | b5=L.BatchNormalization(32),
52 | )
53 |
54 | def __call__(self, x, test=False):
55 | h = self.b1(F.elu(self.c1(x)), test=test)
56 | h = self.b2(F.elu(self.c2(h)), test=test)
57 | h = self.b3(F.elu(self.c3(h)), test=test)
58 | h = self.r1(h, test=test)
59 | h = self.r2(h, test=test)
60 | h = self.r3(h, test=test)
61 | h = self.r4(h, test=test)
62 | h = self.r5(h, test=test)
63 | h = self.b4(F.elu(self.d1(h)), test=test)
64 | h = self.b5(F.elu(self.d2(h)), test=test)
65 | y = self.d3(h)
66 | return (F.tanh(y)+1)*127.5
67 |
68 | class VGG(chainer.Chain):
69 | def __init__(self):
70 | super(VGG, self).__init__(
71 | conv1_1=L.Convolution2D(3, 64, 3, stride=1, pad=1),
72 | conv1_2=L.Convolution2D(64, 64, 3, stride=1, pad=1),
73 |
74 | conv2_1=L.Convolution2D(64, 128, 3, stride=1, pad=1),
75 | conv2_2=L.Convolution2D(128, 128, 3, stride=1, pad=1),
76 |
77 | conv3_1=L.Convolution2D(128, 256, 3, stride=1, pad=1),
78 | conv3_2=L.Convolution2D(256, 256, 3, stride=1, pad=1),
79 | conv3_3=L.Convolution2D(256, 256, 3, stride=1, pad=1),
80 |
81 | conv4_1=L.Convolution2D(256, 512, 3, stride=1, pad=1),
82 | conv4_2=L.Convolution2D(512, 512, 3, stride=1, pad=1),
83 | conv4_3=L.Convolution2D(512, 512, 3, stride=1, pad=1),
84 |
85 | conv5_1=L.Convolution2D(512, 512, 3, stride=1, pad=1),
86 | conv5_2=L.Convolution2D(512, 512, 3, stride=1, pad=1),
87 | conv5_3=L.Convolution2D(512, 512, 3, stride=1, pad=1)
88 | )
89 | self.train = False
90 | self.mean = np.asarray(120, dtype=np.float32)
91 |
92 | def preprocess(self, image):
93 | return np.rollaxis(image - self.mean, 2)
94 |
95 | def __call__(self, x):
96 | y1 = F.relu(self.conv1_2(F.relu(self.conv1_1(x))))
97 | h = F.max_pooling_2d(y1, 2, stride=2)
98 | y2 = F.relu(self.conv2_2(F.relu(self.conv2_1(h))))
99 | h = F.max_pooling_2d(y2, 2, stride=2)
100 | y3 = F.relu(self.conv3_3(F.relu(self.conv3_2(F.relu(self.conv3_1(h))))))
101 | h = F.max_pooling_2d(y3, 2, stride=2)
102 | y4 = F.relu(self.conv4_3(F.relu(self.conv4_2(F.relu(self.conv4_1(h))))))
103 | return [y1, y2, y3, y4]
104 |
--------------------------------------------------------------------------------
/sample_images/output_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_0.jpg
--------------------------------------------------------------------------------
/sample_images/output_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_1.jpg
--------------------------------------------------------------------------------
/sample_images/output_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_2.jpg
--------------------------------------------------------------------------------
/sample_images/output_keep_colors_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_keep_colors_1.jpg
--------------------------------------------------------------------------------
/sample_images/output_keep_colors_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/output_keep_colors_2.jpg
--------------------------------------------------------------------------------
/sample_images/style_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/style_0.jpg
--------------------------------------------------------------------------------
/sample_images/style_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/style_1.png
--------------------------------------------------------------------------------
/sample_images/style_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/style_2.png
--------------------------------------------------------------------------------
/sample_images/tubingen.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yusuketomoto/chainer-fast-neuralstyle/cf4a8466d95641b66976034c9411605754f2cf1c/sample_images/tubingen.jpg
--------------------------------------------------------------------------------
/setup_model.sh:
--------------------------------------------------------------------------------
1 | if [ ! -f VGG_ILSVRC_16_layers.caffemodel ]; then
2 | wget http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel
3 | fi
4 |
5 | python create_chainer_model.py
6 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import numpy as np
3 | import os, re
4 | import argparse
5 | from PIL import Image
6 |
7 | from chainer import cuda, Variable, optimizers, serializers
8 | from net import *
9 |
10 | def load_image(path, size):
11 | image = Image.open(path).convert('RGB')
12 | w,h = image.size
13 | if w < h:
14 | if w < size:
15 | image = image.resize((size, size*h//w))
16 | w, h = image.size
17 | else:
18 | if h < size:
19 | image = image.resize((size*w//h, size))
20 | w, h = image.size
21 | image = image.crop(((w-size)*0.5, (h-size)*0.5, (w+size)*0.5, (h+size)*0.5))
22 | return xp.asarray(image, dtype=np.float32).transpose(2, 0, 1)
23 |
24 | def gram_matrix(y):
25 | b, ch, h, w = y.data.shape
26 | features = F.reshape(y, (b, ch, w*h))
27 | gram = F.batch_matmul(features, features, transb=True)/np.float32(ch*w*h)
28 | return gram
29 |
30 | def total_variation(x):
31 | xp = cuda.get_array_module(x.data)
32 | b, ch, h, w = x.data.shape
33 | wh = Variable(xp.asarray([[[[1], [-1]], [[0], [0]], [[0], [0]]], [[[0], [0]], [[1], [-1]], [[0], [0]]], [[[0], [0]], [[0], [0]], [[1], [-1]]]], dtype=np.float32), volatile=x.volatile)
34 | ww = Variable(xp.asarray([[[[1, -1]], [[0, 0]], [[0, 0]]], [[[0, 0]], [[1, -1]], [[0, 0]]], [[[0, 0]], [[0, 0]], [[1, -1]]]], dtype=np.float32), volatile=x.volatile)
35 | return F.sum(F.convolution_2d(x, W=wh) ** 2) + F.sum(F.convolution_2d(x, W=ww) ** 2)
36 |
37 | parser = argparse.ArgumentParser(description='Real-time style transfer')
38 | parser.add_argument('--gpu', '-g', default=-1, type=int,
39 | help='GPU ID (negative value indicates CPU)')
40 | parser.add_argument('--dataset', '-d', default='dataset', type=str,
41 | help='dataset directory path (according to the paper, use MSCOCO 80k images)')
42 | parser.add_argument('--style_image', '-s', type=str, required=True,
43 | help='style image path')
44 | parser.add_argument('--batchsize', '-b', type=int, default=1,
45 | help='batch size (default value is 1)')
46 | parser.add_argument('--initmodel', '-i', default=None, type=str,
47 | help='initialize the model from given file')
48 | parser.add_argument('--resume', '-r', default=None, type=str,
49 | help='resume the optimization from snapshot')
50 | parser.add_argument('--output', '-o', default=None, type=str,
51 | help='output model file path without extension')
52 | parser.add_argument('--lambda_tv', default=1e-6, type=float,
53 | help='weight of total variation regularization according to the paper to be set between 10e-4 and 10e-6.')
54 | parser.add_argument('--lambda_feat', default=1.0, type=float)
55 | parser.add_argument('--lambda_style', default=5.0, type=float)
56 | parser.add_argument('--epoch', '-e', default=2, type=int)
57 | parser.add_argument('--lr', '-l', default=1e-3, type=float)
58 | parser.add_argument('--checkpoint', '-c', default=0, type=int)
59 | parser.add_argument('--image_size', default=256, type=int)
60 | args = parser.parse_args()
61 |
62 | batchsize = args.batchsize
63 |
64 | image_size = args.image_size
65 | n_epoch = args.epoch
66 | lambda_tv = args.lambda_tv
67 | lambda_f = args.lambda_feat
68 | lambda_s = args.lambda_style
69 | style_prefix, _ = os.path.splitext(os.path.basename(args.style_image))
70 | output = style_prefix if args.output == None else args.output
71 | fs = os.listdir(args.dataset)
72 | imagepaths = []
73 | for fn in fs:
74 | base, ext = os.path.splitext(fn)
75 | if ext == '.jpg' or ext == '.png':
76 | imagepath = os.path.join(args.dataset,fn)
77 | imagepaths.append(imagepath)
78 | n_data = len(imagepaths)
79 | print('num traning images:', n_data)
80 | n_iter = n_data // batchsize
81 | print(n_iter, 'iterations,', n_epoch, 'epochs')
82 |
83 | model = FastStyleNet()
84 | vgg = VGG()
85 | serializers.load_npz('vgg16.model', vgg)
86 | if args.initmodel:
87 | print('load model from', args.initmodel)
88 | serializers.load_npz(args.initmodel, model)
89 | if args.gpu >= 0:
90 | cuda.get_device(args.gpu).use()
91 | model.to_gpu()
92 | vgg.to_gpu()
93 | xp = np if args.gpu < 0 else cuda.cupy
94 |
95 | O = optimizers.Adam(alpha=args.lr)
96 | O.setup(model)
97 | if args.resume:
98 | print('load optimizer state from', args.resume)
99 | serializers.load_npz(args.resume, O)
100 |
101 | style = vgg.preprocess(np.asarray(Image.open(args.style_image).convert('RGB').resize((image_size,image_size)), dtype=np.float32))
102 | style = xp.asarray(style, dtype=xp.float32)
103 | style_b = xp.zeros((batchsize,) + style.shape, dtype=xp.float32)
104 | for i in range(batchsize):
105 | style_b[i] = style
106 | feature_s = vgg(Variable(style_b, volatile=True))
107 | gram_s = [gram_matrix(y) for y in feature_s]
108 |
109 | for epoch in range(n_epoch):
110 | print('epoch', epoch)
111 | for i in range(n_iter):
112 | model.zerograds()
113 | vgg.zerograds()
114 |
115 | indices = range(i * batchsize, (i+1) * batchsize)
116 | x = xp.zeros((batchsize, 3, image_size, image_size), dtype=xp.float32)
117 | for j in range(batchsize):
118 | x[j] = load_image(imagepaths[i*batchsize + j], image_size)
119 |
120 | xc = Variable(x.copy(), volatile=True)
121 | x = Variable(x)
122 |
123 | y = model(x)
124 |
125 | xc -= 120
126 | y -= 120
127 |
128 | feature = vgg(xc)
129 | feature_hat = vgg(y)
130 |
131 | L_feat = lambda_f * F.mean_squared_error(Variable(feature[2].data), feature_hat[2]) # compute for only the output of layer conv3_3
132 |
133 | L_style = Variable(xp.zeros((), dtype=np.float32))
134 | for f, f_hat, g_s in zip(feature, feature_hat, gram_s):
135 | L_style += lambda_s * F.mean_squared_error(gram_matrix(f_hat), Variable(g_s.data))
136 |
137 | L_tv = lambda_tv * total_variation(y)
138 | L = L_feat + L_style + L_tv
139 |
140 | print('(epoch {}) batch {}/{}... training loss is...{}'.format(epoch, i, n_iter, L.data))
141 |
142 | L.backward()
143 | O.update()
144 |
145 | if args.checkpoint > 0 and i % args.checkpoint == 0:
146 | serializers.save_npz('models/{}_{}_{}.model'.format(output, epoch, i), model)
147 | serializers.save_npz('models/{}_{}_{}.state'.format(output, epoch, i), O)
148 |
149 | print('save "style.model"')
150 | serializers.save_npz('models/{}_{}.model'.format(output, epoch), model)
151 | serializers.save_npz('models/{}_{}.state'.format(output, epoch), O)
152 |
153 | serializers.save_npz('models/{}.model'.format(output), model)
154 | serializers.save_npz('models/{}.state'.format(output), O)
155 |
--------------------------------------------------------------------------------