├── teaser3.png ├── input ├── style │ ├── s001.jpg │ ├── s002.jpg │ ├── s003.jpg │ ├── s004.jpg │ ├── s005.jpg │ ├── s006.jpg │ ├── s007.jpg │ ├── s008.jpg │ ├── s009.jpg │ ├── s010.jpg │ ├── s011.jpg │ ├── s012.jpg │ ├── s013.jpg │ ├── s014.jpg │ ├── s015.jpg │ ├── s016.jpg │ ├── s017.jpg │ ├── s018.jpg │ ├── s019.jpg │ ├── s020.jpg │ ├── s021.jpg │ ├── s022.jpg │ ├── s023.jpg │ ├── s024.jpg │ ├── s025.jpg │ ├── s026.jpg │ ├── s027.jpg │ ├── s028.jpg │ ├── s029.jpg │ ├── s030.jpg │ ├── s031.jpg │ ├── s032.jpg │ ├── s033.jpg │ ├── s034.jpg │ └── s035.jpg └── content │ ├── c001.jpg │ ├── c002.jpg │ ├── c003.jpg │ ├── c004.jpg │ ├── c005.jpg │ ├── c006.jpg │ ├── c007.jpg │ ├── c008.jpg │ ├── c009.jpg │ ├── c010.jpg │ ├── c011.jpg │ ├── c012.jpg │ ├── c013.jpg │ ├── c014.jpg │ ├── c015.jpg │ └── c016.jpg ├── output ├── c016_stylized_by_s006_mk.jpg ├── c016_stylized_by_s014_mk.jpg ├── c016_stylized_by_s015_mk.jpg ├── c016_stylized_by_s019_mk.jpg └── c016_stylized_by_s023_mk.jpg ├── demo_folder.sh ├── demo.sh ├── lib ├── TVLossModule.lua ├── MaxCoord.lua ├── TVLossCriterion.lua ├── AdaptiveInstanceNormalization.lua ├── ContentLossModule.lua ├── ImageLoader.lua ├── InstanceNormalization.lua ├── utils.lua ├── NonparametricPatchAutoencoderFactory.lua ├── StyleINLossModule.lua ├── ImageLoaderAsync.lua └── ArtisticStyleLossCriterion.lua ├── README.md ├── runCalcLoss.lua ├── optimal.lua └── optimal_folder.lua /teaser3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/teaser3.png -------------------------------------------------------------------------------- /input/style/s001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s001.jpg -------------------------------------------------------------------------------- /input/style/s002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s002.jpg -------------------------------------------------------------------------------- /input/style/s003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s003.jpg -------------------------------------------------------------------------------- /input/style/s004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s004.jpg -------------------------------------------------------------------------------- /input/style/s005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s005.jpg -------------------------------------------------------------------------------- /input/style/s006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s006.jpg -------------------------------------------------------------------------------- /input/style/s007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s007.jpg -------------------------------------------------------------------------------- /input/style/s008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s008.jpg -------------------------------------------------------------------------------- /input/style/s009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s009.jpg -------------------------------------------------------------------------------- /input/style/s010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s010.jpg -------------------------------------------------------------------------------- /input/style/s011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s011.jpg -------------------------------------------------------------------------------- /input/style/s012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s012.jpg -------------------------------------------------------------------------------- /input/style/s013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s013.jpg -------------------------------------------------------------------------------- /input/style/s014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s014.jpg -------------------------------------------------------------------------------- /input/style/s015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s015.jpg -------------------------------------------------------------------------------- /input/style/s016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s016.jpg -------------------------------------------------------------------------------- /input/style/s017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s017.jpg -------------------------------------------------------------------------------- /input/style/s018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s018.jpg -------------------------------------------------------------------------------- /input/style/s019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s019.jpg -------------------------------------------------------------------------------- /input/style/s020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s020.jpg -------------------------------------------------------------------------------- /input/style/s021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s021.jpg -------------------------------------------------------------------------------- /input/style/s022.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s022.jpg -------------------------------------------------------------------------------- /input/style/s023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s023.jpg -------------------------------------------------------------------------------- /input/style/s024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s024.jpg -------------------------------------------------------------------------------- /input/style/s025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s025.jpg -------------------------------------------------------------------------------- /input/style/s026.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s026.jpg -------------------------------------------------------------------------------- /input/style/s027.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s027.jpg -------------------------------------------------------------------------------- /input/style/s028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s028.jpg -------------------------------------------------------------------------------- /input/style/s029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s029.jpg -------------------------------------------------------------------------------- /input/style/s030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s030.jpg -------------------------------------------------------------------------------- /input/style/s031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s031.jpg -------------------------------------------------------------------------------- /input/style/s032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s032.jpg -------------------------------------------------------------------------------- /input/style/s033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s033.jpg -------------------------------------------------------------------------------- /input/style/s034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s034.jpg -------------------------------------------------------------------------------- /input/style/s035.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/style/s035.jpg -------------------------------------------------------------------------------- /input/content/c001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c001.jpg -------------------------------------------------------------------------------- /input/content/c002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c002.jpg -------------------------------------------------------------------------------- /input/content/c003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c003.jpg -------------------------------------------------------------------------------- /input/content/c004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c004.jpg -------------------------------------------------------------------------------- /input/content/c005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c005.jpg -------------------------------------------------------------------------------- /input/content/c006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c006.jpg -------------------------------------------------------------------------------- /input/content/c007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c007.jpg -------------------------------------------------------------------------------- /input/content/c008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c008.jpg -------------------------------------------------------------------------------- /input/content/c009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c009.jpg -------------------------------------------------------------------------------- /input/content/c010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c010.jpg -------------------------------------------------------------------------------- /input/content/c011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c011.jpg -------------------------------------------------------------------------------- /input/content/c012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c012.jpg -------------------------------------------------------------------------------- /input/content/c013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c013.jpg -------------------------------------------------------------------------------- /input/content/c014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c014.jpg -------------------------------------------------------------------------------- /input/content/c015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c015.jpg -------------------------------------------------------------------------------- /input/content/c016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/input/content/c016.jpg -------------------------------------------------------------------------------- /output/c016_stylized_by_s006_mk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/output/c016_stylized_by_s006_mk.jpg -------------------------------------------------------------------------------- /output/c016_stylized_by_s014_mk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/output/c016_stylized_by_s014_mk.jpg -------------------------------------------------------------------------------- /output/c016_stylized_by_s015_mk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/output/c016_stylized_by_s015_mk.jpg -------------------------------------------------------------------------------- /output/c016_stylized_by_s019_mk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/output/c016_stylized_by_s019_mk.jpg -------------------------------------------------------------------------------- /output/c016_stylized_by_s023_mk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lu-m13/OptimalStyleTransfer/HEAD/output/c016_stylized_by_s023_mk.jpg -------------------------------------------------------------------------------- /demo_folder.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 th optimal_folder.lua -style ./input/style -content ./input/content -alpha 0.8 -patchSize 3 -patchStride 1 -contentSize 0 -styleSize 0 -outputDir ./output/ 2 | -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 th optimal.lua -style ./input/style/s001.jpg -content ./input/content/c001.jpg -alpha 0.8 -patchSize 3 -patchStride 1 -contentSize 0 -styleSize 0 -outputDir ./output/ 2 | -------------------------------------------------------------------------------- /lib/TVLossModule.lua: -------------------------------------------------------------------------------- 1 | require 'lib/TVLossCriterion' 2 | 3 | local module, parent = torch.class('nn.TVLossModule', 'nn.Module') 4 | 5 | function module:__init(strength) 6 | parent.__init(self) 7 | self.strength = strength or 1 8 | self.crit = nn.TVLossCriterion() 9 | self.loss = 0 10 | end 11 | 12 | function module:updateOutput(input) 13 | self.loss = self.crit:forward(input) 14 | self.loss = self.loss * self.strength 15 | self.output = input 16 | return self.output 17 | end 18 | 19 | function module:updateGradInput(input, gradOutput) 20 | self.gradInput = self.crit:backward(input) 21 | self.gradInput:mul(self.strength) 22 | self.gradInput:add(gradOutput) 23 | return self.gradInput 24 | end 25 | -------------------------------------------------------------------------------- /lib/MaxCoord.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Expects 3D or 4D input. Does a max over the feature channels. 3 | --]] 4 | local MaxCoord, parent = torch.class('nn.MaxCoord', 'nn.Module') 5 | 6 | function MaxCoord:__init(inplace) 7 | parent.__init(self) 8 | self.inplace = inplace or false 9 | end 10 | 11 | function MaxCoord:updateOutput(input) 12 | local nInputDim = input:nDimension() 13 | if input:nDimension() == 3 then 14 | local C,H,W = input:size(1), input:size(2), input:size(3) 15 | input = input:view(1,C,H,W) 16 | end 17 | assert(input:nDimension()==4, 'Input must be 3D or 4D (batch).') 18 | 19 | if self._type ~= 'torch.FloatTensor' then 20 | input = input:float() 21 | end 22 | 23 | local _, argmax = torch.max(input,2) 24 | 25 | if self.inplace then 26 | self.output = input:zero() 27 | else 28 | self.output = torch.FloatTensor():resizeAs(input):zero() 29 | end 30 | 31 | local N = input:size(1) 32 | 33 | for b=1,N do 34 | for i=1,self.output:size(3) do 35 | for j=1,self.output:size(4) do 36 | ind = argmax[{b,1,i,j}] 37 | self.output[{b,ind,i,j}] = 1 38 | end 39 | end 40 | end 41 | 42 | self.output = self.output:type(self._type) 43 | 44 | if nInputDim == 3 then 45 | self.output = self.output[1] 46 | end 47 | return self.output 48 | end 49 | 50 | function MaxCoord:updateGradInput(input, gradOutput) 51 | self.gradInput:resizeAs(input):zero() 52 | return self.gradInput 53 | end -------------------------------------------------------------------------------- /lib/TVLossCriterion.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local TVLossCriterion, parent = torch.class('nn.TVLossCriterion', 'nn.Criterion') 4 | 5 | function TVLossCriterion:__init() 6 | parent.__init(self) 7 | 8 | local crop_l = nn.SpatialZeroPadding(-1, 0, 0, 0) 9 | local crop_r = nn.SpatialZeroPadding(0, -1, 0, 0) 10 | local crop_t = nn.SpatialZeroPadding(0, 0, -1, 0) 11 | local crop_b = nn.SpatialZeroPadding(0, 0, 0, -1) 12 | self.target = torch.zeros(1) 13 | self.mse = nn.MSECriterion() 14 | self.mse.sizeAverage = false 15 | 16 | local lr = nn.Sequential() 17 | lr:add(nn.ConcatTable():add(crop_l):add(crop_r)) 18 | lr:add(nn.CSubTable()) 19 | local tb = nn.Sequential() 20 | tb:add(nn.ConcatTable():add(crop_t):add(crop_b)) 21 | tb:add(nn.CSubTable()) 22 | 23 | self.crit = nn.ConcatTable():add(lr):add(tb) 24 | end 25 | 26 | function TVLossCriterion:updateOutput(input) 27 | local output = self.crit:forward(input) 28 | local loss = 0 29 | for i=1,2 do 30 | local target 31 | if output[i]:nDimension() == 3 then 32 | target = self.target:view(1,1,1):expandAs(output[i]) 33 | else 34 | target = self.target:view(1,1,1,1):expandAs(output[i]) 35 | end 36 | loss = loss + self.mse:forward(output[i], target) 37 | end 38 | self.output = loss 39 | return self.output 40 | end 41 | 42 | function TVLossCriterion:updateGradInput(input) 43 | self.gradInput:resizeAs(input):zero() 44 | local output = self.crit.output 45 | local df_do = {} 46 | for i=1,2 do 47 | local target 48 | if output[i]:nDimension() == 3 then 49 | target = self.target:view(1,1,1):expandAs(output[i]) 50 | else 51 | target = self.target:view(1,1,1,1):expandAs(output[i]) 52 | end 53 | df_do[i] = self.mse:backward(output[i], target):clone() 54 | end 55 | local grad = self.crit:backward(input, df_do) 56 | self.gradInput:copy(self.crit:backward(input, df_do)) 57 | return self.gradInput 58 | end 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for A Closed-form Solution to Universal Style Transfer (https://arxiv.org/abs/1906.00668) - ICCV2019 2 | ============= 3 | 4 | This work mathematically derives a closed-form solution to universal style transfer. It is based on the theory of optimal transport and is closed related to AdaIN and WCT. AdaIN ignores the correlation between channels and WCT does not minimize the content loss. We consider both of them. Details of the derivation can be found in the paper. 5 | 6 | 7 | ![Teaser](./teaser3.png) 8 | 9 | ## Acknowledgments 10 | 11 | Link to AdaIN : https://github.com/xunhuang1995/AdaIN-style 12 | 13 | Link to WCT : https://github.com/Yijunmaverick/UniversalStyleTransfer 14 | 15 | ## Usage 16 | 17 | 1. Install Torch from http://torch.ch/ 18 | 19 | 2. Download encoders, decoders from [here](https://drive.google.com/open?id=1uv1m15RqTwgWQog7BMAW38bDVE7BkzO4) and unzip it to models/ 20 | 21 | 3. For single image usage, see demo.sh 22 | 23 | 4. For folder images usage, see demo_folder.sh 24 | 25 | 26 | ## Citation 27 | 28 | If you find this code useful in your research, please consider to cite the following paper: 29 | 30 | ``` 31 | @article{lu2019optimal, 32 | title={a closed-form solution to universal style transfer}, 33 | author={Ming Lu, Hao Zhao, Anbang Yao, Yurong Chen, Feng Xu, Li Zhang}, 34 | journal={ICCV 2019}, 35 | year={2019} 36 | } 37 | ``` 38 | 39 | ## Contemporary Works 40 | 41 | I recently find two contemporary works (same conclusions), please also consider to cite them. 42 | 43 | ``` 44 | @inproceedings{li2019optimal, 45 | title={Optimal Transport of Deep Feature for Image Style Transfer}, 46 | author={Li, Pan and Zhao, Lei and Xu, Duanqing and Lu, Dongming}, 47 | booktitle={Proceedings of the 2019 4th International Conference on Multimedia Systems and Signal Processing}, 48 | pages={167--171}, 49 | year={2019}, 50 | organization={ACM} 51 | } 52 | ``` 53 | 54 | ``` 55 | @article{mroueh2019wasserstein, 56 | title={Wasserstein Style Transfer}, 57 | author={Mroueh, Youssef}, 58 | journal={arXiv preprint arXiv:1905.12828}, 59 | year={2019} 60 | } 61 | ``` 62 | 63 | 64 | -------------------------------------------------------------------------------- /lib/AdaptiveInstanceNormalization.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | --[[ 4 | Implements adaptive instance normalization (AdaIN) as described in the paper: 5 | 6 | Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization 7 | Xun Huang, Serge Belongie 8 | ]] 9 | 10 | local AdaptiveInstanceNormalization, parent = torch.class('nn.AdaptiveInstanceNormalization', 'nn.Module') 11 | 12 | function AdaptiveInstanceNormalization:__init(nOutput, disabled, eps) 13 | parent.__init(self) 14 | 15 | self.eps = eps or 1e-5 16 | 17 | self.nOutput = nOutput 18 | self.batchSize = -1 19 | self.disabled = disabled 20 | end 21 | 22 | function AdaptiveInstanceNormalization:updateOutput(input) --{content, style} 23 | local content = input[1] 24 | local style = input[2] 25 | 26 | if self.disabled then 27 | self.output = content 28 | return self.output 29 | end 30 | 31 | local N, Hc, Wc, Hs, Ws 32 | if content:nDimension() == 3 then 33 | assert(content:size(1) == self.nOutput) 34 | assert(style:size(1) == self.nOutput) 35 | N = 1 36 | Hc, Wc = content:size(2), content:size(3) 37 | Hs, Ws = style:size(2), style:size(3) 38 | content = content:view(1, self.nOutput, Hc, Wc) 39 | style = style:view(1, self.nOutput, Hs, Ws) 40 | elseif content:nDimension() == 4 then 41 | assert(content:size(1) == style:size(1)) 42 | assert(content:size(2) == self.nOutput) 43 | assert(style:size(2) == self.nOutput) 44 | N = content:size(1) 45 | Hc, Wc = content:size(3), content:size(4) 46 | Hs, Ws = style:size(3), style:size(4) 47 | end 48 | 49 | -- compute target mean and standard deviation from the style input 50 | local styleView = style:view(N, self.nOutput, Hs*Ws) 51 | local targetStd = styleView:std(3, true):view(-1) 52 | local targetMean = styleView:mean(3):view(-1) 53 | 54 | -- construct the internal BN layer 55 | if N ~= self.batchSize or (self.bn and self:type() ~= self.bn:type()) then 56 | self.bn = nn.SpatialBatchNormalization(N * self.nOutput, self.eps) 57 | self.bn:type(self:type()) 58 | self.batchSize = N 59 | end 60 | 61 | -- set affine params for the internal BN layer 62 | self.bn.weight:copy(targetStd) 63 | self.bn.bias:copy(targetMean) 64 | 65 | local contentView = content:view(1, N * self.nOutput, Hc, Wc) 66 | self.bn:training() 67 | self.output = self.bn:forward(contentView):viewAs(content) 68 | return self.output 69 | end 70 | 71 | function AdaptiveInstanceNormalization:updateGradInput(input, gradOutput) 72 | -- Not implemented 73 | self.gradInput = nil 74 | return self.gradInput 75 | end 76 | 77 | function AdaptiveInstanceNormalization:clearState() 78 | self.output = self.output.new() 79 | self.gradInput[1] = self.gradInput[1].new() 80 | self.gradInput[2] = self.gradInput[2].new() 81 | if self.bn then self.bn:clearState() end 82 | end 83 | -------------------------------------------------------------------------------- /lib/ContentLossModule.lua: -------------------------------------------------------------------------------- 1 | require 'lib/InstanceNormalization.lua' 2 | 3 | local module, parent = torch.class('nn.ContentLossModule', 'nn.Module') 4 | 5 | function module:__init(strength, normalize, nChannel) 6 | parent.__init(self) 7 | self.normalize = normalize or false 8 | self.strength = strength or 1 9 | self.target = nil 10 | self.loss = 0 11 | self.nC = nChannel 12 | self.crit = nn.MSECriterion() 13 | end 14 | 15 | function module:setTarget(target_features) 16 | if target_features:nDimension()==3 then 17 | local C,H,W = target_features:size(1), target_features:size(2), target_features:size(3) 18 | target_features = target_features:view(1,C,H,W) 19 | end 20 | self.target = target_features:clone() 21 | end 22 | 23 | function module:unsetTarget() 24 | self.target = nil 25 | end 26 | 27 | function module:updateOutput(input) 28 | self.output = input 29 | if self.target ~= nil then 30 | if input:nDimension() == 3 then 31 | local C,H,W = input:size(1), input:size(2), input:size(3) 32 | input = input:view(1,C,H,W) 33 | end 34 | assert(input:nDimension()==4) 35 | local N,C,H,W = self.target:size(1), self.target:size(2), self.target:size(3), self.target:size(4) 36 | assert(input:isSameSizeAs(self.target), 37 | string.format('Input size (%d x %d x %d x %d) does not match target size (%d x %d x %d x %d)', 38 | input:size(1),input:size(2),input:size(3),input:size(4), 39 | N,C,H,W)) 40 | self.loss = self.crit:forward(input, self.target) 41 | self.loss = self.loss * self.strength 42 | end 43 | return self.output 44 | end 45 | 46 | function module:updateGradInput(input, gradOutput) 47 | if self.target ~= nil then 48 | local nInputDim = input:nDimension() 49 | if input:nDimension() == 3 then 50 | local C,H,W = input:size(1), input:size(2), input:size(3) 51 | input = input:view(1,C,H,W) 52 | end 53 | assert(input:nDimension()==4) 54 | local N,C,H,W = self.target:size(1), self.target:size(2), self.target:size(3), self.target:size(4) 55 | assert(input:isSameSizeAs(self.target), 56 | string.format('Input size (%d x %d x %d x %d) does not match target size (%d x %d x %d x %d)', 57 | input:size(1),input:size(2),input:size(3),input:size(4), 58 | N,C,H,W)) 59 | self.gradInput = self.crit:backward(input, self.target):clone() 60 | 61 | if self.normalize then 62 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) 63 | end 64 | 65 | if nInputDim == 3 then 66 | local C,H,W = input:size(2), input:size(3), input:size(4) 67 | self.gradInput = self.gradInput:view(C,H,W) 68 | end 69 | 70 | self.gradInput:mul(self.strength) 71 | self.gradInput:add(gradOutput) 72 | else 73 | self.gradInput = gradOutput 74 | end 75 | return self.gradInput 76 | end 77 | -------------------------------------------------------------------------------- /lib/ImageLoader.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'paths' 3 | require 'image' 4 | 5 | local ImageLoader = torch.class('ImageLoader') 6 | 7 | function ImageLoader:__init(dir) 8 | local files = paths.dir(dir) 9 | local i = 1 10 | while i <= #files do 11 | if not string.find(files[i], 'jpg$') 12 | and not string.find(files[i], 'png$') 13 | and not string.find(files[i], 'ppm$')then 14 | table.remove(files, i) 15 | else 16 | i = i +1 17 | end 18 | end 19 | self.dir = dir 20 | self.files = files 21 | self:rebatch() 22 | end 23 | 24 | function ImageLoader:rebatch() 25 | self.perm = torch.randperm(#self.files) 26 | self.idx = 1 27 | end 28 | 29 | function ImageLoader:next() 30 | -- load image 31 | local img = nil 32 | local name 33 | while true do 34 | if self.idx > #self.files then self:rebatch() end 35 | local i = self.perm[self.idx] 36 | name = self.files[i] 37 | local loc = paths.concat(self.dir, name) 38 | self.idx = self.idx + 1 39 | local status,err = pcall(function() img = image.load(loc,3,'float') end) 40 | if status then 41 | if self.verbose then 42 | print('Loaded ' .. self.files[i]) 43 | end 44 | break 45 | else 46 | io.stderr:write('WARN: Failed to load ' .. loc .. ' due to error: ' .. err .. '\n') 47 | end 48 | end 49 | 50 | -- preprocess 51 | local H, W = img:size(2), img:size(3) 52 | if self.len then 53 | img = image.scale(img, self.len) 54 | elseif self.max_len then 55 | if H > self.max_len or W > self.max_len then 56 | img = image.scale(img, self.max_len) 57 | end 58 | end 59 | 60 | H, W = img:size(2), img:size(3) 61 | if self.div then 62 | local Hc = math.floor(H / self.div) * self.div 63 | local Wc = math.floor(W / self.div) * self.div 64 | img = self:_randomCrop(img, Hc, Wc) 65 | end 66 | 67 | if self.bnw then 68 | img = image.rgb2yuv(img) 69 | img[2]:zero() 70 | img[3]:zero() 71 | img = image.yuv2rgb(img) 72 | end 73 | 74 | collectgarbage() 75 | return img, name 76 | end 77 | 78 | function ImageLoader:setVerbose(verbose) 79 | verbose = verbose or true 80 | self.verbose = verbose 81 | end 82 | 83 | function ImageLoader:setFitToHeightOrWidth(len) 84 | assert(len ~= nil) 85 | self.len = len 86 | self.max_len = nil 87 | end 88 | 89 | function ImageLoader:setMaximumSize(max_len) 90 | assert(max_len ~= nil) 91 | self.max_len = max_len 92 | self.len = nil 93 | end 94 | 95 | function ImageLoader:setDivisibleBy(div) 96 | assert(div ~= nil) 97 | self.div = div 98 | end 99 | 100 | function ImageLoader:_randomCrop(img, oheight, owidth) 101 | assert(img:dim()==3) 102 | local H,W = img:size(2), img:size(3) 103 | assert(oheight <= H) 104 | assert(owidth <= W) 105 | local y = torch.floor(torch.uniform(0, H-oheight+1)) 106 | local x = torch.floor(torch.uniform(0, W-owidth+1)) 107 | local crop_img = image.crop(img, x,y, x+owidth, y+oheight) 108 | return crop_img 109 | end 110 | 111 | function ImageLoader:setBlackNWhite(bool) 112 | if bool then 113 | self.bnw = true 114 | else 115 | self.bnw = false 116 | end 117 | end 118 | -------------------------------------------------------------------------------- /lib/InstanceNormalization.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | --[[ 4 | Copied from https://github.com/jcjohnson/fast-neural-style . 5 | ---------------------------------------------------------------- 6 | Implements instance normalization as described in the paper 7 | 8 | Instance Normalization: The Missing Ingredient for Fast Stylization 9 | Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky 10 | https://arxiv.org/abs/1607.08022 11 | This implementation is based on 12 | https://github.com/DmitryUlyanov/texture_nets 13 | ]] 14 | 15 | local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module') 16 | 17 | function InstanceNormalization:__init(nOutput, disabled, eps, affine) 18 | parent.__init(self) 19 | 20 | self.eps = eps or 1e-5 21 | 22 | if affine ~= nil then 23 | assert(type(affine) == 'boolean', 'affine has to be true/false') 24 | self.affine = affine 25 | else 26 | self.affine = true 27 | end 28 | 29 | if self.affine then 30 | self.weight = torch.Tensor(nOutput):uniform() 31 | self.bias = torch.Tensor(nOutput):zero() 32 | self.gradWeight = torch.Tensor(nOutput) 33 | self.gradBias = torch.Tensor(nOutput) 34 | end 35 | 36 | self.nOutput = nOutput 37 | self.prev_N = -1 38 | self.disabled = disabled 39 | end 40 | 41 | function InstanceNormalization:updateOutput(input) 42 | if self.disabled then 43 | self.output = input:clone() 44 | return self.output 45 | end 46 | 47 | local N,C,H,W 48 | if input:nDimension() == 3 then 49 | N,C,H,W = 1, input:size(1), input:size(2), input:size(3) 50 | elseif input:nDimension() == 4 then 51 | N, C = input:size(1), input:size(2) 52 | H, W = input:size(3), input:size(4) 53 | end 54 | assert(C == self.nOutput) 55 | 56 | if N ~= self.prev_N or (self.bn and self:type() ~= self.bn:type()) then 57 | self.bn = nn.SpatialBatchNormalization(N * C, self.eps, nil, self.affine) 58 | self.bn:type(self:type()) 59 | self.prev_N = N 60 | end 61 | 62 | -- Set params for BN 63 | if self.affine then 64 | self.bn.weight:repeatTensor(self.weight, N) 65 | self.bn.bias:repeatTensor(self.bias, N) 66 | end 67 | 68 | local input_view = input:view(1, N * C, H, W) 69 | self.bn:training() 70 | self.output = self.bn:forward(input_view):viewAs(input) 71 | 72 | return self.output 73 | end 74 | 75 | 76 | function InstanceNormalization:updateGradInput(input, gradOutput) 77 | local N,C,H,W 78 | if input:nDimension() == 3 then 79 | N,C,H,W = 1, input:size(1), input:size(2), input:size(3) 80 | elseif input:nDimension() == 4 then 81 | N, C = input:size(1), input:size(2) 82 | H, W = input:size(3), input:size(4) 83 | end 84 | assert(self.bn) 85 | 86 | local input_view = input:view(1, N * C, H, W) 87 | local gradOutput_view = gradOutput:view(1, N * C, H, W) 88 | 89 | if self.affine then 90 | self.bn.gradWeight:zero() 91 | self.bn.gradBias:zero() 92 | end 93 | 94 | self.bn:training() 95 | self.gradInput = self.bn:backward(input_view, gradOutput_view):viewAs(input) 96 | 97 | if self.affine then 98 | self.gradWeight:add(self.bn.gradWeight:view(N, C):sum(1)) 99 | self.gradBias:add(self.bn.gradBias:view(N, C):sum(1)) 100 | end 101 | return self.gradInput 102 | end 103 | 104 | 105 | function InstanceNormalization:clearState() 106 | self.output = self.output.new() 107 | self.gradInput = self.gradInput.new() 108 | if self.bn then self.bn:clearState() end 109 | end 110 | -------------------------------------------------------------------------------- /lib/utils.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'lfs' 3 | 4 | local function matSqrt(x) 5 | local U,D,V = torch.svd(x) 6 | local result = U*(D:pow(0.5):diag())*V:t() 7 | return result 8 | end 9 | 10 | -- Prepares an RGB image in [0,1] for VGG 11 | function getPreprocessConv() 12 | local mean_pixel = torch.Tensor({103.939, 116.779, 123.68}) 13 | local conv = nn.SpatialConvolution(3,3, 1,1) 14 | conv.weight:zero() 15 | conv.weight[{1,3}] = 255 16 | conv.weight[{2,2}] = 255 17 | conv.weight[{3,1}] = 255 18 | conv.bias = -mean_pixel 19 | conv.gradBias = nil 20 | conv.gradWeight = nil 21 | conv.parameters = function() --[[nop]] end 22 | conv.accGradParameters = function() --[[nop]] end 23 | return conv 24 | end 25 | 26 | function extractImageNamesRecursive(dir) 27 | local files = {} 28 | print("Extracting image paths: " .. dir) 29 | 30 | local function browseFolder(root, pathTable) 31 | for entity in lfs.dir(root) do 32 | if entity~="." and entity~=".." then 33 | local fullPath=root..'/'..entity 34 | local mode=lfs.attributes(fullPath,"mode") 35 | if mode=="file" then 36 | local filepath = paths.concat(root, entity) 37 | 38 | if string.find(filepath, 'jpg$') 39 | or string.find(filepath, 'png$') 40 | or string.find(filepath, 'jpeg$') 41 | or string.find(filepath, 'JPEG$') 42 | or string.find(filepath, 'ppm$') then 43 | table.insert(pathTable, filepath) 44 | end 45 | elseif mode=="directory" then 46 | browseFolder(fullPath, pathTable); 47 | end 48 | end 49 | end 50 | end 51 | 52 | browseFolder(dir, files) 53 | return files 54 | end 55 | 56 | 57 | -- image size preprocessing 58 | function sizePreprocess(x, newSize) 59 | assert(x:dim() == 3) 60 | if newSize ~= 0 then 61 | --x = image.scale(x, '^' .. newSize) 62 | x = image.scale(x, newSize, 'bilinear') 63 | end 64 | return x 65 | end 66 | 67 | 68 | -- copied from torchx: https://github.com/nicholas-leonard/torchx/blob/master/find.lua 69 | function torch.find(tensor, val, dim) 70 | local i = 1 71 | local indice = {} 72 | if dim then 73 | assert(tensor:dim() == 2, "torch.find dim arg only supports matrices for now") 74 | assert(dim == 2, "torch.find only supports dim=2 for now") 75 | 76 | local colSize, rowSize = tensor:size(1), tensor:size(2) 77 | local rowIndice = {} 78 | tensor:apply(function(x) 79 | if x == val then 80 | table.insert(rowIndice, i) 81 | end 82 | if i == rowSize then 83 | i = 1 84 | table.insert(indice, rowIndice) 85 | rowIndice = {} 86 | else 87 | i = i + 1 88 | end 89 | end) 90 | else 91 | tensor:apply(function(x) 92 | if x == val then 93 | table.insert(indice, i) 94 | end 95 | i = i + 1 96 | end) 97 | end 98 | return indice 99 | end 100 | 101 | function torch.add_dummy(self) 102 | local sz = self:size() 103 | local new_sz = torch.Tensor(sz:size()+1) 104 | new_sz[1] = 1 105 | new_sz:narrow(1,2,sz:size()):copy(torch.Tensor{sz:totable()}) 106 | return self:view(new_sz:long():storage()) 107 | end 108 | 109 | function torch.FloatTensor:add_dummy() 110 | return torch.add_dummy(self) 111 | end 112 | function torch.DoubleTensor:add_dummy() 113 | return torch.add_dummy(self) 114 | end 115 | 116 | --[[ 117 | if params.gpu >= 0 then 118 | if params.backend ~= 'clnn' then 119 | function torch.CudaTensor:add_dummy() 120 | return torch.add_dummy(self) 121 | end 122 | else 123 | function torch.ClTensor:add_dummy() 124 | return torch.add_dummy(self) 125 | end 126 | end 127 | end 128 | --]] 129 | ------------------------------------------------ 130 | -------------------------------------------------------------------------------- /lib/NonparametricPatchAutoencoderFactory.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local NonparametricPatchAutoencoderFactory = torch.class('NonparametricPatchAutoencoderFactory') 4 | 5 | function NonparametricPatchAutoencoderFactory.buildAutoencoder(target_img, patch_size, stride, shuffle, normalize, interpolate) 6 | local nDim = 3 7 | assert(target_img:nDimension() == nDim, 'target image must be of dimension 3.') 8 | 9 | patch_size = patch_size or 3 10 | stride = stride or 1 11 | 12 | local type = target_img:type() 13 | local C = target_img:size(nDim-2) 14 | local target_patches = NonparametricPatchAutoencoderFactory._extract_patches(target_img, patch_size, stride, shuffle) 15 | local npatches = target_patches:size(1) 16 | 17 | local conv_enc, conv_dec = NonparametricPatchAutoencoderFactory._build(patch_size, stride, C, target_patches, npatches, normalize, interpolate) 18 | return conv_enc, conv_dec 19 | end 20 | 21 | function NonparametricPatchAutoencoderFactory._build(patch_size, stride, C, target_patches, npatches, normalize, interpolate) 22 | -- for each patch, divide by its L2 norm. 23 | local enc_patches = target_patches:clone() 24 | for i=1,npatches do 25 | enc_patches[i]:mul(1/(torch.norm(enc_patches[i],2)+1e-8)) 26 | end 27 | 28 | ---- Convolution for computing the semi-normalized cross correlation ---- 29 | local conv_enc = nn.SpatialConvolution(C, npatches, patch_size, patch_size, stride, stride):noBias() 30 | conv_enc.weight = enc_patches 31 | conv_enc.gradWeight = nil 32 | conv_enc.accGradParameters = __nop__ 33 | conv_enc.parameters = __nop__ 34 | 35 | if normalize then 36 | -- normalize each cross-correlation term by L2-norm of the input 37 | local aux = conv_enc:clone() 38 | aux.weight:fill(1) 39 | aux.gradWeight = nil 40 | aux.accGradParameters = __nop__ 41 | aux.parameters = __nop__ 42 | local compute_L2 = nn.Sequential() 43 | compute_L2:add(nn.Square()) 44 | compute_L2:add(aux) 45 | compute_L2:add(nn.Sqrt()) 46 | 47 | local normalized_conv_enc = nn.Sequential() 48 | local concat = nn.ConcatTable() 49 | concat:add(conv_enc) 50 | concat:add(compute_L2) 51 | normalized_conv_enc:add(concat) 52 | normalized_conv_enc:add(nn.CDivTable()) 53 | normalized_conv_enc.nInputPlane = conv_enc.nInputPlane 54 | normalized_conv_enc.nOutputPlane = conv_enc.nOutputPlane 55 | conv_enc = normalized_conv_enc 56 | end 57 | 58 | ---- Backward convolution for one patch ---- 59 | local conv_dec = nn.SpatialFullConvolution(npatches, C, patch_size, patch_size, stride, stride) --:noBias() 60 | conv_dec.weight = target_patches 61 | conv_dec.gradWeight = nil 62 | conv_dec.accGradParameters = __nop__ 63 | conv_dec.parameters = __nop__ 64 | 65 | -- normalize input so the result of each pixel location is a 66 | -- weighted combination of the backward conv filters, where 67 | -- the weights sum to one and are proportional to the input. 68 | -- the result is an interpolation of all filters. 69 | if interpolate then 70 | local aux = nn.SpatialFullConvolution(1, 1, patch_size, patch_size, stride, stride) --:noBias() 71 | aux.weight:fill(1) 72 | aux.gradWeight = nil 73 | aux.accGradParameters = __nop__ 74 | aux.parameters = __nop__ 75 | 76 | local counting = nn.Sequential() 77 | counting:add(nn.Sum(1,3)) -- sum up the channels 78 | counting:add(nn.Unsqueeze(1,2)) -- add back the channel dim 79 | counting:add(aux) 80 | counting:add(nn.Squeeze(1,3)) 81 | counting:add(nn.Replicate(C,1,2)) -- replicates the channel dim C times. 82 | 83 | interpolating_conv_dec = nn.Sequential() 84 | local concat = nn.ConcatTable() 85 | concat:add(conv_dec) 86 | concat:add(counting) 87 | interpolating_conv_dec:add(concat) 88 | interpolating_conv_dec:add(nn.CDivTable()) 89 | interpolating_conv_dec.nInputPlane = conv_dec.nInputPlane 90 | interpolating_conv_dec.nOutputPlane = conv_dec.nOutputPlane 91 | conv_dec = interpolating_conv_dec 92 | end 93 | 94 | return conv_enc, conv_dec 95 | end 96 | 97 | function NonparametricPatchAutoencoderFactory._extract_patches(img, patch_size, stride, shuffle) 98 | local nDim = 3 99 | assert(img:nDimension() == nDim, 'image must be of dimension 3.') 100 | local C, H, W = img:size(nDim-2), img:size(nDim-1), img:size(nDim) 101 | local nH = math.floor( (H - patch_size)/stride + 1) 102 | local nW = math.floor( (W - patch_size)/stride + 1) 103 | 104 | -- extract patches 105 | local patches = torch.Tensor(nH*nW, C, patch_size, patch_size):typeAs(img) 106 | for i=1,nH*nW do 107 | local h = math.floor((i-1)/nW) -- zero-index 108 | local w = math.floor((i-1)%nW) -- zero-index 109 | patches[i] = img[{{}, 110 | {1 + h*stride, 1 + h*stride + patch_size-1}, 111 | {1 + w*stride, 1 + w*stride + patch_size-1} 112 | }] 113 | end 114 | 115 | if shuffle then 116 | local shuf = torch.randperm(patches:size(1)):long() 117 | patches = patches:index(1,shuf) 118 | end 119 | 120 | return patches 121 | end 122 | 123 | function __nop__() 124 | -- do nothing 125 | end 126 | -------------------------------------------------------------------------------- /lib/StyleINLossModule.lua: -------------------------------------------------------------------------------- 1 | --////////////////////////////////////////////////////////////////////// 2 | require 'nn' 3 | 4 | local module, parent = torch.class('nn.StyleINLossModule', 'nn.Module') 5 | 6 | function module:__init(strength, normalize, nChannel) 7 | parent.__init(self) 8 | self.normalize = normalize or false 9 | self.strength = strength or 1 10 | self.target_mean = nil 11 | self.target_std = nil 12 | self.mean_loss = 0 -- mean loss 13 | self.std_loss = 0 -- std loss 14 | self.loss = 0 15 | self.nC = nChannel 16 | 17 | self.std_net = nn.Sequential() -- assume the input is centered 18 | self.std_net:add(nn.Square()) 19 | self.std_net:add(nn.Mean(3)) 20 | self.std_net:add(nn.Sqrt(1e-6)) 21 | self.mean_net = nn.Sequential() 22 | self.mean_net:add(nn.Mean(3)) 23 | 24 | self.mean_criterion = nn.MSECriterion() 25 | self.mean_criterion.sizeAverage = false 26 | self.std_criterion = nn.MSECriterion() 27 | self.std_criterion.sizeAverage = false 28 | 29 | self.std_net = self.std_net:cuda() 30 | self.mean_net = self.mean_net:cuda() 31 | self.mean_criterion = self.mean_criterion:cuda() 32 | self.std_criterion = self.std_criterion:cuda() 33 | end 34 | 35 | function module:clearState() 36 | self.std_net:clearState() 37 | self.mean_net:clearState() 38 | return parent.clearState(self) 39 | end 40 | 41 | --///////////////////////////////////////////////////////////// 42 | function module:setTarget(target_features) 43 | if target_features:nDimension() == 3 then 44 | local C = target_features:size(1) 45 | target_features = target_features:view(1, C, -1) 46 | elseif target_features:nDimension() == 4 then 47 | local N,C = target_features:size(1), target_features:size(2) 48 | target_features = target_features:view(N, C, -1) 49 | else 50 | error('Target must be 3D or 4D') 51 | end 52 | self.target_mean = torch.mean(target_features, 3) -- N*C*1 53 | self.target_std = torch.std(target_features, 3, true) 54 | return self 55 | end 56 | 57 | function module:unsetTarget() 58 | self.target_mean = nil 59 | self.target_std = nil 60 | return self 61 | end 62 | 63 | function module:updateOutput(input) 64 | self.output = input 65 | if self.target_mean ~= nil and self.target_std ~= nil then 66 | 67 | if input:nDimension() == 3 then 68 | local C,H,W = input:size(1), input:size(2), input:size(3) 69 | input = input:view(1,C,H,W) 70 | end 71 | assert(input:nDimension()==4) 72 | 73 | local N,C,H,W = input:size(1), input:size(2), input:size(3), input:size(4) 74 | assert(input:size(2) == self.target_mean:size(2)) 75 | assert(input:size(2) == self.target_std:size(2)) 76 | 77 | local input_view = input:view(N, C, -1) 78 | if N < self.target_mean:size(1) then 79 | self.target_mean = self.target_mean[1] 80 | self.target_std = self.target_std[1] 81 | elseif N > self.target_mean:size(1) then 82 | self.target_mean = self.target_mean:expand(N,C,1) 83 | self.target_std = self.target_std:expand(N,C,1) 84 | end 85 | 86 | self.input_mean = self.mean_net:forward(input_view) 87 | self.input_centered = torch.add(input_view, -self.input_mean:view(N, C, 1):expand(N, C, H*W)) -- centered input 88 | self.input_std = self.std_net:forward(self.input_centered) 89 | 90 | self.mean_loss = self.mean_criterion:forward(self.input_mean, self.target_mean) 91 | self.std_loss = self.std_criterion:forward(self.input_std, self.target_std) 92 | self.mean_loss = self.mean_loss / N -- normalized w.r.t. batch size 93 | self.std_loss = self.std_loss / N -- normalized w.r.t. batch size 94 | self.loss = self.mean_loss + self.std_loss 95 | self.loss = self.loss * self.strength 96 | end 97 | return self.output 98 | end 99 | 100 | function module:updateGradInput(input, gradOutput) 101 | if self.target_mean ~= nil and self.target_std ~= nil then 102 | local nInputDim = input:nDimension() 103 | if input:nDimension() == 3 then 104 | local C,H,W = input:size(1), input:size(2), input:size(3) 105 | input = input:view(1,C,H,W) 106 | end 107 | assert(input:nDimension()==4) 108 | 109 | local N,C,H,W = input:size(1), input:size(2), input:size(3), input:size(4) 110 | assert(input:size(2) == self.target_mean:size(2)) 111 | assert(input:size(2) == self.target_std:size(2)) 112 | 113 | local input_view = input:view(N, C, -1) 114 | local mean_grad = self.mean_criterion:backward(self.input_mean, self.target_mean) 115 | local std_grad = self.std_criterion:backward(self.input_std, self.target_std) 116 | self.gradInput = self.mean_net:backward(input_view, mean_grad) 117 | local std_gradInput_centered = self.std_net:backward(self.input_centered, std_grad) 118 | local std_gradInput = std_gradInput_centered:add(-std_gradInput_centered:mean(3):expand(N, C, H*W)) 119 | self.gradInput:add(std_gradInput) 120 | self.gradInput = self.gradInput:view(N,C,H,W) 121 | self.gradInput:div(N) -- normalize w.r.t. batch size 122 | 123 | if self.normalize then 124 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) 125 | end 126 | 127 | if nInputDim == 3 then 128 | self.gradInput = self.gradInput:view(C,H,W) 129 | end 130 | 131 | self.gradInput:mul(self.strength) 132 | self.gradInput:add(gradOutput) 133 | else 134 | self.gradInput = gradOutput 135 | end 136 | return self.gradInput 137 | end 138 | -------------------------------------------------------------------------------- /lib/ImageLoaderAsync.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | local ImageLoaderAsync = torch.class('ImageLoaderAsync') 4 | 5 | local threads = require 'threads' 6 | 7 | local ImageLoader = {} 8 | local ImageLoader_mt = { __index = ImageLoader } 9 | 10 | ---- Asynchronous image loader. 11 | local result = {} 12 | local H, W 13 | local len 14 | 15 | function ImageLoaderAsync:__init(dir, batchSize, options, crop) 16 | if not batchSize then 17 | error('Predetermined batch size is required for asynchronous loader.') 18 | end 19 | options = options or {} 20 | local n = options.n or 1 21 | 22 | -- upvalues 23 | H,W = options.H, options.W 24 | len = options.len 25 | 26 | self.batchSize = batchSize 27 | self._type = 'torch.FloatTensor' 28 | 29 | -- initialize thread and its image loader 30 | self.threads = threads.Threads(n, 31 | function() 32 | imageLoader = ImageLoader:new(dir) 33 | if H ~= nil and W ~= nil then 34 | imageLoader:setWidthAndHeight(W,H) 35 | end 36 | if len ~= nil then 37 | imageLoader:setFitToHeightOrWidth(len) 38 | end 39 | imageLoader.crop = crop 40 | end) 41 | 42 | -- get size 43 | self.threads:addjob( 44 | function() return imageLoader:size() end, 45 | function(size) result[1] = size end) 46 | self.threads:dojob() 47 | self._size = result[1] 48 | result[1] = nil 49 | result[2] = nil 50 | result[3] = nil 51 | 52 | -- add job 53 | for i=1,n do 54 | self.threads:addjob(self.__getBatchFromThread, self.__pushResult, self.batchSize) 55 | end 56 | end 57 | 58 | function ImageLoaderAsync:size() 59 | return self._size 60 | end 61 | 62 | function ImageLoaderAsync:type(type) 63 | if not type then 64 | return self._type 65 | else 66 | assert(torch.Tensor():type(type), 'Invalid type ' .. type .. '?') 67 | self._type = type 68 | end 69 | return self 70 | end 71 | 72 | function ImageLoaderAsync.__getBatchFromThread(batchSize) 73 | a,b,c = imageLoader:nextBatch(batchSize) 74 | return a,b,c 75 | end 76 | 77 | function ImageLoaderAsync.__pushResult(batch, names) 78 | result[1] = batch 79 | result[2] = names 80 | end 81 | 82 | function ImageLoaderAsync:nextBatch() 83 | self.threads:addjob(self.__getBatchFromThread, self.__pushResult, self.batchSize) 84 | self.threads:dojob() 85 | local batch = result[1] 86 | result[1] = nil 87 | local names = result[2] 88 | result[2] = nil 89 | return batch:type(self._type), names 90 | end 91 | 92 | ---- Implementation of the actual image loader. 93 | function ImageLoader:new(dir) 94 | require 'torch' 95 | require 'paths' 96 | require 'image' 97 | require 'lib/utils' 98 | 99 | local imageLoader = {} 100 | setmetatable(imageLoader, ImageLoader_mt) 101 | files = extractImageNamesRecursive(dir) 102 | imageLoader.dir = dir 103 | imageLoader.files = files 104 | imageLoader.tm = torch.Timer() 105 | imageLoader.tm:reset() 106 | imageLoader:rebatch() 107 | return imageLoader 108 | end 109 | 110 | function ImageLoader:size() 111 | return #self.files 112 | end 113 | 114 | function ImageLoader:rebatch() 115 | self.perm = torch.randperm(self:size()) 116 | self.idx = 1 117 | end 118 | 119 | function ImageLoader:nextBatch(batchSize) 120 | local img, name = self:next() 121 | local batch = torch.FloatTensor(batchSize, 3, img:size(2), img:size(3)) 122 | local names = {} 123 | batch[1] = img 124 | table.insert(names, name) 125 | for i=2,batchSize do 126 | local temp, tempname = self:next() 127 | batch[i] = temp 128 | table.insert(names, tempname) 129 | end 130 | return batch, names 131 | end 132 | 133 | function ImageLoader:next() 134 | -- load image 135 | local img = nil 136 | local name 137 | local numErr = 0 138 | while true do 139 | if self.idx > self:size() then self:rebatch() end 140 | local i = self.perm[self.idx] 141 | self.idx = self.idx + 1 142 | name = self.files[i] 143 | local loc = paths.concat(self.dir, name) 144 | local status,err = pcall(function() img = image.load(loc,3,'float') end) -- load in range (0,1) 145 | if status then 146 | if self.verbose then print('Loaded ' .. self.files[i]) end 147 | break 148 | else 149 | io.stderr:write('WARNING: Failed to load ' .. loc .. ' due to error: ' .. err .. '\n') 150 | end 151 | end 152 | 153 | -- preprocess 154 | local H, W = img:size(2), img:size(3) 155 | 156 | if self.len ~= nil then -- resize without changing aspect ratio 157 | img = image.scale(img, "^" .. self.len) 158 | end 159 | 160 | if self.crop then 161 | img = self:_randomCrop(img, self.H, self.W) 162 | else 163 | if self.W and self.H then -- resize 164 | img = image.scale(img, self.W, self.H) 165 | elseif self.max_len then -- resize without changing aspect ratio 166 | if H > self.max_len or W > self.max_len then 167 | img = image.scale(img, self.max_len) 168 | end 169 | end 170 | end 171 | 172 | collectgarbage() 173 | return img, name, numErr 174 | end 175 | 176 | ---- Optional preprocessing 177 | function ImageLoader:setVerbose(verbose) 178 | verbose = verbose or true 179 | self.verbose = verbose 180 | end 181 | 182 | function ImageLoader:setWidthAndHeight(W,H) 183 | self.H = H 184 | self.W = W 185 | end 186 | 187 | function ImageLoader:setFitToHeightOrWidth(len) 188 | assert(len ~= nil) 189 | self.len = len 190 | self.max_len = nil 191 | end 192 | 193 | function ImageLoader:setMaximumSize(max_len) 194 | assert(max_len ~= nil) 195 | self.max_len = max_len 196 | self.len = nil 197 | end 198 | 199 | function ImageLoader:setDivisibleBy(div) 200 | assert(div ~= nil) 201 | self.div = div 202 | end 203 | 204 | function ImageLoader:_randomCrop(img, oheight, owidth) 205 | assert(img:dim()==3) 206 | local H,W = img:size(2), img:size(3) 207 | if oheight > H then 208 | print(oheight, H) 209 | error() 210 | end 211 | if owidth > W then 212 | print(owidth, W) 213 | error() 214 | end 215 | assert(oheight <= H) 216 | assert(owidth <= W) 217 | local y = torch.floor(torch.uniform(0, H-oheight+1)) 218 | local x = torch.floor(torch.uniform(0, W-owidth+1)) 219 | local crop_img = image.crop(img, x,y, x+owidth, y+oheight) 220 | return crop_img 221 | end 222 | -------------------------------------------------------------------------------- /lib/ArtisticStyleLossCriterion.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------- 2 | require 'nn' 3 | require 'lib/ContentLossModule' 4 | require 'lib/StyleINLossModule' 5 | require 'lib/TVLossModule' 6 | require 'lib/utils' 7 | ------------------------------------------------------------------------- 8 | local criterion, parent = torch.class('nn.ArtisticStyleLossCriterion', 'nn.Criterion') 9 | ------------------------------------------------------------------------- 10 | function criterion:__init(cnn, layers, weights, normalize) 11 | parent.__init(self) 12 | layers = layers or {} 13 | layers.content = layers.content or {} 14 | layers.style = layers.style or {} 15 | 16 | weights = weights or {} 17 | weights.content = weights.content or 0 18 | weights.style = weights.style or 0 19 | weights.tv = weights.tv or 0 20 | 21 | if weights.style <= 0 then 22 | layers.style = {} 23 | end 24 | if weights.content <= 0 then 25 | layers.content = {} 26 | end 27 | assert(#layers.content ==1,'Should have only one content layer') 28 | 29 | local net = nn.Sequential() 30 | local style_layers = {} 31 | local content_layers = {} 32 | local next_style_idx = 1 33 | local next_content_idx = 1 34 | -- Build encoder 35 | if weights.tv > 0 then 36 | local tv_mod = nn.TVLossModule(weights.tv) 37 | net:add(tv_mod) 38 | end 39 | local nop = function() end 40 | local prevC 41 | for i=1,cnn:size() do 42 | if next_style_idx <= #layers.style or 43 | next_content_idx <= #layers.content then -- STOP if all loss modules have been inserted 44 | local layer = cnn:get(i) 45 | local name = layer.name 46 | if torch.type(layer) == 'nn.SpatialConvolution' then 47 | -- Remove weight gradients because the encoder weights should be fixed. 48 | layer.accGradParameters = nop 49 | layer.gradWeight = nil 50 | layer.gradBias = nil 51 | prevC = layer.nOutputPlane 52 | end 53 | net:add(layer) 54 | -- Add loss modules 55 | if layers.style[next_style_idx] ~= nil and name == layers.style[next_style_idx] then 56 | local loss_module = nn.StyleINLossModule(weights.style, normalize, prevC) 57 | net:add(loss_module) 58 | table.insert(style_layers, loss_module) 59 | next_style_idx = next_style_idx + 1 60 | end 61 | if layers.content[next_content_idx] ~= nil and name == layers.content[next_content_idx] then 62 | local loss_module = nn.ContentLossModule(weights.content, normalize, prevC) 63 | net:add(loss_module) 64 | table.insert(content_layers, loss_module) 65 | next_content_idx = next_content_idx + 1 66 | end 67 | end 68 | end 69 | -- Error checking 70 | if next_style_idx < #layers.style then 71 | error('Could not find layer ' .. layers.style[next_style_idx]) 72 | end 73 | if next_content_idx < #layers.content then 74 | error('Could not find layer ' .. layers.content[next_content_idx]) 75 | end 76 | -- Prepare 77 | self.net = net 78 | self.style_layers = style_layers 79 | self.content_layers = content_layers 80 | self.dy = torch.Tensor() 81 | end 82 | 83 | function criterion:setTargets(targets) 84 | if targets.style == nil and targets.content == nil then 85 | error('Must provide either target.style or target.content images.') 86 | end 87 | self:unsetTargets() 88 | if targets.style ~= nil then 89 | self:setStyleTarget(targets.style) 90 | end 91 | if targets.content ~= nil then 92 | self:setContentTarget(targets.content) 93 | end 94 | end 95 | 96 | function criterion:setContentTarget(target) 97 | if #self.content_layers == 0 then return end 98 | if target == nil then 99 | error('Must provide target content image.') 100 | end 101 | assert(target:nDimension()==3 or target:nDimension()==4, 'Content target must be 3D or 4D (batch).') 102 | self.targets = self.targets or {} 103 | self.targets.content = target:clone() 104 | self.net:clearState() 105 | self.net:forward(self.targets.content) 106 | for i=1,#self.content_layers do 107 | local target_features = self.content_layers[i].output 108 | self.content_layers[i]:setTarget(target_features) 109 | end 110 | end 111 | 112 | function criterion:setStyleTarget(target) 113 | if #self.style_layers <= 0 then return end 114 | if target == nil then 115 | error('Must provide target style image.') 116 | end 117 | assert(target:nDimension()==3 or target:nDimension()==4, 'Content target must be 3D or 4D (batch).') 118 | self.targets = self.targets or {} 119 | self.targets.style = target:clone() 120 | 121 | -- temporarily remove content targets, else the module 122 | -- may error out due to incorrect size. 123 | local content_targets = {} 124 | for i=1,#self.content_layers do 125 | content_targets[i] = self.content_layers[i].target 126 | self.content_layers[i].target = nil 127 | end 128 | 129 | self.net:clearState() 130 | self.net:forward(self.targets.style) 131 | for i=1,#self.style_layers do 132 | local target_features = self.style_layers[i].output 133 | self.style_layers[i]:setTarget(target_features) 134 | end 135 | 136 | -- reset the content targets 137 | for i=1,#self.content_layers do 138 | self.content_layers[i].target = content_targets[i] 139 | end 140 | end 141 | 142 | function criterion:unsetTargets() 143 | for i=1,#self.style_layers do 144 | self.style_layers[i]:unsetTarget() 145 | end 146 | for i=1,#self.content_layers do 147 | self.content_layers[i]:unsetTarget() 148 | end 149 | end 150 | 151 | --[[ 152 | Assumes input and target are both C x H x W images. (C=3) 153 | Batch mode optional. 154 | --]] 155 | function criterion:updateOutput(input, targets) 156 | self.recompute_gradInput = true 157 | if not self.targets then self:setTargets(targets) end 158 | self.net:forward(input) 159 | -- accumulate losses from the style loss layers 160 | local styleLoss = 0 161 | local contentLoss = 0 162 | for _, mod in ipairs(self.style_layers) do 163 | styleLoss = styleLoss + mod.loss 164 | end 165 | for _, mod in ipairs(self.content_layers) do 166 | contentLoss = contentLoss + mod.loss 167 | end 168 | self.styleLoss = styleLoss 169 | self.contentLoss = contentLoss 170 | self.output = styleLoss+contentLoss 171 | return self.output 172 | end 173 | 174 | function criterion:updateGradInput(input, targets) 175 | if self.recompute_gradInput then 176 | local dy = self.dy:typeAs(self.net.output):resizeAs(self.net.output):zero() 177 | local grad = self.net:backward(input, dy) 178 | self.gradInput = grad:clone() 179 | -- reset targets 180 | if not self.targets then self:unsetTargets() end 181 | end 182 | self.recompute_gradInput = false 183 | return self.gradInput 184 | end 185 | -------------------------------------------------------------------------------- /runCalcLoss.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'image' 4 | require 'optim' 5 | 6 | require 'loadcaffe' 7 | 8 | local cmd = torch.CmdLine() 9 | 10 | cmd:option('-style_image', 'input/s035.jpg','Style image') 11 | cmd:option('-content_image','input/c007.jpg','Content image') 12 | cmd:option('-output_image','input/c007_s035_reshuffle.jpg','Output image') 13 | 14 | cmd:option('-gpu', '0', 'Zero-indexed ID of the GPU to use; for CPU mode set -gpu = -1') 15 | cmd:option('-pooling', 'max', 'max|avg') 16 | cmd:option('-proto_file', 'models/vgg19_deploy.prototxt') 17 | cmd:option('-model_file', 'models/vgg19.caffemodel') 18 | cmd:option('-backend', 'nn', 'nn|cudnn|clnn') 19 | cmd:option('-seed', -1) 20 | 21 | cmd:option('-content_layers', 'relu4_2', 'layers for content') 22 | cmd:option('-style_layers', 'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', 'layers for style') 23 | 24 | local function main(params) 25 | 26 | local dtype, multigpu = setup_gpu(params) 27 | local loadcaffe_backend = params.backend 28 | local cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend):type(dtype) 29 | 30 | local content_image = image.load(params.content_image, 3) 31 | local content_image_caffe = preprocess(content_image):float() 32 | 33 | local style_image = image.load(params.style_image,3) 34 | local style_image_caffe = preprocess(style_image):float() 35 | 36 | local output_image = image.load(params.output_image,3) 37 | local output_image_caffe = preprocess(output_image):float() 38 | 39 | local content_layers = params.content_layers:split(",") 40 | local style_layers = params.style_layers:split(",") 41 | 42 | -- Set up the network, inserting style and content loss modules 43 | local content_losses, style_losses = {}, {} 44 | local next_content_idx, next_style_idx = 1, 1 45 | local net = nn.Sequential() 46 | 47 | for i = 1, #cnn do 48 | if next_content_idx <= #content_layers or next_style_idx <= #style_layers then 49 | local layer = cnn:get(i) 50 | local name = layer.name 51 | local layer_type = torch.type(layer) 52 | local is_pooling = (layer_type == 'cudnn.SpatialMaxPooling' or layer_type == 'nn.SpatialMaxPooling') 53 | if is_pooling and params.pooling == 'avg' then 54 | assert(layer.padW == 0 and layer.padH == 0) 55 | local kW, kH = layer.kW, layer.kH 56 | local dW, dH = layer.dW, layer.dH 57 | local avg_pool_layer = nn.SpatialAveragePooling(kW, kH, dW, dH):type(dtype) 58 | local msg = 'Replacing max pooling at layer %d with average pooling' 59 | print(string.format(msg, i)) 60 | net:add(avg_pool_layer) 61 | else 62 | net:add(layer) 63 | end 64 | if name == content_layers[next_content_idx] then 65 | print("Setting up content layer", i, ":", layer.name) 66 | 67 | local loss_module = nn.ContentLoss(1.0,false):type(dtype) 68 | 69 | net:add(loss_module) 70 | table.insert(content_losses, loss_module) 71 | next_content_idx = next_content_idx + 1 72 | end 73 | if name == style_layers[next_style_idx] then 74 | print("Setting up style layer ", i, ":", layer.name) 75 | 76 | local loss_module = nn.StyleLoss(1.0, false):type(dtype) 77 | 78 | net:add(loss_module) 79 | table.insert(style_losses, loss_module) 80 | next_style_idx = next_style_idx + 1 81 | end 82 | end 83 | end 84 | net:type(dtype) 85 | print(net) 86 | 87 | -- Capture content targets 88 | for i = 1, #content_losses do 89 | content_losses[i].mode = 'capture' 90 | end 91 | print 'Capturing content targets' 92 | content_image_caffe = content_image_caffe:type(dtype) 93 | net:forward(content_image_caffe:type(dtype)) 94 | 95 | -- Capture style targets 96 | for i = 1, #content_losses do 97 | content_losses[i].mode = 'none' 98 | end 99 | 100 | for j = 1, #style_losses do 101 | style_losses[j].mode = 'capture' 102 | style_losses[j].blend_weight = 1.0 103 | end 104 | net:forward(style_image_caffe:type(dtype)) 105 | 106 | -- Set all loss modules to loss mode 107 | for i = 1, #content_losses do 108 | content_losses[i].mode = 'loss' 109 | end 110 | for i = 1, #style_losses do 111 | style_losses[i].mode = 'loss' 112 | end 113 | 114 | -- We don't need the base CNN anymore, so clean it up to save memory. 115 | cnn = nil 116 | for i=1, #net.modules do 117 | local module = net.modules[i] 118 | if torch.type(module) == 'nn.SpatialConvolutionMM' then 119 | module.gradWeight = nil 120 | module.gradBias = nil 121 | end 122 | end 123 | collectgarbage() 124 | 125 | -- Initialize the image 126 | if params.seed >= 0 then 127 | torch.manualSeed(params.seed) 128 | end 129 | 130 | -- forward and print losses 131 | output_image_caffe = output_image_caffe:type(dtype) 132 | local y = net:forward(output_image_caffe) 133 | 134 | for i, loss_module in ipairs(content_losses) do 135 | print(string.format(' Content %d loss: %f', i, loss_module.loss)) 136 | end 137 | 138 | for i, loss_module in ipairs(style_losses) do 139 | print(string.format(' Style %d loss: %f', i, loss_module.loss)) 140 | end 141 | end 142 | 143 | --///// 144 | function setup_gpu(params) 145 | 146 | local multigpu = false 147 | params.gpu = tonumber(params.gpu) + 1 148 | local dtype = 'torch.FloatTensor' 149 | require 'cutorch' 150 | require 'cunn' 151 | cutorch.setDevice(params.gpu) 152 | dtype = 'torch.CudaTensor' 153 | params.backend = 'nn' 154 | return dtype, multigpu 155 | end 156 | 157 | --///// 158 | function setup_multi_gpu(net, params) 159 | local DEFAULT_STRATEGIES = { 160 | [2] = {3}, 161 | } 162 | local gpu_splits = nil 163 | if params.multigpu_strategy == '' then 164 | -- Use a default strategy 165 | gpu_splits = DEFAULT_STRATEGIES[#params.gpu] 166 | -- Offset the default strategy by one if we are using TV 167 | if params.tv_weight > 0 then 168 | for i = 1, #gpu_splits do gpu_splits[i] = gpu_splits[i] + 1 end 169 | end 170 | else 171 | -- Use the user-specified multigpu strategy 172 | gpu_splits = params.multigpu_strategy:split(',') 173 | for i = 1, #gpu_splits do 174 | gpu_splits[i] = tonumber(gpu_splits[i]) 175 | end 176 | end 177 | assert(gpu_splits ~= nil, 'Must specify -multigpu_strategy') 178 | local gpus = params.gpu 179 | 180 | local cur_chunk = nn.Sequential() 181 | local chunks = {} 182 | for i = 1, #net do 183 | cur_chunk:add(net:get(i)) 184 | if i == gpu_splits[1] then 185 | table.remove(gpu_splits, 1) 186 | table.insert(chunks, cur_chunk) 187 | cur_chunk = nn.Sequential() 188 | end 189 | end 190 | table.insert(chunks, cur_chunk) 191 | assert(#chunks == #gpus) 192 | 193 | local new_net = nn.Sequential() 194 | for i = 1, #chunks do 195 | local out_device = nil 196 | if i == #chunks then 197 | out_device = gpus[1] 198 | end 199 | new_net:add(nn.GPU(chunks[i], gpus[i], out_device)) 200 | end 201 | 202 | return new_net 203 | end 204 | 205 | --///////////////////// 206 | function build_filename(output_image, iteration) 207 | local ext = paths.extname(output_image) 208 | local basename = paths.basename(output_image, ext) 209 | local directory = paths.dirname(output_image) 210 | return string.format('%s/%s_%d.%s',directory, basename, iteration, ext) 211 | end 212 | 213 | 214 | -- Preprocess an image before passing it to a Caffe model. 215 | -- We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, 216 | -- and subtract the mean pixel. 217 | function preprocess(img) 218 | local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) 219 | local perm = torch.LongTensor{3, 2, 1} 220 | img = img:index(1, perm):mul(256.0) 221 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 222 | img:add(-1, mean_pixel) 223 | return img 224 | end 225 | 226 | 227 | -- Undo the above preprocessing. 228 | function deprocess(img) 229 | local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) 230 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 231 | img = img + mean_pixel 232 | local perm = torch.LongTensor{3, 2, 1} 233 | img = img:index(1, perm):div(256.0) 234 | return img 235 | end 236 | 237 | 238 | -- Combine the Y channel of the generated image and the UV channels of the 239 | -- content image to perform color-independent style transfer. 240 | function original_colors(content, generated) 241 | local generated_y = image.rgb2yuv(generated)[{{1, 1}}] 242 | local content_uv = image.rgb2yuv(content)[{{2, 3}}] 243 | return image.yuv2rgb(torch.cat(generated_y, content_uv, 1)) 244 | end 245 | 246 | 247 | -- Define an nn Module to compute content loss in-place 248 | local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module') 249 | 250 | function ContentLoss:__init(strength, normalize) 251 | parent.__init(self) 252 | self.strength = strength 253 | self.target = torch.Tensor() 254 | self.normalize = normalize or false 255 | self.loss = 0 256 | self.crit = nn.MSECriterion() 257 | self.mode = 'none' 258 | end 259 | 260 | function ContentLoss:updateOutput(input) 261 | if self.mode == 'loss' then 262 | self.loss = self.crit:forward(input, self.target) * self.strength 263 | 264 | self.loss = self.loss/input:nElement() 265 | 266 | elseif self.mode == 'capture' then 267 | self.target:resizeAs(input):copy(input) 268 | end 269 | self.output = input 270 | return self.output 271 | end 272 | 273 | function ContentLoss:updateGradInput(input, gradOutput) 274 | if self.mode == 'loss' then 275 | if input:nElement() == self.target:nElement() then 276 | self.gradInput = self.crit:backward(input, self.target) 277 | end 278 | if self.normalize then 279 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) 280 | end 281 | self.gradInput:mul(self.strength) 282 | self.gradInput:add(gradOutput) 283 | else 284 | self.gradInput:resizeAs(gradOutput):copy(gradOutput) 285 | end 286 | return self.gradInput 287 | end 288 | 289 | 290 | local Gram, parent = torch.class('nn.GramMatrix', 'nn.Module') 291 | 292 | function Gram:__init() 293 | parent.__init(self) 294 | end 295 | 296 | function Gram:updateOutput(input) 297 | assert(input:dim() == 3) 298 | local C, H, W = input:size(1), input:size(2), input:size(3) 299 | local x_flat = input:view(C, H * W) 300 | self.output:resize(C, C) 301 | self.output:mm(x_flat, x_flat:t()) 302 | return self.output 303 | end 304 | 305 | function Gram:updateGradInput(input, gradOutput) 306 | assert(input:dim() == 3 and input:size(1)) 307 | local C, H, W = input:size(1), input:size(2), input:size(3) 308 | local x_flat = input:view(C, H * W) 309 | self.gradInput:resize(C, H * W):mm(gradOutput, x_flat) 310 | self.gradInput:addmm(gradOutput:t(), x_flat) 311 | self.gradInput = self.gradInput:view(C, H, W) 312 | return self.gradInput 313 | end 314 | 315 | 316 | -- Define an nn Module to compute style loss in-place 317 | local StyleLoss, parent = torch.class('nn.StyleLoss', 'nn.Module') 318 | 319 | function StyleLoss:__init(strength, normalize) 320 | parent.__init(self) 321 | self.normalize = normalize or false 322 | self.strength = strength 323 | self.target = torch.Tensor() 324 | self.mode = 'none' 325 | self.loss = 0 326 | 327 | self.gram = nn.GramMatrix() 328 | self.blend_weight = nil 329 | self.G = nil 330 | self.crit = nn.MSECriterion() 331 | end 332 | 333 | function StyleLoss:updateOutput(input) 334 | 335 | self.G = self.gram:forward(input) 336 | 337 | self.G:div(input:nElement()) -- 338 | 339 | if self.mode == 'capture' then 340 | if self.blend_weight == nil then 341 | self.target:resizeAs(self.G):copy(self.G) 342 | elseif self.target:nElement() == 0 then 343 | self.target:resizeAs(self.G):copy(self.G):mul(self.blend_weight) 344 | else 345 | self.target:add(self.blend_weight, self.G) 346 | end 347 | elseif self.mode == 'loss' then 348 | 349 | self.loss = self.strength * self.crit:forward(self.G, self.target) 350 | 351 | 352 | end 353 | self.output = input 354 | return self.output 355 | end 356 | 357 | function StyleLoss:updateGradInput(input, gradOutput) 358 | if self.mode == 'loss' then 359 | local dG = self.crit:backward(self.G, self.target) 360 | dG:div(input:nElement()) 361 | self.gradInput = self.gram:backward(input, dG) 362 | if self.normalize then 363 | self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) 364 | end 365 | self.gradInput:mul(self.strength) 366 | self.gradInput:add(gradOutput) 367 | else 368 | self.gradInput = gradOutput 369 | end 370 | return self.gradInput 371 | end 372 | 373 | 374 | local TVLoss, parent = torch.class('nn.TVLoss', 'nn.Module') 375 | 376 | function TVLoss:__init(strength) 377 | parent.__init(self) 378 | self.strength = strength 379 | self.x_diff = torch.Tensor() 380 | self.y_diff = torch.Tensor() 381 | end 382 | 383 | function TVLoss:updateOutput(input) 384 | self.output = input 385 | return self.output 386 | end 387 | 388 | -- TV loss backward pass inspired by kaishengtai/neuralart 389 | function TVLoss:updateGradInput(input, gradOutput) 390 | self.gradInput:resizeAs(input):zero() 391 | local C, H, W = input:size(1), input:size(2), input:size(3) 392 | self.x_diff:resize(3, H - 1, W - 1) 393 | self.y_diff:resize(3, H - 1, W - 1) 394 | self.x_diff:copy(input[{{}, {1, -2}, {1, -2}}]) 395 | self.x_diff:add(-1, input[{{}, {1, -2}, {2, -1}}]) 396 | self.y_diff:copy(input[{{}, {1, -2}, {1, -2}}]) 397 | self.y_diff:add(-1, input[{{}, {2, -1}, {1, -2}}]) 398 | self.gradInput[{{}, {1, -2}, {1, -2}}]:add(self.x_diff):add(self.y_diff) 399 | self.gradInput[{{}, {1, -2}, {2, -1}}]:add(-1, self.x_diff) 400 | self.gradInput[{{}, {2, -1}, {1, -2}}]:add(-1, self.y_diff) 401 | self.gradInput:mul(self.strength) 402 | self.gradInput:add(gradOutput) 403 | return self.gradInput 404 | end 405 | 406 | 407 | local params = cmd:parse(arg) 408 | main(params) 409 | -------------------------------------------------------------------------------- /optimal.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'image' 4 | require 'paths' 5 | require 'lib/NonparametricPatchAutoencoderFactory' 6 | require 'lib/MaxCoord' 7 | require 'lib/utils' 8 | require 'lib/AdaptiveInstanceNormalization' 9 | require 'nngraph' 10 | require 'cudnn' 11 | require 'cunn' 12 | 13 | local matio = require 'matio' 14 | 15 | local cmd = torch.CmdLine() 16 | cmd:option('-style', 'input/portrait_10.jpg', 'path to the style image') 17 | cmd:option('-content', 'input/13960.jpg', 'path to the content image') 18 | cmd:option('-alpha', 0.6) 19 | cmd:option('-patchSize', 3) 20 | cmd:option('-patchStride', 1) 21 | cmd:option('-vgg1', 'models/conv1_1.t7', 'Path to the VGG conv1_1') 22 | cmd:option('-vgg2', 'models/conv2_1.t7', 'Path to the VGG conv2_1') 23 | cmd:option('-vgg3', 'models/conv3_1.t7', 'Path to the VGG conv3_1') 24 | cmd:option('-vgg4', 'models/conv4_1.t7', 'Path to the VGG conv4_1') 25 | cmd:option('-vgg5', 'models/conv5_1.t7', 'Path to the VGG conv5_1') 26 | cmd:option('-decoder5', 'models/dec5_1.t7', 'Path to the decoder5') 27 | cmd:option('-decoder4', 'models/dec4_1.t7', 'Path to the decoder4') 28 | cmd:option('-decoder3', 'models/dec3_1.t7', 'Path to the decoder3') 29 | cmd:option('-decoder2', 'models/dec2_1.t7', 'Path to the decoder2') 30 | cmd:option('-decoder1', 'models/dec1_1.t7', 'Path to the decoder1') 31 | cmd:option('-contentSize', 768, 'New (minimum) size for the content image, keeping the original size if set to 0') 32 | cmd:option('-styleSize', 768, 'New (minimum) size for the style image, keeping the original size if set to 0') 33 | cmd:option('-outputDir', 'output/alley_1', 'Directory to save the output image(s)') 34 | opt = cmd:parse(arg) 35 | 36 | --////////////////////////////////////////////////// 37 | -- util functions 38 | --///////////////////////////////////////////////// 39 | function loadModel() 40 | vgg1 = torch.load(opt.vgg1) 41 | vgg2 = torch.load(opt.vgg2) 42 | vgg3 = torch.load(opt.vgg3) 43 | vgg4 = torch.load(opt.vgg4) 44 | vgg5 = torch.load(opt.vgg5) 45 | 46 | decoder5 = torch.load(opt.decoder5) 47 | decoder4 = torch.load(opt.decoder4) 48 | decoder3 = torch.load(opt.decoder3) 49 | decoder2 = torch.load(opt.decoder2) 50 | decoder1 = torch.load(opt.decoder1) 51 | 52 | adain5 = nn.AdaptiveInstanceNormalization(vgg5:get(#vgg5-1).nOutputPlane) 53 | adain4 = nn.AdaptiveInstanceNormalization(vgg4:get(#vgg4-1).nOutputPlane) 54 | adain3 = nn.AdaptiveInstanceNormalization(vgg3:get(#vgg3-1).nOutputPlane) 55 | adain2 = nn.AdaptiveInstanceNormalization(vgg2:get(#vgg2-1).nOutputPlane) 56 | adain1 = nn.AdaptiveInstanceNormalization(vgg1:get(#vgg1-1).nOutputPlane) 57 | 58 | print('GPU mode') 59 | vgg1:cuda() 60 | vgg2:cuda() 61 | vgg3:cuda() 62 | vgg4:cuda() 63 | vgg5:cuda() 64 | 65 | adain5:cuda() 66 | adain4:cuda() 67 | adain3:cuda() 68 | adain2:cuda() 69 | adain1:cuda() 70 | 71 | decoder1:cuda() 72 | decoder2:cuda() 73 | decoder3:cuda() 74 | decoder4:cuda() 75 | decoder5:cuda() 76 | end 77 | 78 | 79 | function normalize_features(x) 80 | local x2 = torch.pow(x, 2) 81 | local sum_x2 = torch.sum(x2, 1) 82 | local dis_x2 = torch.sqrt(sum_x2) 83 | local Nx = torch.cdiv(x, dis_x2:expandAs(x) + 1e-8) 84 | -- local Nx = torch.cdiv(x, dis_x2:expandAs(x)) 85 | dis_x2 = (dis_x2-torch.min(dis_x2))/(torch.max(dis_x2)-torch.min(dis_x2)) 86 | return Nx,dis_x2 87 | end 88 | 89 | 90 | function whitenMatrix(featureIn) 91 | local feature = featureIn:clone() -- c x hw 92 | local sz = feature:size() 93 | local ft_mean = torch.mean(feature,2) 94 | feature = feature - ft_mean:expandAs(feature) 95 | local ft_std = torch.std(feature,2) 96 | local ft_conv = torch.mm(feature,feature:t()):div(sz[2]-1) 97 | local u,e,v = torch.svd(ft_conv:float(),'A') 98 | local k_c = sz[1] 99 | for i=1,sz[1] do 100 | if e[i]<0.00001 then 101 | k_c = i-1 102 | break 103 | end 104 | end 105 | local d = e[{{1,k_c}}]:sqrt():pow(-1) 106 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 107 | return m:cuda(),ft_mean:cuda(),ft_std:cuda() 108 | end 109 | 110 | 111 | function colorMatrix(featureIn) 112 | local feature = featureIn:clone() 113 | local sz = feature:size() 114 | local ft_mean = torch.mean(feature,2) 115 | feature = feature - ft_mean:expandAs(feature) 116 | local ft_std = torch.std(feature,2) 117 | local ft_conv = torch.mm(feature,feature:t()):div(sz[2]-1) 118 | local u,e,v = torch.svd(ft_conv:float(),'A') 119 | local k_c = sz[1] 120 | for i=1,sz[1] do 121 | if e[i]<0.00001 then 122 | k_c = i-1 123 | break 124 | end 125 | end 126 | local d = e[{{1,k_c}}]:sqrt() 127 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 128 | return m:cuda(),ft_mean:cuda(),ft_std:cuda() 129 | end 130 | 131 | 132 | function sqrtInvMatrix(mtx) 133 | local sz = mtx:size() 134 | local u,e,v = torch.svd(mtx:float(),'A') 135 | local k_c = sz[1] 136 | for i=1,sz[1] do 137 | if e[i]<0.00001 then 138 | k_c = i-1 139 | break 140 | end 141 | end 142 | local d = e[{{1,k_c}}]:sqrt():pow(-1) 143 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 144 | return m:cuda() 145 | end 146 | 147 | 148 | function invMatrix(mtx) 149 | local sz = mtx:size() 150 | local u,e,v = torch.svd(mtx:float(),'A') 151 | local k_c = sz[1] 152 | for i=1,sz[1] do 153 | if e[i]<0.00001 then 154 | k_c = i-1 155 | break 156 | end 157 | end 158 | local d = e[{{1,k_c}}]:pow(-1) 159 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 160 | return m:cuda() 161 | end 162 | 163 | 164 | function sqrtMatrix(mtx) 165 | local sz = mtx:size() 166 | local u,e,v = torch.svd(mtx:float(),'A') 167 | local k_c = sz[1] 168 | for i=1,sz[1] do 169 | if e[i]<0.00001 then 170 | k_c = i-1 171 | break 172 | end 173 | end 174 | local d = e[{{1,k_c}}]:sqrt() 175 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 176 | return m:cuda() 177 | end 178 | 179 | --//////////////////////////////////////////////// 180 | -- feature transform functions 181 | --//////////////////////////////////////////////// 182 | function feature_swap(contentFeature, styleFeature) 183 | 184 | local sg = contentFeature:size() 185 | local contentFeature1 = contentFeature:view(sg[1], sg[2]*sg[3]) 186 | local c_mean = torch.mean(contentFeature1, 2) 187 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 188 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(sg[2]*sg[3]-1) 189 | local c_u, c_e, c_v = torch.svd(contentCov:float(), 'A') 190 | local k_c = sg[1] 191 | for i=1, sg[1] do 192 | if c_e[i] < 0.00001 then 193 | k_c = i-1 194 | break 195 | end 196 | end 197 | 198 | local sz = styleFeature:size() 199 | local styleFeature1 = styleFeature:view(sz[1], sz[2]*sz[3]) 200 | local s_mean = torch.mean(styleFeature1, 2) 201 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 202 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sz[2]*sz[3]-1) 203 | local s_u, s_e, s_v = torch.svd(styleCov:float(), 'A') 204 | local k_s = sz[1] 205 | for i=1, sz[1] do 206 | if s_e[i] < 0.00001 then 207 | k_s = i-1 208 | break 209 | end 210 | end 211 | 212 | local s_d = torch.sqrt(s_e[{{1,k_s}}]):pow(-1) 213 | local whiten_styleFeature = nil 214 | whiten_styleFeature = (s_v[{{},{1,k_s}}]:cuda()) * torch.diag(s_d:cuda()) * (s_v[{{},{1,k_s}}]:t():cuda()) * styleFeature1 215 | local swap_enc, swap_dec = NonparametricPatchAutoencoderFactory.buildAutoencoder(whiten_styleFeature:resize(sz[1], sz[2], sz[3]), opt.patchSize, opt.patchStride, false, false, true) 216 | local swap = nn.Sequential() 217 | swap:add(swap_enc) 218 | swap:add(nn.MaxCoord()) 219 | swap:add(swap_dec) 220 | swap:evaluate() 221 | swap:cuda() 222 | local c_d = torch.sqrt(c_e[{{1,k_c}}]):pow(-1) 223 | local s_d1 = torch.sqrt(s_e[{{1,k_s}}]) 224 | local whiten_contentFeature = nil 225 | local targetFeature = nil 226 | whiten_contentFeature = (c_v[{{},{1,k_c}}]:cuda()) * torch.diag(c_d:cuda()) * (c_v[{{},{1,k_c}}]:t():cuda()) *contentFeature1 227 | local swap_latent = swap:forward(whiten_contentFeature:resize(sg[1], sg[2], sg[3])):clone() 228 | local swap_latent1 = swap_latent:view(sg[1], sg[2]*sg[3]) 229 | targetFeature = (s_v[{{},{1,k_s}}]:cuda()) * (torch.diag(s_d1:cuda())) * (s_v[{{},{1,k_s}}]:t():cuda()) * swap_latent1 230 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 231 | local tFeature = targetFeature:resize(sg[1], sg[2], sg[3]) 232 | return tFeature 233 | end 234 | 235 | 236 | function feature_wct(contentFeature, styleFeature) 237 | 238 | local sg = contentFeature:size() 239 | local contentFeature1 = contentFeature:view(sg[1], sg[2]*sg[3]) 240 | local c_mean = torch.mean(contentFeature1, 2) 241 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 242 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(sg[2]*sg[3]-1) 243 | local c_u, c_e, c_v = torch.svd(contentCov:float(), 'A') 244 | local k_c = sg[1] 245 | 246 | for i=1, sg[1] do 247 | if c_e[i] < 0.00001 then 248 | k_c = i-1 249 | break 250 | end 251 | end 252 | 253 | --k_c = sg[1] 254 | 255 | local sz = styleFeature:size() 256 | local styleFeature1 = styleFeature:view(sz[1], sz[2]*sz[3]) 257 | local s_mean = torch.mean(styleFeature1, 2) 258 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 259 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sz[2]*sz[3]-1) 260 | local s_u, s_e, s_v = torch.svd(styleCov:float(), 'A') 261 | local k_s = sz[1] 262 | for i=1, sz[1] do 263 | if s_e[i] < 0.00001 then 264 | k_s = i-1 265 | break 266 | end 267 | end 268 | 269 | local c_d = c_e[{{1,k_c}}]:sqrt():pow(-1) 270 | local s_d1 = s_e[{{1,k_s}}]:sqrt() 271 | local whiten_contentFeature = nil 272 | local targetFeature = nil 273 | 274 | -- ZCA 275 | whiten_contentFeature = (c_v[{{},{1,k_c}}]:cuda()) * torch.diag(c_d:cuda()) * (c_v[{{},{1,k_c}}]:t():cuda()) *contentFeature1 276 | 277 | -- PCA 278 | --whiten_contentFeature = torch.diag(c_d:cuda()) * (c_v[{{},{1,k_c}}]:t():cuda()) *contentFeature1 279 | 280 | -- Cholesky 281 | --[[ 282 | local chol_d = c_e[{{1,k_c}}]:pow(-1) 283 | whiten_M = (c_v[{{},{1,k_c}}]:cuda()) * (torch.diag(chol_d:cuda()) * (c_v[{{},{1,k_c}}]:t():cuda())) 284 | whiten_M = whiten_M:float() 285 | whiten_M = torch.potrf(whiten_M,'L') 286 | whiten_M = whiten_M:t():cuda() 287 | whiten_contentFeature = whiten_M * contentFeature1 -- CxN 288 | --]] 289 | 290 | -- ZCA cor and PCA cor 291 | --[[ 292 | V_std = torch.std(contentFeature1,2):squeeze() 293 | V_sqrt = torch.diag(V_std) 294 | V_sqrt_inv = invMatrix(V_sqrt) 295 | V_sqrt_inv = V_sqrt_inv:float() 296 | P = V_sqrt_inv * contentCov:float() * V_sqrt_inv 297 | G, Theta, Gt = torch.svd(P, 'A') 298 | G_d = Theta[{{1,k_c}}]:sqrt():pow(-1) 299 | whiten_M = (G[{{},{1,k_c}}]:cuda()) * torch.diag(G_d:cuda()) * (G[{{},{1,k_c}}]:t():cuda()) * V_sqrt:cuda() 300 | --whiten_M = torch.diag(G_d:cuda()) * (G[{{},{1,k_c}}]:t():cuda()) * V_sqrt:cuda() 301 | whiten_contentFeature = whiten_M * contentFeature1 -- CxN 302 | --]] 303 | 304 | --whiten_contentFeature = curQ*whiten_contentFeature 305 | 306 | targetFeature = (s_v[{{},{1,k_s}}]:cuda()) * (torch.diag(s_d1:cuda())) * (s_v[{{},{1,k_s}}]:t():cuda()) * whiten_contentFeature 307 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 308 | local tFeature = targetFeature:resize(sg[1], sg[2], sg[3]) 309 | return tFeature 310 | end 311 | 312 | 313 | function feature_mk(contentFeature, styleFeature) 314 | 315 | local eps=1e-10 316 | local cDim = contentFeature:size() 317 | local contentFeature1 = contentFeature:view(cDim[1], cDim[2]*cDim[3]) -- cxhw 318 | local c_mean = torch.mean(contentFeature1, 2) 319 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 320 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(cDim[2]*cDim[3]-1) -- cxc 321 | 322 | local sDim = styleFeature:size() 323 | local styleFeature1 = styleFeature:view(sDim[1], sDim[2]*sDim[3]) -- cxhw 324 | local s_mean = torch.mean(styleFeature1, 2) 325 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 326 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sDim[2]*sDim[3]-1) -- cxc 327 | 328 | local Da2,Ua = torch.eig(contentCov:float(),'V') -- return e(mx2),V(mxm) 329 | Ua = Ua:t() 330 | Da2 = Da2[{{},{1}}]:squeeze():cuda() 331 | Da2 = torch.diag(Da2) 332 | Da2[torch.lt(Da2,0)] = 0 333 | Da2 = Da2+eps 334 | local Da = Da2:sqrt():cuda() -- cxc 335 | Ua = Ua:cuda() 336 | 337 | styleCov = styleCov:cuda() 338 | local C = Da*Ua:t()*styleCov*Ua*Da 339 | 340 | local Dc2,Uc = torch.eig(C:float(),'V') -- return e,V 341 | Uc = Uc:t() 342 | Dc2 = Dc2[{{},{1}}]:squeeze():cuda() 343 | Dc2 = torch.diag(Dc2) 344 | Dc2[torch.lt(Dc2,0)] = 0 345 | Dc2 = Dc2+eps 346 | local Dc = Dc2:sqrt() 347 | Uc = Uc:cuda() 348 | 349 | local Da_inv = Da:pow(-1) 350 | 351 | local T = Ua*Da_inv*Uc*Dc*Uc:t()*Da_inv*Ua:t() -- cxc 352 | 353 | local targetFeature = T*contentFeature1 354 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 355 | local resFeature = targetFeature:resize(cDim[1],cDim[2],cDim[3]) 356 | return resFeature 357 | end 358 | 359 | 360 | function feature_mk2(contentFeature, styleFeature) 361 | 362 | local eps=1e-10 363 | local cDim = contentFeature:size() 364 | local contentFeature1 = contentFeature:view(cDim[1], cDim[2]*cDim[3]) -- cxhw 365 | local c_mean = torch.mean(contentFeature1, 2) 366 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 367 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(cDim[2]*cDim[3]-1) -- cxc 368 | 369 | local sDim = styleFeature:size() 370 | local styleFeature1 = styleFeature:view(sDim[1], sDim[2]*sDim[3]) -- cxhw 371 | local s_mean = torch.mean(styleFeature1, 2) 372 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 373 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sDim[2]*sDim[3]-1) -- cxc 374 | 375 | local sqrtInvU = sqrtInvMatrix(contentCov) 376 | local sqrtU = sqrtMatrix(contentCov) 377 | local C = sqrtU*styleCov*sqrtU 378 | local sqrtC = sqrtMatrix(C) 379 | local T = sqrtInvU*sqrtC*sqrtInvU 380 | local targetFeature = T*contentFeature1 381 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 382 | local resFeature = targetFeature:resize(cDim[1],cDim[2],cDim[3]) 383 | return resFeature 384 | end 385 | 386 | function feature_mk3(contentFeature, styleFeature) 387 | 388 | local eps=1e-10 389 | local cDim = contentFeature:size() -- cxN 390 | local contentFeature1 = contentFeature -- cxN 391 | local c_mean = torch.mean(contentFeature1, 2) 392 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 393 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(cDim[2]-1) -- cxc 394 | 395 | local sDim = styleFeature:size() -- cxN 396 | local styleFeature1 = styleFeature -- cxN 397 | local s_mean = torch.mean(styleFeature1, 2) 398 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 399 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sDim[2]-1) -- cxc 400 | 401 | local sqrtInvU = sqrtInvMatrix(contentCov) 402 | local sqrtU = sqrtMatrix(contentCov) 403 | local C = sqrtU*styleCov*sqrtU 404 | local sqrtC = sqrtMatrix(C) 405 | local T = sqrtInvU*sqrtC*sqrtInvU 406 | local targetFeature = T*contentFeature1 407 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 408 | local resFeature = targetFeature 409 | return resFeature -- cxN 410 | 411 | end 412 | 413 | function feature_mk3_sem(contentFeature, styleFeature,maskC,maskS) 414 | 415 | local eps=1e-10 416 | 417 | maskC = maskC:cuda() 418 | maskS = maskS:cuda() 419 | 420 | local cDim = contentFeature:size() 421 | local contentFeature1 = contentFeature:view(cDim[1], cDim[2]*cDim[3]) -- cxhw 422 | local sDim = styleFeature:size() 423 | local styleFeature1 = styleFeature:view(sDim[1], sDim[2]*sDim[3]) -- cxhw 424 | 425 | local cView = maskC:view(-1) 426 | local sView = maskS:view(-1) 427 | 428 | local targetFeature1 = contentFeature1:clone():zero() 429 | 430 | for k=1,5 do 431 | local cFg = torch.LongTensor(torch.find(cView,k-1)) 432 | local sFg = torch.LongTensor(torch.find(sView,k-1)) 433 | local cFt = contentFeature1:index(2,cFg):view(cDim[1],cFg:nElement()) 434 | local sFt = styleFeature1:index(2,sFg):view(sDim[1],sFg:nElement()) 435 | local tFt = feature_mk3(cFt,sFt) 436 | targetFeature1:indexCopy(2,cFg,tFt) 437 | end 438 | 439 | targetFeature1 = targetFeature1:viewAs(contentFeature) 440 | return targetFeature1 441 | end 442 | 443 | 444 | function feature_clamp(contentFeature,styleFeature) 445 | 446 | -- check feature 447 | --[[ 448 | local cFt = contentFeature[{{1},{},{}}]:squeeze() 449 | local sFt = styleFeature[{{1},{},{}}]:squeeze() 450 | local disp = torch.cat(cFt,sFt) 451 | image.display(disp) 452 | --]] 453 | 454 | local sz_c = contentFeature:size() 455 | local sz_s = styleFeature:size() 456 | local contentFeatureView = contentFeature:view(sz_c[1],sz_c[2]*sz_c[3]) 457 | local styleFeatureView = styleFeature:view(sz_s[1],sz_s[2]*sz_s[3]) 458 | local cWhitenM,cWhitenMean,cWhitenStd = whitenMatrix(contentFeatureView) 459 | local sWhitenM,sWhitenMean,sWhitenStd = whitenMatrix(styleFeatureView) 460 | local sColorM,sColorMean,sColorStd = colorMatrix(styleFeatureView) 461 | -- whiten 462 | local contentWhiten = cWhitenM*(contentFeatureView-cWhitenMean:expandAs(contentFeatureView)) 463 | local styleWhiten = sWhitenM*(styleFeatureView-sWhitenMean:expandAs(styleFeatureView)) 464 | contentWhiten = contentWhiten:view(sz_c[1],sz_c[2],sz_c[3]) 465 | styleWhiten = styleWhiten:view(sz_s[1],sz_s[2],sz_s[3]) 466 | -- blend 467 | local gainMap = torch.cdiv(styleWhiten,contentWhiten) 468 | gainMap = torch.clamp(gainMap,0.5,1.0) 469 | local contentRemap = torch.cmul(contentWhiten,gainMap) 470 | contentRemap = contentRemap:view(sz_c[1],sz_c[2]*sz_c[3]) 471 | contentRemap = sColorM*contentRemap+sColorMean:expandAs(contentRemap) 472 | contentRemap = contentRemap:view(sz_c[1],sz_c[2],sz_c[3]) 473 | return contentRemap 474 | end 475 | 476 | 477 | function feature_blend(contentFeature,styleFeature,alpha) 478 | local szC = contentFeature:size() 479 | local szS = styleFeature:size() 480 | local contentFtView = contentFeature:view(szC[1],szC[2]*szC[3]) 481 | local styleFtView = styleFeature:view(szS[1],szS[2]*szS[3]) 482 | local contentFtN,contentFtD = normalize_features(contentFtView) 483 | local styleFtN,styleFtD = normalize_features(styleFtView) 484 | 485 | contentFtD = contentFtD - 0.05 486 | contentFtD[contentFtD:lt(0.000001)] = 0.0 487 | contentFtD[contentFtD:gt(0.000001)] = 1.0 488 | local gainMap = contentFtD*alpha 489 | gainMap = gainMap:view(1,szC[2],szC[3]) 490 | --image.display(gainMap:squeeze()) 491 | gainMap = gainMap:expandAs(contentFeature) 492 | 493 | --[[ 494 | contentFtD = -300.0*(contentFtD-0.05) 495 | local gainMap = torch.cinv((1+torch.exp(contentFtD))) 496 | gainMap = gainMap:view(1,szC[2],szC[3]) 497 | image.display(gainMap:squeeze()) 498 | gainMap = gainMap:expandAs(contentFeature) 499 | gainMap = alpha*gainMap 500 | --]] 501 | 502 | return torch.cmul(contentFeature,gainMap)+torch.cmul(styleFeature,1-gainMap) 503 | end 504 | 505 | --////////////////////////////////////////////////// 506 | -- style transfer functions 507 | --///////////////////////////////////////////////// 508 | local function styleTransfer_wct(content, style) 509 | 510 | loadModel() 511 | 512 | print('Start wct') 513 | 514 | content = content:cuda() 515 | style = style:cuda() 516 | local cF5 = vgg5:forward(content):clone() 517 | local sF5 = vgg5:forward(style):clone() 518 | vgg5 = nil 519 | local csF5 = nil 520 | --csF5 = feature_swap(cF5, sF5) 521 | csF5 = feature_wct(cF5, sF5) 522 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 523 | local Im5 = decoder5:forward(csF5) 524 | decoder5 = nil 525 | 526 | local cF4 = vgg4:forward(Im5):clone() 527 | local sF4 = vgg4:forward(style):clone() 528 | vgg4 = nil 529 | --local csF4 = feature_swap(cF4,sF4) 530 | local csF4 = feature_wct(cF4, sF4) 531 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 532 | local Im4 = decoder4:forward(csF4) 533 | decoder4 = nil 534 | 535 | local cF3 = vgg3:forward(Im4):clone() 536 | local sF3 = vgg3:forward(style):clone() 537 | vgg3 = nil 538 | local csF3 = feature_wct(cF3, sF3) 539 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 540 | 541 | local Im3 = decoder3:forward(csF3) 542 | decoder3 = nil 543 | local cF2 = vgg2:forward(Im3):clone() 544 | local sF2 = vgg2:forward(style):clone() 545 | vgg2 = nil 546 | 547 | local csF2 = feature_wct(cF2, sF2) 548 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 549 | local Im2 = decoder2:forward(csF2) 550 | decoder2 = nil 551 | 552 | local cF1 = vgg1:forward(Im2):clone() 553 | local sF1 = vgg1:forward(style):clone() 554 | vgg1 = nil 555 | local csF1 = feature_wct(cF1, sF1) 556 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 557 | local Im1 = decoder1:forward(csF1) 558 | decoder1 = nil 559 | return Im1 560 | end 561 | 562 | 563 | local function styleTransfer_adaIn(content, style) 564 | loadModel() 565 | 566 | print('Start AdaIn') 567 | 568 | content = content:cuda() 569 | style = style:cuda() 570 | 571 | local cF5 = vgg5:forward(content):clone() 572 | local sF5 = vgg5:forward(style):clone() 573 | vgg5 = nil 574 | csF5 = adain5:forward({cF5, sF5}):squeeze() 575 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 576 | local Im5 = decoder5:forward(csF5) 577 | decoder5 = nil 578 | 579 | local cF4 = vgg4:forward(Im5):clone() 580 | local sF4 = vgg4:forward(style):clone() 581 | vgg4 = nil 582 | local csF4 = adain4:forward({cF4, sF4}):squeeze() 583 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 584 | local Im4 = decoder4:forward(csF4) 585 | decoder4 = nil 586 | 587 | local cF3 = vgg3:forward(Im4):clone() 588 | local sF3 = vgg3:forward(style):clone() 589 | vgg3 = nil 590 | local csF3 = adain3:forward({cF3, sF3}):squeeze() 591 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 592 | local Im3 = decoder3:forward(csF3) 593 | decoder3 = nil 594 | 595 | local cF2 = vgg2:forward(Im3):clone() 596 | local sF2 = vgg2:forward(style):clone() 597 | vgg2 = nil 598 | local csF2 = adain2:forward({cF2, sF2}):squeeze() 599 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 600 | local Im2 = decoder2:forward(csF2) 601 | decoder2 = nil 602 | 603 | local cF1 = vgg1:forward(Im2):clone() 604 | local sF1 = vgg1:forward(style):clone() 605 | vgg1 = nil 606 | local csF1 = adain1:forward({cF1, sF1}):squeeze() 607 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 608 | local Im1 = decoder1:forward(csF1) 609 | decoder1 = nil 610 | return Im1 611 | end 612 | 613 | local function styleTransfer_clamp(content, style) 614 | 615 | loadModel() 616 | 617 | local cSz = content:size() 618 | local sSz = style:size() 619 | 620 | content = content:cuda() 621 | style = style:cuda() 622 | 623 | --[[ 624 | local cF5 = vgg5:forward(content):clone() 625 | local sF5 = vgg5:forward(style):clone() 626 | vgg5 = nil 627 | csF5 = feature_clamp(cF5,sF5) 628 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 629 | local Im5 = decoder5:forward(csF5):clone() 630 | decoder5 = nil 631 | --]] 632 | 633 | vgg5 = nil 634 | decoder5 = nil 635 | local Im5 = content 636 | 637 | Im5 = image.scale(Im5:float(),cSz[3],cSz[2]) 638 | Im5 = Im5:cuda() 639 | local cF4 = vgg4:forward(Im5):clone() 640 | local sF4 = vgg4:forward(style):clone() 641 | vgg4 = nil 642 | --local csF4 = feature_clamp(cF4,sF4) 643 | local csF4 = feature_blend(cF4,sF4,0.8) 644 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 645 | local Im4 = decoder4:forward(csF4):clone() 646 | decoder4 = nil 647 | 648 | Im4 = image.scale(Im4:float(),cSz[3],cSz[2]) 649 | Im4 = Im4:cuda() 650 | local cF3 = vgg3:forward(Im4):clone() 651 | local sF3 = vgg3:forward(style):clone() 652 | vgg3 = nil 653 | --local csF3 = feature_clamp(cF3,sF3) 654 | local csF3 = feature_blend(cF3,sF3,0.7) 655 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 656 | local Im3 = decoder3:forward(csF3):clone() 657 | decoder3 = nil 658 | 659 | Im3 = image.scale(Im3:float(),cSz[3],cSz[2]) 660 | Im3 = Im3:cuda() 661 | local cF2 = vgg2:forward(Im3):clone() 662 | local sF2 = vgg2:forward(style):clone() 663 | vgg2 = nil 664 | --local csF2 = feature_clamp(cF2,sF2) 665 | local csF2 = feature_blend(cF2,sF2,0.6) 666 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 667 | local Im2 = decoder2:forward(csF2):clone() 668 | decoder2 = nil 669 | 670 | Im2 = image.scale(Im2:float(),cSz[3],cSz[2]) 671 | Im2 = Im2:cuda() 672 | local cF1 = vgg1:forward(Im2):clone() 673 | local sF1 = vgg1:forward(style):clone() 674 | vgg1 = nil 675 | --local csF1 = feature_clamp(cF1,sF1) 676 | local csF1 = feature_blend(cF1,sF1,0.3) 677 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 678 | local Im1 = decoder1:forward(csF1):clone() 679 | decoder1 = nil 680 | Im1 = image.scale(Im1:float(),cSz[3],cSz[2]) 681 | Im1 = Im1:cuda() 682 | 683 | return Im1 684 | end 685 | 686 | 687 | local function styleTransfer_mk(content, style) 688 | 689 | loadModel() 690 | 691 | print('Start MK') 692 | 693 | content = content:cuda() 694 | style = style:cuda() 695 | local cF5 = vgg5:forward(content):clone() 696 | local sF5 = vgg5:forward(style):clone() 697 | vgg5 = nil 698 | local csF5 = nil 699 | csF5 = feature_mk2(cF5, sF5) 700 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 701 | local Im5 = decoder5:forward(csF5) 702 | decoder5 = nil 703 | 704 | local cF4 = vgg4:forward(Im5):clone() 705 | local sF4 = vgg4:forward(style):clone() 706 | vgg4 = nil 707 | local csF4 = feature_mk2(cF4, sF4) 708 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 709 | local Im4 = decoder4:forward(csF4) 710 | decoder4 = nil 711 | 712 | local cF3 = vgg3:forward(Im4):clone() 713 | local sF3 = vgg3:forward(style):clone() 714 | vgg3 = nil 715 | local csF3 = feature_mk2(cF3, sF3) 716 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 717 | 718 | local Im3 = decoder3:forward(csF3) 719 | decoder3 = nil 720 | local cF2 = vgg2:forward(Im3):clone() 721 | local sF2 = vgg2:forward(style):clone() 722 | vgg2 = nil 723 | 724 | local csF2 = feature_mk2(cF2, sF2) 725 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 726 | local Im2 = decoder2:forward(csF2) 727 | decoder2 = nil 728 | 729 | local cF1 = vgg1:forward(Im2):clone() 730 | local sF1 = vgg1:forward(style):clone() 731 | vgg1 = nil 732 | local csF1 = feature_mk2(cF1, sF1) 733 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 734 | local Im1 = decoder1:forward(csF1) 735 | decoder1 = nil 736 | return Im1 737 | end 738 | 739 | 740 | local function styleTransfer_mk_sem(content, style) 741 | 742 | loadModel() 743 | content = content:cuda() 744 | style = style:cuda() 745 | 746 | --/////// 747 | local cF5 = vgg5:forward(content):clone() 748 | local sF5 = vgg5:forward(style):clone() 749 | vgg5 = nil 750 | local maskC = image.scale(masks.cMask,cF5:size(3),cF5:size(2),'simple') 751 | local maskS = image.scale(masks.sMask,sF5:size(3),sF5:size(2),'simple') 752 | local csF5 = nil 753 | csF5 = feature_mk3_sem(cF5, sF5,maskC,maskS) 754 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 755 | local Im5 = decoder5:forward(csF5) 756 | decoder5 = nil 757 | 758 | --////// 759 | local cF4 = vgg4:forward(Im5):clone() 760 | local sF4 = vgg4:forward(style):clone() 761 | vgg4 = nil 762 | maskC = image.scale(masks.cMask,cF4:size(3),cF4:size(2),'simple') 763 | maskS = image.scale(masks.sMask,sF4:size(3),sF4:size(2),'simple') 764 | local csF4 = feature_mk3_sem(cF4, sF4,maskC,maskS) 765 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 766 | local Im4 = decoder4:forward(csF4) 767 | decoder4 = nil 768 | 769 | --////// 770 | local cF3 = vgg3:forward(Im4):clone() 771 | local sF3 = vgg3:forward(style):clone() 772 | vgg3 = nil 773 | maskC = image.scale(masks.cMask,cF3:size(3),cF3:size(2),'simple') 774 | maskS = image.scale(masks.sMask,sF3:size(3),sF3:size(2),'simple') 775 | local csF3 = feature_mk3_sem(cF3, sF3,maskC,maskS) 776 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 777 | local Im3 = decoder3:forward(csF3) 778 | decoder3 = nil 779 | 780 | --/////// 781 | local cF2 = vgg2:forward(Im3):clone() 782 | local sF2 = vgg2:forward(style):clone() 783 | vgg2 = nil 784 | maskC = image.scale(masks.cMask,cF2:size(3),cF2:size(2),'simple') 785 | maskS = image.scale(masks.sMask,sF2:size(3),sF2:size(2),'simple') 786 | local csF2 = feature_mk3_sem(cF2, sF2,maskC,maskS) 787 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 788 | local Im2 = decoder2:forward(csF2) 789 | decoder2 = nil 790 | 791 | --////// 792 | local cF1 = vgg1:forward(Im2):clone() 793 | local sF1 = vgg1:forward(style):clone() 794 | vgg1 = nil 795 | maskC = image.scale(masks.cMask,cF1:size(3),cF1:size(2),'simple') 796 | maskS = image.scale(masks.sMask,sF1:size(3),sF1:size(2),'simple') 797 | local csF1 = feature_mk3_sem(cF1, sF1,maskC,maskS) 798 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 799 | local Im1 = decoder1:forward(csF1) 800 | decoder1 = nil 801 | 802 | return Im1 803 | end 804 | 805 | --///////////////////////////////////////////////// 806 | -- main functions 807 | --///////////////////////////////////////////////// 808 | print('Creating save folder at ' .. opt.outputDir) 809 | paths.mkdir(opt.outputDir) 810 | 811 | local contentPath = opt.content 812 | local contentExt = paths.extname(contentPath) 813 | local contentName = paths.basename(contentPath,contentExt) 814 | local contentImg = image.load(contentPath, 3, 'float') 815 | contentImg = sizePreprocess(contentImg, opt.contentSize) 816 | 817 | local stylePath = opt.style 818 | local styleExt = paths.extname(stylePath) 819 | local styleName = paths.basename(stylePath,styleExt) 820 | local styleImg = image.load(stylePath, 3, 'float') 821 | styleImg = sizePreprocess(styleImg, opt.styleSize) 822 | 823 | local output = styleTransfer_mk(contentImg, styleImg) 824 | --local output = styleTransfer_wct(contentImg,styleImg) 825 | --local output = styleTransfer_adaIn(contentImg, styleImg) 826 | 827 | local savePath = paths.concat(opt.outputDir, contentName .. '_stylized_by_' .. styleName .. '.jpg') 828 | print('Output image saved at: ' .. savePath) 829 | image.save(savePath, output) 830 | 831 | -------------------------------------------------------------------------------- /optimal_folder.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'image' 4 | require 'paths' 5 | require 'lib/NonparametricPatchAutoencoderFactory' 6 | require 'lib/MaxCoord' 7 | require 'lib/utils' 8 | require 'lib/AdaptiveInstanceNormalization' 9 | require 'nngraph' 10 | require 'cudnn' 11 | require 'cunn' 12 | 13 | local matio = require 'matio' 14 | local cmd = torch.CmdLine() 15 | 16 | cmd:option('-style', 'input/portrait_10.jpg', 'path to the style image') 17 | cmd:option('-content', 'input/13960.jpg', 'path to the content image') 18 | cmd:option('-alpha', 0.6) 19 | cmd:option('-patchSize', 3) 20 | cmd:option('-patchStride', 1) 21 | cmd:option('-vgg1', 'models/conv1_1.t7', 'Path to the VGG conv1_1') 22 | cmd:option('-vgg2', 'models/conv2_1.t7', 'Path to the VGG conv2_1') 23 | cmd:option('-vgg3', 'models/conv3_1.t7', 'Path to the VGG conv3_1') 24 | cmd:option('-vgg4', 'models/conv4_1.t7', 'Path to the VGG conv4_1') 25 | cmd:option('-vgg5', 'models/conv5_1.t7', 'Path to the VGG conv5_1') 26 | cmd:option('-decoder5', 'models/dec5_1.t7', 'Path to the decoder5') 27 | cmd:option('-decoder4', 'models/dec4_1.t7', 'Path to the decoder4') 28 | cmd:option('-decoder3', 'models/dec3_1.t7', 'Path to the decoder3') 29 | cmd:option('-decoder2', 'models/dec2_1.t7', 'Path to the decoder2') 30 | cmd:option('-decoder1', 'models/dec1_1.t7', 'Path to the decoder1') 31 | 32 | cmd:option('-contentSize', 768, 'New (minimum) size for the content image, keeping the original size if set to 0') 33 | cmd:option('-styleSize', 768, 'New (minimum) size for the style image, keeping the original size if set to 0') 34 | cmd:option('-outputDir', 'output/alley_1', 'Directory to save the output image(s)') 35 | 36 | opt = cmd:parse(arg) 37 | 38 | --///////////////////////////////////////////////////////////////////// 39 | function loadModel() 40 | vgg1 = torch.load(opt.vgg1) 41 | vgg2 = torch.load(opt.vgg2) 42 | vgg3 = torch.load(opt.vgg3) 43 | vgg4 = torch.load(opt.vgg4) 44 | vgg5 = torch.load(opt.vgg5) 45 | 46 | decoder5 = torch.load(opt.decoder5) 47 | decoder4 = torch.load(opt.decoder4) 48 | decoder3 = torch.load(opt.decoder3) 49 | decoder2 = torch.load(opt.decoder2) 50 | decoder1 = torch.load(opt.decoder1) 51 | 52 | adain5 = nn.AdaptiveInstanceNormalization(vgg5:get(#vgg5-1).nOutputPlane) 53 | adain4 = nn.AdaptiveInstanceNormalization(vgg4:get(#vgg4-1).nOutputPlane) 54 | adain3 = nn.AdaptiveInstanceNormalization(vgg3:get(#vgg3-1).nOutputPlane) 55 | adain2 = nn.AdaptiveInstanceNormalization(vgg2:get(#vgg2-1).nOutputPlane) 56 | adain1 = nn.AdaptiveInstanceNormalization(vgg1:get(#vgg1-1).nOutputPlane) 57 | 58 | print('GPU mode') 59 | vgg1:cuda() 60 | vgg2:cuda() 61 | vgg3:cuda() 62 | vgg4:cuda() 63 | vgg5:cuda() 64 | 65 | adain5:cuda() 66 | adain4:cuda() 67 | adain3:cuda() 68 | adain2:cuda() 69 | adain1:cuda() 70 | 71 | decoder1:cuda() 72 | decoder2:cuda() 73 | decoder3:cuda() 74 | decoder4:cuda() 75 | decoder5:cuda() 76 | end 77 | 78 | --///////////////////////////////////////////////////////////////////// 79 | function normalize_features(x) 80 | local x2 = torch.pow(x, 2) 81 | local sum_x2 = torch.sum(x2, 1) 82 | local dis_x2 = torch.sqrt(sum_x2) 83 | local Nx = torch.cdiv(x, dis_x2:expandAs(x) + 1e-8) 84 | -- local Nx = torch.cdiv(x, dis_x2:expandAs(x)) 85 | dis_x2 = (dis_x2-torch.min(dis_x2))/(torch.max(dis_x2)-torch.min(dis_x2)) 86 | return Nx,dis_x2 87 | end 88 | 89 | --//////////////////////////////////////////////////////////////////// 90 | function whitenMatrix(featureIn) 91 | local feature = featureIn:clone() -- c x hw 92 | local sz = feature:size() 93 | local ft_mean = torch.mean(feature,2) 94 | feature = feature - ft_mean:expandAs(feature) 95 | local ft_std = torch.std(feature,2) 96 | local ft_conv = torch.mm(feature,feature:t()):div(sz[2]-1) 97 | local u,e,v = torch.svd(ft_conv:float(),'A') 98 | local k_c = sz[1] 99 | for i=1,sz[1] do 100 | if e[i]<0.00001 then 101 | k_c = i-1 102 | break 103 | end 104 | end 105 | local d = e[{{1,k_c}}]:sqrt():pow(-1) 106 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 107 | return m:cuda(),ft_mean:cuda(),ft_std:cuda() 108 | end 109 | 110 | --/////////////////////////////////////////////////////////////////// 111 | function colorMatrix(featureIn) 112 | local feature = featureIn:clone() 113 | local sz = feature:size() 114 | local ft_mean = torch.mean(feature,2) 115 | feature = feature - ft_mean:expandAs(feature) 116 | local ft_std = torch.std(feature,2) 117 | local ft_conv = torch.mm(feature,feature:t()):div(sz[2]-1) 118 | local u,e,v = torch.svd(ft_conv:float(),'A') 119 | local k_c = sz[1] 120 | for i=1,sz[1] do 121 | if e[i]<0.00001 then 122 | k_c = i-1 123 | break 124 | end 125 | end 126 | local d = e[{{1,k_c}}]:sqrt() 127 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 128 | return m:cuda(),ft_mean:cuda(),ft_std:cuda() 129 | end 130 | 131 | --//////////////////////////////////////////////////////////////////// 132 | function sqrtInvMatrix(mtx) -- cxc 133 | local sz = mtx:size() 134 | local u,e,v = torch.svd(mtx:float(),'A') 135 | local k_c = sz[1] 136 | for i=1,sz[1] do 137 | if e[i]<0.00001 then 138 | k_c = i-1 139 | break 140 | end 141 | end 142 | local d = e[{{1,k_c}}]:sqrt():pow(-1) 143 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 144 | return m:cuda() 145 | end 146 | --////////////////////////////////////////////////////////////////// 147 | function invMatrix(mtx) -- cxc 148 | local sz = mtx:size() 149 | local u,e,v = torch.svd(mtx:float(),'A') 150 | local k_c = sz[1] 151 | for i=1,sz[1] do 152 | if e[i]<0.00001 then 153 | k_c = i-1 154 | break 155 | end 156 | end 157 | local d = e[{{1,k_c}}]:pow(-1) 158 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 159 | return m:cuda() 160 | end 161 | 162 | --/////////////////////////////////////////////////////////////////// 163 | function sqrtMatrix(mtx) 164 | local sz = mtx:size() 165 | local u,e,v = torch.svd(mtx:float(),'A') 166 | local k_c = sz[1] 167 | for i=1,sz[1] do 168 | if e[i]<0.00001 then 169 | k_c = i-1 170 | break 171 | end 172 | end 173 | local d = e[{{1,k_c}}]:sqrt() 174 | local m = (v[{{},{1,k_c}}]:cuda())*torch.diag(d:cuda())*(v[{{},{1,k_c}}]:t():cuda()) 175 | return m:cuda() 176 | end 177 | 178 | --///////////////////////////////////////////////////////////////////////////////////// 179 | function feature_swap(contentFeature, styleFeature) 180 | 181 | local sg = contentFeature:size() 182 | local contentFeature1 = contentFeature:view(sg[1], sg[2]*sg[3]) 183 | local c_mean = torch.mean(contentFeature1, 2) 184 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 185 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(sg[2]*sg[3]-1) 186 | local c_u, c_e, c_v = torch.svd(contentCov:float(), 'A') 187 | local k_c = sg[1] 188 | for i=1, sg[1] do 189 | if c_e[i] < 0.00001 then 190 | k_c = i-1 191 | break 192 | end 193 | end 194 | 195 | local sz = styleFeature:size() 196 | local styleFeature1 = styleFeature:view(sz[1], sz[2]*sz[3]) 197 | local s_mean = torch.mean(styleFeature1, 2) 198 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 199 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sz[2]*sz[3]-1) 200 | local s_u, s_e, s_v = torch.svd(styleCov:float(), 'A') 201 | local k_s = sz[1] 202 | for i=1, sz[1] do 203 | if s_e[i] < 0.00001 then 204 | k_s = i-1 205 | break 206 | end 207 | end 208 | 209 | local s_d = torch.sqrt(s_e[{{1,k_s}}]):pow(-1) 210 | local whiten_styleFeature = nil 211 | whiten_styleFeature = (s_v[{{},{1,k_s}}]:cuda()) * torch.diag(s_d:cuda()) * (s_v[{{},{1,k_s}}]:t():cuda()) * styleFeature1 212 | local swap_enc, swap_dec = NonparametricPatchAutoencoderFactory.buildAutoencoder(whiten_styleFeature:resize(sz[1], sz[2], sz[3]), opt.patchSize, opt.patchStride, false, false, true) 213 | local swap = nn.Sequential() 214 | swap:add(swap_enc) 215 | swap:add(nn.MaxCoord()) 216 | swap:add(swap_dec) 217 | swap:evaluate() 218 | swap:cuda() 219 | local c_d = torch.sqrt(c_e[{{1,k_c}}]):pow(-1) 220 | local s_d1 = torch.sqrt(s_e[{{1,k_s}}]) 221 | local whiten_contentFeature = nil 222 | local targetFeature = nil 223 | whiten_contentFeature = (c_v[{{},{1,k_c}}]:cuda()) * torch.diag(c_d:cuda()) * (c_v[{{},{1,k_c}}]:t():cuda()) *contentFeature1 224 | local swap_latent = swap:forward(whiten_contentFeature:resize(sg[1], sg[2], sg[3])):clone() 225 | local swap_latent1 = swap_latent:view(sg[1], sg[2]*sg[3]) 226 | targetFeature = (s_v[{{},{1,k_s}}]:cuda()) * (torch.diag(s_d1:cuda())) * (s_v[{{},{1,k_s}}]:t():cuda()) * swap_latent1 227 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 228 | local tFeature = targetFeature:resize(sg[1], sg[2], sg[3]) 229 | return tFeature 230 | end 231 | 232 | --/////////////////////////////////////////////////////////////////////// 233 | function feature_wct(contentFeature, styleFeature) 234 | 235 | local sg = contentFeature:size() 236 | local contentFeature1 = contentFeature:view(sg[1], sg[2]*sg[3]) 237 | local c_mean = torch.mean(contentFeature1, 2) 238 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 239 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(sg[2]*sg[3]-1) 240 | local c_u, c_e, c_v = torch.svd(contentCov:float(), 'A') 241 | local k_c = sg[1] 242 | 243 | for i=1, sg[1] do 244 | if c_e[i] < 0.00001 then 245 | k_c = i-1 246 | break 247 | end 248 | end 249 | 250 | --k_c = sg[1] 251 | 252 | local sz = styleFeature:size() 253 | local styleFeature1 = styleFeature:view(sz[1], sz[2]*sz[3]) 254 | local s_mean = torch.mean(styleFeature1, 2) 255 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 256 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sz[2]*sz[3]-1) 257 | local s_u, s_e, s_v = torch.svd(styleCov:float(), 'A') 258 | local k_s = sz[1] 259 | for i=1, sz[1] do 260 | if s_e[i] < 0.00001 then 261 | k_s = i-1 262 | break 263 | end 264 | end 265 | 266 | --k_s = sz[1] 267 | 268 | --[[ 269 | curQ = nil 270 | if cur_idx == 1 then 271 | curQ = Q.Q1 272 | elseif cur_idx == 2 then 273 | curQ = Q.Q2 274 | elseif cur_idx == 3 then 275 | curQ = Q.Q3 276 | elseif cur_idx == 4 then 277 | curQ = Q.Q4 278 | elseif cur_idx == 5 then 279 | curQ = Q.Q5 280 | end 281 | curQ = curQ:cuda() 282 | print('current idx = '.. tostring(cur_idx)) 283 | cur_idx = cur_idx-1 284 | --]] 285 | 286 | local c_d = c_e[{{1,k_c}}]:sqrt():pow(-1) 287 | local s_d1 = s_e[{{1,k_s}}]:sqrt() 288 | local whiten_contentFeature = nil 289 | local targetFeature = nil 290 | 291 | -- ZCA 292 | whiten_contentFeature = (c_v[{{},{1,k_c}}]:cuda()) * torch.diag(c_d:cuda()) * (c_v[{{},{1,k_c}}]:t():cuda()) *contentFeature1 293 | 294 | -- PCA 295 | --whiten_contentFeature = torch.diag(c_d:cuda()) * (c_v[{{},{1,k_c}}]:t():cuda()) *contentFeature1 296 | 297 | -- Cholesky 298 | --[[ 299 | local chol_d = c_e[{{1,k_c}}]:pow(-1) 300 | whiten_M = (c_v[{{},{1,k_c}}]:cuda()) * (torch.diag(chol_d:cuda()) * (c_v[{{},{1,k_c}}]:t():cuda())) 301 | whiten_M = whiten_M:float() 302 | whiten_M = torch.potrf(whiten_M,'L') 303 | whiten_M = whiten_M:t():cuda() 304 | whiten_contentFeature = whiten_M * contentFeature1 -- CxN 305 | --]] 306 | 307 | -- ZCA cor and PCA cor 308 | --[[ 309 | V_std = torch.std(contentFeature1,2):squeeze() 310 | V_sqrt = torch.diag(V_std) 311 | V_sqrt_inv = invMatrix(V_sqrt) 312 | V_sqrt_inv = V_sqrt_inv:float() 313 | P = V_sqrt_inv * contentCov:float() * V_sqrt_inv 314 | G, Theta, Gt = torch.svd(P, 'A') 315 | G_d = Theta[{{1,k_c}}]:sqrt():pow(-1) 316 | whiten_M = (G[{{},{1,k_c}}]:cuda()) * torch.diag(G_d:cuda()) * (G[{{},{1,k_c}}]:t():cuda()) * V_sqrt:cuda() 317 | --whiten_M = torch.diag(G_d:cuda()) * (G[{{},{1,k_c}}]:t():cuda()) * V_sqrt:cuda() 318 | whiten_contentFeature = whiten_M * contentFeature1 -- CxN 319 | --]] 320 | 321 | --whiten_contentFeature = curQ*whiten_contentFeature 322 | 323 | targetFeature = (s_v[{{},{1,k_s}}]:cuda()) * (torch.diag(s_d1:cuda())) * (s_v[{{},{1,k_s}}]:t():cuda()) * whiten_contentFeature 324 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 325 | local tFeature = targetFeature:resize(sg[1], sg[2], sg[3]) 326 | return tFeature 327 | end 328 | 329 | --//////////////////////////////////////////////////////////////////// 330 | function feature_mk(contentFeature, styleFeature) 331 | 332 | local eps=1e-10 333 | local cDim = contentFeature:size() 334 | local contentFeature1 = contentFeature:view(cDim[1], cDim[2]*cDim[3]) -- cxhw 335 | local c_mean = torch.mean(contentFeature1, 2) 336 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 337 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(cDim[2]*cDim[3]-1) -- cxc 338 | 339 | local sDim = styleFeature:size() 340 | local styleFeature1 = styleFeature:view(sDim[1], sDim[2]*sDim[3]) -- cxhw 341 | local s_mean = torch.mean(styleFeature1, 2) 342 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 343 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sDim[2]*sDim[3]-1) -- cxc 344 | 345 | local Da2,Ua = torch.eig(contentCov:float(),'V') -- return e(mx2),V(mxm) 346 | Ua = Ua:t() 347 | Da2 = Da2[{{},{1}}]:squeeze():cuda() 348 | Da2 = torch.diag(Da2) 349 | Da2[torch.lt(Da2,0)] = 0 350 | Da2 = Da2+eps 351 | local Da = Da2:sqrt():cuda() -- cxc 352 | Ua = Ua:cuda() 353 | 354 | styleCov = styleCov:cuda() 355 | local C = Da*Ua:t()*styleCov*Ua*Da 356 | 357 | local Dc2,Uc = torch.eig(C:float(),'V') -- return e,V 358 | Uc = Uc:t() 359 | Dc2 = Dc2[{{},{1}}]:squeeze():cuda() 360 | Dc2 = torch.diag(Dc2) 361 | Dc2[torch.lt(Dc2,0)] = 0 362 | Dc2 = Dc2+eps 363 | local Dc = Dc2:sqrt() 364 | Uc = Uc:cuda() 365 | 366 | local Da_inv = Da:pow(-1) 367 | 368 | local T = Ua*Da_inv*Uc*Dc*Uc:t()*Da_inv*Ua:t() -- cxc 369 | 370 | local targetFeature = T*contentFeature1 371 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 372 | local resFeature = targetFeature:resize(cDim[1],cDim[2],cDim[3]) 373 | return resFeature 374 | end 375 | 376 | --//////////////////////////////////////////////////////////////////// 377 | function feature_mk2(contentFeature, styleFeature) 378 | 379 | local eps=1e-10 380 | local cDim = contentFeature:size() 381 | local contentFeature1 = contentFeature:view(cDim[1], cDim[2]*cDim[3]) -- cxhw 382 | local c_mean = torch.mean(contentFeature1, 2) 383 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 384 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(cDim[2]*cDim[3]-1) -- cxc 385 | 386 | local sDim = styleFeature:size() 387 | local styleFeature1 = styleFeature:view(sDim[1], sDim[2]*sDim[3]) -- cxhw 388 | local s_mean = torch.mean(styleFeature1, 2) 389 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 390 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sDim[2]*sDim[3]-1) -- cxc 391 | 392 | local sqrtInvU = sqrtInvMatrix(contentCov) 393 | local sqrtU = sqrtMatrix(contentCov) 394 | local C = sqrtU*styleCov*sqrtU 395 | local sqrtC = sqrtMatrix(C) 396 | local T = sqrtInvU*sqrtC*sqrtInvU 397 | local targetFeature = T*contentFeature1 398 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 399 | local resFeature = targetFeature:resize(cDim[1],cDim[2],cDim[3]) 400 | return resFeature 401 | end 402 | --////////////////////////////////////////////////////////////////// 403 | function feature_mk3(contentFeature, styleFeature) 404 | 405 | local eps=1e-10 406 | local cDim = contentFeature:size() -- cxN 407 | local contentFeature1 = contentFeature -- cxN 408 | local c_mean = torch.mean(contentFeature1, 2) 409 | contentFeature1 = contentFeature1 - c_mean:expandAs(contentFeature1) 410 | local contentCov = torch.mm(contentFeature1, contentFeature1:t()):div(cDim[2]-1) -- cxc 411 | 412 | local sDim = styleFeature:size() -- cxN 413 | local styleFeature1 = styleFeature -- cxN 414 | local s_mean = torch.mean(styleFeature1, 2) 415 | styleFeature1 = styleFeature1 - s_mean:expandAs(styleFeature1) 416 | local styleCov = torch.mm(styleFeature1, styleFeature1:t()):div(sDim[2]-1) -- cxc 417 | 418 | local sqrtInvU = sqrtInvMatrix(contentCov) 419 | local sqrtU = sqrtMatrix(contentCov) 420 | local C = sqrtU*styleCov*sqrtU 421 | local sqrtC = sqrtMatrix(C) 422 | local T = sqrtInvU*sqrtC*sqrtInvU 423 | local targetFeature = T*contentFeature1 424 | targetFeature = targetFeature + s_mean:expandAs(targetFeature) 425 | local resFeature = targetFeature 426 | return resFeature -- cxN 427 | 428 | end 429 | 430 | function feature_mk3_sem(contentFeature, styleFeature,maskC,maskS) 431 | 432 | local eps=1e-10 433 | 434 | maskC = maskC:cuda() 435 | maskS = maskS:cuda() 436 | 437 | local cDim = contentFeature:size() 438 | local contentFeature1 = contentFeature:view(cDim[1], cDim[2]*cDim[3]) -- cxhw 439 | local sDim = styleFeature:size() 440 | local styleFeature1 = styleFeature:view(sDim[1], sDim[2]*sDim[3]) -- cxhw 441 | 442 | local cView = maskC:view(-1) 443 | local sView = maskS:view(-1) 444 | 445 | local targetFeature1 = contentFeature1:clone():zero() 446 | 447 | for k=1,5 do 448 | local cFg = torch.LongTensor(torch.find(cView,k-1)) 449 | local sFg = torch.LongTensor(torch.find(sView,k-1)) 450 | local cFt = contentFeature1:index(2,cFg):view(cDim[1],cFg:nElement()) 451 | local sFt = styleFeature1:index(2,sFg):view(sDim[1],sFg:nElement()) 452 | local tFt = feature_mk3(cFt,sFt) 453 | targetFeature1:indexCopy(2,cFg,tFt) 454 | end 455 | 456 | targetFeature1 = targetFeature1:viewAs(contentFeature) 457 | return targetFeature1 458 | end 459 | 460 | --/////////////////////////////////////////////////////////////////// 461 | function feature_clamp(contentFeature,styleFeature) 462 | 463 | -- check feature 464 | --[[ 465 | local cFt = contentFeature[{{1},{},{}}]:squeeze() 466 | local sFt = styleFeature[{{1},{},{}}]:squeeze() 467 | local disp = torch.cat(cFt,sFt) 468 | image.display(disp) 469 | --]] 470 | 471 | local sz_c = contentFeature:size() 472 | local sz_s = styleFeature:size() 473 | local contentFeatureView = contentFeature:view(sz_c[1],sz_c[2]*sz_c[3]) 474 | local styleFeatureView = styleFeature:view(sz_s[1],sz_s[2]*sz_s[3]) 475 | local cWhitenM,cWhitenMean,cWhitenStd = whitenMatrix(contentFeatureView) 476 | local sWhitenM,sWhitenMean,sWhitenStd = whitenMatrix(styleFeatureView) 477 | local sColorM,sColorMean,sColorStd = colorMatrix(styleFeatureView) 478 | -- whiten 479 | local contentWhiten = cWhitenM*(contentFeatureView-cWhitenMean:expandAs(contentFeatureView)) 480 | local styleWhiten = sWhitenM*(styleFeatureView-sWhitenMean:expandAs(styleFeatureView)) 481 | contentWhiten = contentWhiten:view(sz_c[1],sz_c[2],sz_c[3]) 482 | styleWhiten = styleWhiten:view(sz_s[1],sz_s[2],sz_s[3]) 483 | -- blend 484 | local gainMap = torch.cdiv(styleWhiten,contentWhiten) 485 | gainMap = torch.clamp(gainMap,0.5,1.0) 486 | local contentRemap = torch.cmul(contentWhiten,gainMap) 487 | contentRemap = contentRemap:view(sz_c[1],sz_c[2]*sz_c[3]) 488 | contentRemap = sColorM*contentRemap+sColorMean:expandAs(contentRemap) 489 | contentRemap = contentRemap:view(sz_c[1],sz_c[2],sz_c[3]) 490 | return contentRemap 491 | end 492 | 493 | --///////////////////////////////////////////////////////////////////////// 494 | function feature_blend(contentFeature,styleFeature,alpha) 495 | local szC = contentFeature:size() 496 | local szS = styleFeature:size() 497 | local contentFtView = contentFeature:view(szC[1],szC[2]*szC[3]) 498 | local styleFtView = styleFeature:view(szS[1],szS[2]*szS[3]) 499 | local contentFtN,contentFtD = normalize_features(contentFtView) 500 | local styleFtN,styleFtD = normalize_features(styleFtView) 501 | 502 | contentFtD = contentFtD - 0.05 503 | contentFtD[contentFtD:lt(0.000001)] = 0.0 504 | contentFtD[contentFtD:gt(0.000001)] = 1.0 505 | local gainMap = contentFtD*alpha 506 | gainMap = gainMap:view(1,szC[2],szC[3]) 507 | --image.display(gainMap:squeeze()) 508 | gainMap = gainMap:expandAs(contentFeature) 509 | 510 | --[[ 511 | contentFtD = -300.0*(contentFtD-0.05) 512 | local gainMap = torch.cinv((1+torch.exp(contentFtD))) 513 | gainMap = gainMap:view(1,szC[2],szC[3]) 514 | image.display(gainMap:squeeze()) 515 | gainMap = gainMap:expandAs(contentFeature) 516 | gainMap = alpha*gainMap 517 | --]] 518 | 519 | return torch.cmul(contentFeature,gainMap)+torch.cmul(styleFeature,1-gainMap) 520 | end 521 | 522 | --/////////////////////////////////////////////////////////////////////////////////////////// 523 | local function styleTransfer_wct(content, style) 524 | 525 | loadModel() 526 | 527 | print('Start wct') 528 | 529 | content = content:cuda() 530 | style = style:cuda() 531 | local cF5 = vgg5:forward(content):clone() 532 | local sF5 = vgg5:forward(style):clone() 533 | vgg5 = nil 534 | local csF5 = nil 535 | --csF5 = feature_swap(cF5, sF5) 536 | csF5 = feature_wct(cF5, sF5) 537 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 538 | local Im5 = decoder5:forward(csF5) 539 | decoder5 = nil 540 | 541 | local cF4 = vgg4:forward(Im5):clone() 542 | local sF4 = vgg4:forward(style):clone() 543 | vgg4 = nil 544 | --local csF4 = feature_swap(cF4,sF4) 545 | local csF4 = feature_wct(cF4, sF4) 546 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 547 | local Im4 = decoder4:forward(csF4) 548 | decoder4 = nil 549 | 550 | local cF3 = vgg3:forward(Im4):clone() 551 | local sF3 = vgg3:forward(style):clone() 552 | vgg3 = nil 553 | local csF3 = feature_wct(cF3, sF3) 554 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 555 | 556 | local Im3 = decoder3:forward(csF3) 557 | decoder3 = nil 558 | local cF2 = vgg2:forward(Im3):clone() 559 | local sF2 = vgg2:forward(style):clone() 560 | vgg2 = nil 561 | 562 | local csF2 = feature_wct(cF2, sF2) 563 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 564 | local Im2 = decoder2:forward(csF2) 565 | decoder2 = nil 566 | 567 | local cF1 = vgg1:forward(Im2):clone() 568 | local sF1 = vgg1:forward(style):clone() 569 | vgg1 = nil 570 | local csF1 = feature_wct(cF1, sF1) 571 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 572 | local Im1 = decoder1:forward(csF1) 573 | decoder1 = nil 574 | return Im1 575 | end 576 | 577 | --//////////////////////////////////////////////////////////////////////////////////// 578 | local function styleTransfer_adaIn(content, style) 579 | loadModel() 580 | 581 | print('Start AdaIn') 582 | 583 | content = content:cuda() 584 | style = style:cuda() 585 | 586 | local cF5 = vgg5:forward(content):clone() 587 | local sF5 = vgg5:forward(style):clone() 588 | vgg5 = nil 589 | csF5 = adain5:forward({cF5, sF5}):squeeze() 590 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 591 | local Im5 = decoder5:forward(csF5) 592 | decoder5 = nil 593 | 594 | local cF4 = vgg4:forward(Im5):clone() 595 | local sF4 = vgg4:forward(style):clone() 596 | vgg4 = nil 597 | local csF4 = adain4:forward({cF4, sF4}):squeeze() 598 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 599 | local Im4 = decoder4:forward(csF4) 600 | decoder4 = nil 601 | 602 | local cF3 = vgg3:forward(Im4):clone() 603 | local sF3 = vgg3:forward(style):clone() 604 | vgg3 = nil 605 | local csF3 = adain3:forward({cF3, sF3}):squeeze() 606 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 607 | local Im3 = decoder3:forward(csF3) 608 | decoder3 = nil 609 | 610 | local cF2 = vgg2:forward(Im3):clone() 611 | local sF2 = vgg2:forward(style):clone() 612 | vgg2 = nil 613 | local csF2 = adain2:forward({cF2, sF2}):squeeze() 614 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 615 | local Im2 = decoder2:forward(csF2) 616 | decoder2 = nil 617 | 618 | local cF1 = vgg1:forward(Im2):clone() 619 | local sF1 = vgg1:forward(style):clone() 620 | vgg1 = nil 621 | local csF1 = adain1:forward({cF1, sF1}):squeeze() 622 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 623 | local Im1 = decoder1:forward(csF1) 624 | decoder1 = nil 625 | return Im1 626 | end 627 | 628 | --////////////////////////////////////////////////////////////////////////////////////// 629 | local function styleTransfer_clamp(content, style) 630 | 631 | loadModel() 632 | 633 | local cSz = content:size() 634 | local sSz = style:size() 635 | 636 | content = content:cuda() 637 | style = style:cuda() 638 | 639 | --[[ 640 | local cF5 = vgg5:forward(content):clone() 641 | local sF5 = vgg5:forward(style):clone() 642 | vgg5 = nil 643 | csF5 = feature_clamp(cF5,sF5) 644 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 645 | local Im5 = decoder5:forward(csF5):clone() 646 | decoder5 = nil 647 | --]] 648 | 649 | vgg5 = nil 650 | decoder5 = nil 651 | local Im5 = content 652 | 653 | Im5 = image.scale(Im5:float(),cSz[3],cSz[2]) 654 | Im5 = Im5:cuda() 655 | local cF4 = vgg4:forward(Im5):clone() 656 | local sF4 = vgg4:forward(style):clone() 657 | vgg4 = nil 658 | --local csF4 = feature_clamp(cF4,sF4) 659 | local csF4 = feature_blend(cF4,sF4,0.8) 660 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 661 | local Im4 = decoder4:forward(csF4):clone() 662 | decoder4 = nil 663 | 664 | Im4 = image.scale(Im4:float(),cSz[3],cSz[2]) 665 | Im4 = Im4:cuda() 666 | local cF3 = vgg3:forward(Im4):clone() 667 | local sF3 = vgg3:forward(style):clone() 668 | vgg3 = nil 669 | --local csF3 = feature_clamp(cF3,sF3) 670 | local csF3 = feature_blend(cF3,sF3,0.7) 671 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 672 | local Im3 = decoder3:forward(csF3):clone() 673 | decoder3 = nil 674 | 675 | Im3 = image.scale(Im3:float(),cSz[3],cSz[2]) 676 | Im3 = Im3:cuda() 677 | local cF2 = vgg2:forward(Im3):clone() 678 | local sF2 = vgg2:forward(style):clone() 679 | vgg2 = nil 680 | --local csF2 = feature_clamp(cF2,sF2) 681 | local csF2 = feature_blend(cF2,sF2,0.6) 682 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 683 | local Im2 = decoder2:forward(csF2):clone() 684 | decoder2 = nil 685 | 686 | Im2 = image.scale(Im2:float(),cSz[3],cSz[2]) 687 | Im2 = Im2:cuda() 688 | local cF1 = vgg1:forward(Im2):clone() 689 | local sF1 = vgg1:forward(style):clone() 690 | vgg1 = nil 691 | --local csF1 = feature_clamp(cF1,sF1) 692 | local csF1 = feature_blend(cF1,sF1,0.3) 693 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 694 | local Im1 = decoder1:forward(csF1):clone() 695 | decoder1 = nil 696 | Im1 = image.scale(Im1:float(),cSz[3],cSz[2]) 697 | Im1 = Im1:cuda() 698 | 699 | return Im1 700 | end 701 | 702 | --/////////////////////////////////////////////////////////////////////////////////////// 703 | local function styleTransfer_mk(content, style) 704 | 705 | loadModel() 706 | 707 | print('Start MK') 708 | 709 | content = content:cuda() 710 | style = style:cuda() 711 | local cF5 = vgg5:forward(content):clone() 712 | local sF5 = vgg5:forward(style):clone() 713 | vgg5 = nil 714 | local csF5 = nil 715 | csF5 = feature_mk2(cF5, sF5) 716 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 717 | local Im5 = decoder5:forward(csF5) 718 | decoder5 = nil 719 | 720 | local cF4 = vgg4:forward(Im5):clone() 721 | local sF4 = vgg4:forward(style):clone() 722 | vgg4 = nil 723 | local csF4 = feature_mk2(cF4, sF4) 724 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 725 | local Im4 = decoder4:forward(csF4) 726 | decoder4 = nil 727 | 728 | local cF3 = vgg3:forward(Im4):clone() 729 | local sF3 = vgg3:forward(style):clone() 730 | vgg3 = nil 731 | local csF3 = feature_mk2(cF3, sF3) 732 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 733 | 734 | local Im3 = decoder3:forward(csF3) 735 | decoder3 = nil 736 | local cF2 = vgg2:forward(Im3):clone() 737 | local sF2 = vgg2:forward(style):clone() 738 | vgg2 = nil 739 | 740 | local csF2 = feature_mk2(cF2, sF2) 741 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 742 | local Im2 = decoder2:forward(csF2) 743 | decoder2 = nil 744 | 745 | local cF1 = vgg1:forward(Im2):clone() 746 | local sF1 = vgg1:forward(style):clone() 747 | vgg1 = nil 748 | local csF1 = feature_mk2(cF1, sF1) 749 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 750 | local Im1 = decoder1:forward(csF1) 751 | decoder1 = nil 752 | return Im1 753 | end 754 | --/////////////////////////////////////////////////////////////////////////////////////// 755 | local function styleTransfer_mk_sem(content, style) 756 | 757 | loadModel() 758 | content = content:cuda() 759 | style = style:cuda() 760 | 761 | --/////// 762 | local cF5 = vgg5:forward(content):clone() 763 | local sF5 = vgg5:forward(style):clone() 764 | vgg5 = nil 765 | local maskC = image.scale(masks.cMask,cF5:size(3),cF5:size(2),'simple') 766 | local maskS = image.scale(masks.sMask,sF5:size(3),sF5:size(2),'simple') 767 | local csF5 = nil 768 | csF5 = feature_mk3_sem(cF5, sF5,maskC,maskS) 769 | csF5 = opt.alpha * csF5 + (1.0-opt.alpha) * cF5 770 | local Im5 = decoder5:forward(csF5) 771 | decoder5 = nil 772 | 773 | --////// 774 | local cF4 = vgg4:forward(Im5):clone() 775 | local sF4 = vgg4:forward(style):clone() 776 | vgg4 = nil 777 | maskC = image.scale(masks.cMask,cF4:size(3),cF4:size(2),'simple') 778 | maskS = image.scale(masks.sMask,sF4:size(3),sF4:size(2),'simple') 779 | local csF4 = feature_mk3_sem(cF4, sF4,maskC,maskS) 780 | csF4 = opt.alpha * csF4 + (1.0-opt.alpha) * cF4 781 | local Im4 = decoder4:forward(csF4) 782 | decoder4 = nil 783 | 784 | --////// 785 | local cF3 = vgg3:forward(Im4):clone() 786 | local sF3 = vgg3:forward(style):clone() 787 | vgg3 = nil 788 | maskC = image.scale(masks.cMask,cF3:size(3),cF3:size(2),'simple') 789 | maskS = image.scale(masks.sMask,sF3:size(3),sF3:size(2),'simple') 790 | local csF3 = feature_mk3_sem(cF3, sF3,maskC,maskS) 791 | csF3 = opt.alpha * csF3 + (1.0-opt.alpha) * cF3 792 | local Im3 = decoder3:forward(csF3) 793 | decoder3 = nil 794 | 795 | --/////// 796 | local cF2 = vgg2:forward(Im3):clone() 797 | local sF2 = vgg2:forward(style):clone() 798 | vgg2 = nil 799 | maskC = image.scale(masks.cMask,cF2:size(3),cF2:size(2),'simple') 800 | maskS = image.scale(masks.sMask,sF2:size(3),sF2:size(2),'simple') 801 | local csF2 = feature_mk3_sem(cF2, sF2,maskC,maskS) 802 | csF2 = opt.alpha * csF2 + (1.0-opt.alpha) * cF2 803 | local Im2 = decoder2:forward(csF2) 804 | decoder2 = nil 805 | 806 | --////// 807 | local cF1 = vgg1:forward(Im2):clone() 808 | local sF1 = vgg1:forward(style):clone() 809 | vgg1 = nil 810 | maskC = image.scale(masks.cMask,cF1:size(3),cF1:size(2),'simple') 811 | maskS = image.scale(masks.sMask,sF1:size(3),sF1:size(2),'simple') 812 | local csF1 = feature_mk3_sem(cF1, sF1,maskC,maskS) 813 | csF1 = opt.alpha * csF1 + (1.0-opt.alpha) * cF1 814 | local Im1 = decoder1:forward(csF1) 815 | decoder1 = nil 816 | 817 | return Im1 818 | end 819 | 820 | --//////////////////////////////////////////////////////////////////////////////// 821 | 822 | print('Creating save folder at ' .. opt.outputDir) 823 | paths.mkdir(opt.outputDir) 824 | local contentPaths = extractImageNamesRecursive(opt.content) 825 | local stylePaths = extractImageNamesRecursive(opt.style) 826 | print('Num of content images = ' .. tostring(#contentPaths)) 827 | print('Num of style images = ' .. tostring(#stylePaths)) 828 | for ck,cv in pairs(contentPaths) do 829 | for sk,sv in pairs(stylePaths) do 830 | local contentPath = cv 831 | local contentExt = paths.extname(contentPath) 832 | local contentName = paths.basename(contentPath,contentExt) 833 | local contentImg = image.load(contentPath, 3, 'float') 834 | contentImg = sizePreprocess(contentImg, opt.contentSize) 835 | 836 | local stylePath = sv 837 | local styleExt = paths.extname(stylePath) 838 | local styleName = paths.basename(stylePath,styleExt) 839 | local styleImg = image.load(stylePath, 3, 'float') 840 | styleImg = sizePreprocess(styleImg, opt.styleSize) 841 | 842 | local output = styleTransfer_mk(contentImg, styleImg) 843 | --local output = styleTransfer_wct(contentImg,styleImg) 844 | --local output = styleTransfer_adaIn(contentImg,styleImg) 845 | local savePath = paths.concat(opt.outputDir, contentName .. '_stylized_by_' .. styleName .. '_mk.jpg') 846 | print('Output image saved at: ' .. savePath) 847 | image.save(savePath, output) 848 | end 849 | end 850 | 851 | --------------------------------------------------------------------------------