├── .gitignore ├── images ├── all.jpg ├── all11.jpg ├── figure1.jpg └── stylized │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── 8.jpg │ └── 9.jpg ├── .editorconfig ├── experiments ├── images │ ├── 9styles │ │ ├── wave.jpg │ │ ├── candy.jpg │ │ ├── mosaic.jpg │ │ ├── udnie.jpg │ │ ├── feathers.jpg │ │ ├── la_muse.jpg │ │ ├── the_scream.jpg │ │ ├── starry_night.jpg │ │ └── composition_vii.jpg │ └── content │ │ ├── noise.jpg │ │ ├── flowers.jpg │ │ ├── shenyang.jpg │ │ ├── shenyang3.jpg │ │ └── venice-boat.jpg ├── models │ ├── download_models.sh │ └── hang.lua ├── extractGram.lua ├── utils │ ├── preprocess.lua │ ├── DataLoader.lua │ ├── getImages.lua │ └── utils.lua ├── opts.lua ├── test.lua ├── webcam.lua └── main.lua ├── modules ├── utils.lua ├── TotalVariation.lua ├── Calibrate.lua ├── StyleLoss.lua ├── ContentLoss.lua ├── GramMatrix.lua ├── InstanceNormalization.lua ├── Inspiration.lua ├── layer_utils.lua └── PerceptualCriterion.lua ├── texture-scm-1.rockspec ├── init.lua ├── Training.md ├── CMakeLists.txt ├── README.md └── cmake └── select_compute_arch.cmake /.gitignore: -------------------------------------------------------------------------------- 1 | build.luarocks/ 2 | *.DS_Store 3 | *.swp 4 | *.t7 5 | -------------------------------------------------------------------------------- /images/all.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/all.jpg -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = tab 5 | indent_size = 2 6 | -------------------------------------------------------------------------------- /images/all11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/all11.jpg -------------------------------------------------------------------------------- /images/figure1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/figure1.jpg -------------------------------------------------------------------------------- /images/stylized/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/1.jpg -------------------------------------------------------------------------------- /images/stylized/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/2.jpg -------------------------------------------------------------------------------- /images/stylized/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/3.jpg -------------------------------------------------------------------------------- /images/stylized/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/4.jpg -------------------------------------------------------------------------------- /images/stylized/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/5.jpg -------------------------------------------------------------------------------- /images/stylized/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/6.jpg -------------------------------------------------------------------------------- /images/stylized/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/7.jpg -------------------------------------------------------------------------------- /images/stylized/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/8.jpg -------------------------------------------------------------------------------- /images/stylized/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/images/stylized/9.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/wave.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/candy.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/mosaic.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/udnie.jpg -------------------------------------------------------------------------------- /experiments/images/content/noise.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/noise.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/feathers.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/la_muse.jpg -------------------------------------------------------------------------------- /experiments/images/content/flowers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/flowers.jpg -------------------------------------------------------------------------------- /experiments/images/content/shenyang.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/shenyang.jpg -------------------------------------------------------------------------------- /experiments/images/content/shenyang3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/shenyang3.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/the_scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/the_scream.jpg -------------------------------------------------------------------------------- /experiments/images/content/venice-boat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/content/venice-boat.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/starry_night.jpg -------------------------------------------------------------------------------- /experiments/images/9styles/composition_vii.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StacyYang/MSG-Net/HEAD/experiments/images/9styles/composition_vii.jpg -------------------------------------------------------------------------------- /experiments/models/download_models.sh: -------------------------------------------------------------------------------- 1 | cd models 2 | wget -O vgg16.t7 "http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7" 3 | wget -O model_9styles.t7 "https://www.dropbox.com/s/50bn53g0ok26aac/model_9styles.t7?dl=1" 4 | cd ../ 5 | -------------------------------------------------------------------------------- /modules/utils.lua: -------------------------------------------------------------------------------- 1 | -- adding first dummy dimension (by Dmitry Ulyanov) 2 | 3 | function torch.FloatTensor:add_dummy() 4 | local sz = self:size() 5 | local new_sz = torch.Tensor(sz:size()+1) 6 | new_sz[1] = 1 7 | new_sz:narrow(1,2,sz:size()):copy(torch.Tensor{sz:totable()}) 8 | 9 | if self:isContiguous() then 10 | return self:view(new_sz:long():storage()) 11 | else 12 | return self:reshape(new_sz:long():storage()) 13 | end 14 | end 15 | 16 | torch.Tensor.add_dummy = torch.FloatTensor.add_dummy 17 | if cutorch then 18 | torch.CudaTensor.add_dummy = torch.FloatTensor.add_dummy 19 | end 20 | -------------------------------------------------------------------------------- /texture-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "texture" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/zhanghang1989/MSG-Net.git", 6 | tag = "master" 7 | } 8 | 9 | description = { 10 | summary = "Texture Master Network", 11 | detailed = [[ 12 | Texture Master Network 13 | ]], 14 | homepage = "https://github.com/zhanghang1989/MSG-Net" 15 | } 16 | 17 | dependencies = { 18 | "torch >= 7.0", 19 | "cutorch >= 1.0" 20 | } 21 | 22 | build = { 23 | type = "cmake", 24 | variables = { 25 | CMAKE_BUILD_TYPE="Release", 26 | CMAKE_PREFIX_PATH="$(LUA_BINDIR)/..", 27 | CMAKE_INSTALL_PREFIX="$(PREFIX)" 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | require 'nn' 17 | require 'cunn' 18 | require 'cudnn' 19 | 20 | -- load packages from perceptual-loss 21 | require 'texture.utils' 22 | require 'texture.GramMatrix' 23 | require 'texture.ContentLoss' 24 | require 'texture.StyleLoss' 25 | require 'texture.TotalVariation' 26 | require 'texture.PerceptualCriterion' 27 | require 'texture.InstanceNormalization' 28 | 29 | -- load MSG-Net dependencies 30 | require 'texture.Calibrate' 31 | require 'texture.Inspiration' 32 | -------------------------------------------------------------------------------- /modules/TotalVariation.lua: -------------------------------------------------------------------------------- 1 | local TVLoss, parent = torch.class('nn.TotalVariation', 'nn.Module') 2 | 3 | 4 | function TVLoss:__init(strength) 5 | parent.__init(self) 6 | self.strength = strength 7 | self.x_diff = torch.Tensor() 8 | self.y_diff = torch.Tensor() 9 | end 10 | 11 | 12 | function TVLoss:updateOutput(input) 13 | self.output = input 14 | return self.output 15 | end 16 | 17 | 18 | -- TV loss backward pass inspired by kaishengtai/neuralart 19 | function TVLoss:updateGradInput(input, gradOutput) 20 | self.gradInput:resizeAs(input):zero() 21 | local N, C = input:size(1), input:size(2) 22 | local H, W = input:size(3), input:size(4) 23 | self.x_diff:resize(N, 3, H - 1, W - 1) 24 | self.y_diff:resize(N, 3, H - 1, W - 1) 25 | self.x_diff:copy(input[{{}, {}, {1, -2}, {1, -2}}]) 26 | self.x_diff:add(-1, input[{{}, {}, {1, -2}, {2, -1}}]) 27 | self.y_diff:copy(input[{{}, {}, {1, -2}, {1, -2}}]) 28 | self.y_diff:add(-1, input[{{}, {}, {2, -1}, {1, -2}}]) 29 | self.gradInput[{{}, {}, {1, -2}, {1, -2}}]:add(self.x_diff):add(self.y_diff) 30 | self.gradInput[{{}, {}, {1, -2}, {2, -1}}]:add(-1, self.x_diff) 31 | self.gradInput[{{}, {}, {2, -1}, {1, -2}}]:add(-1, self.y_diff) 32 | self.gradInput:mul(self.strength) 33 | self.gradInput:add(gradOutput) 34 | return self.gradInput 35 | end 36 | 37 | -------------------------------------------------------------------------------- /modules/Calibrate.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local Calibrate, parent = torch.class('nn.Calibrate', 'nn.Module') 17 | 18 | function Calibrate:__init() 19 | parent.__init(self) 20 | self.gram = torch.Tensor() 21 | self.calibrator = nn.GramMatrix() 22 | end 23 | 24 | function Calibrate:updateOutput(input) 25 | assert(self and input) 26 | self.gram = self.calibrator:forward(input):squeeze() 27 | return input 28 | end 29 | 30 | function Calibrate:updateGradInput(input, gradOutput) 31 | assert(self and gradOutput) 32 | self.gradInput = gradOutput 33 | return self.gradInput 34 | end 35 | 36 | function Calibrate:getGram() 37 | return self.gram:clone() 38 | end 39 | -------------------------------------------------------------------------------- /Training.md: -------------------------------------------------------------------------------- 1 | ### Install Dependencies 2 | - Install python dependencies 3 | ```bash 4 | sudo apt-get -y install python2.7-dev 5 | sudo apt-get install libhdf5-dev 6 | cd experiments 7 | pip install --user -r requirements.txt 8 | ``` 9 | - Install the deepmind/torch-hdf5 which gives HDF5 bindings for Torch: 10 | ```bash 11 | luarocks install https://raw.githubusercontent.com/deepmind/torch-hdf5/master/hdf5-0-0.rockspec 12 | ``` 13 | ### Download and Prepare the Data 14 | - Download the dataset 15 | ```bash 16 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip 17 | wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip 18 | unzip train2014.zip 19 | unzip val2014.zip 20 | ``` 21 | - Create HDF5 file 22 | ```bash 23 | python make_style_dataset.py \ 24 | --train_dir train2014 \ 25 | --val_dir val2014 \ 26 | --output_file data.h5 27 | ``` 28 | - Download the VGG-16 Torch model 29 | ```bash 30 | bash models/download_models.sh 31 | ``` 32 | 33 | ### Train the Model 34 | - Training a 16 style model. For customized styles, set ``style_image_folder`` pointing at the folders containing the style images. 35 | ```bash 36 | th main.lua \ 37 | -h5_file data.h5 \ 38 | -style_image_folder images/9styles \ 39 | -style_image_size 512 \ 40 | -checkpoint_name 9styles \ 41 | -gpu 0 42 | ``` 43 | - Test the model 44 | ```bash 45 | th test.lua \ 46 | -premodel 9styles.t7 \ 47 | -input_image images/content/venice-boat.jpg 48 | ``` 49 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## Free to reuse and distribute this software for research or 8 | ## non-profit purpose, subject to the following conditions: 9 | ## 1. The code must retain the above copyright notice, this list of 10 | ## conditions. 11 | ## 2. Original authors' names are not deleted. 12 | ## 3. The authors' names are not used to endorse or promote products 13 | ## derived from this software 14 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR) 17 | CMAKE_POLICY(VERSION 2.8) 18 | 19 | # Find Torch and CUDA 20 | FIND_PACKAGE(Torch REQUIRED) 21 | FIND_PACKAGE(CUDA 6.5 REQUIRED) 22 | 23 | # Detect CUDA architecture and get best NVCC flags 24 | IF(NOT COMMAND CUDA_SELECT_NVCC_ARCH_FLAGS) 25 | INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/select_compute_arch.cmake) 26 | ENDIF() 27 | CUDA_SELECT_NVCC_ARCH_FLAGS(NVCC_FLAGS_EXTRA $ENV{TORCH_CUDA_ARCH_LIST}) 28 | LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) 29 | 30 | INCLUDE_DIRECTORIES( 31 | ./include 32 | ${CMAKE_CURRENT_SOURCE_DIR} 33 | "${Torch_INSTALL_INCLUDE}/THC" 34 | ) 35 | LINK_DIRECTORIES("${Torch_INSTALL_LIB}") 36 | 37 | # ADD lua source files 38 | FILE(GLOB src *.c *.cu) 39 | FILE(GLOB luasrc *.lua modules/*.lua) 40 | 41 | # ADD the torch package and link denpendencies 42 | ADD_TORCH_PACKAGE(texture "" "${luasrc}") 43 | TARGET_LINK_LIBRARIES(texture 44 | ) 45 | -------------------------------------------------------------------------------- /modules/StyleLoss.lua: -------------------------------------------------------------------------------- 1 | local StyleLoss, parent = torch.class('nn.StyleLoss', 'nn.Module') 2 | 3 | 4 | function StyleLoss:__init(strength) 5 | parent.__init(self) 6 | self.strength = strength or 1.0 7 | self.loss = 0 8 | self.target = torch.Tensor() 9 | 10 | self.agg = nn.GramMatrix() 11 | self.agg_out = nil 12 | self.mode = 'none' 13 | self.crit = nn.MSECriterion() 14 | end 15 | 16 | 17 | function StyleLoss:updateOutput(input) 18 | self.agg_out = self.agg:forward(input) 19 | if self.mode == 'capture' then 20 | self.target:resizeAs(self.agg_out):copy(self.agg_out) 21 | elseif self.mode == 'loss' then 22 | local target = self.target 23 | if self.agg_out:size(1) > 1 and self.target:size(1) == 1 then 24 | -- Handle minibatch inputs 25 | target = target:expandAs(self.agg_out) 26 | end 27 | self.loss = self.strength * self.crit(self.agg_out, target) 28 | self._target = target 29 | end 30 | self.output = input 31 | return self.output 32 | end 33 | 34 | 35 | function StyleLoss:updateGradInput(input, gradOutput) 36 | if self.mode == 'capture' or self.mode == 'none' then 37 | self.gradInput = gradOutput 38 | elseif self.mode == 'loss' then 39 | self.crit:backward(self.agg_out, self._target) 40 | self.crit.gradInput:mul(self.strength) 41 | self.agg:backward(input, self.crit.gradInput) 42 | self.gradInput:add(self.agg.gradInput, gradOutput) 43 | end 44 | return self.gradInput 45 | end 46 | 47 | 48 | function StyleLoss:setMode(mode) 49 | if mode ~= 'capture' and mode ~= 'loss' and mode ~= 'none' then 50 | error(string.format('Invalid mode "%s"', mode)) 51 | end 52 | self.mode = mode 53 | end 54 | -------------------------------------------------------------------------------- /modules/ContentLoss.lua: -------------------------------------------------------------------------------- 1 | local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module') 2 | 3 | 4 | --[[ 5 | Module to compute content loss in-place. 6 | 7 | The module can be in one of three modes: "none", "capture", or "loss", which 8 | behave as follows: 9 | - "none": This module does nothing; it is basically nn.Identity(). 10 | - "capture": On the forward pass, inputs are captured as targets; otherwise it 11 | is the same as an nn.Identity(). 12 | - "loss": On the forward pass, compute the distance between input and 13 | self.target, store the result in self.loss, and return input. On the backward 14 | pass, add compute the gradient of self.loss with respect to the inputs, and 15 | add this value to the upstream gradOutput to produce gradInput. 16 | --]] 17 | 18 | function ContentLoss:__init(strength) 19 | parent.__init(self) 20 | self.strength = strength or 1.0 21 | self.loss = 0 22 | self.target = torch.Tensor() 23 | 24 | self.mode = 'none' 25 | self.crit = nn.MSECriterion() 26 | 27 | end 28 | 29 | 30 | function ContentLoss:updateOutput(input) 31 | if self.mode == 'capture' then 32 | self.target:resizeAs(input):copy(input) 33 | elseif self.mode == 'loss' then 34 | self.loss = self.strength * self.crit:forward(input, self.target) 35 | end 36 | self.output = input 37 | return self.output 38 | end 39 | 40 | 41 | function ContentLoss:updateGradInput(input, gradOutput) 42 | if self.mode == 'capture' or self.mode == 'none' then 43 | self.gradInput = gradOutput 44 | elseif self.mode == 'loss' then 45 | self.gradInput = self.crit:backward(input, self.target) 46 | self.gradInput:mul(self.strength) 47 | self.gradInput:add(gradOutput) 48 | end 49 | return self.gradInput 50 | end 51 | 52 | 53 | function ContentLoss:setMode(mode) 54 | if mode ~= 'capture' and mode ~= 'loss' and mode ~= 'none' then 55 | error(string.format('Invalid mode "%s"', mode)) 56 | end 57 | self.mode = mode 58 | end 59 | -------------------------------------------------------------------------------- /experiments/extractGram.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | require 'texture' 17 | require 'image' 18 | require 'optim' 19 | 20 | require 'utils.DataLoader' 21 | local utils = require 'utils.utils' 22 | local preprocess = require 'utils.preprocess' 23 | local opts = require 'opts' 24 | local imgLoader = require 'utils.getImages' 25 | 26 | local M = {} 27 | 28 | function M.exec(opt) 29 | local styleLoader = imgLoader(opt.style_image_folder) 30 | if not preprocess[opt.preprocessing] then 31 | local msg = 'invalid -preprocessing "%s"; must be "vgg" or "resnet"' 32 | error(string.format(msg, opt.preprocessing)) 33 | end 34 | preprocess = preprocess[opt.preprocessing] 35 | 36 | models = require('models/' .. opt.model) 37 | local cnet = models.createCNets(opt) 38 | 39 | cnet:cuda() 40 | cnet:evaluate() 41 | 42 | local feat = {} 43 | for i = 1,styleLoader:size() do 44 | feat[i] = {} 45 | local style_image = styleLoader:get(i) 46 | style_image = image.scale(style_image, opt.style_image_size) 47 | style_image = preprocess.preprocess(style_image:add_dummy()) 48 | feat[i] = cnet:calibrate(style_image:cuda()) 49 | end 50 | 51 | local filename = opt.style_image_folder .. '/feat.t7' 52 | torch.save(filename, feat) 53 | print('Feats have been written to ', filename) 54 | end 55 | 56 | return M 57 | -------------------------------------------------------------------------------- /experiments/utils/preprocess.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | 4 | local M = {} 5 | 6 | 7 | local function check_input(img) 8 | assert(img:dim() == 4, 'img must be N x C x H x W') 9 | assert(img:size(2) == 3, 'img must have three channels') 10 | end 11 | 12 | 13 | M.resnet = {} 14 | 15 | local resnet_mean = {0.485, 0.456, 0.406} 16 | local resnet_std = {0.229, 0.224, 0.225} 17 | 18 | 19 | --[[ 20 | Preprocess an image before passing to a ResNet model. The preprocessing is easy: 21 | we just need to subtract the mean and divide by the standard deviation. These 22 | constants are taken from fb.resnet.torch: 23 | 24 | https://github.com/facebook/fb.resnet.torch/blob/master/datasets/imagenet.lua 25 | 26 | Input: 27 | - img: Tensor of shape (N, C, H, W) giving a batch of images. Images are RGB 28 | in the range [0, 1]. 29 | ]] 30 | function M.resnet.preprocess(img) 31 | check_input(img) 32 | local mean = img.new(resnet_mean):view(1, 3, 1, 1):expandAs(img) 33 | local std = img.new(resnet_std):view(1, 3, 1, 1):expandAs(img) 34 | return (img - mean):cdiv(std) 35 | end 36 | 37 | -- Undo ResNet preprocessing. 38 | function M.resnet.deprocess(img) 39 | check_input(img) 40 | local mean = img.new(resnet_mean):view(1, 3, 1, 1):expandAs(img) 41 | local std = img.new(resnet_std):view(1, 3, 1, 1):expandAs(img) 42 | return torch.cmul(img, std):add(mean) 43 | end 44 | 45 | 46 | M.vgg = {} 47 | 48 | local vgg_mean = {103.939, 116.779, 123.68} 49 | 50 | --[[ 51 | Preprocess an image before passing to a VGG model. We need to rescale from 52 | [0, 1] to [0, 255], convert from RGB to BGR, and subtract the mean. 53 | 54 | Input: 55 | - img: Tensor of shape (N, C, H, W) giving a batch of images. Images 56 | ]] 57 | function M.vgg.preprocess(img) 58 | check_input(img) 59 | local mean = img.new(vgg_mean):view(1, 3, 1, 1):expandAs(img) 60 | local perm = torch.LongTensor{3, 2, 1} 61 | return img:index(2, perm):mul(255):add(-1, mean) 62 | end 63 | 64 | 65 | -- Undo VGG preprocessing 66 | function M.vgg.deprocess(img) 67 | check_input(img) 68 | local mean = img.new(vgg_mean):view(1, 3, 1, 1):expandAs(img) 69 | local perm = torch.LongTensor{3, 2, 1} 70 | --img = img + mean 71 | return (img + mean):div(255):index(2, perm) 72 | end 73 | 74 | 75 | return M 76 | -------------------------------------------------------------------------------- /experiments/utils/DataLoader.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'hdf5' 3 | 4 | local utils = require 'utils.utils' 5 | local preprocess = require 'utils.preprocess' 6 | 7 | local DataLoader = torch.class('DataLoader') 8 | 9 | 10 | function DataLoader:__init(opt) 11 | assert(opt.h5_file, 'Must provide h5_file') 12 | assert(opt.batch_size, 'Must provide batch size') 13 | self.preprocess_fn = preprocess[opt.preprocessing].preprocess 14 | 15 | self.h5_file = hdf5.open(opt.h5_file, 'r') 16 | self.batch_size = opt.batch_size 17 | 18 | self.split_idxs = { 19 | train = 1, 20 | val = 1, 21 | } 22 | 23 | self.image_paths = { 24 | train = '/train2014/images', 25 | val = '/val2014/images', 26 | } 27 | 28 | local train_size = self.h5_file:read(self.image_paths.train):dataspaceSize() 29 | self.split_sizes = { 30 | train = train_size[1], 31 | val = self.h5_file:read(self.image_paths.val):dataspaceSize()[1], 32 | } 33 | self.num_channels = train_size[2] 34 | self.image_height = train_size[3] 35 | self.image_width = train_size[4] 36 | 37 | self.num_minibatches = {} 38 | for k, v in pairs(self.split_sizes) do 39 | self.num_minibatches[k] = math.floor(v / self.batch_size) 40 | end 41 | 42 | if opt.max_train and opt.max_train > 0 then 43 | self.split_sizes.train = opt.max_train 44 | end 45 | 46 | self.rgb_to_gray = torch.FloatTensor{0.2989, 0.5870, 0.1140} 47 | end 48 | 49 | 50 | function DataLoader:reset(split) 51 | self.split_idxs[split] = 1 52 | end 53 | 54 | 55 | function DataLoader:getBatch(split) 56 | local path = self.image_paths[split] 57 | 58 | local start_idx = self.split_idxs[split] 59 | local end_idx = math.min(start_idx + self.batch_size - 1, 60 | self.split_sizes[split]) 61 | 62 | -- Load images out of the HDF5 file 63 | local images = self.h5_file:read(path):partial( 64 | {start_idx, end_idx}, 65 | {1, self.num_channels}, 66 | {1, self.image_height}, 67 | {1, self.image_width}):float():div(255) 68 | 69 | -- Advance counters, maybe rolling back to the start 70 | self.split_idxs[split] = end_idx + 1 71 | if self.split_idxs[split] > self.split_sizes[split] then 72 | self.split_idxs[split] = 1 73 | end 74 | 75 | -- Preprocess images 76 | images_pre = self.preprocess_fn(images) 77 | 78 | return images_pre, images_pre 79 | end 80 | 81 | -------------------------------------------------------------------------------- /modules/GramMatrix.lua: -------------------------------------------------------------------------------- 1 | local Gram, parent = torch.class('nn.GramMatrix', 'nn.Module') 2 | 3 | 4 | --[[ 5 | A layer to compute the Gram Matrix of inputs. 6 | 7 | Input: 8 | - features: A tensor of shape (N, C, H, W) or (C, H, W) giving features for 9 | either a single image or a minibatch of images. 10 | 11 | Output: 12 | - gram: A tensor of shape (N, C, C) or (C, C) giving Gram matrix for input. 13 | --]] 14 | 15 | 16 | function Gram:__init(normalize) 17 | parent.__init(self) 18 | if normalize ~= nil then 19 | assert(type(normalize) == 'boolean', 'normalize has to be true/false') 20 | self.normalize = normalize 21 | else 22 | self.normalize = true 23 | end 24 | self.buffer = torch.Tensor() 25 | end 26 | 27 | 28 | function Gram:updateOutput(input) 29 | local C, H, W 30 | if input:dim() == 3 then 31 | C, H, W = input:size(1), input:size(2), input:size(3) 32 | local x_flat = input:view(C, H * W) 33 | self.output:resize(C, C) 34 | self.output:mm(x_flat, x_flat:t()) 35 | elseif input:dim() == 4 then 36 | local N = input:size(1) 37 | C, H, W = input:size(2), input:size(3), input:size(4) 38 | self.output:resize(N, C, C) 39 | local x_flat = input:view(N, C, H * W) 40 | self.output:resize(N, C, C) 41 | self.output:bmm(x_flat, x_flat:transpose(2, 3)) 42 | end 43 | if self.normalize then 44 | -- print('in gram forward; dividing by ', C * H * W) 45 | self.output:div(C * H * W) 46 | end 47 | return self.output 48 | end 49 | 50 | 51 | function Gram:updateGradInput(input, gradOutput) 52 | self.gradInput:resizeAs(input):zero() 53 | local C, H, W 54 | if input:dim() == 3 then 55 | C, H, W = input:size(1), input:size(2), input:size(3) 56 | local x_flat = input:view(C, H * W) 57 | self.buffer:resize(C, H * W) 58 | self.buffer:mm(gradOutput, x_flat) 59 | self.buffer:addmm(gradOutput:t(), x_flat) 60 | self.gradInput = self.buffer:view(C, H, W) 61 | elseif input:dim() == 4 then 62 | local N = input:size(1) 63 | C, H, W = input:size(2), input:size(3), input:size(4) 64 | local x_flat = input:view(N, C, H * W) 65 | self.buffer:resize(N, C, H * W) 66 | self.buffer:bmm(gradOutput, x_flat) 67 | self.buffer:baddbmm(gradOutput:transpose(2, 3), x_flat) 68 | self.gradInput = self.buffer:view(N, C, H, W) 69 | end 70 | if self.normalize then 71 | self.buffer:div(C * H * W) 72 | end 73 | assert(self.gradInput:isContiguous()) 74 | return self.gradInput 75 | end 76 | 77 | -------------------------------------------------------------------------------- /modules/InstanceNormalization.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | --[[ 4 | Implements instance normalization as described in the paper 5 | 6 | Instance Normalization: The Missing Ingredient for Fast Stylization 7 | Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky 8 | https://arxiv.org/abs/1607.08022 9 | 10 | This implementation is based on 11 | https://github.com/DmitryUlyanov/texture_nets 12 | ]] 13 | 14 | local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 15 | 'nn.Module') 16 | 17 | 18 | function InstanceNormalization:__init(nOutput, eps) 19 | parent.__init(self) 20 | 21 | self.eps = eps or 1e-5 22 | 23 | self.nOutput = nOutput 24 | self.prev_N = -1 25 | 26 | self.weight = torch.Tensor(nOutput):uniform() 27 | self.bias = torch.Tensor(nOutput):zero() 28 | self.gradWeight = torch.Tensor(nOutput) 29 | self.gradBias = torch.Tensor(nOutput) 30 | end 31 | 32 | 33 | function InstanceNormalization:updateOutput(input) 34 | local N, C = input:size(1), input:size(2) 35 | local H, W = input:size(3), input:size(4) 36 | assert(C == self.nOutput) 37 | 38 | if N ~= self.prev_N or (self.bn and self:type() ~= self.bn:type()) then 39 | self.bn = nn.SpatialBatchNormalization(N * C, self.eps) 40 | self.bn:type(self:type()) 41 | self.prev_N = N 42 | end 43 | 44 | -- Set params for BN 45 | self.bn.weight:repeatTensor(self.weight, N) 46 | self.bn.bias:repeatTensor(self.bias, N) 47 | 48 | local input_view = input:view(1, N * C, H, W) 49 | self.bn:training() 50 | self.output = self.bn:forward(input_view):viewAs(input) 51 | 52 | return self.output 53 | end 54 | 55 | 56 | function InstanceNormalization:updateGradInput(input, gradOutput) 57 | local N, C = input:size(1), input:size(2) 58 | local H, W = input:size(3), input:size(4) 59 | assert(self.bn) 60 | 61 | local input_view = input:view(1, N * C, H, W) 62 | local gradOutput_view = gradOutput:view(1, N * C, H, W) 63 | 64 | self.bn.gradWeight:zero() 65 | self.bn.gradBias:zero() 66 | 67 | self.bn:training() 68 | self.gradInput = self.bn:backward(input_view, gradOutput_view):viewAs(input) 69 | 70 | self.gradWeight:add(self.bn.gradWeight:view(N, C):sum(1)) 71 | self.gradBias:add(self.bn.gradBias:view(N, C):sum(1)) 72 | return self.gradInput 73 | end 74 | 75 | 76 | function InstanceNormalization:clearState() 77 | self.output = self.output.new() 78 | self.gradInput = self.gradInput.new() 79 | self.bn:clearState() 80 | end 81 | 82 | -------------------------------------------------------------------------------- /experiments/opts.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local M={} 17 | 18 | function M.parse(arg) 19 | local cmd = torch.CmdLine() 20 | 21 | cmd:text() 22 | cmd:text('Options:') 23 | 24 | -- Generic options 25 | cmd:option('-model', 'hang') 26 | cmd:option('-model_nres', '9') 27 | cmd:option('-use_instance_norm', 1) 28 | cmd:option('-h5_file', 'data.h5') 29 | cmd:option('-padding_type', 'reflect-start') 30 | cmd:option('-tanh_constant', 150) 31 | cmd:option('-preprocessing', 'vgg') 32 | cmd:option('-resume', '') 33 | 34 | -- Style loss function options 35 | --cmd:option('-percep_loss_weight', 1.0) 36 | cmd:option('-tv_strength', 1e-6) 37 | 38 | -- Options for feature reconstruction loss 39 | cmd:option('-content_weights', '1.0') 40 | cmd:option('-content_layers', '16') 41 | cmd:option('-loss_network', 'models/vgg16.t7') 42 | 43 | -- Options for style reconstruction loss 44 | cmd:option('-style_image_folder', 'images/9styles') 45 | cmd:option('-style_image_size', 512) 46 | cmd:option('-style_iter', 20) 47 | cmd:option('-style_weights', '5.0') 48 | cmd:option('-style_layers', '4,9,16,23') 49 | 50 | -- Optimization 51 | cmd:option('-num_iterations', 80000) 52 | cmd:option('-max_train', -1) 53 | cmd:option('-batch_size', 4) 54 | cmd:option('-learning_rate', 1e-3) 55 | cmd:option('-lr_decay_every', -1) 56 | cmd:option('-lr_decay_factor', 0.5) 57 | cmd:option('-weight_decay', 0) 58 | 59 | -- Checkpointing 60 | cmd:option('-checkpoint_name', 'checkpoint') 61 | cmd:option('-checkpoint_every', 1000) 62 | cmd:option('-num_val_batches', 10) 63 | 64 | -- Backend options 65 | cmd:option('-gpu', 0) 66 | cmd:option('-use_cudnn', 1) 67 | cmd:option('-backend', 'cuda', 'cuda|opencl') 68 | 69 | -- Test options 70 | cmd:option('-premodel', 'models/model_9styles.t7') 71 | cmd:option('-input_image', 'images/content/venice-boat.jpg') 72 | cmd:option('-output_dir', 'stylized') 73 | cmd:option('-image_size', 512) 74 | 75 | -- Webcam options 76 | cmd:option('-webcam_idx', 0) 77 | cmd:option('-webcam_fps', 30) 78 | cmd:option('-height', 480) 79 | cmd:option('-width', 640) 80 | 81 | local opt = cmd:parse(arg or {}) 82 | return opt 83 | end 84 | 85 | return M 86 | -------------------------------------------------------------------------------- /experiments/test.lua: -------------------------------------------------------------------------------- 1 | require 'texture' 2 | require 'image' 3 | 4 | local utils = require 'utils.utils' 5 | local preprocess = require 'utils.preprocess' 6 | local imgLoader = require 'utils.getImages' 7 | local opts = require 'opts' 8 | 9 | local function main() 10 | local opt = opts.parse(arg) 11 | opt.style_layers, opt.style_weights = utils.parse_layers(opt.style_layers, 12 | opt.style_weights) 13 | if (opt.input_image == '') then 14 | error('Must give exactly one of -input_image') 15 | end 16 | 17 | local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1) 18 | local ok, checkpoint = pcall(function() return torch.load(opt.premodel) end) 19 | if not ok then 20 | print('ERROR: Could not load model from ' .. opt.premodel) 21 | print('You may need to download the pretrained models by running') 22 | return 23 | end 24 | local model = checkpoint.model 25 | model:evaluate() 26 | model:type(dtype) 27 | if use_cudnn then 28 | cudnn.convert(model, cudnn) 29 | if opt.cudnn_benchmark == 0 then 30 | cudnn.benchmark = false 31 | cudnn.fastest = true 32 | end 33 | end 34 | 35 | local preprocess_method = checkpoint.opt.preprocessing or 'vgg' 36 | local preprocess = preprocess[preprocess_method] 37 | local styleLoader = imgLoader(opt.style_image_folder) 38 | 39 | local featpath = opt.style_image_folder .. '/feat.t7' 40 | if not paths.filep(featpath) then 41 | local extractor = require "extractGram" 42 | extractor.exec(opt) 43 | end 44 | 45 | local feat = torch.load(featpath) 46 | feat = nn.utils.recursiveType(feat, 'torch.CudaTensor') 47 | 48 | local function run_image(in_path, out_dir, styleLoader, feat) 49 | if not path.isdir(out_dir) then 50 | paths.mkdir(out_dir) 51 | end 52 | local img = image.load(in_path, 3) 53 | if opt.image_size > 0 then 54 | img = image.scale(img, opt.image_size) 55 | end 56 | local H, W = img:size(2), img:size(3) 57 | local img_pre = preprocess.preprocess(img:view(1, 3, H, W)):type(dtype) 58 | 59 | for i=1,styleLoader:size() do 60 | local style_image = styleLoader:get(i) 61 | model:setTarget(feat[i], dtype) 62 | 63 | local img_out = model:forward(img_pre) 64 | local img_out = preprocess.deprocess(img_out)[1] 65 | 66 | local out_path = paths.concat(opt.output_dir, i) .. '.jpg' 67 | local out_path_style = paths.concat(opt.output_dir, i) .. 'style.jpg' 68 | print('Writing output image to ' .. out_path) 69 | image.save(out_path, img_out) 70 | image.save(out_path_style, style_image) 71 | collectgarbage() 72 | collectgarbage() 73 | end 74 | end 75 | 76 | if opt.input_image ~= '' then 77 | if opt.output_image == '' then 78 | error('Must give -output_image with -input_image') 79 | end 80 | run_image(opt.input_image, opt.output_dir, styleLoader, feat) 81 | else 82 | error('Must provide input image') 83 | end 84 | end 85 | 86 | main() 87 | -------------------------------------------------------------------------------- /experiments/utils/getImages.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | require 'paths' 19 | require 'image' 20 | 21 | local M={} 22 | local Dataset = torch.class('texture.Dataset', M) 23 | 24 | function Dataset:_findImages(dir) 25 | local imagePath = torch.CharTensor() 26 | 27 | ---------------------------------------------------------------------- 28 | -- Options for the GNU and BSD find command 29 | local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 30 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 31 | for i=2,#extensionList do 32 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 33 | end 34 | 35 | -- Find all the images using the find command 36 | local f = io.popen('find -L ' .. dir .. findOptions) 37 | 38 | local maxLength = -1 39 | local imagePaths = {} 40 | 41 | -- Generate a list of all the images 42 | while true do 43 | local line = f:read('*line') 44 | if not line then break end 45 | 46 | local filename = paths.basename(line) 47 | local path = dir .. '/' .. filename 48 | 49 | table.insert(imagePaths, path) 50 | 51 | maxLength = math.max(maxLength, #path + 1) 52 | end 53 | 54 | f:close() 55 | 56 | -- Convert the generated list to a tensor for faster loading 57 | local nImages = #imagePaths 58 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 59 | for i, path in ipairs(imagePaths) do 60 | ffi.copy(imagePath[i]:data(), path) 61 | end 62 | 63 | return imagePath 64 | end 65 | 66 | function Dataset:__init(imgDir) 67 | assert(self) 68 | assert(paths.dirp(imgDir), 'image directory not found: ' .. imgDir) 69 | self.imagePath = {self:_findImages(imgDir)} 70 | end 71 | 72 | function Dataset:_loadImage(path) 73 | local ok, input = pcall(function() 74 | return image.load(path, 3, 'float') 75 | end) 76 | 77 | -- Sometimes image.load fails because the file extension does not match the 78 | -- image format. In that case, use image.decompress on a ByteTensor. 79 | if not ok then 80 | local f = io.open(path, 'r') 81 | assert(f, 'Error reading: ' .. tostring(path)) 82 | local data = f:read('*a') 83 | f:close() 84 | 85 | local b = torch.ByteTensor(string.len(data)) 86 | ffi.copy(b:data(), data, b:size(1)) 87 | 88 | input = image.decompress(b, 3, 'float') 89 | end 90 | 91 | return input 92 | end 93 | 94 | function Dataset:get(i) 95 | assert(self and self.imagePath) 96 | i = (i-1) % self:size() + 1 97 | local path = ffi.string(self.imagePath[1][i]:data()) 98 | local image = self:_loadImage(path) 99 | return image 100 | end 101 | 102 | function Dataset:size() 103 | assert(self and self.imagePath) 104 | return self.imagePath[1]:size(1) 105 | end 106 | 107 | return M.Dataset 108 | -------------------------------------------------------------------------------- /experiments/webcam.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'texture' 3 | require 'image' 4 | require 'camera' 5 | 6 | require 'qt' 7 | require 'qttorch' 8 | require 'qtwidget' 9 | 10 | local utils = require 'utils.utils' 11 | local preprocess = require 'utils.preprocess' 12 | local imgLoader = require 'utils.getImages' 13 | local opts = require 'opts' 14 | 15 | local function main() 16 | local opt = opts.parse(arg) 17 | 18 | if (opt.input_image == '') then 19 | error('Must give exactly one of -input_image') 20 | end 21 | 22 | local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1) 23 | local ok, checkpoint = pcall(function() return torch.load(opt.premodel) end) 24 | if not ok then 25 | print('ERROR: Could not load model from ' .. opt.premodel) 26 | print('You may need to download the pretrained models by running') 27 | return 28 | end 29 | local model = checkpoint.model 30 | model:evaluate() 31 | model:type(dtype) 32 | if use_cudnn then 33 | cudnn.convert(model, cudnn) 34 | if opt.cudnn_benchmark == 0 then 35 | cudnn.benchmark = false 36 | cudnn.fastest = true 37 | end 38 | end 39 | 40 | local preprocess_method = checkpoint.opt.preprocessing or 'vgg' 41 | local preprocess = preprocess[preprocess_method] 42 | local styleLoader = imgLoader(opt.style_image_folder) 43 | 44 | local featpath = opt.style_image_folder .. '/feat.t7' 45 | if not paths.filep(featpath) then 46 | local extractor = require "extractGram" 47 | extractor.exec(opt) 48 | end 49 | 50 | local feat = torch.load(featpath) 51 | feat = nn.utils.recursiveType(feat, 'torch.CudaTensor') 52 | 53 | local style_image = nil 54 | local function run_image(img, feat, idx) 55 | if opt.image_size > 0 then 56 | img = image.scale(img, opt.image_size) 57 | end 58 | local H, W = img:size(2), img:size(3) 59 | local img_pre = preprocess.preprocess(img:view(1, 3, H, W)):type(dtype) 60 | 61 | -- update style image 62 | if (idx-1) % 15 == 0 then 63 | local i=torch.floor((idx-1)/15)%styleLoader:size()+1 64 | style_image = styleLoader:get(i):float() 65 | model:setTarget(feat[i], dtype) 66 | end 67 | 68 | local img_out = model:forward(img_pre) 69 | local styleSize = torch.floor(opt.image_size / 4) 70 | 71 | style_image = image.scale(style_image,styleSize,styleSize) 72 | img_out = preprocess.deprocess(img_out:float())[1] 73 | img = img:float() 74 | img:sub(1,3,21,20+styleSize,21,20+styleSize):copy(style_image) 75 | img_out = torch.cat(img,img_out,3) 76 | 77 | collectgarbage() 78 | collectgarbage() 79 | return img_out 80 | end 81 | 82 | local camera_opt = { 83 | idx = opt.webcam_idx, 84 | fps = opt.webcam_fps, 85 | height = opt.height, 86 | width = opt.width, 87 | } 88 | local cam = image.Camera(camera_opt) 89 | 90 | local win = nil 91 | local idx = 1 92 | 93 | while true do 94 | -- Grab a frame from the webcam 95 | local img = cam:forward() 96 | 97 | -- Run the model 98 | local img_disp = run_image(img, feat, idx) 99 | 100 | idx = idx + 1 101 | if not win then 102 | -- On the first call use image.display to construct a window 103 | win = image.display(img_disp) 104 | else 105 | -- Reuse the same window 106 | win.image = img_disp 107 | local size = win.window.size:totable() 108 | local qt_img = qt.QImage.fromTensor(img_disp) 109 | win.painter:image(0, 0, size.width, size.height, qt_img) 110 | end 111 | end 112 | end 113 | 114 | 115 | main() 116 | 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 17 | 18 | 19 |
5 | Multi-style Generative Network for Real-time Transfer [arXiv] [project]
6 | Hang Zhang, Kristin Dana 7 |
 8 | @article{zhang2017multistyle,
 9 | 	title={Multi-style Generative Network for Real-time Transfer},
10 | 	author={Zhang, Hang and Dana, Kristin},
11 | 	journal={arXiv preprint arXiv:1703.06953},
12 | 	year={2017}
13 | }
14 | 
15 |
20 | 21 | ### Installation 22 | We also provide [PyTorch implementation](https://github.com/zhanghang1989/PyTorch-Style-Transfer) and [MXNet implementation](https://github.com/zhanghang1989/MXNet-Gluon-Style-Transfer). 23 | Please install [Torch7](http://torch.ch/) with cuda and cudnn support. The code has been tested on Ubuntu 16.04 with Titan X Pascal and Maxwell. 24 | ```bash 25 | luarocks install https://raw.githubusercontent.com/zhanghang1989/MSG-Net/master/texture-scm-1.rockspec 26 | ``` 27 | 28 | ### Test and Demo 29 | 30 | 0. Clone the repo and download pre-trained models 31 | ```bash 32 | git clone git@github.com:zhanghang1989/MSG-Net.git 33 | cd MSG-Net/experiments 34 | bash models/download_models.sh 35 | ``` 36 | 0. Web Camera Demo 37 | ``` 38 | qlua webcam.lua 39 | ``` 40 | ![](https://raw.githubusercontent.com/zhanghang1989/PyTorch-Style-Transfer/master/images/myimage.gif) 41 | 0. Test on Image 42 | ``` 43 | th test.lua -input_image images/content/venice-boat.jpg -image_size 1024 44 | eog stylized 45 | ``` 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | [[More Example Results](Examples.md)] 56 | 57 | ### Train Your Own Model 58 | Please follow [this tutorial to train a new model](Training.md). 59 | 60 | ### Release Timeline 61 | - [x] 03/20/2017 we have released the [demo video](https://www.youtube.com/watch?v=oy6pWNWBt4Y). 62 | - [x] 03/24/2017 We have released [ArXiv paper](https://arxiv.org/pdf/1703.06953.pdf) and test code with pre-trained models. 63 | - [x] 04/09/2017 We have released the training code. 64 | - [x] 04/24/2017 Please checkout our PyTorch [implementation](https://github.com/zhanghang1989/PyTorch-Style-Transfer). 65 | 66 | ### Acknowledgement 67 | The code benefits from outstanding prior work and their implementations including: 68 | - [Texture Networks: Feed-forward Synthesis of Textures and Stylized Images](https://arxiv.org/pdf/1603.03417.pdf) by Ulyanov *et al. ICML 2016*. ([code](https://github.com/DmitryUlyanov/texture_nets)) 69 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/pdf/1603.08155.pdf) by Johnson *et al. ECCV 2016* ([code](https://github.com/jcjohnson/fast-neural-style)) 70 | - [Image Style Transfer Using Convolutional Neural Networks](http://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf) by Gatys *et al. CVPR 2016* and its torch implementation [code](https://github.com/jcjohnson/neural-style) by Johnson. 71 | -------------------------------------------------------------------------------- /modules/Inspiration.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | local Inspiration, parent = torch.class('nn.Inspiration', 'nn.Module') 17 | 18 | local function isint(x) 19 | return type(x) == 'number' and x == math.floor(x) 20 | end 21 | 22 | function Inspiration:__init(C) 23 | parent.__init(self) 24 | assert(self and C, 'should specify C') 25 | assert(isint(C), 'C should be integers') 26 | 27 | self.C = C 28 | self.MM_WG = nn.MM() 29 | self.MM_PX = nn.MM(true, false) 30 | self.target = torch.Tensor(C,C) 31 | self.Weight = torch.Tensor(C,C) 32 | self.P = torch.Tensor(C,C) 33 | 34 | self.gradWeight = torch.Tensor(C, C) 35 | self.gradWG = {torch.Tensor(C, C), torch.Tensor(C, C)} 36 | self.gradPX = {torch.Tensor(), torch.Tensor()} 37 | self.gradInput = torch.Tensor() 38 | self:reset() 39 | end 40 | 41 | function Inspiration:reset(stdv) 42 | if stdv then 43 | stdv = stdv * math.sqrt(2) 44 | else 45 | stdv = 1./math.sqrt(self.C) 46 | end 47 | self.Weight:uniform(-stdv, stdv) 48 | self.target:uniform(-stdv, stdv) 49 | return self 50 | end 51 | 52 | function Inspiration:setTarget(nT) 53 | assert(self and image) 54 | self.target = nT 55 | end 56 | 57 | function Inspiration:updateOutput(input) 58 | assert(self) 59 | -- P=WG Y=XP 60 | --self.output:resizeAs(input) 61 | if input:dim() == 3 then 62 | self.P = self.MM_WG:forward({self.Weight, self.target}) 63 | self.output = self.MM_PX:forward({self.P, input:view(self.C,-1)}):viewAs(input) 64 | elseif input:dim() == 4 then 65 | local B = input:size(1) 66 | self.P = self.MM_WG:forward({self.Weight, self.target}) 67 | self.output = self.MM_PX:forward({self.P:add_dummy():expand(B,self.C,self.C), input:view(B,self.C,-1)}):viewAs(input) 68 | else 69 | error('Unsupported dimention for Inspiration layer') 70 | end 71 | return self.output 72 | end 73 | 74 | function Inspiration:updateGradInput(input, gradOutput) 75 | assert(self and self.gradInput) 76 | 77 | --self.gradInput:resizeAs(input):fill(0) 78 | if input:dim() == 3 then 79 | self.gradPX = self.MM_PX:backward({self.P, input:view(self.C,-1)}, gradOutput:view(self.C,-1)) 80 | elseif input:dim() == 4 then 81 | local B = input:size(1) 82 | self.gradPX = self.MM_PX:backward({self.P:add_dummy():expand(B,self.C,self.C), input:view(B,self.C,-1)}, gradOutput:view(B,self.C,-1)) 83 | else 84 | error('Unsupported dimention for Inspiration layer') 85 | end 86 | 87 | self.gradInput = self.gradPX[2]:viewAs(input) 88 | return self.gradInput 89 | end 90 | 91 | function Inspiration:accGradParameters(input, gradOutput, scale) 92 | assert(self) 93 | scale = scale or 1 94 | 95 | if input:dim() == 3 then 96 | self.gradWG = self.MM_WG:backward({self.Weight, self.target}, self.gradPX[1]) 97 | self.gradWeight = scale * self.gradWG[1] 98 | elseif input:dim() == 4 then 99 | self.gradWG = self.MM_WG:backward({self.Weight, self.target}, self.gradPX[1]:sum(1):squeeze()) 100 | self.gradWeight = scale * self.gradWG[1] 101 | else 102 | error('Unsupported dimention for Inspiration layer') 103 | end 104 | end 105 | 106 | function Inspiration:__tostring__() 107 | return torch.type(self) .. 108 | string.format( 109 | '(%dxHxW, -> %dxHxW)', 110 | self.C, self.C 111 | ) 112 | end 113 | -------------------------------------------------------------------------------- /experiments/utils/utils.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | local cjson = require 'cjson' 4 | 5 | 6 | local M = {} 7 | 8 | 9 | -- Parse a string of comma-separated numbers 10 | -- For example convert "1.0,3.14" to {1.0, 3.14} 11 | function M.parse_num_list(s) 12 | local nums = {} 13 | for _, ss in ipairs(s:split(',')) do 14 | table.insert(nums, tonumber(ss)) 15 | end 16 | return nums 17 | end 18 | 19 | 20 | -- Parse a layer string and associated weights string. 21 | -- The layers string is a string of comma-separated layer strings, and the 22 | -- weight string contains comma-separated numbers. If the weights string 23 | -- contains only a single number it is duplicated to be the same length as the 24 | -- layers. 25 | function M.parse_layers(layers_string, weights_string) 26 | local layers = layers_string:split(',') 27 | local weights = M.parse_num_list(weights_string) 28 | if #weights == 1 and #layers > 1 then 29 | -- Duplicate the same weight for all layers 30 | local w = weights[1] 31 | weights = {} 32 | for i = 1, #layers do 33 | table.insert(weights, w) 34 | end 35 | elseif #weights ~= #layers then 36 | local msg = 'size mismatch between layers "%s" and weights "%s"' 37 | error(string.format(msg, layers_string, weights_string)) 38 | end 39 | return layers, weights 40 | end 41 | 42 | 43 | function M.setup_gpu(gpu, backend, use_cudnn) 44 | local dtype = 'torch.FloatTensor' 45 | if gpu >= 0 then 46 | if backend == 'cuda' then 47 | require 'cutorch' 48 | require 'cunn' 49 | cutorch.setDevice(gpu + 1) 50 | dtype = 'torch.CudaTensor' 51 | if use_cudnn then 52 | require 'cudnn' 53 | cudnn.benchmark = true 54 | end 55 | elseif backend == 'opencl' then 56 | require 'cltorch' 57 | require 'clnn' 58 | cltorch.setDevice(gpu + 1) 59 | dtype = torch.Tensor():cl():type() 60 | use_cudnn = false 61 | end 62 | else 63 | use_cudnn = false 64 | end 65 | return dtype, use_cudnn 66 | end 67 | 68 | 69 | function M.clear_gradients(m) 70 | if torch.isTypeOf(m, nn.Container) then 71 | m:applyToModules(M.clear_gradients) 72 | end 73 | if m.weight and m.gradWeight then 74 | m.gradWeight = m.gradWeight.new() 75 | end 76 | if m.bias and m.gradBias then 77 | m.gradBias = m.gradBias.new() 78 | end 79 | end 80 | 81 | 82 | function M.restore_gradients(m) 83 | if torch.isTypeOf(m, nn.Container) then 84 | m:applyToModules(M.restore_gradients) 85 | end 86 | if m.weight and m.gradWeight then 87 | m.gradWeight = m.gradWeight.new(#m.weight):zero() 88 | end 89 | if m.bias and m.gradBias then 90 | m.gradBias = m.gradBias.new(#m.bias):zero() 91 | end 92 | end 93 | 94 | 95 | function M.read_json(path) 96 | local file = io.open(path, 'r') 97 | local text = file:read() 98 | file:close() 99 | local info = cjson.decode(text) 100 | return info 101 | end 102 | 103 | 104 | function M.write_json(path, j) 105 | cjson.encode_sparse_array(true, 2, 10) 106 | local text = cjson.encode(j) 107 | local file = io.open(path, 'w') 108 | file:write(text) 109 | file:close() 110 | end 111 | 112 | local IMAGE_EXTS = {'jpg', 'jpeg', 'png', 'ppm', 'pgm'} 113 | function M.is_image_file(filename) 114 | -- Hidden file are not images 115 | if string.sub(filename, 1, 1) == '.' then 116 | return false 117 | end 118 | -- Check against a list of known image extensions 119 | local ext = string.lower(paths.extname(filename) or "") 120 | for _, image_ext in ipairs(IMAGE_EXTS) do 121 | if ext == image_ext then 122 | return true 123 | end 124 | end 125 | return false 126 | end 127 | 128 | 129 | function M.median_filter(img, r) 130 | local u = img:unfold(2, r, 1):contiguous() 131 | u = u:unfold(3, r, 1):contiguous() 132 | local HH, WW = u:size(2), u:size(3) 133 | local dtype = u:type() 134 | -- Median is not defined for CudaTensors, cast to float and back 135 | local med = u:view(3, HH, WW, r * r):float():median():type(dtype) 136 | return med[{{}, {}, {}, 1}] 137 | end 138 | 139 | 140 | return M 141 | 142 | -------------------------------------------------------------------------------- /modules/layer_utils.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | --[[ 4 | Utility functions for getting and inserting layers into models composed of 5 | hierarchies of nn Modules and nn Containers. In such a model, we can uniquely 6 | address each module with a unique "layer string", which is a series of integers 7 | separated by dashes. This is easiest to understand with an example: consider 8 | the following network; we have labeled each module with its layer string: 9 | 10 | nn.Sequential { 11 | (1) nn.SpatialConvolution 12 | (2) nn.Sequential { 13 | (2-1) nn.SpatialConvolution 14 | (2-2) nn.SpatialConvolution 15 | } 16 | (3) nn.Sequential { 17 | (3-1) nn.SpatialConvolution 18 | (3-2) nn.Sequential { 19 | (3-2-1) nn.SpatialConvolution 20 | (3-2-2) nn.SpatialConvolution 21 | (3-2-3) nn.SpatialConvolution 22 | } 23 | (3-3) nn.SpatialConvolution 24 | } 25 | (4) nn.View 26 | (5) nn.Linear 27 | } 28 | 29 | Any layers that that have the instance variable _ignore set to true are ignored 30 | when computing layer strings for layers. This way, we can insert new layers into 31 | a network without changing the layer strings of existing layers. 32 | --]] 33 | local M = {} 34 | 35 | 36 | --[[ 37 | Convert a layer string to an array of integers. 38 | 39 | For example layer_string_to_nums("1-23-4") = {1, 23, 4}. 40 | --]] 41 | function M.layer_string_to_nums(layer_string) 42 | local nums = {} 43 | for _, s in ipairs(layer_string:split('-')) do 44 | table.insert(nums, tonumber(s)) 45 | end 46 | return nums 47 | end 48 | 49 | 50 | --[[ 51 | Comparison function for layer strings that is compatible with table.sort. 52 | In this comparison scheme, 2-3 comes AFTER 2-3-X for all X. 53 | 54 | Input: 55 | - s1, s2: Two layer strings. 56 | 57 | Output: 58 | - true if s1 should come before s2 in sorted order; false otherwise. 59 | --]] 60 | function M.compare_layer_strings(s1, s2) 61 | local left = M.layer_string_to_nums(s1) 62 | local right = M.layer_string_to_nums(s2) 63 | local out = nil 64 | for i = 1, math.min(#left, #right) do 65 | if left[i] < right[i] then 66 | out = true 67 | elseif left[i] > right[i] then 68 | out = false 69 | end 70 | if out ~= nil then break end 71 | end 72 | 73 | if out == nil then 74 | out = (#left > #right) 75 | end 76 | return out 77 | end 78 | 79 | 80 | --[[ 81 | Get a layer from the network net using a layer string. 82 | --]] 83 | function M.get_layer(net, layer_string) 84 | local nums = M.layer_string_to_nums(layer_string) 85 | local layer = net 86 | for i, num in ipairs(nums) do 87 | local count = 0 88 | for j = 1, #layer do 89 | if not layer:get(j)._ignore then 90 | count = count + 1 91 | end 92 | if count == num then 93 | layer = layer:get(j) 94 | break 95 | end 96 | end 97 | end 98 | return layer 99 | end 100 | 101 | 102 | -- Insert a new layer immediately after the layer specified by a layer string. 103 | -- Any layers inserted this way are flagged with a special variable 104 | function M.insert_after(net, layer_string, new_layer) 105 | new_layer._ignore = true 106 | local nums = M.layer_string_to_nums(layer_string) 107 | local container = net 108 | for i = 1, #nums do 109 | local count = 0 110 | for j = 1, #container do 111 | if not container:get(j)._ignore then 112 | count = count + 1 113 | end 114 | if count == nums[i] then 115 | if i < #nums then 116 | container = container:get(j) 117 | break 118 | elseif i == #nums then 119 | container:insert(new_layer, j + 1) 120 | return 121 | end 122 | end 123 | end 124 | end 125 | end 126 | 127 | 128 | -- Remove the layers of the network that occur after the last _ignore 129 | function M.trim_network(net) 130 | local function contains_ignore(layer) 131 | if torch.isTypeOf(layer, nn.Container) then 132 | local found = false 133 | for i = 1, layer:size() do 134 | found = found or contains_ignore(layer:get(i)) 135 | end 136 | return found 137 | else 138 | return layer._ignore == true 139 | end 140 | end 141 | local last_layer = 0 142 | for i = 1, #net do 143 | if contains_ignore(net:get(i)) then 144 | last_layer = i 145 | end 146 | end 147 | local num_to_remove = #net - last_layer 148 | for i = 1, num_to_remove do 149 | net:remove() 150 | end 151 | return net 152 | end 153 | 154 | 155 | return M 156 | 157 | -------------------------------------------------------------------------------- /modules/PerceptualCriterion.lua: -------------------------------------------------------------------------------- 1 | local layer_utils = require 'texture.layer_utils' 2 | 3 | local crit, parent = torch.class('nn.PerceptualCriterion', 'nn.Criterion') 4 | 5 | --[[ 6 | Input: args is a table with the following keys: 7 | - cnn: A network giving the base CNN. 8 | - content_layers: An array of layer strings 9 | - content_weights: A list of the same length as content_layers 10 | - style_layers: An array of layers strings 11 | - style_weights: A list of the same length as style_layers 12 | "mean" or "gram" 13 | - deepdream_layers: Array of layer strings 14 | - deepdream_weights: List of the same length as deepdream_layers 15 | --]] 16 | function crit:__init(args) 17 | args.content_layers = args.content_layers or {} 18 | args.style_layers = args.style_layers or {} 19 | args.deepdream_layers = args.deepdream_layers or {} 20 | 21 | self.net = args.cnn 22 | self.net:evaluate() 23 | self.content_loss_layers = {} 24 | self.style_loss_layers = {} 25 | self.deepdream_loss_layers = {} 26 | 27 | -- Set up content loss layers 28 | for i, layer_string in ipairs(args.content_layers) do 29 | local weight = args.content_weights[i] 30 | local content_loss_layer = nn.ContentLoss(weight) 31 | layer_utils.insert_after(self.net, layer_string, content_loss_layer) 32 | table.insert(self.content_loss_layers, content_loss_layer) 33 | end 34 | 35 | -- Set up style loss layers 36 | for i, layer_string in ipairs(args.style_layers) do 37 | local weight = args.style_weights[i] 38 | local style_loss_layer = nn.StyleLoss(weight) 39 | layer_utils.insert_after(self.net, layer_string, style_loss_layer) 40 | table.insert(self.style_loss_layers, style_loss_layer) 41 | end 42 | 43 | layer_utils.trim_network(self.net) 44 | self.grad_net_output = torch.Tensor() 45 | 46 | end 47 | 48 | 49 | --[[ 50 | target: Tensor of shape (1, 3, H, W) giving pixels for style target image 51 | --]] 52 | function crit:setStyleTarget(target) 53 | for i, content_loss_layer in ipairs(self.content_loss_layers) do 54 | content_loss_layer:setMode('none') 55 | end 56 | for i, style_loss_layer in ipairs(self.style_loss_layers) do 57 | style_loss_layer:setMode('capture') 58 | end 59 | self.net:forward(target) 60 | end 61 | 62 | 63 | --[[ 64 | target: Tensor of shape (N, 3, H, W) giving pixels for content target images 65 | --]] 66 | function crit:setContentTarget(target) 67 | for i, style_loss_layer in ipairs(self.style_loss_layers) do 68 | style_loss_layer:setMode('none') 69 | end 70 | for i, content_loss_layer in ipairs(self.content_loss_layers) do 71 | content_loss_layer:setMode('capture') 72 | end 73 | self.net:forward(target) 74 | end 75 | 76 | 77 | function crit:setStyleWeight(weight) 78 | for i, style_loss_layer in ipairs(self.style_loss_layers) do 79 | style_loss_layer.strength = weight 80 | end 81 | end 82 | 83 | 84 | function crit:setContentWeight(weight) 85 | for i, content_loss_layer in ipairs(self.content_loss_layers) do 86 | content_loss_layer.strength = weight 87 | end 88 | end 89 | 90 | 91 | --[[ 92 | Inputs: 93 | - input: Tensor of shape (N, 3, H, W) giving pixels for generated images 94 | - target: Table with the following keys: 95 | - content_target: Tensor of shape (N, 3, H, W) 96 | - style_target: Tensor of shape (1, 3, H, W) 97 | --]] 98 | function crit:updateOutput(input, target) 99 | if target.content_target then 100 | self:setContentTarget(target.content_target) 101 | end 102 | if target.style_target then 103 | self.setStyleTarget(target.style_target) 104 | end 105 | 106 | -- Make sure to set all content and style loss layers to loss mode before 107 | -- running the image forward. 108 | for i, content_loss_layer in ipairs(self.content_loss_layers) do 109 | content_loss_layer:setMode('loss') 110 | end 111 | for i, style_loss_layer in ipairs(self.style_loss_layers) do 112 | style_loss_layer:setMode('loss') 113 | end 114 | 115 | local output = self.net:forward(input) 116 | 117 | -- Set up a tensor of zeros to pass as gradient to net in backward pass 118 | self.grad_net_output:resizeAs(output):zero() 119 | 120 | -- Go through and add up losses 121 | self.total_content_loss = 0 122 | self.content_losses = {} 123 | self.total_style_loss = 0 124 | self.style_losses = {} 125 | for i, content_loss_layer in ipairs(self.content_loss_layers) do 126 | self.total_content_loss = self.total_content_loss + content_loss_layer.loss 127 | table.insert(self.content_losses, content_loss_layer.loss) 128 | end 129 | for i, style_loss_layer in ipairs(self.style_loss_layers) do 130 | self.total_style_loss = self.total_style_loss + style_loss_layer.loss 131 | table.insert(self.style_losses, style_loss_layer.loss) 132 | end 133 | 134 | self.output = self.total_style_loss + self.total_content_loss 135 | return self.output 136 | end 137 | 138 | 139 | function crit:updateGradInput(input, target) 140 | self.gradInput = self.net:updateGradInput(input, self.grad_net_output) 141 | return self.gradInput 142 | end 143 | 144 | -------------------------------------------------------------------------------- /experiments/models/hang.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | require 'texture' 17 | 18 | local pad = nn.SpatialReflectionPadding 19 | local normalization = nn.InstanceNormalization 20 | local layer_utils = require 'texture.layer_utils' 21 | 22 | local M = {} 23 | 24 | function M.createModel(opt) 25 | -- Global variable keeping track of the input channels 26 | local iChannels 27 | 28 | -- The shortcut layer is either identity or 1x1 convolution 29 | local function shortcut(nInputPlane, nOutputPlane, stride) 30 | if nInputPlane ~= nOutputPlane then 31 | return nn.Sequential() 32 | :add(nn.SpatialConvolution(nInputPlane, nOutputPlane, 1, 1, stride, stride)) 33 | else 34 | return nn.Identity() 35 | end 36 | end 37 | 38 | local function full_shortcut(nInputPlane, nOutputPlane, stride) 39 | if nInputPlane ~= nOutputPlane or stride ~= 1 then 40 | return nn.Sequential() 41 | --:add(pad(1,0,1,0)) 42 | :add(nn.SpatialUpSamplingNearest(stride)) 43 | :add(nn.SpatialConvolution(nInputPlane, nOutputPlane, 1, 1, 1, 1)) 44 | --:add(nn.SpatialFullConvolution(nInputPlane, nOutputPlane, 1, 1, stride, stride, 1, 1, 1, 1)) 45 | else 46 | return nn.Identity() 47 | end 48 | end 49 | 50 | local function basic_block(n, stride) 51 | stride = stride or 1 52 | local nInputPlane = iChannels 53 | iChannels = n 54 | -- Convolutions 55 | local conv_block = nn.Sequential() 56 | 57 | conv_block:add(normalization(nInputPlane)) 58 | conv_block:add(nn.ReLU(true)) 59 | conv_block:add(pad(1, 1, 1, 1)) 60 | conv_block:add(nn.SpatialConvolution(nInputPlane, n, 3, 3, stride, stride, 0, 0)) 61 | 62 | conv_block:add(normalization(n)) 63 | conv_block:add(nn.ReLU(true)) 64 | conv_block:add(pad(1, 1, 1, 1)) 65 | conv_block:add(nn.SpatialConvolution(n, n, 3, 3, 1, 1, 0, 0)) 66 | 67 | local concat = nn.ConcatTable():add(conv_block):add(shortcut(nInputPlane, n, stride)) 68 | 69 | -- Sum 70 | local res_block = nn.Sequential() 71 | res_block:add(concat) 72 | res_block:add(nn.CAddTable()) 73 | return res_block 74 | end 75 | 76 | local function bottleneck(n, stride) 77 | stride = stride or 1 78 | local nInputPlane = iChannels 79 | iChannels = 4 * n 80 | -- Convolutions 81 | local conv_block = nn.Sequential() 82 | 83 | conv_block:add(normalization(nInputPlane)) 84 | conv_block:add(nn.ReLU(true)) 85 | conv_block:add(nn.SpatialConvolution(nInputPlane, n, 1, 1, 1, 1, 0, 0)) 86 | 87 | conv_block:add(normalization(n)) 88 | conv_block:add(nn.ReLU(true)) 89 | conv_block:add(pad(1, 1, 1, 1)) 90 | conv_block:add(nn.SpatialConvolution(n, n, 3, 3, stride, stride, 0, 0)) 91 | 92 | conv_block:add(normalization(n)) 93 | conv_block:add(nn.ReLU(true)) 94 | conv_block:add(nn.SpatialConvolution(n, n*4, 1, 1, 1, 1, 0, 0)) 95 | 96 | local concat = nn.ConcatTable():add(conv_block):add(shortcut(nInputPlane, n*4, stride)) 97 | 98 | -- Sum 99 | local res_block = nn.Sequential() 100 | res_block:add(concat) 101 | res_block:add(nn.CAddTable()) 102 | return res_block 103 | end 104 | 105 | local function full_bottleneck(n, stride) 106 | stride = stride or 1 107 | local nInputPlane = iChannels 108 | iChannels = 4 * n 109 | -- Convolutions 110 | local conv_block = nn.Sequential() 111 | 112 | conv_block:add(normalization(nInputPlane)) 113 | conv_block:add(nn.ReLU(true)) 114 | conv_block:add(nn.SpatialConvolution(nInputPlane, n, 1, 1, 1, 1, 0, 0)) 115 | 116 | conv_block:add(normalization(n)) 117 | conv_block:add(nn.ReLU(true)) 118 | 119 | if stride~=1 then 120 | conv_block:add(nn.SpatialUpSamplingNearest(stride)) 121 | conv_block:add(pad(1, 1, 1, 1)) 122 | conv_block:add(nn.SpatialConvolution(n, n, 3, 3, 1, 1, 0, 0)) 123 | else 124 | conv_block:add(pad(1, 1, 1, 1)) 125 | conv_block:add(nn.SpatialConvolution(n, n, 3, 3, 1, 1, 0, 0)) 126 | end 127 | conv_block:add(normalization(n)) 128 | conv_block:add(nn.ReLU(true)) 129 | conv_block:add(nn.SpatialConvolution(n, n*4, 1, 1, 1, 1, 0, 0)) 130 | 131 | local concat = nn.ConcatTable() 132 | :add(conv_block) 133 | :add(full_shortcut(nInputPlane, n*4, stride)) 134 | 135 | -- Sum 136 | local res_block = nn.Sequential() 137 | res_block:add(concat) 138 | res_block:add(nn.CAddTable()) 139 | return res_block 140 | end 141 | 142 | local function layer(block, features, count, stride) 143 | local s = nn.Sequential() 144 | for i=1,count do 145 | s:add(block(features, i==1 and stride or 1)) 146 | end 147 | return s 148 | end 149 | 150 | local model = nn.Sequential() 151 | model.cNetsNum = {} 152 | 153 | -- 256x256 154 | model:add(normalization(3)) 155 | model:add(pad(3, 3, 3, 3)) 156 | model:add(nn.SpatialConvolution(3, 64, 7, 7, 1, 1, 0, 0)) 157 | model:add(normalization(64)) 158 | model:add(nn.ReLU(true)) 159 | 160 | iChannels = 64 161 | local block = bottleneck -- basic_block 162 | 163 | model:add(layer(block, 32, 1, 2)) 164 | model:add(layer(block, 64, 1, 2)) 165 | 166 | -- 32x32x512 167 | model:add(nn.Inspiration(iChannels)) 168 | table.insert(model.cNetsNum,#model) 169 | model:add(normalization(iChannels)) 170 | model:add(nn.ReLU(true)) 171 | 172 | for i = 1,opt.model_nres do 173 | model:add(layer(block, 64, 1, 1)) 174 | end 175 | 176 | block = full_bottleneck 177 | model:add(layer(block, 32, 1, 2)) 178 | model:add(layer(block, 16, 1, 2)) 179 | 180 | model:add(normalization(64)) 181 | model:add(nn.ReLU(true)) 182 | 183 | model:add(pad(3, 3, 3, 3)) 184 | model:add(nn.SpatialConvolution(64, 3, 7, 7, 1, 1, 0, 0)) 185 | 186 | model:add(nn.TotalVariation(opt.tv_strength)) 187 | 188 | function model:setTarget(feat, dtype) 189 | model.modules[model.cNetsNum[1]]:setTarget(feat[3]:type(dtype)) 190 | end 191 | return model 192 | end 193 | 194 | function M.createCNets(opt) 195 | -- The descriptive network in the paper 196 | local cnet = torch.load(opt.loss_network) 197 | cnet:evaluate() 198 | cnet.style_layers = {} 199 | 200 | -- Set up calibrate layers 201 | for i, layer_string in ipairs(opt.style_layers) do 202 | local calibrator = nn.Calibrate() 203 | layer_utils.insert_after(cnet, layer_string, calibrator) 204 | table.insert(cnet.style_layers, calibrator) 205 | end 206 | layer_utils.trim_network(cnet) 207 | 208 | function cnet:calibrate(input) 209 | cnet:forward(input) 210 | local feat = {} 211 | for i, calibrator in ipairs(cnet.style_layers) do 212 | table.insert(feat, calibrator:getGram()) 213 | end 214 | return feat 215 | end 216 | 217 | return cnet 218 | end 219 | 220 | return M 221 | -------------------------------------------------------------------------------- /experiments/main.lua: -------------------------------------------------------------------------------- 1 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | -- Created by: Hang Zhang 3 | -- ECE Department, Rutgers University 4 | -- Email: zhang.hang@rutgers.edu 5 | -- Copyright (c) 2017 6 | -- 7 | -- Free to reuse and distribute this software for research or 8 | -- non-profit purpose, subject to the following conditions: 9 | -- 1. The code must retain the above copyright notice, this list of 10 | -- conditions. 11 | -- 2. Original authors' names are not deleted. 12 | -- 3. The authors' names are not used to endorse or promote products 13 | -- derived from this software 14 | --+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 15 | 16 | require 'texture' 17 | require 'image' 18 | require 'optim' 19 | 20 | require 'utils.DataLoader' 21 | 22 | local utils = require 'utils.utils' 23 | local preprocess = require 'utils.preprocess' 24 | local opts = require 'opts' 25 | local imgLoader = require 'utils.getImages' 26 | 27 | function main() 28 | local opt = opts.parse(arg) 29 | -- Parse layer strings and weights 30 | opt.content_layers, opt.content_weights = 31 | utils.parse_layers(opt.content_layers, opt.content_weights) 32 | opt.style_layers, opt.style_weights = 33 | utils.parse_layers(opt.style_layers, opt.style_weights) 34 | 35 | -- Figure out preprocessing 36 | if not preprocess[opt.preprocessing] then 37 | local msg = 'invalid -preprocessing "%s"; must be "vgg" or "resnet"' 38 | error(string.format(msg, opt.preprocessing)) 39 | end 40 | preprocess = preprocess[opt.preprocessing] 41 | 42 | -- Figure out the backend 43 | local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1) 44 | 45 | -- Style images 46 | local styleLoader = imgLoader(opt.style_image_folder) 47 | local featpath = opt.style_image_folder .. '/feat.t7' 48 | if not paths.filep(featpath) then 49 | local extractor = require "extractGram" 50 | extractor.exec(opt) 51 | end 52 | local feat = torch.load(featpath) 53 | feat = nn.utils.recursiveType(feat, 'torch.CudaTensor') 54 | 55 | -- Build the model 56 | local model = nil 57 | 58 | -- Checkpoint 59 | if opt.resume ~= '' then 60 | print('Loading checkpoint from ' .. opt.resume) 61 | model = torch.load(opt.resume).model:type(dtype) 62 | else 63 | print('Initializing model from scratch') 64 | models = require('models.' .. opt.model) 65 | model = models.createModel(opt):type(dtype) 66 | end 67 | 68 | if use_cudnn then cudnn.convert(model, cudnn) end 69 | model:training() 70 | print(model) 71 | 72 | -- Set up the perceptual loss function 73 | local percep_crit 74 | local loss_net = torch.load(opt.loss_network) 75 | local crit_args = { 76 | cnn = loss_net, 77 | style_layers = opt.style_layers, 78 | style_weights = opt.style_weights, 79 | content_layers = opt.content_layers, 80 | content_weights = opt.content_weights, 81 | } 82 | percep_crit = nn.PerceptualCriterion(crit_args):type(dtype) 83 | 84 | local loader = DataLoader(opt) 85 | local params, grad_params = model:getParameters() 86 | 87 | local function f(x) 88 | assert(x == params) 89 | grad_params:zero() 90 | 91 | local x, y = loader:getBatch('train') 92 | x, y = x:type(dtype), y:type(dtype) 93 | 94 | -- Run model forward 95 | local out = model:forward(x) 96 | local target = {content_target=y} 97 | local loss = percep_crit:forward(out, target) 98 | local grad_out = percep_crit:backward(out, target) 99 | 100 | -- Run model backward 101 | model:backward(x, grad_out) 102 | 103 | return loss, grad_params 104 | end 105 | 106 | local optim_state = {learningRate=opt.learning_rate} 107 | local train_loss_history = {} 108 | local val_loss_history = {} 109 | local val_loss_history_ts = {} 110 | local style_loss_history = nil 111 | 112 | style_loss_history = {} 113 | for i, k in ipairs(opt.style_layers) do 114 | style_loss_history[string.format('style-%d', k)] = {} 115 | end 116 | for i, k in ipairs(opt.content_layers) do 117 | style_loss_history[string.format('content-%d', k)] = {} 118 | end 119 | 120 | local style_weight = opt.style_weight 121 | for t = 1, opt.num_iterations do 122 | -- set Target Here 123 | if (t-1)%opt.style_iter == 0 then 124 | --print('Setting Style Target') 125 | local idx = (t-1)/opt.style_iter % #feat + 1 126 | 127 | local style_image = styleLoader:get(idx) 128 | style_image = image.scale(style_image, opt.style_image_size) 129 | style_image = preprocess.preprocess(style_image:add_dummy()) 130 | percep_crit:setStyleTarget(style_image:type(dtype)) 131 | 132 | local style_image_feat = feat[idx] 133 | model:setTarget(style_image_feat, dtype) 134 | end 135 | local epoch = t / loader.num_minibatches['train'] 136 | 137 | local _, loss = optim.adam(f, params, optim_state) 138 | 139 | table.insert(train_loss_history, loss[1]) 140 | 141 | for i, k in ipairs(opt.style_layers) do 142 | table.insert(style_loss_history[string.format('style-%d', k)], 143 | percep_crit.style_losses[i]) 144 | end 145 | for i, k in ipairs(opt.content_layers) do 146 | table.insert(style_loss_history[string.format('content-%d', k)], 147 | percep_crit.content_losses[i]) 148 | end 149 | 150 | print(string.format('Epoch %f, Iteration %d / %d, loss = %f', 151 | epoch, t, opt.num_iterations, loss[1]), optim_state.learningRate) 152 | 153 | if t % opt.checkpoint_every == 0 then 154 | -- Check loss on the validation set 155 | loader:reset('val') 156 | model:evaluate() 157 | local val_loss = 0 158 | print 'Running on validation set ... ' 159 | local val_batches = opt.num_val_batches 160 | for j = 1, val_batches do 161 | local x, y = loader:getBatch('val') 162 | x, y = x:type(dtype), y:type(dtype) 163 | local out = model:forward(x) 164 | --y = shave_y(x, y, out) 165 | 166 | local percep_loss = 0 167 | percep_loss = percep_crit:forward(out, {content_target=y}) 168 | val_loss = val_loss + percep_loss 169 | end 170 | val_loss = val_loss / val_batches 171 | print(string.format('val loss = %f', val_loss)) 172 | table.insert(val_loss_history, val_loss) 173 | table.insert(val_loss_history_ts, t) 174 | model:training() 175 | 176 | -- Save a checkpoint 177 | local checkpoint = { 178 | opt=opt, 179 | train_loss_history=train_loss_history, 180 | val_loss_history=val_loss_history, 181 | val_loss_history_ts=val_loss_history_ts, 182 | style_loss_history=style_loss_history, 183 | } 184 | local filename = string.format('%s.json', opt.checkpoint_name) 185 | paths.mkdir(paths.dirname(filename)) 186 | utils.write_json(filename, checkpoint) 187 | 188 | -- Save a torch checkpoint; convert the model to float first 189 | model:clearState() 190 | if use_cudnn then 191 | cudnn.convert(model, nn) 192 | end 193 | model:float() 194 | checkpoint.model = model 195 | filename = string.format('%s.t7', opt.checkpoint_name) 196 | torch.save(filename, checkpoint) 197 | 198 | -- Convert the model back 199 | model:type(dtype) 200 | if use_cudnn then 201 | cudnn.convert(model, cudnn) 202 | end 203 | params, grad_params = model:getParameters() 204 | 205 | collectgarbage() 206 | collectgarbage() 207 | end 208 | 209 | if opt.lr_decay_every > 0 and t % opt.lr_decay_every == 0 then 210 | local new_lr = opt.lr_decay_factor * optim_state.learningRate 211 | optim_state = {learningRate = new_lr} 212 | end 213 | end 214 | end 215 | 216 | 217 | main() 218 | 219 | -------------------------------------------------------------------------------- /cmake/select_compute_arch.cmake: -------------------------------------------------------------------------------- 1 | # Synopsis: 2 | # CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures]) 3 | # -- Selects GPU arch flags for nvcc based on target_CUDA_architectures 4 | # target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...) 5 | # - "Auto" detects local machine GPU compute arch at runtime. 6 | # - "Common" and "All" cover common and entire subsets of architectures 7 | # ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX 8 | # NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal 9 | # NUM: Any number. Only those pairs are currently accepted by NVCC though: 10 | # 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 11 | # Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable} 12 | # Additionally, sets ${out_variable}_readable to the resulting numeric list 13 | # Example: 14 | # CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell) 15 | # LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) 16 | # 17 | # More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA 18 | # 19 | 20 | # This list will be used for CUDA_ARCH_NAME = All option 21 | set(CUDA_KNOWN_GPU_ARCHITECTURES "Fermi" "Kepler" "Maxwell") 22 | 23 | # This list will be used for CUDA_ARCH_NAME = Common option (enabled by default) 24 | set(CUDA_COMMON_GPU_ARCHITECTURES "3.0" "3.5" "5.0") 25 | 26 | if (CUDA_VERSION VERSION_GREATER "6.5") 27 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra") 28 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2") 29 | endif () 30 | 31 | if (CUDA_VERSION VERSION_GREATER "7.5") 32 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal") 33 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1" "6.1+PTX") 34 | else() 35 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX") 36 | endif () 37 | 38 | 39 | 40 | ################################################################################################ 41 | # A function for automatic detection of GPUs installed (if autodetection is enabled) 42 | # Usage: 43 | # CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE) 44 | # 45 | function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) 46 | if(NOT CUDA_GPU_DETECT_OUTPUT) 47 | set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu) 48 | 49 | file(WRITE ${cufile} "" 50 | "#include \n" 51 | "int main()\n" 52 | "{\n" 53 | " int count = 0;\n" 54 | " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" 55 | " if (count == 0) return -1;\n" 56 | " for (int device = 0; device < count; ++device)\n" 57 | " {\n" 58 | " cudaDeviceProp prop;\n" 59 | " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" 60 | " std::printf(\"%d.%d \", prop.major, prop.minor);\n" 61 | " }\n" 62 | " return 0;\n" 63 | "}\n") 64 | 65 | execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${cufile}" 66 | "-ccbin" ${CMAKE_CXX_COMPILER} 67 | WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" 68 | RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out 69 | ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) 70 | 71 | if(nvcc_res EQUAL 0) 72 | string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}") 73 | set(CUDA_GPU_DETECT_OUTPUT ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_gpus tool" FORCE) 74 | endif() 75 | endif() 76 | 77 | if(NOT CUDA_GPU_DETECT_OUTPUT) 78 | message(STATUS "Automatic GPU detection failed. Building for common architectures.") 79 | set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE) 80 | else() 81 | set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT} PARENT_SCOPE) 82 | endif() 83 | endfunction() 84 | 85 | 86 | ################################################################################################ 87 | # Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list 88 | # Usage: 89 | # SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs]) 90 | function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) 91 | set(CUDA_ARCH_LIST "${ARGN}") 92 | 93 | if("X${CUDA_ARCH_LIST}" STREQUAL "X" ) 94 | set(CUDA_ARCH_LIST "Auto") 95 | endif() 96 | 97 | set(cuda_arch_bin) 98 | set(cuda_arch_ptx) 99 | 100 | if("${CUDA_ARCH_LIST}" STREQUAL "All") 101 | set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES}) 102 | elseif("${CUDA_ARCH_LIST}" STREQUAL "Common") 103 | set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES}) 104 | elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto") 105 | CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST) 106 | message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}") 107 | endif() 108 | 109 | # Now process the list and look for names 110 | string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") 111 | list(REMOVE_DUPLICATES CUDA_ARCH_LIST) 112 | foreach(arch_name ${CUDA_ARCH_LIST}) 113 | set(arch_bin) 114 | set(add_ptx FALSE) 115 | # Check to see if we are compiling PTX 116 | if(arch_name MATCHES "(.*)\\+PTX$") 117 | set(add_ptx TRUE) 118 | set(arch_name ${CMAKE_MATCH_1}) 119 | endif() 120 | if(arch_name MATCHES "(^[0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$") 121 | set(arch_bin ${CMAKE_MATCH_1}) 122 | set(arch_ptx ${arch_bin}) 123 | else() 124 | # Look for it in our list of known architectures 125 | if(${arch_name} STREQUAL "Fermi") 126 | set(arch_bin "2.0 2.1(2.0)") 127 | elseif(${arch_name} STREQUAL "Kepler+Tegra") 128 | set(arch_bin 3.2) 129 | elseif(${arch_name} STREQUAL "Kepler+Tesla") 130 | set(arch_bin 3.7) 131 | elseif(${arch_name} STREQUAL "Kepler") 132 | set(arch_bin 3.0 3.5) 133 | set(arch_ptx 3.5) 134 | elseif(${arch_name} STREQUAL "Maxwell+Tegra") 135 | set(arch_bin 5.3) 136 | elseif(${arch_name} STREQUAL "Maxwell") 137 | set(arch_bin 5.0 5.2) 138 | set(arch_ptx 5.2) 139 | elseif(${arch_name} STREQUAL "Pascal") 140 | set(arch_bin 6.0 6.1) 141 | set(arch_ptx 6.1) 142 | else() 143 | message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") 144 | endif() 145 | endif() 146 | if(NOT arch_bin) 147 | message(SEND_ERROR "arch_bin wasn't set for some reason") 148 | endif() 149 | list(APPEND cuda_arch_bin ${arch_bin}) 150 | if(add_ptx) 151 | if (NOT arch_ptx) 152 | set(arch_ptx ${arch_bin}) 153 | endif() 154 | list(APPEND cuda_arch_ptx ${arch_ptx}) 155 | endif() 156 | endforeach() 157 | 158 | # remove dots and convert to lists 159 | string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}") 160 | string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}") 161 | string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}") 162 | string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}") 163 | 164 | if(cuda_arch_bin) 165 | list(REMOVE_DUPLICATES cuda_arch_bin) 166 | endif() 167 | if(cuda_arch_ptx) 168 | list(REMOVE_DUPLICATES cuda_arch_ptx) 169 | endif() 170 | 171 | set(nvcc_flags "") 172 | set(nvcc_archs_readable "") 173 | 174 | # Tell NVCC to add binaries for the specified GPUs 175 | foreach(arch ${cuda_arch_bin}) 176 | if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)") 177 | # User explicitly specified ARCH for the concrete CODE 178 | list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1}) 179 | list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1}) 180 | else() 181 | # User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE 182 | list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch}) 183 | list(APPEND nvcc_archs_readable sm_${arch}) 184 | endif() 185 | endforeach() 186 | 187 | # Tell NVCC to add PTX intermediate code for the specified architectures 188 | foreach(arch ${cuda_arch_ptx}) 189 | list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch}) 190 | list(APPEND nvcc_archs_readable compute_${arch}) 191 | endforeach() 192 | 193 | string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}") 194 | set(${out_variable} ${nvcc_flags} PARENT_SCOPE) 195 | set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE) 196 | endfunction() 197 | --------------------------------------------------------------------------------