├── .gitignore ├── images ├── all.jpg ├── all11.jpg ├── figure1.jpg └── stylized │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── 8.jpg │ └── 9.jpg ├── .editorconfig ├── experiments ├── images │ ├── 9styles │ │ ├── wave.jpg │ │ ├── candy.jpg │ │ ├── mosaic.jpg │ │ ├── udnie.jpg │ │ ├── feathers.jpg │ │ ├── la_muse.jpg │ │ ├── the_scream.jpg │ │ ├── starry_night.jpg │ │ └── composition_vii.jpg │ └── content │ │ ├── noise.jpg │ │ ├── flowers.jpg │ │ ├── shenyang.jpg │ │ ├── shenyang3.jpg │ │ └── venice-boat.jpg ├── models │ ├── download_models.sh │ └── hang.lua ├── extractGram.lua ├── utils │ ├── preprocess.lua │ ├── DataLoader.lua │ ├── getImages.lua │ └── utils.lua ├── opts.lua ├── test.lua ├── webcam.lua └── main.lua ├── modules ├── utils.lua ├── TotalVariation.lua ├── Calibrate.lua ├── StyleLoss.lua ├── ContentLoss.lua ├── GramMatrix.lua ├── InstanceNormalization.lua ├── Inspiration.lua ├── layer_utils.lua └── PerceptualCriterion.lua ├── texture-scm-1.rockspec ├── init.lua ├── Training.md ├── CMakeLists.txt ├── README.md └── cmake └── select_compute_arch.cmake /.gitignore: -------------------------------------------------------------------------------- 1 | build.luarocks/ 2 | *.DS_Store 3 | *.swp 4 | *.t7 5 | -------------------------------------------------------------------------------- /images/all.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/all.jpg -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = tab 5 | indent_size = 2 6 | -------------------------------------------------------------------------------- /images/all11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/all11.jpg -------------------------------------------------------------------------------- /images/figure1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/figure1.jpg -------------------------------------------------------------------------------- /images/stylized/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/1.jpg -------------------------------------------------------------------------------- /images/stylized/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/2.jpg -------------------------------------------------------------------------------- /images/stylized/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/3.jpg -------------------------------------------------------------------------------- /images/stylized/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/4.jpg -------------------------------------------------------------------------------- /images/stylized/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/5.jpg -------------------------------------------------------------------------------- /images/stylized/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/6.jpg -------------------------------------------------------------------------------- /images/stylized/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/7.jpg -------------------------------------------------------------------------------- /images/stylized/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/8.jpg -------------------------------------------------------------------------------- /images/stylized/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/9.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/wave.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/candy.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/mosaic.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/udnie.jpg -------------------------------------------------------------------------------- /experiments/images/content/noise.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/noise.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/feathers.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/la_muse.jpg -------------------------------------------------------------------------------- /experiments/images/content/flowers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/flowers.jpg -------------------------------------------------------------------------------- /experiments/images/content/shenyang.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/shenyang.jpg -------------------------------------------------------------------------------- /experiments/images/content/shenyang3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/shenyang3.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/the_scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/the_scream.jpg -------------------------------------------------------------------------------- /experiments/images/content/venice-boat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/venice-boat.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/starry_night.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/composition_vii.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/composition_vii.jpg -------------------------------------------------------------------------------- /experiments/models/download_models.sh: -------------------------------------------------------------------------------- 1 | cd models 2 | wget -O vgg16.t7 "http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7" 3 | wget -O model_9styles.t7 "https://www.dropbox.com/s/50bn53g0ok26aac/model_9styles.t7?dl=1" 4 | cd ../ 5 | -------------------------------------------------------------------------------- /modules/utils.lua: -------------------------------------------------------------------------------- 1 | -- adding first dummy dimension (by Dmitry Ulyanov) 2 | 3 | function torch.FloatTensor:add_dummy() 4 | local sz = self:size() 5 | local new_sz = torch.Tensor(sz:size()+1) 6 | new_sz[1] = 1 7 | new_sz:narrow(1,2,sz:size()):copy(torch.Tensor{sz:totable()}) 8 | 9 | if self:isContiguous() then 10 | return self:view(new_sz:long():storage()) 11 | else 12 | return self:reshape(new_sz:long():storage()) 13 | end 14 | end 15 | 16 | torch.Tensor.add_dummy = torch.FloatTensor.add_dummy 17 | if cutorch then 18 | torch.CudaTensor.add_dummy = torch.FloatTensor.add_dummy 19 | end 20 | -------------------------------------------------------------------------------- /texture-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "texture" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/zhanghang1989/MSG-Net.git", 6 | tag = "master" 7 | } 8 | 9 | description = { 10 | summary = "Texture Master Network", 11 | detailed = [[ 12 | Texture Master Network 13 | ]], 14 | homepage = "https://github.com/zhanghang1989/MSG-Net" 15 | } 16 | 17 | dependencies = { 18 | "torch >= 7.0", 19 | "cutorch >= 1.0" 20 | } 21 | 22 | build = { 23 | type = "cmake", 24 | variables = { 25 | CMAKE_BUILD_TYPE="Release", 26 | CMAKE_PREFIX_PATH="$(LUA_BINDIR)/..", 27 | CMAKE_INSTALL_PREFIX="$(PREFIX)" 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | require 'nn' 17 | require 'cunn' 18 | require 'cudnn' 19 | 20 | -- load packages from perceptual-loss 21 | require 'texture.utils' 22 | require 'texture.GramMatrix' 23 | require 'texture.ContentLoss' 24 | require 'texture.StyleLoss' 25 | require 'texture.TotalVariation' 26 | require 'texture.PerceptualCriterion' 27 | require 'texture.InstanceNormalization' 28 | 29 | -- load MSG-Net dependencies 30 | require 'texture.Calibrate' 31 | require 'texture.Inspiration' 32 | -------------------------------------------------------------------------------- /modules/TotalVariation.lua: -------------------------------------------------------------------------------- 1 | local TVLoss, parent = torch.class('nn.TotalVariation', 'nn.Module') 2 | 3 | 4 | function TVLoss:__init(strength) 5 | parent.__init(self) 6 | self.strength = strength 7 | self.x_diff = torch.Tensor() 8 | self.y_diff = torch.Tensor() 9 | end 10 | 11 | 12 | function TVLoss:updateOutput(input) 13 | self.output = input 14 | return self.output 15 | end 16 | 17 | 18 | -- TV loss backward pass inspired by kaishengtai/neuralart 19 | function TVLoss:updateGradInput(input, gradOutput) 20 | self.gradInput:resizeAs(input):zero() 21 | local N, C = input:size(1), input:size(2) 22 | local H, W = input:size(3), input:size(4) 23 | self.x_diff:resize(N, 3, H - 1, W - 1) 24 | self.y_diff:resize(N, 3, H - 1, W - 1) 25 | self.x_diff:copy(input[{{}, {}, {1, -2}, {1, -2}}]) 26 | self.x_diff:add(-1, input[{{}, {}, {1, -2}, {2, -1}}]) 27 | self.y_diff:copy(input[{{}, {}, {1, -2}, {1, -2}}]) 28 | self.y_diff:add(-1, input[{{}, {}, {2, -1}, {1, -2}}]) 29 | self.gradInput[{{}, {}, {1, -2}, {1, -2}}]:add(self.x_diff):add(self.y_diff) 30 | self.gradInput[{{}, {}, {1, -2}, {2, -1}}]:add(-1, self.x_diff) 31 | self.gradInput[{{}, {}, {2, -1}, {1, -2}}]:add(-1, self.y_diff) 32 | self.gradInput:mul(self.strength) 33 | self.gradInput:add(gradOutput) 34 | return self.gradInput 35 | end 36 | 37 | -------------------------------------------------------------------------------- /modules/Calibrate.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local Calibrate, parent = torch.class('nn.Calibrate', 'nn.Module') 17 | 18 | function Calibrate:__init() 19 | parent.__init(self) 20 | self.gram = torch.Tensor() 21 | self.calibrator = nn.GramMatrix() 22 | end 23 | 24 | function Calibrate:updateOutput(input) 25 | assert(self and input) 26 | self.gram = self.calibrator:forward(input):squeeze() 27 | return input 28 | end 29 | 30 | function Calibrate:updateGradInput(input, gradOutput) 31 | assert(self and gradOutput) 32 | self.gradInput = gradOutput 33 | return self.gradInput 34 | end 35 | 36 | function Calibrate:getGram() 37 | return self.gram:clone() 38 | end 39 | -------------------------------------------------------------------------------- /Training.md: -------------------------------------------------------------------------------- 1 | ### Install Dependencies 2 | - Install python dependencies 3 | ```bash 4 | sudo apt-get -y install python2.7-dev 5 | sudo apt-get install libhdf5-dev 6 | cd experiments 7 | pip install --user -r requirements.txt 8 | ``` 9 | - Install the deepmind/torch-hdf5 which gives HDF5 bindings for Torch: 10 | ```bash 11 | luarocks install https://raw.githubusercontent.com/deepmind/torch-hdf5/master/hdf5-0-0.rockspec 12 | ``` 13 | ### Download and Prepare the Data 14 | - Download the dataset 15 | ```bash 16 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip 17 | wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip 18 | unzip train2014.zip 19 | unzip val2014.zip 20 | ``` 21 | - Create HDF5 file 22 | ```bash 23 | python make_style_dataset.py \ 24 | --train_dir train2014 \ 25 | --val_dir val2014 \ 26 | --output_file data.h5 27 | ``` 28 | - Download the VGG-16 Torch model 29 | ```bash 30 | bash models/download_models.sh 31 | ``` 32 | 33 | ### Train the Model 34 | - Training a 16 style model. For customized styles, set ``style_image_folder`` pointing at the folders containing the style images. 35 | ```bash 36 | th main.lua \ 37 | -h5_file data.h5 \ 38 | -style_image_folder images/9styles \ 39 | -style_image_size 512 \ 40 | -checkpoint_name 9styles \ 41 | -gpu 0 42 | ``` 43 | - Test the model 44 | ```bash 45 | th test.lua \ 46 | -premodel 9styles.t7 \ 47 | -input_image images/content/venice-boat.jpg 48 | ``` 49 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## Free to reuse and distribute this software for research or 8 | ## non-profit purpose, subject to the following conditions: 9 | ## 1. The code must retain the above copyright notice, this list of 10 | ## conditions. 11 | ## 2. Original authors' names are not deleted. 12 | ## 3. The authors' names are not used to endorse or promote products 13 | ## derived from this software 14 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR) 17 | CMAKE_POLICY(VERSION 2.8) 18 | 19 | # Find Torch and CUDA 20 | FIND_PACKAGE(Torch REQUIRED) 21 | FIND_PACKAGE(CUDA 6.5 REQUIRED) 22 | 23 | # Detect CUDA architecture and get best NVCC flags 24 | IF(NOT COMMAND CUDA_SELECT_NVCC_ARCH_FLAGS) 25 | INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/select_compute_arch.cmake) 26 | ENDIF() 27 | CUDA_SELECT_NVCC_ARCH_FLAGS(NVCC_FLAGS_EXTRA $ENV{TORCH_CUDA_ARCH_LIST}) 28 | LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) 29 | 30 | INCLUDE_DIRECTORIES( 31 | ./include 32 | ${CMAKE_CURRENT_SOURCE_DIR} 33 | "${Torch_INSTALL_INCLUDE}/THC" 34 | ) 35 | LINK_DIRECTORIES("${Torch_INSTALL_LIB}") 36 | 37 | # ADD lua source files 38 | FILE(GLOB src *.c *.cu) 39 | FILE(GLOB luasrc *.lua modules/*.lua) 40 | 41 | # ADD the torch package and link denpendencies 42 | ADD_TORCH_PACKAGE(texture "" "${luasrc}") 43 | TARGET_LINK_LIBRARIES(texture 44 | ) 45 | -------------------------------------------------------------------------------- /modules/StyleLoss.lua: -------------------------------------------------------------------------------- 1 | local StyleLoss, parent = torch.class('nn.StyleLoss', 'nn.Module') 2 | 3 | 4 | function StyleLoss:__init(strength) 5 | parent.__init(self) 6 | self.strength = strength or 1.0 7 | self.loss = 0 8 | self.target = torch.Tensor() 9 | 10 | self.agg = nn.GramMatrix() 11 | self.agg_out = nil 12 | self.mode = 'none' 13 | self.crit = nn.MSECriterion() 14 | end 15 | 16 | 17 | function StyleLoss:updateOutput(input) 18 | self.agg_out = self.agg:forward(input) 19 | if self.mode == 'capture' then 20 | self.target:resizeAs(self.agg_out):copy(self.agg_out) 21 | elseif self.mode == 'loss' then 22 | local target = self.target 23 | if self.agg_out:size(1) > 1 and self.target:size(1) == 1 then 24 | -- Handle minibatch inputs 25 | target = target:expandAs(self.agg_out) 26 | end 27 | self.loss = self.strength * self.crit(self.agg_out, target) 28 | self._target = target 29 | end 30 | self.output = input 31 | return self.output 32 | end 33 | 34 | 35 | function StyleLoss:updateGradInput(input, gradOutput) 36 | if self.mode == 'capture' or self.mode == 'none' then 37 | self.gradInput = gradOutput 38 | elseif self.mode == 'loss' then 39 | self.crit:backward(self.agg_out, self._target) 40 | self.crit.gradInput:mul(self.strength) 41 | self.agg:backward(input, self.crit.gradInput) 42 | self.gradInput:add(self.agg.gradInput, gradOutput) 43 | end 44 | return self.gradInput 45 | end 46 | 47 | 48 | function StyleLoss:setMode(mode) 49 | if mode ~= 'capture' and mode ~= 'loss' and mode ~= 'none' then 50 | error(string.format('Invalid mode "%s"', mode)) 51 | end 52 | self.mode = mode 53 | end 54 | -------------------------------------------------------------------------------- /modules/ContentLoss.lua: -------------------------------------------------------------------------------- 1 | local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module') 2 | 3 | 4 | --[[ 5 | Module to compute content loss in-place. 6 | 7 | The module can be in one of three modes: "none", "capture", or "loss", which 8 | behave as follows: 9 | - "none": This module does nothing; it is basically nn.Identity(). 10 | - "capture": On the forward pass, inputs are captured as targets; otherwise it 11 | is the same as an nn.Identity(). 12 | - "loss": On the forward pass, compute the distance between input and 13 | self.target, store the result in self.loss, and return input. On the backward 14 | pass, add compute the gradient of self.loss with respect to the inputs, and 15 | add this value to the upstream gradOutput to produce gradInput. 16 | --]] 17 | 18 | function ContentLoss:__init(strength) 19 | parent.__init(self) 20 | self.strength = strength or 1.0 21 | self.loss = 0 22 | self.target = torch.Tensor() 23 | 24 | self.mode = 'none' 25 | self.crit = nn.MSECriterion() 26 | 27 | end 28 | 29 | 30 | function ContentLoss:updateOutput(input) 31 | if self.mode == 'capture' then 32 | self.target:resizeAs(input):copy(input) 33 | elseif self.mode == 'loss' then 34 | self.loss = self.strength * self.crit:forward(input, self.target) 35 | end 36 | self.output = input 37 | return self.output 38 | end 39 | 40 | 41 | function ContentLoss:updateGradInput(input, gradOutput) 42 | if self.mode == 'capture' or self.mode == 'none' then 43 | self.gradInput = gradOutput 44 | elseif self.mode == 'loss' then 45 | self.gradInput = self.crit:backward(input, self.target) 46 | self.gradInput:mul(self.strength) 47 | self.gradInput:add(gradOutput) 48 | end 49 | return self.gradInput 50 | end 51 | 52 | 53 | function ContentLoss:setMode(mode) 54 | if mode ~= 'capture' and mode ~= 'loss' and mode ~= 'none' then 55 | error(string.format('Invalid mode "%s"', mode)) 56 | end 57 | self.mode = mode 58 | end 59 | -------------------------------------------------------------------------------- /experiments/extractGram.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | require 'texture' 17 | require 'image' 18 | require 'optim' 19 | 20 | require 'utils.DataLoader' 21 | local utils = require 'utils.utils' 22 | local preprocess = require 'utils.preprocess' 23 | local opts = require 'opts' 24 | local imgLoader = require 'utils.getImages' 25 | 26 | local M = {} 27 | 28 | function M.exec(opt) 29 | local styleLoader = imgLoader(opt.style_image_folder) 30 | if not preprocess[opt.preprocessing] then 31 | local msg = 'invalid -preprocessing "%s"; must be "vgg" or "resnet"' 32 | error(string.format(msg, opt.preprocessing)) 33 | end 34 | preprocess = preprocess[opt.preprocessing] 35 | 36 | models = require('models/' .. opt.model) 37 | local cnet = models.createCNets(opt) 38 | 39 | cnet:cuda() 40 | cnet:evaluate() 41 | 42 | local feat = {} 43 | for i = 1,styleLoader:size() do 44 | feat[i] = {} 45 | local style_image = styleLoader:get(i) 46 | style_image = image.scale(style_image, opt.style_image_size) 47 | style_image = preprocess.preprocess(style_image:add_dummy()) 48 | feat[i] = cnet:calibrate(style_image:cuda()) 49 | end 50 | 51 | local filename = opt.style_image_folder .. '/feat.t7' 52 | torch.save(filename, feat) 53 | print('Feats have been written to ', filename) 54 | end 55 | 56 | return M 57 | -------------------------------------------------------------------------------- /experiments/utils/preprocess.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | 4 | local M = {} 5 | 6 | 7 | local function check_input(img) 8 | assert(img:dim() == 4, 'img must be N x C x H x W') 9 | assert(img:size(2) == 3, 'img must have three channels') 10 | end 11 | 12 | 13 | M.resnet = {} 14 | 15 | local resnet_mean = {0.485, 0.456, 0.406} 16 | local resnet_std = {0.229, 0.224, 0.225} 17 | 18 | 19 | --[[ 20 | Preprocess an image before passing to a ResNet model. The preprocessing is easy: 21 | we just need to subtract the mean and divide by the standard deviation. These 22 | constants are taken from fb.resnet.torch: 23 | 24 | https://github.com/facebook/fb.resnet.torch/blob/master/datasets/imagenet.lua 25 | 26 | Input: 27 | - img: Tensor of shape (N, C, H, W) giving a batch of images. Images are RGB 28 | in the range [0, 1]. 29 | ]] 30 | function M.resnet.preprocess(img) 31 | check_input(img) 32 | local mean = img.new(resnet_mean):view(1, 3, 1, 1):expandAs(img) 33 | local std = img.new(resnet_std):view(1, 3, 1, 1):expandAs(img) 34 | return (img - mean):cdiv(std) 35 | end 36 | 37 | -- Undo ResNet preprocessing. 38 | function M.resnet.deprocess(img) 39 | check_input(img) 40 | local mean = img.new(resnet_mean):view(1, 3, 1, 1):expandAs(img) 41 | local std = img.new(resnet_std):view(1, 3, 1, 1):expandAs(img) 42 | return torch.cmul(img, std):add(mean) 43 | end 44 | 45 | 46 | M.vgg = {} 47 | 48 | local vgg_mean = {103.939, 116.779, 123.68} 49 | 50 | --[[ 51 | Preprocess an image before passing to a VGG model. We need to rescale from 52 | [0, 1] to [0, 255], convert from RGB to BGR, and subtract the mean. 53 | 54 | Input: 55 | - img: Tensor of shape (N, C, H, W) giving a batch of images. Images 56 | ]] 57 | function M.vgg.preprocess(img) 58 | check_input(img) 59 | local mean = img.new(vgg_mean):view(1, 3, 1, 1):expandAs(img) 60 | local perm = torch.LongTensor{3, 2, 1} 61 | return img:index(2, perm):mul(255):add(-1, mean) 62 | end 63 | 64 | 65 | -- Undo VGG preprocessing 66 | function M.vgg.deprocess(img) 67 | check_input(img) 68 | local mean = img.new(vgg_mean):view(1, 3, 1, 1):expandAs(img) 69 | local perm = torch.LongTensor{3, 2, 1} 70 | --img = img + mean 71 | return (img + mean):div(255):index(2, perm) 72 | end 73 | 74 | 75 | return M 76 | -------------------------------------------------------------------------------- /experiments/utils/DataLoader.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'hdf5' 3 | 4 | local utils = require 'utils.utils' 5 | local preprocess = require 'utils.preprocess' 6 | 7 | local DataLoader = torch.class('DataLoader') 8 | 9 | 10 | function DataLoader:__init(opt) 11 | assert(opt.h5_file, 'Must provide h5_file') 12 | assert(opt.batch_size, 'Must provide batch size') 13 | self.preprocess_fn = preprocess[opt.preprocessing].preprocess 14 | 15 | self.h5_file = hdf5.open(opt.h5_file, 'r') 16 | self.batch_size = opt.batch_size 17 | 18 | self.split_idxs = { 19 | train = 1, 20 | val = 1, 21 | } 22 | 23 | self.image_paths = { 24 | train = '/train2014/images', 25 | val = '/val2014/images', 26 | } 27 | 28 | local train_size = self.h5_file:read(self.image_paths.train):dataspaceSize() 29 | self.split_sizes = { 30 | train = train_size[1], 31 | val = self.h5_file:read(self.image_paths.val):dataspaceSize()[1], 32 | } 33 | self.num_channels = train_size[2] 34 | self.image_height = train_size[3] 35 | self.image_width = train_size[4] 36 | 37 | self.num_minibatches = {} 38 | for k, v in pairs(self.split_sizes) do 39 | self.num_minibatches[k] = math.floor(v / self.batch_size) 40 | end 41 | 42 | if opt.max_train and opt.max_train > 0 then 43 | self.split_sizes.train = opt.max_train 44 | end 45 | 46 | self.rgb_to_gray = torch.FloatTensor{0.2989, 0.5870, 0.1140} 47 | end 48 | 49 | 50 | function DataLoader:reset(split) 51 | self.split_idxs[split] = 1 52 | end 53 | 54 | 55 | function DataLoader:getBatch(split) 56 | local path = self.image_paths[split] 57 | 58 | local start_idx = self.split_idxs[split] 59 | local end_idx = math.min(start_idx + self.batch_size - 1, 60 | self.split_sizes[split]) 61 | 62 | -- Load images out of the HDF5 file 63 | local images = self.h5_file:read(path):partial( 64 | {start_idx, end_idx}, 65 | {1, self.num_channels}, 66 | {1, self.image_height}, 67 | {1, self.image_width}):float():div(255) 68 | 69 | -- Advance counters, maybe rolling back to the start 70 | self.split_idxs[split] = end_idx + 1 71 | if self.split_idxs[split] > self.split_sizes[split] then 72 | self.split_idxs[split] = 1 73 | end 74 | 75 | -- Preprocess images 76 | images_pre = self.preprocess_fn(images) 77 | 78 | return images_pre, images_pre 79 | end 80 | 81 | -------------------------------------------------------------------------------- /modules/GramMatrix.lua: -------------------------------------------------------------------------------- 1 | local Gram, parent = torch.class('nn.GramMatrix', 'nn.Module') 2 | 3 | 4 | --[[ 5 | A layer to compute the Gram Matrix of inputs. 6 | 7 | Input: 8 | - features: A tensor of shape (N, C, H, W) or (C, H, W) giving features for 9 | either a single image or a minibatch of images. 10 | 11 | Output: 12 | - gram: A tensor of shape (N, C, C) or (C, C) giving Gram matrix for input. 13 | --]] 14 | 15 | 16 | function Gram:__init(normalize) 17 | parent.__init(self) 18 | if normalize ~= nil then 19 | assert(type(normalize) == 'boolean', 'normalize has to be true/false') 20 | self.normalize = normalize 21 | else 22 | self.normalize = true 23 | end 24 | self.buffer = torch.Tensor() 25 | end 26 | 27 | 28 | function Gram:updateOutput(input) 29 | local C, H, W 30 | if input:dim() == 3 then 31 | C, H, W = input:size(1), input:size(2), input:size(3) 32 | local x_flat = input:view(C, H * W) 33 | self.output:resize(C, C) 34 | self.output:mm(x_flat, x_flat:t()) 35 | elseif input:dim() == 4 then 36 | local N = input:size(1) 37 | C, H, W = input:size(2), input:size(3), input:size(4) 38 | self.output:resize(N, C, C) 39 | local x_flat = input:view(N, C, H * W) 40 | self.output:resize(N, C, C) 41 | self.output:bmm(x_flat, x_flat:transpose(2, 3)) 42 | end 43 | if self.normalize then 44 | -- print('in gram forward; dividing by ', C * H * W) 45 | self.output:div(C * H * W) 46 | end 47 | return self.output 48 | end 49 | 50 | 51 | function Gram:updateGradInput(input, gradOutput) 52 | self.gradInput:resizeAs(input):zero() 53 | local C, H, W 54 | if input:dim() == 3 then 55 | C, H, W = input:size(1), input:size(2), input:size(3) 56 | local x_flat = input:view(C, H * W) 57 | self.buffer:resize(C, H * W) 58 | self.buffer:mm(gradOutput, x_flat) 59 | self.buffer:addmm(gradOutput:t(), x_flat) 60 | self.gradInput = self.buffer:view(C, H, W) 61 | elseif input:dim() == 4 then 62 | local N = input:size(1) 63 | C, H, W = input:size(2), input:size(3), input:size(4) 64 | local x_flat = input:view(N, C, H * W) 65 | self.buffer:resize(N, C, H * W) 66 | self.buffer:bmm(gradOutput, x_flat) 67 | self.buffer:baddbmm(gradOutput:transpose(2, 3), x_flat) 68 | self.gradInput = self.buffer:view(N, C, H, W) 69 | end 70 | if self.normalize then 71 | self.buffer:div(C * H * W) 72 | end 73 | assert(self.gradInput:isContiguous()) 74 | return self.gradInput 75 | end 76 | 77 | -------------------------------------------------------------------------------- /modules/InstanceNormalization.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | --[[ 4 | Implements instance normalization as described in the paper 5 | 6 | Instance Normalization: The Missing Ingredient for Fast Stylization 7 | Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky 8 | https://arxiv.org/abs/1607.08022 9 | 10 | This implementation is based on 11 | https://github.com/DmitryUlyanov/texture_nets 12 | ]] 13 | 14 | local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 15 | 'nn.Module') 16 | 17 | 18 | function InstanceNormalization:__init(nOutput, eps) 19 | parent.__init(self) 20 | 21 | self.eps = eps or 1e-5 22 | 23 | self.nOutput = nOutput 24 | self.prev_N = -1 25 | 26 | self.weight = torch.Tensor(nOutput):uniform() 27 | self.bias = torch.Tensor(nOutput):zero() 28 | self.gradWeight = torch.Tensor(nOutput) 29 | self.gradBias = torch.Tensor(nOutput) 30 | end 31 | 32 | 33 | function InstanceNormalization:updateOutput(input) 34 | local N, C = input:size(1), input:size(2) 35 | local H, W = input:size(3), input:size(4) 36 | assert(C == self.nOutput) 37 | 38 | if N ~= self.prev_N or (self.bn and self:type() ~= self.bn:type()) then 39 | self.bn = nn.SpatialBatchNormalization(N * C, self.eps) 40 | self.bn:type(self:type()) 41 | self.prev_N = N 42 | end 43 | 44 | -- Set params for BN 45 | self.bn.weight:repeatTensor(self.weight, N) 46 | self.bn.bias:repeatTensor(self.bias, N) 47 | 48 | local input_view = input:view(1, N * C, H, W) 49 | self.bn:training() 50 | self.output = self.bn:forward(input_view):viewAs(input) 51 | 52 | return self.output 53 | end 54 | 55 | 56 | function InstanceNormalization:updateGradInput(input, gradOutput) 57 | local N, C = input:size(1), input:size(2) 58 | local H, W = input:size(3), input:size(4) 59 | assert(self.bn) 60 | 61 | local input_view = input:view(1, N * C, H, W) 62 | local gradOutput_view = gradOutput:view(1, N * C, H, W) 63 | 64 | self.bn.gradWeight:zero() 65 | self.bn.gradBias:zero() 66 | 67 | self.bn:training() 68 | self.gradInput = self.bn:backward(input_view, gradOutput_view):viewAs(input) 69 | 70 | self.gradWeight:add(self.bn.gradWeight:view(N, C):sum(1)) 71 | self.gradBias:add(self.bn.gradBias:view(N, C):sum(1)) 72 | return self.gradInput 73 | end 74 | 75 | 76 | function InstanceNormalization:clearState() 77 | self.output = self.output.new() 78 | self.gradInput = self.gradInput.new() 79 | self.bn:clearState() 80 | end 81 | 82 | -------------------------------------------------------------------------------- /experiments/opts.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local M={} 17 | 18 | function M.parse(arg) 19 | local cmd = torch.CmdLine() 20 | 21 | cmd:text() 22 | cmd:text('Options:') 23 | 24 | -- Generic options 25 | cmd:option('-model', 'hang') 26 | cmd:option('-model_nres', '9') 27 | cmd:option('-use_instance_norm', 1) 28 | cmd:option('-h5_file', 'data.h5') 29 | cmd:option('-padding_type', 'reflect-start') 30 | cmd:option('-tanh_constant', 150) 31 | cmd:option('-preprocessing', 'vgg') 32 | cmd:option('-resume', '') 33 | 34 | -- Style loss function options 35 | --cmd:option('-percep_loss_weight', 1.0) 36 | cmd:option('-tv_strength', 1e-6) 37 | 38 | -- Options for feature reconstruction loss 39 | cmd:option('-content_weights', '1.0') 40 | cmd:option('-content_layers', '16') 41 | cmd:option('-loss_network', 'models/vgg16.t7') 42 | 43 | -- Options for style reconstruction loss 44 | cmd:option('-style_image_folder', 'images/9styles') 45 | cmd:option('-style_image_size', 512) 46 | cmd:option('-style_iter', 20) 47 | cmd:option('-style_weights', '5.0') 48 | cmd:option('-style_layers', '4,9,16,23') 49 | 50 | -- Optimization 51 | cmd:option('-num_iterations', 80000) 52 | cmd:option('-max_train', -1) 53 | cmd:option('-batch_size', 4) 54 | cmd:option('-learning_rate', 1e-3) 55 | cmd:option('-lr_decay_every', -1) 56 | cmd:option('-lr_decay_factor', 0.5) 57 | cmd:option('-weight_decay', 0) 58 | 59 | -- Checkpointing 60 | cmd:option('-checkpoint_name', 'checkpoint') 61 | cmd:option('-checkpoint_every', 1000) 62 | cmd:option('-num_val_batches', 10) 63 | 64 | -- Backend options 65 | cmd:option('-gpu', 0) 66 | cmd:option('-use_cudnn', 1) 67 | cmd:option('-backend', 'cuda', 'cuda|opencl') 68 | 69 | -- Test options 70 | cmd:option('-premodel', 'models/model_9styles.t7') 71 | cmd:option('-input_image', 'images/content/venice-boat.jpg') 72 | cmd:option('-output_dir', 'stylized') 73 | cmd:option('-image_size', 512) 74 | 75 | -- Webcam options 76 | cmd:option('-webcam_idx', 0) 77 | cmd:option('-webcam_fps', 30) 78 | cmd:option('-height', 480) 79 | cmd:option('-width', 640) 80 | 81 | local opt = cmd:parse(arg or {}) 82 | return opt 83 | end 84 | 85 | return M 86 | -------------------------------------------------------------------------------- /experiments/test.lua: -------------------------------------------------------------------------------- 1 | require 'texture' 2 | require 'image' 3 | 4 | local utils = require 'utils.utils' 5 | local preprocess = require 'utils.preprocess' 6 | local imgLoader = require 'utils.getImages' 7 | local opts = require 'opts' 8 | 9 | local function main() 10 | local opt = opts.parse(arg) 11 | opt.style_layers, opt.style_weights = utils.parse_layers(opt.style_layers, 12 | opt.style_weights) 13 | if (opt.input_image == '') then 14 | error('Must give exactly one of -input_image') 15 | end 16 | 17 | local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1) 18 | local ok, checkpoint = pcall(function() return torch.load(opt.premodel) end) 19 | if not ok then 20 | print('ERROR: Could not load model from ' .. opt.premodel) 21 | print('You may need to download the pretrained models by running') 22 | return 23 | end 24 | local model = checkpoint.model 25 | model:evaluate() 26 | model:type(dtype) 27 | if use_cudnn then 28 | cudnn.convert(model, cudnn) 29 | if opt.cudnn_benchmark == 0 then 30 | cudnn.benchmark = false 31 | cudnn.fastest = true 32 | end 33 | end 34 | 35 | local preprocess_method = checkpoint.opt.preprocessing or 'vgg' 36 | local preprocess = preprocess[preprocess_method] 37 | local styleLoader = imgLoader(opt.style_image_folder) 38 | 39 | local featpath = opt.style_image_folder .. '/feat.t7' 40 | if not paths.filep(featpath) then 41 | local extractor = require "extractGram" 42 | extractor.exec(opt) 43 | end 44 | 45 | local feat = torch.load(featpath) 46 | feat = nn.utils.recursiveType(feat, 'torch.CudaTensor') 47 | 48 | local function run_image(in_path, out_dir, styleLoader, feat) 49 | if not path.isdir(out_dir) then 50 | paths.mkdir(out_dir) 51 | end 52 | local img = image.load(in_path, 3) 53 | if opt.image_size > 0 then 54 | img = image.scale(img, opt.image_size) 55 | end 56 | local H, W = img:size(2), img:size(3) 57 | local img_pre = preprocess.preprocess(img:view(1, 3, H, W)):type(dtype) 58 | 59 | for i=1,styleLoader:size() do 60 | local style_image = styleLoader:get(i) 61 | model:setTarget(feat[i], dtype) 62 | 63 | local img_out = model:forward(img_pre) 64 | local img_out = preprocess.deprocess(img_out)[1] 65 | 66 | local out_path = paths.concat(opt.output_dir, i) .. '.jpg' 67 | local out_path_style = paths.concat(opt.output_dir, i) .. 'style.jpg' 68 | print('Writing output image to ' .. out_path) 69 | image.save(out_path, img_out) 70 | image.save(out_path_style, style_image) 71 | collectgarbage() 72 | collectgarbage() 73 | end 74 | end 75 | 76 | if opt.input_image ~= '' then 77 | if opt.output_image == '' then 78 | error('Must give -output_image with -input_image') 79 | end 80 | run_image(opt.input_image, opt.output_dir, styleLoader, feat) 81 | else 82 | error('Must provide input image') 83 | end 84 | end 85 | 86 | main() 87 | -------------------------------------------------------------------------------- /experiments/utils/getImages.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | require 'paths' 19 | require 'image' 20 | 21 | local M={} 22 | local Dataset = torch.class('texture.Dataset', M) 23 | 24 | function Dataset:_findImages(dir) 25 | local imagePath = torch.CharTensor() 26 | 27 | ---------------------------------------------------------------------- 28 | -- Options for the GNU and BSD find command 29 | local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 30 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 31 | for i=2,#extensionList do 32 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 33 | end 34 | 35 | -- Find all the images using the find command 36 | local f = io.popen('find -L ' .. dir .. findOptions) 37 | 38 | local maxLength = -1 39 | local imagePaths = {} 40 | 41 | -- Generate a list of all the images 42 | while true do 43 | local line = f:read('*line') 44 | if not line then break end 45 | 46 | local filename = paths.basename(line) 47 | local path = dir .. '/' .. filename 48 | 49 | table.insert(imagePaths, path) 50 | 51 | maxLength = math.max(maxLength, #path + 1) 52 | end 53 | 54 | f:close() 55 | 56 | -- Convert the generated list to a tensor for faster loading 57 | local nImages = #imagePaths 58 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 59 | for i, path in ipairs(imagePaths) do 60 | ffi.copy(imagePath[i]:data(), path) 61 | end 62 | 63 | return imagePath 64 | end 65 | 66 | function Dataset:__init(imgDir) 67 | assert(self) 68 | assert(paths.dirp(imgDir), 'image directory not found: ' .. imgDir) 69 | self.imagePath = {self:_findImages(imgDir)} 70 | end 71 | 72 | function Dataset:_loadImage(path) 73 | local ok, input = pcall(function() 74 | return image.load(path, 3, 'float') 75 | end) 76 | 77 | -- Sometimes image.load fails because the file extension does not match the 78 | -- image format. In that case, use image.decompress on a ByteTensor. 79 | if not ok then 80 | local f = io.open(path, 'r') 81 | assert(f, 'Error reading: ' .. tostring(path)) 82 | local data = f:read('*a') 83 | f:close() 84 | 85 | local b = torch.ByteTensor(string.len(data)) 86 | ffi.copy(b:data(), data, b:size(1)) 87 | 88 | input = image.decompress(b, 3, 'float') 89 | end 90 | 91 | return input 92 | end 93 | 94 | function Dataset:get(i) 95 | assert(self and self.imagePath) 96 | i = (i-1) % self:size() + 1 97 | local path = ffi.string(self.imagePath[1][i]:data()) 98 | local image = self:_loadImage(path) 99 | return image 100 | end 101 | 102 | function Dataset:size() 103 | assert(self and self.imagePath) 104 | return self.imagePath[1]:size(1) 105 | end 106 | 107 | return M.Dataset 108 | -------------------------------------------------------------------------------- /experiments/webcam.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'texture' 3 | require 'image' 4 | require 'camera' 5 | 6 | require 'qt' 7 | require 'qttorch' 8 | require 'qtwidget' 9 | 10 | local utils = require 'utils.utils' 11 | local preprocess = require 'utils.preprocess' 12 | local imgLoader = require 'utils.getImages' 13 | local opts = require 'opts' 14 | 15 | local function main() 16 | local opt = opts.parse(arg) 17 | 18 | if (opt.input_image == '') then 19 | error('Must give exactly one of -input_image') 20 | end 21 | 22 | local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1) 23 | local ok, checkpoint = pcall(function() return torch.load(opt.premodel) end) 24 | if not ok then 25 | print('ERROR: Could not load model from ' .. opt.premodel) 26 | print('You may need to download the pretrained models by running') 27 | return 28 | end 29 | local model = checkpoint.model 30 | model:evaluate() 31 | model:type(dtype) 32 | if use_cudnn then 33 | cudnn.convert(model, cudnn) 34 | if opt.cudnn_benchmark == 0 then 35 | cudnn.benchmark = false 36 | cudnn.fastest = true 37 | end 38 | end 39 | 40 | local preprocess_method = checkpoint.opt.preprocessing or 'vgg' 41 | local preprocess = preprocess[preprocess_method] 42 | local styleLoader = imgLoader(opt.style_image_folder) 43 | 44 | local featpath = opt.style_image_folder .. '/feat.t7' 45 | if not paths.filep(featpath) then 46 | local extractor = require "extractGram" 47 | extractor.exec(opt) 48 | end 49 | 50 | local feat = torch.load(featpath) 51 | feat = nn.utils.recursiveType(feat, 'torch.CudaTensor') 52 | 53 | local style_image = nil 54 | local function run_image(img, feat, idx) 55 | if opt.image_size > 0 then 56 | img = image.scale(img, opt.image_size) 57 | end 58 | local H, W = img:size(2), img:size(3) 59 | local img_pre = preprocess.preprocess(img:view(1, 3, H, W)):type(dtype) 60 | 61 | -- update style image 62 | if (idx-1) % 15 == 0 then 63 | local i=torch.floor((idx-1)/15)%styleLoader:size()+1 64 | style_image = styleLoader:get(i):float() 65 | model:setTarget(feat[i], dtype) 66 | end 67 | 68 | local img_out = model:forward(img_pre) 69 | local styleSize = torch.floor(opt.image_size / 4) 70 | 71 | style_image = image.scale(style_image,styleSize,styleSize) 72 | img_out = preprocess.deprocess(img_out:float())[1] 73 | img = img:float() 74 | img:sub(1,3,21,20+styleSize,21,20+styleSize):copy(style_image) 75 | img_out = torch.cat(img,img_out,3) 76 | 77 | collectgarbage() 78 | collectgarbage() 79 | return img_out 80 | end 81 | 82 | local camera_opt = { 83 | idx = opt.webcam_idx, 84 | fps = opt.webcam_fps, 85 | height = opt.height, 86 | width = opt.width, 87 | } 88 | local cam = image.Camera(camera_opt) 89 | 90 | local win = nil 91 | local idx = 1 92 | 93 | while true do 94 | -- Grab a frame from the webcam 95 | local img = cam:forward() 96 | 97 | -- Run the model 98 | local img_disp = run_image(img, feat, idx) 99 | 100 | idx = idx + 1 101 | if not win then 102 | -- On the first call use image.display to construct a window 103 | win = image.display(img_disp) 104 | else 105 | -- Reuse the same window 106 | win.image = img_disp 107 | local size = win.window.size:totable() 108 | local qt_img = qt.QImage.fromTensor(img_disp) 109 | win.painter:image(0, 0, size.width, size.height, qt_img) 110 | end 111 | end 112 | end 113 | 114 | 115 | main() 116 | 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
|
5 | Multi-style Generative Network for Real-time Transfer [arXiv] [project] 6 | Hang Zhang, Kristin Dana 7 |
8 | @article{zhang2017multistyle,
9 | title={Multi-style Generative Network for Real-time Transfer},
10 | author={Zhang, Hang and Dana, Kristin},
11 | journal={arXiv preprint arXiv:1703.06953},
12 | year={2017}
13 | }
14 |
15 | |
16 | ![]() |
17 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | [[More Example Results](Examples.md)]
56 |
57 | ### Train Your Own Model
58 | Please follow [this tutorial to train a new model](Training.md).
59 |
60 | ### Release Timeline
61 | - [x] 03/20/2017 we have released the [demo video](https://www.youtube.com/watch?v=oy6pWNWBt4Y).
62 | - [x] 03/24/2017 We have released [ArXiv paper](https://arxiv.org/pdf/1703.06953.pdf) and test code with pre-trained models.
63 | - [x] 04/09/2017 We have released the training code.
64 | - [x] 04/24/2017 Please checkout our PyTorch [implementation](https://github.com/zhanghang1989/PyTorch-Style-Transfer).
65 |
66 | ### Acknowledgement
67 | The code benefits from outstanding prior work and their implementations including:
68 | - [Texture Networks: Feed-forward Synthesis of Textures and Stylized Images](https://arxiv.org/pdf/1603.03417.pdf) by Ulyanov *et al. ICML 2016*. ([code](https://github.com/DmitryUlyanov/texture_nets))
69 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/pdf/1603.08155.pdf) by Johnson *et al. ECCV 2016* ([code](https://github.com/jcjohnson/fast-neural-style))
70 | - [Image Style Transfer Using Convolutional Neural Networks](http://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf) by Gatys *et al. CVPR 2016* and its torch implementation [code](https://github.com/jcjohnson/neural-style) by Johnson.
71 |
--------------------------------------------------------------------------------
/modules/Inspiration.lua:
--------------------------------------------------------------------------------
1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | -- Created by: Hang Zhang
3 | -- ECE Department, Rutgers University
4 | -- Email: zhang.hang@rutgers.edu
5 | -- Copyright (c) 2017
6 | --
7 | -- Free to reuse and distribute this software for research or
8 | -- non-profit purpose, subject to the following conditions:
9 | -- 1. The code must retain the above copyright notice, this list of
10 | -- conditions.
11 | -- 2. Original authors' names are not deleted.
12 | -- 3. The authors' names are not used to endorse or promote products
13 | -- derived from this software
14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
15 |
16 | local Inspiration, parent = torch.class('nn.Inspiration', 'nn.Module')
17 |
18 | local function isint(x)
19 | return type(x) == 'number' and x == math.floor(x)
20 | end
21 |
22 | function Inspiration:__init(C)
23 | parent.__init(self)
24 | assert(self and C, 'should specify C')
25 | assert(isint(C), 'C should be integers')
26 |
27 | self.C = C
28 | self.MM_WG = nn.MM()
29 | self.MM_PX = nn.MM(true, false)
30 | self.target = torch.Tensor(C,C)
31 | self.Weight = torch.Tensor(C,C)
32 | self.P = torch.Tensor(C,C)
33 |
34 | self.gradWeight = torch.Tensor(C, C)
35 | self.gradWG = {torch.Tensor(C, C), torch.Tensor(C, C)}
36 | self.gradPX = {torch.Tensor(), torch.Tensor()}
37 | self.gradInput = torch.Tensor()
38 | self:reset()
39 | end
40 |
41 | function Inspiration:reset(stdv)
42 | if stdv then
43 | stdv = stdv * math.sqrt(2)
44 | else
45 | stdv = 1./math.sqrt(self.C)
46 | end
47 | self.Weight:uniform(-stdv, stdv)
48 | self.target:uniform(-stdv, stdv)
49 | return self
50 | end
51 |
52 | function Inspiration:setTarget(nT)
53 | assert(self and image)
54 | self.target = nT
55 | end
56 |
57 | function Inspiration:updateOutput(input)
58 | assert(self)
59 | -- P=WG Y=XP
60 | --self.output:resizeAs(input)
61 | if input:dim() == 3 then
62 | self.P = self.MM_WG:forward({self.Weight, self.target})
63 | self.output = self.MM_PX:forward({self.P, input:view(self.C,-1)}):viewAs(input)
64 | elseif input:dim() == 4 then
65 | local B = input:size(1)
66 | self.P = self.MM_WG:forward({self.Weight, self.target})
67 | self.output = self.MM_PX:forward({self.P:add_dummy():expand(B,self.C,self.C), input:view(B,self.C,-1)}):viewAs(input)
68 | else
69 | error('Unsupported dimention for Inspiration layer')
70 | end
71 | return self.output
72 | end
73 |
74 | function Inspiration:updateGradInput(input, gradOutput)
75 | assert(self and self.gradInput)
76 |
77 | --self.gradInput:resizeAs(input):fill(0)
78 | if input:dim() == 3 then
79 | self.gradPX = self.MM_PX:backward({self.P, input:view(self.C,-1)}, gradOutput:view(self.C,-1))
80 | elseif input:dim() == 4 then
81 | local B = input:size(1)
82 | self.gradPX = self.MM_PX:backward({self.P:add_dummy():expand(B,self.C,self.C), input:view(B,self.C,-1)}, gradOutput:view(B,self.C,-1))
83 | else
84 | error('Unsupported dimention for Inspiration layer')
85 | end
86 |
87 | self.gradInput = self.gradPX[2]:viewAs(input)
88 | return self.gradInput
89 | end
90 |
91 | function Inspiration:accGradParameters(input, gradOutput, scale)
92 | assert(self)
93 | scale = scale or 1
94 |
95 | if input:dim() == 3 then
96 | self.gradWG = self.MM_WG:backward({self.Weight, self.target}, self.gradPX[1])
97 | self.gradWeight = scale * self.gradWG[1]
98 | elseif input:dim() == 4 then
99 | self.gradWG = self.MM_WG:backward({self.Weight, self.target}, self.gradPX[1]:sum(1):squeeze())
100 | self.gradWeight = scale * self.gradWG[1]
101 | else
102 | error('Unsupported dimention for Inspiration layer')
103 | end
104 | end
105 |
106 | function Inspiration:__tostring__()
107 | return torch.type(self) ..
108 | string.format(
109 | '(%dxHxW, -> %dxHxW)',
110 | self.C, self.C
111 | )
112 | end
113 |
--------------------------------------------------------------------------------
/experiments/utils/utils.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | local cjson = require 'cjson'
4 |
5 |
6 | local M = {}
7 |
8 |
9 | -- Parse a string of comma-separated numbers
10 | -- For example convert "1.0,3.14" to {1.0, 3.14}
11 | function M.parse_num_list(s)
12 | local nums = {}
13 | for _, ss in ipairs(s:split(',')) do
14 | table.insert(nums, tonumber(ss))
15 | end
16 | return nums
17 | end
18 |
19 |
20 | -- Parse a layer string and associated weights string.
21 | -- The layers string is a string of comma-separated layer strings, and the
22 | -- weight string contains comma-separated numbers. If the weights string
23 | -- contains only a single number it is duplicated to be the same length as the
24 | -- layers.
25 | function M.parse_layers(layers_string, weights_string)
26 | local layers = layers_string:split(',')
27 | local weights = M.parse_num_list(weights_string)
28 | if #weights == 1 and #layers > 1 then
29 | -- Duplicate the same weight for all layers
30 | local w = weights[1]
31 | weights = {}
32 | for i = 1, #layers do
33 | table.insert(weights, w)
34 | end
35 | elseif #weights ~= #layers then
36 | local msg = 'size mismatch between layers "%s" and weights "%s"'
37 | error(string.format(msg, layers_string, weights_string))
38 | end
39 | return layers, weights
40 | end
41 |
42 |
43 | function M.setup_gpu(gpu, backend, use_cudnn)
44 | local dtype = 'torch.FloatTensor'
45 | if gpu >= 0 then
46 | if backend == 'cuda' then
47 | require 'cutorch'
48 | require 'cunn'
49 | cutorch.setDevice(gpu + 1)
50 | dtype = 'torch.CudaTensor'
51 | if use_cudnn then
52 | require 'cudnn'
53 | cudnn.benchmark = true
54 | end
55 | elseif backend == 'opencl' then
56 | require 'cltorch'
57 | require 'clnn'
58 | cltorch.setDevice(gpu + 1)
59 | dtype = torch.Tensor():cl():type()
60 | use_cudnn = false
61 | end
62 | else
63 | use_cudnn = false
64 | end
65 | return dtype, use_cudnn
66 | end
67 |
68 |
69 | function M.clear_gradients(m)
70 | if torch.isTypeOf(m, nn.Container) then
71 | m:applyToModules(M.clear_gradients)
72 | end
73 | if m.weight and m.gradWeight then
74 | m.gradWeight = m.gradWeight.new()
75 | end
76 | if m.bias and m.gradBias then
77 | m.gradBias = m.gradBias.new()
78 | end
79 | end
80 |
81 |
82 | function M.restore_gradients(m)
83 | if torch.isTypeOf(m, nn.Container) then
84 | m:applyToModules(M.restore_gradients)
85 | end
86 | if m.weight and m.gradWeight then
87 | m.gradWeight = m.gradWeight.new(#m.weight):zero()
88 | end
89 | if m.bias and m.gradBias then
90 | m.gradBias = m.gradBias.new(#m.bias):zero()
91 | end
92 | end
93 |
94 |
95 | function M.read_json(path)
96 | local file = io.open(path, 'r')
97 | local text = file:read()
98 | file:close()
99 | local info = cjson.decode(text)
100 | return info
101 | end
102 |
103 |
104 | function M.write_json(path, j)
105 | cjson.encode_sparse_array(true, 2, 10)
106 | local text = cjson.encode(j)
107 | local file = io.open(path, 'w')
108 | file:write(text)
109 | file:close()
110 | end
111 |
112 | local IMAGE_EXTS = {'jpg', 'jpeg', 'png', 'ppm', 'pgm'}
113 | function M.is_image_file(filename)
114 | -- Hidden file are not images
115 | if string.sub(filename, 1, 1) == '.' then
116 | return false
117 | end
118 | -- Check against a list of known image extensions
119 | local ext = string.lower(paths.extname(filename) or "")
120 | for _, image_ext in ipairs(IMAGE_EXTS) do
121 | if ext == image_ext then
122 | return true
123 | end
124 | end
125 | return false
126 | end
127 |
128 |
129 | function M.median_filter(img, r)
130 | local u = img:unfold(2, r, 1):contiguous()
131 | u = u:unfold(3, r, 1):contiguous()
132 | local HH, WW = u:size(2), u:size(3)
133 | local dtype = u:type()
134 | -- Median is not defined for CudaTensors, cast to float and back
135 | local med = u:view(3, HH, WW, r * r):float():median():type(dtype)
136 | return med[{{}, {}, {}, 1}]
137 | end
138 |
139 |
140 | return M
141 |
142 |
--------------------------------------------------------------------------------
/modules/layer_utils.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 |
3 | --[[
4 | Utility functions for getting and inserting layers into models composed of
5 | hierarchies of nn Modules and nn Containers. In such a model, we can uniquely
6 | address each module with a unique "layer string", which is a series of integers
7 | separated by dashes. This is easiest to understand with an example: consider
8 | the following network; we have labeled each module with its layer string:
9 |
10 | nn.Sequential {
11 | (1) nn.SpatialConvolution
12 | (2) nn.Sequential {
13 | (2-1) nn.SpatialConvolution
14 | (2-2) nn.SpatialConvolution
15 | }
16 | (3) nn.Sequential {
17 | (3-1) nn.SpatialConvolution
18 | (3-2) nn.Sequential {
19 | (3-2-1) nn.SpatialConvolution
20 | (3-2-2) nn.SpatialConvolution
21 | (3-2-3) nn.SpatialConvolution
22 | }
23 | (3-3) nn.SpatialConvolution
24 | }
25 | (4) nn.View
26 | (5) nn.Linear
27 | }
28 |
29 | Any layers that that have the instance variable _ignore set to true are ignored
30 | when computing layer strings for layers. This way, we can insert new layers into
31 | a network without changing the layer strings of existing layers.
32 | --]]
33 | local M = {}
34 |
35 |
36 | --[[
37 | Convert a layer string to an array of integers.
38 |
39 | For example layer_string_to_nums("1-23-4") = {1, 23, 4}.
40 | --]]
41 | function M.layer_string_to_nums(layer_string)
42 | local nums = {}
43 | for _, s in ipairs(layer_string:split('-')) do
44 | table.insert(nums, tonumber(s))
45 | end
46 | return nums
47 | end
48 |
49 |
50 | --[[
51 | Comparison function for layer strings that is compatible with table.sort.
52 | In this comparison scheme, 2-3 comes AFTER 2-3-X for all X.
53 |
54 | Input:
55 | - s1, s2: Two layer strings.
56 |
57 | Output:
58 | - true if s1 should come before s2 in sorted order; false otherwise.
59 | --]]
60 | function M.compare_layer_strings(s1, s2)
61 | local left = M.layer_string_to_nums(s1)
62 | local right = M.layer_string_to_nums(s2)
63 | local out = nil
64 | for i = 1, math.min(#left, #right) do
65 | if left[i] < right[i] then
66 | out = true
67 | elseif left[i] > right[i] then
68 | out = false
69 | end
70 | if out ~= nil then break end
71 | end
72 |
73 | if out == nil then
74 | out = (#left > #right)
75 | end
76 | return out
77 | end
78 |
79 |
80 | --[[
81 | Get a layer from the network net using a layer string.
82 | --]]
83 | function M.get_layer(net, layer_string)
84 | local nums = M.layer_string_to_nums(layer_string)
85 | local layer = net
86 | for i, num in ipairs(nums) do
87 | local count = 0
88 | for j = 1, #layer do
89 | if not layer:get(j)._ignore then
90 | count = count + 1
91 | end
92 | if count == num then
93 | layer = layer:get(j)
94 | break
95 | end
96 | end
97 | end
98 | return layer
99 | end
100 |
101 |
102 | -- Insert a new layer immediately after the layer specified by a layer string.
103 | -- Any layers inserted this way are flagged with a special variable
104 | function M.insert_after(net, layer_string, new_layer)
105 | new_layer._ignore = true
106 | local nums = M.layer_string_to_nums(layer_string)
107 | local container = net
108 | for i = 1, #nums do
109 | local count = 0
110 | for j = 1, #container do
111 | if not container:get(j)._ignore then
112 | count = count + 1
113 | end
114 | if count == nums[i] then
115 | if i < #nums then
116 | container = container:get(j)
117 | break
118 | elseif i == #nums then
119 | container:insert(new_layer, j + 1)
120 | return
121 | end
122 | end
123 | end
124 | end
125 | end
126 |
127 |
128 | -- Remove the layers of the network that occur after the last _ignore
129 | function M.trim_network(net)
130 | local function contains_ignore(layer)
131 | if torch.isTypeOf(layer, nn.Container) then
132 | local found = false
133 | for i = 1, layer:size() do
134 | found = found or contains_ignore(layer:get(i))
135 | end
136 | return found
137 | else
138 | return layer._ignore == true
139 | end
140 | end
141 | local last_layer = 0
142 | for i = 1, #net do
143 | if contains_ignore(net:get(i)) then
144 | last_layer = i
145 | end
146 | end
147 | local num_to_remove = #net - last_layer
148 | for i = 1, num_to_remove do
149 | net:remove()
150 | end
151 | return net
152 | end
153 |
154 |
155 | return M
156 |
157 |
--------------------------------------------------------------------------------
/modules/PerceptualCriterion.lua:
--------------------------------------------------------------------------------
1 | local layer_utils = require 'texture.layer_utils'
2 |
3 | local crit, parent = torch.class('nn.PerceptualCriterion', 'nn.Criterion')
4 |
5 | --[[
6 | Input: args is a table with the following keys:
7 | - cnn: A network giving the base CNN.
8 | - content_layers: An array of layer strings
9 | - content_weights: A list of the same length as content_layers
10 | - style_layers: An array of layers strings
11 | - style_weights: A list of the same length as style_layers
12 | "mean" or "gram"
13 | - deepdream_layers: Array of layer strings
14 | - deepdream_weights: List of the same length as deepdream_layers
15 | --]]
16 | function crit:__init(args)
17 | args.content_layers = args.content_layers or {}
18 | args.style_layers = args.style_layers or {}
19 | args.deepdream_layers = args.deepdream_layers or {}
20 |
21 | self.net = args.cnn
22 | self.net:evaluate()
23 | self.content_loss_layers = {}
24 | self.style_loss_layers = {}
25 | self.deepdream_loss_layers = {}
26 |
27 | -- Set up content loss layers
28 | for i, layer_string in ipairs(args.content_layers) do
29 | local weight = args.content_weights[i]
30 | local content_loss_layer = nn.ContentLoss(weight)
31 | layer_utils.insert_after(self.net, layer_string, content_loss_layer)
32 | table.insert(self.content_loss_layers, content_loss_layer)
33 | end
34 |
35 | -- Set up style loss layers
36 | for i, layer_string in ipairs(args.style_layers) do
37 | local weight = args.style_weights[i]
38 | local style_loss_layer = nn.StyleLoss(weight)
39 | layer_utils.insert_after(self.net, layer_string, style_loss_layer)
40 | table.insert(self.style_loss_layers, style_loss_layer)
41 | end
42 |
43 | layer_utils.trim_network(self.net)
44 | self.grad_net_output = torch.Tensor()
45 |
46 | end
47 |
48 |
49 | --[[
50 | target: Tensor of shape (1, 3, H, W) giving pixels for style target image
51 | --]]
52 | function crit:setStyleTarget(target)
53 | for i, content_loss_layer in ipairs(self.content_loss_layers) do
54 | content_loss_layer:setMode('none')
55 | end
56 | for i, style_loss_layer in ipairs(self.style_loss_layers) do
57 | style_loss_layer:setMode('capture')
58 | end
59 | self.net:forward(target)
60 | end
61 |
62 |
63 | --[[
64 | target: Tensor of shape (N, 3, H, W) giving pixels for content target images
65 | --]]
66 | function crit:setContentTarget(target)
67 | for i, style_loss_layer in ipairs(self.style_loss_layers) do
68 | style_loss_layer:setMode('none')
69 | end
70 | for i, content_loss_layer in ipairs(self.content_loss_layers) do
71 | content_loss_layer:setMode('capture')
72 | end
73 | self.net:forward(target)
74 | end
75 |
76 |
77 | function crit:setStyleWeight(weight)
78 | for i, style_loss_layer in ipairs(self.style_loss_layers) do
79 | style_loss_layer.strength = weight
80 | end
81 | end
82 |
83 |
84 | function crit:setContentWeight(weight)
85 | for i, content_loss_layer in ipairs(self.content_loss_layers) do
86 | content_loss_layer.strength = weight
87 | end
88 | end
89 |
90 |
91 | --[[
92 | Inputs:
93 | - input: Tensor of shape (N, 3, H, W) giving pixels for generated images
94 | - target: Table with the following keys:
95 | - content_target: Tensor of shape (N, 3, H, W)
96 | - style_target: Tensor of shape (1, 3, H, W)
97 | --]]
98 | function crit:updateOutput(input, target)
99 | if target.content_target then
100 | self:setContentTarget(target.content_target)
101 | end
102 | if target.style_target then
103 | self.setStyleTarget(target.style_target)
104 | end
105 |
106 | -- Make sure to set all content and style loss layers to loss mode before
107 | -- running the image forward.
108 | for i, content_loss_layer in ipairs(self.content_loss_layers) do
109 | content_loss_layer:setMode('loss')
110 | end
111 | for i, style_loss_layer in ipairs(self.style_loss_layers) do
112 | style_loss_layer:setMode('loss')
113 | end
114 |
115 | local output = self.net:forward(input)
116 |
117 | -- Set up a tensor of zeros to pass as gradient to net in backward pass
118 | self.grad_net_output:resizeAs(output):zero()
119 |
120 | -- Go through and add up losses
121 | self.total_content_loss = 0
122 | self.content_losses = {}
123 | self.total_style_loss = 0
124 | self.style_losses = {}
125 | for i, content_loss_layer in ipairs(self.content_loss_layers) do
126 | self.total_content_loss = self.total_content_loss + content_loss_layer.loss
127 | table.insert(self.content_losses, content_loss_layer.loss)
128 | end
129 | for i, style_loss_layer in ipairs(self.style_loss_layers) do
130 | self.total_style_loss = self.total_style_loss + style_loss_layer.loss
131 | table.insert(self.style_losses, style_loss_layer.loss)
132 | end
133 |
134 | self.output = self.total_style_loss + self.total_content_loss
135 | return self.output
136 | end
137 |
138 |
139 | function crit:updateGradInput(input, target)
140 | self.gradInput = self.net:updateGradInput(input, self.grad_net_output)
141 | return self.gradInput
142 | end
143 |
144 |
--------------------------------------------------------------------------------
/experiments/models/hang.lua:
--------------------------------------------------------------------------------
1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | -- Created by: Hang Zhang
3 | -- ECE Department, Rutgers University
4 | -- Email: zhang.hang@rutgers.edu
5 | -- Copyright (c) 2017
6 | --
7 | -- Free to reuse and distribute this software for research or
8 | -- non-profit purpose, subject to the following conditions:
9 | -- 1. The code must retain the above copyright notice, this list of
10 | -- conditions.
11 | -- 2. Original authors' names are not deleted.
12 | -- 3. The authors' names are not used to endorse or promote products
13 | -- derived from this software
14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
15 |
16 | require 'texture'
17 |
18 | local pad = nn.SpatialReflectionPadding
19 | local normalization = nn.InstanceNormalization
20 | local layer_utils = require 'texture.layer_utils'
21 |
22 | local M = {}
23 |
24 | function M.createModel(opt)
25 | -- Global variable keeping track of the input channels
26 | local iChannels
27 |
28 | -- The shortcut layer is either identity or 1x1 convolution
29 | local function shortcut(nInputPlane, nOutputPlane, stride)
30 | if nInputPlane ~= nOutputPlane then
31 | return nn.Sequential()
32 | :add(nn.SpatialConvolution(nInputPlane, nOutputPlane, 1, 1, stride, stride))
33 | else
34 | return nn.Identity()
35 | end
36 | end
37 |
38 | local function full_shortcut(nInputPlane, nOutputPlane, stride)
39 | if nInputPlane ~= nOutputPlane or stride ~= 1 then
40 | return nn.Sequential()
41 | --:add(pad(1,0,1,0))
42 | :add(nn.SpatialUpSamplingNearest(stride))
43 | :add(nn.SpatialConvolution(nInputPlane, nOutputPlane, 1, 1, 1, 1))
44 | --:add(nn.SpatialFullConvolution(nInputPlane, nOutputPlane, 1, 1, stride, stride, 1, 1, 1, 1))
45 | else
46 | return nn.Identity()
47 | end
48 | end
49 |
50 | local function basic_block(n, stride)
51 | stride = stride or 1
52 | local nInputPlane = iChannels
53 | iChannels = n
54 | -- Convolutions
55 | local conv_block = nn.Sequential()
56 |
57 | conv_block:add(normalization(nInputPlane))
58 | conv_block:add(nn.ReLU(true))
59 | conv_block:add(pad(1, 1, 1, 1))
60 | conv_block:add(nn.SpatialConvolution(nInputPlane, n, 3, 3, stride, stride, 0, 0))
61 |
62 | conv_block:add(normalization(n))
63 | conv_block:add(nn.ReLU(true))
64 | conv_block:add(pad(1, 1, 1, 1))
65 | conv_block:add(nn.SpatialConvolution(n, n, 3, 3, 1, 1, 0, 0))
66 |
67 | local concat = nn.ConcatTable():add(conv_block):add(shortcut(nInputPlane, n, stride))
68 |
69 | -- Sum
70 | local res_block = nn.Sequential()
71 | res_block:add(concat)
72 | res_block:add(nn.CAddTable())
73 | return res_block
74 | end
75 |
76 | local function bottleneck(n, stride)
77 | stride = stride or 1
78 | local nInputPlane = iChannels
79 | iChannels = 4 * n
80 | -- Convolutions
81 | local conv_block = nn.Sequential()
82 |
83 | conv_block:add(normalization(nInputPlane))
84 | conv_block:add(nn.ReLU(true))
85 | conv_block:add(nn.SpatialConvolution(nInputPlane, n, 1, 1, 1, 1, 0, 0))
86 |
87 | conv_block:add(normalization(n))
88 | conv_block:add(nn.ReLU(true))
89 | conv_block:add(pad(1, 1, 1, 1))
90 | conv_block:add(nn.SpatialConvolution(n, n, 3, 3, stride, stride, 0, 0))
91 |
92 | conv_block:add(normalization(n))
93 | conv_block:add(nn.ReLU(true))
94 | conv_block:add(nn.SpatialConvolution(n, n*4, 1, 1, 1, 1, 0, 0))
95 |
96 | local concat = nn.ConcatTable():add(conv_block):add(shortcut(nInputPlane, n*4, stride))
97 |
98 | -- Sum
99 | local res_block = nn.Sequential()
100 | res_block:add(concat)
101 | res_block:add(nn.CAddTable())
102 | return res_block
103 | end
104 |
105 | local function full_bottleneck(n, stride)
106 | stride = stride or 1
107 | local nInputPlane = iChannels
108 | iChannels = 4 * n
109 | -- Convolutions
110 | local conv_block = nn.Sequential()
111 |
112 | conv_block:add(normalization(nInputPlane))
113 | conv_block:add(nn.ReLU(true))
114 | conv_block:add(nn.SpatialConvolution(nInputPlane, n, 1, 1, 1, 1, 0, 0))
115 |
116 | conv_block:add(normalization(n))
117 | conv_block:add(nn.ReLU(true))
118 |
119 | if stride~=1 then
120 | conv_block:add(nn.SpatialUpSamplingNearest(stride))
121 | conv_block:add(pad(1, 1, 1, 1))
122 | conv_block:add(nn.SpatialConvolution(n, n, 3, 3, 1, 1, 0, 0))
123 | else
124 | conv_block:add(pad(1, 1, 1, 1))
125 | conv_block:add(nn.SpatialConvolution(n, n, 3, 3, 1, 1, 0, 0))
126 | end
127 | conv_block:add(normalization(n))
128 | conv_block:add(nn.ReLU(true))
129 | conv_block:add(nn.SpatialConvolution(n, n*4, 1, 1, 1, 1, 0, 0))
130 |
131 | local concat = nn.ConcatTable()
132 | :add(conv_block)
133 | :add(full_shortcut(nInputPlane, n*4, stride))
134 |
135 | -- Sum
136 | local res_block = nn.Sequential()
137 | res_block:add(concat)
138 | res_block:add(nn.CAddTable())
139 | return res_block
140 | end
141 |
142 | local function layer(block, features, count, stride)
143 | local s = nn.Sequential()
144 | for i=1,count do
145 | s:add(block(features, i==1 and stride or 1))
146 | end
147 | return s
148 | end
149 |
150 | local model = nn.Sequential()
151 | model.cNetsNum = {}
152 |
153 | -- 256x256
154 | model:add(normalization(3))
155 | model:add(pad(3, 3, 3, 3))
156 | model:add(nn.SpatialConvolution(3, 64, 7, 7, 1, 1, 0, 0))
157 | model:add(normalization(64))
158 | model:add(nn.ReLU(true))
159 |
160 | iChannels = 64
161 | local block = bottleneck -- basic_block
162 |
163 | model:add(layer(block, 32, 1, 2))
164 | model:add(layer(block, 64, 1, 2))
165 |
166 | -- 32x32x512
167 | model:add(nn.Inspiration(iChannels))
168 | table.insert(model.cNetsNum,#model)
169 | model:add(normalization(iChannels))
170 | model:add(nn.ReLU(true))
171 |
172 | for i = 1,opt.model_nres do
173 | model:add(layer(block, 64, 1, 1))
174 | end
175 |
176 | block = full_bottleneck
177 | model:add(layer(block, 32, 1, 2))
178 | model:add(layer(block, 16, 1, 2))
179 |
180 | model:add(normalization(64))
181 | model:add(nn.ReLU(true))
182 |
183 | model:add(pad(3, 3, 3, 3))
184 | model:add(nn.SpatialConvolution(64, 3, 7, 7, 1, 1, 0, 0))
185 |
186 | model:add(nn.TotalVariation(opt.tv_strength))
187 |
188 | function model:setTarget(feat, dtype)
189 | model.modules[model.cNetsNum[1]]:setTarget(feat[3]:type(dtype))
190 | end
191 | return model
192 | end
193 |
194 | function M.createCNets(opt)
195 | -- The descriptive network in the paper
196 | local cnet = torch.load(opt.loss_network)
197 | cnet:evaluate()
198 | cnet.style_layers = {}
199 |
200 | -- Set up calibrate layers
201 | for i, layer_string in ipairs(opt.style_layers) do
202 | local calibrator = nn.Calibrate()
203 | layer_utils.insert_after(cnet, layer_string, calibrator)
204 | table.insert(cnet.style_layers, calibrator)
205 | end
206 | layer_utils.trim_network(cnet)
207 |
208 | function cnet:calibrate(input)
209 | cnet:forward(input)
210 | local feat = {}
211 | for i, calibrator in ipairs(cnet.style_layers) do
212 | table.insert(feat, calibrator:getGram())
213 | end
214 | return feat
215 | end
216 |
217 | return cnet
218 | end
219 |
220 | return M
221 |
--------------------------------------------------------------------------------
/experiments/main.lua:
--------------------------------------------------------------------------------
1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | -- Created by: Hang Zhang
3 | -- ECE Department, Rutgers University
4 | -- Email: zhang.hang@rutgers.edu
5 | -- Copyright (c) 2017
6 | --
7 | -- Free to reuse and distribute this software for research or
8 | -- non-profit purpose, subject to the following conditions:
9 | -- 1. The code must retain the above copyright notice, this list of
10 | -- conditions.
11 | -- 2. Original authors' names are not deleted.
12 | -- 3. The authors' names are not used to endorse or promote products
13 | -- derived from this software
14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
15 |
16 | require 'texture'
17 | require 'image'
18 | require 'optim'
19 |
20 | require 'utils.DataLoader'
21 |
22 | local utils = require 'utils.utils'
23 | local preprocess = require 'utils.preprocess'
24 | local opts = require 'opts'
25 | local imgLoader = require 'utils.getImages'
26 |
27 | function main()
28 | local opt = opts.parse(arg)
29 | -- Parse layer strings and weights
30 | opt.content_layers, opt.content_weights =
31 | utils.parse_layers(opt.content_layers, opt.content_weights)
32 | opt.style_layers, opt.style_weights =
33 | utils.parse_layers(opt.style_layers, opt.style_weights)
34 |
35 | -- Figure out preprocessing
36 | if not preprocess[opt.preprocessing] then
37 | local msg = 'invalid -preprocessing "%s"; must be "vgg" or "resnet"'
38 | error(string.format(msg, opt.preprocessing))
39 | end
40 | preprocess = preprocess[opt.preprocessing]
41 |
42 | -- Figure out the backend
43 | local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1)
44 |
45 | -- Style images
46 | local styleLoader = imgLoader(opt.style_image_folder)
47 | local featpath = opt.style_image_folder .. '/feat.t7'
48 | if not paths.filep(featpath) then
49 | local extractor = require "extractGram"
50 | extractor.exec(opt)
51 | end
52 | local feat = torch.load(featpath)
53 | feat = nn.utils.recursiveType(feat, 'torch.CudaTensor')
54 |
55 | -- Build the model
56 | local model = nil
57 |
58 | -- Checkpoint
59 | if opt.resume ~= '' then
60 | print('Loading checkpoint from ' .. opt.resume)
61 | model = torch.load(opt.resume).model:type(dtype)
62 | else
63 | print('Initializing model from scratch')
64 | models = require('models.' .. opt.model)
65 | model = models.createModel(opt):type(dtype)
66 | end
67 |
68 | if use_cudnn then cudnn.convert(model, cudnn) end
69 | model:training()
70 | print(model)
71 |
72 | -- Set up the perceptual loss function
73 | local percep_crit
74 | local loss_net = torch.load(opt.loss_network)
75 | local crit_args = {
76 | cnn = loss_net,
77 | style_layers = opt.style_layers,
78 | style_weights = opt.style_weights,
79 | content_layers = opt.content_layers,
80 | content_weights = opt.content_weights,
81 | }
82 | percep_crit = nn.PerceptualCriterion(crit_args):type(dtype)
83 |
84 | local loader = DataLoader(opt)
85 | local params, grad_params = model:getParameters()
86 |
87 | local function f(x)
88 | assert(x == params)
89 | grad_params:zero()
90 |
91 | local x, y = loader:getBatch('train')
92 | x, y = x:type(dtype), y:type(dtype)
93 |
94 | -- Run model forward
95 | local out = model:forward(x)
96 | local target = {content_target=y}
97 | local loss = percep_crit:forward(out, target)
98 | local grad_out = percep_crit:backward(out, target)
99 |
100 | -- Run model backward
101 | model:backward(x, grad_out)
102 |
103 | return loss, grad_params
104 | end
105 |
106 | local optim_state = {learningRate=opt.learning_rate}
107 | local train_loss_history = {}
108 | local val_loss_history = {}
109 | local val_loss_history_ts = {}
110 | local style_loss_history = nil
111 |
112 | style_loss_history = {}
113 | for i, k in ipairs(opt.style_layers) do
114 | style_loss_history[string.format('style-%d', k)] = {}
115 | end
116 | for i, k in ipairs(opt.content_layers) do
117 | style_loss_history[string.format('content-%d', k)] = {}
118 | end
119 |
120 | local style_weight = opt.style_weight
121 | for t = 1, opt.num_iterations do
122 | -- set Target Here
123 | if (t-1)%opt.style_iter == 0 then
124 | --print('Setting Style Target')
125 | local idx = (t-1)/opt.style_iter % #feat + 1
126 |
127 | local style_image = styleLoader:get(idx)
128 | style_image = image.scale(style_image, opt.style_image_size)
129 | style_image = preprocess.preprocess(style_image:add_dummy())
130 | percep_crit:setStyleTarget(style_image:type(dtype))
131 |
132 | local style_image_feat = feat[idx]
133 | model:setTarget(style_image_feat, dtype)
134 | end
135 | local epoch = t / loader.num_minibatches['train']
136 |
137 | local _, loss = optim.adam(f, params, optim_state)
138 |
139 | table.insert(train_loss_history, loss[1])
140 |
141 | for i, k in ipairs(opt.style_layers) do
142 | table.insert(style_loss_history[string.format('style-%d', k)],
143 | percep_crit.style_losses[i])
144 | end
145 | for i, k in ipairs(opt.content_layers) do
146 | table.insert(style_loss_history[string.format('content-%d', k)],
147 | percep_crit.content_losses[i])
148 | end
149 |
150 | print(string.format('Epoch %f, Iteration %d / %d, loss = %f',
151 | epoch, t, opt.num_iterations, loss[1]), optim_state.learningRate)
152 |
153 | if t % opt.checkpoint_every == 0 then
154 | -- Check loss on the validation set
155 | loader:reset('val')
156 | model:evaluate()
157 | local val_loss = 0
158 | print 'Running on validation set ... '
159 | local val_batches = opt.num_val_batches
160 | for j = 1, val_batches do
161 | local x, y = loader:getBatch('val')
162 | x, y = x:type(dtype), y:type(dtype)
163 | local out = model:forward(x)
164 | --y = shave_y(x, y, out)
165 |
166 | local percep_loss = 0
167 | percep_loss = percep_crit:forward(out, {content_target=y})
168 | val_loss = val_loss + percep_loss
169 | end
170 | val_loss = val_loss / val_batches
171 | print(string.format('val loss = %f', val_loss))
172 | table.insert(val_loss_history, val_loss)
173 | table.insert(val_loss_history_ts, t)
174 | model:training()
175 |
176 | -- Save a checkpoint
177 | local checkpoint = {
178 | opt=opt,
179 | train_loss_history=train_loss_history,
180 | val_loss_history=val_loss_history,
181 | val_loss_history_ts=val_loss_history_ts,
182 | style_loss_history=style_loss_history,
183 | }
184 | local filename = string.format('%s.json', opt.checkpoint_name)
185 | paths.mkdir(paths.dirname(filename))
186 | utils.write_json(filename, checkpoint)
187 |
188 | -- Save a torch checkpoint; convert the model to float first
189 | model:clearState()
190 | if use_cudnn then
191 | cudnn.convert(model, nn)
192 | end
193 | model:float()
194 | checkpoint.model = model
195 | filename = string.format('%s.t7', opt.checkpoint_name)
196 | torch.save(filename, checkpoint)
197 |
198 | -- Convert the model back
199 | model:type(dtype)
200 | if use_cudnn then
201 | cudnn.convert(model, cudnn)
202 | end
203 | params, grad_params = model:getParameters()
204 |
205 | collectgarbage()
206 | collectgarbage()
207 | end
208 |
209 | if opt.lr_decay_every > 0 and t % opt.lr_decay_every == 0 then
210 | local new_lr = opt.lr_decay_factor * optim_state.learningRate
211 | optim_state = {learningRate = new_lr}
212 | end
213 | end
214 | end
215 |
216 |
217 | main()
218 |
219 |
--------------------------------------------------------------------------------
/cmake/select_compute_arch.cmake:
--------------------------------------------------------------------------------
1 | # Synopsis:
2 | # CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
3 | # -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
4 | # target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
5 | # - "Auto" detects local machine GPU compute arch at runtime.
6 | # - "Common" and "All" cover common and entire subsets of architectures
7 | # ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
8 | # NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal
9 | # NUM: Any number. Only those pairs are currently accepted by NVCC though:
10 | # 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2
11 | # Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
12 | # Additionally, sets ${out_variable}_readable to the resulting numeric list
13 | # Example:
14 | # CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
15 | # LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
16 | #
17 | # More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
18 | #
19 |
20 | # This list will be used for CUDA_ARCH_NAME = All option
21 | set(CUDA_KNOWN_GPU_ARCHITECTURES "Fermi" "Kepler" "Maxwell")
22 |
23 | # This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
24 | set(CUDA_COMMON_GPU_ARCHITECTURES "3.0" "3.5" "5.0")
25 |
26 | if (CUDA_VERSION VERSION_GREATER "6.5")
27 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra")
28 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2")
29 | endif ()
30 |
31 | if (CUDA_VERSION VERSION_GREATER "7.5")
32 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal")
33 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1" "6.1+PTX")
34 | else()
35 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX")
36 | endif ()
37 |
38 |
39 |
40 | ################################################################################################
41 | # A function for automatic detection of GPUs installed (if autodetection is enabled)
42 | # Usage:
43 | # CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE)
44 | #
45 | function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
46 | if(NOT CUDA_GPU_DETECT_OUTPUT)
47 | set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)
48 |
49 | file(WRITE ${cufile} ""
50 | "#include