├── LICENSE ├── README.md ├── cnn └── vgg19 │ ├── .gitignore │ └── download_models.sh ├── dreamer.style ├── ezstyle.lua ├── gram.style ├── images ├── .DS_Store ├── _results │ ├── dreamer.png │ ├── gram.png │ ├── masked_gram.png │ └── mrf.png ├── data │ └── renoir_gram_mask.t7 ├── ford.png ├── lohan.png ├── picasso.png ├── renoir.png ├── renoir_style_mask.png ├── renoir_target_mask.png ├── trump.png └── winter.png ├── lib ├── amplayer.lua ├── caffe_image.lua ├── cleanup_model.lua ├── contentloss.lua ├── gramloss.lua ├── masked_gramloss.lua ├── mrfloss.lua ├── randlayer.lua └── tvloss.lua ├── masked_gram.style ├── mrf.style ├── styled_cnn.lua └── tools └── buildMask.lua /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Zhou Chang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # easyStyle 2 | All kinds of neural style transformer. 3 | 4 | This project collects many kinds of nerual style transformer , including 5 | 6 | * dreamer mode 7 | 8 | * gram matrix style 9 | 10 | * MRF style 11 | 12 | * guided style/patch transformer 13 | 14 | 15 | My project is foccus on clean and simple implementation of all kinds of algorithm. 16 | 17 | 18 | --- 19 | 20 | ## 1. install and setup 21 | 22 | Install following package of Torch7. 23 | 24 | ``` 25 | cunn 26 | loadcaffe 27 | cudnn 28 | ``` 29 | 30 | This project needs a GPU with 4G memory at least. Firstly you should download VGG19 caffe model. 31 | 32 | ``` 33 | cd cnn/vgg19 34 | source ./download_models.sh 35 | ``` 36 | 37 | ## 2. Quick demo 38 | 39 | The .style files descript the arch of network used in style transformer, they are based on Lua language. 40 | The .style files is very simple ,all the paramters are configed in thease files. 41 | 42 | ### 2.1 dreamer mode 43 | 44 | ``` 45 | th ezstyle ./dreamer.style 46 | ``` 47 |

48 | 49 | 50 |

51 | 52 | 53 | 54 | ### 2.2 gram matrix mode 55 | 56 | ``` 57 | th ezstyle ./gram.style 58 | ``` 59 |

60 | 61 | 62 | 63 |

64 | 65 | 66 | ### 2.3 MRF mode 67 | 68 | ``` 69 | th ezstyle ./mrf.style 70 | ``` 71 |

72 | 73 | 74 | 75 |

76 | 77 | 78 | ### 2.4 guided mode 1 ( masked gram style transform ) 79 | 80 | ``` 81 | th ezstyle ./masked_gram.style 82 | ``` 83 |

84 | 85 | 86 | 87 |

