├── style ├── ink.jpg ├── wave.jpg ├── crayon.jpg ├── mosaic.jpg ├── sketch.jpg ├── starry.jpg └── feathers.jpg ├── screenshot.jpeg ├── static ├── models_image │ ├── ink.ckpt.jpg │ ├── wave.ckpt.jpg │ ├── crayon.ckpt.jpg │ ├── mosaic.ckpt.jpg │ ├── sketch.ckpt.jpg │ ├── starry.ckpt.jpg │ └── feathers.ckpt.jpg └── index.html ├── .gitignore ├── default_config.py ├── reader.py ├── README.md ├── loss.py ├── vgg.py ├── eval.py ├── transform.py ├── server.py └── train.py /style/ink.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/style/ink.jpg -------------------------------------------------------------------------------- /style/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/style/wave.jpg -------------------------------------------------------------------------------- /screenshot.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/screenshot.jpeg -------------------------------------------------------------------------------- /style/crayon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/style/crayon.jpg -------------------------------------------------------------------------------- /style/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/style/mosaic.jpg -------------------------------------------------------------------------------- /style/sketch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/style/sketch.jpg -------------------------------------------------------------------------------- /style/starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/style/starry.jpg -------------------------------------------------------------------------------- /style/feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/style/feathers.jpg -------------------------------------------------------------------------------- /static/models_image/ink.ckpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/static/models_image/ink.ckpt.jpg -------------------------------------------------------------------------------- /static/models_image/wave.ckpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/static/models_image/wave.ckpt.jpg -------------------------------------------------------------------------------- /static/models_image/crayon.ckpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/static/models_image/crayon.ckpt.jpg -------------------------------------------------------------------------------- /static/models_image/mosaic.ckpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/static/models_image/mosaic.ckpt.jpg -------------------------------------------------------------------------------- /static/models_image/sketch.ckpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/static/models_image/sketch.ckpt.jpg -------------------------------------------------------------------------------- /static/models_image/starry.ckpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/static/models_image/starry.ckpt.jpg -------------------------------------------------------------------------------- /static/models_image/feathers.ckpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijkzzz/prisma/HEAD/static/models_image/feathers.ckpt.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .DS_Store 3 | tags 4 | *.pyc 5 | models/ 6 | content/ 7 | generate/ 8 | imagenet-vgg-verydeep-19.mat 9 | train2014/ -------------------------------------------------------------------------------- /default_config.py: -------------------------------------------------------------------------------- 1 | CELERY_BROKER_URL = 'redis://localhost:6379/0' 2 | CELERY_RESULT_BACKEND = 'redis://localhost:6379/0' 3 | 4 | MAIL_SERVER = 'smtp.qq.com' 5 | MAIL_PROT = 465 6 | MAIL_USE_TLS = True 7 | MAIL_USE_SSL = False 8 | MAIL_USERNAME = '' 9 | MAIL_PASSWORD = '' 10 | MAIL_DEBUG = True 11 | 12 | UPLOAD_FOLDER = 'upload/' 13 | ALLOWED_EXTENSIONS = set(['.png', '.jpg', '.jpeg']) 14 | MAX_CONTENT_LENGTH = 4 * 1024 * 1024 15 | 16 | MODEL_FOLDER = 'models/' 17 | MODEL_FILES = set(['crayon.ckpt', 'feathers.ckpt', 'sketch.ckpt', 18 | 'starry.ckpt', 'ink.ckpt', 'mosaic.ckpt', 'wave.ckpt']) 19 | OUTPUT_FOLDER = 'generate/' 20 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | from os import listdir, remove 2 | from os.path import exists, join, isfile 3 | import tensorflow as tf 4 | 5 | 6 | def preprocess(image, size): 7 | shape = tf.shape(image) 8 | size_t = tf.constant(size, tf.float64) 9 | height = tf.cast(shape[0], tf.float64) 10 | width = tf.cast(shape[1], tf.float64) 11 | 12 | cond_op = tf.less(height, width) 13 | 14 | # 等比例缩放 15 | new_height, new_width = tf.cond( 16 | cond_op, 17 | lambda: (size_t, (width * size_t) / height), 18 | lambda: ((height * size_t) / width, size_t)) 19 | 20 | resized_image = tf.image.resize_images( 21 | image, 22 | [tf.to_int32(new_height), tf.to_int32(new_width)], 23 | method=tf.image.ResizeMethod.BICUBIC) 24 | cropped = tf.image.resize_image_with_crop_or_pad(resized_image, size, size) 25 | 26 | return cropped 27 | 28 | 29 | def get_image(path, size): 30 | png = path.lower().endswith('png') 31 | img_bytes = tf.read_file(path) 32 | image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3) 33 | return preprocess(image, size) 34 | 35 | 36 | def image(n, size, path, epochs=2, shuffle=True, crop=True): 37 | # for macOS 38 | if exists(join(path, '.DS_Store')): 39 | remove(join(path, '.DS_Store')) 40 | 41 | filenames = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 42 | if not shuffle: 43 | filenames = sorted(filenames) 44 | 45 | png = filenames[0].lower().endswith('png') # If first file is a png, assume they all are 46 | 47 | filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle, num_epochs=epochs) 48 | reader = tf.WholeFileReader() 49 | _, img_bytes = reader.read(filename_queue) 50 | image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3) 51 | 52 | processed_image = preprocess(image, size) 53 | return tf.train.batch([processed_image], n, dynamic_pad=True) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # prisma 2 | 3 | Online fast neural style transfer 4 | 5 | ## Requirements 6 | 7 | - Python 3+ 8 | - Tensorflow 1.1.0+ 9 | - Scipy 10 | - Flask 11 | - Flask-Mail 12 | - Celery 13 | - Redis 14 | 15 | ## Setup 16 | - Dependencies 17 | ``` 18 | pip3 install numpy pillow scipy 19 | pip3 install flask flask-mail celery redis 20 | pip3 install tensorflow // for cpu 21 | ``` 22 | - Download Models 23 | >http://pan.baidu.com/s/1pLPSXdx 24 | 25 | ``` 26 | mv models/ prisma/ 27 | ``` 28 | 29 | - Mailbox 30 | ``` 31 | default_config.py 32 | 33 | MAIL_SERVER = 'xxxxx' 34 | MAIL_PORT = xxx 35 | MAIL_USERNAME = 'xxxxxx' 36 | MAIL_PASSWORD = 'xxxxxx' 37 | ``` 38 | 39 | - Run Redis 40 | ``` 41 | default_config.py 42 | 43 | CELERY_BROKER_URL = 'redis://localhost:6379/0' 44 | ./redis-server 45 | ``` 46 | 47 | - Run Celery 48 | ``` 49 | celery -A server.celery worker 50 | ``` 51 | 52 | - Run Flask 53 | ``` 54 | python server.py 55 | ``` 56 | 57 | ## Training 58 | - Download COCO dataset and VGG19 model 59 | >[VGG19 model](http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat) 60 | 61 | >[COCO dataset](http://msvocds.blob.core.windows.net/coco2014/train2014.zip) 62 | 63 | - Put the model and dataset into "prisma/" 64 | 65 | ``` 66 | # Recommend using tensorflow of gpu version 67 | python3 train.py --STYLE_IMAGES style-image.jpg --CONTENT_WEIGHT 1.0 --STYLE_WEIGHT 10.0 --MODEL_PATH models/newmodel.ckpt 68 | 69 | mv models/newmodel.ckpt-done models/newmodel.ckpt 70 | 71 | # Test 72 | python3 eval.py --CONTENT_IMAGE content-image.jpg --MODEL_PATH models/newmodel.ckpt --OUTPUT_FOLDER generate/ 73 | ``` 74 | 75 | - Add to Prisma 76 | ``` 77 | default_config.py 78 | 79 | MODEL_FILES = set(['newmodel.ckpt', ......]) 80 | 81 | # Put a example image ("newmodel.jpg") into "static/models_image/" 82 | ``` 83 | 84 | ## Screenshot 85 | >http://localhost:5000/ 86 | 87 |  88 | 89 | ## Reference 90 | - [A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576) 91 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) 92 | - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) 93 | - [OlavHN/fast-neural-style](https://github.com/OlavHN/fast-neural-style) 94 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import vgg 3 | import reader 4 | 5 | 6 | def gram(layer): 7 | shape = tf.shape(layer) 8 | num_images = shape[0] 9 | width = shape[1] 10 | height = shape[2] 11 | num_filters = shape[3] 12 | filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters])) 13 | grams = tf.matmul(filters, filters, transpose_a=True) / \ 14 | tf.to_float(width * height * num_filters) 15 | 16 | return grams 17 | 18 | 19 | def get_style_features(style_paths, style_layers, image_size, style_scale, vgg_path): 20 | with tf.Graph().as_default(), tf.Session() as sess: 21 | size = int(round(image_size * style_scale)) 22 | images = tf.stack( 23 | [reader.get_image(path, size) for path in style_paths]) 24 | net, _ = vgg.net(vgg_path, images - vgg.MEAN_PIXEL) 25 | features = [] 26 | for layer in style_layers: 27 | features.append(gram(net[layer])) 28 | 29 | return sess.run(features) 30 | 31 | 32 | def style_loss(net, style_features_t, style_layers): 33 | style_loss = 0 34 | for style_gram, layer in zip(style_features_t, style_layers): 35 | generated_images, _ = tf.split(net[layer], 2, 0) 36 | size = tf.size(generated_images) 37 | layer_style_loss = tf.nn.l2_loss( 38 | gram(generated_images) - style_gram) * 2 / tf.to_float(size) 39 | style_loss += layer_style_loss 40 | return style_loss 41 | 42 | 43 | def content_loss(net, content_layers): 44 | content_loss = 0 45 | for layer in content_layers: 46 | generated_images, content_images = tf.split( 47 | net[layer], 2, 0) 48 | size = tf.size(generated_images) 49 | # remain the same as in the paper 50 | content_loss += tf.nn.l2_loss(generated_images - 51 | content_images) * 2 / tf.to_float(size) 52 | return content_loss 53 | 54 | 55 | # 全变差正则化 56 | def total_variation_loss(layer): 57 | shape = tf.shape(layer) 58 | height = shape[1] 59 | width = shape[2] 60 | y = tf.slice(layer, [0, 0, 0, 0], tf.stack( 61 | [-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1]) 62 | x = tf.slice(layer, [0, 0, 0, 0], tf.stack( 63 | [-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1]) 64 | loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + \ 65 | tf.nn.l2_loss(y) / tf.to_float(tf.size(y)) 66 | return loss 67 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-2016 Anish Athalye. Released under GPLv3. 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import scipy.io 6 | 7 | MEAN_PIXEL = np.array([123.68, 116.779, 103.939]) 8 | 9 | 10 | # 加载预训练的 VGG19 网络 11 | def net(data_path, input_image): 12 | layers = ( 13 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 14 | 15 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 16 | 17 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 18 | 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 19 | 20 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 21 | 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 22 | 23 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 24 | 'relu5_3', 'conv5_4', 'relu5_4' 25 | ) 26 | 27 | data = scipy.io.loadmat(data_path) 28 | mean = data['normalization'][0][0][0] 29 | mean_pixel = np.mean(mean, axis=(0, 1)) 30 | weights = data['layers'][0] 31 | 32 | net = {} 33 | current = input_image 34 | with tf.variable_scope('vgg19'): 35 | for i, name in enumerate(layers): 36 | with tf.variable_scope(name): 37 | kind = name[:4] 38 | if kind == 'conv': 39 | kernels, bias = weights[i][0][0][0][0] 40 | # matconvnet: weights are [width, height, in_channels, out_channels] 41 | # tensorflow: weights are [height, width, in_channels, out_channels] 42 | kernels = np.transpose(kernels, (1, 0, 2, 3)) 43 | bias = bias.reshape(-1) 44 | current = _conv_layer(current, kernels, bias) 45 | elif kind == 'relu': 46 | current = tf.nn.relu(current) 47 | elif kind == 'pool': 48 | current = _pool_layer(current) 49 | net[name] = current 50 | 51 | assert len(net) == len(layers) 52 | return net, mean_pixel 53 | 54 | 55 | def _conv_layer(input, weights, bias): 56 | conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), 57 | padding='SAME') 58 | return tf.nn.bias_add(conv, bias) 59 | 60 | 61 | def _pool_layer(input): 62 | return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), 63 | padding='SAME') 64 | 65 | 66 | def preprocess(image): 67 | return image - MEAN_PIXEL 68 | 69 | 70 | def unprocess(image): 71 | return image + MEAN_PIXEL 72 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from scipy import misc 4 | import vgg 5 | import transform 6 | 7 | tf.app.flags.DEFINE_string("MODEL_PATH", "models/fast-style-model.ckpt-done", "Pre-trained models") 8 | tf.app.flags.DEFINE_string("CONTENT_IMAGE", "content/content-image.png", "Path to content image") 9 | tf.app.flags.DEFINE_string("OUTPUT_FOLDER", "generate/", "Path to output image") 10 | tf.app.flags.DEFINE_integer("BATCH_SIZE", 1, "Number of concurrent images to train on") 11 | 12 | FLAGS = tf.app.flags.FLAGS 13 | 14 | 15 | def generate(): 16 | if not FLAGS.CONTENT_IMAGE: 17 | tf.logging.info("train a fast nerual style need to set the Content images path") 18 | return 19 | 20 | if not os.path.exists(FLAGS.OUTPUT_FOLDER): 21 | os.mkdir(FLAGS.OUTPUT_FOLDER) 22 | 23 | # 获取图片信息 24 | height = 0 25 | width = 0 26 | with open(FLAGS.CONTENT_IMAGE, 'rb') as img: 27 | with tf.Session().as_default() as sess: 28 | if FLAGS.CONTENT_IMAGE.lower().endswith('png'): 29 | image = sess.run(tf.image.decode_png(img.read())) 30 | else: 31 | image = sess.run(tf.image.decode_jpeg(img.read())) 32 | height = image.shape[0] 33 | width = image.shape[1] 34 | tf.logging.info('Image size: %dx%d' % (width, height)) 35 | 36 | with tf.Graph().as_default(), tf.Session() as sess: 37 | # 读取图片文件 38 | path = FLAGS.CONTENT_IMAGE 39 | png = path.lower().endswith('png') 40 | img_bytes = tf.read_file(path) 41 | 42 | # 图片解码 43 | content_image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3) 44 | content_image = tf.image.convert_image_dtype(content_image, tf.float32) * 255.0 45 | content_image = tf.expand_dims(content_image, 0) 46 | 47 | generated_images = transform.net(content_image - vgg.MEAN_PIXEL, training=False) 48 | output_format = tf.saturate_cast(generated_images, tf.uint8) 49 | 50 | # 开始转换 51 | saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1) 52 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 53 | model_path = os.path.abspath(FLAGS.MODEL_PATH) 54 | tf.logging.info('Usage model {}'.format(model_path)) 55 | saver.restore(sess, model_path) 56 | 57 | filename = os.path.basename(FLAGS.CONTENT_IMAGE) 58 | (shotname, extension) = os.path.splitext(filename) 59 | filename = shotname + '-' + os.path.basename(FLAGS.MODEL_PATH) + extension 60 | 61 | tf.logging.info("image {}".format(filename)) 62 | images_t = sess.run(output_format) 63 | 64 | assert len(images_t) == 1 65 | misc.imsave(os.path.join(FLAGS.OUTPUT_FOLDER, filename), images_t[0]) 66 | 67 | 68 | if __name__ == '__main__': 69 | tf.logging.set_verbosity(tf.logging.INFO) 70 | generate() -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv2d(x, input_filters, output_filters, kernel, strides, mode='REFLECT'): 5 | with tf.variable_scope('conv') as scope: 6 | 7 | shape = [kernel, kernel, input_filters, output_filters] 8 | weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight') 9 | x_padded = tf.pad(x, [[0, 0], [kernel // 2, kernel // 2], [kernel // 2, kernel // 2], [0, 0]], mode=mode) 10 | return tf.nn.conv2d(x_padded, weight, strides=[1, strides, strides, 1], padding='VALID', name='conv') 11 | 12 | 13 | def resize_conv2d(x, input_filters, output_filters, kernel, strides, training): 14 | ''' 15 | An alternative to transposed convolution where we first resize, then convolve. 16 | See http://distill.pub/2016/deconv-checkerboard/ 17 | For some reason the shape needs to be statically known for gradient propagation 18 | through tf.image.resize_images, but we only know that for fixed image size, so we 19 | plumb through a "training" argument 20 | ''' 21 | with tf.variable_scope('conv_transpose') as scope: 22 | height = x.get_shape()[1].value if training else tf.shape(x)[1] 23 | width = x.get_shape()[2].value if training else tf.shape(x)[2] 24 | 25 | new_height = height * strides * 2 26 | new_width = width * strides * 2 27 | 28 | x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR) 29 | 30 | shape = [kernel, kernel, input_filters, output_filters] 31 | weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight') 32 | return conv2d(x_resized, input_filters, output_filters, kernel, strides) 33 | 34 | 35 | def instance_norm(x): 36 | epsilon = 1e-9 37 | 38 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) 39 | 40 | return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon))) 41 | 42 | 43 | def residual(x, filters, kernel, strides): 44 | with tf.variable_scope('residual') as scope: 45 | conv1 = conv2d(x, filters, filters, kernel, strides) 46 | conv2 = conv2d(tf.nn.relu(conv1), filters, filters, kernel, strides) 47 | 48 | residual = x + conv2 49 | 50 | return residual 51 | 52 | 53 | # 上采样 - 深度残差网络 - 下采样 54 | def net(image, training): 55 | # Less border effects when padding a little before passing through .. 56 | image = tf.pad(image, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT') 57 | 58 | with tf.variable_scope('conv1'): 59 | conv1 = tf.nn.relu(instance_norm(conv2d(image, 3, 32, 9, 1))) 60 | with tf.variable_scope('conv2'): 61 | conv2 = tf.nn.relu(instance_norm(conv2d(conv1, 32, 64, 3, 2))) 62 | with tf.variable_scope('conv3'): 63 | conv3 = tf.nn.relu(instance_norm(conv2d(conv2, 64, 128, 3, 2))) 64 | with tf.variable_scope('res1'): 65 | res1 = residual(conv3, 128, 3, 1) 66 | with tf.variable_scope('res2'): 67 | res2 = residual(res1, 128, 3, 1) 68 | with tf.variable_scope('res3'): 69 | res3 = residual(res2, 128, 3, 1) 70 | with tf.variable_scope('res4'): 71 | res4 = residual(res3, 128, 3, 1) 72 | with tf.variable_scope('res5'): 73 | res5 = residual(res4, 128, 3, 1) 74 | with tf.variable_scope('deconv1'): 75 | deconv1 = tf.nn.relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training))) 76 | with tf.variable_scope('deconv2'): 77 | deconv2 = tf.nn.relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training))) 78 | with tf.variable_scope('deconv3'): 79 | deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1))) 80 | 81 | y = (deconv3 + 1) * 127.5 82 | 83 | # Remove border effect reducing padding. 84 | height = tf.shape(y)[1] 85 | width = tf.shape(y)[2] 86 | y = tf.slice(y, [0, 10, 10, 0], tf.stack([-1, height - 20, width - 20, -1])) 87 | 88 | return y -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | from flask_mail import Mail, Message 3 | from celery import Celery 4 | from os import mkdir, remove 5 | from os.path import join, exists, splitext 6 | import re 7 | import time 8 | import base64 9 | import json 10 | import subprocess 11 | 12 | 13 | app = Flask(__name__) 14 | app.config.from_pyfile('default_config.py') 15 | mail = Mail(app) 16 | celery = Celery( 17 | app.name, broker=app.config['CELERY_BROKER_URL']) 18 | celery.conf.update(app.config) 19 | 20 | 21 | @app.route('/', methods=['GET']) 22 | def home(): 23 | return app.send_static_file('index.html') 24 | 25 | 26 | @app.route('/help', methods=['GET']) 27 | def help(): 28 | return jsonify(status='HELP_SUCCESS', models=list(app.config['MODEL_FILES']), \ 29 | format='/transform: {email:xxx(receive output), filename:xxx(jpg or png), model:xxx(see models), image:xxxx(base64 encode image)}') 30 | 31 | 32 | @app.route('/transform', methods=['POST']) 33 | def transform(): 34 | # 获取参数 35 | json_data = json.loads(request.get_data().decode(encoding='utf-8')) 36 | filename = json_data.get('filename') 37 | model = json_data.get('model') 38 | image = json_data.get('image') 39 | email = json_data.get('email') 40 | 41 | app.logger.info("%s %d %s %s", filename, len(image), model, email) 42 | 43 | # 检查参数 44 | if filename is None or image is None or email is None or model is None: 45 | return jsonify(status='PARAMS ERROR') 46 | if re.match("^[a-zA-Z0-9_\\-.]+$", filename) is None or splitext(filename)[1] not in app.config['ALLOWED_EXTENSIONS']: 47 | return jsonify(status='FILENAME NOT SUPPORT') 48 | if re.match("^.+\\@(\\[?)[a-zA-Z0-9\\-\\.]+\\.([a-zA-Z]{2,3}|[0-9]{1,3})(\\]?)$", email) is None: 49 | return jsonify(status='EMAIL FORMAT ERROR') 50 | if model not in app.config['MODEL_FILES']: 51 | return jsonify(status='MODEL NOT EXISTS') 52 | 53 | try: 54 | image = base64.b64decode(image) 55 | except TypeError: 56 | return jsonify(status='IMAGE ERROR') 57 | 58 | # 保存图片 59 | filename = str(time.time()) + '-' + filename 60 | with open(join(app.config['UPLOAD_FOLDER'], filename), 'wb') as f: 61 | f.write(image) 62 | 63 | # 异步转换 64 | transform_async.delay(filename, email, model) 65 | return jsonify(status='SUBMIT_SUCCESS') 66 | 67 | 68 | @celery.task 69 | def transform_async(filename, email, model): 70 | # 开始转换 71 | content_file_path = join(app.config['UPLOAD_FOLDER'], filename) 72 | model_file_path = join(app.config['MODEL_FOLDER'], model) 73 | output_folder = app.config['OUTPUT_FOLDER'] 74 | 75 | output_filename = filename 76 | (shotname, extension) = splitext(output_filename) 77 | output_filename = shotname + '-' + model + extension 78 | output_file_path = join(output_folder, output_filename) 79 | 80 | command = 'python eval.py --CONTENT_IMAG %s --MODEL_PATH %s --OUTPUT_FOLDER %s' % ( 81 | content_file_path, model_file_path, output_folder) 82 | status, output = subprocess.getstatusoutput(command) 83 | 84 | # 打印日志 85 | print(status, output) 86 | 87 | # 发送邮件 88 | if status == 0: 89 | with app.app_context(): 90 | msg = Message("IMAGE-STYLE-TRANSFER", 91 | sender=app.config['MAIL_USERNAME'], recipients=[email]) 92 | msg.body = filename 93 | with app.open_resource(output_file_path) as f: 94 | mime_type = 'image/jpg' if splitext( 95 | filename)[1] is not '.png' else 'image/png' 96 | msg.attach(filename, mime_type, f.read()) 97 | mail.send(msg) 98 | else: 99 | with app.app_context(): 100 | msg = Message("IMAGE-STYLE-TRANSFER", 101 | sender=app.config['MAIL_USERNAME'], recipients=[email]) 102 | msg.body = "CONVERT ERROR\n" + filename + "\n HELP - http://host:port/help" 103 | mail.send(msg) 104 | 105 | remove_files.apply_async( 106 | args=[[content_file_path, output_file_path]], countdown=60) 107 | 108 | 109 | @celery.task 110 | def remove_files(file_list): 111 | for file in file_list: 112 | if exists(file): 113 | remove(file) 114 | 115 | 116 | if __name__ == '__main__': 117 | if not exists(app.config['UPLOAD_FOLDER']): 118 | mkdir(app.config['UPLOAD_FOLDER']) 119 | if not exists(app.config['OUTPUT_FOLDER']): 120 | mkdir(app.config['OUTPUT_FOLDER']) 121 | 122 | app.run(host='0.0.0.0') 123 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import tensorflow as tf 4 | import vgg 5 | import transform 6 | import loss 7 | import reader 8 | 9 | 10 | tf.app.flags.DEFINE_float("CONTENT_WEIGHT", 1.0, 11 | "Weight for content features loss") 12 | tf.app.flags.DEFINE_float("STYLE_WEIGHT", 10.0, 13 | "Weight for style features loss") 14 | tf.app.flags.DEFINE_float("TV_WEIGHT", 1e-6, 15 | "Weight for total variation loss") 16 | tf.app.flags.DEFINE_float("LEARNING_RATE", 1e-3, 17 | "Learning rate for training") 18 | tf.app.flags.DEFINE_integer("EPOCHS", 2, 19 | "Num epochs") 20 | tf.app.flags.DEFINE_string( 21 | "STYLE_IMAGES", "style-images/style-image.png", "Styles to train") 22 | tf.app.flags.DEFINE_float("STYLE_SCALE", 1.0, 23 | "Scale styles. Higher extracts smaller features") 24 | tf.app.flags.DEFINE_integer("IMAGE_SIZE", 256, "Size of output image") 25 | tf.app.flags.DEFINE_integer("BATCH_SIZE", 4, 26 | "Number of concurrent images to train on") 27 | tf.app.flags.DEFINE_string("MODEL_PATH", "models/fast-style-transfer.ckpt", 28 | "Path to read/write trained models") 29 | tf.app.flags.DEFINE_string("VGG_PATH", "imagenet-vgg-verydeep-19.mat", 30 | "Path to vgg model weights") 31 | tf.app.flags.DEFINE_string("TRAIN_IMAGES_FOLDER", "train2014/", 32 | "Path to training images") 33 | tf.app.flags.DEFINE_string("CONTENT_LAYERS", "relu4_2", 34 | "Which VGG layer to extract content loss from") 35 | tf.app.flags.DEFINE_string("STYLE_LAYERS", 36 | "relu1_1,relu2_1,relu3_1,relu4_1,relu5_1", 37 | "Which layers to extract style from") 38 | 39 | FLAGS = tf.app.flags.FLAGS 40 | 41 | # how to select GPU 42 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 43 | 44 | 45 | def optimize(): 46 | MODEL_DIR_NAME = os.path.dirname(FLAGS.MODEL_PATH) 47 | if not os.path.exists(MODEL_DIR_NAME): 48 | os.mkdir(MODEL_DIR_NAME) 49 | 50 | style_paths = FLAGS.STYLE_IMAGES.split(',') 51 | style_layers = FLAGS.STYLE_LAYERS.split(',') 52 | content_layers = FLAGS.CONTENT_LAYERS.split(',') 53 | 54 | # style gram matrix 55 | style_features_t = loss.get_style_features(style_paths, style_layers, 56 | FLAGS.IMAGE_SIZE, FLAGS.STYLE_SCALE, FLAGS.VGG_PATH) 57 | 58 | with tf.Graph().as_default(), tf.Session() as sess: 59 | # train_images 60 | images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, 61 | FLAGS.TRAIN_IMAGES_FOLDER, FLAGS.EPOCHS) 62 | 63 | generated = transform.net(images - vgg.MEAN_PIXEL, training=True) 64 | net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat([generated, images], 0) - vgg.MEAN_PIXEL) 65 | 66 | # 损失函数 67 | content_loss = loss.content_loss(net, content_layers) 68 | style_loss = loss.style_loss( 69 | net, style_features_t, style_layers) / len(style_paths) 70 | 71 | total_loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + \ 72 | FLAGS.TV_WEIGHT * loss.total_variation_loss(generated) 73 | 74 | # 准备训练 75 | global_step = tf.Variable(0, name="global_step", trainable=False) 76 | 77 | variable_to_train = [] 78 | for variable in tf.trainable_variables(): 79 | if not variable.name.startswith('vgg19'): 80 | variable_to_train.append(variable) 81 | 82 | train_op = tf.train.AdamOptimizer(FLAGS.LEARNING_RATE).minimize( 83 | total_loss, global_step=global_step, var_list=variable_to_train) 84 | 85 | variables_to_restore = [] 86 | for v in tf.global_variables(): 87 | if not v.name.startswith('vgg19'): 88 | variables_to_restore.append(v) 89 | 90 | # 开始训练 91 | saver = tf.train.Saver(variables_to_restore, 92 | write_version=tf.train.SaverDef.V1) 93 | sess.run([tf.global_variables_initializer(), 94 | tf.local_variables_initializer()]) 95 | 96 | # 加载检查点 97 | ckpt = tf.train.latest_checkpoint(MODEL_DIR_NAME) 98 | if ckpt: 99 | tf.logging.info('Restoring model from {}'.format(ckpt)) 100 | saver.restore(sess, ckpt) 101 | 102 | coord = tf.train.Coordinator() 103 | threads = tf.train.start_queue_runners(coord=coord) 104 | start_time = time.time() 105 | try: 106 | while not coord.should_stop(): 107 | _, loss_t, step = sess.run([train_op, total_loss, global_step]) 108 | elapsed_time = time.time() - start_time 109 | start_time = time.time() 110 | 111 | if step % 10 == 0: 112 | tf.logging.info( 113 | 'step: %d, total loss %f, secs/step: %f' % (step, loss_t, elapsed_time)) 114 | 115 | if step % 10000 == 0: 116 | saver.save(sess, FLAGS.MODEL_PATH, global_step=step) 117 | tf.logging.info('Save model') 118 | 119 | except tf.errors.OutOfRangeError: 120 | saver.save(sess, FLAGS.MODEL_PATH + '-done') 121 | tf.logging.info('Done training -- epoch limit reached') 122 | finally: 123 | coord.request_stop() 124 | 125 | coord.join(threads) 126 | 127 | 128 | if __name__ == '__main__': 129 | tf.logging.set_verbosity(tf.logging.INFO) 130 | optimize() 131 | -------------------------------------------------------------------------------- /static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |Upload photo
86 |89 | Drop your photo here. 90 |
91 |Popular styles
114 |