├── .gitignore ├── LICENSE ├── README.md ├── content_imgs ├── brad_pitt.jpg ├── chicago.jpg ├── gandalf.jpg ├── golden_gate.jpg ├── hoovertowernight.jpg ├── matheus.jpg └── tubingen.jpg ├── data ├── download_coco.sh └── models │ └── download_models.sh ├── fast_style_transfer.py ├── iterative_style_transfer.py ├── layers.py ├── make_gram_dataset.py ├── make_style_dataset.py ├── model.py ├── readme_imgs ├── 1style │ ├── chicago_style_the_scream.png │ ├── golden_gate_style_feathers.png │ └── hoovertowernight_style_candy.png ├── 6styles │ ├── chicago_style_the_scream.png │ ├── golden_gate_style_feathers.png │ └── hoovertowernight_style_candy.png ├── candy_tubingen_init_content.gif └── candy_tubingen_init_random.gif ├── style_imgs ├── candy.jpg ├── composition_vii.jpg ├── escher_sphere.jpg ├── feathers.jpg ├── frida_kahlo.jpg ├── la_muse.jpg ├── mosaic.jpg ├── picasso_selfport1907.jpg ├── seated-nude.jpg ├── shipwreck.jpg ├── starry_night.jpg ├── the_scream.jpg ├── udnie.jpg ├── wave_crop.jpg └── woman-with-hat-matisse.jpg ├── train.py ├── training.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/linux,macos,python 2 | 3 | ### Linux ### 4 | *~ 5 | 6 | # temporary files which can be created if a process still has a handle open of a deleted file 7 | .fuse_hidden* 8 | 9 | # KDE directory preferences 10 | .directory 11 | 12 | # Linux trash folder which might appear on any partition or disk 13 | .Trash-* 14 | 15 | # .nfs files are created when an open file is removed but is still being accessed 16 | .nfs* 17 | 18 | ### macOS ### 19 | *.DS_Store 20 | .AppleDouble 21 | .LSOverride 22 | 23 | # Icon must end with two \r 24 | Icon 25 | 26 | 27 | # Thumbnails 28 | ._* 29 | 30 | # Files that might appear in the root of a volume 31 | .DocumentRevisions-V100 32 | .fseventsd 33 | .Spotlight-V100 34 | .TemporaryItems 35 | .Trashes 36 | .VolumeIcon.icns 37 | .com.apple.timemachine.donotpresent 38 | 39 | # Directories potentially created on remote AFP share 40 | .AppleDB 41 | .AppleDesktop 42 | Network Trash Folder 43 | Temporary Items 44 | .apdisk 45 | 46 | ### Python ### 47 | # Byte-compiled / optimized / DLL files 48 | __pycache__/ 49 | *.py[cod] 50 | *$py.class 51 | 52 | # C extensions 53 | *.so 54 | 55 | # Distribution / packaging 56 | .Python 57 | env/ 58 | build/ 59 | develop-eggs/ 60 | dist/ 61 | downloads/ 62 | eggs/ 63 | .eggs/ 64 | lib/ 65 | lib64/ 66 | parts/ 67 | sdist/ 68 | var/ 69 | wheels/ 70 | *.egg-info/ 71 | .installed.cfg 72 | *.egg 73 | 74 | # PyInstaller 75 | # Usually these files are written by a python script from a template 76 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 77 | *.manifest 78 | *.spec 79 | 80 | # Installer logs 81 | pip-log.txt 82 | pip-delete-this-directory.txt 83 | 84 | # Unit test / coverage reports 85 | htmlcov/ 86 | .tox/ 87 | .coverage 88 | .coverage.* 89 | .cache 90 | nosetests.xml 91 | coverage.xml 92 | *,cover 93 | .hypothesis/ 94 | 95 | # Translations 96 | *.mo 97 | *.pot 98 | 99 | # Django stuff: 100 | *.log 101 | local_settings.py 102 | 103 | # Flask stuff: 104 | instance/ 105 | .webassets-cache 106 | 107 | # Scrapy stuff: 108 | .scrapy 109 | 110 | # Sphinx documentation 111 | docs/_build/ 112 | 113 | # PyBuilder 114 | target/ 115 | 116 | # Jupyter Notebook 117 | .ipynb_checkpoints 118 | 119 | # pyenv 120 | .python-version 121 | 122 | # celery beat schedule file 123 | celerybeat-schedule 124 | 125 | # dotenv 126 | .env 127 | 128 | # virtualenv 129 | .venv/ 130 | venv/ 131 | ENV/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # End of https://www.gitignore.io/api/linux,macos,python 140 | # Ignore the dataset 141 | data/coco 142 | # Ignore .h5 files 143 | *.h5 144 | *.hdf5 145 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Roberto de Moura Estevão Filho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # neural-style-keras 2 | Implementation of neural style transfer methods using Keras and Tensorflow. This includes the iterative stylization from [1], the fast style transfer from [2] and the multi-style nets from [3]. All models include instance normalization [4]. 3 | 4 | ## Fast style transfer 5 | To use a pre-tained net simply use the script 6 | 7 | ```bash 8 | python fast_stye_transfer.py --checkpoint_path data/models/candy.h5 --input_path content_imgs\ 9 | --output_path pastiche_output --use_style_name 10 | ``` 11 | 12 | All the images in the `input_path` will be stylized and the resulting images will be saved on the `output_path`. If the network was trained on multiple styles, all styles will be applied. 13 | 14 | ## Training a pastiche network 15 | #### Set up coco 16 | You may download [the coco dataset](http://mscoco.org/home/) [5] with the provided script (it may take a while) 17 | ```bash 18 | cd data; sh download_coco.sh 19 | ``` 20 | Afterwards, preprocess it into an hdf5 file 21 | ```bash 22 | python make_style_dataset.py 23 | ``` 24 | #### Make the gram dataset 25 | You should then make the gram dataset, this will pre-compute the gram matrices for the styles the net will be trained on. At this stage you should input the size of each image for the computation of the gram matrices. For more information on the usage see [this section](#make_gram_dataset). Below is an example: 26 | ```bash 27 | python make_gram_dataset.py --gram_dataset_path candy_grams.h5 --style_dir style_imgs \ 28 | --style_imgs candy.jpg --style_img_size 384 --gpu 0 29 | ``` 30 | #### Train the network 31 | You may then train the network with the `train.py` script. For more usage information see [this section](#train). Below is an example: 32 | ```bash 33 | python train.py --content_weight 1 --style_weight 10 --tv_weight 1e-4 --norm_by_channels \ 34 | --gram_dataset_path candy_grams.h5 --checkpoint_path candy.h5 35 | ``` 36 | 37 | ### Stylized images 38 | Below are some examples of images stylized by trained networks. Each row contains, respectively, the style image, the content image and images stylized by a single style net and by a 6-style net. All images on the fourth column were stylized by the same network. 39 | 40 |
41 | 42 | 43 | 44 | 45 |
46 |
47 | 48 | 49 | 50 | 51 |
52 |
53 | 54 | 55 | 56 | 57 |
58 | 59 | ## Script usage 60 | ### make_gram_dataset 61 | This script pre-computes the gram matrices of the style images and stores them in an hdf5 file. All gram matrices are computed without channel normalization and will be normalized before training if necessary. The options are: 62 | * `--style_dir [gram_imgs]`: directory that contains the style images. 63 | * `--gram_dataset_path [grams.h5]`: where to save the output file. 64 | * `--style_imgs [None]`: image file names that will be used, can be a list. If `None`, all images in the directory will be used. 65 | * `--style_img_size [None]`: largest size of the image, can be a single size that will be applied to all images or a list with a size for each image. If `None`, the image will not be resized. 66 | * `--gpu`: which gpu to run the code on. If `-1`, the code will run on cpu. 67 | * `--allow_growth`: flag that stops tensorflow from allocating all gpu memory at the start of the session. 68 | 69 | ### train 70 | Script that trains the pastiche network. You should have preprocessed the coco dataset and a gram dataset before using it. Below are all the options: 71 | * `--lr [0.01]`: learning rate that will be used on the adam optimizer to update the network weights. 72 | * `--content_weight [1.]`: weight of the content loss. Can be a single value that will be used for all styles or a list with a different weight for each style. 73 | * `--style_weight [1e-4]`: weight of the style loss. Can be a single value that will be used for all styles or a list with a different weight for each style. Do note that the default value is not reasonable if gram matrices are normalized by channels. If so, this should be around 10. 74 | * `--tv_weight [1e-4]`: weight of the total variation loss. Variation loss can be used to improve the local coherence of the images. 75 | * `--width_factor [2]`: how wide the convolutional layers are. This is a multiplicative factor. Use 2 for the original implementation on [2] and if reproducing [3]. Use 1 if reproducing results from [2] with instance normalization. 76 | * `--nb_classes [1]`: the number of styles the network will learn. 77 | * `--norm_by_channels`: flag that sets whether the gram matrices will be normalized by the number of channels. Use this to reproduce [2], do not use to reproduce [3]. 78 | * `--num_iterations [40000]`: how many iterations the model should be trained for. 79 | * `--save_every [500]`: how often the model will be saved to the checkpoint. 80 | * `--batch_size [4]`: batch size that will be used during training. 81 | * `--coco_path [data/coco/ms-coco-256.h5]`: path to the coco dataset hdf5 file. 82 | * `--gram_dataset_path [grams.h5]`: path to the gram dataset. 83 | * `--checkpoint_path [checkpoint.h5]`: where to save the checkpoints. The checkpoint includes the network weights as well as the training args. 84 | * `--gpu`: which gpu to run the code on. If `-1`, the code will run on cpu. It is not recommended to run this on cpu. 85 | * `--allow_growth`: flag that stops tensorflow from allocating all gpu memory at the start of the session. 86 | 87 | 88 | ## Iterative style transfer 89 | You can stylize the image by iteratively updating the input image so as to minimize loss. 90 | ```bash 91 | python iterative_style_transfer.py \ 92 | --content_image_path content_imgs/tubingen.jpg \ 93 | --style_image_path style_imgs/starry_night.jpg \ 94 | --output_path starry_tubingen.png 95 | ``` 96 | 97 | The full set of options include: 98 | * `--content_image_path`: path to the image that will be used as content. 99 | * `--style_image_path`: path to the image that will be used as style. 100 | * `--output_path`: where to save the output. 101 | * `--lr [10.]`: changes the learning rate of the adam optimizer used to update the image. 102 | * `--num_iterations [1000]`: number of iterations performed. 103 | * `--content_weight [1.]`: weight of the content loss. 104 | * `--style_weight [1e-4]`: weight of the style loss. 105 | * `--tv_weight [1e-4]`: weight of the total variation loss. 106 | * `--content_layers`: list of layers used on the content loss. 107 | * `--style_layers`: list of layers used on the style loss. 108 | * `--norm_by_channels`: whether the gram matrices will the normalized by the number of channels. Use this to compute gram matrices as in [2], do not use if to compute as in [3]. If using this, style_weight should be in the order of `10`. 109 | * `--img_size [512]`: largest size of the generated image, the aspect ratio is determined by the content image used. The larger the size the longer it takes. 110 | * `--style_img_size [None]`: largest size of the style image when computing the gram matrix. Changing this will change how the style is captured. If `None`, the original size is used. 111 | * `--print_and_save [100]`: how many iterations for each print and save. 112 | * `--init [random]`: type of initialization: `random` for noise initialization or `content` for initializing with the content image. 113 | * `--std_init [0.001]`: standard deviation for the noise initialization. 114 | * `--gpu`: which gpu to run it on. Use `-1` to run on cpu. It is very recommended to run this on a gpu. 115 | * `--allow_growth`: flag that stops tensorflow from allocating all gpu memory at the start of the session. 116 | 117 | #### Examples 118 | **Initializing with noise** 119 | 120 | ```bash 121 | python iterative_style_transfer.py \ 122 | --content_image_path content_imgs/tubingen.jpg \ 123 | --style_image_path style_imgs/candy.jpg \ 124 | --output_path candy_tubingen.png \ 125 | --content_weight 5 --style_weight 30 \ 126 | --tv_weight 1e-3 --norm_by_channels \ 127 | --style_img_size 384 --gpu 0 128 | ``` 129 | 130 |
131 | 132 |
133 | 134 | **Initializing with content** 135 | 136 | ```bash 137 | python iterative_style_transfer.py \ 138 | --content_image_path content_imgs/tubingen.jpg \ 139 | --style_image_path style_imgs/candy.jpg \ 140 | --output_path candy_tubingen_content.png \ 141 | --content_weight 5 --style_weight 100 \ 142 | --norm_by_channels --style_img_size 384 \ 143 | --gpu 0 --init content 144 | ``` 145 | 146 |
147 | 148 |
149 | 150 | ## Requirements 151 | * keras 152 | * tensorflow 153 | * h5py 154 | * pyyaml 155 | 156 | ## References 157 | * [1]: L. A. Gatys, A. S. Ecker and M. Bethge. "A Neural Algorithm for Artistic Style". [Arxiv](https://arxiv.org/abs/1508.06576). 158 | * [2]: J. Johnson, A. Alahi and L. Fei-Fei. "Perceptual Losses for Real-Time Style Transfer and Super-Resolution". [Paper](http://cs.stanford.edu/people/jcjohns/papers/eccv16/JohnsonECCV16.pdf) [Github](https://github.com/jcjohnson/fast-neural-style) 159 | * [3]: V. Dumoulin, J. Shlens and M. Kudlur. "A Learned Representation for Artistic Style". [Arxiv](https://arxiv.org/abs/1610.07629) [Github](https://github.com/tensorflow/magenta/tree/master/magenta/models/image_stylization) 160 | * [4]: D. Ulyanov, A. Vedaldi and V. Lempitsky. "Instance Normalization: The Missing Ingredient for Fast Stylization". [Arxiv](https://arxiv.org/abs/1607.08022) 161 | * [5]: T. Lin et al. "Microsoft COCO: Common Objects in Context". [Arxiv](https://arxiv.org/abs/1405.0312) [Website](http://mscoco.org/home/) 162 | -------------------------------------------------------------------------------- /content_imgs/brad_pitt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/brad_pitt.jpg -------------------------------------------------------------------------------- /content_imgs/chicago.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/chicago.jpg -------------------------------------------------------------------------------- /content_imgs/gandalf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/gandalf.jpg -------------------------------------------------------------------------------- /content_imgs/golden_gate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/golden_gate.jpg -------------------------------------------------------------------------------- /content_imgs/hoovertowernight.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/hoovertowernight.jpg -------------------------------------------------------------------------------- /content_imgs/matheus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/matheus.jpg -------------------------------------------------------------------------------- /content_imgs/tubingen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/tubingen.jpg -------------------------------------------------------------------------------- /data/download_coco.sh: -------------------------------------------------------------------------------- 1 | mkdir -p coco/images 2 | cd coco/ 3 | wget -c http://msvocds.blob.core.windows.net/coco2014/train2014.zip 4 | unzip train2104.zip -d images/ 5 | wget -c http://msvocds.blob.core.windows.net/coco2014/val2014.zip 6 | unzip val2014.zip -d images/ 7 | -------------------------------------------------------------------------------- /data/models/download_models.sh: -------------------------------------------------------------------------------- 1 | wget -c -O candy.h5 https://www.dropbox.com/s/jj4g1k11hfmuylx/candy.h5?dl=0 2 | wget -c -O feathers.h5 https://www.dropbox.com/s/im79c6ni2a4hl19/feathers.h5?dl=0 3 | wget -c -O muse.h5 https://www.dropbox.com/s/2d45czsn07arfpa/muse.h5?dl=0 4 | wget -c -O scream.h5 https://www.dropbox.com/s/ku8r2h1nz1h8kow/scream.h5?dl=0 5 | wget -c -O udnie.h5 https://www.dropbox.com/s/bmwxg6ny9dg80dm/udnie.h5?dl=0 6 | wget -c -O candy_feathers.h5 https://www.dropbox.com/s/5bjwqgkrumsv20w/candy_feathers.h5?dl=0 7 | wget -c -O 6_styles.h5 https://www.dropbox.com/s/6hwusy6qifcwg8r/6_styles.h5?dl=0 8 | -------------------------------------------------------------------------------- /fast_style_transfer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Use a trained pastiche net to stylize images. 3 | ''' 4 | 5 | from __future__ import print_function 6 | import os 7 | import argparse 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import keras 12 | import keras.backend as K 13 | 14 | from utils import config_gpu, preprocess_image_scale, deprocess_image 15 | import h5py 16 | import yaml 17 | import time 18 | from scipy.misc import imsave 19 | 20 | from model import pastiche_model 21 | 22 | if __name__ == '__main__': 23 | 24 | parser = argparse.ArgumentParser(description='Use a trained pastiche network.') 25 | parser.add_argument('--checkpoint_path', type=str, default='checkpoint') 26 | parser.add_argument('--img_size', type=int, default=1024) 27 | parser.add_argument('--batch_size', type=int, default=4) 28 | parser.add_argument('--input_path', type=str, default='pastiche_input') 29 | parser.add_argument('--output_path', type=str, default='pastiche_output') 30 | parser.add_argument('--use_style_name', default=False, action='store_true') 31 | parser.add_argument('--gpu', type=str, default='') 32 | parser.add_argument('--allow_growth', default=False, action='store_true') 33 | 34 | args = parser.parse_args() 35 | 36 | config_gpu(args.gpu, args.allow_growth) 37 | 38 | # Strip the extension if there is one 39 | checkpoint_path = os.path.splitext(args.checkpoint_path)[0] 40 | 41 | with h5py.File(checkpoint_path + '.h5', 'r') as f: 42 | model_args = yaml.load(f.attrs['args']) 43 | style_names = f.attrs['style_names'] 44 | 45 | print('Creating pastiche model...') 46 | class_targets = K.placeholder(shape=(None,), dtype=tf.int32) 47 | # Intantiate the model using information stored on tha yaml file 48 | pastiche_net = pastiche_model(None, width_factor=model_args.width_factor, 49 | nb_classes=model_args.nb_classes, 50 | targets=class_targets) 51 | with h5py.File(checkpoint_path + '.h5', 'r') as f: 52 | pastiche_net.load_weights_from_hdf5_group(f['model_weights']) 53 | 54 | inputs = [pastiche_net.input, class_targets, K.learning_phase()] 55 | 56 | transfer_style = K.function(inputs, [pastiche_net.output]) 57 | 58 | num_batches = int(np.ceil(model_args.nb_classes / float(args.batch_size))) 59 | 60 | for img_name in os.listdir(args.input_path): 61 | print('Processing %s' %img_name) 62 | img = preprocess_image_scale(os.path.join(args.input_path, img_name), 63 | img_size=args.img_size) 64 | imgs = np.repeat(img, model_args.nb_classes, axis=0) 65 | out_name = os.path.splitext(os.path.split(img_name)[-1])[0] 66 | 67 | for batch_idx in range(num_batches): 68 | idx = batch_idx * args.batch_size 69 | 70 | batch = imgs[idx:idx + args.batch_size] 71 | indices = batch_idx * args.batch_size + np.arange(batch.shape[0]) 72 | 73 | if args.use_style_name: 74 | names = style_names[idx:idx + args.batch_size] 75 | else: 76 | names = indices 77 | print(' Processing styles %s' %str(names)) 78 | 79 | out = transfer_style([batch, indices, 0.])[0] 80 | 81 | for name, im in zip(names, out): 82 | print('Saving file %s_style_%s.png' %(out_name, str(name))) 83 | imsave(os.path.join(args.output_path, '%s_style_%s.png' %(out_name, str(name))), 84 | deprocess_image(im[None, :, :, :].copy())) 85 | -------------------------------------------------------------------------------- /iterative_style_transfer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Original neural style transfer algorithm that iteratively updates the input 3 | image in order to minimize the loss. Script inspired on the Keras example 4 | available at 5 | https://github.com/fchollet/keras/blob/master/examples/neural_style_transfer.py 6 | ''' 7 | 8 | import os 9 | import argparse 10 | import time 11 | 12 | import numpy as np 13 | import tensorflow 14 | import keras 15 | import keras.backend as K 16 | from keras.optimizers import Adam 17 | from keras.applications import vgg19 18 | from keras.layers import Input 19 | 20 | from training import get_content_features, get_style_features, get_content_losses, get_style_losses, tv_loss 21 | from utils import config_gpu, preprocess_image_scale, deprocess_image 22 | 23 | from scipy.misc import imsave 24 | 25 | 26 | if __name__ == '__main__': 27 | def_cl = ['block4_conv2'] 28 | def_sl = ['block1_conv1', 'block2_conv1', 29 | 'block3_conv1', 'block4_conv1', 30 | 'block5_conv1'] 31 | 32 | parser = argparse.ArgumentParser(description='Iterative style transfer.') 33 | parser.add_argument('--content_image_path', type=str, 34 | default='content_imgs/tubingen.jpg', 35 | help='Path to the image to transform.') 36 | parser.add_argument('--style_image_path', type=str, 37 | default='style_imgs/starry_night.jpg', nargs='+', 38 | help='Path to the style reference image. Can be a list.') 39 | parser.add_argument('--output_path', type=str, default='pastiche_output') 40 | parser.add_argument('--lr', help='Learning rate.', type=float, default=10.) 41 | parser.add_argument('--num_iterations', type=int, default=1000) 42 | parser.add_argument('--content_weight', type=float, default=1.) 43 | parser.add_argument('--style_weight', type=float, default=1e-4) 44 | parser.add_argument('--tv_weight', type=float, default=1e-4) 45 | parser.add_argument('--content_layers', type=str, nargs='+', default=def_cl) 46 | parser.add_argument('--style_layers', type=str, nargs='+', default=def_sl) 47 | parser.add_argument('--norm_by_channels', default=False, action='store_true') 48 | parser.add_argument('--img_size', type=int, default=512, 49 | help='Maximum heigth/width of generated image.') 50 | parser.add_argument('--style_img_size', type=int, default=None, 51 | help='Maximum height/width of the style images.') 52 | parser.add_argument('--print_and_save', type=float, default=100, 53 | help='Print and save image every chosen iterations.') 54 | parser.add_argument('--init', type=str, default='random', 55 | help='How to initialize the pastiche images.') 56 | parser.add_argument('--std_init', type=float, default=0.001, 57 | help='Standard deviation for random init.') 58 | parser.add_argument('--gpu', type=str, default='') 59 | parser.add_argument('--allow_growth', default=False, action='store_true') 60 | 61 | args = parser.parse_args() 62 | # Arguments parsed 63 | 64 | # Split the extension 65 | output_path, ext = os.path.splitext(args.output_path) 66 | if ext == '': 67 | ext = '.png' 68 | config_gpu(args.gpu, args.allow_growth) 69 | 70 | ## Precomputing the targets for content and style 71 | # Load content and style images 72 | content_image = preprocess_image_scale(args.content_image_path, 73 | img_size=args.img_size) 74 | style_images = [preprocess_image_scale(img, img_size=args.style_img_size) 75 | for img in args.style_image_path] 76 | nb_styles = len(style_images) 77 | 78 | model = vgg19.VGG19(weights='imagenet', include_top=False) 79 | outputs_dict = dict([(layer.name, layer.output) for layer in model.layers]) 80 | 81 | content_features = get_content_features(outputs_dict, args.content_layers) 82 | style_features = get_style_features(outputs_dict, args.style_layers, 83 | norm_by_channels=args.norm_by_channels) 84 | 85 | get_content_fun = K.function([model.input], content_features) 86 | get_style_fun = K.function([model.input], style_features) 87 | 88 | content_targets = get_content_fun([content_image]) 89 | # List of list of features 90 | style_targets_list = [get_style_fun([img]) for img in style_images] 91 | 92 | # List of batched features 93 | style_targets = [] 94 | for l in range(len(args.style_layers)): 95 | batched_features = [] 96 | for i in range(nb_styles): 97 | batched_features.append(style_targets_list[i][l][None]) 98 | style_targets.append(np.concatenate(batched_features)) 99 | 100 | if args.init == 'content': 101 | pastiche_image = K.variable(np.repeat(content_image, nb_styles, axis=0)) 102 | else: 103 | if args.init != 'random': 104 | print('Could not recognize init arg \'%s\'. Falling back to random.' %args.init) 105 | pastiche_image = K.variable(args.std_init*np.random.randn(nb_styles, *content_image.shape[1:])) 106 | 107 | # Store targets as variables 108 | content_targets_dict = {k: K.variable(v) for k, v in zip(args.content_layers, content_targets)} 109 | style_targets_dict = {k: K.variable(v) for k, v in zip(args.style_layers, style_targets)} 110 | 111 | model = vgg19.VGG19(weights='imagenet', include_top=False, input_tensor=Input(tensor=pastiche_image)) 112 | outputs_dict = dict([(layer.name, layer.output) for layer in model.layers]) 113 | 114 | content_losses = get_content_losses(outputs_dict, content_targets_dict, 115 | args.content_layers) 116 | style_losses = get_style_losses(outputs_dict, style_targets_dict, 117 | args.style_layers, 118 | norm_by_channels=args.norm_by_channels) 119 | 120 | # Total variation loss is used to improve local coherence 121 | total_var_loss = tv_loss(pastiche_image) 122 | 123 | # Compute total loss 124 | weighted_style_losses = [] 125 | weighted_content_losses = [] 126 | 127 | total_loss = K.variable(0.) 128 | for loss in style_losses: 129 | weighted_loss = args.style_weight * K.mean(loss) 130 | weighted_style_losses.append(weighted_loss) 131 | total_loss += weighted_loss 132 | for loss in content_losses: 133 | weighted_loss = args.content_weight * K.mean(loss) 134 | weighted_content_losses.append(weighted_loss) 135 | total_loss += weighted_loss 136 | weighted_tv_loss = args.tv_weight * K.mean(total_var_loss) 137 | total_loss += weighted_tv_loss 138 | 139 | opt = Adam(lr=args.lr) 140 | updates = opt.get_updates([pastiche_image], {}, total_loss) 141 | # List of outputs 142 | outputs = [total_loss] + weighted_content_losses + weighted_style_losses + [weighted_tv_loss] 143 | 144 | # Function that makes a step after backpropping to the image 145 | make_step = K.function([], outputs, updates) 146 | 147 | 148 | # Perform optimization steps and save the results 149 | start_time = time.time() 150 | 151 | for i in range(args.num_iterations): 152 | out = make_step([]) 153 | if (i + 1) % args.print_and_save == 0: 154 | print('Iteration %d/%d' %(i + 1, args.num_iterations)) 155 | N = len(content_losses) 156 | for j, l in enumerate(out[1:N+1]): 157 | print(' Content loss %d: %g' %(j, l)) 158 | for j, l in enumerate(out[N+1:-1]): 159 | print(' Style loss %d: %g' %(j, l)) 160 | 161 | print(' Total style loss: %g' %(sum(out[N+1:-1]))) 162 | print(' TV loss: %g' %(out[-1])) 163 | print(' Total loss: %g' %out[0]) 164 | stop_time = time.time() 165 | print('Did %d iterations in %.2fs.' %(args.print_and_save, stop_time - start_time)) 166 | x = K.get_value(pastiche_image) 167 | for s in range(nb_styles): 168 | fname = output_path + '_style%d_%d%s' %(s, (i + 1) / args.print_and_save, ext) 169 | print('Saving image to %s.\n' %fname) 170 | img = deprocess_image(x[s:s+1]) 171 | imsave(fname, img) 172 | start_time = time.time() 173 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Custom Keras layers used on the pastiche model. 3 | ''' 4 | 5 | import tensorflow as tf 6 | import keras 7 | from keras import initializations 8 | from keras.layers import ZeroPadding2D, Layer, InputSpec 9 | 10 | # Extending the ZeroPadding2D layer to do reflection padding instead. 11 | class ReflectionPadding2D(ZeroPadding2D): 12 | def call(self, x, mask=None): 13 | pattern = [[0, 0], 14 | [self.top_pad, self.bottom_pad], 15 | [self.left_pad, self.right_pad], 16 | [0, 0]] 17 | return tf.pad(x, pattern, mode='REFLECT') 18 | 19 | 20 | class InstanceNormalization(Layer): 21 | def __init__(self, epsilon=1e-5, weights=None, 22 | beta_init='zero', gamma_init='one', **kwargs): 23 | self.beta_init = initializations.get(beta_init) 24 | self.gamma_init = initializations.get(gamma_init) 25 | self.epsilon = epsilon 26 | super(InstanceNormalization, self).__init__(**kwargs) 27 | 28 | def build(self, input_shape): 29 | # This currently only works for 4D inputs: assuming (B, H, W, C) 30 | self.input_spec = [InputSpec(shape=input_shape)] 31 | shape = (1, 1, 1, input_shape[-1]) 32 | 33 | self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name)) 34 | self.beta = self.beta_init(shape, name='{}_beta'.format(self.name)) 35 | self.trainable_weights = [self.gamma, self.beta] 36 | 37 | self.built = True 38 | 39 | def call(self, x, mask=None): 40 | # Do not regularize batch axis 41 | reduction_axes = [1, 2] 42 | 43 | mean, var = tf.nn.moments(x, reduction_axes, 44 | shift=None, name=None, keep_dims=True) 45 | x_normed = tf.nn.batch_normalization(x, mean, var, self.beta, self.gamma, self.epsilon) 46 | return x_normed 47 | 48 | def get_config(self): 49 | config = {"epsilon": self.epsilon} 50 | base_config = super(InstanceNormalization, self).get_config() 51 | return dict(list(base_config.items()) + list(config.items())) 52 | 53 | 54 | class ConditionalInstanceNormalization(InstanceNormalization): 55 | def __init__(self, targets, nb_classes, **kwargs): 56 | self.targets = targets 57 | self.nb_classes = nb_classes 58 | super(ConditionalInstanceNormalization, self).__init__(**kwargs) 59 | 60 | def build(self, input_shape): 61 | # This currently only works for 4D inputs: assuming (B, H, W, C) 62 | self.input_spec = [InputSpec(shape=input_shape)] 63 | shape = (self.nb_classes, 1, 1, input_shape[-1]) 64 | 65 | self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name)) 66 | self.beta = self.beta_init(shape, name='{}_beta'.format(self.name)) 67 | self.trainable_weights = [self.gamma, self.beta] 68 | 69 | self.built = True 70 | 71 | def call(self, x, mask=None): 72 | # Do not regularize batch axis 73 | reduction_axes = [1, 2] 74 | 75 | mean, var = tf.nn.moments(x, reduction_axes, 76 | shift=None, name=None, keep_dims=True) 77 | 78 | # Get the appropriate lines of gamma and beta 79 | beta = tf.gather(self.beta, self.targets) 80 | gamma = tf.gather(self.gamma, self.targets) 81 | x_normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, self.epsilon) 82 | 83 | return x_normed 84 | -------------------------------------------------------------------------------- /make_gram_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script makes a hdf5 style dataset with all images in a chosen directory. 3 | Gram matrices computed here are never normalized by the number of channels. 4 | Normalization is done if necessary on the training stage. 5 | ''' 6 | import numpy as np 7 | import h5py 8 | 9 | import keras 10 | import keras.backend as K 11 | from keras.applications import vgg16 12 | 13 | from training import get_style_features 14 | from utils import preprocess_image_scale, config_gpu, std_input_list 15 | 16 | import os 17 | import argparse 18 | 19 | if __name__ == "__main__": 20 | 21 | def_sl = ['block1_conv2', 'block2_conv2', 22 | 'block3_conv3', 'block4_conv3'] 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--style_dir', type=str, default='gram_imgs', 26 | help='Directory that contains the images.') 27 | parser.add_argument('--gram_dataset_path', type=str, default='grams.h5', 28 | help='Name of the output hdf5 file.') 29 | parser.add_argument('--style_imgs', type=str, default=None, nargs='+', 30 | help='Style image file names.') 31 | parser.add_argument('--style_img_size', type=int, default=[None], nargs='+', 32 | help='Largest size of the style images') 33 | parser.add_argument('--style_layers', type=str, nargs='+', default=def_sl) 34 | parser.add_argument('--gpu', type=str, default='') 35 | parser.add_argument('--allow_growth', default=False, action='store_true') 36 | args = parser.parse_args() 37 | 38 | config_gpu(args.gpu, args.allow_growth) 39 | 40 | loss_net = vgg16.VGG16(weights='imagenet', include_top=False) 41 | 42 | targets_dict = dict([(layer.name, layer.output) for layer in loss_net.layers]) 43 | 44 | s_targets = get_style_features(targets_dict, args.style_layers) 45 | 46 | get_style_target = K.function([loss_net.input], s_targets) 47 | gm_lists = [[] for l in args.style_layers] 48 | 49 | img_list = [] 50 | img_size_list = [] 51 | # Get style image names or get all images in the directory 52 | if args.style_imgs is None: 53 | args.style_imgs = os.listdir(args.style_dir) 54 | 55 | # Check the image sizes 56 | args.style_img_size = std_input_list(args.style_img_size, len(args.style_imgs), 'Image size') 57 | 58 | for img_name, img_size in zip(args.style_imgs, args.style_img_size): 59 | try: 60 | print(img_name) 61 | img = preprocess_image_scale(os.path.join(args.style_dir, img_name), 62 | img_size=img_size) 63 | s_targets = get_style_target([img]) 64 | for l, t in zip(gm_lists, s_targets): 65 | l.append(t) 66 | img_list.append(os.path.splitext(img_name)[0]) 67 | img_size_list.append(img_size) 68 | except IOError as e: 69 | print('Could not open file %s as image.' %img_name) 70 | 71 | mtx = [] 72 | for l in gm_lists: 73 | mtx.append(np.concatenate(l)) 74 | 75 | f = h5py.File(args.gram_dataset_path, 'w') 76 | 77 | f.attrs['img_names'] = img_list 78 | f.attrs['img_sizes'] = img_size_list 79 | for name, m in zip(args.style_layers, mtx): 80 | f.create_dataset(name, data=m) 81 | 82 | f.flush() 83 | f.close() 84 | -------------------------------------------------------------------------------- /make_style_dataset.py: -------------------------------------------------------------------------------- 1 | import os, json, argparse 2 | from threading import Thread 3 | from Queue import Queue 4 | 5 | import numpy as np 6 | from scipy.misc import imread, imresize 7 | import h5py 8 | 9 | """ 10 | Create an HDF5 file of images for training a feedforward style transfer model. 11 | Original file created by Justin Johnson available at: 12 | https://github.com/jcjohnson/fast-neural-style 13 | """ 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--train_dir', default='data/coco/images/train2014') 17 | parser.add_argument('--val_dir', default='data/coco/images/val2014') 18 | parser.add_argument('--output_file', default='data/ms-coco-256.h5') 19 | parser.add_argument('--height', type=int, default=256) 20 | parser.add_argument('--width', type=int, default=256) 21 | parser.add_argument('--max_images', type=int, default=-1) 22 | parser.add_argument('--num_workers', type=int, default=2) 23 | parser.add_argument('--include_val', type=int, default=1) 24 | parser.add_argument('--max_resize', default=16, type=int) 25 | args = parser.parse_args() 26 | 27 | 28 | def add_data(h5_file, image_dir, prefix, args): 29 | # Make a list of all images in the source directory 30 | image_list = [] 31 | image_extensions = {'.jpg', '.jpeg', '.JPG', '.JPEG', '.png', '.PNG'} 32 | for filename in os.listdir(image_dir): 33 | ext = os.path.splitext(filename)[1] 34 | if ext in image_extensions: 35 | image_list.append(os.path.join(image_dir, filename)) 36 | num_images = len(image_list) 37 | 38 | # Resize all images and copy them into the hdf5 file 39 | # We'll bravely try multithreading 40 | dset_name = os.path.join(prefix, 'images') 41 | # dset_size = (num_images, 3, args.height, args.width) 42 | dset_size = (num_images, args.height, args.width, 3) 43 | imgs_dset = h5_file.create_dataset(dset_name, dset_size, np.uint8) 44 | 45 | # input_queue stores (idx, filename) tuples, 46 | # output_queue stores (idx, resized_img) tuples 47 | input_queue = Queue() 48 | output_queue = Queue() 49 | 50 | # Read workers pull images off disk and resize them 51 | def read_worker(): 52 | while True: 53 | idx, filename = input_queue.get() 54 | img = imread(filename) 55 | try: 56 | # First crop the image so its size is a multiple of max_resize 57 | H, W = img.shape[0], img.shape[1] 58 | H_crop = H - H % args.max_resize 59 | W_crop = W - W % args.max_resize 60 | img = img[:H_crop, :W_crop] 61 | img = imresize(img, (args.height, args.width)) 62 | except (ValueError, IndexError) as e: 63 | print filename 64 | print img.shape, img.dtype 65 | print e 66 | input_queue.task_done() 67 | output_queue.put((idx, img)) 68 | 69 | # Write workers write resized images to the hdf5 file 70 | def write_worker(): 71 | num_written = 0 72 | while True: 73 | idx, img = output_queue.get() 74 | if img.ndim == 3: 75 | # RGB image, transpose from H x W x C to C x H x W 76 | # DO NOT TRANSPOSE 77 | imgs_dset[idx] = img 78 | elif img.ndim == 2: 79 | # Grayscale image; it is H x W so broadcasting to C x H x W will just copy 80 | # grayscale values into all channels. 81 | # COPY GRAY SCALE TO CHANNELS DIMENSION 82 | img_dtype = img.dtype 83 | imgs_dset[idx] = (img[:, :, None] * np.array([1, 1, 1])).astype(img_dtype) 84 | output_queue.task_done() 85 | num_written = num_written + 1 86 | if num_written % 100 == 0: 87 | print 'Copied %d / %d images' % (num_written, num_images) 88 | 89 | # Start the read workers. 90 | for i in xrange(args.num_workers): 91 | t = Thread(target=read_worker) 92 | t.daemon = True 93 | t.start() 94 | 95 | # h5py locks internally, so we can only use a single write worker =( 96 | t = Thread(target=write_worker) 97 | t.daemon = True 98 | t.start() 99 | 100 | for idx, filename in enumerate(image_list): 101 | if args.max_images > 0 and idx >= args.max_images: break 102 | input_queue.put((idx, filename)) 103 | 104 | input_queue.join() 105 | output_queue.join() 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | 111 | with h5py.File(args.output_file, 'w') as f: 112 | add_data(f, args.train_dir, 'train2014', args) 113 | 114 | if args.include_val != 0: 115 | add_data(f, args.val_dir, 'val2014', args) 116 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This module contains functions for building the pastiche model. 3 | ''' 4 | 5 | import keras 6 | from keras.models import Model 7 | from keras.layers import (Convolution2D, Activation, UpSampling2D, 8 | ZeroPadding2D, Input, BatchNormalization, 9 | merge, Lambda) 10 | from layers import (ReflectionPadding2D, InstanceNormalization, 11 | ConditionalInstanceNormalization) 12 | from keras.initializations import normal 13 | 14 | # Initialize weights with normal distribution with std 0.01 15 | def weights_init(shape, name=None, dim_ordering=None): 16 | return normal(shape, scale=0.01, name=name) 17 | 18 | 19 | def conv(x, n_filters, kernel_size=3, stride=1, relu=True, nb_classes=1, targets=None): 20 | ''' 21 | Reflection padding, convolution, instance normalization and (maybe) relu. 22 | ''' 23 | if not kernel_size % 2: 24 | raise ValueError('Expected odd kernel size.') 25 | pad = (kernel_size - 1) / 2 26 | o = ReflectionPadding2D(padding=(pad, pad))(x) 27 | o = Convolution2D(n_filters, kernel_size, kernel_size, 28 | subsample=(stride, stride), init=weights_init)(o) 29 | # o = BatchNormalization()(o) 30 | if nb_classes > 1: 31 | o = ConditionalInstanceNormalization(targets, nb_classes)(o) 32 | else: 33 | o = InstanceNormalization()(o) 34 | if relu: 35 | o = Activation('relu')(o) 36 | return o 37 | 38 | 39 | def residual_block(x, n_filters, nb_classes=1, targets=None): 40 | ''' 41 | Residual block with 2 3x3 convolutions blocks. Last one is linear (no ReLU). 42 | ''' 43 | o = conv(x, n_filters) 44 | # Linear activation on second conv 45 | o = conv(o, n_filters, relu=False, nb_classes=nb_classes, targets=targets) 46 | # Shortcut connection 47 | o = merge([o, x], mode='sum') 48 | return o 49 | 50 | 51 | def upsampling(x, n_filters, nb_classes=1, targets=None): 52 | ''' 53 | Upsampling block with nearest-neighbor interpolation and a conv block. 54 | ''' 55 | o = UpSampling2D()(x) 56 | o = conv(o, n_filters, nb_classes=nb_classes, targets=targets) 57 | return o 58 | 59 | 60 | def pastiche_model(img_size, width_factor=2, nb_classes=1, targets=None): 61 | k = width_factor 62 | x = Input(shape=(img_size, img_size, 3)) 63 | o = conv(x, 16 * k, kernel_size=9, nb_classes=nb_classes, targets=targets) 64 | o = conv(o, 32 * k, stride=2, nb_classes=nb_classes, targets=targets) 65 | o = conv(o, 64 * k, stride=2, nb_classes=nb_classes, targets=targets) 66 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets) 67 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets) 68 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets) 69 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets) 70 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets) 71 | o = upsampling(o, 32 * k, nb_classes=nb_classes, targets=targets) 72 | o = upsampling(o, 16 * k, nb_classes=nb_classes, targets=targets) 73 | o = conv(o, 3, kernel_size=9, relu=False, nb_classes=nb_classes, targets=targets) 74 | o = Activation('tanh')(o) 75 | o = Lambda(lambda x: 150*x, name='scaling')(o) 76 | pastiche_net = Model(input=x, output=o) 77 | return pastiche_net 78 | -------------------------------------------------------------------------------- /readme_imgs/1style/chicago_style_the_scream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/1style/chicago_style_the_scream.png -------------------------------------------------------------------------------- /readme_imgs/1style/golden_gate_style_feathers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/1style/golden_gate_style_feathers.png -------------------------------------------------------------------------------- /readme_imgs/1style/hoovertowernight_style_candy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/1style/hoovertowernight_style_candy.png -------------------------------------------------------------------------------- /readme_imgs/6styles/chicago_style_the_scream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/6styles/chicago_style_the_scream.png -------------------------------------------------------------------------------- /readme_imgs/6styles/golden_gate_style_feathers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/6styles/golden_gate_style_feathers.png -------------------------------------------------------------------------------- /readme_imgs/6styles/hoovertowernight_style_candy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/6styles/hoovertowernight_style_candy.png -------------------------------------------------------------------------------- /readme_imgs/candy_tubingen_init_content.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/candy_tubingen_init_content.gif -------------------------------------------------------------------------------- /readme_imgs/candy_tubingen_init_random.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/candy_tubingen_init_random.gif -------------------------------------------------------------------------------- /style_imgs/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/candy.jpg -------------------------------------------------------------------------------- /style_imgs/composition_vii.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/composition_vii.jpg -------------------------------------------------------------------------------- /style_imgs/escher_sphere.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/escher_sphere.jpg -------------------------------------------------------------------------------- /style_imgs/feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/feathers.jpg -------------------------------------------------------------------------------- /style_imgs/frida_kahlo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/frida_kahlo.jpg -------------------------------------------------------------------------------- /style_imgs/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/la_muse.jpg -------------------------------------------------------------------------------- /style_imgs/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/mosaic.jpg -------------------------------------------------------------------------------- /style_imgs/picasso_selfport1907.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/picasso_selfport1907.jpg -------------------------------------------------------------------------------- /style_imgs/seated-nude.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/seated-nude.jpg -------------------------------------------------------------------------------- /style_imgs/shipwreck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/shipwreck.jpg -------------------------------------------------------------------------------- /style_imgs/starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/starry_night.jpg -------------------------------------------------------------------------------- /style_imgs/the_scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/the_scream.jpg -------------------------------------------------------------------------------- /style_imgs/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/udnie.jpg -------------------------------------------------------------------------------- /style_imgs/wave_crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/wave_crop.jpg -------------------------------------------------------------------------------- /style_imgs/woman-with-hat-matisse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/woman-with-hat-matisse.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script can be used to train a pastiche network. 3 | ''' 4 | 5 | from __future__ import print_function 6 | import os 7 | import argparse 8 | 9 | import time 10 | import h5py 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | import keras 15 | import keras.backend as K 16 | from keras.optimizers import Adam 17 | from model import pastiche_model 18 | from training import get_loss_net, get_content_losses, get_style_losses, tv_loss 19 | from utils import preprocess_input, config_gpu, save_checkpoint, std_input_list 20 | 21 | if __name__ == '__main__': 22 | def_cl = ['block3_conv3'] 23 | def_sl = ['block1_conv2', 'block2_conv2', 24 | 'block3_conv3', 'block4_conv3'] 25 | 26 | # Argument parser 27 | parser = argparse.ArgumentParser(description='Train a pastiche network.') 28 | parser.add_argument('--lr', help='Learning rate.', type=float, default=0.001) 29 | parser.add_argument('--content_weight', type=float, default=[1.], nargs='+') 30 | parser.add_argument('--style_weight', type=float, default=[1e-4], nargs='+') 31 | parser.add_argument('--tv_weight', type=float, default=[1e-4], nargs='+') 32 | parser.add_argument('--content_layers', type=str, nargs='+', default=def_cl) 33 | parser.add_argument('--style_layers', type=str, nargs='+', default=def_sl) 34 | parser.add_argument('--width_factor', type=int, default=2) 35 | parser.add_argument('--nb_classes', type=int, default=1) 36 | parser.add_argument('--norm_by_channels', default=False, action='store_true') 37 | parser.add_argument('--num_iterations', type=int, default=40000) 38 | parser.add_argument('--save_every', type=int, default=500) 39 | parser.add_argument('--batch_size', type=int, default=4) 40 | parser.add_argument('--coco_path', type=str, default='data/coco/ms-coco-256.h5') 41 | parser.add_argument('--gram_dataset_path', type=str, default='grams.h5') 42 | parser.add_argument('--checkpoint_path', type=str, default='checkpoint.h5') 43 | parser.add_argument('--gpu', type=str, default='') 44 | parser.add_argument('--allow_growth', default=False, action='store_true') 45 | 46 | args = parser.parse_args() 47 | # Arguments parsed 48 | 49 | # Check loss weights 50 | args.style_weight = std_input_list(args.style_weight, args.nb_classes, 'Style weight') 51 | args.content_weight = std_input_list(args.content_weight, args.nb_classes, 'Content weight') 52 | args.tv_weight = std_input_list(args.tv_weight, args.nb_classes, 'TV weight') 53 | 54 | config_gpu(args.gpu, args.allow_growth) 55 | 56 | print('Creating pastiche model...') 57 | class_targets = K.placeholder(shape=(None,), dtype=tf.int32) 58 | # The model will be trained with 256 x 256 images of the coco dataset. 59 | pastiche_net = pastiche_model(256, width_factor=args.width_factor, 60 | nb_classes=args.nb_classes, 61 | targets=class_targets) 62 | x = pastiche_net.input 63 | o = pastiche_net.output 64 | 65 | print('Loading loss network...') 66 | loss_net, outputs_dict, content_targets_dict = get_loss_net(pastiche_net.output, input_tensor=pastiche_net.input) 67 | 68 | # Placeholder sizes 69 | ph_sizes = {k : K.int_shape(content_targets_dict[k])[-1] for k in args.style_layers} 70 | 71 | # Our style targets are precomputed and are fed through these placeholders 72 | style_targets_dict = {k : K.placeholder(shape=(None, ph_sizes[k], ph_sizes[k])) for k in args.style_layers} 73 | 74 | 75 | print('Setting up training...') 76 | # Setup the loss weights as variables 77 | content_weights = K.variable(args.content_weight) 78 | style_weights = K.variable(args.style_weight) 79 | tv_weights = K.variable(args.tv_weight) 80 | 81 | style_losses = get_style_losses(outputs_dict, style_targets_dict, args.style_layers, 82 | norm_by_channels=args.norm_by_channels) 83 | 84 | content_losses = get_content_losses(outputs_dict, content_targets_dict, args.content_layers) 85 | 86 | # Use total variation to improve local coherence 87 | total_var_loss = tv_loss(pastiche_net.output) 88 | 89 | 90 | weighted_style_losses = [] 91 | weighted_content_losses = [] 92 | 93 | # Compute total loss 94 | total_loss = K.variable(0.) 95 | for loss in style_losses: 96 | weighted_loss = K.mean(K.gather(style_weights, class_targets) * loss) 97 | weighted_style_losses.append(weighted_loss) 98 | total_loss += weighted_loss 99 | 100 | for loss in content_losses: 101 | weighted_loss = K.mean(K.gather(content_weights, class_targets) * loss) 102 | weighted_content_losses.append(weighted_loss) 103 | total_loss += weighted_loss 104 | 105 | weighted_tv_loss = K.mean(K.gather(tv_weights, class_targets) * total_var_loss) 106 | total_loss += weighted_tv_loss 107 | 108 | 109 | ## Make training function 110 | 111 | # Get a list of inputs 112 | inputs = [pastiche_net.input, class_targets] + \ 113 | [style_targets_dict[k] for k in args.style_layers] + \ 114 | [K.learning_phase()] 115 | 116 | # Get trainable params 117 | params = pastiche_net.trainable_weights 118 | constraints = pastiche_net.constraints 119 | 120 | opt = Adam(lr=args.lr) 121 | updates = opt.get_updates(params, constraints, total_loss) 122 | 123 | # List of outputs 124 | outputs = [total_loss] + weighted_content_losses + weighted_style_losses + [weighted_tv_loss] 125 | 126 | f_train = K.function(inputs, outputs, updates) 127 | 128 | X = h5py.File(args.coco_path, 'r')['train2014']['images'] 129 | dataset_size = X.shape[0] 130 | batches_per_epoch = int(np.ceil(dataset_size / args.batch_size)) 131 | batch_idx = 0 132 | 133 | print('Loading Gram matrices from dataset file...') 134 | if args.norm_by_channels: 135 | print('Normalizing the stored Gram matrices by the number of channels.') 136 | Y = {} 137 | with h5py.File(args.gram_dataset_path, 'r') as f: 138 | styles = f.attrs['img_names'] 139 | style_sizes = f.attrs['img_sizes'] 140 | for k, v in f.iteritems(): 141 | Y[k] = np.array(v) 142 | if args.norm_by_channels: 143 | #Correct the Gram matrices from the dataset 144 | Y[k] /= Y[k].shape[-1] 145 | 146 | # Get a log going 147 | log = {} 148 | log['args'] = args 149 | log['style_names'] = styles[:args.nb_classes] 150 | log['style_image_sizes'] = style_sizes 151 | log['total_loss'] = [] 152 | log['style_loss'] = {k: [] for k in args.style_layers} 153 | log['content_loss'] = {k: [] for k in args.content_layers} 154 | log['tv_loss'] = [] 155 | 156 | # Strip the extension if there is one 157 | checkpoint_path = os.path.splitext(args.checkpoint_path)[0] 158 | 159 | start_time = time.time() 160 | # for it in range(args.num_iterations): 161 | for it in range(args.num_iterations): 162 | if batch_idx >= batches_per_epoch: 163 | print('Epoch done. Going back to the beginning...') 164 | batch_idx = 0 165 | 166 | # Get the batch 167 | idx = args.batch_size * batch_idx 168 | batch = X[idx:idx+args.batch_size] 169 | batch = preprocess_input(batch) 170 | batch_idx += 1 171 | 172 | # Get class information for each image on the batch 173 | batch_classes = np.random.randint(args.nb_classes, size=(args.batch_size,)) 174 | 175 | batch_targets = [Y[l][batch_classes] for l in args.style_layers] 176 | 177 | # Do a step 178 | start_time2 = time.time() 179 | out = f_train([batch, batch_classes] + batch_targets + [1.]) 180 | stop_time2 = time.time() 181 | # Log the statistics 182 | 183 | log['total_loss'].append(out[0]) 184 | offset = 1 185 | for i, k in enumerate(args.content_layers): 186 | log['content_loss'][k].append(out[offset + i]) 187 | offset += len(args.content_layers) 188 | for i, k in enumerate(args.style_layers): 189 | log['style_loss'][k].append(out[offset + i]) 190 | log['tv_loss'].append(out[-1]) 191 | 192 | stop_time = time.time() 193 | print('Iteration %d/%d: loss = %f. t = %f (%f)' %(it + 1, 194 | args.num_iterations, out[0], stop_time - start_time, 195 | stop_time2 - start_time2)) 196 | 197 | if not ((it + 1) % args.save_every): 198 | print('Saving checkpoint in %s.h5...' %(checkpoint_path)) 199 | save_checkpoint(checkpoint_path, pastiche_net, log) 200 | print('Checkpoint saved.') 201 | 202 | start_time = time.time() 203 | save_checkpoint(checkpoint_path, pastiche_net, log) 204 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.applications import vgg16 3 | 4 | ''' 5 | Module that defines loss functions and other auxiliary functions used when 6 | training a pastiche model. 7 | ''' 8 | 9 | def gram_matrix(x, norm_by_channels=False): 10 | ''' 11 | Returns the Gram matrix of the tensor x. 12 | ''' 13 | if K.ndim(x) == 3: 14 | features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1))) 15 | shape = K.shape(x) 16 | C, H, W = shape[0], shape[1], shape[2] 17 | gram = K.dot(features, K.transpose(features)) 18 | elif K.ndim(x) == 4: 19 | # Swap from (H, W, C) to (B, C, H, W) 20 | x = K.permute_dimensions(x, (0, 3, 1, 2)) 21 | shape = K.shape(x) 22 | B, C, H, W = shape[0], shape[1], shape[2], shape[3] 23 | # Reshape as a batch of 2D matrices with vectorized channels 24 | features = K.reshape(x, K.stack([B, C, H*W])) 25 | # This is a batch of Gram matrices (B, C, C). 26 | gram = K.batch_dot(features, features, axes=2) 27 | else: 28 | raise ValueError('The input tensor should be either a 3d (H, W, C) or 4d (B, H, W, C) tensor.') 29 | # Normalize the Gram matrix 30 | if norm_by_channels: 31 | denominator = C * H * W # Normalization from Johnson 32 | else: 33 | denominator = H * W # Normalization from Google 34 | gram = gram / K.cast(denominator, x.dtype) 35 | 36 | return gram 37 | 38 | 39 | def content_loss(x, target): 40 | ''' 41 | Content loss is simply the MSE between activations of a layer 42 | ''' 43 | return K.mean(K.square(target - x), axis=(1, 2, 3)) 44 | 45 | 46 | def style_loss(x, target, norm_by_channels=False): 47 | ''' 48 | Style loss is the MSE between Gram matrices computed using activation maps. 49 | ''' 50 | x_gram = gram_matrix(x, norm_by_channels=norm_by_channels) 51 | return K.mean(K.square(target - x_gram), axis=(1, 2)) 52 | 53 | 54 | def tv_loss(x): 55 | ''' 56 | Total variation loss is used to keep the image locally coherent 57 | ''' 58 | assert K.ndim(x) == 4 59 | a = K.square(x[:, :-1, :-1, :] - x[:, 1:, :-1, :]) 60 | b = K.square(x[:, :-1, :-1, :] - x[:, :-1, 1:, :]) 61 | return K.sum(a + b, axis=(1, 2, 3)) 62 | 63 | 64 | def get_content_features(out_dict, layer_names): 65 | return [out_dict[l] for l in layer_names] 66 | 67 | 68 | def get_style_features(out_dict, layer_names, norm_by_channels=False): 69 | features = [] 70 | for l in layer_names: 71 | layer_features = out_dict[l] 72 | S = gram_matrix(layer_features, norm_by_channels=norm_by_channels) 73 | features.append(S) 74 | return features 75 | 76 | 77 | def get_loss_net(pastiche_net_output, input_tensor=None): 78 | ''' 79 | Instantiates a VGG net and applies its layers on top of the pastiche net's 80 | output. 81 | ''' 82 | loss_net = vgg16.VGG16(weights='imagenet', include_top=False, 83 | input_tensor=input_tensor) 84 | targets_dict = dict([(layer.name, layer.output) for layer in loss_net.layers]) 85 | i = pastiche_net_output 86 | # We need to apply all layers to the output of the style net 87 | outputs_dict = {} 88 | for l in loss_net.layers[1:]: # Ignore the input layer 89 | i = l(i) 90 | outputs_dict[l.name] = i 91 | 92 | return loss_net, outputs_dict, targets_dict 93 | 94 | 95 | def get_style_losses(outputs_dict, targets_dict, style_layers, 96 | norm_by_channels=False): 97 | ''' 98 | Returns the style loss for the desired layers 99 | ''' 100 | return [style_loss(outputs_dict[l], targets_dict[l], 101 | norm_by_channels=norm_by_channels) 102 | for l in style_layers] 103 | 104 | def get_content_losses(outputs_dict, targets_dict, content_layers): 105 | return [content_loss(outputs_dict[l], targets_dict[l]) 106 | for l in content_layers] 107 | 108 | def get_total_loss(content_losses, style_losses, total_var_loss, 109 | content_weights, style_weights, tv_weights, class_targets): 110 | total_loss = K.variable(0.) 111 | 112 | # Compute content losses 113 | for loss in content_losses: 114 | weighted_loss = K.mean(K.gather(content_weights, class_targets) * loss) 115 | weighted_content_losses.append(weighted_loss) 116 | total_loss += weighted_loss 117 | 118 | # Compute style losses 119 | for loss in style_losses: 120 | weighted_loss = K.mean(K.gather(style_weights, class_targets) * loss) 121 | weighted_style_losses.append(weighted_loss) 122 | total_loss += weighted_loss 123 | 124 | # Compute tv loss 125 | weighted_tv_loss = K.mean(K.gather(tv_weights, class_targets) * 126 | total_var_loss) 127 | total_loss += weighted_tv_loss 128 | 129 | return (total_loss, weighted_content_losses, weighted_style_losses, 130 | weighted_tv_loss) 131 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Module constains utilitary functions. 3 | ''' 4 | import tensorflow as tf 5 | from keras.preprocessing.image import load_img, img_to_array 6 | import numpy as np 7 | import h5py 8 | import yaml 9 | from PIL import Image 10 | from keras.applications import vgg16 11 | from keras import backend as K 12 | 13 | def config_gpu(gpu, allow_growth): 14 | # Choosing gpu 15 | if gpu == '-1': 16 | config = tf.ConfigProto(device_count ={'GPU': 0}) 17 | else: 18 | if gpu == 'all' or gpu == '': 19 | gpu = '' 20 | config = tf.ConfigProto() 21 | config.gpu_options.visible_device_list = gpu 22 | if allow_growth == True: 23 | config.gpu_options.allow_growth = True 24 | session = tf.Session(config=config) 25 | K.set_session(session) 26 | 27 | def save_checkpoint(checkpoint_path, pastiche_net, log): 28 | with h5py.File(checkpoint_path + '.h5', 'w') as f: 29 | g = f.create_group('model_weights') 30 | pastiche_net.save_weights_to_hdf5_group(g) 31 | g = f.create_group('log') 32 | g.create_dataset('total_loss', data=np.array(log['total_loss'])) 33 | g.create_dataset('tv_loss', data=np.array(log['tv_loss'])) 34 | g2 = g.create_group('style_loss') 35 | for k, v in log['style_loss'].items(): 36 | g2.create_dataset(k, data=v) 37 | g2 = g.create_group('content_loss') 38 | for k, v in log['content_loss'].items(): 39 | g2.create_dataset(k, data=v) 40 | f.attrs['args'] = yaml.dump(log['args']) 41 | f.attrs['style_names'] = log['style_names'] 42 | f.attrs['style_image_sizes'] = log['style_image_sizes'] 43 | 44 | def preprocess_input(x): 45 | return vgg16.preprocess_input(x.astype('float32')) 46 | 47 | def preprocess_image_crop(image_path, img_size): 48 | ''' 49 | Preprocess the image scaling it so that its smaller size is img_size. 50 | The larger size is then cropped in order to produce a square image. 51 | ''' 52 | img = load_img(image_path) 53 | scale = float(img_size) / min(img.size) 54 | new_size = (int(np.ceil(scale * img.size[0])), int(np.ceil(scale * img.size[1]))) 55 | # print('old size: %s,new size: %s' %(str(img.size), str(new_size))) 56 | img = img.resize(new_size, resample=Image.BILINEAR) 57 | img = img_to_array(img) 58 | crop_h = img.shape[0] - img_size 59 | crop_v = img.shape[1] - img_size 60 | img = img[crop_h:img_size+crop_h, crop_v:img_size+crop_v, :] 61 | img = np.expand_dims(img, axis=0) 62 | img = vgg16.preprocess_input(img) 63 | return img 64 | 65 | # util function to open, resize and format pictures into appropriate tensors 66 | def preprocess_image_scale(image_path, img_size=None): 67 | ''' 68 | Preprocess the image scaling it so that its larger size is max_size. 69 | This function preserves aspect ratio. 70 | ''' 71 | img = load_img(image_path) 72 | if img_size: 73 | scale = float(img_size) / max(img.size) 74 | new_size = (int(np.ceil(scale * img.size[0])), int(np.ceil(scale * img.size[1]))) 75 | img = img.resize(new_size, resample=Image.BILINEAR) 76 | img = img_to_array(img) 77 | img = np.expand_dims(img, axis=0) 78 | img = vgg16.preprocess_input(img) 79 | return img 80 | 81 | 82 | # util function to convert a tensor into a valid image 83 | def deprocess_image(x): 84 | x = x[0] 85 | # Remove zero-center by mean pixel 86 | x[:, :, 0] += 103.939 87 | x[:, :, 1] += 116.779 88 | x[:, :, 2] += 123.68 89 | # 'BGR'->'RGB' 90 | x = x[:, :, ::-1] 91 | x = np.clip(x, 0, 255).astype('uint8') 92 | return x 93 | 94 | def std_input_list(input_list, nb_el, name): 95 | if len(input_list) == 1: 96 | return [input_list[0] for _ in range(nb_el)] 97 | elif len(input_list) != nb_el: 98 | raise ValueError('%s list should have length %d, found %d.' %(name, nb_el, len(input_list))) 99 | return input_list 100 | --------------------------------------------------------------------------------