├── pbs
└── Robert_D.pb
├── images
├── content
│ └── tubingen.jpg
├── styles
│ └── Robert_D.jpg
└── output
│ └── RobertD_output.jpg
├── .gitignore
├── freeze_graph.py
├── custom_vgg16.py
├── generate.py
├── README.md
├── net.py
└── train.py
/pbs/Robert_D.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/antlerros/tensorflow-fast-neuralstyle/HEAD/pbs/Robert_D.pb
--------------------------------------------------------------------------------
/images/content/tubingen.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/antlerros/tensorflow-fast-neuralstyle/HEAD/images/content/tubingen.jpg
--------------------------------------------------------------------------------
/images/styles/Robert_D.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/antlerros/tensorflow-fast-neuralstyle/HEAD/images/styles/Robert_D.jpg
--------------------------------------------------------------------------------
/images/output/RobertD_output.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/antlerros/tensorflow-fast-neuralstyle/HEAD/images/output/RobertD_output.jpg
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask instance folder
57 | instance/
58 |
59 | # Scrapy stuff:
60 | .scrapy
61 |
62 | # Sphinx documentation
63 | docs/_build/
64 |
65 | # PyBuilder
66 | target/
67 |
68 | # IPython Notebook
69 | .ipynb_checkpoints
70 |
71 | # pyenv
72 | .python-version
73 |
74 | # celery beat schedule file
75 | celerybeat-schedule
76 |
77 | # dotenv
78 | .env
79 |
80 | # virtualenv
81 | venv/
82 | ENV/
83 |
84 | # Spyder project settings
85 | .spyderproject
86 |
87 | # Rope project settings
88 | .ropeproject
--------------------------------------------------------------------------------
/freeze_graph.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import tensorflow as tf
3 | from tensorflow.python.framework import graph_util
4 |
5 | dir = os.path.dirname(os.path.realpath(__file__))
6 |
7 | def freeze_graph(model_folder, output_graph, output_node_name):
8 | # We retrieve our checkpoint fullpath
9 | checkpoint = tf.train.get_checkpoint_state(model_folder)
10 | input_checkpoint = checkpoint.model_checkpoint_path
11 |
12 |
13 | # We precise the file fullname of our freezed graph
14 | absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
15 | # output_graph = absolute_model_folder + "/{}.pb".format(style_name)
16 |
17 | # Before exporting our graph, we need to precise what is our output node
18 | # This is how TF decides what part of the Graph he has to keep and what part it can dump
19 | # NOTE: this variable is plural, because you can have multiple output nodes
20 | output_node_names = output_node_name
21 |
22 | # We clear devices to allow TensorFlow to control on which device it will load operations
23 | clear_devices = True
24 |
25 | # We import the meta graph and retrieve a Saver
26 | saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
27 |
28 | # We retrieve the protobuf graph definition
29 | graph = tf.get_default_graph()
30 | input_graph_def = graph.as_graph_def()
31 |
32 | # We start a session and restore the graph weights
33 | with tf.Session() as sess:
34 | saver.restore(sess, input_checkpoint)
35 |
36 | # We use a built-in TF helper to export variables to constants
37 | output_graph_def = graph_util.convert_variables_to_constants(
38 | sess, # The session is used to retrieve the weights
39 | input_graph_def, # The graph_def is used to retrieve the nodes
40 | output_node_names.split(",") # The output node names are used to select the usefull nodes
41 | )
42 |
43 | # Finally we serialize and dump the output graph to the filesystem
44 | with tf.gfile.GFile(output_graph, "wb") as f:
45 | f.write(output_graph_def.SerializeToString())
46 | print("{} ops in the final graph.".format(len(output_graph_def.node)))
--------------------------------------------------------------------------------
/custom_vgg16.py:
--------------------------------------------------------------------------------
1 | import os, sys, inspect
2 | import tensorflow as tf
3 |
4 | import numpy as np
5 | import time
6 | from tensorflow_vgg import vgg16
7 |
8 | VGG_MEAN = [103.939, 116.779, 123.68]
9 |
10 | def loadWeightsData(vgg16_npy_path=None):
11 | if vgg16_npy_path is None:
12 | path = inspect.getfile(Vgg16)
13 | path = os.path.abspath(os.path.join(path, os.pardir))
14 | path = os.path.join(path, "vgg16.npy")
15 | vgg16_npy_path = path
16 | print (vgg16_npy_path)
17 | return np.load(vgg16_npy_path, encoding='latin1').item()
18 |
19 | class custom_Vgg16(vgg16.Vgg16):
20 | # Input should be an rgb image [batch, height, width, 3]
21 | # values scaled [0, 1]
22 |
23 | def __init__(self, rgb, data_dict, train=False):
24 | # It's a shared weights data and used in various
25 | # member functions.
26 | self.data_dict = data_dict
27 |
28 | # start_time = time.time()
29 |
30 | # rgb_scaled = rgb * 255.0
31 | rgb_scaled = rgb
32 | # Convert RGB to BGR
33 | red, green, blue = tf.split(rgb_scaled, 3, 3)
34 |
35 | bgr = tf.concat([blue - VGG_MEAN[0],
36 | green - VGG_MEAN[1],
37 | red - VGG_MEAN[2]],
38 | 3)
39 |
40 | self.conv1_1 = self.conv_layer(bgr, "conv1_1")
41 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
42 | self.pool1 = self.max_pool(self.conv1_2, 'pool1')
43 |
44 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
45 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
46 | self.pool2 = self.max_pool(self.conv2_2, 'pool2')
47 |
48 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
49 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
50 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
51 | self.pool3 = self.max_pool(self.conv3_3, 'pool3')
52 |
53 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
54 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
55 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
56 | self.pool4 = self.max_pool(self.conv4_3, 'pool4')
57 |
58 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
59 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
60 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
61 | self.pool5 = self.max_pool(self.conv5_3, 'pool5')
62 |
63 | # self.data_dict = None
64 | # print ("build model finished: %ds" % (time.time() - start_time))
65 |
66 | def debug(self):
67 | pass
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import tensorflow as tf
4 | import os
5 | from PIL import Image
6 |
7 | parser = argparse.ArgumentParser(description='Real-time style transfer image generator')
8 | parser.add_argument('--input', '-i', type=str, help='content image')
9 | parser.add_argument('--gpu', '-g', default=-1, type=int,
10 | help='GPU ID (negative value indicates CPU)')
11 | parser.add_argument('--style', '-s', default=None, type=str, help='style model name')
12 | parser.add_argument('--ckpt', '-c', default=-1, type=int, help='checkpoint to be loaded')
13 | parser.add_argument('--out', '-o', default='stylized_image.jpg', type=str, help='stylized image\'s name')
14 | parser.add_argument('--pb', '-pb', default=False, type=bool, help='load with pb')
15 |
16 | args = parser.parse_args()
17 |
18 | if not os.path.exists('./images/output/'):
19 | os.makedirs('./images/output/')
20 |
21 | outfile_path = './images/output/' + args.out
22 | content_image_path = args.input
23 | style_name = args.style
24 | ckpt = args.ckpt
25 | load_with_pb = args.pb
26 | gpu = args.gpu
27 |
28 | original_image = Image.open(content_image_path).convert('RGB')
29 |
30 | img = np.asarray(original_image.resize((224, 224)), dtype=np.float32)
31 | shaped_input = img.reshape((1,) + img.shape)
32 |
33 | if gpu > -1:
34 | device = '/gpu:{}'.format(gpu)
35 | else:
36 | device = '/cpu:0'
37 |
38 |
39 | with tf.device(device):
40 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
41 | if load_with_pb:
42 | from tensorflow.core.framework import graph_pb2
43 | graph_def = graph_pb2.GraphDef()
44 | with open('./pbs/{}.pb'.format(style_name), "rb") as f:
45 | graph_def.ParseFromString(f.read())
46 | input_image, output = tf.import_graph_def(graph_def, return_elements=['input:0', 'output:0'])
47 |
48 | else:
49 | if ckpt < 0:
50 | checkpoint = tf.train.get_checkpoint_state('./ckpts/{}/'.format(style_name))
51 | input_checkpoint = checkpoint.model_checkpoint_path
52 | else:
53 | input_checkpoint = './ckpts/{}/{}-{}'.format(style_name, style_name, ckpt)
54 | saver = tf.train.import_meta_graph(input_checkpoint + '.meta')
55 | saver.restore(sess, input_checkpoint)
56 | graph = tf.get_default_graph()
57 |
58 | input_image = graph.get_tensor_by_name('input:0')
59 | output = graph.get_tensor_by_name('output:0')
60 |
61 | out = sess.run(output, feed_dict={input_image: shaped_input})
62 |
63 | out = out.reshape((out.shape[1:]))
64 | im = Image.fromarray(np.uint8(out))
65 |
66 | im = im.resize(original_image.size, resample=Image.LANCZOS)
67 | im.save(outfile_path)
68 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensorflow implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution"
2 | Fast artistic style transfer by using feed forward network.
3 |
4 |
5 |
6 |
7 |
8 |
9 | - input image size: 1024x768
10 | - process time(CPU): 2.246 sec (Core i5-5257U)
11 | - process time(GPU): 1.728 sec (GPU GRID K520)
12 |
13 |
14 | ## Requirement
15 | - [Tensorflow 1.0](https://github.com/tensorflow/tensorflow)
16 | - [Pillow](https://github.com/python-pillow/Pillow)
17 | - [Numpy](https://github.com/numpy/numpy)
18 | - [Scipy](https://github.com/scipy/scipy)
19 |
20 |
21 | ## Prerequisite
22 | In this implementation, the VGG model part was based on [Tensorflow VGG16 and VGG19](https://github.com/machrisaa/tensorflow-vgg). Please add this as a submodule and follow the instructions there to have vgg16 model. Make sure the name of the module in your project matches the one in line 6 of`custom_vgg16.py`.
23 |
24 | ## Train a style model
25 | Need to train one image transformation network model per one style target.
26 | According to the paper, the models are trained on the [Microsoft COCO dataset](http://mscoco.org/dataset/#download).
27 | Also, it will save the transformation model, including the trained weights, for later use (in C++) in ```pbs``` directory, while the checkpoint files would be saved in ```ckpts//```.
28 |
29 |
30 | - ```.pb``` is saved by default. To turn off, add argument ```-pb 0```.
31 | - To train a model from scratch.
32 | ```
33 | python train.py -s -d -g 0
34 |
35 | ```
36 | - To load a pre-trained model, specify the checkpoint to load. Negative checkpoint value suggests using the latest checkpoint.
37 | ```
38 | python train.py -s -d -c
39 | ```
40 |
41 | ## Generate a stylized image
42 |
43 | ### Load with .pb file
44 | ```
45 | python generate.py -i -o -s -pb 1
46 | ```
47 |
48 | ### Load with checkpoint files
49 | - By default, the latest checkpoint file is used (negative value for the checkpoint argument).
50 | ```
51 | python generate.py -s -o -c
52 | ```
53 |
54 | ## Difference between implementation and paper
55 | - Convolution kernel size 4 instead of 3.
56 | - Training with batchsize(n >= 2) causes unstable result.
57 |
58 | ## License
59 | MIT
60 |
61 | ## Reference
62 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)
63 |
64 | Codes written in this repository based on following nice works, thanks to the author.
65 |
66 | - [Chainer implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution"](https://github.com/yusuketomoto/chainer-fast-neuralstyle)
67 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def weight_variable(shape, name=None):
5 | # initialize weighted variables.
6 | initial = tf.truncated_normal(shape, stddev=0.001)
7 | return tf.Variable(initial, name=name)
8 |
9 | def conv2d(x, W, strides=[1, 1, 1, 1], p='SAME', name=None):
10 | # set convolution layers.
11 | assert isinstance(x, tf.Tensor)
12 | return tf.nn.conv2d(x, W, strides=strides, padding=p, name=name)
13 |
14 | def batch_norm(x):
15 | assert isinstance(x, tf.Tensor)
16 | # reduce dimension 1, 2, 3, which would produce batch mean and batch variance.
17 | mean, var = tf.nn.moments(x, axes=[1, 2, 3])
18 | return tf.nn.batch_normalization(x, mean, var, 0, 1, 1e-5)
19 |
20 | def relu(x):
21 | assert isinstance(x, tf.Tensor)
22 | return tf.nn.relu(x)
23 |
24 | def deconv2d(x, W, strides=[1, 1, 1, 1], p='SAME', name=None):
25 | assert isinstance(x, tf.Tensor)
26 | _, _, c, _ = W.get_shape().as_list()
27 | b, h, w, _ = x.get_shape().as_list()
28 | return tf.nn.conv2d_transpose(x, W, [b, strides[1] * h, strides[1] * w, c], strides=strides, padding=p, name=name)
29 |
30 | def max_pool_2x2(x):
31 | assert isinstance(x, tf.Tensor)
32 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
33 |
34 |
35 | class ResidualBlock():
36 | def __init__(self, idx, ksize=3, train=False, data_dict=None):
37 | self.W1 = weight_variable([ksize, ksize, 128, 128], name='R'+ str(idx) + '_conv1_w')
38 | self.W2 = weight_variable([ksize, ksize, 128, 128], name='R'+ str(idx) + '_conv2_w')
39 | def __call__(self, x, idx, strides=[1, 1, 1, 1]):
40 | h = relu(batch_norm(conv2d(x, self.W1, strides, name='R' + str(idx) + '_conv1')))
41 | h = batch_norm(conv2d(h, self.W2, name='R' + str(idx) + '_conv2'))
42 | return x + h
43 |
44 |
45 | class FastStyleNet():
46 | def __init__(self, train=True, data_dict=None):
47 | self.c1 = weight_variable([9, 9, 3, 32], name='t_conv1_w')
48 | self.c2 = weight_variable([4, 4, 32, 64], name='t_conv2_w')
49 | self.c3 = weight_variable([4, 4, 64, 128], name='t_conv3_w')
50 | self.r1 = ResidualBlock(1, train=train)
51 | self.r2 = ResidualBlock(2, train=train)
52 | self.r3 = ResidualBlock(3, train=train)
53 | self.r4 = ResidualBlock(4, train=train)
54 | self.r5 = ResidualBlock(5, train=train)
55 | self.d1 = weight_variable([4, 4, 64, 128], name='t_dconv1_w')
56 | self.d2 = weight_variable([4, 4, 32, 64], name='t_dconv2_w')
57 | self.d3 = weight_variable([9, 9, 3, 32], name='t_dconv3_w')
58 | def __call__(self, h):
59 | h = batch_norm(relu(conv2d(h, self.c1, name='t_conv1')))
60 | h = batch_norm(relu(conv2d(h, self.c2, strides=[1, 2, 2, 1], name='t_conv2')))
61 | h = batch_norm(relu(conv2d(h, self.c3, strides=[1, 2, 2, 1], name='t_conv3')))
62 |
63 | h = self.r1(h, 1)
64 | h = self.r2(h, 2)
65 | h = self.r3(h, 3)
66 | h = self.r4(h, 4)
67 | h = self.r5(h, 5)
68 |
69 | h = batch_norm(relu(deconv2d(h, self.d1, strides=[1, 2, 2, 1], name='t_deconv1')))
70 | h = batch_norm(relu(deconv2d(h, self.d2, strides=[1, 2, 2, 1], name='t_deconv2')))
71 | y = deconv2d(h, self.d3, name='t_deconv3')
72 | return tf.multiply((tf.tanh(y) + 1), tf.constant(127.5, tf.float32, shape=y.get_shape()), name='output')
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, sys
3 | import argparse
4 | from PIL import Image
5 | from freeze_graph import freeze_graph
6 | import tensorflow as tf
7 | import time
8 |
9 | from net import *
10 | sys.path.append(os.path.join(os.path.dirname(sys.path[0]), "./"))
11 | from custom_vgg16 import *
12 |
13 |
14 | # gram matrix per layer
15 | def gram_matrix(x):
16 | assert isinstance(x, tf.Tensor)
17 | b, h, w, ch = x.get_shape().as_list()
18 | features = tf.reshape(x, [b, h*w, ch])
19 | # gram = tf.batch_matmul(features, features, adj_x=True)/tf.constant(ch*w*h, tf.float32)
20 | gram = tf.matmul(features, features, adjoint_a=True)/tf.constant(ch*w*h, tf.float32)
21 | return gram
22 |
23 | # total variation denoising
24 | def total_variation_regularization(x, beta=1):
25 | assert isinstance(x, tf.Tensor)
26 | wh = tf.constant([[[[ 1], [ 1], [ 1]]], [[[-1], [-1], [-1]]]], tf.float32)
27 | ww = tf.constant([[[[ 1], [ 1], [ 1]], [[-1], [-1], [-1]]]], tf.float32)
28 | tvh = lambda x: conv2d(x, wh, p='SAME')
29 | tvw = lambda x: conv2d(x, ww, p='SAME')
30 | dh = tvh(x)
31 | dw = tvw(x)
32 | tv = (tf.add(tf.reduce_sum(dh**2, [1, 2, 3]), tf.reduce_sum(dw**2, [1, 2, 3]))) ** (beta / 2.)
33 | return tv
34 |
35 | parser = argparse.ArgumentParser(description='Real-time style transfer')
36 | parser.add_argument('--gpu', '-g', default=-1, type=int,
37 | help='GPU ID (negative value indicates CPU)')
38 | parser.add_argument('--dataset', '-d', default='dataset', type=str,
39 | help='dataset directory path (according to the paper, use MSCOCO 80k images)')
40 | parser.add_argument('--style_image', '-s', type=str, required=True,
41 | help='style image path')
42 | parser.add_argument('--batchsize', '-b', type=int, default=1,
43 | help='batch size (default value is 1)')
44 | parser.add_argument('--ckpt', '-c', default=None, type=int,
45 | help='the global step of checkpoint file desired to restore.')
46 | parser.add_argument('--lambda_tv', '-l_tv', default=10e-4, type=float,
47 | help='weight of total variation regularization according to the paper to be set between 10e-4 and 10e-6.')
48 | parser.add_argument('--lambda_feat', '-l_feat', default=1e0, type=float)
49 | parser.add_argument('--lambda_style', '-l_style', default=1e1, type=float)
50 | parser.add_argument('--epoch', '-e', default=2, type=int)
51 | parser.add_argument('--lr', '-l', default=1e-3, type=float)
52 | parser.add_argument('--pb', '-pb', default=True, type=bool, help='save a pb format as well.')
53 | args = parser.parse_args()
54 |
55 | data_dict = loadWeightsData('./vgg16.npy')
56 |
57 | batchsize = args.batchsize
58 | gpu = args.gpu
59 | dataset = args.dataset
60 | epochs = args.epoch
61 | learning_rate = args.lr
62 | ckpt = args.ckpt
63 | lambda_tv = args.lambda_tv
64 | lambda_f = args.lambda_feat
65 | lambda_s = args.lambda_style
66 | style_image = args.style_image
67 | save_pb = args.pb
68 | gpu = args.gpu
69 |
70 | style_name, _ = os.path.splitext(style_image.split(os.sep)[-1])
71 |
72 | fpath = os.listdir(args.dataset)
73 | imagepaths = []
74 | for fn in fpath:
75 | base, ext = os.path.splitext(fn)
76 | if ext == '.jpg' or ext == '.png':
77 | imagepath = os.path.join(dataset, fn)
78 | imagepaths.append(imagepath)
79 | data_len = len(imagepaths)
80 | iterations = int(data_len / batchsize)
81 | print ('Number of traning images: {}'.format(data_len))
82 | print ('{} epochs, {} iterations per epoch'.format(epochs, iterations))
83 |
84 | style_np = np.asarray(Image.open(style_image).convert('RGB').resize((224, 224)), dtype=np.float32)
85 | styles_np = [style_np for x in range(batchsize)]
86 |
87 | if gpu > -1:
88 | device = '/gpu:{}'.format(gpu)
89 | else:
90 | device = '/cpu:0'
91 |
92 | with tf.device(device):
93 |
94 | inputs = tf.placeholder(tf.float32, shape=[batchsize, 224, 224, 3], name='input')
95 | net = FastStyleNet()
96 | saver = tf.train.Saver(restore_sequentially=True)
97 | saver_def = saver.as_saver_def()
98 |
99 |
100 | target = tf.placeholder(tf.float32, shape=[batchsize, 224, 224, 3])
101 | outputs = net(inputs)
102 |
103 | # style target feature
104 | # compute gram maxtrix of style target
105 | vgg_s = custom_Vgg16(target, data_dict=data_dict)
106 | feature_ = [vgg_s.conv1_2, vgg_s.conv2_2, vgg_s.conv3_3, vgg_s.conv4_3, vgg_s.conv5_3]
107 | gram_ = [gram_matrix(l) for l in feature_]
108 |
109 | # content target feature
110 | vgg_c = custom_Vgg16(inputs, data_dict=data_dict)
111 | feature_ = [vgg_c.conv1_2, vgg_c.conv2_2, vgg_c.conv3_3, vgg_c.conv4_3, vgg_c.conv5_3]
112 |
113 | # feature after transformation
114 | vgg = custom_Vgg16(outputs, data_dict=data_dict)
115 | feature = [vgg.conv1_2, vgg.conv2_2, vgg.conv3_3, vgg.conv4_3, vgg.conv5_3]
116 |
117 | # compute feature loss
118 | loss_f = tf.zeros(batchsize, tf.float32)
119 | for f, f_ in zip(feature, feature_):
120 | loss_f += lambda_f * tf.reduce_mean(tf.subtract(f, f_) ** 2, [1, 2, 3])
121 |
122 | # compute style loss
123 | gram = [gram_matrix(l) for l in feature]
124 | loss_s = tf.zeros(batchsize, tf.float32)
125 | for g, g_ in zip(gram, gram_):
126 | loss_s += lambda_s * tf.reduce_mean(tf.subtract(g, g_) ** 2, [1, 2])
127 |
128 | # total variation denoising
129 | loss_tv = lambda_tv * total_variation_regularization(outputs)
130 |
131 | # total loss
132 | loss = loss_s + loss_f + loss_tv
133 |
134 | # optimizer
135 | train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
136 |
137 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
138 |
139 | ckpt_directory = './ckpts/{}/'.format(style_name)
140 | if not os.path.exists(ckpt_directory):
141 | os.makedirs(ckpt_directory)
142 |
143 | # training
144 | tf.global_variables_initializer().run()
145 |
146 | if ckpt:
147 | if ckpt < 0:
148 | checkpoint = tf.train.get_checkpoint_state(ckpt_directory)
149 | input_checkpoint = checkpoint.model_checkpoint_path
150 | else:
151 | input_checkpoint = ckpt_directory + style_name + '-{}'.format(ckpt)
152 | saver.restore(sess, input_checkpoint)
153 | print ('Checkpoint {} restored.'.format(ckpt))
154 |
155 | for epoch in range(1, epochs + 1):
156 | imgs = np.zeros((batchsize, 224, 224, 3), dtype=np.float32)
157 | for i in range(iterations):
158 | for j in range(batchsize):
159 | p = imagepaths[i * batchsize + j]
160 | imgs[j] = np.asarray(Image.open(p).convert('RGB').resize((224, 224)), np.float32)
161 | feed_dict = {inputs: imgs, target: styles_np}
162 | loss_, _= sess.run([loss, train_step,], feed_dict=feed_dict)
163 | print('[epoch {}/{}] batch {}/{}... loss: {}'.format(epoch, epochs, i + 1, iterations, loss_[0]))
164 | saver.save(sess, ckpt_directory + style_name, global_step=epoch)
165 |
166 | if save_pb:
167 | if not os.path.exists('./pbs'):
168 | os.makedirs('./pbs')
169 | freeze_graph(ckpt_directory, './pbs/{}.pb'.format(style_name), 'output')
--------------------------------------------------------------------------------