├── data ├── Monet │ ├── grid.png │ ├── style.png │ ├── result.png │ ├── style_mask.png │ └── target_mask.png ├── Renoir │ ├── grid.png │ ├── creek.jpg │ ├── result.png │ ├── style.png │ ├── comparison.jpg │ ├── creek_mask.jpg │ ├── style_mask.png │ └── target_mask.png ├── Van_Gogh │ ├── seth.jpg │ ├── portrait.jpg │ ├── seth_mask.png │ └── portrait_mask.png └── pretrained │ └── download_models.sh ├── src ├── content_loss.lua ├── style_loss.lua └── utils.lua ├── LICENSE ├── get_mask_hdf5.py ├── README.md └── fast_neural_doodle.lua /data/Monet/grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Monet/grid.png -------------------------------------------------------------------------------- /data/Monet/style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Monet/style.png -------------------------------------------------------------------------------- /data/Renoir/grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Renoir/grid.png -------------------------------------------------------------------------------- /data/Monet/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Monet/result.png -------------------------------------------------------------------------------- /data/Renoir/creek.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Renoir/creek.jpg -------------------------------------------------------------------------------- /data/Renoir/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Renoir/result.png -------------------------------------------------------------------------------- /data/Renoir/style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Renoir/style.png -------------------------------------------------------------------------------- /data/Van_Gogh/seth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Van_Gogh/seth.jpg -------------------------------------------------------------------------------- /data/Monet/style_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Monet/style_mask.png -------------------------------------------------------------------------------- /data/Monet/target_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Monet/target_mask.png -------------------------------------------------------------------------------- /data/Renoir/comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Renoir/comparison.jpg -------------------------------------------------------------------------------- /data/Renoir/creek_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Renoir/creek_mask.jpg -------------------------------------------------------------------------------- /data/Renoir/style_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Renoir/style_mask.png -------------------------------------------------------------------------------- /data/Renoir/target_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Renoir/target_mask.png -------------------------------------------------------------------------------- /data/Van_Gogh/portrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Van_Gogh/portrait.jpg -------------------------------------------------------------------------------- /data/Van_Gogh/seth_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Van_Gogh/seth_mask.png -------------------------------------------------------------------------------- /data/Van_Gogh/portrait_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/fast-neural-doodle/HEAD/data/Van_Gogh/portrait_mask.png -------------------------------------------------------------------------------- /data/pretrained/download_models.sh: -------------------------------------------------------------------------------- 1 | wget -c https://gist.githubusercontent.com/ksimonyan/3785162f95cd2d5fee77/raw/bb2b4fe0a9bb0669211cf3d0bc949dfdda173e9e/VGG_ILSVRC_19_layers_deploy.prototxt 2 | wget -c --no-check-certificate https://bethgelab.org/media/uploads/deeptextures/vgg_normalised.caffemodel 3 | wget -c http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel 4 | -------------------------------------------------------------------------------- /src/content_loss.lua: -------------------------------------------------------------------------------- 1 | local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module') 2 | 3 | function ContentLoss:__init(strength, target, normalize) 4 | parent.__init(self) 5 | self.strength = strength 6 | self.target = target 7 | self.normalize = normalize or false 8 | self.loss = 0 9 | self.crit = nn.MSECriterion() 10 | end 11 | 12 | function ContentLoss:updateOutput(input) 13 | if input:nElement() == self.target:nElement() then 14 | self.loss = self.crit:forward(input, self.target) * self.strength 15 | else 16 | print('WARNING: Skipping content loss') 17 | end 18 | self.output = input 19 | return self.output 20 | end 21 | 22 | function ContentLoss:updateGradInput(input, gradOutput) 23 | if input:nElement() == self.target:nElement() then 24 | self.gradInput = self.crit:backward(input, self.target) 25 | end 26 | if self.normalize then 27 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) 28 | end 29 | self.gradInput:mul(self.strength) 30 | self.gradInput:add(gradOutput) 31 | return self.gradInput 32 | end 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2016 Dmitry Ulyanov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /get_mask_hdf5.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import scipy 3 | import numpy as np 4 | import h5py 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument('--n_colors', type=int, help='How many distinct colors does mask have.') 10 | parser.add_argument('--style_image', help='Path to style image.') 11 | parser.add_argument('--target_image', help='Path to target(content) image.') 12 | parser.add_argument('--style_mask', help='Path to mask for style.') 13 | parser.add_argument('--target_mask', help='Path to target mask.') 14 | parser.add_argument('--out_hdf5', default='masks.hdf5', help='Where to store hdf5 file.') 15 | 16 | args = parser.parse_args() 17 | 18 | # Load images 19 | img_style = scipy.misc.imread(args.style_image) 20 | if args.target_image != None: 21 | img_content = scipy.misc.imread(args.target_image) 22 | 23 | # Load masks 24 | mask_style = scipy.misc.imread(args.style_mask) 25 | mask_target = scipy.misc.imread(args.target_mask) 26 | 27 | # Save shapes 28 | style_shape = mask_style.shape 29 | target_shape = mask_target.shape 30 | if img_style.shape != style_shape: 31 | raise Exception('Style image and mask have different sizes!') 32 | if args.target_image != None: 33 | if img_content.shape != target_shape: 34 | raise Exception('Content image and mask have different sizes!') 35 | 36 | 37 | # Run K-Means to get rid of possible intermediate colors 38 | style_flatten = mask_style.reshape(style_shape[0]*style_shape[1], -1) 39 | target_flatten = mask_target.reshape(target_shape[0]*target_shape[1], -1) 40 | 41 | kmeans = KMeans(n_clusters=args.n_colors, random_state=0).fit(style_flatten) 42 | 43 | # Predict masks 44 | labels_style = kmeans.predict(style_flatten.astype(float)) 45 | labels_target = kmeans.predict(target_flatten.astype(float)) 46 | 47 | style_kval = labels_style.reshape(style_shape[0], style_shape[1]) 48 | target_kval = labels_target.reshape(target_shape[0], target_shape[1]) 49 | 50 | # Dump 51 | f = h5py.File(args.out_hdf5, 'w') 52 | 53 | for i in range(args.n_colors): 54 | f['style_mask_%d' % i] = (style_kval == i).astype(float) 55 | f['target_mask_%d' % i] = (target_kval == i).astype(float) 56 | 57 | # Torch style image save 58 | f['style_img'] = img_style.transpose(2, 0, 1).astype(float) / 255. 59 | if args.target_image != None: 60 | f['content_img'] = img_content.transpose(2, 0, 1).astype(float) / 255. 61 | f['has_content'] = np.array([1]) 62 | else: 63 | f['has_content'] = np.array([0]) 64 | f['n_colors'] = np.array([args.n_colors]) # Torch does not want to read just number 65 | 66 | f.close() 67 | 68 | print ('Done!') 69 | -------------------------------------------------------------------------------- /src/style_loss.lua: -------------------------------------------------------------------------------- 1 | -- Returns a network that computes the CxC Gram matrix from inputs 2 | -- of size C x H x W 3 | function GramMatrix() 4 | local net = nn.Sequential() 5 | net:add(nn.View(-1):setNumInputDims(2)) 6 | local concat = nn.ConcatTable() 7 | concat:add(nn.Identity()) 8 | concat:add(nn.Identity()) 9 | net:add(concat) 10 | net:add(nn.MM(false, true)) 11 | return net 12 | end 13 | 14 | -- Define an nn Module to compute style loss in-place 15 | local StyleLoss, parent = torch.class('nn.StyleLoss', 'nn.Module') 16 | 17 | function StyleLoss:__init(strength, target_grams, normalize, target_masks) 18 | parent.__init(self) 19 | self.normalize = normalize or false 20 | self.strength = strength 21 | self.target_grams = target_grams 22 | self.loss = 0 23 | 24 | self.target_masks = target_masks 25 | self.target_masks_means = nil 26 | self.target_masks_exp = nil 27 | 28 | self.first = true 29 | 30 | self.gram = GramMatrix() 31 | self.crit = nn.SmoothL1Criterion() 32 | 33 | self.gradInput = nil 34 | end 35 | 36 | function StyleLoss:updateOutput(input) 37 | -- We do everything in updateGradInput to save memory 38 | self.output = input 39 | return self.output 40 | end 41 | 42 | function StyleLoss:updateGradInput(input, gradOutput) 43 | -- Iterate through colors and get gradient 44 | self.gradInput = self.gradInput or gradOutput:clone() 45 | self.gradInput:zero() 46 | self.loss = 0 47 | 48 | -- Expand masks for one time 49 | if self.first then 50 | self.first = false 51 | self.target_masks_exp = {} 52 | self.target_masks_means = {} 53 | 54 | for k , _ in ipairs(self.target_masks) do 55 | self.target_masks_exp[k] = self.target_masks[k]:add_dummy():expandAs(input) 56 | self.target_masks_means[k] = self.target_masks[k]:mean() 57 | 58 | -- Delete 59 | self.target_masks[k] = nil 60 | end 61 | end 62 | 63 | -- Apply masks 64 | for k , _ in ipairs(self.target_masks_exp) do 65 | 66 | -- Forward 67 | local masked_input = torch.cmul(input,self.target_masks_exp[k]) 68 | local G = self.gram:forward(masked_input) 69 | 70 | if(self.target_masks_means[k] > 0) then 71 | G:div(input:nElement() * self.target_masks_means[k]) 72 | end 73 | 74 | self.loss = self.loss + self.crit:forward(G, self.target_grams[k]) 75 | 76 | -- Backward 77 | local dG = self.crit:backward(G, self.target_grams[k]) 78 | if self.target_masks_means[k] > 0 then 79 | dG:div(input:nElement() * self.target_masks_means[k]) 80 | end 81 | 82 | local gradInput = self.gram:backward(masked_input, dG) 83 | if self.normalize then 84 | gradInput:div(torch.norm(gradInput, 1) + 1e-8) 85 | end 86 | self.gradInput:add(gradInput) 87 | 88 | end 89 | self.gradInput:add(gradOutput) 90 | 91 | return self.gradInput 92 | end -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Faster neural doodle 2 | 3 | This is my try on drawing with neural networks, which is faster than [Alex J. Champandard's version](https://github.com/alexjc/neural-doodle), and similar in quality. This approach is based on [neural artistic style method](http://arxiv.org/abs/1508.06576) (L. Gatys), whereas Alex's version uses [CNN+MRF approach](http://arxiv.org/abs/1601.04589) of Chuan Li. 4 | 5 | It takes several minutes to redraw `Renoir` example using GPU and it will easily fit in 4GB GPUs. If you were able to work with [Justin Johnson's code for artistic style](https://github.com/jcjohnson/neural-style) then this code should work for you too. 6 | 7 | You can find even faster version [here](https://github.com/DmitryUlyanov/online-neural-doodle). 8 | 9 | ## Requirements 10 | - torch 11 | - torch.cudnn (optional) 12 | - [torch-hdf5](https://github.com/deepmind/torch-hdf5) 13 | - python + numpy + scipy + h5py + sklearn 14 | 15 | Tested with python2.7 and latest `conda` packages. 16 | ## Do it yourself 17 | 18 | First download VGG-19. 19 | ``` 20 | cd data/pretrained && bash download_models.sh && cd ../.. 21 | ``` 22 | 23 | Use this script to get intermediate representations for masks. 24 | ``` 25 | python get_mask_hdf5.py --n_colors=4 --style_image=data/Renoir/style.png --style_mask=data/Renoir/style_mask.png --target_mask=data/Renoir/target_mask.png 26 | ``` 27 | 28 | Now run doodle. 29 | ``` 30 | th fast_neural_doodle.lua -masks_hdf5 masks.hdf5 31 | ``` 32 | 33 | And here is the result. 34 | ![Renoir](data/Renoir/grid.png) 35 | First row: original, second -- result. 36 | 37 | And Monet. 38 | ![Monet](data/Monet/grid.png) 39 | 40 | ## Multiscale 41 | 42 | Processing the image at low resolution first can provide a significant speed-up. You can pass a list of resolutions to use when processing. Passing `256` means that the images and masks should be resized to `256x256` resolution. With `0` passed no resizing is done. Here is an example for cmd parameters: 43 | - `-num_iterations 450,100 -resolutions 256,0` 44 | Which means: work for 450 iterations at `256x256` resolution and 100 iterations at original. 45 | 46 | `Monet` and `Renoir` examples take ~1.5 min to process with these options. 47 | 48 | ## Style transfer 49 | 50 | You can also provide target image to use in content loss (in the same way as in neural artisctic style algorithm) via `--target_image` option of `get\_mask\_hdf5.py` script. 51 | 52 | Example: 53 | ``` 54 | python get_mask_hdf5.py --n_colors=4 --style_image=data/Renoir/style.png --style_mask=data/Renoir/style_mask.png --target_mask=data/Renoir/creek_mask.jpg --target_image=data/Renoir/creek.jpg 55 | th fast_neural_doodle.lua -masks_hdf5 masks.hdf5 56 | ``` 57 | 58 | ![Renoir](data/Renoir/comparison.jpg) 59 | Upper left: target image. Upper right: neural doodle with target image, i.e. both the masks and content loss were used. Lower left: regular neural doodle without content loss. Lower right: stylization without masks, with high style weight, obtained via [neural style code](https://github.com/jcjohnson/neural-style). With high style weight, stylization tends to mix unrelated parts of image, such as patches of grass floating in the sky on last picture. Neural doodle with content loss allows to generate highly stylized images without this problem. 60 | 61 | ## Misc 62 | - Supported backends: 63 | - nn (CPU/GPU mode) 64 | - cudnn 65 | - clnn (not tested yet..) 66 | 67 | - When using `-backend cudnn` do not forget to switch `-cudnn_autotune`. 68 | 69 | ## Acknowledgement 70 | 71 | The code is heavily based on [Justin Johnson's great code](https://github.com/jcjohnson/neural-style) for artistic style. 72 | 73 | ## Citation 74 | 75 | If you use this code for your research please cite [neural-style](https://github.com/jcjohnson/neural-style) and this repository. 76 | 77 | ``` 78 | @misc{Ulyanov2016fastdoodle, 79 | author = {Ulyanov, Dmitry}, 80 | title = {Fast Neural Doodle}, 81 | year = {2016}, 82 | publisher = {GitHub}, 83 | journal = {GitHub repository}, 84 | howpublished = {\url{https://github.com/DmitryUlyanov/fast-neural-doodle}}, 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /src/utils.lua: -------------------------------------------------------------------------------- 1 | 2 | function torch.add_dummy(self) 3 | local sz = self:size() 4 | local new_sz = torch.Tensor(sz:size()+1) 5 | new_sz[1] = 1 6 | new_sz:narrow(1,2,sz:size()):copy(torch.Tensor{sz:totable()}) 7 | return self:view(new_sz:long():storage()) 8 | end 9 | 10 | function torch.FloatTensor:add_dummy() 11 | return torch.add_dummy(self) 12 | end 13 | function torch.DoubleTensor:add_dummy() 14 | return torch.add_dummy(self) 15 | end 16 | 17 | if params.gpu >= 0 then 18 | if params.backend ~= 'clnn' then 19 | function torch.CudaTensor:add_dummy() 20 | return torch.add_dummy(self) 21 | end 22 | else 23 | function torch.ClTensor:add_dummy() 24 | return torch.add_dummy(self) 25 | end 26 | end 27 | end 28 | 29 | 30 | function deepcopy(orig) 31 | local orig_type = type(orig) 32 | local copy 33 | if orig_type == 'table' then 34 | copy = {} 35 | for orig_key, orig_value in next, orig, nil do 36 | copy[deepcopy(orig_key)] = deepcopy(orig_value) 37 | end 38 | setmetatable(copy, deepcopy(getmetatable(orig))) 39 | else -- number, string, boolean, etc 40 | copy = orig 41 | end 42 | return copy 43 | end 44 | 45 | function build_filename(output_image, iteration) 46 | local ext = paths.extname(output_image) 47 | local basename = paths.basename(output_image, ext) 48 | local directory = paths.dirname(output_image) 49 | return string.format('%s/%s_%d.%s',directory, basename, iteration, ext) 50 | end 51 | 52 | 53 | -- Preprocess an image before passing it to a Caffe model. 54 | -- We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, 55 | -- and subtract the mean pixel. 56 | function preprocess(img) 57 | local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68}) 58 | local perm = torch.LongTensor{3, 2, 1} 59 | img = img:index(1, perm):mul(256.0) 60 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 61 | img:add(-1, mean_pixel) 62 | return img 63 | end 64 | 65 | 66 | -- Undo the above preprocessing. 67 | function deprocess(img) 68 | local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) 69 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 70 | img = img + mean_pixel 71 | local perm = torch.LongTensor{3, 2, 1} 72 | img = img:index(1, perm):div(256.0) 73 | return img 74 | end 75 | 76 | function maybe_print(t, loss, style_losses) 77 | local verbose = (params.print_iter > 0 and t % params.print_iter == 0) 78 | if verbose then 79 | print(string.format('Iteration %d / %d', t, cur_num_iterations)) 80 | for i, loss_module in ipairs(style_losses) do 81 | print(string.format(' Style %d loss: %f', i, loss_module.loss)) 82 | end 83 | print(string.format(' Total loss: %f', loss)) 84 | end 85 | end 86 | 87 | function maybe_save(t, img) 88 | local should_save = params.save_iter > 0 and t % params.save_iter == 0 89 | should_save = should_save or t == cur_num_iterations 90 | if should_save then 91 | local disp = deprocess(img:double()) 92 | disp = image.minmax{tensor=disp, min=0, max=1} 93 | local filename = build_filename(params.output_image, t) 94 | if t == cur_num_iterations then 95 | filename = params.output_image 96 | end 97 | image.save(filename, disp) 98 | end 99 | end 100 | 101 | local TVLoss, parent = torch.class('nn.TVLoss', 'nn.Module') 102 | 103 | function TVLoss:__init(strength) 104 | parent.__init(self) 105 | self.strength = strength 106 | self.x_diff = torch.Tensor() 107 | self.y_diff = torch.Tensor() 108 | end 109 | 110 | function TVLoss:updateOutput(input) 111 | self.output = input 112 | return self.output 113 | end 114 | 115 | -- TV loss backward pass inspired by kaishengtai/neuralart 116 | function TVLoss:updateGradInput(input, gradOutput) 117 | self.gradInput:resizeAs(input):zero() 118 | local C, H, W = input:size(1), input:size(2), input:size(3) 119 | self.x_diff:resize(3, H - 1, W - 1) 120 | self.y_diff:resize(3, H - 1, W - 1) 121 | self.x_diff:copy(input[{{}, {1, -2}, {1, -2}}]) 122 | self.x_diff:add(-1, input[{{}, {1, -2}, {2, -1}}]) 123 | self.y_diff:copy(input[{{}, {1, -2}, {1, -2}}]) 124 | self.y_diff:add(-1, input[{{}, {2, -1}, {1, -2}}]) 125 | self.gradInput[{{}, {1, -2}, {1, -2}}]:add(self.x_diff):add(self.y_diff) 126 | self.gradInput[{{}, {1, -2}, {2, -1}}]:add(-1, self.x_diff) 127 | self.gradInput[{{}, {2, -1}, {1, -2}}]:add(-1, self.y_diff) 128 | self.gradInput:mul(self.strength) 129 | self.gradInput:add(gradOutput) 130 | return self.gradInput 131 | end 132 | -------------------------------------------------------------------------------- /fast_neural_doodle.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'image' 4 | require 'optim' 5 | require 'hdf5' 6 | require 'loadcaffe' 7 | require 'src/style_loss' 8 | require 'src/content_loss' 9 | 10 | local cmd = torch.CmdLine() 11 | 12 | -- Basic options 13 | cmd:option('-gpu', 0, 'Zero-indexed ID of the GPU to use; for CPU mode set -gpu = -1') 14 | cmd:option('-masks_hdf5', 'masks.hdf5', 15 | 'Path to .hdf5 file with masks. It can be obtained with get_mask_hdf5.py.') 16 | 17 | -- Optimization options 18 | cmd:option('-content_weight', 1e-3) 19 | cmd:option('-style_weight', 1e0) 20 | cmd:option('-tv_weight', 0, 'TV weight, zero works fine for me.') 21 | cmd:option('-normalize_gradients', false) 22 | cmd:option('-optimizer', 'lbfgs', 'lbfgs|adam') 23 | cmd:option('-learning_rate', 1e1) 24 | 25 | cmd:option('-num_iterations', '1000', 26 | 'Comma separated (no spaces) list with iteration number to do at corresponding resolution.') 27 | cmd:option('-resolutions', '0', 'Comma separated (no spaces) list or resolutions. 0 for original') 28 | 29 | -- Output options 30 | cmd:option('-print_iter', 50) 31 | cmd:option('-save_iter', 50) 32 | cmd:option('-output_image', 'out.png') 33 | 34 | -- Other options 35 | cmd:option('-style_scale', 1.0) 36 | cmd:option('-pooling', 'avg', 'max|avg') 37 | cmd:option('-proto_file', 'data/pretrained/VGG_ILSVRC_19_layers_deploy.prototxt') 38 | cmd:option('-model_file', 'data/pretrained/VGG_ILSVRC_19_layers.caffemodel') 39 | cmd:option('-backend', 'nn', 'nn|cudnn|clnn') 40 | cmd:option('-cudnn_autotune', false) 41 | cmd:option('-seed', -1) 42 | 43 | cmd:option('-vgg_no_pad', false, 'Because of border effects padding is advised to be set to `valid`. This flag does it.') 44 | cmd:option('-content_layers', 'relu4_2', 'layers for content') 45 | cmd:option('-style_layers', 'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', 'layers for style') 46 | 47 | local function main() 48 | init = true 49 | 50 | -- Load images 51 | local f_data = hdf5.open(params.masks_hdf5) 52 | local style_img = f_data:read('style_img'):all():float() 53 | if cur_resolution ~= 0 then 54 | style_img = image.scale(style_img, cur_resolution, cur_resolution) 55 | end 56 | style_img = preprocess(style_img):float() 57 | 58 | local has_content = f_data:read('has_content'):all()[1] == 1 59 | local content_img = nil 60 | if has_content then 61 | content_img = f_data:read('content_img'):all():float() 62 | if cur_resolution ~= 0 then 63 | content_img = image.scale(content_img, cur_resolution, cur_resolution) 64 | end 65 | content_img = preprocess(content_img):float() 66 | else 67 | print('Content image is not provided, content loss will be ignored') 68 | params.content_weight = 0 69 | end 70 | 71 | if params.gpu >= 0 then 72 | if params.backend ~= 'clnn' then 73 | style_img = style_img:cuda() 74 | if has_content then 75 | content_img = content_img:cuda() 76 | end 77 | else 78 | style_img = style_img:cl() 79 | if has_content then 80 | content_img = content_img:cl() 81 | end 82 | end 83 | end 84 | 85 | local n_colors = f_data:read('n_colors'):all()[1] 86 | 87 | -- Load masks 88 | local style_masks, target_masks = {}, {} 89 | for k = 0, n_colors - 1 do 90 | local style_mask = f_data:read('style_mask_' .. k):all():float() 91 | local target_mask = f_data:read('target_mask_' .. k):all():float() 92 | 93 | -- Scale 94 | if cur_resolution ~= 0 then 95 | style_mask = image.scale(style_mask, cur_resolution, cur_resolution, 'simple') 96 | target_mask = image.scale(target_mask, cur_resolution, cur_resolution, 'simple') 97 | end 98 | table.insert(style_masks, style_mask) 99 | table.insert(target_masks, target_mask) 100 | end 101 | local target_size = target_masks[1]:size() 102 | 103 | local content_layers = params.content_layers:split(",") 104 | local style_layers = params.style_layers:split(",") 105 | 106 | -- Set up the network, inserting style and content loss modules 107 | local content_losses, style_losses = {}, {} 108 | local next_content_idx, next_style_idx = 1, 1 109 | local net = nn.Sequential() 110 | 111 | if params.tv_weight > 0 then 112 | local tv_mod = nn.TVLoss(params.tv_weight):float() 113 | if params.gpu >= 0 then 114 | if params.backend ~= 'clnn' then 115 | tv_mod:cuda() 116 | else 117 | tv_mod:cl() 118 | end 119 | end 120 | net:add(tv_mod) 121 | end 122 | for i = 1, #cnn do 123 | 124 | if next_style_idx <= #style_layers then 125 | local layer = cnn:get(i) 126 | local name = layer.name 127 | local layer_type = torch.type(layer) 128 | local is_pooling = (layer_type == 'cudnn.SpatialMaxPooling' or layer_type == 'nn.SpatialMaxPooling') 129 | local is_conv = (layer_type == 'nn.SpatialConvolution' or layer_type == 'cudnn.SpatialConvolution') 130 | 131 | if is_pooling then 132 | 133 | if params.pooling == 'avg' then 134 | assert(layer.padW == 0 and layer.padH == 0) 135 | local kW, kH = layer.kW, layer.kH 136 | local dW, dH = layer.dW, layer.dH 137 | local avg_pool_layer = nn.SpatialAveragePooling(kW, kH, dW, dH):float() 138 | if params.gpu >= 0 then 139 | if params.backend ~= 'clnn' then 140 | avg_pool_layer:cuda() 141 | else 142 | avg_pool_layer:cl() 143 | end 144 | end 145 | local msg = 'Replacing max pooling at layer %d with average pooling' 146 | print(string.format(msg, i)) 147 | 148 | layer = avg_pool_layer 149 | end 150 | 151 | layer:floor() 152 | for k, _ in ipairs(style_masks) do 153 | style_masks[k] = image.scale(style_masks[k] , math.floor(style_masks[k]:size(2)/2), math.floor(style_masks[k]:size(1)/2)) 154 | target_masks[k] = image.scale(target_masks[k] , math.floor(target_masks[k]:size(2)/2), math.floor(target_masks[k]:size(1)/2)) 155 | end 156 | 157 | style_masks = deepcopy(style_masks) 158 | target_masks = deepcopy(target_masks) 159 | 160 | elseif is_conv then 161 | 162 | local sap = nn.SpatialAveragePooling(3,3,1,1,1,1):float() 163 | for k, _ in ipairs (style_masks) do 164 | style_masks[k] = sap:forward(style_masks[k]:add_dummy())[1]:clone() 165 | target_masks[k] = sap:forward(target_masks[k]:add_dummy())[1]:clone() 166 | end 167 | 168 | -- Turn off padding 169 | if params.vgg_no_pad and (layer_type == 'nn.SpatialConvolution' or layer_type == 'cudnn.SpatialConvolution') then 170 | layer.padW = 0 171 | layer.padH = 0 172 | 173 | for k, _ in ipairs (style_masks) do 174 | style_masks[k] = image.crop(style_masks[k] , 'c', style_masks[k]:size(2)-2, style_masks[k]:size(1)-2) 175 | target_masks[k] = image.crop(target_masks[k] , 'c', target_masks[k]:size(2)-2, target_masks[k]:size(1)-2) 176 | end 177 | style_masks = deepcopy(style_masks) 178 | target_masks = deepcopy(target_masks) 179 | end 180 | end 181 | 182 | net:add(layer) 183 | 184 | -- Content 185 | if has_content and name == content_layers[next_content_idx] then 186 | print("Setting up content layer", i, ":", layer.name) 187 | local target = net:forward(content_img):clone() 188 | local norm = params.normalize_gradients 189 | local loss_module = nn.ContentLoss(params.content_weight, target, norm):float() 190 | if params.gpu >= 0 then 191 | if params.backend ~= 'clnn' then 192 | loss_module:cuda() 193 | else 194 | loss_module:cl() 195 | end 196 | end 197 | net:add(loss_module) 198 | table.insert(content_losses, loss_module) 199 | next_content_idx = next_content_idx + 1 200 | end 201 | -- Style 202 | if name == style_layers[next_style_idx] then 203 | print("Setting up style layer ", i, ":", layer.name) 204 | local gram = GramMatrix():float() 205 | if params.gpu >= 0 then 206 | if params.backend ~= 'clnn' then 207 | gram = gram:cuda() 208 | else 209 | gram = gram:cl() 210 | end 211 | end 212 | 213 | local target_features = net:forward(style_img):clone() 214 | 215 | -- Compute target gram mats 216 | local target_grams = {} 217 | for k, _ in ipairs(style_masks) do 218 | local layer_mask = style_masks[k]:add_dummy():expandAs(target_features) 219 | if params.gpu >= 0 then 220 | if params.backend ~= 'clnn' then 221 | layer_mask = layer_mask:cuda() 222 | else 223 | layer_mask = layer_mask:cl() 224 | end 225 | end 226 | local masked = torch.cmul(target_features, layer_mask) 227 | 228 | local target = gram:forward(masked):clone() 229 | 230 | if style_masks[k]:mean() > 0 then 231 | target:div(target_features:nElement() * style_masks[k]:mean()) 232 | end 233 | 234 | target_grams[k] = target 235 | end 236 | 237 | local norm = params.normalize_gradients 238 | local loss_module = nn.StyleLoss(params.style_weight, target_grams, norm, deepcopy(target_masks)):float() 239 | if params.gpu >= 0 then 240 | if params.backend ~= 'clnn' then 241 | loss_module:cuda() 242 | else 243 | loss_module:cl() 244 | end 245 | end 246 | 247 | net:add(loss_module) 248 | table.insert(style_losses, loss_module) 249 | next_style_idx = next_style_idx + 1 250 | end 251 | end 252 | end 253 | init = false 254 | -- We don't need the base CNN anymore, so clean it up to save memory. 255 | 256 | for i=1,#net.modules do 257 | local module = net.modules[i] 258 | if torch.type(module) == 'nn.SpatialConvolutionMM' then 259 | module.gradWeight = nil 260 | module.gradBias = nil 261 | end 262 | end 263 | collectgarbage() 264 | 265 | -- Initialize with previous or with noise 266 | if img then 267 | img = image.scale(img:float(), target_size[2], target_size[1]) 268 | else 269 | if params.seed >= 0 then 270 | torch.manualSeed(params.seed) 271 | end 272 | img = torch.randn(3, target_size[1], target_size[2]):float():mul(0.001) 273 | end 274 | 275 | if params.gpu >= 0 then 276 | if params.backend ~= 'clnn' then 277 | img = img:cuda() 278 | else 279 | img = img:cl() 280 | end 281 | end 282 | 283 | -- Run it through the network once to get the proper size for the gradient 284 | -- All the gradients will come from the extra loss modules, so we just pass 285 | -- zeros into the top of the net on the backward pass. 286 | local y = net:forward(img) 287 | local dy = img.new(#y):zero() 288 | 289 | -- Declaring this here lets us access it in maybe_print 290 | local optim_state = nil 291 | if params.optimizer == 'lbfgs' then 292 | optim_state = { 293 | maxIter = cur_num_iterations, 294 | tolX = -1, 295 | tolFun = -1, 296 | verbose=true, 297 | } 298 | elseif params.optimizer == 'adam' then 299 | optim_state = { 300 | learningRate = params.learning_rate, 301 | } 302 | else 303 | error(string.format('Unrecognized optimizer "%s"', params.optimizer)) 304 | end 305 | 306 | -- Function to evaluate loss and gradient. We run the net forward and 307 | -- backward to get the gradient, and sum up losses from the loss modules. 308 | -- optim.lbfgs internally handles iteration and calls this fucntion many 309 | -- times, so we manually count the number of iterations to handle printing 310 | -- and saving intermediate results. 311 | local num_calls = 0 312 | local function feval(x) 313 | num_calls = num_calls + 1 314 | net:forward(x) 315 | local grad = net:updateGradInput(x, dy) 316 | local loss = 0 317 | for _, mod in ipairs(style_losses) do 318 | loss = loss + mod.loss 319 | end 320 | maybe_print(num_calls, loss, style_losses) 321 | maybe_save(num_calls, img) 322 | 323 | collectgarbage() 324 | -- optim.lbfgs expects a vector for gradients 325 | return loss, grad:view(grad:nElement()) 326 | end 327 | 328 | -- Run optimization. 329 | if params.optimizer == 'lbfgs' then 330 | print('Running optimization with L-BFGS') 331 | local x, losses = optim.lbfgs(feval, img, optim_state) 332 | elseif params.optimizer == 'adam' then 333 | print('Running optimization with ADAM') 334 | for t = 1, cur_num_iterations do 335 | local x, losses = optim.adam(feval, img, optim_state) 336 | end 337 | end 338 | end 339 | 340 | 341 | 342 | ------------------------------------------------------------- 343 | 344 | params = cmd:parse(arg) 345 | 346 | -- Load libs 347 | if params.gpu >= 0 then 348 | if params.backend ~= 'clnn' then 349 | require 'cutorch' 350 | require 'cunn' 351 | cutorch.setDevice(params.gpu + 1) 352 | else 353 | require 'clnn' 354 | require 'cltorch' 355 | cltorch.setDevice(params.gpu + 1) 356 | end 357 | else 358 | params.backend = 'nn' 359 | end 360 | require 'src/utils' 361 | 362 | if params.backend == 'cudnn' then 363 | require 'cudnn' 364 | if params.cudnn_autotune then 365 | cudnn.benchmark = true 366 | end 367 | end 368 | 369 | -- Load VGG 370 | local loadcaffe_backend = params.backend 371 | if params.backend == 'clnn' then loadcaffe_backend = 'nn' end 372 | cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend):float() 373 | if params.gpu >= 0 then 374 | if params.backend ~= 'clnn' then 375 | cnn:cuda() 376 | else 377 | cnn:cl() 378 | end 379 | end 380 | 381 | for i = 1,9 do 382 | cnn:remove() 383 | end 384 | 385 | 386 | -- run at different resolutions 387 | local resolutions = params.resolutions:split(",") 388 | local num_iterations = params.num_iterations:split(",") 389 | assert(#resolutions == #num_iterations, 'Incorrect resolution-iteration correspondence.') 390 | 391 | img = nil 392 | for res = 1, #resolutions do 393 | cur_resolution = tonumber(resolutions[res]) 394 | cur_num_iterations = tonumber(num_iterations[res]) 395 | 396 | main(params) 397 | end 398 | --------------------------------------------------------------------------------