├── pbs └── Robert_D.pb ├── images ├── content │ └── tubingen.jpg ├── styles │ └── Robert_D.jpg └── output │ └── RobertD_output.jpg ├── .gitignore ├── freeze_graph.py ├── custom_vgg16.py ├── generate.py ├── README.md ├── net.py └── train.py /pbs/Robert_D.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antlerros/tensorflow-fast-neuralstyle/HEAD/pbs/Robert_D.pb -------------------------------------------------------------------------------- /images/content/tubingen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antlerros/tensorflow-fast-neuralstyle/HEAD/images/content/tubingen.jpg -------------------------------------------------------------------------------- /images/styles/Robert_D.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antlerros/tensorflow-fast-neuralstyle/HEAD/images/styles/Robert_D.jpg -------------------------------------------------------------------------------- /images/output/RobertD_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antlerros/tensorflow-fast-neuralstyle/HEAD/images/output/RobertD_output.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask instance folder 57 | instance/ 58 | 59 | # Scrapy stuff: 60 | .scrapy 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyBuilder 66 | target/ 67 | 68 | # IPython Notebook 69 | .ipynb_checkpoints 70 | 71 | # pyenv 72 | .python-version 73 | 74 | # celery beat schedule file 75 | celerybeat-schedule 76 | 77 | # dotenv 78 | .env 79 | 80 | # virtualenv 81 | venv/ 82 | ENV/ 83 | 84 | # Spyder project settings 85 | .spyderproject 86 | 87 | # Rope project settings 88 | .ropeproject -------------------------------------------------------------------------------- /freeze_graph.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import tensorflow as tf 3 | from tensorflow.python.framework import graph_util 4 | 5 | dir = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | def freeze_graph(model_folder, output_graph, output_node_name): 8 | # We retrieve our checkpoint fullpath 9 | checkpoint = tf.train.get_checkpoint_state(model_folder) 10 | input_checkpoint = checkpoint.model_checkpoint_path 11 | 12 | 13 | # We precise the file fullname of our freezed graph 14 | absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1]) 15 | # output_graph = absolute_model_folder + "/{}.pb".format(style_name) 16 | 17 | # Before exporting our graph, we need to precise what is our output node 18 | # This is how TF decides what part of the Graph he has to keep and what part it can dump 19 | # NOTE: this variable is plural, because you can have multiple output nodes 20 | output_node_names = output_node_name 21 | 22 | # We clear devices to allow TensorFlow to control on which device it will load operations 23 | clear_devices = True 24 | 25 | # We import the meta graph and retrieve a Saver 26 | saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) 27 | 28 | # We retrieve the protobuf graph definition 29 | graph = tf.get_default_graph() 30 | input_graph_def = graph.as_graph_def() 31 | 32 | # We start a session and restore the graph weights 33 | with tf.Session() as sess: 34 | saver.restore(sess, input_checkpoint) 35 | 36 | # We use a built-in TF helper to export variables to constants 37 | output_graph_def = graph_util.convert_variables_to_constants( 38 | sess, # The session is used to retrieve the weights 39 | input_graph_def, # The graph_def is used to retrieve the nodes 40 | output_node_names.split(",") # The output node names are used to select the usefull nodes 41 | ) 42 | 43 | # Finally we serialize and dump the output graph to the filesystem 44 | with tf.gfile.GFile(output_graph, "wb") as f: 45 | f.write(output_graph_def.SerializeToString()) 46 | print("{} ops in the final graph.".format(len(output_graph_def.node))) -------------------------------------------------------------------------------- /custom_vgg16.py: -------------------------------------------------------------------------------- 1 | import os, sys, inspect 2 | import tensorflow as tf 3 | 4 | import numpy as np 5 | import time 6 | from tensorflow_vgg import vgg16 7 | 8 | VGG_MEAN = [103.939, 116.779, 123.68] 9 | 10 | def loadWeightsData(vgg16_npy_path=None): 11 | if vgg16_npy_path is None: 12 | path = inspect.getfile(Vgg16) 13 | path = os.path.abspath(os.path.join(path, os.pardir)) 14 | path = os.path.join(path, "vgg16.npy") 15 | vgg16_npy_path = path 16 | print (vgg16_npy_path) 17 | return np.load(vgg16_npy_path, encoding='latin1').item() 18 | 19 | class custom_Vgg16(vgg16.Vgg16): 20 | # Input should be an rgb image [batch, height, width, 3] 21 | # values scaled [0, 1] 22 | 23 | def __init__(self, rgb, data_dict, train=False): 24 | # It's a shared weights data and used in various 25 | # member functions. 26 | self.data_dict = data_dict 27 | 28 | # start_time = time.time() 29 | 30 | # rgb_scaled = rgb * 255.0 31 | rgb_scaled = rgb 32 | # Convert RGB to BGR 33 | red, green, blue = tf.split(rgb_scaled, 3, 3) 34 | 35 | bgr = tf.concat([blue - VGG_MEAN[0], 36 | green - VGG_MEAN[1], 37 | red - VGG_MEAN[2]], 38 | 3) 39 | 40 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 41 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 42 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 43 | 44 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 45 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 46 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 47 | 48 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 49 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 50 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 51 | self.pool3 = self.max_pool(self.conv3_3, 'pool3') 52 | 53 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 54 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 55 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 56 | self.pool4 = self.max_pool(self.conv4_3, 'pool4') 57 | 58 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 59 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 60 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 61 | self.pool5 = self.max_pool(self.conv5_3, 'pool5') 62 | 63 | # self.data_dict = None 64 | # print ("build model finished: %ds" % (time.time() - start_time)) 65 | 66 | def debug(self): 67 | pass 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import tensorflow as tf 4 | import os 5 | from PIL import Image 6 | 7 | parser = argparse.ArgumentParser(description='Real-time style transfer image generator') 8 | parser.add_argument('--input', '-i', type=str, help='content image') 9 | parser.add_argument('--gpu', '-g', default=-1, type=int, 10 | help='GPU ID (negative value indicates CPU)') 11 | parser.add_argument('--style', '-s', default=None, type=str, help='style model name') 12 | parser.add_argument('--ckpt', '-c', default=-1, type=int, help='checkpoint to be loaded') 13 | parser.add_argument('--out', '-o', default='stylized_image.jpg', type=str, help='stylized image\'s name') 14 | parser.add_argument('--pb', '-pb', default=False, type=bool, help='load with pb') 15 | 16 | args = parser.parse_args() 17 | 18 | if not os.path.exists('./images/output/'): 19 | os.makedirs('./images/output/') 20 | 21 | outfile_path = './images/output/' + args.out 22 | content_image_path = args.input 23 | style_name = args.style 24 | ckpt = args.ckpt 25 | load_with_pb = args.pb 26 | gpu = args.gpu 27 | 28 | original_image = Image.open(content_image_path).convert('RGB') 29 | 30 | img = np.asarray(original_image.resize((224, 224)), dtype=np.float32) 31 | shaped_input = img.reshape((1,) + img.shape) 32 | 33 | if gpu > -1: 34 | device = '/gpu:{}'.format(gpu) 35 | else: 36 | device = '/cpu:0' 37 | 38 | 39 | with tf.device(device): 40 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 41 | if load_with_pb: 42 | from tensorflow.core.framework import graph_pb2 43 | graph_def = graph_pb2.GraphDef() 44 | with open('./pbs/{}.pb'.format(style_name), "rb") as f: 45 | graph_def.ParseFromString(f.read()) 46 | input_image, output = tf.import_graph_def(graph_def, return_elements=['input:0', 'output:0']) 47 | 48 | else: 49 | if ckpt < 0: 50 | checkpoint = tf.train.get_checkpoint_state('./ckpts/{}/'.format(style_name)) 51 | input_checkpoint = checkpoint.model_checkpoint_path 52 | else: 53 | input_checkpoint = './ckpts/{}/{}-{}'.format(style_name, style_name, ckpt) 54 | saver = tf.train.import_meta_graph(input_checkpoint + '.meta') 55 | saver.restore(sess, input_checkpoint) 56 | graph = tf.get_default_graph() 57 | 58 | input_image = graph.get_tensor_by_name('input:0') 59 | output = graph.get_tensor_by_name('output:0') 60 | 61 | out = sess.run(output, feed_dict={input_image: shaped_input}) 62 | 63 | out = out.reshape((out.shape[1:])) 64 | im = Image.fromarray(np.uint8(out)) 65 | 66 | im = im.resize(original_image.size, resample=Image.LANCZOS) 67 | im.save(outfile_path) 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" 2 | Fast artistic style transfer by using feed forward network. 3 | 4 | 5 | 6 | 7 | 8 | 9 | - input image size: 1024x768 10 | - process time(CPU): 2.246 sec (Core i5-5257U) 11 | - process time(GPU): 1.728 sec (GPU GRID K520) 12 | 13 | 14 | ## Requirement 15 | - [Tensorflow 1.0](https://github.com/tensorflow/tensorflow) 16 | - [Pillow](https://github.com/python-pillow/Pillow) 17 | - [Numpy](https://github.com/numpy/numpy) 18 | - [Scipy](https://github.com/scipy/scipy) 19 | 20 | 21 | ## Prerequisite 22 | In this implementation, the VGG model part was based on [Tensorflow VGG16 and VGG19](https://github.com/machrisaa/tensorflow-vgg). Please add this as a submodule and follow the instructions there to have vgg16 model. Make sure the name of the module in your project matches the one in line 6 of`custom_vgg16.py`. 23 | 24 | ## Train a style model 25 | Need to train one image transformation network model per one style target. 26 | According to the paper, the models are trained on the [Microsoft COCO dataset](http://mscoco.org/dataset/#download). 27 | Also, it will save the transformation model, including the trained weights, for later use (in C++) in ```pbs``` directory, while the checkpoint files would be saved in ```ckpts//```. 28 | 29 | 30 | - ```.pb``` is saved by default. To turn off, add argument ```-pb 0```. 31 | - To train a model from scratch. 32 | ``` 33 | python train.py -s -d -g 0 34 | 35 | ``` 36 | - To load a pre-trained model, specify the checkpoint to load. Negative checkpoint value suggests using the latest checkpoint. 37 | ``` 38 | python train.py -s -d -c 39 | ``` 40 | 41 | ## Generate a stylized image 42 | 43 | ### Load with .pb file 44 | ``` 45 | python generate.py -i -o -s -pb 1 46 | ``` 47 | 48 | ### Load with checkpoint files 49 | - By default, the latest checkpoint file is used (negative value for the checkpoint argument). 50 | ``` 51 | python generate.py -s -o -c 52 | ``` 53 | 54 | ## Difference between implementation and paper 55 | - Convolution kernel size 4 instead of 3. 56 | - Training with batchsize(n >= 2) causes unstable result. 57 | 58 | ## License 59 | MIT 60 | 61 | ## Reference 62 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) 63 | 64 | Codes written in this repository based on following nice works, thanks to the author. 65 | 66 | - [Chainer implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution"](https://github.com/yusuketomoto/chainer-fast-neuralstyle) 67 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def weight_variable(shape, name=None): 5 | # initialize weighted variables. 6 | initial = tf.truncated_normal(shape, stddev=0.001) 7 | return tf.Variable(initial, name=name) 8 | 9 | def conv2d(x, W, strides=[1, 1, 1, 1], p='SAME', name=None): 10 | # set convolution layers. 11 | assert isinstance(x, tf.Tensor) 12 | return tf.nn.conv2d(x, W, strides=strides, padding=p, name=name) 13 | 14 | def batch_norm(x): 15 | assert isinstance(x, tf.Tensor) 16 | # reduce dimension 1, 2, 3, which would produce batch mean and batch variance. 17 | mean, var = tf.nn.moments(x, axes=[1, 2, 3]) 18 | return tf.nn.batch_normalization(x, mean, var, 0, 1, 1e-5) 19 | 20 | def relu(x): 21 | assert isinstance(x, tf.Tensor) 22 | return tf.nn.relu(x) 23 | 24 | def deconv2d(x, W, strides=[1, 1, 1, 1], p='SAME', name=None): 25 | assert isinstance(x, tf.Tensor) 26 | _, _, c, _ = W.get_shape().as_list() 27 | b, h, w, _ = x.get_shape().as_list() 28 | return tf.nn.conv2d_transpose(x, W, [b, strides[1] * h, strides[1] * w, c], strides=strides, padding=p, name=name) 29 | 30 | def max_pool_2x2(x): 31 | assert isinstance(x, tf.Tensor) 32 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 33 | 34 | 35 | class ResidualBlock(): 36 | def __init__(self, idx, ksize=3, train=False, data_dict=None): 37 | self.W1 = weight_variable([ksize, ksize, 128, 128], name='R'+ str(idx) + '_conv1_w') 38 | self.W2 = weight_variable([ksize, ksize, 128, 128], name='R'+ str(idx) + '_conv2_w') 39 | def __call__(self, x, idx, strides=[1, 1, 1, 1]): 40 | h = relu(batch_norm(conv2d(x, self.W1, strides, name='R' + str(idx) + '_conv1'))) 41 | h = batch_norm(conv2d(h, self.W2, name='R' + str(idx) + '_conv2')) 42 | return x + h 43 | 44 | 45 | class FastStyleNet(): 46 | def __init__(self, train=True, data_dict=None): 47 | self.c1 = weight_variable([9, 9, 3, 32], name='t_conv1_w') 48 | self.c2 = weight_variable([4, 4, 32, 64], name='t_conv2_w') 49 | self.c3 = weight_variable([4, 4, 64, 128], name='t_conv3_w') 50 | self.r1 = ResidualBlock(1, train=train) 51 | self.r2 = ResidualBlock(2, train=train) 52 | self.r3 = ResidualBlock(3, train=train) 53 | self.r4 = ResidualBlock(4, train=train) 54 | self.r5 = ResidualBlock(5, train=train) 55 | self.d1 = weight_variable([4, 4, 64, 128], name='t_dconv1_w') 56 | self.d2 = weight_variable([4, 4, 32, 64], name='t_dconv2_w') 57 | self.d3 = weight_variable([9, 9, 3, 32], name='t_dconv3_w') 58 | def __call__(self, h): 59 | h = batch_norm(relu(conv2d(h, self.c1, name='t_conv1'))) 60 | h = batch_norm(relu(conv2d(h, self.c2, strides=[1, 2, 2, 1], name='t_conv2'))) 61 | h = batch_norm(relu(conv2d(h, self.c3, strides=[1, 2, 2, 1], name='t_conv3'))) 62 | 63 | h = self.r1(h, 1) 64 | h = self.r2(h, 2) 65 | h = self.r3(h, 3) 66 | h = self.r4(h, 4) 67 | h = self.r5(h, 5) 68 | 69 | h = batch_norm(relu(deconv2d(h, self.d1, strides=[1, 2, 2, 1], name='t_deconv1'))) 70 | h = batch_norm(relu(deconv2d(h, self.d2, strides=[1, 2, 2, 1], name='t_deconv2'))) 71 | y = deconv2d(h, self.d3, name='t_deconv3') 72 | return tf.multiply((tf.tanh(y) + 1), tf.constant(127.5, tf.float32, shape=y.get_shape()), name='output') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys 3 | import argparse 4 | from PIL import Image 5 | from freeze_graph import freeze_graph 6 | import tensorflow as tf 7 | import time 8 | 9 | from net import * 10 | sys.path.append(os.path.join(os.path.dirname(sys.path[0]), "./")) 11 | from custom_vgg16 import * 12 | 13 | 14 | # gram matrix per layer 15 | def gram_matrix(x): 16 | assert isinstance(x, tf.Tensor) 17 | b, h, w, ch = x.get_shape().as_list() 18 | features = tf.reshape(x, [b, h*w, ch]) 19 | # gram = tf.batch_matmul(features, features, adj_x=True)/tf.constant(ch*w*h, tf.float32) 20 | gram = tf.matmul(features, features, adjoint_a=True)/tf.constant(ch*w*h, tf.float32) 21 | return gram 22 | 23 | # total variation denoising 24 | def total_variation_regularization(x, beta=1): 25 | assert isinstance(x, tf.Tensor) 26 | wh = tf.constant([[[[ 1], [ 1], [ 1]]], [[[-1], [-1], [-1]]]], tf.float32) 27 | ww = tf.constant([[[[ 1], [ 1], [ 1]], [[-1], [-1], [-1]]]], tf.float32) 28 | tvh = lambda x: conv2d(x, wh, p='SAME') 29 | tvw = lambda x: conv2d(x, ww, p='SAME') 30 | dh = tvh(x) 31 | dw = tvw(x) 32 | tv = (tf.add(tf.reduce_sum(dh**2, [1, 2, 3]), tf.reduce_sum(dw**2, [1, 2, 3]))) ** (beta / 2.) 33 | return tv 34 | 35 | parser = argparse.ArgumentParser(description='Real-time style transfer') 36 | parser.add_argument('--gpu', '-g', default=-1, type=int, 37 | help='GPU ID (negative value indicates CPU)') 38 | parser.add_argument('--dataset', '-d', default='dataset', type=str, 39 | help='dataset directory path (according to the paper, use MSCOCO 80k images)') 40 | parser.add_argument('--style_image', '-s', type=str, required=True, 41 | help='style image path') 42 | parser.add_argument('--batchsize', '-b', type=int, default=1, 43 | help='batch size (default value is 1)') 44 | parser.add_argument('--ckpt', '-c', default=None, type=int, 45 | help='the global step of checkpoint file desired to restore.') 46 | parser.add_argument('--lambda_tv', '-l_tv', default=10e-4, type=float, 47 | help='weight of total variation regularization according to the paper to be set between 10e-4 and 10e-6.') 48 | parser.add_argument('--lambda_feat', '-l_feat', default=1e0, type=float) 49 | parser.add_argument('--lambda_style', '-l_style', default=1e1, type=float) 50 | parser.add_argument('--epoch', '-e', default=2, type=int) 51 | parser.add_argument('--lr', '-l', default=1e-3, type=float) 52 | parser.add_argument('--pb', '-pb', default=True, type=bool, help='save a pb format as well.') 53 | args = parser.parse_args() 54 | 55 | data_dict = loadWeightsData('./vgg16.npy') 56 | 57 | batchsize = args.batchsize 58 | gpu = args.gpu 59 | dataset = args.dataset 60 | epochs = args.epoch 61 | learning_rate = args.lr 62 | ckpt = args.ckpt 63 | lambda_tv = args.lambda_tv 64 | lambda_f = args.lambda_feat 65 | lambda_s = args.lambda_style 66 | style_image = args.style_image 67 | save_pb = args.pb 68 | gpu = args.gpu 69 | 70 | style_name, _ = os.path.splitext(style_image.split(os.sep)[-1]) 71 | 72 | fpath = os.listdir(args.dataset) 73 | imagepaths = [] 74 | for fn in fpath: 75 | base, ext = os.path.splitext(fn) 76 | if ext == '.jpg' or ext == '.png': 77 | imagepath = os.path.join(dataset, fn) 78 | imagepaths.append(imagepath) 79 | data_len = len(imagepaths) 80 | iterations = int(data_len / batchsize) 81 | print ('Number of traning images: {}'.format(data_len)) 82 | print ('{} epochs, {} iterations per epoch'.format(epochs, iterations)) 83 | 84 | style_np = np.asarray(Image.open(style_image).convert('RGB').resize((224, 224)), dtype=np.float32) 85 | styles_np = [style_np for x in range(batchsize)] 86 | 87 | if gpu > -1: 88 | device = '/gpu:{}'.format(gpu) 89 | else: 90 | device = '/cpu:0' 91 | 92 | with tf.device(device): 93 | 94 | inputs = tf.placeholder(tf.float32, shape=[batchsize, 224, 224, 3], name='input') 95 | net = FastStyleNet() 96 | saver = tf.train.Saver(restore_sequentially=True) 97 | saver_def = saver.as_saver_def() 98 | 99 | 100 | target = tf.placeholder(tf.float32, shape=[batchsize, 224, 224, 3]) 101 | outputs = net(inputs) 102 | 103 | # style target feature 104 | # compute gram maxtrix of style target 105 | vgg_s = custom_Vgg16(target, data_dict=data_dict) 106 | feature_ = [vgg_s.conv1_2, vgg_s.conv2_2, vgg_s.conv3_3, vgg_s.conv4_3, vgg_s.conv5_3] 107 | gram_ = [gram_matrix(l) for l in feature_] 108 | 109 | # content target feature 110 | vgg_c = custom_Vgg16(inputs, data_dict=data_dict) 111 | feature_ = [vgg_c.conv1_2, vgg_c.conv2_2, vgg_c.conv3_3, vgg_c.conv4_3, vgg_c.conv5_3] 112 | 113 | # feature after transformation 114 | vgg = custom_Vgg16(outputs, data_dict=data_dict) 115 | feature = [vgg.conv1_2, vgg.conv2_2, vgg.conv3_3, vgg.conv4_3, vgg.conv5_3] 116 | 117 | # compute feature loss 118 | loss_f = tf.zeros(batchsize, tf.float32) 119 | for f, f_ in zip(feature, feature_): 120 | loss_f += lambda_f * tf.reduce_mean(tf.subtract(f, f_) ** 2, [1, 2, 3]) 121 | 122 | # compute style loss 123 | gram = [gram_matrix(l) for l in feature] 124 | loss_s = tf.zeros(batchsize, tf.float32) 125 | for g, g_ in zip(gram, gram_): 126 | loss_s += lambda_s * tf.reduce_mean(tf.subtract(g, g_) ** 2, [1, 2]) 127 | 128 | # total variation denoising 129 | loss_tv = lambda_tv * total_variation_regularization(outputs) 130 | 131 | # total loss 132 | loss = loss_s + loss_f + loss_tv 133 | 134 | # optimizer 135 | train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss) 136 | 137 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 138 | 139 | ckpt_directory = './ckpts/{}/'.format(style_name) 140 | if not os.path.exists(ckpt_directory): 141 | os.makedirs(ckpt_directory) 142 | 143 | # training 144 | tf.global_variables_initializer().run() 145 | 146 | if ckpt: 147 | if ckpt < 0: 148 | checkpoint = tf.train.get_checkpoint_state(ckpt_directory) 149 | input_checkpoint = checkpoint.model_checkpoint_path 150 | else: 151 | input_checkpoint = ckpt_directory + style_name + '-{}'.format(ckpt) 152 | saver.restore(sess, input_checkpoint) 153 | print ('Checkpoint {} restored.'.format(ckpt)) 154 | 155 | for epoch in range(1, epochs + 1): 156 | imgs = np.zeros((batchsize, 224, 224, 3), dtype=np.float32) 157 | for i in range(iterations): 158 | for j in range(batchsize): 159 | p = imagepaths[i * batchsize + j] 160 | imgs[j] = np.asarray(Image.open(p).convert('RGB').resize((224, 224)), np.float32) 161 | feed_dict = {inputs: imgs, target: styles_np} 162 | loss_, _= sess.run([loss, train_step,], feed_dict=feed_dict) 163 | print('[epoch {}/{}] batch {}/{}... loss: {}'.format(epoch, epochs, i + 1, iterations, loss_[0])) 164 | saver.save(sess, ckpt_directory + style_name, global_step=epoch) 165 | 166 | if save_pb: 167 | if not os.path.exists('./pbs'): 168 | os.makedirs('./pbs') 169 | freeze_graph(ckpt_directory, './pbs/{}.pb'.format(style_name), 'output') --------------------------------------------------------------------------------