├── examples ├── style │ ├── wave.jpg │ ├── africa.jpg │ ├── bango.jpg │ ├── udnie.jpg │ ├── aquarelle.jpg │ ├── hampson.jpg │ ├── la_muse.jpg │ ├── the_scream.jpg │ ├── chinese_style.jpg │ ├── rain_princess.jpg │ └── the_shipwreck_of_the_minotaur.jpg └── content │ ├── fox.mp4 │ ├── chicago.jpg │ └── stata.jpg ├── .gitignore ├── setup.sh ├── utils.py ├── transform_video.py ├── main.py ├── combine_videos.py ├── evaluate.py ├── solver.py ├── README.md ├── tf_utils.py └── style_transfer.py /examples/style/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/wave.jpg -------------------------------------------------------------------------------- /examples/content/fox.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/content/fox.mp4 -------------------------------------------------------------------------------- /examples/style/africa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/africa.jpg -------------------------------------------------------------------------------- /examples/style/bango.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/bango.jpg -------------------------------------------------------------------------------- /examples/style/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/udnie.jpg -------------------------------------------------------------------------------- /examples/content/chicago.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/content/chicago.jpg -------------------------------------------------------------------------------- /examples/content/stata.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/content/stata.jpg -------------------------------------------------------------------------------- /examples/style/aquarelle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/aquarelle.jpg -------------------------------------------------------------------------------- /examples/style/hampson.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/hampson.jpg -------------------------------------------------------------------------------- /examples/style/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/la_muse.jpg -------------------------------------------------------------------------------- /examples/style/the_scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/the_scream.jpg -------------------------------------------------------------------------------- /examples/style/chinese_style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/chinese_style.jpg -------------------------------------------------------------------------------- /examples/style/rain_princess.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/rain_princess.jpg -------------------------------------------------------------------------------- /examples/style/the_shipwreck_of_the_minotaur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/the_shipwreck_of_the_minotaur.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | checkpoints/ 4 | temp/ 5 | examples/temp 6 | examples/test 7 | examples/results 8 | logs/ 9 | checkpoints_original/ 10 | checkpoints_v2/ 11 | github_img/ 12 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Logan Engstrom 5 | # ------------------------------------------------------------ 6 | #! /bin/bash 7 | 8 | mkdir data 9 | cd data 10 | wget http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat 11 | mkdir bin 12 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip 13 | unzip train2014.zip 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import sys 9 | import scipy.misc 10 | import numpy as np 11 | 12 | 13 | def imread(path, is_gray_scale=False, img_size=None): 14 | if is_gray_scale: 15 | img = scipy.misc.imread(path, flatten=True).astype(np.float32) 16 | else: 17 | img = scipy.misc.imread(path, mode='RGB').astype(np.float32) 18 | 19 | if not (img.ndim == 3 and img.shape[2] == 3): 20 | img = np.dstack((img, img, img)) 21 | 22 | if img_size is not None: 23 | img = scipy.misc.imresize(img, img_size) 24 | 25 | return img 26 | 27 | 28 | def imsave(path, img): 29 | img = np.clip(img, 0, 255).astype(np.uint8) 30 | scipy.misc.imsave(path, img) 31 | 32 | 33 | def all_files_under(path, extension=None, append_path=True, sort=True): 34 | if append_path: 35 | if extension is None: 36 | filenames = [os.path.join(path, fname) for fname in os.listdir(path)] 37 | else: 38 | filenames = [os.path.join(path, fname) 39 | for fname in os.listdir(path) if fname.endswith(extension)] 40 | else: 41 | if extension is None: 42 | filenames = [os.path.basename(fname) for fname in os.listdir(path)] 43 | else: 44 | filenames = [os.path.basename(fname) 45 | for fname in os.listdir(path) if fname.endswith(extension)] 46 | 47 | if sort: 48 | filenames = sorted(filenames) 49 | 50 | return filenames 51 | 52 | 53 | def exists(p, msg): 54 | assert os.path.exists(p), msg 55 | 56 | 57 | def print_metrics(itr, kargs): 58 | print("*** Iteration {} ====> ".format(itr)) 59 | for name, value in kargs.items(): 60 | print("{} : {}, ".format(name, value)) 61 | print("") 62 | sys.stdout.flush() 63 | 64 | -------------------------------------------------------------------------------- /transform_video.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin, based on code from Logan Engstrom 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import numpy as np 9 | import tensorflow as tf 10 | import moviepy.video.io.ffmpeg_writer as ffmpeg_writer 11 | from moviepy.video.io.VideoFileClip import VideoFileClip 12 | 13 | import utils as utils 14 | from style_transfer import Transfer 15 | 16 | FLAGS = tf.flags.FLAGS 17 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index, default: 0') 18 | tf.flags.DEFINE_string('checkpoint_dir', 'checkpoints/la_muse', 19 | 'dir to read checkpoint in, default: ./checkpoints/la_muse') 20 | tf.flags.DEFINE_string('in_path', None, 'input video path') 21 | tf.flags.DEFINE_string('out_path', None, 'path to save processeced video to') 22 | 23 | 24 | def feed_forward_video(path_in, path_out, checkpoint_dir): 25 | # initialize video cap 26 | video_cap = VideoFileClip(path_in, audio=False) 27 | # initialize writer 28 | video_writer = ffmpeg_writer.FFMPEG_VideoWriter(path_out, video_cap.size, video_cap.fps, codec='libx264', 29 | preset='medium', bitrate='2000k', audiofile=path_in, 30 | threads=None, ffmpeg_params=None) 31 | 32 | g = tf.Graph() 33 | soft_config = tf.ConfigProto(allow_soft_placement=True) 34 | soft_config.gpu_options.allow_growth = True 35 | 36 | with g.as_default(), tf.Session(config=soft_config) as sess: 37 | batch_shape = (None, video_cap.size[1], video_cap.size[0], 3) 38 | img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') 39 | 40 | model = Transfer() 41 | pred = model(img_placeholder) 42 | saver = tf.train.Saver() 43 | 44 | if os.path.isdir(checkpoint_dir): 45 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 46 | if ckpt and ckpt.model_checkpoint_path: 47 | saver.restore(sess, ckpt.model_checkpoint_path) 48 | else: 49 | raise Exception('No checkpoint found...') 50 | else: 51 | saver.restore(sess, checkpoint_dir) 52 | 53 | frame_id = 0 54 | for frame in video_cap.iter_frames(): 55 | print('frame id: {}'.format(frame_id)) 56 | _pred = sess.run(pred, feed_dict={img_placeholder: np.asarray([frame]).astype(np.float32)}) 57 | video_writer.write_frame(np.clip(_pred, 0, 255).astype(np.uint8)) 58 | frame_id += 1 59 | 60 | video_writer.close() 61 | 62 | 63 | def check_opts(flags): 64 | utils.exists(flags.checkpoint_dir, 'checkpoint_dir not found!') 65 | utils.exists(flags.in_path, 'in_path not found!') 66 | 67 | 68 | def main(_): 69 | os.environ['CUDA_AVAILABLE_DEVICES'] = FLAGS.gpu_index 70 | check_opts(FLAGS) 71 | 72 | feed_forward_video(FLAGS.in_path, FLAGS.out_path, FLAGS.checkpoint_dir) 73 | 74 | 75 | if __name__ == '__main__': 76 | tf.app.run() 77 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import tensorflow as tf 9 | 10 | import utils as utils 11 | from solver import Solver 12 | 13 | FLAGS = tf.flags.FLAGS 14 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index, default: 0') 15 | tf.flags.DEFINE_string('checkpoint_dir', 'checkpoints', 'dir to save checkpoint in, default: ./checkpoints') 16 | 17 | tf.flags.DEFINE_string('style_img', 'examples/style/la_muse.jpg', 18 | 'style image path, default: ./examples/style/la_muse.jpg') 19 | tf.flags.DEFINE_string('train_path', '../Data/coco/img/train2014', 20 | 'path to training images folder, default: ../Data/coco/img/train2014') 21 | tf.flags.DEFINE_string('test_path', 'examples/content', 22 | 'test image path, default: ./examples/content') 23 | tf.flags.DEFINE_string('test_dir', './examples/temp', 'test image save dir, default: ./examples/temp') 24 | 25 | tf.flags.DEFINE_integer('epochs', 2, 'number of epochs for training data, default: 2') 26 | tf.flags.DEFINE_integer('batch_size', 4, 'batch size for single feed forward, default: 4') 27 | 28 | tf.flags.DEFINE_string('vgg_path', '../Models_zoo/imagenet-vgg-verydeep-19.mat', 29 | 'path to VGG19 network, default: ../Models_zoo/imagenet-vgg-verydeep-19.mat') 30 | tf.flags.DEFINE_float('content_weight', 7.5, 'content weight, default: 7.5') 31 | tf.flags.DEFINE_float('style_weight', 100., 'style weight, default: 100.') 32 | tf.flags.DEFINE_float('tv_weight', 200., 'total variation regularization weight, default: 200.') 33 | tf.flags.DEFINE_float('learning_rate', 0.001, 'learning rate, default: 1e-3') 34 | 35 | tf.flags.DEFINE_integer('print_freq', 100, 'print loss frequency, defalut: 100') 36 | tf.flags.DEFINE_integer('sample_freq', 2000, 'sample frequency, default: 2000') 37 | 38 | 39 | def check_opts(flags): 40 | utils.exists(flags.style_img, 'style path not found!') 41 | utils.exists(flags.train_path, 'train path not found!') 42 | utils.exists(flags.test_path, 'test image path not found!') 43 | utils.exists(flags.vgg_path, 'vgg network data not found!') 44 | 45 | assert flags.epochs > 0 46 | assert flags.batch_size > 0 47 | assert flags.print_freq > 0 48 | assert flags.sample_freq > 0 49 | assert flags.content_weight >= 0 50 | assert flags.style_weight >= 0 51 | assert flags.tv_weight >= 0 52 | assert flags.learning_rate >= 0 53 | 54 | print(flags.style_img) 55 | print(flags.style_img.split('/')[-1][:-4]) 56 | 57 | style_img_name = flags.style_img.split('/')[-1][:-4] # extract style image name 58 | fold_name = os.path.join(flags.checkpoint_dir, style_img_name) 59 | if not os.path.isdir(fold_name): 60 | os.makedirs(fold_name) 61 | 62 | fold_name = os.path.join(flags.test_dir, style_img_name) 63 | if not os.path.isdir(fold_name): 64 | os.makedirs(fold_name) 65 | 66 | 67 | def main(_): 68 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_index 69 | check_opts(FLAGS) 70 | 71 | solver = Solver(FLAGS) 72 | solver.train() 73 | 74 | 75 | if __name__ == '__main__': 76 | tf.app.run() 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /combine_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | 6 | parser = argparse.ArgumentParser(description='') 7 | parser.add_argument('--resize_ratio', dest='resize_ratio', type=float, default=0.4, help='resize iamge') 8 | parser.add_argument('--delay', dest='delay', type=int, default=1, help='interval between two frames') 9 | parser.add_argument('--style_size', dest='style_size', type=int, default=64, 10 | help='sylte image size in video') 11 | args = parser.parse_args() 12 | 13 | 14 | def main(): 15 | # Define the codec and create VideoWriter object 16 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 17 | out = cv2.VideoWriter('./examples/results/output.mp4', fourcc, 20.0, (768, 1024)) 18 | 19 | n_rows, n_cols = 4, 3 20 | video_path = './examples/{}/fox.mp4' 21 | style_path = './examples/style/' 22 | video_file = ['content', 'results/africa', 'results/aquarelle', 23 | 'results/bango', 'results/chinese_style', 'results/hampson', 24 | 'results/la_muse', 'results/rain_princess', 'results/the_scream', 25 | 'results/the_shipwreck_of_the_minotaur', 'results/udnie', 'results/wave'] 26 | img_file = ['africa.jpg', 'aquarelle.jpg', 27 | 'bango.jpg', 'chinese_style.jpg', 'hampson.jpg', 28 | 'la_muse.jpg', 'rain_princess.jpg', 'the_scream.jpg', 29 | 'the_shipwreck_of_the_minotaur.jpg', 'udnie.jpg', 'wave.jpg'] 30 | 31 | # initialize video captures & sylte images 32 | caps, styles = [], [] 33 | for file in video_file: 34 | caps.append(cv2.VideoCapture(video_path.format(file))) 35 | 36 | for file in img_file: 37 | styles.append(cv2.imread(os.path.join(style_path, file))) 38 | 39 | cv2.namedWindow('Show') 40 | cv2.moveWindow('Show', 0, 0) 41 | while True: 42 | # read frames 43 | frames = [] 44 | for idx in range(len(video_file)): 45 | rest, frame = caps[idx].read() 46 | 47 | if rest is False: 48 | print('Can not find frame!') 49 | break 50 | else: 51 | # resize original frame 52 | resized_frame = cv2.resize(frame, (int(frame.shape[0] * args.resize_ratio), 53 | int(frame.shape[1] * args.resize_ratio))) 54 | 55 | # past style image 56 | if idx >= 1: 57 | img = styles[idx-1] 58 | resized_img = cv2.resize(img, (args.style_size, args.style_size)) 59 | resized_frame[-args.style_size:, 0:args.style_size, :] = resized_img 60 | 61 | frames.append(resized_frame) 62 | 63 | # initialize canvas 64 | height, width, channel = frames[0].shape 65 | canvas = np.zeros((n_rows * height, n_cols * width, channel), dtype=np.uint8) 66 | 67 | for row in range(n_rows): 68 | for col in range(n_cols): 69 | canvas[row*height:(row+1)*height, col*width:(col+1)*width, :] = frames[row * n_cols + col] 70 | 71 | cv2.imshow('Show', canvas) 72 | if cv2.waitKey(args.delay) & 0xFF == 27: 73 | break 74 | 75 | # write the new frame 76 | out.write(canvas) 77 | 78 | # When everyting done, release the capture 79 | for idx in range(len(caps)): 80 | caps[idx].release() 81 | out.release() 82 | 83 | cv2.destroyAllWindows() 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin, based on code from Logan Engstrom 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import time 9 | import numpy as np 10 | import tensorflow as tf 11 | from collections import defaultdict 12 | 13 | from style_transfer import Transfer 14 | import utils as utils 15 | 16 | FLAGS = tf.flags.FLAGS 17 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index, default: 0') 18 | tf.flags.DEFINE_string('checkpoint_dir', 'checkpoints/la_muse', 19 | 'dir to read checkpoint in, default: ./checkpoints/la_muse') 20 | 21 | tf.flags.DEFINE_string('in_path', './examples/test', 'test image path, default: ./examples/test') 22 | tf.flags.DEFINE_string('out_path', './examples/results', 23 | 'destination dir of transformed files, default: ./examples/restuls') 24 | 25 | 26 | def feed_transform(data_in, paths_out, checkpoint_dir): 27 | img_shape = utils.imread(data_in[0]).shape 28 | 29 | g = tf.Graph() 30 | soft_config = tf.ConfigProto(allow_soft_placement=True) 31 | soft_config.gpu_options.allow_growth = True 32 | 33 | with g.as_default(), tf.Session(config=soft_config) as sess: 34 | img_placeholder = tf.placeholder(tf.float32, shape=[None, *img_shape], name='img_placeholder') 35 | 36 | model = Transfer() 37 | pred = model(img_placeholder) 38 | 39 | saver = tf.train.Saver() 40 | if os.path.isdir(checkpoint_dir): 41 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 42 | if ckpt and ckpt.model_checkpoint_path: 43 | saver.restore(sess, ckpt.model_checkpoint_path) 44 | else: 45 | raise Exception('No checkpoint found...') 46 | else: 47 | saver.restore(sess, checkpoint_dir) 48 | 49 | img = np.asarray([utils.imread(data_in[0])]).astype(np.float32) 50 | start_tic = time.time() 51 | _pred = sess.run(pred, feed_dict={img_placeholder: img}) 52 | end_toc = time.time() 53 | print('PT: {:.2f} msec.\n'.format((end_toc - start_tic) * 1000)) 54 | utils.imsave(paths_out[0], _pred[0]) # paths_out and _pred is list 55 | 56 | 57 | def feed_forward(in_paths, out_paths, checkpoint_dir): 58 | in_path_of_shape = defaultdict(list) 59 | out_path_of_shape = defaultdict(list) 60 | 61 | for idx in range(len(in_paths)): 62 | in_image = in_paths[idx] 63 | out_image = out_paths[idx] 64 | 65 | shape = "%dx%dx%d" % utils.imread(in_image).shape 66 | in_path_of_shape[shape].append(in_image) 67 | out_path_of_shape[shape].append(out_image) 68 | 69 | for shape in in_path_of_shape: 70 | print('Processing images of shape {}'.format(shape)) 71 | feed_transform(in_path_of_shape[shape], out_path_of_shape[shape], checkpoint_dir) 72 | 73 | 74 | def check_opts(flags): 75 | utils.exists(flags.checkpoint_dir, 'checkpoint_dir not found!') 76 | utils.exists(flags.in_path, 'in_path not found!') 77 | 78 | style_name = FLAGS.checkpoint_dir.split('/')[-1] 79 | if not os.path.isdir(os.path.join(flags.out_path, style_name)): 80 | os.makedirs(os.path.join(flags.out_path, style_name)) 81 | 82 | 83 | def main(_): 84 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_index 85 | check_opts(FLAGS) 86 | 87 | style_name = FLAGS.checkpoint_dir.split('/')[-1] 88 | img_paths = utils.all_files_under(FLAGS.in_path) 89 | out_paths = [os.path.join(FLAGS.out_path, style_name, file) 90 | for file in utils.all_files_under(FLAGS.in_path, append_path=False)] 91 | 92 | feed_forward(img_paths, out_paths, FLAGS.checkpoint_dir) 93 | 94 | 95 | if __name__ == '__main__': 96 | tf.app.run() 97 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import random 9 | import numpy as np 10 | import tensorflow as tf 11 | from datetime import datetime 12 | 13 | import tf_utils as tf_utils 14 | import utils as utils 15 | from style_transfer import StyleTranser, Transfer 16 | 17 | 18 | class Solver(object): 19 | def __init__(self, flags): 20 | run_config = tf.ConfigProto() 21 | run_config.gpu_options.allow_growth = True 22 | self.sess = tf.Session(config=run_config) 23 | 24 | self.flags = flags 25 | self.style_img_name = flags.style_img.split('/')[-1][:-4] 26 | self.content_target_paths = utils.all_files_under(self.flags.train_path) 27 | 28 | self.test_targets = utils.all_files_under(self.flags.test_path, extension='.jpg') 29 | 30 | self.test_target_names = utils.all_files_under(self.flags.test_path, append_path=False, extension='.jpg') 31 | self.test_save_paths = [os.path.join(self.flags.test_dir, self.style_img_name, file[:-4]) 32 | for file in self.test_target_names] 33 | 34 | self.num_contents = len(self.content_target_paths) 35 | self.num_iters = int(self.num_contents / self.flags.batch_size) * self.flags.epochs 36 | 37 | self.model = StyleTranser(self.sess, self.flags, self.num_iters) 38 | 39 | self.train_writer = tf.summary.FileWriter('logs/{}'.format(self.style_img_name), graph_def=self.sess.graph_def) 40 | self.saver = tf.train.Saver() 41 | self.sess.run(tf.global_variables_initializer()) 42 | tf_utils.show_all_variables() 43 | 44 | def train(self): 45 | random.seed(datetime.now()) # set random sedd 46 | 47 | for iter_time in range(self.num_iters): 48 | # sampling images and save them 49 | self.sample(iter_time) 50 | 51 | # read batch data and feed forward 52 | batch_imgs = self.next_batch() 53 | loss, summary = self.model.train_step(batch_imgs) 54 | 55 | # write log to tensorboard 56 | self.train_writer.add_summary(summary, iter_time) 57 | self.train_writer.flush() 58 | 59 | # print loss information 60 | self.model.print_info(loss, iter_time) 61 | 62 | # save model at the end 63 | self.save_model() 64 | 65 | def save_model(self): 66 | model_name = 'model' 67 | self.saver.save(self.sess, os.path.join(self.flags.checkpoint_dir, self.style_img_name, model_name)) 68 | print('=====================================') 69 | print(' Model saved! ') 70 | print('=====================================\n') 71 | 72 | def sample(self, iter_time): 73 | if np.mod(iter_time, self.flags.sample_freq) == 0: 74 | self.save_model() # save model before sample examples 75 | 76 | for idx in range(len(self.test_save_paths)): 77 | save_path = (self.test_save_paths[idx] + '_%s.png' % iter_time) 78 | 79 | print('save path: {}'.format(save_path)) 80 | print('test_target: {}'.format(self.test_targets[idx])) 81 | 82 | self.feed_transform([self.test_targets[idx]], [save_path]) 83 | 84 | def next_batch(self): 85 | batch_imgs = [] 86 | batch_files = np.random.choice(self.content_target_paths, self.flags.batch_size, replace=False) 87 | 88 | for batch_file in batch_files: 89 | img = utils.imread(batch_file, img_size=(256, 256, 3)) 90 | batch_imgs.append(img) 91 | 92 | return np.asarray(batch_imgs) 93 | 94 | def feed_transform(self, data_in, paths_out): 95 | checkpoint_dir = os.path.join(self.flags.checkpoint_dir, self.style_img_name, 'model') 96 | img_shape = utils.imread(data_in[0]).shape 97 | 98 | g = tf.Graph() 99 | soft_config = tf.ConfigProto(allow_soft_placement=True) 100 | soft_config.gpu_options.allow_growth = True 101 | 102 | with g.as_default(), tf.Session(config=soft_config) as sess: 103 | img_placeholder = tf.placeholder(tf.float32, shape=[None, *img_shape], name='img_placeholder') 104 | 105 | model = Transfer() 106 | pred = model(img_placeholder) 107 | 108 | saver = tf.train.Saver() 109 | if os.path.isdir(checkpoint_dir): 110 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 111 | if ckpt and ckpt.model_checkpoint_path: 112 | saver.restore(sess, ckpt.model_checkpoint_path) 113 | else: 114 | raise Exception('No checkpoint found...') 115 | else: 116 | saver.restore(sess, checkpoint_dir) 117 | 118 | img = np.asarray([utils.imread(data_in[0])]).astype(np.float32) 119 | _pred = sess.run(pred, feed_dict={img_placeholder: img}) 120 | utils.imsave(paths_out[0], _pred[0]) # paths_out and _pred is list 121 | 122 | 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-Time Style Transfer in [TensorFlow](https://github.com/tensorflow/tensorflow) 2 | This repository is Tensorflow implementation of Johnson's [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155). 3 | 4 |

