├── .gitignore
├── LICENSE
├── README.md
├── content_imgs
├── brad_pitt.jpg
├── chicago.jpg
├── gandalf.jpg
├── golden_gate.jpg
├── hoovertowernight.jpg
├── matheus.jpg
└── tubingen.jpg
├── data
├── download_coco.sh
└── models
│ └── download_models.sh
├── fast_style_transfer.py
├── iterative_style_transfer.py
├── layers.py
├── make_gram_dataset.py
├── make_style_dataset.py
├── model.py
├── readme_imgs
├── 1style
│ ├── chicago_style_the_scream.png
│ ├── golden_gate_style_feathers.png
│ └── hoovertowernight_style_candy.png
├── 6styles
│ ├── chicago_style_the_scream.png
│ ├── golden_gate_style_feathers.png
│ └── hoovertowernight_style_candy.png
├── candy_tubingen_init_content.gif
└── candy_tubingen_init_random.gif
├── style_imgs
├── candy.jpg
├── composition_vii.jpg
├── escher_sphere.jpg
├── feathers.jpg
├── frida_kahlo.jpg
├── la_muse.jpg
├── mosaic.jpg
├── picasso_selfport1907.jpg
├── seated-nude.jpg
├── shipwreck.jpg
├── starry_night.jpg
├── the_scream.jpg
├── udnie.jpg
├── wave_crop.jpg
└── woman-with-hat-matisse.jpg
├── train.py
├── training.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.gitignore.io/api/linux,macos,python
2 |
3 | ### Linux ###
4 | *~
5 |
6 | # temporary files which can be created if a process still has a handle open of a deleted file
7 | .fuse_hidden*
8 |
9 | # KDE directory preferences
10 | .directory
11 |
12 | # Linux trash folder which might appear on any partition or disk
13 | .Trash-*
14 |
15 | # .nfs files are created when an open file is removed but is still being accessed
16 | .nfs*
17 |
18 | ### macOS ###
19 | *.DS_Store
20 | .AppleDouble
21 | .LSOverride
22 |
23 | # Icon must end with two \r
24 | Icon
25 |
26 |
27 | # Thumbnails
28 | ._*
29 |
30 | # Files that might appear in the root of a volume
31 | .DocumentRevisions-V100
32 | .fseventsd
33 | .Spotlight-V100
34 | .TemporaryItems
35 | .Trashes
36 | .VolumeIcon.icns
37 | .com.apple.timemachine.donotpresent
38 |
39 | # Directories potentially created on remote AFP share
40 | .AppleDB
41 | .AppleDesktop
42 | Network Trash Folder
43 | Temporary Items
44 | .apdisk
45 |
46 | ### Python ###
47 | # Byte-compiled / optimized / DLL files
48 | __pycache__/
49 | *.py[cod]
50 | *$py.class
51 |
52 | # C extensions
53 | *.so
54 |
55 | # Distribution / packaging
56 | .Python
57 | env/
58 | build/
59 | develop-eggs/
60 | dist/
61 | downloads/
62 | eggs/
63 | .eggs/
64 | lib/
65 | lib64/
66 | parts/
67 | sdist/
68 | var/
69 | wheels/
70 | *.egg-info/
71 | .installed.cfg
72 | *.egg
73 |
74 | # PyInstaller
75 | # Usually these files are written by a python script from a template
76 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
77 | *.manifest
78 | *.spec
79 |
80 | # Installer logs
81 | pip-log.txt
82 | pip-delete-this-directory.txt
83 |
84 | # Unit test / coverage reports
85 | htmlcov/
86 | .tox/
87 | .coverage
88 | .coverage.*
89 | .cache
90 | nosetests.xml
91 | coverage.xml
92 | *,cover
93 | .hypothesis/
94 |
95 | # Translations
96 | *.mo
97 | *.pot
98 |
99 | # Django stuff:
100 | *.log
101 | local_settings.py
102 |
103 | # Flask stuff:
104 | instance/
105 | .webassets-cache
106 |
107 | # Scrapy stuff:
108 | .scrapy
109 |
110 | # Sphinx documentation
111 | docs/_build/
112 |
113 | # PyBuilder
114 | target/
115 |
116 | # Jupyter Notebook
117 | .ipynb_checkpoints
118 |
119 | # pyenv
120 | .python-version
121 |
122 | # celery beat schedule file
123 | celerybeat-schedule
124 |
125 | # dotenv
126 | .env
127 |
128 | # virtualenv
129 | .venv/
130 | venv/
131 | ENV/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # End of https://www.gitignore.io/api/linux,macos,python
140 | # Ignore the dataset
141 | data/coco
142 | # Ignore .h5 files
143 | *.h5
144 | *.hdf5
145 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Roberto de Moura Estevão Filho
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # neural-style-keras
2 | Implementation of neural style transfer methods using Keras and Tensorflow. This includes the iterative stylization from [1], the fast style transfer from [2] and the multi-style nets from [3]. All models include instance normalization [4].
3 |
4 | ## Fast style transfer
5 | To use a pre-tained net simply use the script
6 |
7 | ```bash
8 | python fast_stye_transfer.py --checkpoint_path data/models/candy.h5 --input_path content_imgs\
9 | --output_path pastiche_output --use_style_name
10 | ```
11 |
12 | All the images in the `input_path` will be stylized and the resulting images will be saved on the `output_path`. If the network was trained on multiple styles, all styles will be applied.
13 |
14 | ## Training a pastiche network
15 | #### Set up coco
16 | You may download [the coco dataset](http://mscoco.org/home/) [5] with the provided script (it may take a while)
17 | ```bash
18 | cd data; sh download_coco.sh
19 | ```
20 | Afterwards, preprocess it into an hdf5 file
21 | ```bash
22 | python make_style_dataset.py
23 | ```
24 | #### Make the gram dataset
25 | You should then make the gram dataset, this will pre-compute the gram matrices for the styles the net will be trained on. At this stage you should input the size of each image for the computation of the gram matrices. For more information on the usage see [this section](#make_gram_dataset). Below is an example:
26 | ```bash
27 | python make_gram_dataset.py --gram_dataset_path candy_grams.h5 --style_dir style_imgs \
28 | --style_imgs candy.jpg --style_img_size 384 --gpu 0
29 | ```
30 | #### Train the network
31 | You may then train the network with the `train.py` script. For more usage information see [this section](#train). Below is an example:
32 | ```bash
33 | python train.py --content_weight 1 --style_weight 10 --tv_weight 1e-4 --norm_by_channels \
34 | --gram_dataset_path candy_grams.h5 --checkpoint_path candy.h5
35 | ```
36 |
37 | ### Stylized images
38 | Below are some examples of images stylized by trained networks. Each row contains, respectively, the style image, the content image and images stylized by a single style net and by a 6-style net. All images on the fourth column were stylized by the same network.
39 |
40 |
46 |
52 |
58 |
59 | ## Script usage
60 | ### make_gram_dataset
61 | This script pre-computes the gram matrices of the style images and stores them in an hdf5 file. All gram matrices are computed without channel normalization and will be normalized before training if necessary. The options are:
62 | * `--style_dir [gram_imgs]`: directory that contains the style images.
63 | * `--gram_dataset_path [grams.h5]`: where to save the output file.
64 | * `--style_imgs [None]`: image file names that will be used, can be a list. If `None`, all images in the directory will be used.
65 | * `--style_img_size [None]`: largest size of the image, can be a single size that will be applied to all images or a list with a size for each image. If `None`, the image will not be resized.
66 | * `--gpu`: which gpu to run the code on. If `-1`, the code will run on cpu.
67 | * `--allow_growth`: flag that stops tensorflow from allocating all gpu memory at the start of the session.
68 |
69 | ### train
70 | Script that trains the pastiche network. You should have preprocessed the coco dataset and a gram dataset before using it. Below are all the options:
71 | * `--lr [0.01]`: learning rate that will be used on the adam optimizer to update the network weights.
72 | * `--content_weight [1.]`: weight of the content loss. Can be a single value that will be used for all styles or a list with a different weight for each style.
73 | * `--style_weight [1e-4]`: weight of the style loss. Can be a single value that will be used for all styles or a list with a different weight for each style. Do note that the default value is not reasonable if gram matrices are normalized by channels. If so, this should be around 10.
74 | * `--tv_weight [1e-4]`: weight of the total variation loss. Variation loss can be used to improve the local coherence of the images.
75 | * `--width_factor [2]`: how wide the convolutional layers are. This is a multiplicative factor. Use 2 for the original implementation on [2] and if reproducing [3]. Use 1 if reproducing results from [2] with instance normalization.
76 | * `--nb_classes [1]`: the number of styles the network will learn.
77 | * `--norm_by_channels`: flag that sets whether the gram matrices will be normalized by the number of channels. Use this to reproduce [2], do not use to reproduce [3].
78 | * `--num_iterations [40000]`: how many iterations the model should be trained for.
79 | * `--save_every [500]`: how often the model will be saved to the checkpoint.
80 | * `--batch_size [4]`: batch size that will be used during training.
81 | * `--coco_path [data/coco/ms-coco-256.h5]`: path to the coco dataset hdf5 file.
82 | * `--gram_dataset_path [grams.h5]`: path to the gram dataset.
83 | * `--checkpoint_path [checkpoint.h5]`: where to save the checkpoints. The checkpoint includes the network weights as well as the training args.
84 | * `--gpu`: which gpu to run the code on. If `-1`, the code will run on cpu. It is not recommended to run this on cpu.
85 | * `--allow_growth`: flag that stops tensorflow from allocating all gpu memory at the start of the session.
86 |
87 |
88 | ## Iterative style transfer
89 | You can stylize the image by iteratively updating the input image so as to minimize loss.
90 | ```bash
91 | python iterative_style_transfer.py \
92 | --content_image_path content_imgs/tubingen.jpg \
93 | --style_image_path style_imgs/starry_night.jpg \
94 | --output_path starry_tubingen.png
95 | ```
96 |
97 | The full set of options include:
98 | * `--content_image_path`: path to the image that will be used as content.
99 | * `--style_image_path`: path to the image that will be used as style.
100 | * `--output_path`: where to save the output.
101 | * `--lr [10.]`: changes the learning rate of the adam optimizer used to update the image.
102 | * `--num_iterations [1000]`: number of iterations performed.
103 | * `--content_weight [1.]`: weight of the content loss.
104 | * `--style_weight [1e-4]`: weight of the style loss.
105 | * `--tv_weight [1e-4]`: weight of the total variation loss.
106 | * `--content_layers`: list of layers used on the content loss.
107 | * `--style_layers`: list of layers used on the style loss.
108 | * `--norm_by_channels`: whether the gram matrices will the normalized by the number of channels. Use this to compute gram matrices as in [2], do not use if to compute as in [3]. If using this, style_weight should be in the order of `10`.
109 | * `--img_size [512]`: largest size of the generated image, the aspect ratio is determined by the content image used. The larger the size the longer it takes.
110 | * `--style_img_size [None]`: largest size of the style image when computing the gram matrix. Changing this will change how the style is captured. If `None`, the original size is used.
111 | * `--print_and_save [100]`: how many iterations for each print and save.
112 | * `--init [random]`: type of initialization: `random` for noise initialization or `content` for initializing with the content image.
113 | * `--std_init [0.001]`: standard deviation for the noise initialization.
114 | * `--gpu`: which gpu to run it on. Use `-1` to run on cpu. It is very recommended to run this on a gpu.
115 | * `--allow_growth`: flag that stops tensorflow from allocating all gpu memory at the start of the session.
116 |
117 | #### Examples
118 | **Initializing with noise**
119 |
120 | ```bash
121 | python iterative_style_transfer.py \
122 | --content_image_path content_imgs/tubingen.jpg \
123 | --style_image_path style_imgs/candy.jpg \
124 | --output_path candy_tubingen.png \
125 | --content_weight 5 --style_weight 30 \
126 | --tv_weight 1e-3 --norm_by_channels \
127 | --style_img_size 384 --gpu 0
128 | ```
129 |
130 |
131 |

132 |
133 |
134 | **Initializing with content**
135 |
136 | ```bash
137 | python iterative_style_transfer.py \
138 | --content_image_path content_imgs/tubingen.jpg \
139 | --style_image_path style_imgs/candy.jpg \
140 | --output_path candy_tubingen_content.png \
141 | --content_weight 5 --style_weight 100 \
142 | --norm_by_channels --style_img_size 384 \
143 | --gpu 0 --init content
144 | ```
145 |
146 |
147 |

148 |
149 |
150 | ## Requirements
151 | * keras
152 | * tensorflow
153 | * h5py
154 | * pyyaml
155 |
156 | ## References
157 | * [1]: L. A. Gatys, A. S. Ecker and M. Bethge. "A Neural Algorithm for Artistic Style". [Arxiv](https://arxiv.org/abs/1508.06576).
158 | * [2]: J. Johnson, A. Alahi and L. Fei-Fei. "Perceptual Losses for Real-Time Style Transfer and Super-Resolution". [Paper](http://cs.stanford.edu/people/jcjohns/papers/eccv16/JohnsonECCV16.pdf) [Github](https://github.com/jcjohnson/fast-neural-style)
159 | * [3]: V. Dumoulin, J. Shlens and M. Kudlur. "A Learned Representation for Artistic Style". [Arxiv](https://arxiv.org/abs/1610.07629) [Github](https://github.com/tensorflow/magenta/tree/master/magenta/models/image_stylization)
160 | * [4]: D. Ulyanov, A. Vedaldi and V. Lempitsky. "Instance Normalization: The Missing Ingredient for Fast Stylization". [Arxiv](https://arxiv.org/abs/1607.08022)
161 | * [5]: T. Lin et al. "Microsoft COCO: Common Objects in Context". [Arxiv](https://arxiv.org/abs/1405.0312) [Website](http://mscoco.org/home/)
162 |
--------------------------------------------------------------------------------
/content_imgs/brad_pitt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/brad_pitt.jpg
--------------------------------------------------------------------------------
/content_imgs/chicago.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/chicago.jpg
--------------------------------------------------------------------------------
/content_imgs/gandalf.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/gandalf.jpg
--------------------------------------------------------------------------------
/content_imgs/golden_gate.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/golden_gate.jpg
--------------------------------------------------------------------------------
/content_imgs/hoovertowernight.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/hoovertowernight.jpg
--------------------------------------------------------------------------------
/content_imgs/matheus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/matheus.jpg
--------------------------------------------------------------------------------
/content_imgs/tubingen.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/content_imgs/tubingen.jpg
--------------------------------------------------------------------------------
/data/download_coco.sh:
--------------------------------------------------------------------------------
1 | mkdir -p coco/images
2 | cd coco/
3 | wget -c http://msvocds.blob.core.windows.net/coco2014/train2014.zip
4 | unzip train2104.zip -d images/
5 | wget -c http://msvocds.blob.core.windows.net/coco2014/val2014.zip
6 | unzip val2014.zip -d images/
7 |
--------------------------------------------------------------------------------
/data/models/download_models.sh:
--------------------------------------------------------------------------------
1 | wget -c -O candy.h5 https://www.dropbox.com/s/jj4g1k11hfmuylx/candy.h5?dl=0
2 | wget -c -O feathers.h5 https://www.dropbox.com/s/im79c6ni2a4hl19/feathers.h5?dl=0
3 | wget -c -O muse.h5 https://www.dropbox.com/s/2d45czsn07arfpa/muse.h5?dl=0
4 | wget -c -O scream.h5 https://www.dropbox.com/s/ku8r2h1nz1h8kow/scream.h5?dl=0
5 | wget -c -O udnie.h5 https://www.dropbox.com/s/bmwxg6ny9dg80dm/udnie.h5?dl=0
6 | wget -c -O candy_feathers.h5 https://www.dropbox.com/s/5bjwqgkrumsv20w/candy_feathers.h5?dl=0
7 | wget -c -O 6_styles.h5 https://www.dropbox.com/s/6hwusy6qifcwg8r/6_styles.h5?dl=0
8 |
--------------------------------------------------------------------------------
/fast_style_transfer.py:
--------------------------------------------------------------------------------
1 | '''
2 | Use a trained pastiche net to stylize images.
3 | '''
4 |
5 | from __future__ import print_function
6 | import os
7 | import argparse
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | import keras
12 | import keras.backend as K
13 |
14 | from utils import config_gpu, preprocess_image_scale, deprocess_image
15 | import h5py
16 | import yaml
17 | import time
18 | from scipy.misc import imsave
19 |
20 | from model import pastiche_model
21 |
22 | if __name__ == '__main__':
23 |
24 | parser = argparse.ArgumentParser(description='Use a trained pastiche network.')
25 | parser.add_argument('--checkpoint_path', type=str, default='checkpoint')
26 | parser.add_argument('--img_size', type=int, default=1024)
27 | parser.add_argument('--batch_size', type=int, default=4)
28 | parser.add_argument('--input_path', type=str, default='pastiche_input')
29 | parser.add_argument('--output_path', type=str, default='pastiche_output')
30 | parser.add_argument('--use_style_name', default=False, action='store_true')
31 | parser.add_argument('--gpu', type=str, default='')
32 | parser.add_argument('--allow_growth', default=False, action='store_true')
33 |
34 | args = parser.parse_args()
35 |
36 | config_gpu(args.gpu, args.allow_growth)
37 |
38 | # Strip the extension if there is one
39 | checkpoint_path = os.path.splitext(args.checkpoint_path)[0]
40 |
41 | with h5py.File(checkpoint_path + '.h5', 'r') as f:
42 | model_args = yaml.load(f.attrs['args'])
43 | style_names = f.attrs['style_names']
44 |
45 | print('Creating pastiche model...')
46 | class_targets = K.placeholder(shape=(None,), dtype=tf.int32)
47 | # Intantiate the model using information stored on tha yaml file
48 | pastiche_net = pastiche_model(None, width_factor=model_args.width_factor,
49 | nb_classes=model_args.nb_classes,
50 | targets=class_targets)
51 | with h5py.File(checkpoint_path + '.h5', 'r') as f:
52 | pastiche_net.load_weights_from_hdf5_group(f['model_weights'])
53 |
54 | inputs = [pastiche_net.input, class_targets, K.learning_phase()]
55 |
56 | transfer_style = K.function(inputs, [pastiche_net.output])
57 |
58 | num_batches = int(np.ceil(model_args.nb_classes / float(args.batch_size)))
59 |
60 | for img_name in os.listdir(args.input_path):
61 | print('Processing %s' %img_name)
62 | img = preprocess_image_scale(os.path.join(args.input_path, img_name),
63 | img_size=args.img_size)
64 | imgs = np.repeat(img, model_args.nb_classes, axis=0)
65 | out_name = os.path.splitext(os.path.split(img_name)[-1])[0]
66 |
67 | for batch_idx in range(num_batches):
68 | idx = batch_idx * args.batch_size
69 |
70 | batch = imgs[idx:idx + args.batch_size]
71 | indices = batch_idx * args.batch_size + np.arange(batch.shape[0])
72 |
73 | if args.use_style_name:
74 | names = style_names[idx:idx + args.batch_size]
75 | else:
76 | names = indices
77 | print(' Processing styles %s' %str(names))
78 |
79 | out = transfer_style([batch, indices, 0.])[0]
80 |
81 | for name, im in zip(names, out):
82 | print('Saving file %s_style_%s.png' %(out_name, str(name)))
83 | imsave(os.path.join(args.output_path, '%s_style_%s.png' %(out_name, str(name))),
84 | deprocess_image(im[None, :, :, :].copy()))
85 |
--------------------------------------------------------------------------------
/iterative_style_transfer.py:
--------------------------------------------------------------------------------
1 | '''
2 | Original neural style transfer algorithm that iteratively updates the input
3 | image in order to minimize the loss. Script inspired on the Keras example
4 | available at
5 | https://github.com/fchollet/keras/blob/master/examples/neural_style_transfer.py
6 | '''
7 |
8 | import os
9 | import argparse
10 | import time
11 |
12 | import numpy as np
13 | import tensorflow
14 | import keras
15 | import keras.backend as K
16 | from keras.optimizers import Adam
17 | from keras.applications import vgg19
18 | from keras.layers import Input
19 |
20 | from training import get_content_features, get_style_features, get_content_losses, get_style_losses, tv_loss
21 | from utils import config_gpu, preprocess_image_scale, deprocess_image
22 |
23 | from scipy.misc import imsave
24 |
25 |
26 | if __name__ == '__main__':
27 | def_cl = ['block4_conv2']
28 | def_sl = ['block1_conv1', 'block2_conv1',
29 | 'block3_conv1', 'block4_conv1',
30 | 'block5_conv1']
31 |
32 | parser = argparse.ArgumentParser(description='Iterative style transfer.')
33 | parser.add_argument('--content_image_path', type=str,
34 | default='content_imgs/tubingen.jpg',
35 | help='Path to the image to transform.')
36 | parser.add_argument('--style_image_path', type=str,
37 | default='style_imgs/starry_night.jpg', nargs='+',
38 | help='Path to the style reference image. Can be a list.')
39 | parser.add_argument('--output_path', type=str, default='pastiche_output')
40 | parser.add_argument('--lr', help='Learning rate.', type=float, default=10.)
41 | parser.add_argument('--num_iterations', type=int, default=1000)
42 | parser.add_argument('--content_weight', type=float, default=1.)
43 | parser.add_argument('--style_weight', type=float, default=1e-4)
44 | parser.add_argument('--tv_weight', type=float, default=1e-4)
45 | parser.add_argument('--content_layers', type=str, nargs='+', default=def_cl)
46 | parser.add_argument('--style_layers', type=str, nargs='+', default=def_sl)
47 | parser.add_argument('--norm_by_channels', default=False, action='store_true')
48 | parser.add_argument('--img_size', type=int, default=512,
49 | help='Maximum heigth/width of generated image.')
50 | parser.add_argument('--style_img_size', type=int, default=None,
51 | help='Maximum height/width of the style images.')
52 | parser.add_argument('--print_and_save', type=float, default=100,
53 | help='Print and save image every chosen iterations.')
54 | parser.add_argument('--init', type=str, default='random',
55 | help='How to initialize the pastiche images.')
56 | parser.add_argument('--std_init', type=float, default=0.001,
57 | help='Standard deviation for random init.')
58 | parser.add_argument('--gpu', type=str, default='')
59 | parser.add_argument('--allow_growth', default=False, action='store_true')
60 |
61 | args = parser.parse_args()
62 | # Arguments parsed
63 |
64 | # Split the extension
65 | output_path, ext = os.path.splitext(args.output_path)
66 | if ext == '':
67 | ext = '.png'
68 | config_gpu(args.gpu, args.allow_growth)
69 |
70 | ## Precomputing the targets for content and style
71 | # Load content and style images
72 | content_image = preprocess_image_scale(args.content_image_path,
73 | img_size=args.img_size)
74 | style_images = [preprocess_image_scale(img, img_size=args.style_img_size)
75 | for img in args.style_image_path]
76 | nb_styles = len(style_images)
77 |
78 | model = vgg19.VGG19(weights='imagenet', include_top=False)
79 | outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])
80 |
81 | content_features = get_content_features(outputs_dict, args.content_layers)
82 | style_features = get_style_features(outputs_dict, args.style_layers,
83 | norm_by_channels=args.norm_by_channels)
84 |
85 | get_content_fun = K.function([model.input], content_features)
86 | get_style_fun = K.function([model.input], style_features)
87 |
88 | content_targets = get_content_fun([content_image])
89 | # List of list of features
90 | style_targets_list = [get_style_fun([img]) for img in style_images]
91 |
92 | # List of batched features
93 | style_targets = []
94 | for l in range(len(args.style_layers)):
95 | batched_features = []
96 | for i in range(nb_styles):
97 | batched_features.append(style_targets_list[i][l][None])
98 | style_targets.append(np.concatenate(batched_features))
99 |
100 | if args.init == 'content':
101 | pastiche_image = K.variable(np.repeat(content_image, nb_styles, axis=0))
102 | else:
103 | if args.init != 'random':
104 | print('Could not recognize init arg \'%s\'. Falling back to random.' %args.init)
105 | pastiche_image = K.variable(args.std_init*np.random.randn(nb_styles, *content_image.shape[1:]))
106 |
107 | # Store targets as variables
108 | content_targets_dict = {k: K.variable(v) for k, v in zip(args.content_layers, content_targets)}
109 | style_targets_dict = {k: K.variable(v) for k, v in zip(args.style_layers, style_targets)}
110 |
111 | model = vgg19.VGG19(weights='imagenet', include_top=False, input_tensor=Input(tensor=pastiche_image))
112 | outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])
113 |
114 | content_losses = get_content_losses(outputs_dict, content_targets_dict,
115 | args.content_layers)
116 | style_losses = get_style_losses(outputs_dict, style_targets_dict,
117 | args.style_layers,
118 | norm_by_channels=args.norm_by_channels)
119 |
120 | # Total variation loss is used to improve local coherence
121 | total_var_loss = tv_loss(pastiche_image)
122 |
123 | # Compute total loss
124 | weighted_style_losses = []
125 | weighted_content_losses = []
126 |
127 | total_loss = K.variable(0.)
128 | for loss in style_losses:
129 | weighted_loss = args.style_weight * K.mean(loss)
130 | weighted_style_losses.append(weighted_loss)
131 | total_loss += weighted_loss
132 | for loss in content_losses:
133 | weighted_loss = args.content_weight * K.mean(loss)
134 | weighted_content_losses.append(weighted_loss)
135 | total_loss += weighted_loss
136 | weighted_tv_loss = args.tv_weight * K.mean(total_var_loss)
137 | total_loss += weighted_tv_loss
138 |
139 | opt = Adam(lr=args.lr)
140 | updates = opt.get_updates([pastiche_image], {}, total_loss)
141 | # List of outputs
142 | outputs = [total_loss] + weighted_content_losses + weighted_style_losses + [weighted_tv_loss]
143 |
144 | # Function that makes a step after backpropping to the image
145 | make_step = K.function([], outputs, updates)
146 |
147 |
148 | # Perform optimization steps and save the results
149 | start_time = time.time()
150 |
151 | for i in range(args.num_iterations):
152 | out = make_step([])
153 | if (i + 1) % args.print_and_save == 0:
154 | print('Iteration %d/%d' %(i + 1, args.num_iterations))
155 | N = len(content_losses)
156 | for j, l in enumerate(out[1:N+1]):
157 | print(' Content loss %d: %g' %(j, l))
158 | for j, l in enumerate(out[N+1:-1]):
159 | print(' Style loss %d: %g' %(j, l))
160 |
161 | print(' Total style loss: %g' %(sum(out[N+1:-1])))
162 | print(' TV loss: %g' %(out[-1]))
163 | print(' Total loss: %g' %out[0])
164 | stop_time = time.time()
165 | print('Did %d iterations in %.2fs.' %(args.print_and_save, stop_time - start_time))
166 | x = K.get_value(pastiche_image)
167 | for s in range(nb_styles):
168 | fname = output_path + '_style%d_%d%s' %(s, (i + 1) / args.print_and_save, ext)
169 | print('Saving image to %s.\n' %fname)
170 | img = deprocess_image(x[s:s+1])
171 | imsave(fname, img)
172 | start_time = time.time()
173 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | '''
2 | Custom Keras layers used on the pastiche model.
3 | '''
4 |
5 | import tensorflow as tf
6 | import keras
7 | from keras import initializations
8 | from keras.layers import ZeroPadding2D, Layer, InputSpec
9 |
10 | # Extending the ZeroPadding2D layer to do reflection padding instead.
11 | class ReflectionPadding2D(ZeroPadding2D):
12 | def call(self, x, mask=None):
13 | pattern = [[0, 0],
14 | [self.top_pad, self.bottom_pad],
15 | [self.left_pad, self.right_pad],
16 | [0, 0]]
17 | return tf.pad(x, pattern, mode='REFLECT')
18 |
19 |
20 | class InstanceNormalization(Layer):
21 | def __init__(self, epsilon=1e-5, weights=None,
22 | beta_init='zero', gamma_init='one', **kwargs):
23 | self.beta_init = initializations.get(beta_init)
24 | self.gamma_init = initializations.get(gamma_init)
25 | self.epsilon = epsilon
26 | super(InstanceNormalization, self).__init__(**kwargs)
27 |
28 | def build(self, input_shape):
29 | # This currently only works for 4D inputs: assuming (B, H, W, C)
30 | self.input_spec = [InputSpec(shape=input_shape)]
31 | shape = (1, 1, 1, input_shape[-1])
32 |
33 | self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name))
34 | self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
35 | self.trainable_weights = [self.gamma, self.beta]
36 |
37 | self.built = True
38 |
39 | def call(self, x, mask=None):
40 | # Do not regularize batch axis
41 | reduction_axes = [1, 2]
42 |
43 | mean, var = tf.nn.moments(x, reduction_axes,
44 | shift=None, name=None, keep_dims=True)
45 | x_normed = tf.nn.batch_normalization(x, mean, var, self.beta, self.gamma, self.epsilon)
46 | return x_normed
47 |
48 | def get_config(self):
49 | config = {"epsilon": self.epsilon}
50 | base_config = super(InstanceNormalization, self).get_config()
51 | return dict(list(base_config.items()) + list(config.items()))
52 |
53 |
54 | class ConditionalInstanceNormalization(InstanceNormalization):
55 | def __init__(self, targets, nb_classes, **kwargs):
56 | self.targets = targets
57 | self.nb_classes = nb_classes
58 | super(ConditionalInstanceNormalization, self).__init__(**kwargs)
59 |
60 | def build(self, input_shape):
61 | # This currently only works for 4D inputs: assuming (B, H, W, C)
62 | self.input_spec = [InputSpec(shape=input_shape)]
63 | shape = (self.nb_classes, 1, 1, input_shape[-1])
64 |
65 | self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name))
66 | self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
67 | self.trainable_weights = [self.gamma, self.beta]
68 |
69 | self.built = True
70 |
71 | def call(self, x, mask=None):
72 | # Do not regularize batch axis
73 | reduction_axes = [1, 2]
74 |
75 | mean, var = tf.nn.moments(x, reduction_axes,
76 | shift=None, name=None, keep_dims=True)
77 |
78 | # Get the appropriate lines of gamma and beta
79 | beta = tf.gather(self.beta, self.targets)
80 | gamma = tf.gather(self.gamma, self.targets)
81 | x_normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, self.epsilon)
82 |
83 | return x_normed
84 |
--------------------------------------------------------------------------------
/make_gram_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | This script makes a hdf5 style dataset with all images in a chosen directory.
3 | Gram matrices computed here are never normalized by the number of channels.
4 | Normalization is done if necessary on the training stage.
5 | '''
6 | import numpy as np
7 | import h5py
8 |
9 | import keras
10 | import keras.backend as K
11 | from keras.applications import vgg16
12 |
13 | from training import get_style_features
14 | from utils import preprocess_image_scale, config_gpu, std_input_list
15 |
16 | import os
17 | import argparse
18 |
19 | if __name__ == "__main__":
20 |
21 | def_sl = ['block1_conv2', 'block2_conv2',
22 | 'block3_conv3', 'block4_conv3']
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--style_dir', type=str, default='gram_imgs',
26 | help='Directory that contains the images.')
27 | parser.add_argument('--gram_dataset_path', type=str, default='grams.h5',
28 | help='Name of the output hdf5 file.')
29 | parser.add_argument('--style_imgs', type=str, default=None, nargs='+',
30 | help='Style image file names.')
31 | parser.add_argument('--style_img_size', type=int, default=[None], nargs='+',
32 | help='Largest size of the style images')
33 | parser.add_argument('--style_layers', type=str, nargs='+', default=def_sl)
34 | parser.add_argument('--gpu', type=str, default='')
35 | parser.add_argument('--allow_growth', default=False, action='store_true')
36 | args = parser.parse_args()
37 |
38 | config_gpu(args.gpu, args.allow_growth)
39 |
40 | loss_net = vgg16.VGG16(weights='imagenet', include_top=False)
41 |
42 | targets_dict = dict([(layer.name, layer.output) for layer in loss_net.layers])
43 |
44 | s_targets = get_style_features(targets_dict, args.style_layers)
45 |
46 | get_style_target = K.function([loss_net.input], s_targets)
47 | gm_lists = [[] for l in args.style_layers]
48 |
49 | img_list = []
50 | img_size_list = []
51 | # Get style image names or get all images in the directory
52 | if args.style_imgs is None:
53 | args.style_imgs = os.listdir(args.style_dir)
54 |
55 | # Check the image sizes
56 | args.style_img_size = std_input_list(args.style_img_size, len(args.style_imgs), 'Image size')
57 |
58 | for img_name, img_size in zip(args.style_imgs, args.style_img_size):
59 | try:
60 | print(img_name)
61 | img = preprocess_image_scale(os.path.join(args.style_dir, img_name),
62 | img_size=img_size)
63 | s_targets = get_style_target([img])
64 | for l, t in zip(gm_lists, s_targets):
65 | l.append(t)
66 | img_list.append(os.path.splitext(img_name)[0])
67 | img_size_list.append(img_size)
68 | except IOError as e:
69 | print('Could not open file %s as image.' %img_name)
70 |
71 | mtx = []
72 | for l in gm_lists:
73 | mtx.append(np.concatenate(l))
74 |
75 | f = h5py.File(args.gram_dataset_path, 'w')
76 |
77 | f.attrs['img_names'] = img_list
78 | f.attrs['img_sizes'] = img_size_list
79 | for name, m in zip(args.style_layers, mtx):
80 | f.create_dataset(name, data=m)
81 |
82 | f.flush()
83 | f.close()
84 |
--------------------------------------------------------------------------------
/make_style_dataset.py:
--------------------------------------------------------------------------------
1 | import os, json, argparse
2 | from threading import Thread
3 | from Queue import Queue
4 |
5 | import numpy as np
6 | from scipy.misc import imread, imresize
7 | import h5py
8 |
9 | """
10 | Create an HDF5 file of images for training a feedforward style transfer model.
11 | Original file created by Justin Johnson available at:
12 | https://github.com/jcjohnson/fast-neural-style
13 | """
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--train_dir', default='data/coco/images/train2014')
17 | parser.add_argument('--val_dir', default='data/coco/images/val2014')
18 | parser.add_argument('--output_file', default='data/ms-coco-256.h5')
19 | parser.add_argument('--height', type=int, default=256)
20 | parser.add_argument('--width', type=int, default=256)
21 | parser.add_argument('--max_images', type=int, default=-1)
22 | parser.add_argument('--num_workers', type=int, default=2)
23 | parser.add_argument('--include_val', type=int, default=1)
24 | parser.add_argument('--max_resize', default=16, type=int)
25 | args = parser.parse_args()
26 |
27 |
28 | def add_data(h5_file, image_dir, prefix, args):
29 | # Make a list of all images in the source directory
30 | image_list = []
31 | image_extensions = {'.jpg', '.jpeg', '.JPG', '.JPEG', '.png', '.PNG'}
32 | for filename in os.listdir(image_dir):
33 | ext = os.path.splitext(filename)[1]
34 | if ext in image_extensions:
35 | image_list.append(os.path.join(image_dir, filename))
36 | num_images = len(image_list)
37 |
38 | # Resize all images and copy them into the hdf5 file
39 | # We'll bravely try multithreading
40 | dset_name = os.path.join(prefix, 'images')
41 | # dset_size = (num_images, 3, args.height, args.width)
42 | dset_size = (num_images, args.height, args.width, 3)
43 | imgs_dset = h5_file.create_dataset(dset_name, dset_size, np.uint8)
44 |
45 | # input_queue stores (idx, filename) tuples,
46 | # output_queue stores (idx, resized_img) tuples
47 | input_queue = Queue()
48 | output_queue = Queue()
49 |
50 | # Read workers pull images off disk and resize them
51 | def read_worker():
52 | while True:
53 | idx, filename = input_queue.get()
54 | img = imread(filename)
55 | try:
56 | # First crop the image so its size is a multiple of max_resize
57 | H, W = img.shape[0], img.shape[1]
58 | H_crop = H - H % args.max_resize
59 | W_crop = W - W % args.max_resize
60 | img = img[:H_crop, :W_crop]
61 | img = imresize(img, (args.height, args.width))
62 | except (ValueError, IndexError) as e:
63 | print filename
64 | print img.shape, img.dtype
65 | print e
66 | input_queue.task_done()
67 | output_queue.put((idx, img))
68 |
69 | # Write workers write resized images to the hdf5 file
70 | def write_worker():
71 | num_written = 0
72 | while True:
73 | idx, img = output_queue.get()
74 | if img.ndim == 3:
75 | # RGB image, transpose from H x W x C to C x H x W
76 | # DO NOT TRANSPOSE
77 | imgs_dset[idx] = img
78 | elif img.ndim == 2:
79 | # Grayscale image; it is H x W so broadcasting to C x H x W will just copy
80 | # grayscale values into all channels.
81 | # COPY GRAY SCALE TO CHANNELS DIMENSION
82 | img_dtype = img.dtype
83 | imgs_dset[idx] = (img[:, :, None] * np.array([1, 1, 1])).astype(img_dtype)
84 | output_queue.task_done()
85 | num_written = num_written + 1
86 | if num_written % 100 == 0:
87 | print 'Copied %d / %d images' % (num_written, num_images)
88 |
89 | # Start the read workers.
90 | for i in xrange(args.num_workers):
91 | t = Thread(target=read_worker)
92 | t.daemon = True
93 | t.start()
94 |
95 | # h5py locks internally, so we can only use a single write worker =(
96 | t = Thread(target=write_worker)
97 | t.daemon = True
98 | t.start()
99 |
100 | for idx, filename in enumerate(image_list):
101 | if args.max_images > 0 and idx >= args.max_images: break
102 | input_queue.put((idx, filename))
103 |
104 | input_queue.join()
105 | output_queue.join()
106 |
107 |
108 |
109 | if __name__ == '__main__':
110 |
111 | with h5py.File(args.output_file, 'w') as f:
112 | add_data(f, args.train_dir, 'train2014', args)
113 |
114 | if args.include_val != 0:
115 | add_data(f, args.val_dir, 'val2014', args)
116 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | '''
2 | This module contains functions for building the pastiche model.
3 | '''
4 |
5 | import keras
6 | from keras.models import Model
7 | from keras.layers import (Convolution2D, Activation, UpSampling2D,
8 | ZeroPadding2D, Input, BatchNormalization,
9 | merge, Lambda)
10 | from layers import (ReflectionPadding2D, InstanceNormalization,
11 | ConditionalInstanceNormalization)
12 | from keras.initializations import normal
13 |
14 | # Initialize weights with normal distribution with std 0.01
15 | def weights_init(shape, name=None, dim_ordering=None):
16 | return normal(shape, scale=0.01, name=name)
17 |
18 |
19 | def conv(x, n_filters, kernel_size=3, stride=1, relu=True, nb_classes=1, targets=None):
20 | '''
21 | Reflection padding, convolution, instance normalization and (maybe) relu.
22 | '''
23 | if not kernel_size % 2:
24 | raise ValueError('Expected odd kernel size.')
25 | pad = (kernel_size - 1) / 2
26 | o = ReflectionPadding2D(padding=(pad, pad))(x)
27 | o = Convolution2D(n_filters, kernel_size, kernel_size,
28 | subsample=(stride, stride), init=weights_init)(o)
29 | # o = BatchNormalization()(o)
30 | if nb_classes > 1:
31 | o = ConditionalInstanceNormalization(targets, nb_classes)(o)
32 | else:
33 | o = InstanceNormalization()(o)
34 | if relu:
35 | o = Activation('relu')(o)
36 | return o
37 |
38 |
39 | def residual_block(x, n_filters, nb_classes=1, targets=None):
40 | '''
41 | Residual block with 2 3x3 convolutions blocks. Last one is linear (no ReLU).
42 | '''
43 | o = conv(x, n_filters)
44 | # Linear activation on second conv
45 | o = conv(o, n_filters, relu=False, nb_classes=nb_classes, targets=targets)
46 | # Shortcut connection
47 | o = merge([o, x], mode='sum')
48 | return o
49 |
50 |
51 | def upsampling(x, n_filters, nb_classes=1, targets=None):
52 | '''
53 | Upsampling block with nearest-neighbor interpolation and a conv block.
54 | '''
55 | o = UpSampling2D()(x)
56 | o = conv(o, n_filters, nb_classes=nb_classes, targets=targets)
57 | return o
58 |
59 |
60 | def pastiche_model(img_size, width_factor=2, nb_classes=1, targets=None):
61 | k = width_factor
62 | x = Input(shape=(img_size, img_size, 3))
63 | o = conv(x, 16 * k, kernel_size=9, nb_classes=nb_classes, targets=targets)
64 | o = conv(o, 32 * k, stride=2, nb_classes=nb_classes, targets=targets)
65 | o = conv(o, 64 * k, stride=2, nb_classes=nb_classes, targets=targets)
66 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets)
67 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets)
68 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets)
69 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets)
70 | o = residual_block(o, 64 * k, nb_classes=nb_classes, targets=targets)
71 | o = upsampling(o, 32 * k, nb_classes=nb_classes, targets=targets)
72 | o = upsampling(o, 16 * k, nb_classes=nb_classes, targets=targets)
73 | o = conv(o, 3, kernel_size=9, relu=False, nb_classes=nb_classes, targets=targets)
74 | o = Activation('tanh')(o)
75 | o = Lambda(lambda x: 150*x, name='scaling')(o)
76 | pastiche_net = Model(input=x, output=o)
77 | return pastiche_net
78 |
--------------------------------------------------------------------------------
/readme_imgs/1style/chicago_style_the_scream.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/1style/chicago_style_the_scream.png
--------------------------------------------------------------------------------
/readme_imgs/1style/golden_gate_style_feathers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/1style/golden_gate_style_feathers.png
--------------------------------------------------------------------------------
/readme_imgs/1style/hoovertowernight_style_candy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/1style/hoovertowernight_style_candy.png
--------------------------------------------------------------------------------
/readme_imgs/6styles/chicago_style_the_scream.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/6styles/chicago_style_the_scream.png
--------------------------------------------------------------------------------
/readme_imgs/6styles/golden_gate_style_feathers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/6styles/golden_gate_style_feathers.png
--------------------------------------------------------------------------------
/readme_imgs/6styles/hoovertowernight_style_candy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/6styles/hoovertowernight_style_candy.png
--------------------------------------------------------------------------------
/readme_imgs/candy_tubingen_init_content.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/candy_tubingen_init_content.gif
--------------------------------------------------------------------------------
/readme_imgs/candy_tubingen_init_random.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/readme_imgs/candy_tubingen_init_random.gif
--------------------------------------------------------------------------------
/style_imgs/candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/candy.jpg
--------------------------------------------------------------------------------
/style_imgs/composition_vii.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/composition_vii.jpg
--------------------------------------------------------------------------------
/style_imgs/escher_sphere.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/escher_sphere.jpg
--------------------------------------------------------------------------------
/style_imgs/feathers.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/feathers.jpg
--------------------------------------------------------------------------------
/style_imgs/frida_kahlo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/frida_kahlo.jpg
--------------------------------------------------------------------------------
/style_imgs/la_muse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/la_muse.jpg
--------------------------------------------------------------------------------
/style_imgs/mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/mosaic.jpg
--------------------------------------------------------------------------------
/style_imgs/picasso_selfport1907.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/picasso_selfport1907.jpg
--------------------------------------------------------------------------------
/style_imgs/seated-nude.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/seated-nude.jpg
--------------------------------------------------------------------------------
/style_imgs/shipwreck.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/shipwreck.jpg
--------------------------------------------------------------------------------
/style_imgs/starry_night.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/starry_night.jpg
--------------------------------------------------------------------------------
/style_imgs/the_scream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/the_scream.jpg
--------------------------------------------------------------------------------
/style_imgs/udnie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/udnie.jpg
--------------------------------------------------------------------------------
/style_imgs/wave_crop.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/wave_crop.jpg
--------------------------------------------------------------------------------
/style_imgs/woman-with-hat-matisse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robertomest/neural-style-keras/11fecd8e99228aab4851e4c00e85ed31217406db/style_imgs/woman-with-hat-matisse.jpg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | This script can be used to train a pastiche network.
3 | '''
4 |
5 | from __future__ import print_function
6 | import os
7 | import argparse
8 |
9 | import time
10 | import h5py
11 |
12 | import numpy as np
13 | import tensorflow as tf
14 | import keras
15 | import keras.backend as K
16 | from keras.optimizers import Adam
17 | from model import pastiche_model
18 | from training import get_loss_net, get_content_losses, get_style_losses, tv_loss
19 | from utils import preprocess_input, config_gpu, save_checkpoint, std_input_list
20 |
21 | if __name__ == '__main__':
22 | def_cl = ['block3_conv3']
23 | def_sl = ['block1_conv2', 'block2_conv2',
24 | 'block3_conv3', 'block4_conv3']
25 |
26 | # Argument parser
27 | parser = argparse.ArgumentParser(description='Train a pastiche network.')
28 | parser.add_argument('--lr', help='Learning rate.', type=float, default=0.001)
29 | parser.add_argument('--content_weight', type=float, default=[1.], nargs='+')
30 | parser.add_argument('--style_weight', type=float, default=[1e-4], nargs='+')
31 | parser.add_argument('--tv_weight', type=float, default=[1e-4], nargs='+')
32 | parser.add_argument('--content_layers', type=str, nargs='+', default=def_cl)
33 | parser.add_argument('--style_layers', type=str, nargs='+', default=def_sl)
34 | parser.add_argument('--width_factor', type=int, default=2)
35 | parser.add_argument('--nb_classes', type=int, default=1)
36 | parser.add_argument('--norm_by_channels', default=False, action='store_true')
37 | parser.add_argument('--num_iterations', type=int, default=40000)
38 | parser.add_argument('--save_every', type=int, default=500)
39 | parser.add_argument('--batch_size', type=int, default=4)
40 | parser.add_argument('--coco_path', type=str, default='data/coco/ms-coco-256.h5')
41 | parser.add_argument('--gram_dataset_path', type=str, default='grams.h5')
42 | parser.add_argument('--checkpoint_path', type=str, default='checkpoint.h5')
43 | parser.add_argument('--gpu', type=str, default='')
44 | parser.add_argument('--allow_growth', default=False, action='store_true')
45 |
46 | args = parser.parse_args()
47 | # Arguments parsed
48 |
49 | # Check loss weights
50 | args.style_weight = std_input_list(args.style_weight, args.nb_classes, 'Style weight')
51 | args.content_weight = std_input_list(args.content_weight, args.nb_classes, 'Content weight')
52 | args.tv_weight = std_input_list(args.tv_weight, args.nb_classes, 'TV weight')
53 |
54 | config_gpu(args.gpu, args.allow_growth)
55 |
56 | print('Creating pastiche model...')
57 | class_targets = K.placeholder(shape=(None,), dtype=tf.int32)
58 | # The model will be trained with 256 x 256 images of the coco dataset.
59 | pastiche_net = pastiche_model(256, width_factor=args.width_factor,
60 | nb_classes=args.nb_classes,
61 | targets=class_targets)
62 | x = pastiche_net.input
63 | o = pastiche_net.output
64 |
65 | print('Loading loss network...')
66 | loss_net, outputs_dict, content_targets_dict = get_loss_net(pastiche_net.output, input_tensor=pastiche_net.input)
67 |
68 | # Placeholder sizes
69 | ph_sizes = {k : K.int_shape(content_targets_dict[k])[-1] for k in args.style_layers}
70 |
71 | # Our style targets are precomputed and are fed through these placeholders
72 | style_targets_dict = {k : K.placeholder(shape=(None, ph_sizes[k], ph_sizes[k])) for k in args.style_layers}
73 |
74 |
75 | print('Setting up training...')
76 | # Setup the loss weights as variables
77 | content_weights = K.variable(args.content_weight)
78 | style_weights = K.variable(args.style_weight)
79 | tv_weights = K.variable(args.tv_weight)
80 |
81 | style_losses = get_style_losses(outputs_dict, style_targets_dict, args.style_layers,
82 | norm_by_channels=args.norm_by_channels)
83 |
84 | content_losses = get_content_losses(outputs_dict, content_targets_dict, args.content_layers)
85 |
86 | # Use total variation to improve local coherence
87 | total_var_loss = tv_loss(pastiche_net.output)
88 |
89 |
90 | weighted_style_losses = []
91 | weighted_content_losses = []
92 |
93 | # Compute total loss
94 | total_loss = K.variable(0.)
95 | for loss in style_losses:
96 | weighted_loss = K.mean(K.gather(style_weights, class_targets) * loss)
97 | weighted_style_losses.append(weighted_loss)
98 | total_loss += weighted_loss
99 |
100 | for loss in content_losses:
101 | weighted_loss = K.mean(K.gather(content_weights, class_targets) * loss)
102 | weighted_content_losses.append(weighted_loss)
103 | total_loss += weighted_loss
104 |
105 | weighted_tv_loss = K.mean(K.gather(tv_weights, class_targets) * total_var_loss)
106 | total_loss += weighted_tv_loss
107 |
108 |
109 | ## Make training function
110 |
111 | # Get a list of inputs
112 | inputs = [pastiche_net.input, class_targets] + \
113 | [style_targets_dict[k] for k in args.style_layers] + \
114 | [K.learning_phase()]
115 |
116 | # Get trainable params
117 | params = pastiche_net.trainable_weights
118 | constraints = pastiche_net.constraints
119 |
120 | opt = Adam(lr=args.lr)
121 | updates = opt.get_updates(params, constraints, total_loss)
122 |
123 | # List of outputs
124 | outputs = [total_loss] + weighted_content_losses + weighted_style_losses + [weighted_tv_loss]
125 |
126 | f_train = K.function(inputs, outputs, updates)
127 |
128 | X = h5py.File(args.coco_path, 'r')['train2014']['images']
129 | dataset_size = X.shape[0]
130 | batches_per_epoch = int(np.ceil(dataset_size / args.batch_size))
131 | batch_idx = 0
132 |
133 | print('Loading Gram matrices from dataset file...')
134 | if args.norm_by_channels:
135 | print('Normalizing the stored Gram matrices by the number of channels.')
136 | Y = {}
137 | with h5py.File(args.gram_dataset_path, 'r') as f:
138 | styles = f.attrs['img_names']
139 | style_sizes = f.attrs['img_sizes']
140 | for k, v in f.iteritems():
141 | Y[k] = np.array(v)
142 | if args.norm_by_channels:
143 | #Correct the Gram matrices from the dataset
144 | Y[k] /= Y[k].shape[-1]
145 |
146 | # Get a log going
147 | log = {}
148 | log['args'] = args
149 | log['style_names'] = styles[:args.nb_classes]
150 | log['style_image_sizes'] = style_sizes
151 | log['total_loss'] = []
152 | log['style_loss'] = {k: [] for k in args.style_layers}
153 | log['content_loss'] = {k: [] for k in args.content_layers}
154 | log['tv_loss'] = []
155 |
156 | # Strip the extension if there is one
157 | checkpoint_path = os.path.splitext(args.checkpoint_path)[0]
158 |
159 | start_time = time.time()
160 | # for it in range(args.num_iterations):
161 | for it in range(args.num_iterations):
162 | if batch_idx >= batches_per_epoch:
163 | print('Epoch done. Going back to the beginning...')
164 | batch_idx = 0
165 |
166 | # Get the batch
167 | idx = args.batch_size * batch_idx
168 | batch = X[idx:idx+args.batch_size]
169 | batch = preprocess_input(batch)
170 | batch_idx += 1
171 |
172 | # Get class information for each image on the batch
173 | batch_classes = np.random.randint(args.nb_classes, size=(args.batch_size,))
174 |
175 | batch_targets = [Y[l][batch_classes] for l in args.style_layers]
176 |
177 | # Do a step
178 | start_time2 = time.time()
179 | out = f_train([batch, batch_classes] + batch_targets + [1.])
180 | stop_time2 = time.time()
181 | # Log the statistics
182 |
183 | log['total_loss'].append(out[0])
184 | offset = 1
185 | for i, k in enumerate(args.content_layers):
186 | log['content_loss'][k].append(out[offset + i])
187 | offset += len(args.content_layers)
188 | for i, k in enumerate(args.style_layers):
189 | log['style_loss'][k].append(out[offset + i])
190 | log['tv_loss'].append(out[-1])
191 |
192 | stop_time = time.time()
193 | print('Iteration %d/%d: loss = %f. t = %f (%f)' %(it + 1,
194 | args.num_iterations, out[0], stop_time - start_time,
195 | stop_time2 - start_time2))
196 |
197 | if not ((it + 1) % args.save_every):
198 | print('Saving checkpoint in %s.h5...' %(checkpoint_path))
199 | save_checkpoint(checkpoint_path, pastiche_net, log)
200 | print('Checkpoint saved.')
201 |
202 | start_time = time.time()
203 | save_checkpoint(checkpoint_path, pastiche_net, log)
204 |
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | from keras.applications import vgg16
3 |
4 | '''
5 | Module that defines loss functions and other auxiliary functions used when
6 | training a pastiche model.
7 | '''
8 |
9 | def gram_matrix(x, norm_by_channels=False):
10 | '''
11 | Returns the Gram matrix of the tensor x.
12 | '''
13 | if K.ndim(x) == 3:
14 | features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
15 | shape = K.shape(x)
16 | C, H, W = shape[0], shape[1], shape[2]
17 | gram = K.dot(features, K.transpose(features))
18 | elif K.ndim(x) == 4:
19 | # Swap from (H, W, C) to (B, C, H, W)
20 | x = K.permute_dimensions(x, (0, 3, 1, 2))
21 | shape = K.shape(x)
22 | B, C, H, W = shape[0], shape[1], shape[2], shape[3]
23 | # Reshape as a batch of 2D matrices with vectorized channels
24 | features = K.reshape(x, K.stack([B, C, H*W]))
25 | # This is a batch of Gram matrices (B, C, C).
26 | gram = K.batch_dot(features, features, axes=2)
27 | else:
28 | raise ValueError('The input tensor should be either a 3d (H, W, C) or 4d (B, H, W, C) tensor.')
29 | # Normalize the Gram matrix
30 | if norm_by_channels:
31 | denominator = C * H * W # Normalization from Johnson
32 | else:
33 | denominator = H * W # Normalization from Google
34 | gram = gram / K.cast(denominator, x.dtype)
35 |
36 | return gram
37 |
38 |
39 | def content_loss(x, target):
40 | '''
41 | Content loss is simply the MSE between activations of a layer
42 | '''
43 | return K.mean(K.square(target - x), axis=(1, 2, 3))
44 |
45 |
46 | def style_loss(x, target, norm_by_channels=False):
47 | '''
48 | Style loss is the MSE between Gram matrices computed using activation maps.
49 | '''
50 | x_gram = gram_matrix(x, norm_by_channels=norm_by_channels)
51 | return K.mean(K.square(target - x_gram), axis=(1, 2))
52 |
53 |
54 | def tv_loss(x):
55 | '''
56 | Total variation loss is used to keep the image locally coherent
57 | '''
58 | assert K.ndim(x) == 4
59 | a = K.square(x[:, :-1, :-1, :] - x[:, 1:, :-1, :])
60 | b = K.square(x[:, :-1, :-1, :] - x[:, :-1, 1:, :])
61 | return K.sum(a + b, axis=(1, 2, 3))
62 |
63 |
64 | def get_content_features(out_dict, layer_names):
65 | return [out_dict[l] for l in layer_names]
66 |
67 |
68 | def get_style_features(out_dict, layer_names, norm_by_channels=False):
69 | features = []
70 | for l in layer_names:
71 | layer_features = out_dict[l]
72 | S = gram_matrix(layer_features, norm_by_channels=norm_by_channels)
73 | features.append(S)
74 | return features
75 |
76 |
77 | def get_loss_net(pastiche_net_output, input_tensor=None):
78 | '''
79 | Instantiates a VGG net and applies its layers on top of the pastiche net's
80 | output.
81 | '''
82 | loss_net = vgg16.VGG16(weights='imagenet', include_top=False,
83 | input_tensor=input_tensor)
84 | targets_dict = dict([(layer.name, layer.output) for layer in loss_net.layers])
85 | i = pastiche_net_output
86 | # We need to apply all layers to the output of the style net
87 | outputs_dict = {}
88 | for l in loss_net.layers[1:]: # Ignore the input layer
89 | i = l(i)
90 | outputs_dict[l.name] = i
91 |
92 | return loss_net, outputs_dict, targets_dict
93 |
94 |
95 | def get_style_losses(outputs_dict, targets_dict, style_layers,
96 | norm_by_channels=False):
97 | '''
98 | Returns the style loss for the desired layers
99 | '''
100 | return [style_loss(outputs_dict[l], targets_dict[l],
101 | norm_by_channels=norm_by_channels)
102 | for l in style_layers]
103 |
104 | def get_content_losses(outputs_dict, targets_dict, content_layers):
105 | return [content_loss(outputs_dict[l], targets_dict[l])
106 | for l in content_layers]
107 |
108 | def get_total_loss(content_losses, style_losses, total_var_loss,
109 | content_weights, style_weights, tv_weights, class_targets):
110 | total_loss = K.variable(0.)
111 |
112 | # Compute content losses
113 | for loss in content_losses:
114 | weighted_loss = K.mean(K.gather(content_weights, class_targets) * loss)
115 | weighted_content_losses.append(weighted_loss)
116 | total_loss += weighted_loss
117 |
118 | # Compute style losses
119 | for loss in style_losses:
120 | weighted_loss = K.mean(K.gather(style_weights, class_targets) * loss)
121 | weighted_style_losses.append(weighted_loss)
122 | total_loss += weighted_loss
123 |
124 | # Compute tv loss
125 | weighted_tv_loss = K.mean(K.gather(tv_weights, class_targets) *
126 | total_var_loss)
127 | total_loss += weighted_tv_loss
128 |
129 | return (total_loss, weighted_content_losses, weighted_style_losses,
130 | weighted_tv_loss)
131 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Module constains utilitary functions.
3 | '''
4 | import tensorflow as tf
5 | from keras.preprocessing.image import load_img, img_to_array
6 | import numpy as np
7 | import h5py
8 | import yaml
9 | from PIL import Image
10 | from keras.applications import vgg16
11 | from keras import backend as K
12 |
13 | def config_gpu(gpu, allow_growth):
14 | # Choosing gpu
15 | if gpu == '-1':
16 | config = tf.ConfigProto(device_count ={'GPU': 0})
17 | else:
18 | if gpu == 'all' or gpu == '':
19 | gpu = ''
20 | config = tf.ConfigProto()
21 | config.gpu_options.visible_device_list = gpu
22 | if allow_growth == True:
23 | config.gpu_options.allow_growth = True
24 | session = tf.Session(config=config)
25 | K.set_session(session)
26 |
27 | def save_checkpoint(checkpoint_path, pastiche_net, log):
28 | with h5py.File(checkpoint_path + '.h5', 'w') as f:
29 | g = f.create_group('model_weights')
30 | pastiche_net.save_weights_to_hdf5_group(g)
31 | g = f.create_group('log')
32 | g.create_dataset('total_loss', data=np.array(log['total_loss']))
33 | g.create_dataset('tv_loss', data=np.array(log['tv_loss']))
34 | g2 = g.create_group('style_loss')
35 | for k, v in log['style_loss'].items():
36 | g2.create_dataset(k, data=v)
37 | g2 = g.create_group('content_loss')
38 | for k, v in log['content_loss'].items():
39 | g2.create_dataset(k, data=v)
40 | f.attrs['args'] = yaml.dump(log['args'])
41 | f.attrs['style_names'] = log['style_names']
42 | f.attrs['style_image_sizes'] = log['style_image_sizes']
43 |
44 | def preprocess_input(x):
45 | return vgg16.preprocess_input(x.astype('float32'))
46 |
47 | def preprocess_image_crop(image_path, img_size):
48 | '''
49 | Preprocess the image scaling it so that its smaller size is img_size.
50 | The larger size is then cropped in order to produce a square image.
51 | '''
52 | img = load_img(image_path)
53 | scale = float(img_size) / min(img.size)
54 | new_size = (int(np.ceil(scale * img.size[0])), int(np.ceil(scale * img.size[1])))
55 | # print('old size: %s,new size: %s' %(str(img.size), str(new_size)))
56 | img = img.resize(new_size, resample=Image.BILINEAR)
57 | img = img_to_array(img)
58 | crop_h = img.shape[0] - img_size
59 | crop_v = img.shape[1] - img_size
60 | img = img[crop_h:img_size+crop_h, crop_v:img_size+crop_v, :]
61 | img = np.expand_dims(img, axis=0)
62 | img = vgg16.preprocess_input(img)
63 | return img
64 |
65 | # util function to open, resize and format pictures into appropriate tensors
66 | def preprocess_image_scale(image_path, img_size=None):
67 | '''
68 | Preprocess the image scaling it so that its larger size is max_size.
69 | This function preserves aspect ratio.
70 | '''
71 | img = load_img(image_path)
72 | if img_size:
73 | scale = float(img_size) / max(img.size)
74 | new_size = (int(np.ceil(scale * img.size[0])), int(np.ceil(scale * img.size[1])))
75 | img = img.resize(new_size, resample=Image.BILINEAR)
76 | img = img_to_array(img)
77 | img = np.expand_dims(img, axis=0)
78 | img = vgg16.preprocess_input(img)
79 | return img
80 |
81 |
82 | # util function to convert a tensor into a valid image
83 | def deprocess_image(x):
84 | x = x[0]
85 | # Remove zero-center by mean pixel
86 | x[:, :, 0] += 103.939
87 | x[:, :, 1] += 116.779
88 | x[:, :, 2] += 123.68
89 | # 'BGR'->'RGB'
90 | x = x[:, :, ::-1]
91 | x = np.clip(x, 0, 255).astype('uint8')
92 | return x
93 |
94 | def std_input_list(input_list, nb_el, name):
95 | if len(input_list) == 1:
96 | return [input_list[0] for _ in range(nb_el)]
97 | elif len(input_list) != nb_el:
98 | raise ValueError('%s list should have length %d, found %d.' %(name, nb_el, len(input_list)))
99 | return input_list
100 |
--------------------------------------------------------------------------------