88 | 89 | 90 | ### 2.5 guided mode 2 ( masked mrf ) 91 | 92 | [WIP] 93 | 94 | 95 | ### 3. Resources 96 | 97 | All of code is coming from following projects, I have make them more simpler and stupid :). 98 | 99 | https://github.com/chuanli11/CNNMRF 100 | 101 | https://github.com/alexjc/neural-doodle 102 | 103 | https://github.com/awentzonline/image-analogies 104 | 105 | https://github.com/jcjohnson/neural-style 106 | 107 | https://github.com/DmitryUlyanov/fast-neural-doodle 108 | 109 | 110 | -------------------------------------------------------------------------------- /cnn/vgg19/.gitignore: -------------------------------------------------------------------------------- 1 | VGG_ILSVRC_19_layers.caffemodel 2 | VGG_ILSVRC_19_layers_deploy.prototxt 3 | VGG_ILSVRC_19_layers_deploy.prototxt.lua 4 | VGG_ILSVRC_19_layers_deploy.prototxt.cpu.lua 5 | VGG_ILSVRC_19_layers_deploy.prototxt.opencl.lua 6 | vgg_normalised.caffemodel 7 | -------------------------------------------------------------------------------- /cnn/vgg19/download_models.sh: -------------------------------------------------------------------------------- 1 | cd models 2 | wget -c https://gist.githubusercontent.com/ksimonyan/3785162f95cd2d5fee77/raw/bb2b4fe0a9bb0669211cf3d0bc949dfdda173e9e/VGG_ILSVRC_19_layers_deploy.prototxt 3 | wget -c --no-check-certificate https://bethgelab.org/media/uploads/deeptextures/vgg_normalised.caffemodel 4 | wget -c http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel 5 | cd .. 6 | -------------------------------------------------------------------------------- /dreamer.style: -------------------------------------------------------------------------------- 1 | local net = { 2 | image_list = {'./images/winter.png', './output.png'}, 3 | input = 1, 4 | output = 2, 5 | 6 | convergence = false, 7 | maxIterate = 10, 8 | step = 0.0001, 9 | 10 | cnn = { 11 | proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt', 12 | caffemodel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel' 13 | }, 14 | 15 | net = { 16 | { 17 | layer = 'relu2_1', 18 | type = 'amp', 19 | ratio = 1.0, 20 | }, 21 | { 22 | layer = 'relu4_4', 23 | type = 'amp', 24 | ratio = 2.0, 25 | } 26 | } 27 | } 28 | 29 | return net 30 | -------------------------------------------------------------------------------- /ezstyle.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | require('nngraph') 3 | require('loadcaffe') 4 | require('xlua') 5 | require('image') 6 | require('optim') 7 | require('cunn') 8 | 9 | require('./lib/tvloss') 10 | require('./lib/contentloss') 11 | require('./lib/gramloss') 12 | require('./lib/mrfloss') 13 | require('./lib/masked_gramloss') 14 | require('./lib/amplayer') 15 | require('./lib/randlayer') 16 | 17 | local caffeImage = require('./lib/caffe_image') 18 | 19 | ---------------------------------------------------------------------------------------- 20 | local g = {} 21 | 22 | local doRevert = function() 23 | local currentImage = g.x 24 | for i = 1, g.conf.maxIterate do 25 | local inout = g.net:forward(currentImage) 26 | inout = g.net:backward(currentImage, g.dy) 27 | 28 | currentImage:add(inout * (-1 * g.conf.step) ); 29 | currentImage:clamp(-128,128) 30 | 31 | collectgarbage() 32 | xlua.progress(i, g.conf.maxIterate) 33 | end 34 | end 35 | 36 | local doConvergence = function() 37 | local optim_state = { 38 | maxIter = g.conf.maxIterate, 39 | verbose = true, 40 | } 41 | 42 | local num_calls = 0 43 | local function feval(x) 44 | 45 | num_calls = num_calls + 1 46 | 47 | g.net:forward(x) 48 | local grad = g.net:updateGradInput(x, g.dy) 49 | local loss = 0 50 | 51 | for _, mod in ipairs(g.modifier) do 52 | loss = loss + mod.loss 53 | end 54 | 55 | print(">>>>>>>>>" .. loss) 56 | --xlua.progress(num_calls, optim_state.maxIter) 57 | 58 | collectgarbage() 59 | -- optim.lbfgs expects a vector for gradients 60 | return loss, grad:view(grad:nElement()) 61 | end 62 | 63 | local x, losses = optim.lbfgs(feval, g.x, optim_state) 64 | end 65 | 66 | local main = function() 67 | torch.setdefaulttensortype('torch.FloatTensor') 68 | torch.manualSeed(1979) 69 | if ( #arg < 1) then 70 | print("Please input config file!") 71 | os.exit(0) 72 | end 73 | 74 | -- init 75 | g.conf = dofile(arg[1]) 76 | g.cnn = loadCNN(g.conf.cnn) 77 | g.net, g.modifier = buildNetwork(g.conf, g.cnn) 78 | g.x = loadInput(g.conf) 79 | g.dy = torch.zeros( g.net:forward(g.x):size() ) 80 | 81 | -- cuda 82 | g.net:cuda() 83 | g.x = g.x:cuda() 84 | g.dy = g.dy:cuda() 85 | 86 | print(g.net) 87 | 88 | collectgarbage() 89 | if (g.conf.convergence) then 90 | doConvergence() 91 | else 92 | doRevert() 93 | end 94 | 95 | local img = caffeImage.caffe2img(g.x:float()) 96 | image.savePNG(g.conf.image_list[g.conf.output], img) 97 | end 98 | 99 | ----------------------------------------------------------------------------------------- 100 | -- helper functions 101 | 102 | string.startsWith = function(self, str) 103 | return self:find('^' .. str) ~= nil 104 | end 105 | 106 | function loadInput(conf) 107 | local img = nil 108 | if ( conf.image_list[conf.input] == nil ) then 109 | img = torch.rand(3, conf.height, conf.width) 110 | else 111 | img = image.load(conf.image_list[conf.input], 3) 112 | end 113 | 114 | img = caffeImage.img2caffe(img) 115 | return img 116 | end 117 | 118 | function loadCNN(cnnFiles) 119 | local fullModel = loadcaffe.load(cnnFiles.proto, cnnFiles.caffemodel, 'nn') 120 | local cnn = nn.Sequential(); 121 | for i = 1, #fullModel do 122 | local name = fullModel:get(i).name 123 | if ( name:startsWith('relu') or name:startsWith('conv') or name:startsWith('pool') ) then 124 | cnn:add( fullModel:get(i) ) 125 | else 126 | break 127 | end 128 | end 129 | fullModel = nil 130 | collectgarbage() 131 | 132 | return cnn 133 | end 134 | 135 | function buildNetwork(conf, cnn) 136 | local net = nn.Sequential() 137 | local modifier = {} 138 | 139 | local nindex = 1 140 | if ( conf.net[1].layer == 'input') then 141 | local layer = buildLayer(net, conf, 1, cnn) 142 | net:add(layer) 143 | nindex = 2 144 | end 145 | 146 | for i = 1, #cnn do 147 | local name = cnn:get(i).name 148 | net:add(cnn:get(i)) 149 | 150 | if ( name == conf.net[nindex].layer ) then 151 | local layer = buildLayer(net, conf, nindex, cnn) 152 | net:add(layer) 153 | 154 | table.insert(modifier, layer) 155 | 156 | nindex = nindex + 1 157 | if ( nindex > #conf.net ) then 158 | break 159 | end 160 | end 161 | collectgarbage() 162 | end 163 | 164 | return net, modifier 165 | end 166 | 167 | function buildLayer(net, conf, nindex, cnn) 168 | local layer = nil 169 | 170 | if ( conf.net[nindex].type == "tvloss" ) then 171 | layer = nn.TVLoss(conf.net[nindex].weight) 172 | elseif ( conf.net[nindex].type == "amp") then 173 | layer = nn.AmpLayer(conf.net[nindex].ratio) 174 | elseif ( conf.net[nindex].type == "rand") then 175 | layer = nn.RandLayer() 176 | elseif ( conf.net[nindex].type == "content") then 177 | local targetImage = conf.image_list[ conf.net[nindex].target] 178 | targetImage = image.load(targetImage,3) 179 | local targetCaffe = caffeImage.img2caffe(targetImage) 180 | local target = net:forward(targetCaffe) 181 | 182 | layer = nn.ContentLoss(conf.net[nindex].weight, target) 183 | elseif ( conf.net[nindex].type == "gram") then 184 | local targetImage = conf.image_list[ conf.net[nindex].target] 185 | targetImage = image.load(targetImage,3) 186 | local targetCaffe = caffeImage.img2caffe(targetImage) 187 | local target = net:forward(targetCaffe) 188 | 189 | layer = nn.GramLoss(conf.net[nindex].weight, target) 190 | elseif ( conf.net[nindex].type == "mrf") then 191 | local targetImage = conf.image_list[ conf.net[nindex].target] 192 | targetImage = image.load(targetImage,3) 193 | local targetCaffe = caffeImage.img2caffe(targetImage) 194 | local target = net:forward(targetCaffe):clone() 195 | 196 | local inputImage = conf.image_list[ conf.input] 197 | inputImage = image.load(inputImage, 3) 198 | local inputCaffe = caffeImage.img2caffe(inputImage) 199 | local input = net:forward(inputCaffe) 200 | 201 | layer = nn.MRFLoss(conf.net[nindex].weight, input, target) 202 | elseif ( conf.net[nindex].type == 'mask_gram') then 203 | local styleImage = conf.image_list[ conf.net[nindex].style] 204 | styleImage = image.load(styleImage,3) 205 | local styleCaffe = caffeImage.img2caffe(styleImage) 206 | local style = net:forward(styleCaffe):clone() 207 | 208 | local masks = torch.load ( conf.image_list[ conf.net[nindex].mask], 'ascii') 209 | 210 | layer = nn.MaskedGramLoss(conf.net[nindex].weight, style, masks) 211 | end 212 | 213 | return layer 214 | end 215 | 216 | ----------------------------------------------------------------------------------------- 217 | main() 218 | 219 | -------------------------------------------------------------------------------- /gram.style: -------------------------------------------------------------------------------- 1 | local net = { 2 | image_list = {'./images/trump.png', './images/picasso.png', './output.png'}, 3 | input = 1, 4 | output = 3, 5 | 6 | convergence = true, 7 | maxIterate = 500, 8 | 9 | cnn = { 10 | proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt', 11 | caffemodel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel' 12 | }, 13 | 14 | net = { 15 | { 16 | layer = 'input', 17 | type = 'tvloss', 18 | weight = 0.001, 19 | }, 20 | { 21 | layer = 'relu1_1', 22 | type = 'gram', 23 | weight = 1, 24 | target = 2, 25 | }, 26 | { 27 | layer = 'relu2_1', 28 | type = 'gram', 29 | weight = 1, 30 | target = 2, 31 | }, 32 | { 33 | layer = 'relu3_1', 34 | type = 'gram', 35 | weight = 1, 36 | target = 2, 37 | }, 38 | { 39 | layer = 'relu4_1', 40 | type = 'gram', 41 | weight = 1, 42 | target = 2, 43 | }, 44 | { 45 | layer = 'relu4_2', 46 | type = 'content', 47 | weight = 1000.0, 48 | target = 1, 49 | }, 50 | { 51 | layer = 'relu5_1', 52 | type = 'gram', 53 | weight = 1, 54 | target = 2, 55 | } 56 | } 57 | } 58 | 59 | return net 60 | -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/.DS_Store -------------------------------------------------------------------------------- /images/_results/dreamer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/_results/dreamer.png -------------------------------------------------------------------------------- /images/_results/gram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/_results/gram.png -------------------------------------------------------------------------------- /images/_results/masked_gram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/_results/masked_gram.png -------------------------------------------------------------------------------- /images/_results/mrf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/_results/mrf.png -------------------------------------------------------------------------------- /images/ford.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/ford.png -------------------------------------------------------------------------------- /images/lohan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/lohan.png -------------------------------------------------------------------------------- /images/picasso.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/picasso.png -------------------------------------------------------------------------------- /images/renoir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/renoir.png -------------------------------------------------------------------------------- /images/renoir_style_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/renoir_style_mask.png -------------------------------------------------------------------------------- /images/renoir_target_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/renoir_target_mask.png -------------------------------------------------------------------------------- /images/trump.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/trump.png -------------------------------------------------------------------------------- /images/winter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/winter.png -------------------------------------------------------------------------------- /lib/amplayer.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | 3 | local AmpLayer, parent = torch.class('nn.AmpLayer', 'nn.Module') 4 | 5 | function AmpLayer:__init(ratio) 6 | parent.__init(self) 7 | self.ratio = ratio 8 | self.loss = 0 9 | end 10 | 11 | function AmpLayer:updateOutput(input) 12 | self.output = input 13 | return self.output 14 | end 15 | 16 | function AmpLayer:updateGradInput(input, gradOutput) 17 | self.gradInput:resizeAs(input):copy(input) 18 | 19 | self.gradInput:mul(-1*self.ratio) 20 | self.gradInput:add(gradOutput) 21 | return self.gradInput 22 | end 23 | 24 | -------------------------------------------------------------------------------- /lib/caffe_image.lua: -------------------------------------------------------------------------------- 1 | require('image') 2 | 3 | local caffeImage = {} 4 | 5 | caffeImage.img2caffe = function(img) 6 | local mean_pixel = torch.Tensor({103.939, 116.779, 123.68}) 7 | local perm = torch.LongTensor{3, 2, 1} 8 | img = img:index(1, perm):mul(256.0) 9 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 10 | img:add(-1, mean_pixel) 11 | return img 12 | end 13 | 14 | caffeImage.caffe2img = function(img) 15 | local mean_pixel = torch.Tensor({103.939, 116.779, 123.68}) 16 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 17 | img = img + mean_pixel 18 | local perm = torch.LongTensor{3, 2, 1} 19 | img = img:index(1, perm):div(256.0) 20 | return img 21 | end 22 | 23 | return caffeImage 24 | -------------------------------------------------------------------------------- /lib/cleanup_model.lua: -------------------------------------------------------------------------------- 1 | -- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049 2 | 3 | local function zeroDataSize(data) 4 | if type(data) == 'table' then 5 | for i = 1, #data do 6 | data[i] = zeroDataSize(data[i]) 7 | end 8 | elseif type(data) == 'userdata' then 9 | data = torch.Tensor():typeAs(data) 10 | end 11 | return data 12 | end 13 | -- Resize the output, gradInput, etc temporary tensors to zero (so that the 14 | -- on disk size is smaller) 15 | local function cleanupModel(node) 16 | if node.output ~= nil then 17 | node.output = zeroDataSize(node.output) 18 | end 19 | if node.gradInput ~= nil then 20 | node.gradInput = zeroDataSize(node.gradInput) 21 | end 22 | if node.finput ~= nil then 23 | node.finput = zeroDataSize(node.finput) 24 | end 25 | if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then 26 | if node.negative ~= nil then 27 | node.negative = zeroDataSize(node.negative) 28 | end 29 | end 30 | if tostring(node) == "nn.Dropout" then 31 | if node.noise ~= nil then 32 | node.noise = zeroDataSize(node.noise) 33 | end 34 | end 35 | -- Recurse on nodes with 'modules' 36 | if (node.modules ~= nil) then 37 | if (type(node.modules) == 'table') then 38 | for i = 1, #node.modules do 39 | local child = node.modules[i] 40 | cleanupModel(child) 41 | end 42 | end 43 | end 44 | end 45 | 46 | return cleanupModel 47 | -------------------------------------------------------------------------------- /lib/contentloss.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | 3 | -- Define an nn Module to compute content loss in-place 4 | local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module') 5 | 6 | function ContentLoss:__init(strength, target, normalize) 7 | parent.__init(self) 8 | self.strength = strength 9 | if ( target ~= nil) then 10 | self.target = target:clone() 11 | end 12 | self.normalize = normalize or false 13 | self.loss = 0 14 | self.crit = nn.MSECriterion() 15 | end 16 | 17 | function ContentLoss:setTarget(target) 18 | self.target = target:clone() 19 | end 20 | 21 | function ContentLoss:updateOutput(input) 22 | if self.target and input:nElement() == self.target:nElement() then 23 | self.loss = self.crit:forward(input, self.target) * self.strength 24 | end 25 | 26 | self.output = input 27 | return self.output 28 | end 29 | 30 | function ContentLoss:updateGradInput(input, gradOutput) 31 | if input:nElement() == self.target:nElement() then 32 | self.gradInput = self.crit:backward(input, self.target) 33 | end 34 | if self.normalize then 35 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) 36 | end 37 | self.gradInput:mul(self.strength) 38 | 39 | self.gradInput:add(gradOutput) 40 | return self.gradInput 41 | end 42 | 43 | 44 | -------------------------------------------------------------------------------- /lib/gramloss.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | 3 | -- Returns a network that computes the CxC Gram matrix from inputs 4 | -- of size C x H x W 5 | function GramMatrix() 6 | local net = nn.Sequential() 7 | net:add(nn.View(-1):setNumInputDims(2)) 8 | local concat = nn.ConcatTable() 9 | concat:add(nn.Identity()) 10 | concat:add(nn.Identity()) 11 | net:add(concat) 12 | net:add(nn.MM(false, true)) 13 | return net 14 | end 15 | 16 | 17 | -- Define an nn Module to compute style loss in-place 18 | local GramLoss, parent = torch.class('nn.GramLoss', 'nn.Module') 19 | 20 | function GramLoss:__init(strength, target, normalize) 21 | parent.__init(self) 22 | self.normalize = normalize or false 23 | self.strength = strength 24 | self.loss = 0 25 | 26 | self.gram = GramMatrix() 27 | self.G = nil 28 | self.crit = nn.MSECriterion() 29 | 30 | local tsize = target:size() 31 | local img_size = tsize[2] * tsize[3] 32 | 33 | self.target = self.gram:forward(target):clone() 34 | self.target:div( img_size ) 35 | end 36 | 37 | function GramLoss:updateOutput(input) 38 | local tsize = input:size() 39 | local img_size = tsize[2] * tsize[3] 40 | 41 | self.G = self.gram:forward(input) 42 | self.G:div(img_size) 43 | 44 | self.loss = self.crit:forward(self.G, self.target) 45 | self.loss = self.loss * self.strength 46 | self.output = input 47 | return self.output 48 | end 49 | 50 | function GramLoss:updateGradInput(input, gradOutput) 51 | local dG = self.crit:backward(self.G, self.target) 52 | 53 | local tsize = input:size() 54 | local img_size = tsize[2] * tsize[3] 55 | dG:div(img_size) 56 | 57 | self.gradInput = self.gram:backward(input, dG) 58 | if self.normalize then 59 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) 60 | end 61 | self.gradInput:mul(self.strength) 62 | self.gradInput:add(gradOutput) 63 | return self.gradInput 64 | end 65 | 66 | -------------------------------------------------------------------------------- /lib/masked_gramloss.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | require('image') 3 | 4 | -- Returns a network that computes the CxC Gram matrix from inputs 5 | -- of size C x H x W 6 | function GramMatrix() 7 | local net = nn.Sequential() 8 | net:add(nn.View(-1):setNumInputDims(2)) 9 | local concat = nn.ConcatTable() 10 | concat:add(nn.Identity()) 11 | concat:add(nn.Identity()) 12 | net:add(concat) 13 | net:add(nn.MM(false, true)) 14 | return net 15 | end 16 | 17 | 18 | -- Define an nn Module to compute style loss in-place 19 | local MaskedGramLoss, parent = torch.class('nn.MaskedGramLoss', 'nn.Module') 20 | 21 | function MaskedGramLoss:__init(strength, style, masks) 22 | parent.__init(self) 23 | self.strength = strength 24 | self.loss = 0 25 | self.crit = nn.MSECriterion() 26 | 27 | local channel = style:size()[1] 28 | local hei = style:size()[2] 29 | local wid = style:size()[3] 30 | 31 | local gram = GramMatrix() 32 | 33 | local allGramTarget = {} 34 | 35 | local maskedLoss = nn.ConcatTable() 36 | for i = 1, #masks.style do 37 | local style_mask = image.scale(masks.style[i], wid, hei):float() 38 | 39 | style_mask = style_mask:view(1, hei, wid):expandAs(style) 40 | style_mask = torch.cmul(style_mask, style) 41 | 42 | allGramTarget[i] = gram:forward(style_mask):clone() 43 | allGramTarget[i]:div(wid*hei) 44 | 45 | local target_mask = image.scale(masks.target[i], wid, hei):float() 46 | target_mask = target_mask:view(1, hei, wid):expandAs(style) 47 | 48 | local mask_net = nn.Sequential() 49 | local cmul = nn.CMul(style:size()) 50 | cmul.weight:copy( target_mask) 51 | 52 | mask_net:add(cmul) 53 | mask_net:add(GramMatrix()) 54 | maskedLoss:add(mask_net) 55 | end 56 | 57 | self.allGramTarget = allGramTarget 58 | self.maskedLoss = maskedLoss 59 | 60 | end 61 | 62 | function MaskedGramLoss:updateOutput(input) 63 | local tsize = input:size() 64 | local img_size = tsize[2] * tsize[3] 65 | 66 | self.loss = 0 67 | local maskedGram = self.maskedLoss:forward(input) 68 | for i = 1, #maskedGram do 69 | maskedGram[i]:div(img_size) 70 | self.loss = self.loss + self.crit:forward(maskedGram[i], self.allGramTarget[i]) 71 | end 72 | self.maskedGram = maskedGram 73 | 74 | self.loss = self.loss * self.strength 75 | self.output = input 76 | return self.output 77 | end 78 | 79 | function MaskedGramLoss:updateGradInput(input, gradOutput) 80 | local tsize = input:size() 81 | local img_size = tsize[2] * tsize[3] 82 | 83 | local dG = {} 84 | for i = 1, #self.allGramTarget do 85 | dG[i] = self.crit:backward(self.maskedGram[i], self.allGramTarget[i]):clone() 86 | dG[i]:mul(img_size) 87 | end 88 | 89 | self.gradInput = self.maskedLoss:backward(input, dG) 90 | self.gradInput:mul(self.strength) 91 | self.gradInput:add(gradOutput) 92 | return self.gradInput 93 | end 94 | 95 | -------------------------------------------------------------------------------- /lib/mrfloss.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | 3 | local buildMRFTarget = function(input, ref) 4 | -- this is a simple version , no scale, no rotation 5 | local channel = input:size()[1] 6 | local height = input:size()[2] 7 | local width = input:size()[3] 8 | local width_ref = ref:size()[3] 9 | 10 | local normConv = nn.SpatialConvolutionMM(channel, 1, 3, 3, 1, 1, 0, 0) 11 | normConv.weight:fill(1.0) 12 | normConv.bias:fill(0) 13 | local normValue = normConv:forward(ref:abs()) 14 | normValue:div(9*channel) 15 | 16 | local target = input:clone() 17 | 18 | local x, y = 1, 1 19 | while true do 20 | -- processing patch by patch 21 | if ( y > height - 2) then 22 | break 23 | end 24 | 25 | local conv = nn.SpatialConvolution(channel, 1, 3, 3, 1, 1, 0, 0) 26 | conv.weight[1]:copy( input[{{}, {y,y+2}, {x, x+2}}] ) 27 | conv.bias:fill(0) 28 | 29 | local scores = conv:forward(ref) 30 | scores:cdiv(normValue) 31 | 32 | -- find best match patch from reference images 33 | local _, pos = scores:view(-1):max(1) 34 | pos = pos[1] - 1 35 | local bestX = pos % ( width_ref - 2) + 1 36 | local bestY = math.floor( pos / ( width_ref - 2) ) + 1 37 | 38 | target[{{}, {y,y+2}, {x, x+2}}]:copy( ref[{{},{bestY, bestY+2},{bestX, bestX+2}}]) 39 | 40 | x = x + 3 41 | if ( x > width - 2) then 42 | x = 1 43 | y = y + 3 44 | end 45 | collectgarbage() 46 | end 47 | 48 | return target 49 | end 50 | 51 | 52 | local MRFLoss, parent = torch.class('nn.MRFLoss', 'nn.Module') 53 | 54 | function MRFLoss:__init(strength, input, ref, normalize) 55 | parent.__init(self) 56 | self.normalize = normalize or false 57 | self.strength = strength 58 | self.loss = 0 59 | self.crit = nn.MSECriterion() 60 | 61 | self.target = buildMRFTarget(input, ref) 62 | end 63 | 64 | 65 | function MRFLoss:updateOutput(input) 66 | if ( self.target:nElement() == input:nElement() ) then 67 | self.loss = self.crit:forward(input, self.target) * self.strength 68 | end 69 | 70 | self.output = input 71 | return self.output 72 | end 73 | 74 | function MRFLoss:updateGradInput(input, gradOutput) 75 | self.gradInput:resizeAs(input):zero() 76 | 77 | if input:nElement() == self.target:nElement() then 78 | self.gradInput = self.crit:backward(input, self.target) 79 | end 80 | if self.normalize then 81 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) 82 | end 83 | self.gradInput:mul(self.strength) 84 | self.gradInput:add(gradOutput) 85 | return self.gradInput 86 | end 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /lib/randlayer.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | 3 | local RandLayer, parent = torch.class('nn.RandLayer', 'nn.Module') 4 | 5 | function RandLayer:__init() 6 | parent.__init(self) 7 | self.randMap = nil 8 | self.loss = 0 9 | end 10 | 11 | function RandLayer:updateOutput(input) 12 | self.output = input 13 | 14 | if ( self.randMap == nil or self.randMap:isSameSizeAs(input)) then 15 | self.randMap = torch.rand(input:size()) * -1 16 | end 17 | 18 | return self.output 19 | end 20 | 21 | function RandLayer:updateGradInput(input, gradOutput) 22 | self.gradInput:resizeAs(input):copy(input) 23 | 24 | self.gradInput:cmul(self.randMap) 25 | self.gradInput:add(gradOutput) 26 | return self.gradInput 27 | end 28 | 29 | -------------------------------------------------------------------------------- /lib/tvloss.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | 3 | local TVLoss, parent = torch.class('nn.TVLoss', 'nn.Module') 4 | 5 | function TVLoss:__init(strength) 6 | parent.__init(self) 7 | self.strength = strength 8 | self.x_diff = torch.Tensor() 9 | self.y_diff = torch.Tensor() 10 | end 11 | 12 | function TVLoss:updateOutput(input) 13 | self.output = input 14 | return self.output 15 | end 16 | 17 | -- TV loss backward pass inspired by kaishengtai/neuralart 18 | function TVLoss:updateGradInput(input, gradOutput) 19 | self.gradInput:resizeAs(input):zero() 20 | local C, H, W = input:size(1), input:size(2), input:size(3) 21 | self.x_diff:resize(3, H - 1, W - 1) 22 | self.y_diff:resize(3, H - 1, W - 1) 23 | self.x_diff:copy(input[{{}, {1, -2}, {1, -2}}]) 24 | self.x_diff:add(-1, input[{{}, {1, -2}, {2, -1}}]) 25 | self.y_diff:copy(input[{{}, {1, -2}, {1, -2}}]) 26 | self.y_diff:add(-1, input[{{}, {2, -1}, {1, -2}}]) 27 | self.gradInput[{{}, {1, -2}, {1, -2}}]:add(self.x_diff):add(self.y_diff) 28 | self.gradInput[{{}, {1, -2}, {2, -1}}]:add(-1, self.x_diff) 29 | self.gradInput[{{}, {2, -1}, {1, -2}}]:add(-1, self.y_diff) 30 | self.gradInput:mul(self.strength) 31 | self.gradInput:add(gradOutput) 32 | return self.gradInput 33 | end 34 | 35 | 36 | -------------------------------------------------------------------------------- /masked_gram.style: -------------------------------------------------------------------------------- 1 | local net = { 2 | image_list = {'./images/renoir.png', './images/data/renoir_gram_mask.t7', './output.png'}, 3 | input = -1, 4 | output = 3, 5 | 6 | width = 512, 7 | height = 320, 8 | 9 | convergence = true, 10 | maxIterate = 1500, 11 | 12 | cnn = { 13 | proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt', 14 | caffemodel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel' 15 | }, 16 | 17 | net = { 18 | { 19 | layer = 'input', 20 | type = 'tvloss', 21 | weight = 0.001, 22 | }, 23 | { 24 | layer = 'relu1_1', 25 | type = 'mask_gram', 26 | weight = 1, 27 | style = 1, 28 | mask = 2 29 | }, 30 | { 31 | layer = 'relu2_1', 32 | type = 'mask_gram', 33 | weight = 1, 34 | style = 1, 35 | mask = 2 36 | }, 37 | { 38 | layer = 'relu3_1', 39 | type = 'mask_gram', 40 | weight = 1, 41 | style = 1, 42 | mask = 2 43 | }, 44 | { 45 | layer = 'relu4_1', 46 | type = 'mask_gram', 47 | weight = 1, 48 | style = 1, 49 | mask = 2 50 | }, 51 | { 52 | layer = 'relu5_1', 53 | type = 'mask_gram', 54 | weight = 1, 55 | style = 1, 56 | mask = 2 57 | } 58 | } 59 | } 60 | 61 | return net 62 | -------------------------------------------------------------------------------- /mrf.style: -------------------------------------------------------------------------------- 1 | local net = { 2 | image_list = {'./images/ford.png', './images/lohan.png', './output.png'}, 3 | input = 1, 4 | output = 3, 5 | 6 | convergence = true, 7 | maxIterate = 1000, 8 | 9 | cnn = { 10 | proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt', 11 | caffemodel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel' 12 | }, 13 | 14 | net = { 15 | { 16 | layer = 'input', 17 | type = 'tvloss', 18 | weight = 0.001, 19 | }, 20 | { 21 | layer = 'relu4_1', 22 | type = 'mrf', 23 | weight = 10.0, 24 | target = 2, 25 | }, 26 | { 27 | layer = 'relu4_2', 28 | type = 'content', 29 | weight = 1.0, 30 | target = 1, 31 | } 32 | } 33 | } 34 | 35 | return net 36 | -------------------------------------------------------------------------------- /styled_cnn.lua: -------------------------------------------------------------------------------- 1 | require('nn') 2 | require('nngraph') 3 | require('loadcaffe') 4 | require('xlua') 5 | require('image') 6 | require('optim') 7 | require('cunn') 8 | require('cudnn') 9 | 10 | require('./lib/tvloss') 11 | require('./lib/contentloss') 12 | require('./lib/gramloss') 13 | require('./lib/mrfloss') 14 | require('./lib/masked_gramloss') 15 | require('./lib/amplayer') 16 | require('./lib/randlayer') 17 | 18 | 19 | local cleanupModel = require('./lib/cleanup_model') 20 | local caffeImage = require('./lib/caffe_image') 21 | local g = {} 22 | g.styleImage = './images/picasso.png' 23 | 24 | g.trainImages_Path = './scene/' 25 | g.trainImages_Number = 16657 26 | 27 | 28 | ----------------------------------------------------------------------------------------- 29 | -- helper functions 30 | 31 | string.startsWith = function(self, str) 32 | return self:find('^' .. str) ~= nil 33 | end 34 | 35 | function loadVGG() 36 | local proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt' 37 | local caffeModel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel' 38 | 39 | local fullModel = loadcaffe.load(proto, caffeModel, 'nn') 40 | local cnn = nn.Sequential() 41 | for i = 1, #fullModel do 42 | local name = fullModel:get(i).name 43 | if ( name:startsWith('relu') or name:startsWith('conv') or name:startsWith('pool') ) then 44 | cnn:add( fullModel:get(i) ) 45 | else 46 | break 47 | end 48 | end 49 | 50 | fullModel = nil 51 | collectgarbage() 52 | return cnn 53 | end 54 | 55 | function loadTrainData() 56 | local randSeq = torch.randperm(g.trainImages_Number) 57 | 58 | local trainSplit = math.floor(g.trainImages_Number * 0.85) 59 | 60 | g.trainSet = {} 61 | g.trainSet.data = {} 62 | g.trainSet.index = 1 63 | for i = 1, trainSplit do 64 | g.trainSet.data[i] = g.trainImages_Path .. '/' .. randSeq[i] .. '.png' 65 | end 66 | 67 | g.testSet = {} 68 | g.testSet.data = {} 69 | g.testSet.index = 1 70 | for i = trainSplit + 1, g.trainImages_Number do 71 | g.testSet.data[i] = g.trainImages_Path .. '/' .. randSeq[i] .. '.png' 72 | end 73 | end 74 | 75 | function loadBatch(set, batch_size) 76 | 77 | local batch = {} 78 | batch.x = torch.Tensor(batch_size, 3, 256, 256) 79 | 80 | for i = 1, batch_size do 81 | local sampleIndex = i + set.index 82 | sampleIndex = sampleIndex % #set.data + 1 83 | 84 | local rgb = image.loadPNG( set.data[sampleIndex], 3) 85 | batch.x[i]:copy( caffeImage.img2caffe(rgb) ) 86 | end 87 | 88 | set.index = (set.index + batch_size) % #set.data + 1 89 | 90 | return batch 91 | end 92 | 93 | ----------------------------------------------------------------------------------------- 94 | -- worker functions 95 | function buildLossNet () 96 | local gramLoss = {'relu1_2', 'relu2_2', 'relu3_2', 'relu4_1'} 97 | local contentLoss = {'relu4_2'} 98 | 99 | local styleCaffeImage = caffeImage.img2caffe( image.loadPNG(g.styleImage, 3) ) 100 | 101 | local modifier = {} 102 | local cindex = -1 103 | 104 | local net = nn.Sequential() 105 | net:add(nn.TVLoss(0.001)) 106 | 107 | local gram_index = 1 108 | local content_index = 1 109 | for i = 1, #g.vgg do 110 | if ( gram_index > #gramLoss and content_index > #contentLoss) then 111 | break 112 | end 113 | 114 | local name = g.vgg:get(i).name 115 | net:add(g.vgg:get(i)) 116 | 117 | if ( name == gramLoss[ gram_index ] ) then 118 | local target = net:forward( styleCaffeImage ) 119 | local layer = nn.GramLoss(0.01, target, false) 120 | net:add(layer) 121 | table.insert(modifier, layer) 122 | 123 | gram_index = gram_index + 1 124 | end 125 | 126 | if ( name == contentLoss[content_index] ) then 127 | local layer = nn.ContentLoss(1.0, nil, nil) 128 | net:add(layer) 129 | table.insert(modifier, layer) 130 | 131 | cindex = #modifier 132 | content_index = content_index + 1 133 | end 134 | end 135 | 136 | local lossNet = {} 137 | lossNet.net = net 138 | lossNet.modifier = modifier 139 | lossNet.cindex = cindex 140 | return lossNet 141 | end 142 | 143 | function buildStyledNet() 144 | local model = nn.Sequential() 145 | 146 | model:add(cudnn.SpatialConvolution(3, 32, 3, 3, 1, 1, 1, 1)) 147 | model:add(nn.SpatialBatchNormalization(32)) 148 | model:add(nn.LeakyReLU(0.1)) 149 | model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 1, 1)) 150 | model:add(nn.SpatialBatchNormalization(32)) 151 | model:add(nn.LeakyReLU(0.1)) 152 | model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 1, 1)) 153 | model:add(nn.SpatialBatchNormalization(64)) 154 | model:add(nn.LeakyReLU(0.1)) 155 | model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 1, 1)) 156 | model:add(nn.SpatialBatchNormalization(128)) 157 | model:add(nn.LeakyReLU(0.1)) 158 | model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 1, 1)) 159 | model:add(nn.SpatialBatchNormalization(128)) 160 | model:add(nn.LeakyReLU(0.1)) 161 | model:add(cudnn.SpatialConvolution(128, 3, 3, 3, 1, 1, 1, 1)) 162 | model:add(nn.Tanh()) 163 | model:add(nn.MulConstant(128)) 164 | 165 | return model 166 | end 167 | 168 | function doTrain() 169 | g.lossNet.net:cuda() 170 | g.styledNet:cuda() 171 | g.zeroLoss = g.zeroLoss:cuda() 172 | 173 | g.styledNet:training() 174 | g.lossNet.net:evaluate() 175 | 176 | local batchSize = 4 177 | local oneEpoch = math.floor( #g.trainSet.data / batchSize ) 178 | g.trainSet.index = 1 179 | 180 | local batch = nil 181 | local dyhat = torch.zeros(batchSize, 3, 256, 256):cuda() 182 | local parameters,gradParameters = g.styledNet:getParameters() 183 | 184 | local feval = function(x) 185 | -- get new parameters 186 | if x ~= parameters then 187 | parameters:copy(x) 188 | end 189 | -- reset gradients 190 | gradParameters:zero() 191 | 192 | local loss = 0 193 | local yhat = g.styledNet:forward( batch.x ) 194 | 195 | for i = 1, batchSize do 196 | g.lossNet.net:forward( batch.x[i] ) 197 | local contentTarget = g.lossNet.modifier[g.lossNet.cindex].output 198 | g.lossNet.modifier[g.lossNet.cindex]:setTarget(contentTarget) 199 | 200 | g.lossNet.net:forward(yhat[i]) 201 | local dy = g.lossNet.net:backward(yhat[i], g.zeroLoss) 202 | dyhat[i]:copy(dy) 203 | 204 | for _, mod in ipairs(g.lossNet.modifier) do 205 | loss = loss + mod.loss 206 | end 207 | end 208 | 209 | g.styledNet:backward(batch.x, dyhat) 210 | 211 | return loss/batchSize, gradParameters 212 | end 213 | 214 | local minValue = -1 215 | for j = 1, oneEpoch do 216 | batch = loadBatch(g.trainSet, batchSize) 217 | batch.x = batch.x:cuda() 218 | 219 | local _, err = optim.adam(feval, parameters, g.optimState) 220 | 221 | print(">>>>>>>>> err = " .. err[1]); 222 | 223 | if ( j % 100 == 0) then 224 | torch.save('./model/style_' .. err[1] .. '.t7', g.styledNet) 225 | end 226 | 227 | collectgarbage(); 228 | end 229 | 230 | end 231 | 232 | function doTest() 233 | 234 | 235 | end 236 | 237 | 238 | function doForward() 239 | local net = torch.load( arg[1] ) 240 | local img = image.loadPNG( arg[2] , 3) 241 | 242 | local img = caffeImage.img2caffe(img) 243 | local x = torch.Tensor(1, img:size(1), img:size(2), img:size(3)) 244 | x[1]:copy(img) 245 | x = x:cuda() 246 | 247 | local outImg = net:forward(x) 248 | outImg = outImg:float() 249 | outImg = caffeImage.caffe2img(outImg[1]) 250 | 251 | image.savePNG('./output.png', outImg) 252 | end 253 | 254 | 255 | ----------------------------------------------------------------------------------------- 256 | function main() 257 | torch.setdefaulttensortype('torch.FloatTensor') 258 | torch.manualSeed(1979) 259 | 260 | if ( #arg == 2) then 261 | doForward() 262 | return 263 | end 264 | 265 | 266 | -- build net 267 | g.vgg = loadVGG() 268 | g.lossNet = buildLossNet() 269 | local tempImage = torch.rand(3, 256, 256) 270 | local tempOutput = g.lossNet.net:forward(tempImage) 271 | g.zeroLoss = torch.zeros( tempOutput:size()) 272 | 273 | g.styledNet = buildStyledNet() 274 | g.optimState = { 275 | learningRate = 0.0005, 276 | } 277 | 278 | -- load data 279 | loadTrainData() 280 | 281 | -- trainging() 282 | for i = 1, 4 do 283 | doTrain() 284 | doTest() 285 | end 286 | end 287 | 288 | main() 289 | -------------------------------------------------------------------------------- /tools/buildMask.lua: -------------------------------------------------------------------------------- 1 | -- stupid tools only support: red, green, blue and white 4 colors masks 2 | -- 3 | 4 | require('image') 5 | 6 | function main() 7 | if ( #arg ~= 3) then 8 | print("Please input [style_mask_png_file], [target_mask_png_file], [output_mask_file] ") 9 | print("Only support red, green, blue and white 4 colors masks") 10 | return 11 | end 12 | 13 | local style_mask = image.load(arg[1], 3) 14 | local target_mask = image.load(arg[2], 3) 15 | 16 | if ( style_mask:size()[2] ~= target_mask:size()[2] 17 | or style_mask:size()[3] ~= target_mask:size()[3] ) then 18 | print("Error: style_mask and target_mask must be same size") 19 | return 20 | end 21 | 22 | local width = style_mask:size()[3] 23 | local height = style_mask:size()[2] 24 | 25 | -- only support 4 channels 26 | local masks = {} 27 | masks.style = {} 28 | masks.target = {} 29 | 30 | -- white only 31 | local whiteMask = torch.zeros(height, width):byte() 32 | for i = 1, 3 do 33 | whiteMask:add(style_mask[i]:le(0.5) * (-1) + 1) 34 | end 35 | whiteMask = whiteMask:le(2.5) * (-1) + 1 36 | masks.style[1] = whiteMask:clone() 37 | whiteMask:zero() 38 | for i = 1, 3 do 39 | whiteMask:add(target_mask[i]:le(0.5) * (-1) + 1) 40 | end 41 | whiteMask = whiteMask:le(2.5) * (-1) + 1 42 | masks.target[1] = whiteMask:clone() 43 | 44 | -- Red, Green, Blue 45 | for i = 1, 3 do 46 | local j = (i + 1) % 3 + 1 47 | 48 | local colorMask = (target_mask[i]:le(0.5) * (-1) + 1) 49 | colorMask:cmul( target_mask[j]:le(0.5) ) 50 | masks.target[i+1] = colorMask:clone() 51 | 52 | colorMask = (style_mask[i]:le(0.5) * (-1) + 1) 53 | colorMask:cmul( style_mask[j]:le(0.5) ) 54 | masks.style[i+1] = colorMask:clone() 55 | end 56 | 57 | torch.save(arg[3], masks, 'ascii') 58 | end 59 | 60 | main() 61 | --------------------------------------------------------------------------------