├── LICENSE
├── README.md
├── cnn
└── vgg19
│ ├── .gitignore
│ └── download_models.sh
├── dreamer.style
├── ezstyle.lua
├── gram.style
├── images
├── .DS_Store
├── _results
│ ├── dreamer.png
│ ├── gram.png
│ ├── masked_gram.png
│ └── mrf.png
├── data
│ └── renoir_gram_mask.t7
├── ford.png
├── lohan.png
├── picasso.png
├── renoir.png
├── renoir_style_mask.png
├── renoir_target_mask.png
├── trump.png
└── winter.png
├── lib
├── amplayer.lua
├── caffe_image.lua
├── cleanup_model.lua
├── contentloss.lua
├── gramloss.lua
├── masked_gramloss.lua
├── mrfloss.lua
├── randlayer.lua
└── tvloss.lua
├── masked_gram.style
├── mrf.style
├── styled_cnn.lua
└── tools
└── buildMask.lua
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Zhou Chang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # easyStyle
2 | All kinds of neural style transformer.
3 |
4 | This project collects many kinds of nerual style transformer , including
5 |
6 | * dreamer mode
7 |
8 | * gram matrix style
9 |
10 | * MRF style
11 |
12 | * guided style/patch transformer
13 |
14 |
15 | My project is foccus on clean and simple implementation of all kinds of algorithm.
16 |
17 |
18 | ---
19 |
20 | ## 1. install and setup
21 |
22 | Install following package of Torch7.
23 |
24 | ```
25 | cunn
26 | loadcaffe
27 | cudnn
28 | ```
29 |
30 | This project needs a GPU with 4G memory at least. Firstly you should download VGG19 caffe model.
31 |
32 | ```
33 | cd cnn/vgg19
34 | source ./download_models.sh
35 | ```
36 |
37 | ## 2. Quick demo
38 |
39 | The .style files descript the arch of network used in style transformer, they are based on Lua language.
40 | The .style files is very simple ,all the paramters are configed in thease files.
41 |
42 | ### 2.1 dreamer mode
43 |
44 | ```
45 | th ezstyle ./dreamer.style
46 | ```
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | ### 2.2 gram matrix mode
55 |
56 | ```
57 | th ezstyle ./gram.style
58 | ```
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 | ### 2.3 MRF mode
67 |
68 | ```
69 | th ezstyle ./mrf.style
70 | ```
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 | ### 2.4 guided mode 1 ( masked gram style transform )
79 |
80 | ```
81 | th ezstyle ./masked_gram.style
82 | ```
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 | ### 2.5 guided mode 2 ( masked mrf )
91 |
92 | [WIP]
93 |
94 |
95 | ### 3. Resources
96 |
97 | All of code is coming from following projects, I have make them more simpler and stupid :).
98 |
99 | https://github.com/chuanli11/CNNMRF
100 |
101 | https://github.com/alexjc/neural-doodle
102 |
103 | https://github.com/awentzonline/image-analogies
104 |
105 | https://github.com/jcjohnson/neural-style
106 |
107 | https://github.com/DmitryUlyanov/fast-neural-doodle
108 |
109 |
110 |
--------------------------------------------------------------------------------
/cnn/vgg19/.gitignore:
--------------------------------------------------------------------------------
1 | VGG_ILSVRC_19_layers.caffemodel
2 | VGG_ILSVRC_19_layers_deploy.prototxt
3 | VGG_ILSVRC_19_layers_deploy.prototxt.lua
4 | VGG_ILSVRC_19_layers_deploy.prototxt.cpu.lua
5 | VGG_ILSVRC_19_layers_deploy.prototxt.opencl.lua
6 | vgg_normalised.caffemodel
7 |
--------------------------------------------------------------------------------
/cnn/vgg19/download_models.sh:
--------------------------------------------------------------------------------
1 | cd models
2 | wget -c https://gist.githubusercontent.com/ksimonyan/3785162f95cd2d5fee77/raw/bb2b4fe0a9bb0669211cf3d0bc949dfdda173e9e/VGG_ILSVRC_19_layers_deploy.prototxt
3 | wget -c --no-check-certificate https://bethgelab.org/media/uploads/deeptextures/vgg_normalised.caffemodel
4 | wget -c http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel
5 | cd ..
6 |
--------------------------------------------------------------------------------
/dreamer.style:
--------------------------------------------------------------------------------
1 | local net = {
2 | image_list = {'./images/winter.png', './output.png'},
3 | input = 1,
4 | output = 2,
5 |
6 | convergence = false,
7 | maxIterate = 10,
8 | step = 0.0001,
9 |
10 | cnn = {
11 | proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt',
12 | caffemodel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel'
13 | },
14 |
15 | net = {
16 | {
17 | layer = 'relu2_1',
18 | type = 'amp',
19 | ratio = 1.0,
20 | },
21 | {
22 | layer = 'relu4_4',
23 | type = 'amp',
24 | ratio = 2.0,
25 | }
26 | }
27 | }
28 |
29 | return net
30 |
--------------------------------------------------------------------------------
/ezstyle.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 | require('nngraph')
3 | require('loadcaffe')
4 | require('xlua')
5 | require('image')
6 | require('optim')
7 | require('cunn')
8 |
9 | require('./lib/tvloss')
10 | require('./lib/contentloss')
11 | require('./lib/gramloss')
12 | require('./lib/mrfloss')
13 | require('./lib/masked_gramloss')
14 | require('./lib/amplayer')
15 | require('./lib/randlayer')
16 |
17 | local caffeImage = require('./lib/caffe_image')
18 |
19 | ----------------------------------------------------------------------------------------
20 | local g = {}
21 |
22 | local doRevert = function()
23 | local currentImage = g.x
24 | for i = 1, g.conf.maxIterate do
25 | local inout = g.net:forward(currentImage)
26 | inout = g.net:backward(currentImage, g.dy)
27 |
28 | currentImage:add(inout * (-1 * g.conf.step) );
29 | currentImage:clamp(-128,128)
30 |
31 | collectgarbage()
32 | xlua.progress(i, g.conf.maxIterate)
33 | end
34 | end
35 |
36 | local doConvergence = function()
37 | local optim_state = {
38 | maxIter = g.conf.maxIterate,
39 | verbose = true,
40 | }
41 |
42 | local num_calls = 0
43 | local function feval(x)
44 |
45 | num_calls = num_calls + 1
46 |
47 | g.net:forward(x)
48 | local grad = g.net:updateGradInput(x, g.dy)
49 | local loss = 0
50 |
51 | for _, mod in ipairs(g.modifier) do
52 | loss = loss + mod.loss
53 | end
54 |
55 | print(">>>>>>>>>" .. loss)
56 | --xlua.progress(num_calls, optim_state.maxIter)
57 |
58 | collectgarbage()
59 | -- optim.lbfgs expects a vector for gradients
60 | return loss, grad:view(grad:nElement())
61 | end
62 |
63 | local x, losses = optim.lbfgs(feval, g.x, optim_state)
64 | end
65 |
66 | local main = function()
67 | torch.setdefaulttensortype('torch.FloatTensor')
68 | torch.manualSeed(1979)
69 | if ( #arg < 1) then
70 | print("Please input config file!")
71 | os.exit(0)
72 | end
73 |
74 | -- init
75 | g.conf = dofile(arg[1])
76 | g.cnn = loadCNN(g.conf.cnn)
77 | g.net, g.modifier = buildNetwork(g.conf, g.cnn)
78 | g.x = loadInput(g.conf)
79 | g.dy = torch.zeros( g.net:forward(g.x):size() )
80 |
81 | -- cuda
82 | g.net:cuda()
83 | g.x = g.x:cuda()
84 | g.dy = g.dy:cuda()
85 |
86 | print(g.net)
87 |
88 | collectgarbage()
89 | if (g.conf.convergence) then
90 | doConvergence()
91 | else
92 | doRevert()
93 | end
94 |
95 | local img = caffeImage.caffe2img(g.x:float())
96 | image.savePNG(g.conf.image_list[g.conf.output], img)
97 | end
98 |
99 | -----------------------------------------------------------------------------------------
100 | -- helper functions
101 |
102 | string.startsWith = function(self, str)
103 | return self:find('^' .. str) ~= nil
104 | end
105 |
106 | function loadInput(conf)
107 | local img = nil
108 | if ( conf.image_list[conf.input] == nil ) then
109 | img = torch.rand(3, conf.height, conf.width)
110 | else
111 | img = image.load(conf.image_list[conf.input], 3)
112 | end
113 |
114 | img = caffeImage.img2caffe(img)
115 | return img
116 | end
117 |
118 | function loadCNN(cnnFiles)
119 | local fullModel = loadcaffe.load(cnnFiles.proto, cnnFiles.caffemodel, 'nn')
120 | local cnn = nn.Sequential();
121 | for i = 1, #fullModel do
122 | local name = fullModel:get(i).name
123 | if ( name:startsWith('relu') or name:startsWith('conv') or name:startsWith('pool') ) then
124 | cnn:add( fullModel:get(i) )
125 | else
126 | break
127 | end
128 | end
129 | fullModel = nil
130 | collectgarbage()
131 |
132 | return cnn
133 | end
134 |
135 | function buildNetwork(conf, cnn)
136 | local net = nn.Sequential()
137 | local modifier = {}
138 |
139 | local nindex = 1
140 | if ( conf.net[1].layer == 'input') then
141 | local layer = buildLayer(net, conf, 1, cnn)
142 | net:add(layer)
143 | nindex = 2
144 | end
145 |
146 | for i = 1, #cnn do
147 | local name = cnn:get(i).name
148 | net:add(cnn:get(i))
149 |
150 | if ( name == conf.net[nindex].layer ) then
151 | local layer = buildLayer(net, conf, nindex, cnn)
152 | net:add(layer)
153 |
154 | table.insert(modifier, layer)
155 |
156 | nindex = nindex + 1
157 | if ( nindex > #conf.net ) then
158 | break
159 | end
160 | end
161 | collectgarbage()
162 | end
163 |
164 | return net, modifier
165 | end
166 |
167 | function buildLayer(net, conf, nindex, cnn)
168 | local layer = nil
169 |
170 | if ( conf.net[nindex].type == "tvloss" ) then
171 | layer = nn.TVLoss(conf.net[nindex].weight)
172 | elseif ( conf.net[nindex].type == "amp") then
173 | layer = nn.AmpLayer(conf.net[nindex].ratio)
174 | elseif ( conf.net[nindex].type == "rand") then
175 | layer = nn.RandLayer()
176 | elseif ( conf.net[nindex].type == "content") then
177 | local targetImage = conf.image_list[ conf.net[nindex].target]
178 | targetImage = image.load(targetImage,3)
179 | local targetCaffe = caffeImage.img2caffe(targetImage)
180 | local target = net:forward(targetCaffe)
181 |
182 | layer = nn.ContentLoss(conf.net[nindex].weight, target)
183 | elseif ( conf.net[nindex].type == "gram") then
184 | local targetImage = conf.image_list[ conf.net[nindex].target]
185 | targetImage = image.load(targetImage,3)
186 | local targetCaffe = caffeImage.img2caffe(targetImage)
187 | local target = net:forward(targetCaffe)
188 |
189 | layer = nn.GramLoss(conf.net[nindex].weight, target)
190 | elseif ( conf.net[nindex].type == "mrf") then
191 | local targetImage = conf.image_list[ conf.net[nindex].target]
192 | targetImage = image.load(targetImage,3)
193 | local targetCaffe = caffeImage.img2caffe(targetImage)
194 | local target = net:forward(targetCaffe):clone()
195 |
196 | local inputImage = conf.image_list[ conf.input]
197 | inputImage = image.load(inputImage, 3)
198 | local inputCaffe = caffeImage.img2caffe(inputImage)
199 | local input = net:forward(inputCaffe)
200 |
201 | layer = nn.MRFLoss(conf.net[nindex].weight, input, target)
202 | elseif ( conf.net[nindex].type == 'mask_gram') then
203 | local styleImage = conf.image_list[ conf.net[nindex].style]
204 | styleImage = image.load(styleImage,3)
205 | local styleCaffe = caffeImage.img2caffe(styleImage)
206 | local style = net:forward(styleCaffe):clone()
207 |
208 | local masks = torch.load ( conf.image_list[ conf.net[nindex].mask], 'ascii')
209 |
210 | layer = nn.MaskedGramLoss(conf.net[nindex].weight, style, masks)
211 | end
212 |
213 | return layer
214 | end
215 |
216 | -----------------------------------------------------------------------------------------
217 | main()
218 |
219 |
--------------------------------------------------------------------------------
/gram.style:
--------------------------------------------------------------------------------
1 | local net = {
2 | image_list = {'./images/trump.png', './images/picasso.png', './output.png'},
3 | input = 1,
4 | output = 3,
5 |
6 | convergence = true,
7 | maxIterate = 500,
8 |
9 | cnn = {
10 | proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt',
11 | caffemodel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel'
12 | },
13 |
14 | net = {
15 | {
16 | layer = 'input',
17 | type = 'tvloss',
18 | weight = 0.001,
19 | },
20 | {
21 | layer = 'relu1_1',
22 | type = 'gram',
23 | weight = 1,
24 | target = 2,
25 | },
26 | {
27 | layer = 'relu2_1',
28 | type = 'gram',
29 | weight = 1,
30 | target = 2,
31 | },
32 | {
33 | layer = 'relu3_1',
34 | type = 'gram',
35 | weight = 1,
36 | target = 2,
37 | },
38 | {
39 | layer = 'relu4_1',
40 | type = 'gram',
41 | weight = 1,
42 | target = 2,
43 | },
44 | {
45 | layer = 'relu4_2',
46 | type = 'content',
47 | weight = 1000.0,
48 | target = 1,
49 | },
50 | {
51 | layer = 'relu5_1',
52 | type = 'gram',
53 | weight = 1,
54 | target = 2,
55 | }
56 | }
57 | }
58 |
59 | return net
60 |
--------------------------------------------------------------------------------
/images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/.DS_Store
--------------------------------------------------------------------------------
/images/_results/dreamer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/_results/dreamer.png
--------------------------------------------------------------------------------
/images/_results/gram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/_results/gram.png
--------------------------------------------------------------------------------
/images/_results/masked_gram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/_results/masked_gram.png
--------------------------------------------------------------------------------
/images/_results/mrf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/_results/mrf.png
--------------------------------------------------------------------------------
/images/ford.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/ford.png
--------------------------------------------------------------------------------
/images/lohan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/lohan.png
--------------------------------------------------------------------------------
/images/picasso.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/picasso.png
--------------------------------------------------------------------------------
/images/renoir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/renoir.png
--------------------------------------------------------------------------------
/images/renoir_style_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/renoir_style_mask.png
--------------------------------------------------------------------------------
/images/renoir_target_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/renoir_target_mask.png
--------------------------------------------------------------------------------
/images/trump.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/trump.png
--------------------------------------------------------------------------------
/images/winter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Teaonly/easyStyle/24b7dd4edf0a55d58b7c298ae7ad9c9707e4782b/images/winter.png
--------------------------------------------------------------------------------
/lib/amplayer.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 |
3 | local AmpLayer, parent = torch.class('nn.AmpLayer', 'nn.Module')
4 |
5 | function AmpLayer:__init(ratio)
6 | parent.__init(self)
7 | self.ratio = ratio
8 | self.loss = 0
9 | end
10 |
11 | function AmpLayer:updateOutput(input)
12 | self.output = input
13 | return self.output
14 | end
15 |
16 | function AmpLayer:updateGradInput(input, gradOutput)
17 | self.gradInput:resizeAs(input):copy(input)
18 |
19 | self.gradInput:mul(-1*self.ratio)
20 | self.gradInput:add(gradOutput)
21 | return self.gradInput
22 | end
23 |
24 |
--------------------------------------------------------------------------------
/lib/caffe_image.lua:
--------------------------------------------------------------------------------
1 | require('image')
2 |
3 | local caffeImage = {}
4 |
5 | caffeImage.img2caffe = function(img)
6 | local mean_pixel = torch.Tensor({103.939, 116.779, 123.68})
7 | local perm = torch.LongTensor{3, 2, 1}
8 | img = img:index(1, perm):mul(256.0)
9 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img)
10 | img:add(-1, mean_pixel)
11 | return img
12 | end
13 |
14 | caffeImage.caffe2img = function(img)
15 | local mean_pixel = torch.Tensor({103.939, 116.779, 123.68})
16 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img)
17 | img = img + mean_pixel
18 | local perm = torch.LongTensor{3, 2, 1}
19 | img = img:index(1, perm):div(256.0)
20 | return img
21 | end
22 |
23 | return caffeImage
24 |
--------------------------------------------------------------------------------
/lib/cleanup_model.lua:
--------------------------------------------------------------------------------
1 | -- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
2 |
3 | local function zeroDataSize(data)
4 | if type(data) == 'table' then
5 | for i = 1, #data do
6 | data[i] = zeroDataSize(data[i])
7 | end
8 | elseif type(data) == 'userdata' then
9 | data = torch.Tensor():typeAs(data)
10 | end
11 | return data
12 | end
13 | -- Resize the output, gradInput, etc temporary tensors to zero (so that the
14 | -- on disk size is smaller)
15 | local function cleanupModel(node)
16 | if node.output ~= nil then
17 | node.output = zeroDataSize(node.output)
18 | end
19 | if node.gradInput ~= nil then
20 | node.gradInput = zeroDataSize(node.gradInput)
21 | end
22 | if node.finput ~= nil then
23 | node.finput = zeroDataSize(node.finput)
24 | end
25 | if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
26 | if node.negative ~= nil then
27 | node.negative = zeroDataSize(node.negative)
28 | end
29 | end
30 | if tostring(node) == "nn.Dropout" then
31 | if node.noise ~= nil then
32 | node.noise = zeroDataSize(node.noise)
33 | end
34 | end
35 | -- Recurse on nodes with 'modules'
36 | if (node.modules ~= nil) then
37 | if (type(node.modules) == 'table') then
38 | for i = 1, #node.modules do
39 | local child = node.modules[i]
40 | cleanupModel(child)
41 | end
42 | end
43 | end
44 | end
45 |
46 | return cleanupModel
47 |
--------------------------------------------------------------------------------
/lib/contentloss.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 |
3 | -- Define an nn Module to compute content loss in-place
4 | local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module')
5 |
6 | function ContentLoss:__init(strength, target, normalize)
7 | parent.__init(self)
8 | self.strength = strength
9 | if ( target ~= nil) then
10 | self.target = target:clone()
11 | end
12 | self.normalize = normalize or false
13 | self.loss = 0
14 | self.crit = nn.MSECriterion()
15 | end
16 |
17 | function ContentLoss:setTarget(target)
18 | self.target = target:clone()
19 | end
20 |
21 | function ContentLoss:updateOutput(input)
22 | if self.target and input:nElement() == self.target:nElement() then
23 | self.loss = self.crit:forward(input, self.target) * self.strength
24 | end
25 |
26 | self.output = input
27 | return self.output
28 | end
29 |
30 | function ContentLoss:updateGradInput(input, gradOutput)
31 | if input:nElement() == self.target:nElement() then
32 | self.gradInput = self.crit:backward(input, self.target)
33 | end
34 | if self.normalize then
35 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8)
36 | end
37 | self.gradInput:mul(self.strength)
38 |
39 | self.gradInput:add(gradOutput)
40 | return self.gradInput
41 | end
42 |
43 |
44 |
--------------------------------------------------------------------------------
/lib/gramloss.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 |
3 | -- Returns a network that computes the CxC Gram matrix from inputs
4 | -- of size C x H x W
5 | function GramMatrix()
6 | local net = nn.Sequential()
7 | net:add(nn.View(-1):setNumInputDims(2))
8 | local concat = nn.ConcatTable()
9 | concat:add(nn.Identity())
10 | concat:add(nn.Identity())
11 | net:add(concat)
12 | net:add(nn.MM(false, true))
13 | return net
14 | end
15 |
16 |
17 | -- Define an nn Module to compute style loss in-place
18 | local GramLoss, parent = torch.class('nn.GramLoss', 'nn.Module')
19 |
20 | function GramLoss:__init(strength, target, normalize)
21 | parent.__init(self)
22 | self.normalize = normalize or false
23 | self.strength = strength
24 | self.loss = 0
25 |
26 | self.gram = GramMatrix()
27 | self.G = nil
28 | self.crit = nn.MSECriterion()
29 |
30 | local tsize = target:size()
31 | local img_size = tsize[2] * tsize[3]
32 |
33 | self.target = self.gram:forward(target):clone()
34 | self.target:div( img_size )
35 | end
36 |
37 | function GramLoss:updateOutput(input)
38 | local tsize = input:size()
39 | local img_size = tsize[2] * tsize[3]
40 |
41 | self.G = self.gram:forward(input)
42 | self.G:div(img_size)
43 |
44 | self.loss = self.crit:forward(self.G, self.target)
45 | self.loss = self.loss * self.strength
46 | self.output = input
47 | return self.output
48 | end
49 |
50 | function GramLoss:updateGradInput(input, gradOutput)
51 | local dG = self.crit:backward(self.G, self.target)
52 |
53 | local tsize = input:size()
54 | local img_size = tsize[2] * tsize[3]
55 | dG:div(img_size)
56 |
57 | self.gradInput = self.gram:backward(input, dG)
58 | if self.normalize then
59 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8)
60 | end
61 | self.gradInput:mul(self.strength)
62 | self.gradInput:add(gradOutput)
63 | return self.gradInput
64 | end
65 |
66 |
--------------------------------------------------------------------------------
/lib/masked_gramloss.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 | require('image')
3 |
4 | -- Returns a network that computes the CxC Gram matrix from inputs
5 | -- of size C x H x W
6 | function GramMatrix()
7 | local net = nn.Sequential()
8 | net:add(nn.View(-1):setNumInputDims(2))
9 | local concat = nn.ConcatTable()
10 | concat:add(nn.Identity())
11 | concat:add(nn.Identity())
12 | net:add(concat)
13 | net:add(nn.MM(false, true))
14 | return net
15 | end
16 |
17 |
18 | -- Define an nn Module to compute style loss in-place
19 | local MaskedGramLoss, parent = torch.class('nn.MaskedGramLoss', 'nn.Module')
20 |
21 | function MaskedGramLoss:__init(strength, style, masks)
22 | parent.__init(self)
23 | self.strength = strength
24 | self.loss = 0
25 | self.crit = nn.MSECriterion()
26 |
27 | local channel = style:size()[1]
28 | local hei = style:size()[2]
29 | local wid = style:size()[3]
30 |
31 | local gram = GramMatrix()
32 |
33 | local allGramTarget = {}
34 |
35 | local maskedLoss = nn.ConcatTable()
36 | for i = 1, #masks.style do
37 | local style_mask = image.scale(masks.style[i], wid, hei):float()
38 |
39 | style_mask = style_mask:view(1, hei, wid):expandAs(style)
40 | style_mask = torch.cmul(style_mask, style)
41 |
42 | allGramTarget[i] = gram:forward(style_mask):clone()
43 | allGramTarget[i]:div(wid*hei)
44 |
45 | local target_mask = image.scale(masks.target[i], wid, hei):float()
46 | target_mask = target_mask:view(1, hei, wid):expandAs(style)
47 |
48 | local mask_net = nn.Sequential()
49 | local cmul = nn.CMul(style:size())
50 | cmul.weight:copy( target_mask)
51 |
52 | mask_net:add(cmul)
53 | mask_net:add(GramMatrix())
54 | maskedLoss:add(mask_net)
55 | end
56 |
57 | self.allGramTarget = allGramTarget
58 | self.maskedLoss = maskedLoss
59 |
60 | end
61 |
62 | function MaskedGramLoss:updateOutput(input)
63 | local tsize = input:size()
64 | local img_size = tsize[2] * tsize[3]
65 |
66 | self.loss = 0
67 | local maskedGram = self.maskedLoss:forward(input)
68 | for i = 1, #maskedGram do
69 | maskedGram[i]:div(img_size)
70 | self.loss = self.loss + self.crit:forward(maskedGram[i], self.allGramTarget[i])
71 | end
72 | self.maskedGram = maskedGram
73 |
74 | self.loss = self.loss * self.strength
75 | self.output = input
76 | return self.output
77 | end
78 |
79 | function MaskedGramLoss:updateGradInput(input, gradOutput)
80 | local tsize = input:size()
81 | local img_size = tsize[2] * tsize[3]
82 |
83 | local dG = {}
84 | for i = 1, #self.allGramTarget do
85 | dG[i] = self.crit:backward(self.maskedGram[i], self.allGramTarget[i]):clone()
86 | dG[i]:mul(img_size)
87 | end
88 |
89 | self.gradInput = self.maskedLoss:backward(input, dG)
90 | self.gradInput:mul(self.strength)
91 | self.gradInput:add(gradOutput)
92 | return self.gradInput
93 | end
94 |
95 |
--------------------------------------------------------------------------------
/lib/mrfloss.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 |
3 | local buildMRFTarget = function(input, ref)
4 | -- this is a simple version , no scale, no rotation
5 | local channel = input:size()[1]
6 | local height = input:size()[2]
7 | local width = input:size()[3]
8 | local width_ref = ref:size()[3]
9 |
10 | local normConv = nn.SpatialConvolutionMM(channel, 1, 3, 3, 1, 1, 0, 0)
11 | normConv.weight:fill(1.0)
12 | normConv.bias:fill(0)
13 | local normValue = normConv:forward(ref:abs())
14 | normValue:div(9*channel)
15 |
16 | local target = input:clone()
17 |
18 | local x, y = 1, 1
19 | while true do
20 | -- processing patch by patch
21 | if ( y > height - 2) then
22 | break
23 | end
24 |
25 | local conv = nn.SpatialConvolution(channel, 1, 3, 3, 1, 1, 0, 0)
26 | conv.weight[1]:copy( input[{{}, {y,y+2}, {x, x+2}}] )
27 | conv.bias:fill(0)
28 |
29 | local scores = conv:forward(ref)
30 | scores:cdiv(normValue)
31 |
32 | -- find best match patch from reference images
33 | local _, pos = scores:view(-1):max(1)
34 | pos = pos[1] - 1
35 | local bestX = pos % ( width_ref - 2) + 1
36 | local bestY = math.floor( pos / ( width_ref - 2) ) + 1
37 |
38 | target[{{}, {y,y+2}, {x, x+2}}]:copy( ref[{{},{bestY, bestY+2},{bestX, bestX+2}}])
39 |
40 | x = x + 3
41 | if ( x > width - 2) then
42 | x = 1
43 | y = y + 3
44 | end
45 | collectgarbage()
46 | end
47 |
48 | return target
49 | end
50 |
51 |
52 | local MRFLoss, parent = torch.class('nn.MRFLoss', 'nn.Module')
53 |
54 | function MRFLoss:__init(strength, input, ref, normalize)
55 | parent.__init(self)
56 | self.normalize = normalize or false
57 | self.strength = strength
58 | self.loss = 0
59 | self.crit = nn.MSECriterion()
60 |
61 | self.target = buildMRFTarget(input, ref)
62 | end
63 |
64 |
65 | function MRFLoss:updateOutput(input)
66 | if ( self.target:nElement() == input:nElement() ) then
67 | self.loss = self.crit:forward(input, self.target) * self.strength
68 | end
69 |
70 | self.output = input
71 | return self.output
72 | end
73 |
74 | function MRFLoss:updateGradInput(input, gradOutput)
75 | self.gradInput:resizeAs(input):zero()
76 |
77 | if input:nElement() == self.target:nElement() then
78 | self.gradInput = self.crit:backward(input, self.target)
79 | end
80 | if self.normalize then
81 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8)
82 | end
83 | self.gradInput:mul(self.strength)
84 | self.gradInput:add(gradOutput)
85 | return self.gradInput
86 | end
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/lib/randlayer.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 |
3 | local RandLayer, parent = torch.class('nn.RandLayer', 'nn.Module')
4 |
5 | function RandLayer:__init()
6 | parent.__init(self)
7 | self.randMap = nil
8 | self.loss = 0
9 | end
10 |
11 | function RandLayer:updateOutput(input)
12 | self.output = input
13 |
14 | if ( self.randMap == nil or self.randMap:isSameSizeAs(input)) then
15 | self.randMap = torch.rand(input:size()) * -1
16 | end
17 |
18 | return self.output
19 | end
20 |
21 | function RandLayer:updateGradInput(input, gradOutput)
22 | self.gradInput:resizeAs(input):copy(input)
23 |
24 | self.gradInput:cmul(self.randMap)
25 | self.gradInput:add(gradOutput)
26 | return self.gradInput
27 | end
28 |
29 |
--------------------------------------------------------------------------------
/lib/tvloss.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 |
3 | local TVLoss, parent = torch.class('nn.TVLoss', 'nn.Module')
4 |
5 | function TVLoss:__init(strength)
6 | parent.__init(self)
7 | self.strength = strength
8 | self.x_diff = torch.Tensor()
9 | self.y_diff = torch.Tensor()
10 | end
11 |
12 | function TVLoss:updateOutput(input)
13 | self.output = input
14 | return self.output
15 | end
16 |
17 | -- TV loss backward pass inspired by kaishengtai/neuralart
18 | function TVLoss:updateGradInput(input, gradOutput)
19 | self.gradInput:resizeAs(input):zero()
20 | local C, H, W = input:size(1), input:size(2), input:size(3)
21 | self.x_diff:resize(3, H - 1, W - 1)
22 | self.y_diff:resize(3, H - 1, W - 1)
23 | self.x_diff:copy(input[{{}, {1, -2}, {1, -2}}])
24 | self.x_diff:add(-1, input[{{}, {1, -2}, {2, -1}}])
25 | self.y_diff:copy(input[{{}, {1, -2}, {1, -2}}])
26 | self.y_diff:add(-1, input[{{}, {2, -1}, {1, -2}}])
27 | self.gradInput[{{}, {1, -2}, {1, -2}}]:add(self.x_diff):add(self.y_diff)
28 | self.gradInput[{{}, {1, -2}, {2, -1}}]:add(-1, self.x_diff)
29 | self.gradInput[{{}, {2, -1}, {1, -2}}]:add(-1, self.y_diff)
30 | self.gradInput:mul(self.strength)
31 | self.gradInput:add(gradOutput)
32 | return self.gradInput
33 | end
34 |
35 |
36 |
--------------------------------------------------------------------------------
/masked_gram.style:
--------------------------------------------------------------------------------
1 | local net = {
2 | image_list = {'./images/renoir.png', './images/data/renoir_gram_mask.t7', './output.png'},
3 | input = -1,
4 | output = 3,
5 |
6 | width = 512,
7 | height = 320,
8 |
9 | convergence = true,
10 | maxIterate = 1500,
11 |
12 | cnn = {
13 | proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt',
14 | caffemodel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel'
15 | },
16 |
17 | net = {
18 | {
19 | layer = 'input',
20 | type = 'tvloss',
21 | weight = 0.001,
22 | },
23 | {
24 | layer = 'relu1_1',
25 | type = 'mask_gram',
26 | weight = 1,
27 | style = 1,
28 | mask = 2
29 | },
30 | {
31 | layer = 'relu2_1',
32 | type = 'mask_gram',
33 | weight = 1,
34 | style = 1,
35 | mask = 2
36 | },
37 | {
38 | layer = 'relu3_1',
39 | type = 'mask_gram',
40 | weight = 1,
41 | style = 1,
42 | mask = 2
43 | },
44 | {
45 | layer = 'relu4_1',
46 | type = 'mask_gram',
47 | weight = 1,
48 | style = 1,
49 | mask = 2
50 | },
51 | {
52 | layer = 'relu5_1',
53 | type = 'mask_gram',
54 | weight = 1,
55 | style = 1,
56 | mask = 2
57 | }
58 | }
59 | }
60 |
61 | return net
62 |
--------------------------------------------------------------------------------
/mrf.style:
--------------------------------------------------------------------------------
1 | local net = {
2 | image_list = {'./images/ford.png', './images/lohan.png', './output.png'},
3 | input = 1,
4 | output = 3,
5 |
6 | convergence = true,
7 | maxIterate = 1000,
8 |
9 | cnn = {
10 | proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt',
11 | caffemodel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel'
12 | },
13 |
14 | net = {
15 | {
16 | layer = 'input',
17 | type = 'tvloss',
18 | weight = 0.001,
19 | },
20 | {
21 | layer = 'relu4_1',
22 | type = 'mrf',
23 | weight = 10.0,
24 | target = 2,
25 | },
26 | {
27 | layer = 'relu4_2',
28 | type = 'content',
29 | weight = 1.0,
30 | target = 1,
31 | }
32 | }
33 | }
34 |
35 | return net
36 |
--------------------------------------------------------------------------------
/styled_cnn.lua:
--------------------------------------------------------------------------------
1 | require('nn')
2 | require('nngraph')
3 | require('loadcaffe')
4 | require('xlua')
5 | require('image')
6 | require('optim')
7 | require('cunn')
8 | require('cudnn')
9 |
10 | require('./lib/tvloss')
11 | require('./lib/contentloss')
12 | require('./lib/gramloss')
13 | require('./lib/mrfloss')
14 | require('./lib/masked_gramloss')
15 | require('./lib/amplayer')
16 | require('./lib/randlayer')
17 |
18 |
19 | local cleanupModel = require('./lib/cleanup_model')
20 | local caffeImage = require('./lib/caffe_image')
21 | local g = {}
22 | g.styleImage = './images/picasso.png'
23 |
24 | g.trainImages_Path = './scene/'
25 | g.trainImages_Number = 16657
26 |
27 |
28 | -----------------------------------------------------------------------------------------
29 | -- helper functions
30 |
31 | string.startsWith = function(self, str)
32 | return self:find('^' .. str) ~= nil
33 | end
34 |
35 | function loadVGG()
36 | local proto = './cnn/vgg19/VGG_ILSVRC_19_layers_deploy.prototxt'
37 | local caffeModel = './cnn/vgg19/VGG_ILSVRC_19_layers.caffemodel'
38 |
39 | local fullModel = loadcaffe.load(proto, caffeModel, 'nn')
40 | local cnn = nn.Sequential()
41 | for i = 1, #fullModel do
42 | local name = fullModel:get(i).name
43 | if ( name:startsWith('relu') or name:startsWith('conv') or name:startsWith('pool') ) then
44 | cnn:add( fullModel:get(i) )
45 | else
46 | break
47 | end
48 | end
49 |
50 | fullModel = nil
51 | collectgarbage()
52 | return cnn
53 | end
54 |
55 | function loadTrainData()
56 | local randSeq = torch.randperm(g.trainImages_Number)
57 |
58 | local trainSplit = math.floor(g.trainImages_Number * 0.85)
59 |
60 | g.trainSet = {}
61 | g.trainSet.data = {}
62 | g.trainSet.index = 1
63 | for i = 1, trainSplit do
64 | g.trainSet.data[i] = g.trainImages_Path .. '/' .. randSeq[i] .. '.png'
65 | end
66 |
67 | g.testSet = {}
68 | g.testSet.data = {}
69 | g.testSet.index = 1
70 | for i = trainSplit + 1, g.trainImages_Number do
71 | g.testSet.data[i] = g.trainImages_Path .. '/' .. randSeq[i] .. '.png'
72 | end
73 | end
74 |
75 | function loadBatch(set, batch_size)
76 |
77 | local batch = {}
78 | batch.x = torch.Tensor(batch_size, 3, 256, 256)
79 |
80 | for i = 1, batch_size do
81 | local sampleIndex = i + set.index
82 | sampleIndex = sampleIndex % #set.data + 1
83 |
84 | local rgb = image.loadPNG( set.data[sampleIndex], 3)
85 | batch.x[i]:copy( caffeImage.img2caffe(rgb) )
86 | end
87 |
88 | set.index = (set.index + batch_size) % #set.data + 1
89 |
90 | return batch
91 | end
92 |
93 | -----------------------------------------------------------------------------------------
94 | -- worker functions
95 | function buildLossNet ()
96 | local gramLoss = {'relu1_2', 'relu2_2', 'relu3_2', 'relu4_1'}
97 | local contentLoss = {'relu4_2'}
98 |
99 | local styleCaffeImage = caffeImage.img2caffe( image.loadPNG(g.styleImage, 3) )
100 |
101 | local modifier = {}
102 | local cindex = -1
103 |
104 | local net = nn.Sequential()
105 | net:add(nn.TVLoss(0.001))
106 |
107 | local gram_index = 1
108 | local content_index = 1
109 | for i = 1, #g.vgg do
110 | if ( gram_index > #gramLoss and content_index > #contentLoss) then
111 | break
112 | end
113 |
114 | local name = g.vgg:get(i).name
115 | net:add(g.vgg:get(i))
116 |
117 | if ( name == gramLoss[ gram_index ] ) then
118 | local target = net:forward( styleCaffeImage )
119 | local layer = nn.GramLoss(0.01, target, false)
120 | net:add(layer)
121 | table.insert(modifier, layer)
122 |
123 | gram_index = gram_index + 1
124 | end
125 |
126 | if ( name == contentLoss[content_index] ) then
127 | local layer = nn.ContentLoss(1.0, nil, nil)
128 | net:add(layer)
129 | table.insert(modifier, layer)
130 |
131 | cindex = #modifier
132 | content_index = content_index + 1
133 | end
134 | end
135 |
136 | local lossNet = {}
137 | lossNet.net = net
138 | lossNet.modifier = modifier
139 | lossNet.cindex = cindex
140 | return lossNet
141 | end
142 |
143 | function buildStyledNet()
144 | local model = nn.Sequential()
145 |
146 | model:add(cudnn.SpatialConvolution(3, 32, 3, 3, 1, 1, 1, 1))
147 | model:add(nn.SpatialBatchNormalization(32))
148 | model:add(nn.LeakyReLU(0.1))
149 | model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 1, 1))
150 | model:add(nn.SpatialBatchNormalization(32))
151 | model:add(nn.LeakyReLU(0.1))
152 | model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 1, 1))
153 | model:add(nn.SpatialBatchNormalization(64))
154 | model:add(nn.LeakyReLU(0.1))
155 | model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 1, 1))
156 | model:add(nn.SpatialBatchNormalization(128))
157 | model:add(nn.LeakyReLU(0.1))
158 | model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 1, 1))
159 | model:add(nn.SpatialBatchNormalization(128))
160 | model:add(nn.LeakyReLU(0.1))
161 | model:add(cudnn.SpatialConvolution(128, 3, 3, 3, 1, 1, 1, 1))
162 | model:add(nn.Tanh())
163 | model:add(nn.MulConstant(128))
164 |
165 | return model
166 | end
167 |
168 | function doTrain()
169 | g.lossNet.net:cuda()
170 | g.styledNet:cuda()
171 | g.zeroLoss = g.zeroLoss:cuda()
172 |
173 | g.styledNet:training()
174 | g.lossNet.net:evaluate()
175 |
176 | local batchSize = 4
177 | local oneEpoch = math.floor( #g.trainSet.data / batchSize )
178 | g.trainSet.index = 1
179 |
180 | local batch = nil
181 | local dyhat = torch.zeros(batchSize, 3, 256, 256):cuda()
182 | local parameters,gradParameters = g.styledNet:getParameters()
183 |
184 | local feval = function(x)
185 | -- get new parameters
186 | if x ~= parameters then
187 | parameters:copy(x)
188 | end
189 | -- reset gradients
190 | gradParameters:zero()
191 |
192 | local loss = 0
193 | local yhat = g.styledNet:forward( batch.x )
194 |
195 | for i = 1, batchSize do
196 | g.lossNet.net:forward( batch.x[i] )
197 | local contentTarget = g.lossNet.modifier[g.lossNet.cindex].output
198 | g.lossNet.modifier[g.lossNet.cindex]:setTarget(contentTarget)
199 |
200 | g.lossNet.net:forward(yhat[i])
201 | local dy = g.lossNet.net:backward(yhat[i], g.zeroLoss)
202 | dyhat[i]:copy(dy)
203 |
204 | for _, mod in ipairs(g.lossNet.modifier) do
205 | loss = loss + mod.loss
206 | end
207 | end
208 |
209 | g.styledNet:backward(batch.x, dyhat)
210 |
211 | return loss/batchSize, gradParameters
212 | end
213 |
214 | local minValue = -1
215 | for j = 1, oneEpoch do
216 | batch = loadBatch(g.trainSet, batchSize)
217 | batch.x = batch.x:cuda()
218 |
219 | local _, err = optim.adam(feval, parameters, g.optimState)
220 |
221 | print(">>>>>>>>> err = " .. err[1]);
222 |
223 | if ( j % 100 == 0) then
224 | torch.save('./model/style_' .. err[1] .. '.t7', g.styledNet)
225 | end
226 |
227 | collectgarbage();
228 | end
229 |
230 | end
231 |
232 | function doTest()
233 |
234 |
235 | end
236 |
237 |
238 | function doForward()
239 | local net = torch.load( arg[1] )
240 | local img = image.loadPNG( arg[2] , 3)
241 |
242 | local img = caffeImage.img2caffe(img)
243 | local x = torch.Tensor(1, img:size(1), img:size(2), img:size(3))
244 | x[1]:copy(img)
245 | x = x:cuda()
246 |
247 | local outImg = net:forward(x)
248 | outImg = outImg:float()
249 | outImg = caffeImage.caffe2img(outImg[1])
250 |
251 | image.savePNG('./output.png', outImg)
252 | end
253 |
254 |
255 | -----------------------------------------------------------------------------------------
256 | function main()
257 | torch.setdefaulttensortype('torch.FloatTensor')
258 | torch.manualSeed(1979)
259 |
260 | if ( #arg == 2) then
261 | doForward()
262 | return
263 | end
264 |
265 |
266 | -- build net
267 | g.vgg = loadVGG()
268 | g.lossNet = buildLossNet()
269 | local tempImage = torch.rand(3, 256, 256)
270 | local tempOutput = g.lossNet.net:forward(tempImage)
271 | g.zeroLoss = torch.zeros( tempOutput:size())
272 |
273 | g.styledNet = buildStyledNet()
274 | g.optimState = {
275 | learningRate = 0.0005,
276 | }
277 |
278 | -- load data
279 | loadTrainData()
280 |
281 | -- trainging()
282 | for i = 1, 4 do
283 | doTrain()
284 | doTest()
285 | end
286 | end
287 |
288 | main()
289 |
--------------------------------------------------------------------------------
/tools/buildMask.lua:
--------------------------------------------------------------------------------
1 | -- stupid tools only support: red, green, blue and white 4 colors masks
2 | --
3 |
4 | require('image')
5 |
6 | function main()
7 | if ( #arg ~= 3) then
8 | print("Please input [style_mask_png_file], [target_mask_png_file], [output_mask_file] ")
9 | print("Only support red, green, blue and white 4 colors masks")
10 | return
11 | end
12 |
13 | local style_mask = image.load(arg[1], 3)
14 | local target_mask = image.load(arg[2], 3)
15 |
16 | if ( style_mask:size()[2] ~= target_mask:size()[2]
17 | or style_mask:size()[3] ~= target_mask:size()[3] ) then
18 | print("Error: style_mask and target_mask must be same size")
19 | return
20 | end
21 |
22 | local width = style_mask:size()[3]
23 | local height = style_mask:size()[2]
24 |
25 | -- only support 4 channels
26 | local masks = {}
27 | masks.style = {}
28 | masks.target = {}
29 |
30 | -- white only
31 | local whiteMask = torch.zeros(height, width):byte()
32 | for i = 1, 3 do
33 | whiteMask:add(style_mask[i]:le(0.5) * (-1) + 1)
34 | end
35 | whiteMask = whiteMask:le(2.5) * (-1) + 1
36 | masks.style[1] = whiteMask:clone()
37 | whiteMask:zero()
38 | for i = 1, 3 do
39 | whiteMask:add(target_mask[i]:le(0.5) * (-1) + 1)
40 | end
41 | whiteMask = whiteMask:le(2.5) * (-1) + 1
42 | masks.target[1] = whiteMask:clone()
43 |
44 | -- Red, Green, Blue
45 | for i = 1, 3 do
46 | local j = (i + 1) % 3 + 1
47 |
48 | local colorMask = (target_mask[i]:le(0.5) * (-1) + 1)
49 | colorMask:cmul( target_mask[j]:le(0.5) )
50 | masks.target[i+1] = colorMask:clone()
51 |
52 | colorMask = (style_mask[i]:le(0.5) * (-1) + 1)
53 | colorMask:cmul( style_mask[j]:le(0.5) )
54 | masks.style[i+1] = colorMask:clone()
55 | end
56 |
57 | torch.save(arg[3], masks, 'ascii')
58 | end
59 |
60 | main()
61 |
--------------------------------------------------------------------------------