5 | 6 |

7 | 8 |

9 | 10 |

11 | 12 | It takes 385 ms on a GTX1080Ti to style the MIT Stata Center (1024x680). 13 | 14 | ## Requirements 15 | - tensorflow 1.8.0 16 | - python 3.5.3 17 | - numpy 1.14.2 18 | - scipy 0.19.0 19 | - moviepy 0.2.3.2 20 | 21 | ## Video Stylization 22 | Here we transformed every frame in a video using various stylle images, then combined the results. [Click to go to the full demo on YouTube!](https://www.youtube.com/watch?v=HpKMLA19zkg&feature=youtu.be) 23 | 24 |

25 | 26 | 27 | 28 |

29 | 30 | ## Image Stylization 31 | A photo of Chicago was applied for various style paintings. Click on the ./examples/style folder to see full applied style images. **For more results you can find [here](https://www.dropbox.com/sh/8x0jkgywxhrtb7o/AAC2YmgDr08D0FQ6PzeIF5fwa?dl=0)**. 32 | 33 |

34 | 35 |

36 | 37 | ## Implementation Details 38 | Implementation uses TensorFlow to train a real-time style transfer network. Same transformation network is used as described in Johnson, except that batch normalization is replaced with Ulyanov's instance normalization, zero padding is replaced by reflected padding to reduce boundary artifacts, and the scaling/offset of the output `tanh` layer is slightly different. 39 | 40 | We follow [Logan Engstrom](https://github.com/lengstrom/fast-style-transfer) to use a loss function close to the one described in Gatys, using VGG19 instead of VGG16 and typically using "shallower" layers than in Johson's implementation (e.g. `relu1_1` is used rather than `relu1_2`). 41 | 42 | ## Documentation 43 | ### Training Style Transfer Networks 44 | Use `main.py` to train a new style transform network. Training takes 6-8 hours on a GTX 1080Ti. **Before you run this, you should run `setup.sh`**. Example usage: 45 | 46 | ``` 47 | python main.py --style_img path/to/style/img.jpg \ 48 | --train_path path/to/trainng/data/fold \ 49 | --test_path path/to/test/data/fold \ 50 | --vgg_path path/to/vgg19/imagenet-vgg-verydeep-19.mat 51 | ``` 52 | - `--gpu_index`: gpu index, default: `0` 53 | - `--checkpoint_dir`: dir to save checkpoint in, default: `./checkpoints` 54 | - `--style_img`: style image path, default: `./examples/style/la_muse.jpg` 55 | - `--train_path`: path to trainng images folder, default: `../Data/coco/img/train2014` 56 | - `--test_path`: test image path, default: `./examples/content` 57 | - `--test_dir`: test oa,ge save dor. default: `./examples/temp` 58 | - `--epochs`: number of epochs for training data, default: `2` 59 | - `--batch_size`: batch size for single feed forward, default: `4` 60 | - `--vgg_path`: path to VGG19 network, default: `../Models_zoo/imagenet-vgg-verydeep-19.mat` 61 | - `--content_weight`: content weight, default: `7.5` 62 | - `--style_weight`: style weight, default: `100.` 63 | - `--tv_weight`: total variation regularization weight, default: `200.` 64 | - `--print_freq`: print loss frequency, default: `100` 65 | - `--sample_freq`: sample frequency, default: `2000` 66 | 67 | ### Evaluating Style Transfer Networks 68 | Use `evaluate.py` to evaluate a style transfer network. Evaluation takes 300 ms per frame on a GTX 1080Ti. Takes several seconds per frame on a CPU. **Models for evaluation are [located here](https://www.dropbox.com/sh/wh067d88o0ylcha/AABpYBTnufQiMQeVHXqYhdZXa?dl=0)**. Example usage: 69 | ``` 70 | python evaluate.py --checkpoint_dir path/to/checkpoint / 71 | --in_path path/to/test/image/folder 72 | ``` 73 | - `--gpu_index`: gpu index, default: `0` 74 | - `--checkpoint_dir`: dir to read checkpoint in, default: `./checkpoints/la_muse` 75 | - `--in_path`: test image path, default: `./examples/test` 76 | - `--out_path`: destination dir of transformed files, default: `./examples/results` 77 | 78 | ### Stylizing Video 79 | Use `transform_video.py` to transfer style into a video. Requires `moviepy`. Example usage: 80 | ``` 81 | python transform_video.py --checkpoint_dir path/to/checkpoint / 82 | --in_path path/to/input/video.mp4 / 83 | --out_path path/to/write/predicted_video.mp4 84 | ``` 85 | - `--gpu_index`: gpu index, default: `0` 86 | - `--checkpoint_dir`: dir to read checkpoint in, default: `./checkpoints/la_muse` 87 | - `--in_path`: input video path, default: `None` 88 | - `--out_path`: path to save processed video to, default: `None` 89 | 90 | ### Citation 91 | ``` 92 | @misc{chengbinjin2018realtimestyletransfer, 93 | author = {Cheng-Bin Jin}, 94 | title = {Real-Time Style Transfer}, 95 | year = {2018}, 96 | howpublished = {\url{https://github.com/ChengBinJin/Real-time-style-transfer/}}, 97 | note = {commit xxxxxxx} 98 | } 99 | ``` 100 | 101 | ### Attributions/Thanks 102 | - This project borrowed some code from [Logan Engstrom](https://github.com/lengstrom/fast-style-transfer) adnd Anish's [Neural Style](https://github.com/anishathalye/neural-style/) 103 | - Some readme formatting was borrowed from [Logan Engstrom](https://github.com/lengstrom/fast-style-transfer) 104 | - The image of the MIT Stata Center at the very beginning of the README was taken by [Juan Paulo](https://juanpaulo.me/) 105 | 106 | ## License 107 | Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained. 108 | -------------------------------------------------------------------------------- /tf_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Tensorflow Utils Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | from tensorflow.python.training import moving_averages 10 | 11 | 12 | def padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='pad2d'): 13 | if pad_type == 'REFLECT': 14 | return tf.pad(x, [[0, 0], [p_h, p_h], [p_w, p_w], [0, 0]], 'REFLECT', name=name) 15 | 16 | 17 | def conv2d(x, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, padding='SAME', name='conv2d', is_print=True): 18 | with tf.variable_scope(name): 19 | w = tf.get_variable('w', [k_h, k_w, x.get_shape()[-1], output_dim], 20 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 21 | conv = tf.nn.conv2d(x, w, strides=[1, d_h, d_w, 1], padding=padding) 22 | 23 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 24 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 25 | conv = tf.nn.bias_add(conv, biases) 26 | 27 | if is_print: 28 | print_activations(conv) 29 | 30 | return conv 31 | 32 | 33 | def deconv2d(x, k, k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, padding_='SAME', output_size=None, 34 | name='deconv2d', with_w=False): 35 | with tf.variable_scope(name): 36 | input_shape = x.get_shape().as_list() 37 | 38 | # calculate output size 39 | h_output, w_output = None, None 40 | if not output_size: 41 | h_output, w_output = input_shape[1] * 2, input_shape[2] * 2 42 | # output_shape = [input_shape[0], h_output, w_output, k] # error when not define batch_size 43 | output_shape = [tf.shape(x)[0], h_output, w_output, k] 44 | 45 | # conv2d transpose 46 | w = tf.get_variable('w', [k_h, k_w, k, input_shape[3]], 47 | initializer=tf.random_normal_initializer(stddev=stddev)) 48 | deconv = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=[1, d_h, d_w, 1], 49 | padding=padding_) 50 | 51 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 52 | deconv = tf.nn.bias_add(deconv, biases) 53 | 54 | if with_w: 55 | return deconv, w, biases 56 | else: 57 | return deconv 58 | 59 | 60 | def upsampling2d(x, size=(2, 2), name='upsampling2d'): 61 | with tf.name_scope(name): 62 | shape = x.get_shape().as_list() 63 | return tf.image.resize_nearest_neighbor(x, size=(size[0] * shape[1], size[1] * shape[2])) 64 | 65 | 66 | def linear(x, output_size, bias_start=0.0, with_w=False, name='fc'): 67 | shape = x.get_shape().as_list() 68 | # print('shape: ', shape) 69 | 70 | with tf.variable_scope(name): 71 | matrix = tf.get_variable(name="matrix", shape=[shape[1], output_size], 72 | dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 73 | bias = tf.get_variable(name="bias", shape=[output_size], 74 | initializer=tf.constant_initializer(bias_start)) 75 | if with_w: 76 | return tf.matmul(x, matrix) + bias, matrix, bias 77 | else: 78 | return tf.matmul(x, matrix) + bias 79 | 80 | 81 | def norm(x, name, _type, _ops, is_train=True): 82 | if _type == 'batch': 83 | return batch_norm(x, name=name, _ops=_ops, is_train=is_train) 84 | elif _type == 'instance': 85 | return instance_norm(x, name=name) 86 | else: 87 | raise NotImplementedError 88 | 89 | 90 | def batch_norm(x, name, _ops, is_train=True): 91 | """Batch normalization.""" 92 | with tf.variable_scope(name): 93 | params_shape = [x.get_shape()[-1]] 94 | 95 | beta = tf.get_variable('beta', params_shape, tf.float32, 96 | initializer=tf.constant_initializer(0.0, tf.float32)) 97 | gamma = tf.get_variable('gamma', params_shape, tf.float32, 98 | initializer=tf.constant_initializer(1.0, tf.float32)) 99 | 100 | if is_train is True: 101 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') 102 | 103 | moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32, 104 | initializer=tf.constant_initializer(0.0, tf.float32), 105 | trainable=False) 106 | moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32, 107 | initializer=tf.constant_initializer(1.0, tf.float32), 108 | trainable=False) 109 | 110 | _ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9)) 111 | _ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9)) 112 | else: 113 | mean = tf.get_variable('moving_mean', params_shape, tf.float32, 114 | initializer=tf.constant_initializer(0.0, tf.float32), trainable=False) 115 | variance = tf.get_variable('moving_variance', params_shape, tf.float32, trainable=False) 116 | 117 | # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net. 118 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5) 119 | y.set_shape(x.get_shape()) 120 | 121 | return y 122 | 123 | 124 | def instance_norm(x, name='instance_norm', mean=1.0, stddev=0.02, epsilon=1e-5): 125 | with tf.variable_scope(name): 126 | depth = x.get_shape()[3] 127 | scale = tf.get_variable( 128 | 'scale', [depth], tf.float32, 129 | initializer=tf.random_normal_initializer(mean=mean, stddev=stddev, dtype=tf.float32)) 130 | offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0)) 131 | 132 | # calcualte mean and variance as instance 133 | mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 134 | 135 | # normalization 136 | inv = tf.rsqrt(variance + epsilon) 137 | normalized = (x - mean) * inv 138 | 139 | return scale * normalized + offset 140 | 141 | 142 | def n_res_blocks(x, _ops=None, norm_='instance', is_train=True, num_blocks=6, is_print=False): 143 | output = None 144 | for idx in range(1, num_blocks+1): 145 | output = res_block(x, x.get_shape()[3], _ops=_ops, norm_=norm_, is_train=is_train, 146 | name='res{}'.format(idx)) 147 | x = output 148 | 149 | if is_print: 150 | print_activations(output) 151 | 152 | return output 153 | 154 | 155 | # norm(x, name, _type, _ops, is_train=True) 156 | def res_block(x, k, _ops=None, norm_='instance', is_train=True, pad_type=None, name=None): 157 | with tf.variable_scope(name): 158 | conv1, conv2 = None, None 159 | 160 | # 3x3 Conv-Batch-Relu S1 161 | with tf.variable_scope('layer1'): 162 | if pad_type is None: 163 | conv1 = conv2d(x, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv') 164 | elif pad_type == 'REFLECT': 165 | padded1 = padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='padding') 166 | conv1 = conv2d(padded1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv') 167 | normalized1 = norm(conv1, name='norm', _type=norm_, _ops=_ops, is_train=is_train) 168 | relu1 = tf.nn.relu(normalized1) 169 | 170 | # 3x3 Conv-Batch S1 171 | with tf.variable_scope('layer2'): 172 | if pad_type is None: 173 | conv2 = conv2d(relu1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv') 174 | elif pad_type == 'REFLECT': 175 | padded2 = padding2d(relu1, p_h=1, p_w=1, pad_type='REFLECT', name='padding') 176 | conv2 = conv2d(padded2, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv') 177 | normalized2 = norm(conv2, name='norm', _type=norm_, _ops=_ops, is_train=is_train) 178 | 179 | # sum layer1 and layer2 180 | output = x + normalized2 181 | return output 182 | 183 | 184 | def identity(x, name='identity', is_print=False): 185 | output = tf.identity(x, name=name) 186 | if is_print: 187 | print_activations(output) 188 | 189 | return output 190 | 191 | 192 | def max_pool_2x2(x, name='max_pool'): 193 | with tf.name_scope(name): 194 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 195 | 196 | 197 | def sigmoid(x, name='sigmoid', is_print=False): 198 | output = tf.nn.sigmoid(x, name=name) 199 | if is_print: 200 | print_activations(output) 201 | 202 | return output 203 | 204 | 205 | def tanh(x, name='tanh', is_print=False): 206 | output = tf.nn.tanh(x, name=name) 207 | if is_print: 208 | print_activations(output) 209 | 210 | return output 211 | 212 | 213 | def relu(x, name='relu', is_print=False): 214 | output = tf.nn.relu(x, name=name) 215 | if is_print: 216 | print_activations(output) 217 | 218 | return output 219 | 220 | 221 | def lrelu(x, leak=0.2, name='lrelu', is_print=False): 222 | output = tf.maximum(x, leak*x, name=name) 223 | if is_print: 224 | print_activations(output) 225 | 226 | return output 227 | 228 | 229 | def xavier_init(in_dim): 230 | print('in_dim: ', in_dim) 231 | xavier_stddev = 1. / tf.sqrt(in_dim / 2.) 232 | return xavier_stddev 233 | 234 | 235 | def print_activations(t): 236 | print(t.op.name, ' ', t.get_shape().as_list()) 237 | 238 | 239 | def show_all_variables(): 240 | model_vars = tf.trainable_variables() 241 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 242 | 243 | 244 | def batch_convert2int(images): 245 | # images: 4D float tensor (batch_size, image_size, image_size, depth) 246 | return tf.map_fn(convert2int, images, dtype=tf.uint8) 247 | 248 | 249 | def convert2int(image): 250 | # transform from float tensor ([-1.,1.]) to int image ([0,255]) 251 | return tf.image.convert_image_dtype((image + 1.0) / 2.0, tf.uint8) 252 | -------------------------------------------------------------------------------- /style_transfer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin, based on code from Logan Engstrom 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import functools 8 | import collections 9 | import tensorflow as tf 10 | import numpy as np 11 | import scipy.io 12 | from operator import mul 13 | 14 | import tf_utils as tf_utils 15 | import utils as utils 16 | 17 | 18 | class StyleTranser(object): 19 | def __init__(self, sess, flags, num_iters): 20 | self.sess = sess 21 | self.flags = flags 22 | self.num_iters = num_iters 23 | 24 | self.style_target = np.asarray([utils.imread(self.flags.style_img)]) # [H, W, C] -> [1, H, W, C] 25 | self.style_shape = self.style_target.shape 26 | self.content_shape = [None, 256, 256, 3] 27 | 28 | self.style_layers = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1') 29 | # self.style_layers = ('relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3') # original paper 30 | self.content_layer = 'relu4_2' 31 | # self.content_layer = 'relu2_2' # original paper 32 | 33 | self.style_target_gram = {} 34 | self.content_loss, self.style_loss, self.tv_loss = None, None, None 35 | 36 | self._build_net() 37 | self._tensorboard() 38 | 39 | def _build_net(self): 40 | # ph: tensorflow placeholder 41 | self.style_img_ph = tf.placeholder(tf.float32, shape=self.style_shape, name='style_img') 42 | self.content_img_ph = tf.placeholder(tf.float32, shape=self.content_shape, name='content_img') 43 | 44 | self.transfer = Transfer() 45 | self.vgg = VGG19(self.flags.vgg_path) 46 | 47 | # step 1: extract style_target feature 48 | vgg_dic = self.vgg(self.style_img_ph) 49 | for layer in self.style_layers: 50 | features = self.sess.run(vgg_dic[layer], feed_dict={self.style_img_ph: self.style_target}) 51 | features = np.reshape(features, (-1, features.shape[3])) 52 | gram = np.matmul(features.T, features) / features.size 53 | self.style_target_gram[layer] = gram 54 | 55 | # step 2: extract content_target feature 56 | content_target_feature = {} 57 | vgg_content_dic = self.vgg(self.content_img_ph, is_reuse=True) 58 | content_target_feature[self.content_layer] = vgg_content_dic[self.content_layer] 59 | 60 | # step 3: tranfer content image to predicted image 61 | self.preds = self.transfer(self.content_img_ph/255.0) 62 | # step 4: extract vgg feature of the predicted image 63 | preds_dict = self.vgg(self.preds, is_reuse=True) 64 | 65 | # self.sample_pred = self.transfer(self.sample_img_ph/255.0, is_reuse=True) 66 | 67 | self.content_loss_func(preds_dict, content_target_feature) 68 | self.style_loss_func(preds_dict) 69 | self.tv_loss_func(self.preds) 70 | 71 | self.total_loss = self.content_loss + self.style_loss + self.tv_loss 72 | self.optim = tf.train.AdamOptimizer(learning_rate=self.flags.learning_rate).minimize(self.total_loss) 73 | 74 | def _tensorboard(self): 75 | tf.summary.scalar('loss/content_loss', self.content_loss) 76 | tf.summary.scalar('loss/style_loss', self.style_loss) 77 | tf.summary.scalar('loss/tv_loss', self.tv_loss) 78 | tf.summary.scalar('loss/total_loss', self.total_loss) 79 | 80 | self.summary_op = tf.summary.merge_all() 81 | 82 | def content_loss_func(self, preds_dict, content_target_feature): 83 | # calucate content size and check the feature dimension between content and predicted image 84 | content_size = self._tensor_size(content_target_feature[self.content_layer]) * self.flags.batch_size 85 | assert self._tensor_size(content_target_feature[self.content_layer]) == self._tensor_size( 86 | preds_dict[self.content_layer]) 87 | 88 | self.content_loss = self.flags.content_weight * (2 * tf.nn.l2_loss( 89 | preds_dict[self.content_layer] - content_target_feature[self.content_layer]) / content_size) 90 | 91 | def style_loss_func(self, preds_dict): 92 | style_losses = [] 93 | for style_layer in self.style_layers: 94 | layer = preds_dict[style_layer] 95 | _, height, width, num_filters = map(lambda i: i.value, layer.get_shape()) 96 | feature_size = height * width * num_filters 97 | feats = tf.reshape(layer, (tf.shape(layer)[0], height * width, num_filters)) 98 | feats_trans = tf.transpose(feats, perm=[0, 2, 1]) 99 | grams = tf.matmul(feats_trans, feats) / feature_size 100 | style_gram = self.style_target_gram[style_layer] 101 | style_losses.append(2 * tf.nn.l2_loss(grams - style_gram) / style_gram.size) 102 | 103 | self.style_loss = self.flags.style_weight * functools.reduce(tf.add, style_losses) / self.flags.batch_size 104 | 105 | def tv_loss_func(self, preds): 106 | # total variation denoising 107 | tv_y_size = self._tensor_size(preds[:, 1:, :, :]) 108 | tv_x_size = self._tensor_size(preds[:, :, 1:, :]) 109 | 110 | y_tv = tf.nn.l2_loss(preds[:, 1:, :, :] - preds[:, :self.content_shape[1]-1, :, :]) 111 | x_tv = tf.nn.l2_loss(preds[:, :, 1:, :] - preds[:, :, :self.content_shape[2]-1, :]) 112 | self.tv_loss = self.flags.tv_weight * 2 * (x_tv / tv_x_size + y_tv / tv_y_size) / self.flags.batch_size 113 | 114 | @staticmethod 115 | def _tensor_size(tensor): 116 | return functools.reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1) 117 | 118 | def train_step(self, imgs): 119 | ops = [self.optim, self.content_loss, self.style_loss, self.tv_loss, self.total_loss, self.summary_op] 120 | feed_dict = {self.content_img_ph: imgs} 121 | _, content_loss, style_loss, tv_loss, total_loss, summary = self.sess.run(ops, feed_dict=feed_dict) 122 | 123 | return [content_loss, style_loss, tv_loss, total_loss], summary 124 | 125 | def sample_img(self, img): 126 | return self.sess.run(self.preds, feed_dict={self.content_img_ph: img}) 127 | 128 | def print_info(self, loss, iter_time): 129 | if np.mod(iter_time, self.flags.print_freq) == 0: 130 | ord_output = collections.OrderedDict([('cur_iter', iter_time), ('tar_iter', self.num_iters), 131 | ('batch_size', self.flags.batch_size), 132 | ('content_loss', loss[0]), ('style_loss', loss[1]), 133 | ('tv_loss', loss[2]), ('total_loss', loss[3]), 134 | ('gpu_index', self.flags.gpu_index)]) 135 | 136 | utils.print_metrics(iter_time, ord_output) 137 | 138 | 139 | class Transfer(object): 140 | def __call__(self, img, name='transfer', is_reuse=False): 141 | with tf.variable_scope(name, reuse=is_reuse): 142 | # [H, W, C] -> [H, W, 32] 143 | conv1 = self._conv_layer(img, num_filters=32, filter_size=9, strides=1, name='conv1') 144 | # [H, W, 32] -> [H/2, W/2, 64] 145 | conv2 = self._conv_layer(conv1, num_filters=64, filter_size=3, strides=2, name='conv2') 146 | # [H/2, W/2, 64] -> [H/4, W/4, 128] 147 | conv3 = self._conv_layer(conv2, num_filters=128, filter_size=3, strides=2, name='conv3') 148 | # [H/4, W/4, 128] -> [H/4, W/4, 128] 149 | resid = self.n_res_blocks(conv3, num_blocks=5, name='res_blocks') 150 | # [H/4, W/4, 128] -> [H/2, W/2, 64] 151 | conv_t1 = self._conv_tranpose_layer(resid, num_filters=64, filter_size=3, strides=2, name='trans_conv1') 152 | # [H/2, W/2, 64] -> [H, W, 32] 153 | conv_t2 = self._conv_tranpose_layer(conv_t1, num_filters=32, filter_size=3, strides=2, name='trans_conv2') 154 | # [H, W, 32] -> [H, W, 3] 155 | conv_t3 = self._conv_layer(conv_t2, num_filters=3, filter_size=9, strides=1, relu=False, name='conv4') 156 | preds = tf.nn.tanh(conv_t3) * 150 + 255. / 2 157 | 158 | return preds 159 | 160 | @staticmethod 161 | def _conv_layer(input_, num_filters=32, filter_size=3, strides=1, relu=True, name=None): 162 | with tf.variable_scope(name): 163 | input_ = tf_utils.padding2d(input_, p_h=int(filter_size/2), p_w=int(filter_size/2), pad_type='REFLECT') 164 | input_ = tf_utils.conv2d(input_, output_dim=num_filters, k_h=filter_size, k_w=filter_size, 165 | d_h=strides, d_w=strides, padding='VALID') 166 | input_ = tf_utils.instance_norm(input_) 167 | 168 | if relu: 169 | input_ = tf.nn.relu(input_) 170 | 171 | return input_ 172 | 173 | @staticmethod 174 | def n_res_blocks(x, _ops=None, norm_='instance', is_train=True, num_blocks=6, is_print=False, name=None): 175 | with tf.variable_scope(name): 176 | output = None 177 | for idx in range(1, num_blocks + 1): 178 | output = tf_utils.res_block(x, x.get_shape()[3], _ops=_ops, norm_=norm_, is_train=is_train, 179 | pad_type='REFLECT', name='res{}'.format(idx)) 180 | x = output 181 | 182 | if is_print: 183 | tf_utils.print_activations(output) 184 | 185 | return output 186 | 187 | @staticmethod 188 | def _conv_tranpose_layer(input_, num_filters=32, filter_size=3, strides=2, name=None): 189 | with tf.variable_scope(name): 190 | input_ = tf_utils.deconv2d(input_, num_filters, k_h=filter_size, k_w=filter_size, d_h=strides, d_w=strides) 191 | input_ = tf_utils.instance_norm(input_) 192 | 193 | return tf.nn.relu(input_) 194 | 195 | 196 | class VGG19(object): 197 | def __init__(self, data_path): 198 | self.data = scipy.io.loadmat(data_path) 199 | self.weights = self.data['layers'][0] 200 | 201 | self.mean_pixel = np.asarray([123.68, 116.779, 103.939]) 202 | self.layers = ( 203 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 204 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 205 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 206 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 207 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4' 208 | ) 209 | 210 | def __call__(self, img, name='vgg', is_reuse=False): 211 | with tf.variable_scope(name, reuse=is_reuse): 212 | img_pre = self.preprocess(img) 213 | 214 | net_dic = {} 215 | current = img_pre 216 | for i, name in enumerate(self.layers): 217 | kind = name[:4] 218 | if kind == 'conv': 219 | kernels, bias = self.weights[i][0][0][0][0] 220 | # matconvent: weights are [width, height, in_channels, out_channels] 221 | # tensorflow: weights are [height, width, in_channels, out_channels] 222 | kernels = np.transpose(kernels, (1, 0, 2, 3)) 223 | bias = bias.reshape(-1) 224 | current = self._conv_layer(current, kernels, bias) 225 | elif kind == 'relu': 226 | current = tf.nn.relu(current) 227 | elif kind == 'pool': 228 | current = self._pool_layer(current) 229 | 230 | net_dic[name] = current 231 | 232 | assert len(net_dic) == len(self.layers) 233 | 234 | return net_dic 235 | 236 | @staticmethod 237 | def _conv_layer(input_, weights, bias): 238 | conv = tf.nn.conv2d(input_, tf.constant(weights), strides=(1, 1, 1, 1), padding='SAME') 239 | return tf.nn.bias_add(conv, bias) 240 | 241 | @staticmethod 242 | def _pool_layer(input_): 243 | return tf.nn.max_pool(input_, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME') 244 | 245 | def preprocess(self, img): 246 | return img - self.mean_pixel 247 | 248 | def unprocess(self, img): 249 | return img + self.mean_pixel 250 | --------------------------------------------------------------------------------