├── examples ├── style │ ├── wave.jpg │ ├── africa.jpg │ ├── bango.jpg │ ├── udnie.jpg │ ├── aquarelle.jpg │ ├── hampson.jpg │ ├── la_muse.jpg │ ├── the_scream.jpg │ ├── chinese_style.jpg │ ├── rain_princess.jpg │ └── the_shipwreck_of_the_minotaur.jpg └── content │ ├── fox.mp4 │ ├── chicago.jpg │ └── stata.jpg ├── .gitignore ├── setup.sh ├── utils.py ├── transform_video.py ├── main.py ├── combine_videos.py ├── evaluate.py ├── solver.py ├── README.md ├── tf_utils.py └── style_transfer.py /examples/style/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/wave.jpg -------------------------------------------------------------------------------- /examples/content/fox.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/content/fox.mp4 -------------------------------------------------------------------------------- /examples/style/africa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/africa.jpg -------------------------------------------------------------------------------- /examples/style/bango.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/bango.jpg -------------------------------------------------------------------------------- /examples/style/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/udnie.jpg -------------------------------------------------------------------------------- /examples/content/chicago.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/content/chicago.jpg -------------------------------------------------------------------------------- /examples/content/stata.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/content/stata.jpg -------------------------------------------------------------------------------- /examples/style/aquarelle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/aquarelle.jpg -------------------------------------------------------------------------------- /examples/style/hampson.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/hampson.jpg -------------------------------------------------------------------------------- /examples/style/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/la_muse.jpg -------------------------------------------------------------------------------- /examples/style/the_scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/the_scream.jpg -------------------------------------------------------------------------------- /examples/style/chinese_style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/chinese_style.jpg -------------------------------------------------------------------------------- /examples/style/rain_princess.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/rain_princess.jpg -------------------------------------------------------------------------------- /examples/style/the_shipwreck_of_the_minotaur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengBinJin/Real-time-style-transfer/HEAD/examples/style/the_shipwreck_of_the_minotaur.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | checkpoints/ 4 | temp/ 5 | examples/temp 6 | examples/test 7 | examples/results 8 | logs/ 9 | checkpoints_original/ 10 | checkpoints_v2/ 11 | github_img/ 12 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Logan Engstrom 5 | # ------------------------------------------------------------ 6 | #! /bin/bash 7 | 8 | mkdir data 9 | cd data 10 | wget http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat 11 | mkdir bin 12 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip 13 | unzip train2014.zip 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import sys 9 | import scipy.misc 10 | import numpy as np 11 | 12 | 13 | def imread(path, is_gray_scale=False, img_size=None): 14 | if is_gray_scale: 15 | img = scipy.misc.imread(path, flatten=True).astype(np.float32) 16 | else: 17 | img = scipy.misc.imread(path, mode='RGB').astype(np.float32) 18 | 19 | if not (img.ndim == 3 and img.shape[2] == 3): 20 | img = np.dstack((img, img, img)) 21 | 22 | if img_size is not None: 23 | img = scipy.misc.imresize(img, img_size) 24 | 25 | return img 26 | 27 | 28 | def imsave(path, img): 29 | img = np.clip(img, 0, 255).astype(np.uint8) 30 | scipy.misc.imsave(path, img) 31 | 32 | 33 | def all_files_under(path, extension=None, append_path=True, sort=True): 34 | if append_path: 35 | if extension is None: 36 | filenames = [os.path.join(path, fname) for fname in os.listdir(path)] 37 | else: 38 | filenames = [os.path.join(path, fname) 39 | for fname in os.listdir(path) if fname.endswith(extension)] 40 | else: 41 | if extension is None: 42 | filenames = [os.path.basename(fname) for fname in os.listdir(path)] 43 | else: 44 | filenames = [os.path.basename(fname) 45 | for fname in os.listdir(path) if fname.endswith(extension)] 46 | 47 | if sort: 48 | filenames = sorted(filenames) 49 | 50 | return filenames 51 | 52 | 53 | def exists(p, msg): 54 | assert os.path.exists(p), msg 55 | 56 | 57 | def print_metrics(itr, kargs): 58 | print("*** Iteration {} ====> ".format(itr)) 59 | for name, value in kargs.items(): 60 | print("{} : {}, ".format(name, value)) 61 | print("") 62 | sys.stdout.flush() 63 | 64 | -------------------------------------------------------------------------------- /transform_video.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin, based on code from Logan Engstrom 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import numpy as np 9 | import tensorflow as tf 10 | import moviepy.video.io.ffmpeg_writer as ffmpeg_writer 11 | from moviepy.video.io.VideoFileClip import VideoFileClip 12 | 13 | import utils as utils 14 | from style_transfer import Transfer 15 | 16 | FLAGS = tf.flags.FLAGS 17 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index, default: 0') 18 | tf.flags.DEFINE_string('checkpoint_dir', 'checkpoints/la_muse', 19 | 'dir to read checkpoint in, default: ./checkpoints/la_muse') 20 | tf.flags.DEFINE_string('in_path', None, 'input video path') 21 | tf.flags.DEFINE_string('out_path', None, 'path to save processeced video to') 22 | 23 | 24 | def feed_forward_video(path_in, path_out, checkpoint_dir): 25 | # initialize video cap 26 | video_cap = VideoFileClip(path_in, audio=False) 27 | # initialize writer 28 | video_writer = ffmpeg_writer.FFMPEG_VideoWriter(path_out, video_cap.size, video_cap.fps, codec='libx264', 29 | preset='medium', bitrate='2000k', audiofile=path_in, 30 | threads=None, ffmpeg_params=None) 31 | 32 | g = tf.Graph() 33 | soft_config = tf.ConfigProto(allow_soft_placement=True) 34 | soft_config.gpu_options.allow_growth = True 35 | 36 | with g.as_default(), tf.Session(config=soft_config) as sess: 37 | batch_shape = (None, video_cap.size[1], video_cap.size[0], 3) 38 | img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder') 39 | 40 | model = Transfer() 41 | pred = model(img_placeholder) 42 | saver = tf.train.Saver() 43 | 44 | if os.path.isdir(checkpoint_dir): 45 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 46 | if ckpt and ckpt.model_checkpoint_path: 47 | saver.restore(sess, ckpt.model_checkpoint_path) 48 | else: 49 | raise Exception('No checkpoint found...') 50 | else: 51 | saver.restore(sess, checkpoint_dir) 52 | 53 | frame_id = 0 54 | for frame in video_cap.iter_frames(): 55 | print('frame id: {}'.format(frame_id)) 56 | _pred = sess.run(pred, feed_dict={img_placeholder: np.asarray([frame]).astype(np.float32)}) 57 | video_writer.write_frame(np.clip(_pred, 0, 255).astype(np.uint8)) 58 | frame_id += 1 59 | 60 | video_writer.close() 61 | 62 | 63 | def check_opts(flags): 64 | utils.exists(flags.checkpoint_dir, 'checkpoint_dir not found!') 65 | utils.exists(flags.in_path, 'in_path not found!') 66 | 67 | 68 | def main(_): 69 | os.environ['CUDA_AVAILABLE_DEVICES'] = FLAGS.gpu_index 70 | check_opts(FLAGS) 71 | 72 | feed_forward_video(FLAGS.in_path, FLAGS.out_path, FLAGS.checkpoint_dir) 73 | 74 | 75 | if __name__ == '__main__': 76 | tf.app.run() 77 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import tensorflow as tf 9 | 10 | import utils as utils 11 | from solver import Solver 12 | 13 | FLAGS = tf.flags.FLAGS 14 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index, default: 0') 15 | tf.flags.DEFINE_string('checkpoint_dir', 'checkpoints', 'dir to save checkpoint in, default: ./checkpoints') 16 | 17 | tf.flags.DEFINE_string('style_img', 'examples/style/la_muse.jpg', 18 | 'style image path, default: ./examples/style/la_muse.jpg') 19 | tf.flags.DEFINE_string('train_path', '../Data/coco/img/train2014', 20 | 'path to training images folder, default: ../Data/coco/img/train2014') 21 | tf.flags.DEFINE_string('test_path', 'examples/content', 22 | 'test image path, default: ./examples/content') 23 | tf.flags.DEFINE_string('test_dir', './examples/temp', 'test image save dir, default: ./examples/temp') 24 | 25 | tf.flags.DEFINE_integer('epochs', 2, 'number of epochs for training data, default: 2') 26 | tf.flags.DEFINE_integer('batch_size', 4, 'batch size for single feed forward, default: 4') 27 | 28 | tf.flags.DEFINE_string('vgg_path', '../Models_zoo/imagenet-vgg-verydeep-19.mat', 29 | 'path to VGG19 network, default: ../Models_zoo/imagenet-vgg-verydeep-19.mat') 30 | tf.flags.DEFINE_float('content_weight', 7.5, 'content weight, default: 7.5') 31 | tf.flags.DEFINE_float('style_weight', 100., 'style weight, default: 100.') 32 | tf.flags.DEFINE_float('tv_weight', 200., 'total variation regularization weight, default: 200.') 33 | tf.flags.DEFINE_float('learning_rate', 0.001, 'learning rate, default: 1e-3') 34 | 35 | tf.flags.DEFINE_integer('print_freq', 100, 'print loss frequency, defalut: 100') 36 | tf.flags.DEFINE_integer('sample_freq', 2000, 'sample frequency, default: 2000') 37 | 38 | 39 | def check_opts(flags): 40 | utils.exists(flags.style_img, 'style path not found!') 41 | utils.exists(flags.train_path, 'train path not found!') 42 | utils.exists(flags.test_path, 'test image path not found!') 43 | utils.exists(flags.vgg_path, 'vgg network data not found!') 44 | 45 | assert flags.epochs > 0 46 | assert flags.batch_size > 0 47 | assert flags.print_freq > 0 48 | assert flags.sample_freq > 0 49 | assert flags.content_weight >= 0 50 | assert flags.style_weight >= 0 51 | assert flags.tv_weight >= 0 52 | assert flags.learning_rate >= 0 53 | 54 | print(flags.style_img) 55 | print(flags.style_img.split('/')[-1][:-4]) 56 | 57 | style_img_name = flags.style_img.split('/')[-1][:-4] # extract style image name 58 | fold_name = os.path.join(flags.checkpoint_dir, style_img_name) 59 | if not os.path.isdir(fold_name): 60 | os.makedirs(fold_name) 61 | 62 | fold_name = os.path.join(flags.test_dir, style_img_name) 63 | if not os.path.isdir(fold_name): 64 | os.makedirs(fold_name) 65 | 66 | 67 | def main(_): 68 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_index 69 | check_opts(FLAGS) 70 | 71 | solver = Solver(FLAGS) 72 | solver.train() 73 | 74 | 75 | if __name__ == '__main__': 76 | tf.app.run() 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /combine_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | 6 | parser = argparse.ArgumentParser(description='') 7 | parser.add_argument('--resize_ratio', dest='resize_ratio', type=float, default=0.4, help='resize iamge') 8 | parser.add_argument('--delay', dest='delay', type=int, default=1, help='interval between two frames') 9 | parser.add_argument('--style_size', dest='style_size', type=int, default=64, 10 | help='sylte image size in video') 11 | args = parser.parse_args() 12 | 13 | 14 | def main(): 15 | # Define the codec and create VideoWriter object 16 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 17 | out = cv2.VideoWriter('./examples/results/output.mp4', fourcc, 20.0, (768, 1024)) 18 | 19 | n_rows, n_cols = 4, 3 20 | video_path = './examples/{}/fox.mp4' 21 | style_path = './examples/style/' 22 | video_file = ['content', 'results/africa', 'results/aquarelle', 23 | 'results/bango', 'results/chinese_style', 'results/hampson', 24 | 'results/la_muse', 'results/rain_princess', 'results/the_scream', 25 | 'results/the_shipwreck_of_the_minotaur', 'results/udnie', 'results/wave'] 26 | img_file = ['africa.jpg', 'aquarelle.jpg', 27 | 'bango.jpg', 'chinese_style.jpg', 'hampson.jpg', 28 | 'la_muse.jpg', 'rain_princess.jpg', 'the_scream.jpg', 29 | 'the_shipwreck_of_the_minotaur.jpg', 'udnie.jpg', 'wave.jpg'] 30 | 31 | # initialize video captures & sylte images 32 | caps, styles = [], [] 33 | for file in video_file: 34 | caps.append(cv2.VideoCapture(video_path.format(file))) 35 | 36 | for file in img_file: 37 | styles.append(cv2.imread(os.path.join(style_path, file))) 38 | 39 | cv2.namedWindow('Show') 40 | cv2.moveWindow('Show', 0, 0) 41 | while True: 42 | # read frames 43 | frames = [] 44 | for idx in range(len(video_file)): 45 | rest, frame = caps[idx].read() 46 | 47 | if rest is False: 48 | print('Can not find frame!') 49 | break 50 | else: 51 | # resize original frame 52 | resized_frame = cv2.resize(frame, (int(frame.shape[0] * args.resize_ratio), 53 | int(frame.shape[1] * args.resize_ratio))) 54 | 55 | # past style image 56 | if idx >= 1: 57 | img = styles[idx-1] 58 | resized_img = cv2.resize(img, (args.style_size, args.style_size)) 59 | resized_frame[-args.style_size:, 0:args.style_size, :] = resized_img 60 | 61 | frames.append(resized_frame) 62 | 63 | # initialize canvas 64 | height, width, channel = frames[0].shape 65 | canvas = np.zeros((n_rows * height, n_cols * width, channel), dtype=np.uint8) 66 | 67 | for row in range(n_rows): 68 | for col in range(n_cols): 69 | canvas[row*height:(row+1)*height, col*width:(col+1)*width, :] = frames[row * n_cols + col] 70 | 71 | cv2.imshow('Show', canvas) 72 | if cv2.waitKey(args.delay) & 0xFF == 27: 73 | break 74 | 75 | # write the new frame 76 | out.write(canvas) 77 | 78 | # When everyting done, release the capture 79 | for idx in range(len(caps)): 80 | caps[idx].release() 81 | out.release() 82 | 83 | cv2.destroyAllWindows() 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin, based on code from Logan Engstrom 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import time 9 | import numpy as np 10 | import tensorflow as tf 11 | from collections import defaultdict 12 | 13 | from style_transfer import Transfer 14 | import utils as utils 15 | 16 | FLAGS = tf.flags.FLAGS 17 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index, default: 0') 18 | tf.flags.DEFINE_string('checkpoint_dir', 'checkpoints/la_muse', 19 | 'dir to read checkpoint in, default: ./checkpoints/la_muse') 20 | 21 | tf.flags.DEFINE_string('in_path', './examples/test', 'test image path, default: ./examples/test') 22 | tf.flags.DEFINE_string('out_path', './examples/results', 23 | 'destination dir of transformed files, default: ./examples/restuls') 24 | 25 | 26 | def feed_transform(data_in, paths_out, checkpoint_dir): 27 | img_shape = utils.imread(data_in[0]).shape 28 | 29 | g = tf.Graph() 30 | soft_config = tf.ConfigProto(allow_soft_placement=True) 31 | soft_config.gpu_options.allow_growth = True 32 | 33 | with g.as_default(), tf.Session(config=soft_config) as sess: 34 | img_placeholder = tf.placeholder(tf.float32, shape=[None, *img_shape], name='img_placeholder') 35 | 36 | model = Transfer() 37 | pred = model(img_placeholder) 38 | 39 | saver = tf.train.Saver() 40 | if os.path.isdir(checkpoint_dir): 41 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 42 | if ckpt and ckpt.model_checkpoint_path: 43 | saver.restore(sess, ckpt.model_checkpoint_path) 44 | else: 45 | raise Exception('No checkpoint found...') 46 | else: 47 | saver.restore(sess, checkpoint_dir) 48 | 49 | img = np.asarray([utils.imread(data_in[0])]).astype(np.float32) 50 | start_tic = time.time() 51 | _pred = sess.run(pred, feed_dict={img_placeholder: img}) 52 | end_toc = time.time() 53 | print('PT: {:.2f} msec.\n'.format((end_toc - start_tic) * 1000)) 54 | utils.imsave(paths_out[0], _pred[0]) # paths_out and _pred is list 55 | 56 | 57 | def feed_forward(in_paths, out_paths, checkpoint_dir): 58 | in_path_of_shape = defaultdict(list) 59 | out_path_of_shape = defaultdict(list) 60 | 61 | for idx in range(len(in_paths)): 62 | in_image = in_paths[idx] 63 | out_image = out_paths[idx] 64 | 65 | shape = "%dx%dx%d" % utils.imread(in_image).shape 66 | in_path_of_shape[shape].append(in_image) 67 | out_path_of_shape[shape].append(out_image) 68 | 69 | for shape in in_path_of_shape: 70 | print('Processing images of shape {}'.format(shape)) 71 | feed_transform(in_path_of_shape[shape], out_path_of_shape[shape], checkpoint_dir) 72 | 73 | 74 | def check_opts(flags): 75 | utils.exists(flags.checkpoint_dir, 'checkpoint_dir not found!') 76 | utils.exists(flags.in_path, 'in_path not found!') 77 | 78 | style_name = FLAGS.checkpoint_dir.split('/')[-1] 79 | if not os.path.isdir(os.path.join(flags.out_path, style_name)): 80 | os.makedirs(os.path.join(flags.out_path, style_name)) 81 | 82 | 83 | def main(_): 84 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_index 85 | check_opts(FLAGS) 86 | 87 | style_name = FLAGS.checkpoint_dir.split('/')[-1] 88 | img_paths = utils.all_files_under(FLAGS.in_path) 89 | out_paths = [os.path.join(FLAGS.out_path, style_name, file) 90 | for file in utils.all_files_under(FLAGS.in_path, append_path=False)] 91 | 92 | feed_forward(img_paths, out_paths, FLAGS.checkpoint_dir) 93 | 94 | 95 | if __name__ == '__main__': 96 | tf.app.run() 97 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Real-Time Style Transfer Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # ------------------------------------------------------------ 7 | import os 8 | import random 9 | import numpy as np 10 | import tensorflow as tf 11 | from datetime import datetime 12 | 13 | import tf_utils as tf_utils 14 | import utils as utils 15 | from style_transfer import StyleTranser, Transfer 16 | 17 | 18 | class Solver(object): 19 | def __init__(self, flags): 20 | run_config = tf.ConfigProto() 21 | run_config.gpu_options.allow_growth = True 22 | self.sess = tf.Session(config=run_config) 23 | 24 | self.flags = flags 25 | self.style_img_name = flags.style_img.split('/')[-1][:-4] 26 | self.content_target_paths = utils.all_files_under(self.flags.train_path) 27 | 28 | self.test_targets = utils.all_files_under(self.flags.test_path, extension='.jpg') 29 | 30 | self.test_target_names = utils.all_files_under(self.flags.test_path, append_path=False, extension='.jpg') 31 | self.test_save_paths = [os.path.join(self.flags.test_dir, self.style_img_name, file[:-4]) 32 | for file in self.test_target_names] 33 | 34 | self.num_contents = len(self.content_target_paths) 35 | self.num_iters = int(self.num_contents / self.flags.batch_size) * self.flags.epochs 36 | 37 | self.model = StyleTranser(self.sess, self.flags, self.num_iters) 38 | 39 | self.train_writer = tf.summary.FileWriter('logs/{}'.format(self.style_img_name), graph_def=self.sess.graph_def) 40 | self.saver = tf.train.Saver() 41 | self.sess.run(tf.global_variables_initializer()) 42 | tf_utils.show_all_variables() 43 | 44 | def train(self): 45 | random.seed(datetime.now()) # set random sedd 46 | 47 | for iter_time in range(self.num_iters): 48 | # sampling images and save them 49 | self.sample(iter_time) 50 | 51 | # read batch data and feed forward 52 | batch_imgs = self.next_batch() 53 | loss, summary = self.model.train_step(batch_imgs) 54 | 55 | # write log to tensorboard 56 | self.train_writer.add_summary(summary, iter_time) 57 | self.train_writer.flush() 58 | 59 | # print loss information 60 | self.model.print_info(loss, iter_time) 61 | 62 | # save model at the end 63 | self.save_model() 64 | 65 | def save_model(self): 66 | model_name = 'model' 67 | self.saver.save(self.sess, os.path.join(self.flags.checkpoint_dir, self.style_img_name, model_name)) 68 | print('=====================================') 69 | print(' Model saved! ') 70 | print('=====================================\n') 71 | 72 | def sample(self, iter_time): 73 | if np.mod(iter_time, self.flags.sample_freq) == 0: 74 | self.save_model() # save model before sample examples 75 | 76 | for idx in range(len(self.test_save_paths)): 77 | save_path = (self.test_save_paths[idx] + '_%s.png' % iter_time) 78 | 79 | print('save path: {}'.format(save_path)) 80 | print('test_target: {}'.format(self.test_targets[idx])) 81 | 82 | self.feed_transform([self.test_targets[idx]], [save_path]) 83 | 84 | def next_batch(self): 85 | batch_imgs = [] 86 | batch_files = np.random.choice(self.content_target_paths, self.flags.batch_size, replace=False) 87 | 88 | for batch_file in batch_files: 89 | img = utils.imread(batch_file, img_size=(256, 256, 3)) 90 | batch_imgs.append(img) 91 | 92 | return np.asarray(batch_imgs) 93 | 94 | def feed_transform(self, data_in, paths_out): 95 | checkpoint_dir = os.path.join(self.flags.checkpoint_dir, self.style_img_name, 'model') 96 | img_shape = utils.imread(data_in[0]).shape 97 | 98 | g = tf.Graph() 99 | soft_config = tf.ConfigProto(allow_soft_placement=True) 100 | soft_config.gpu_options.allow_growth = True 101 | 102 | with g.as_default(), tf.Session(config=soft_config) as sess: 103 | img_placeholder = tf.placeholder(tf.float32, shape=[None, *img_shape], name='img_placeholder') 104 | 105 | model = Transfer() 106 | pred = model(img_placeholder) 107 | 108 | saver = tf.train.Saver() 109 | if os.path.isdir(checkpoint_dir): 110 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 111 | if ckpt and ckpt.model_checkpoint_path: 112 | saver.restore(sess, ckpt.model_checkpoint_path) 113 | else: 114 | raise Exception('No checkpoint found...') 115 | else: 116 | saver.restore(sess, checkpoint_dir) 117 | 118 | img = np.asarray([utils.imread(data_in[0])]).astype(np.float32) 119 | _pred = sess.run(pred, feed_dict={img_placeholder: img}) 120 | utils.imsave(paths_out[0], _pred[0]) # paths_out and _pred is list 121 | 122 | 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-Time Style Transfer in [TensorFlow](https://github.com/tensorflow/tensorflow) 2 | This repository is Tensorflow implementation of Johnson's [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155). 3 | 4 |
5 |
6 |
9 |
10 |
25 |
26 |
27 |
28 |
34 |
35 |