├── DDM.png ├── 1557.pdf ├── models_word.png ├── DeepCD_brown.png ├── DeepCD_triplet.png ├── supplementary_final.pdf ├── main ├── UBCdataset │ └── Readme.md ├── train │ ├── DataDependentModule.lua │ ├── allWeightedMSECriterion.lua │ ├── DistanceRatioCriterion_allW.lua │ └── fun_DeepCD_2S.lua ├── Readme.md ├── runAllDeepCD.sh ├── models │ ├── model_DeepCD_2stream.lua │ └── model_DeepCD_2S_binary_binary.lua ├── note.txt ├── evalDeepCD.sh ├── utils.lua └── eval │ └── fun_evalDeepCD_2S.lua ├── LICENSE └── README.md /DDM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/DDM.png -------------------------------------------------------------------------------- /1557.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/1557.pdf -------------------------------------------------------------------------------- /models_word.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/models_word.png -------------------------------------------------------------------------------- /DeepCD_brown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/DeepCD_brown.png -------------------------------------------------------------------------------- /DeepCD_triplet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/DeepCD_triplet.png -------------------------------------------------------------------------------- /supplementary_final.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/supplementary_final.pdf -------------------------------------------------------------------------------- /main/UBCdataset/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## Download the UBC dataset (Brown dataset) 3 | 4 | Follow https://github.com/vbalnt/UBC-Phototour-Patches-Torch 5 | ``` 6 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/notredame-t7.tar.gz 7 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/liberty-t7.tar.gz 8 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/yosemite-t7.tar.gz 9 | ``` 10 | Put the unzipped t7 files under UBCdataset folder 11 | -------------------------------------------------------------------------------- /main/train/DataDependentModule.lua: -------------------------------------------------------------------------------- 1 | local DataDependentModule, parent = torch.class('nn.DataDependentModule', 'nn.Module') 2 | 3 | function DataDependentModule:__init(DDM_learning_rate) 4 | parent.__init(self) 5 | self.gradInput = {} 6 | self.DDM_learning_rate = DDM_learning_rate or 0 7 | end 8 | 9 | function DataDependentModule:updateOutput(input) 10 | self.output = input[1] 11 | return self.output 12 | end 13 | 14 | 15 | function DataDependentModule:updateGradInput(input, gradOutput) 16 | 17 | 18 | self.gradInput[1] = gradOutput:clone() 19 | 20 | DDM_vec = input[2] 21 | --print(DDM_vec) 22 | self.gradInput[2] = gradOutput[{{},2}]:clone() 23 | for i=1, self.gradInput[1]:size(1) do 24 | self.gradInput[1][{i,2}] = self.gradInput[1][{i,2}]*DDM_vec[{i}] 25 | self.gradInput[2][{i}] = self.gradInput[2][{i}]*input[1][{i,2}]*self.DDM_learning_rate 26 | end 27 | 28 | return self.gradInput 29 | end 30 | -------------------------------------------------------------------------------- /main/Readme.md: -------------------------------------------------------------------------------- 1 | # DeepCD Source Code 2 | 3 | Author: Tsun-Yi Yang 楊存毅 4 | 5 | This folder presents the source codes I used for DeepCD. 6 | 7 | DeepCD project is heavily inspired by pnnet https://github.com/vbalnt/pnnet 8 | 9 | ## Platform 10 | + Torch7 11 | + Matlab 12 | 13 | ## Dependencies 14 | + Cuda 15 | + Cudnn 16 | + matio 17 | + gnuplot 18 | ``` 19 | luarocks install matio 20 | luarocks install gnuplot 21 | ``` 22 | We use MATLAB to save and analysis some information (ex:DDM). 23 | 24 | ## Parameter concepts 25 | 26 | + Read note.txt 27 | 28 | ## Download the UBC dataset (Brown dataset) 29 | 30 | Follow https://github.com/vbalnt/UBC-Phototour-Patches-Torch 31 | ``` 32 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/notredame-t7.tar.gz 33 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/liberty-t7.tar.gz 34 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/yosemite-t7.tar.gz 35 | ``` 36 | Put the unzipped t7 files under UBCdataset folder 37 | 38 | ## Simple training command for DeepCD 39 | ``` 40 | sh runAllDeepCD.sh 41 | ``` 42 | 43 | ## Simple evaluation command 44 | ``` 45 | sh evalDeepCD.sh 46 | ``` 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Tsun-Yi Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main/runAllDeepCD.sh: -------------------------------------------------------------------------------- 1 | 2 | # Training DeepCD 2-stream (lead: real, complementary: binary) with DDM 3 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "liberty" true 1e-4 4 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "notredame" true 1e-4 5 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "yosemite" true 1e-4 6 | 7 | # Training DeepCD 2-stream (lead: real, complementary: binary) without DDM 8 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "liberty" false 9 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "notredame" false 10 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "yosemite" false 11 | 12 | 13 | # Training DeepCD 2-stream (lead: binary, complementary: binary) with DDM 14 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "liberty" true 5e-4 15 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "notredame" true 5e-4 16 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "yosemite" true 1e-3 17 | 18 | # Training DeepCD 2-stream (lead: binary, complementary: binary) without DDM 19 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "liberty" false 20 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "notredame" false 21 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "yosemite" false 22 | -------------------------------------------------------------------------------- /main/models/model_DeepCD_2stream.lua: -------------------------------------------------------------------------------- 1 | function createModel(pT) 2 | -- setup the CNN 3 | model = nn.Sequential() 4 | local CT = nn.ConcatTable() 5 | 6 | local lead_net = nn.Sequential() 7 | lead_net:add(cudnn.SpatialConvolution(1, 32, 7, 7)) 8 | lead_net:add(cudnn.Tanh(true)) 9 | lead_net:add(cudnn.SpatialMaxPooling(2,2,2,2)) 10 | lead_net:add(cudnn.SpatialConvolution(32, 64, 6, 6)) 11 | lead_net:add(cudnn.Tanh(true)) 12 | lead_net:add(nn.View(64*8*8)) 13 | lead_net:add(nn.Linear(64*8*8, pT.dim1)) 14 | lead_net:add(cudnn.Tanh(true)) 15 | 16 | 17 | local com_net = nn.Sequential() 18 | com_net:add(cudnn.SpatialConvolution(1, 32, 7, 7)) 19 | com_net:add(cudnn.Tanh(true)) 20 | com_net:add(cudnn.SpatialMaxPooling(2,2,2,2)) 21 | com_net:add(cudnn.SpatialConvolution(32, 64, 6, 6)) 22 | com_net:add(cudnn.Tanh(true)) 23 | com_net:add(nn.View(64*8*8)) 24 | com_net:add(nn.Linear(64*8*8, math.max(128,pT.bits2/2))) 25 | com_net:add(cudnn.Tanh(true)) 26 | com_net:add(nn.Linear(math.max(128,pT.bits2/2), pT.bits2)) 27 | com_net:add(nn.MulConstant(pT.scal_sigmoid, true)) 28 | com_net:add(cudnn.Sigmoid()) 29 | com_net:add(nn.MulConstant(pT.w_com,true)) 30 | 31 | CT:add(lead_net) 32 | CT:add(com_net) 33 | model:add(CT) 34 | return model 35 | 36 | end -------------------------------------------------------------------------------- /main/note.txt: -------------------------------------------------------------------------------- 1 | Parameters Concepts: 2 | 3 | 1. w_com 4 | 5 | This term balance the domain range between leading descriptor and the completing 6 | descriptor. 7 | 8 | 2. ws_lead = 1 9 | 10 | We want to preserve the original optimization parameters of SoftPN. (For both loss 11 | function and learning rate) Therefore this term is fixed to 1. 12 | 13 | 3. ws_pro 14 | 15 | This term enhances the importance of our optimization target, product score. If this 16 | term is 1, then the product score will be equally important as the SoftPN loss 17 | function, and the optimization of product score will fail since the loss function will 18 | tend to optimize the SoftPN loss function. 19 | 20 | For example, if the leading descriptor is real-valued 128 dim descriptor, and it is way 21 | stronger than the completing descriptor which is 256 bits binary descriptor. Then we 22 | set ws_pro a big value (at least bigger than 1), for example ws_pro=5. 23 | 24 | 25 | 26 | ------------------------------------------------------------------------------------ 27 | 1. Leading: real-valued, Completing: binary 28 | 29 | 128dim1, 256bits2: w_com=1.41, ws_lead=1, ws_pro=5 30 | 31 | 32 | 2. Leading: binary, Completing: binary 33 | 34 | 512bits1, 256bits2: w_com=1.41, ws_lead=1, ws_pro=2 35 | 36 | 128bits1, 64bits2: w_com=1.41, ws_lead=1, ws_pro=2 37 | 38 | 64bits1, 32bits2: w_com=1.41, ws_lead=1, ws_pro=2 39 | 40 | 41 | -------------------------------------------------------------------------------- /main/models/model_DeepCD_2S_binary_binary.lua: -------------------------------------------------------------------------------- 1 | function createModel(pT) 2 | -- setup the CNN 3 | model = nn.Sequential() 4 | local CT = nn.ConcatTable() 5 | 6 | local lead_net = nn.Sequential() 7 | lead_net:add(cudnn.SpatialConvolution(1, 32, 7, 7)) 8 | lead_net:add(cudnn.Tanh(true)) 9 | lead_net:add(cudnn.SpatialMaxPooling(2,2,2,2)) 10 | lead_net:add(cudnn.SpatialConvolution(32, 64, 6, 6)) 11 | lead_net:add(cudnn.Tanh(true)) 12 | lead_net:add(nn.View(64*8*8)) 13 | lead_net:add(nn.Linear(64*8*8, math.max(128,pT.bits1/2))) 14 | lead_net:add(cudnn.Tanh(true)) 15 | lead_net:add(nn.Linear(math.max(128,pT.bits1/2), pT.bits1)) 16 | lead_net:add(nn.MulConstant(pT.scal_sigmoid, true)) 17 | lead_net:add(cudnn.Sigmoid()) 18 | 19 | 20 | 21 | local com_net = nn.Sequential() 22 | com_net:add(cudnn.SpatialConvolution(1, 32, 7, 7)) 23 | com_net:add(cudnn.Tanh(true)) 24 | com_net:add(cudnn.SpatialMaxPooling(2,2,2,2)) 25 | com_net:add(cudnn.SpatialConvolution(32, 64, 6, 6)) 26 | com_net:add(cudnn.Tanh(true)) 27 | com_net:add(nn.View(64*8*8)) 28 | com_net:add(nn.Linear(64*8*8, math.max(128,pT.bits2/2))) 29 | com_net:add(cudnn.Tanh(true)) 30 | com_net:add(nn.Linear(math.max(128,pT.bits2/2), pT.bits2)) 31 | com_net:add(nn.MulConstant(pT.scal_sigmoid, true)) 32 | com_net:add(cudnn.Sigmoid()) 33 | com_net:add(nn.MulConstant(pT.w_com,true)) 34 | 35 | CT:add(lead_net) 36 | CT:add(com_net) 37 | model:add(CT) 38 | return model 39 | 40 | end -------------------------------------------------------------------------------- /main/evalDeepCD.sh: -------------------------------------------------------------------------------- 1 | # Uncomment the command you want to operate. 2 | # Example: 3 | #---------------------------------------------------------------------------------------------------------- 4 | # th ./eval/fun_evalDeepCD_2S.lua "rb" 128 256 "liberty" "notredame" true 1 5 | #---------------------------------------------------------------------------------------------------------- 6 | # "rb": real-valued for leading and binary for complementary 7 | # 128 256: 128 dim for leading and 256 bits for complementary 8 | # "liberty" "notredame": the 1st one is the training subset, the 2nd one is the evaluation subset 9 | # true: DDM training is true. Otherwise its false 10 | # The last input is the epoch number of the model. 11 | 12 | 13 | #---------------------------------------------------------------------------------------------------------- 14 | # Evaluate DeepCD 2-stream (lead: real, complementary: binary) with DDM 15 | th ./eval/fun_evalDeepCD_2S.lua "rb" 128 256 "liberty" "notredame" true 1 16 | 17 | 18 | #---------------------------------------------------------------------------------------------------------- 19 | # Evaluate DeepCD 2-stream (lead: real, complementary: binary) without DDM 20 | #th ./eval/fun_evalDeepCD_2S.lua "rb" 128 256 "liberty" "notredame" false 1 21 | 22 | 23 | #---------------------------------------------------------------------------------------------------------- 24 | # Training DeepCD 2-stream (lead: binary, complementary: binary) with DDM 25 | #th ./eval/fun_evalDeepCD_2S.lua "bb" 512 256 "liberty" "notredame" true 1 26 | 27 | 28 | #---------------------------------------------------------------------------------------------------------- 29 | # Training DeepCD 2-stream (lead: binary, complementary: binary) without DDM 30 | #th ./eval/fun_evalDeepCD_2S.lua "bb" 512 256 "liberty" "notredame" false 1 31 | -------------------------------------------------------------------------------- /main/train/allWeightedMSECriterion.lua: -------------------------------------------------------------------------------- 1 | local allWeightedMSECriterion, parent = torch.class('nn.allWeightedMSECriterion','nn.MSECriterion') 2 | 3 | function allWeightedMSECriterion:__init(w) 4 | parent.__init(self) 5 | self.weight = w:clone() 6 | end 7 | 8 | function allWeightedMSECriterion:updateOutput(input,target) 9 | 10 | self.buffer1 = self.buffer1 or input.new() 11 | self.buffer1:resizeAs(input):copy(input) 12 | if input:dim() - 1 == self.weight:dim() then 13 | for i=1,input:size(1) do 14 | self.buffer1[i]:cmul(self.weight) 15 | end 16 | else 17 | self.buffer1:cmul(self.weight) 18 | end 19 | 20 | self.buffer2 = self.buffer2 or input.new() 21 | self.buffer2:resizeAs(input):copy(target) 22 | if input:dim() - 1 == self.weight:dim() then 23 | for i=1,input:size(1) do 24 | self.buffer2[i]:cmul(self.weight) 25 | end 26 | else 27 | self.buffer2:cmul(self.weight) 28 | end 29 | 30 | self.output_tensor = self.output_tensor or input.new(1) 31 | input.THNN.MSECriterion_updateOutput( 32 | self.buffer1:cdata(), 33 | self.buffer2:cdata(), 34 | self.output_tensor:cdata(), 35 | self.sizeAverage 36 | ) 37 | self.output = self.output_tensor[1] 38 | return self.output 39 | end 40 | 41 | function allWeightedMSECriterion:updateGradInput(input, target) 42 | self.buffer1 = self.buffer1 or input.new() 43 | self.buffer1:resizeAs(input):copy(input) 44 | if input:dim() - 1 == self.weight:dim() then 45 | for i=1,input:size(1) do 46 | self.buffer1[i]:cmul(self.weight) 47 | end 48 | else 49 | self.buffer1:cmul(self.weight) 50 | end 51 | 52 | self.buffer2 = self.buffer2 or input.new() 53 | self.buffer2:resizeAs(input):copy(target) 54 | if input:dim() - 1 == self.weight:dim() then 55 | for i=1,input:size(1) do 56 | self.buffer2[i]:cmul(self.weight) 57 | end 58 | else 59 | self.buffer2:cmul(self.weight) 60 | end 61 | 62 | input.THNN.MSECriterion_updateGradInput( 63 | self.buffer1:cdata(), 64 | self.buffer2:cdata(), 65 | self.gradInput:cdata(), 66 | self.sizeAverage 67 | ) 68 | return self.gradInput 69 | end -------------------------------------------------------------------------------- /main/train/DistanceRatioCriterion_allW.lua: -------------------------------------------------------------------------------- 1 | -- Taken from Elad Hoffer's TripletNet https://github.com/eladhoffer 2 | -- Hinge loss ranking could also be used, see below 3 | -- https://github.com/torch/nn/blob/master/doc/criterion.md#nn.MarginRankingCriterion 4 | 5 | local DistanceRatioCriterion_allW, parent = torch.class('nn.DistanceRatioCriterion_allW', 'nn.Criterion') 6 | 7 | function DistanceRatioCriterion_allW:__init(w) 8 | parent.__init(self) 9 | self.SoftMax = nn.SoftMax() 10 | self.wMSE = nn.allWeightedMSECriterion(w) 11 | -- wMSE:forward(a,b) equals to "sum((a*w-b*w).^2)/dim" 12 | self.Target = torch.Tensor() 13 | end 14 | 15 | function DistanceRatioCriterion_allW:createTarget(input, target) 16 | local target = target or 1 17 | self.Target:resizeAs(input):typeAs(input):zero() 18 | self.Target[{{},target}]:add(1) 19 | --self.Target[{{},target+1}]:add(1) 20 | 21 | self.Target[{{},3}]:add(1) 22 | self.Target[{{},4}]:add(1) 23 | --self.Target[{{},5}]:add(1) 24 | --self.Target[{{},7}]:add(1) 25 | --print(input:size()) 26 | --print(self.Target:size()) 27 | --os.exit() 28 | --[[ 29 | print(input) 30 | ... 31 | 0.2617 0.5276 32 | 0.1764 0.3031 33 | 0.4771 0.3169 34 | 0.3398 0.0905 35 | 0.3689 0.1940 36 | --[torch.CudaTensor of size 128x2] 37 | 38 | print(self.Target) 39 | ... 40 | 1 0 41 | 1 0 42 | 1 0 43 | 1 0 44 | 1 0 45 | [torch.CudaTensor of size 128x2] 46 | 47 | --The first column contains negtive distance 48 | --while the second one the is positive distance 49 | os.exit() 50 | --]] 51 | end 52 | 53 | function DistanceRatioCriterion_allW:updateOutput(input, target) 54 | if not self.Target:isSameSizeAs(input) then 55 | self:createTarget(input, target) 56 | end 57 | self.output = self.wMSE:updateOutput(self.SoftMax:updateOutput(input),self.Target) 58 | return self.output 59 | end 60 | 61 | function DistanceRatioCriterion_allW:updateGradInput(input, target) 62 | if not self.Target:isSameSizeAs(input) then 63 | self:createTarget(input, target) 64 | end 65 | 66 | self.gradInput = self.SoftMax:updateGradInput(input, self.wMSE:updateGradInput(self.SoftMax.output,self.Target)) 67 | return self.gradInput 68 | end 69 | 70 | function DistanceRatioCriterion_allW:type(t) 71 | parent.type(self, t) 72 | self.SoftMax:type(t) 73 | self.wMSE:type(t) 74 | self.Target = self.Target:type(t) 75 | return self 76 | end 77 | -------------------------------------------------------------------------------- /main/utils.lua: -------------------------------------------------------------------------------- 1 | -- From pnnet 2 | -- read dataset in the wanted format 3 | -- example here: http://vbalnt.io/notredame-torch.tar.gz 4 | function read_brown_data(name) 5 | local d = torch.load(name..'.t7') -- this one is modified 6 | d.patches32 = d.patches32:float() 7 | -- labels in data are zero-indexed 8 | d.labels:add(1) 9 | return d 10 | end 11 | 12 | -- get the stats 13 | function get_stats(d) 14 | local mi = d.patches32:mean() 15 | local sigma = d.patches32:std() 16 | local stats = {} 17 | stats.mi = mi 18 | stats.sigma = sigma 19 | return stats 20 | end 21 | 22 | -- norm data based on stats 23 | function norm_data(d,stats) 24 | d.patches32:add(-stats.mi):div(stats.sigma) 25 | end 26 | 27 | -- following functions taken from Elad Hoffer's 28 | -- TripletNet https://github.com/eladhoffer 29 | function ArrangeByLabel(traind) 30 | local numClasses = traind.labels:max() 31 | local Ordered = {} 32 | for i=1,traind.labels:size(1) do 33 | -- print(i) 34 | if Ordered[traind.labels[i]] == nil then 35 | Ordered[traind.labels[i]] = {} 36 | end 37 | table.insert(Ordered[traind.labels[i]], i) 38 | end 39 | return Ordered 40 | end 41 | 42 | 43 | function generate_pairs(traind, num_pairs) 44 | local list = torch.IntTensor(num_pairs,3) 45 | local pairs = torch.IntTensor(num_pairs,3) 46 | 47 | local Ordered = ArrangeByLabel(traind) 48 | local nClasses = #Ordered 49 | for i=1, num_pairs do 50 | -- print(i) 51 | local c1 = math.random(nClasses) 52 | local c2 = math.random(nClasses) 53 | while c2 == c1 do 54 | c2 = math.random(nClasses) 55 | end 56 | local n1 = math.random(#Ordered[c1]) 57 | local n2 = math.random(#Ordered[c2]) 58 | local n3 = math.random(#Ordered[c1]) 59 | while n3 == n1 do 60 | n3 = math.random(#Ordered[c1]) 61 | end 62 | 63 | list[i][1] = Ordered[c1][n1] 64 | list[i][2] = Ordered[c2][n2] 65 | list[i][3] = Ordered[c1][n3] 66 | 67 | lbl = math.random(0,2) 68 | if ((lbl==0) or (lbl==1)) then 69 | pairs[i][1] = list[i][1] 70 | pairs[i][2] = list[i][2] 71 | pairs[i][3] = -1 72 | else 73 | pairs[i][1] = list[i][1] 74 | pairs[i][2] = list[i][3] 75 | pairs[i][3] = 1 76 | end 77 | end 78 | 79 | return pairs 80 | 81 | end 82 | 83 | 84 | 85 | function generate_triplets(traind, num_pairs) 86 | local list = torch.IntTensor(num_pairs,3) 87 | 88 | local Ordered = ArrangeByLabel(traind) 89 | local nClasses = #Ordered 90 | for i=1, num_pairs do 91 | -- print(i) 92 | local c1 = math.random(nClasses) 93 | local c2 = math.random(nClasses) 94 | while c2 == c1 do 95 | c2 = math.random(nClasses) 96 | end 97 | local n1 = math.random(#Ordered[c1]) 98 | local n2 = math.random(#Ordered[c2]) 99 | local n3 = math.random(#Ordered[c1]) 100 | while n3 == n1 do 101 | n3 = math.random(#Ordered[c1]) 102 | end 103 | list[i][1] = Ordered[c1][n1] 104 | list[i][2] = Ordered[c2][n2] 105 | list[i][3] = Ordered[c1][n3] 106 | end 107 | 108 | return list 109 | end 110 | 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepCD 2 | Code Author: Tsun-Yi Yang 3 | 4 | Last update: 2017/08/17 (Training and testing codes are both uploaded.) 5 | 6 | Platform: Ubuntu 14.04, Torch7 7 | 8 | Paper 9 | - 10 | **[ICCV17] DeepCD: Learning Deep Complementary Descriptors for Patch Representations** 11 | 12 | **Authors: [Tsun-Yi Yang](http://shamangary.logdown.com/), Jo-Han Hsu, [Yen-Yu Lin](https://www.citi.sinica.edu.tw/pages/yylin/index_zh.html), and [Yung-Yu Chuang](https://www.csie.ntu.edu.tw/~cyy/)** 13 | 14 | **PDF:** 15 | + Link1: http://www.csie.ntu.edu.tw/~cyy/publications/papers/Yang2017DLD.pdf 16 | + Link2: https://github.com/shamangary/DeepCD/blob/master/1557.pdf 17 | 18 | Code abstract 19 | - 20 | This is the source code of DeepCD. The training is done on Brown dataset. 21 | 22 | Two distinct descriptors are learned for the same network. 23 | 24 | Product late fusion in distance domain is performed before the final ranking. 25 | 26 | DeepCD project is heavily inspired by pnnet https://github.com/vbalnt/pnnet 27 | 28 | ***This respository:*** (author: Tsun-Yi Yang) 29 | + Brown dataset (Training and testing) https://github.com/shamangary/DeepCD/tree/master/main 30 | 31 | ***Related respositories:*** (author: Jo-Han Hsu) 32 | + MVS dataset (testing) https://github.com/Rohan8288/DeepCD_MVS 33 | + Oxford dataset (testing) https://github.com/Rohan8288/DeepCD_Oxford 34 | 35 | Model 36 | - 37 | 38 | 39 | Training with Data-Dependent Modulation (DDM) layer 40 | - 41 | + DDM layer dynamically adapt the learning rate of the complementary stream. 42 | 43 | + It consider information of the whole batch by considering both leading and complementary distances. 44 | 45 | The backward gradient value is scaled by a factor η (1e-3~1e-4). This step not only let us to slow down the learning of fully connected layer inside DDM layer, but also let us to approximately ignore the effect of DDM layer on the forward propagation of the complementary stream and make it an identity operation. The update equation is basically the the backward equation derived from multipling a parameter w from the previous layer. 46 | 47 | 48 | 49 | ``` 50 | a_DDM = nn.Identity() 51 | output_layer_DDM = nn.Linear(pT.batch_size*2,pT.batch_size) 52 | output_layer_DDM.weight:fill(0) 53 | output_layer_DDM.bias:fill(1) 54 | b_DDM = nn.Sequential():add(nn.Reshape(pT.batch_size*2,false)):add(output_layer_DDM):add(nn.Sigmoid()) 55 | DDM_ct1 = nn.ConcatTable():add(a_DDM:clone()):add(b_DDM:clone()) 56 | DDM_layer = nn.Sequential():add(DDM_ct1):add(nn.DataDependentModule(pT.DDM_LR)) 57 | ``` 58 | Testing stage 59 | - 60 | + A **hard threshold** will be appied on the complementary descriptor before the Hamming distance calculation. 61 | 62 | + **DDM layer is not involved in the testing stage** since we only need the trained model from the triplet structure. 63 | 64 | + **Product late fusion at distance domain** is computed before the final ranking. 65 | 66 | Brown dataset results 67 | - 68 | 69 | 70 | -------------------------------------------------------------------------------- /main/eval/fun_evalDeepCD_2S.lua: -------------------------------------------------------------------------------- 1 | require 'cutorch' 2 | require 'xlua' 3 | require 'trepl' 4 | require 'cunn' 5 | require 'cudnn' 6 | require 'image' 7 | require 'nn' 8 | require 'torch' 9 | require 'gnuplot' 10 | require '../utils.lua' 11 | 12 | des_type = arg[1] 13 | if des_type == 'rb' then 14 | dim1 = tonumber(arg[2]) 15 | bits2 = tonumber(arg[3]) 16 | elseif des_type == 'bb' then 17 | bits1 = tonumber(arg[2]) 18 | bits2 = tonumber(arg[3]) 19 | end 20 | network = arg[4] 21 | eval_data = arg[5] 22 | if arg[6] == 'true' then 23 | isDDM = true 24 | elseif arg[6] == 'false' then 25 | isDDM = false 26 | end 27 | checkEpoch = arg[7] 28 | 29 | -- load default 128 out tanh-maxpooling network trained on liberty dataset 30 | -- for more details http://phototour.cs.washington.edu/patches/default.htm 31 | 32 | if isDDM then 33 | addName = '_DDM' 34 | else 35 | addName = '' 36 | end 37 | 38 | if des_type == 'rb' then 39 | net = torch.load('./train_epoch/DeepCD_2S'..addName..'_'..dim1..'dim1_'..bits2..'bits2_'..network..'/NET_DeepCD_2S'..addName..'_'..network..'_epoch'..checkEpoch..'.t7') 40 | elseif des_type == 'bb' then 41 | net = torch.load('./train_epoch/DeepCD_2S'..addName..'_'..bits1..'bits1_'..bits2..'bits2_'..network..'/NET_DeepCD_2S'..addName..'_'..network..'_epoch'..checkEpoch..'.t7') 42 | end 43 | net:evaluate() 44 | 45 | -- test on the testing gt (100k pairs from Brown's dataset) 46 | ntest = 100000 47 | R = torch.ones(2,ntest) 48 | print(net) 49 | net:get(1):get(2):remove(12) 50 | 51 | trained = torch.load('./UBCdataset/'..network..'.t7') 52 | dataset = torch.load('./UBCdataset/'..eval_data..'.t7') 53 | stats = get_stats(trained) 54 | print(stats) 55 | norm_data(dataset,stats) 56 | npatches = (dataset.patches32:size(1)) 57 | print(npatches) 58 | 59 | -- normalize data 60 | patches32 = dataset.patches32 61 | 62 | -- split the patches in batches to avoid memory problems 63 | BatchSize = 128 64 | 65 | 66 | 67 | for iter=1, 2 do 68 | 69 | 70 | if iter == 1 then 71 | if des_type == 'rb' then 72 | dim = dim1 73 | isBinary = false 74 | elseif des_type == 'bb' then 75 | dim = bits1 76 | isBinary = true 77 | end 78 | net:add(nn.SelectTable(1)) 79 | 80 | elseif iter == 2 then 81 | dim = bits2 82 | isBinary = true 83 | net:remove(2) 84 | net:add(nn.SelectTable(2)) 85 | 86 | end 87 | 88 | 89 | 90 | 91 | 92 | local Descrs = torch.CudaTensor(npatches,dim) 93 | local DescrsSplit = Descrs:split(BatchSize) 94 | for i,v in ipairs(patches32:split(BatchSize)) do 95 | temp = v:clone():cuda() 96 | DescrsSplit[i]:copy(net:forward(temp)) 97 | --print(net:get(1):get(2):get(1):get(1):get(2):get(11).output) 98 | --os.exit() 99 | end 100 | 101 | 102 | 103 | 104 | 105 | for j=1,ntest do 106 | l = dataset.gt100k[j] 107 | lbl = l[2]==l[5] and 1 or 0 108 | id1 = l[1]+1 109 | id2 = l[4]+1 110 | dl = Descrs[{ {id1},{} }] 111 | dr = Descrs[{ {id2},{} }] 112 | 113 | if isBinary then 114 | dl = dl:gt(0.5):float() 115 | dr = dr:gt(0.5):float() 116 | end 117 | d = torch.dist(dl,dr) 118 | 119 | 120 | if iter == 1 then 121 | 122 | 123 | R[{{1},{j}}] = lbl 124 | R[{{2},{j}}] = R[{{2},{j}}]*d 125 | elseif iter ==2 then 126 | 127 | R[{{2},{j}}] = R[{{2},{j}}]*d 128 | 129 | 130 | end 131 | 132 | --io.write(string.format("%d %.4f \n", lbl,d)) 133 | end 134 | end 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | --FPR95(FPR at TPR or recall 95%) and ROC curve 144 | val_sorted, temp_id_sorted = torch.sort(R[{{2},{}}]) 145 | id_sorted = temp_id_sorted[{1,{}}] 146 | pn_sorted = R[{{1},{}}]:index(2,id_sorted) 147 | 148 | 149 | pos_all = torch.sum(R[{{1},{}}]) 150 | print('pos_all:'..pos_all) 151 | neg_all = ntest-pos_all 152 | pos_95 = torch.floor(pos_all*0.95) 153 | 154 | pos_acc = pn_sorted:clone() --Don't forget "clone()" 155 | 156 | for j=2,ntest do 157 | pos_acc[{1,j}] = pos_acc[{1,j-1}] + pos_acc[{1,j}] 158 | end 159 | 160 | 161 | print('pos_95:'..pos_95) 162 | id_95 = pos_acc[{1,{}}]:eq(pos_95):nonzero():min() 163 | print(id_95) 164 | 165 | 166 | 167 | tpr = torch.Tensor(1,ntest) 168 | fpr = torch.Tensor(1,ntest) 169 | for k=1,ntest do 170 | tpr[{1,k}] = pos_acc[{1,k}]/pos_all 171 | fpr[{1,k}] = torch.abs(pos_acc[{1,k}]-k)/neg_all 172 | end 173 | 174 | 175 | 176 | Result = torch.cat(fpr,tpr,1) 177 | gnuplot.plot(Result:t()) 178 | gnuplot.title('ROC curve') 179 | gnuplot.xlabel('FPR') 180 | gnuplot.ylabel('TPR') 181 | FPR95 = fpr[{1,id_95}] 182 | print('FPR95%:'.. FPR95*100) 183 | 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /main/train/fun_DeepCD_2S.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require '../utils.lua' 3 | require 'image' 4 | require 'optim' 5 | require './DistanceRatioCriterion_allW.lua' 6 | require 'cudnn' 7 | require 'cutorch' 8 | require 'cunn' 9 | require './allWeightedMSECriterion.lua' 10 | require './DataDependentModule.lua' 11 | 12 | local matio = require 'matio' 13 | 14 | des_type = arg[1] 15 | if des_type == 'rb' then 16 | dim1 = tonumber(arg[2]) 17 | bits2 = tonumber(arg[3]) 18 | elseif des_type == 'bb' then 19 | bits1 = tonumber(arg[2]) 20 | bits2 = tonumber(arg[3]) 21 | end 22 | w_com = tonumber(arg[4]) 23 | ws_lead = tonumber(arg[5]) 24 | ws_pro = tonumber(arg[6]) 25 | name = arg[7] 26 | if arg[8] == 'true' then 27 | isDDM = true 28 | elseif arg[8] == 'false' then 29 | isDDM = false 30 | end 31 | DDM_LR = tonumber(arg[9]) 32 | 33 | pT = { 34 | 35 | --dim1 = 128, 36 | --bits2 = 256, 37 | --w_com= 1.41, 38 | --ws_lead = 1, 39 | --ws_pro = 5, 40 | scal_sigmoid = 100, 41 | isNorm = true, 42 | --isSTN = false, 43 | --name = 'liberty', 44 | batch_size = 128, 45 | num_triplets = 1280000, 46 | max_epoch = 100 47 | 48 | } 49 | 50 | 51 | if des_type == 'rb' then 52 | pT.dim1 = dim1 53 | pT.bits2 = bits2 54 | elseif des_type == 'bb' then 55 | pT.bits1 = bits1 56 | pT.bits2 = bits2 57 | end 58 | pT.w_com = w_com 59 | pT.ws_lead = ws_lead 60 | pT.ws_pro = ws_pro 61 | pT.name = name 62 | pT.isDDM = isDDM 63 | pT.DDM_LR = DDM_LR 64 | 65 | -- optim parameters 66 | optimState = { 67 | learningRate = 0.1, 68 | weightDecay = 1e-4, 69 | momentum = 0.9, 70 | learningRateDecay = 1e-6 71 | } 72 | pT.optimState = optimState 73 | 74 | if pT.isDDM then 75 | addName = 'DDM_' 76 | else 77 | addName = '' 78 | end 79 | 80 | if des_type == 'rb' then 81 | folder_name = './train_epoch/DeepCD_2S_'..addName..pT.dim1..'dim1_'..pT.bits2..'bits2_'..pT.name 82 | elseif des_type == 'bb' then 83 | folder_name = './train_epoch/DeepCD_2S_'..addName..pT.bits1..'bits1_'..pT.bits2..'bits2_'..pT.name 84 | end 85 | 86 | print(pT) 87 | os.execute("mkdir -p " .. folder_name) 88 | if pT.isDDM then 89 | os.execute("mkdir -p " .. folder_name.."/DDM_vec/") 90 | end 91 | torch.save(folder_name..'/ParaTable_DeepCD_2S_'..addName..pT.name..'.t7',pT) 92 | 93 | 94 | ------------------------------------------------------------------------------------------------ 95 | 96 | 97 | -- number of threads 98 | torch.setnumthreads(13) 99 | 100 | -- read training data, save mu and sigma & normalize 101 | 102 | traind = read_brown_data('./UBCdataset/'..pT.name) 103 | stats = get_stats(traind) 104 | print(stats) 105 | if pT.isNorm then 106 | norm_data(traind,stats) 107 | end 108 | print("==> read the dataset") 109 | 110 | -- generate random triplets for training data 111 | 112 | training_triplets = generate_triplets(traind, pT.num_triplets) 113 | print("==> created the tests") 114 | ------------------------------------------------------------------------------------------------ 115 | if des_type == 'rb' then 116 | paths.dofile('../models/model_DeepCD_2stream.lua') 117 | elseif des_type == 'bb' then 118 | paths.dofile('../models/model_DeepCD_2S_binary_binary.lua') 119 | end 120 | model1 = createModel(pT) 121 | model1:training() 122 | 123 | -- clone the other two networks in the triplet 124 | model2 = model1:clone('weight', 'bias','gradWeight','gradBias') 125 | model3 = model1:clone('weight', 'bias','gradWeight','gradBias') 126 | 127 | -- add them to a parallel table 128 | prl = nn.ParallelTable() 129 | prl:add(model1) 130 | prl:add(model2) 131 | prl:add(model3) 132 | prl:cuda() 133 | ------------------------------------------------------------------------------------------------ 134 | mlp= nn.Sequential() 135 | mlp:add(prl) 136 | 137 | -- get feature distances 138 | cc = nn.ConcatTable() 139 | 140 | 141 | -- feats 1 with 2 leading 142 | cnn_left1_lead = nn.Sequential() 143 | 144 | cnnneg1_dist_lead = nn.ConcatTable() 145 | a_neg1_lead = nn.Sequential() 146 | a_neg1_lead:add(nn.SelectTable(1)) 147 | a_neg1_lead:add(nn.SelectTable(1)) 148 | b_neg1_lead = nn.Sequential() 149 | b_neg1_lead:add(nn.SelectTable(2)) 150 | b_neg1_lead:add(nn.SelectTable(1)) 151 | cnnneg1_dist_lead:add(a_neg1_lead) 152 | cnnneg1_dist_lead:add(b_neg1_lead) 153 | 154 | cnn_left1_lead:add(cnnneg1_dist_lead) 155 | cnn_left1_lead:add(nn.PairwiseDistance(2)) 156 | cnn_left1_lead:add(nn.View(pT.batch_size ,1)) 157 | cnn_left1_lead:cuda() 158 | cc:add(cnn_left1_lead) 159 | 160 | -- feats 1 with 2 completing 161 | cnn_left1_com = nn.Sequential() 162 | 163 | cnnneg1_dist_com = nn.ConcatTable() 164 | a_neg1_com = nn.Sequential() 165 | a_neg1_com:add(nn.SelectTable(1)) 166 | a_neg1_com:add(nn.SelectTable(2)) 167 | b_neg1_com = nn.Sequential() 168 | b_neg1_com:add(nn.SelectTable(2)) 169 | b_neg1_com:add(nn.SelectTable(2)) 170 | cnnneg1_dist_com:add(a_neg1_com) 171 | cnnneg1_dist_com:add(b_neg1_com) 172 | 173 | cnn_left1_com:add(cnnneg1_dist_com) 174 | cnn_left1_com:add(nn.PairwiseDistance(2)) 175 | cnn_left1_com:add(nn.View(pT.batch_size ,1)) 176 | cnn_left1_com:cuda() 177 | cc:add(cnn_left1_com) 178 | 179 | -- feats 2 with 3 leading 180 | cnn_left2_lead = nn.Sequential() 181 | cnnneg2_dist_lead = nn.ConcatTable() 182 | a_neg2_lead = nn.Sequential() 183 | a_neg2_lead:add(nn.SelectTable(2)) 184 | a_neg2_lead:add(nn.SelectTable(1)) 185 | b_neg2_lead = nn.Sequential() 186 | b_neg2_lead:add(nn.SelectTable(3)) 187 | b_neg2_lead:add(nn.SelectTable(1)) 188 | cnnneg2_dist_lead:add(a_neg2_lead) 189 | cnnneg2_dist_lead:add(b_neg2_lead) 190 | cnn_left2_lead:add(cnnneg2_dist_lead) 191 | cnn_left2_lead:add(nn.PairwiseDistance(2)) 192 | cnn_left2_lead:add(nn.View(pT.batch_size,1)) 193 | cnn_left2_lead:cuda() 194 | cc:add(cnn_left2_lead) 195 | 196 | -- feats 2 with 3 completing 197 | cnn_left2_com = nn.Sequential() 198 | cnnneg2_dist_com = nn.ConcatTable() 199 | a_neg2_com = nn.Sequential() 200 | a_neg2_com:add(nn.SelectTable(2)) 201 | a_neg2_com:add(nn.SelectTable(2)) 202 | b_neg2_com = nn.Sequential() 203 | b_neg2_com:add(nn.SelectTable(3)) 204 | b_neg2_com:add(nn.SelectTable(2)) 205 | cnnneg2_dist_com:add(a_neg2_com) 206 | cnnneg2_dist_com:add(b_neg2_com) 207 | cnn_left2_com:add(cnnneg2_dist_com) 208 | cnn_left2_com:add(nn.PairwiseDistance(2)) 209 | cnn_left2_com:add(nn.View(pT.batch_size ,1)) 210 | cnn_left2_com:cuda() 211 | cc:add(cnn_left2_com) 212 | 213 | -- feats 1 with 3 leading 214 | cnn_right_lead = nn.Sequential() 215 | cnnpos_dist_lead = nn.ConcatTable() 216 | a_pos_lead = nn.Sequential() 217 | a_pos_lead:add(nn.SelectTable(1)) 218 | a_pos_lead:add(nn.SelectTable(1)) 219 | b_pos_lead = nn.Sequential() 220 | b_pos_lead:add(nn.SelectTable(3)) 221 | b_pos_lead:add(nn.SelectTable(1)) 222 | cnnpos_dist_lead:add(a_pos_lead) 223 | cnnpos_dist_lead:add(b_pos_lead) 224 | cnn_right_lead:add(cnnpos_dist_lead) 225 | cnn_right_lead:add(nn.PairwiseDistance(2)) 226 | cnn_right_lead:add(nn.View(pT.batch_size ,1)) 227 | cnn_right_lead:cuda() 228 | cc:add(cnn_right_lead) 229 | 230 | -- feats 1 with 3 completing 231 | cnn_right_com = nn.Sequential() 232 | cnnpos_dist_com = nn.ConcatTable() 233 | a_pos_com = nn.Sequential() 234 | a_pos_com:add(nn.SelectTable(1)) 235 | a_pos_com:add(nn.SelectTable(2)) 236 | b_pos_com = nn.Sequential() 237 | b_pos_com:add(nn.SelectTable(3)) 238 | b_pos_com:add(nn.SelectTable(2)) 239 | cnnpos_dist_com:add(a_pos_com) 240 | cnnpos_dist_com:add(b_pos_com) 241 | cnn_right_com:add(cnnpos_dist_com) 242 | cnn_right_com:add(nn.PairwiseDistance(2)) 243 | cnn_right_com:add(nn.View(pT.batch_size ,1)) 244 | cnn_right_com:cuda() 245 | cc:add(cnn_right_com) 246 | 247 | 248 | cc:cuda() 249 | 250 | mlp:add(cc) 251 | ------------------------------------------------------------------------------------------------ 252 | last_layer = nn.ConcatTable() 253 | 254 | 255 | -- select leading min negative distance inside the triplet 256 | mined_neg = nn.Sequential() 257 | mining_layer = nn.ConcatTable() 258 | mining_layer:add(nn.SelectTable(1)) 259 | mining_layer:add(nn.SelectTable(3)) 260 | mined_neg:add(mining_layer) 261 | mined_neg:add(nn.JoinTable(2)) 262 | mined_neg:add(nn.Min(2)) 263 | mined_neg:add(nn.View(pT.batch_size ,1)) 264 | last_layer:add(mined_neg) 265 | 266 | -- add leading positive distance 267 | pos_layer = nn.Sequential() 268 | pos_layer:add(nn.SelectTable(5)) 269 | pos_layer:add(nn.View(pT.batch_size ,1)) 270 | last_layer:add(pos_layer) 271 | 272 | 273 | ------------------------------------------------------------------------------------------------ 274 | a_DDM = nn.Identity() 275 | output_layer_DDM = nn.Linear(pT.batch_size*2,pT.batch_size) 276 | output_layer_DDM.weight:fill(0) 277 | output_layer_DDM.bias:fill(1) 278 | b_DDM = nn.Sequential():add(nn.Reshape(pT.batch_size*2,false)):add(output_layer_DDM):add(nn.Sigmoid()) 279 | DDM_ct = nn.ConcatTable():add(a_DDM:clone()):add(b_DDM:clone()) 280 | DDM_layer = nn.Sequential():add(DDM_ct):add(nn.DataDependentModule(pT.DDM_LR)) 281 | ------------------------------------------------------------------------------------------------ 282 | 283 | 284 | --add neg1 (real,binary) distance 285 | neg1_RB_layer = nn.Sequential() 286 | temp_neg1_RB = nn.ConcatTable() 287 | 288 | S1_neg1 = nn.Sequential() 289 | S1_neg1:add(nn.SelectTable(1)) 290 | 291 | S2_neg1 = nn.Sequential() 292 | S2_neg1:add(nn.SelectTable(2)) 293 | 294 | temp_neg1_RB:add(S1_neg1) 295 | temp_neg1_RB:add(S2_neg1) 296 | neg1_RB_layer:add(temp_neg1_RB) 297 | if pT.isDDM then 298 | neg1_RB_layer:add(nn.JoinTable(2)) 299 | neg1_RB_layer:add(DDM_layer) 300 | neg1_RB_layer:add(nn.SplitTable(2)) 301 | end 302 | neg1_RB_layer:add(nn.CMulTable()) 303 | neg1_RB_layer:add(nn.Sqrt()) 304 | neg1_RB_layer:add(nn.View(pT.batch_size ,1)) 305 | neg1_RB_layer:cuda() 306 | last_layer:add(neg1_RB_layer) 307 | 308 | --add neg2 (real,binary) distance 309 | neg2_RB_layer = nn.Sequential() 310 | temp_neg2_RB = nn.ConcatTable() 311 | 312 | S1_neg2 = nn.Sequential() 313 | S1_neg2:add(nn.SelectTable(3)) 314 | 315 | S2_neg2 = nn.Sequential() 316 | S2_neg2:add(nn.SelectTable(4)) 317 | 318 | temp_neg2_RB:add(S1_neg2) 319 | temp_neg2_RB:add(S2_neg2 ) 320 | neg2_RB_layer:add(temp_neg2_RB) 321 | if pT.isDDM then 322 | neg2_RB_layer:add(nn.JoinTable(2)) 323 | neg2_RB_layer:add(DDM_layer) 324 | neg2_RB_layer:add(nn.SplitTable(2)) 325 | end 326 | neg2_RB_layer:add(nn.CMulTable()) 327 | neg2_RB_layer:add(nn.Sqrt()) 328 | neg2_RB_layer:add(nn.View(pT.batch_size ,1)) 329 | neg2_RB_layer:cuda() 330 | last_layer:add(neg2_RB_layer) 331 | 332 | 333 | --add pos (real,binary) distance 334 | pos_RB_layer = nn.Sequential() 335 | temp_pos_RB = nn.ConcatTable() 336 | 337 | S1_pos = nn.Sequential() 338 | S1_pos:add(nn.SelectTable(5)) 339 | S2_pos = nn.Sequential() 340 | S2_pos:add(nn.SelectTable(6)) 341 | 342 | temp_pos_RB:add(S1_pos) 343 | temp_pos_RB:add(S2_pos) 344 | pos_RB_layer:add(temp_pos_RB) 345 | pos_RB_layer:add(nn.CMulTable()) 346 | pos_RB_layer:add(nn.Sqrt()) 347 | 348 | pos_RB_layer:add(nn.View(pT.batch_size ,1)) 349 | pos_RB_layer:cuda() 350 | last_layer:add(pos_RB_layer) 351 | ------------------------------------------------------------------------------------------------ 352 | mlp:add(last_layer) 353 | 354 | mlp:add(nn.JoinTable(2)) 355 | mlp:cuda() 356 | ------------------------------------------------------------------------------------------------ 357 | -- setup the criterion: ratio of min-negative to positive 358 | epoch = 1 359 | 360 | 361 | 362 | x=torch.zeros(pT.batch_size,1,32,32):cuda() 363 | y=torch.zeros(pT.batch_size,1,32,32):cuda() 364 | z=torch.zeros(pT.batch_size,1,32,32):cuda() 365 | 366 | 367 | 368 | parameters, gradParameters = mlp:getParameters() 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | -- main training loop 377 | 378 | Loss = torch.zeros(1,pT.max_epoch) 379 | 380 | 381 | 382 | w = torch.zeros(128,5) 383 | w[{{},{1,2}}]:fill(pT.ws_lead) 384 | w[{{},{3,5}}]:fill(pT.ws_pro) 385 | crit=nn.DistanceRatioCriterion_allW(w):cuda() 386 | 387 | for epoch=epoch, pT.max_epoch do 388 | 389 | 390 | 391 | Gerr = 0 392 | shuffle = torch.randperm(pT.num_triplets) 393 | nbatches = pT.num_triplets/pT.batch_size 394 | 395 | for k=1,nbatches-1 do 396 | xlua.progress(k+1, nbatches) 397 | 398 | s = shuffle[{ {k*pT.batch_size,k*pT.batch_size+pT.batch_size} }] 399 | for i=1,pT.batch_size do 400 | x[i] = traind.patches32[training_triplets[s[i]][1]] 401 | y[i] = traind.patches32[training_triplets[s[i]][2]] 402 | z[i] = traind.patches32[training_triplets[s[i]][3]] 403 | end 404 | 405 | local feval = function(f) 406 | if f ~= parameters then parameters:copy(f) end 407 | gradParameters:zero() 408 | inputs = {x,y,z} 409 | local outputs = mlp:forward(inputs) 410 | 411 | local f = crit:forward(outputs, 1) 412 | Gerr = Gerr+f 413 | local df_do = crit:backward(outputs) 414 | mlp:backward(inputs, df_do) 415 | return f,gradParameters 416 | end 417 | optim.sgd(feval, parameters, optimState) 418 | 419 | end 420 | loss = Gerr/nbatches 421 | Loss[{{1},{epoch}}] = loss 422 | if pT.isDDM then 423 | print(DDM_vec) 424 | matio.save(folder_name..'/DDM_vec/DDM_vec_epoch'..epoch..'.mat',{DDM_vec=DDM_vec:double()}) 425 | end 426 | 427 | print('==> epoch '..epoch) 428 | print(loss) 429 | print('') 430 | 431 | --remain = math.fmod(epoch,3) 432 | --if epoch == 1 or remain ==0 then 433 | net_save = mlp:get(1):get(1):clone() 434 | torch.save(folder_name..'/NET_DeepCD_2S_'..addName..pT.name..'_epoch'..epoch..'.t7',net_save:clearState()) 435 | --end 436 | 437 | end 438 | 439 | torch.save(folder_name..'/Loss_DeepCD_2S_'..addName..pT.name..'.t7',Loss) 440 | 441 | 442 | 443 | --------------------------------------------------------------------------------