├── DDM.png
├── 1557.pdf
├── models_word.png
├── DeepCD_brown.png
├── DeepCD_triplet.png
├── supplementary_final.pdf
├── main
├── UBCdataset
│ └── Readme.md
├── train
│ ├── DataDependentModule.lua
│ ├── allWeightedMSECriterion.lua
│ ├── DistanceRatioCriterion_allW.lua
│ └── fun_DeepCD_2S.lua
├── Readme.md
├── runAllDeepCD.sh
├── models
│ ├── model_DeepCD_2stream.lua
│ └── model_DeepCD_2S_binary_binary.lua
├── note.txt
├── evalDeepCD.sh
├── utils.lua
└── eval
│ └── fun_evalDeepCD_2S.lua
├── LICENSE
└── README.md
/DDM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/DDM.png
--------------------------------------------------------------------------------
/1557.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/1557.pdf
--------------------------------------------------------------------------------
/models_word.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/models_word.png
--------------------------------------------------------------------------------
/DeepCD_brown.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/DeepCD_brown.png
--------------------------------------------------------------------------------
/DeepCD_triplet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/DeepCD_triplet.png
--------------------------------------------------------------------------------
/supplementary_final.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shamangary/DeepCD/HEAD/supplementary_final.pdf
--------------------------------------------------------------------------------
/main/UBCdataset/Readme.md:
--------------------------------------------------------------------------------
1 |
2 | ## Download the UBC dataset (Brown dataset)
3 |
4 | Follow https://github.com/vbalnt/UBC-Phototour-Patches-Torch
5 | ```
6 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/notredame-t7.tar.gz
7 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/liberty-t7.tar.gz
8 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/yosemite-t7.tar.gz
9 | ```
10 | Put the unzipped t7 files under UBCdataset folder
11 |
--------------------------------------------------------------------------------
/main/train/DataDependentModule.lua:
--------------------------------------------------------------------------------
1 | local DataDependentModule, parent = torch.class('nn.DataDependentModule', 'nn.Module')
2 |
3 | function DataDependentModule:__init(DDM_learning_rate)
4 | parent.__init(self)
5 | self.gradInput = {}
6 | self.DDM_learning_rate = DDM_learning_rate or 0
7 | end
8 |
9 | function DataDependentModule:updateOutput(input)
10 | self.output = input[1]
11 | return self.output
12 | end
13 |
14 |
15 | function DataDependentModule:updateGradInput(input, gradOutput)
16 |
17 |
18 | self.gradInput[1] = gradOutput:clone()
19 |
20 | DDM_vec = input[2]
21 | --print(DDM_vec)
22 | self.gradInput[2] = gradOutput[{{},2}]:clone()
23 | for i=1, self.gradInput[1]:size(1) do
24 | self.gradInput[1][{i,2}] = self.gradInput[1][{i,2}]*DDM_vec[{i}]
25 | self.gradInput[2][{i}] = self.gradInput[2][{i}]*input[1][{i,2}]*self.DDM_learning_rate
26 | end
27 |
28 | return self.gradInput
29 | end
30 |
--------------------------------------------------------------------------------
/main/Readme.md:
--------------------------------------------------------------------------------
1 | # DeepCD Source Code
2 |
3 | Author: Tsun-Yi Yang 楊存毅
4 |
5 | This folder presents the source codes I used for DeepCD.
6 |
7 | DeepCD project is heavily inspired by pnnet https://github.com/vbalnt/pnnet
8 |
9 | ## Platform
10 | + Torch7
11 | + Matlab
12 |
13 | ## Dependencies
14 | + Cuda
15 | + Cudnn
16 | + matio
17 | + gnuplot
18 | ```
19 | luarocks install matio
20 | luarocks install gnuplot
21 | ```
22 | We use MATLAB to save and analysis some information (ex:DDM).
23 |
24 | ## Parameter concepts
25 |
26 | + Read note.txt
27 |
28 | ## Download the UBC dataset (Brown dataset)
29 |
30 | Follow https://github.com/vbalnt/UBC-Phototour-Patches-Torch
31 | ```
32 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/notredame-t7.tar.gz
33 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/liberty-t7.tar.gz
34 | wget http://www.iis.ee.ic.ac.uk/~vbalnt/yosemite-t7.tar.gz
35 | ```
36 | Put the unzipped t7 files under UBCdataset folder
37 |
38 | ## Simple training command for DeepCD
39 | ```
40 | sh runAllDeepCD.sh
41 | ```
42 |
43 | ## Simple evaluation command
44 | ```
45 | sh evalDeepCD.sh
46 | ```
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Tsun-Yi Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/main/runAllDeepCD.sh:
--------------------------------------------------------------------------------
1 |
2 | # Training DeepCD 2-stream (lead: real, complementary: binary) with DDM
3 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "liberty" true 1e-4
4 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "notredame" true 1e-4
5 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "yosemite" true 1e-4
6 |
7 | # Training DeepCD 2-stream (lead: real, complementary: binary) without DDM
8 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "liberty" false
9 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "notredame" false
10 | th ./train/fun_DeepCD_2S.lua "rb" 128 256 1.41 1 5 "yosemite" false
11 |
12 |
13 | # Training DeepCD 2-stream (lead: binary, complementary: binary) with DDM
14 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "liberty" true 5e-4
15 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "notredame" true 5e-4
16 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "yosemite" true 1e-3
17 |
18 | # Training DeepCD 2-stream (lead: binary, complementary: binary) without DDM
19 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "liberty" false
20 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "notredame" false
21 | th ./train/fun_DeepCD_2S.lua "bb" 512 256 1.41 1 2 "yosemite" false
22 |
--------------------------------------------------------------------------------
/main/models/model_DeepCD_2stream.lua:
--------------------------------------------------------------------------------
1 | function createModel(pT)
2 | -- setup the CNN
3 | model = nn.Sequential()
4 | local CT = nn.ConcatTable()
5 |
6 | local lead_net = nn.Sequential()
7 | lead_net:add(cudnn.SpatialConvolution(1, 32, 7, 7))
8 | lead_net:add(cudnn.Tanh(true))
9 | lead_net:add(cudnn.SpatialMaxPooling(2,2,2,2))
10 | lead_net:add(cudnn.SpatialConvolution(32, 64, 6, 6))
11 | lead_net:add(cudnn.Tanh(true))
12 | lead_net:add(nn.View(64*8*8))
13 | lead_net:add(nn.Linear(64*8*8, pT.dim1))
14 | lead_net:add(cudnn.Tanh(true))
15 |
16 |
17 | local com_net = nn.Sequential()
18 | com_net:add(cudnn.SpatialConvolution(1, 32, 7, 7))
19 | com_net:add(cudnn.Tanh(true))
20 | com_net:add(cudnn.SpatialMaxPooling(2,2,2,2))
21 | com_net:add(cudnn.SpatialConvolution(32, 64, 6, 6))
22 | com_net:add(cudnn.Tanh(true))
23 | com_net:add(nn.View(64*8*8))
24 | com_net:add(nn.Linear(64*8*8, math.max(128,pT.bits2/2)))
25 | com_net:add(cudnn.Tanh(true))
26 | com_net:add(nn.Linear(math.max(128,pT.bits2/2), pT.bits2))
27 | com_net:add(nn.MulConstant(pT.scal_sigmoid, true))
28 | com_net:add(cudnn.Sigmoid())
29 | com_net:add(nn.MulConstant(pT.w_com,true))
30 |
31 | CT:add(lead_net)
32 | CT:add(com_net)
33 | model:add(CT)
34 | return model
35 |
36 | end
--------------------------------------------------------------------------------
/main/note.txt:
--------------------------------------------------------------------------------
1 | Parameters Concepts:
2 |
3 | 1. w_com
4 |
5 | This term balance the domain range between leading descriptor and the completing
6 | descriptor.
7 |
8 | 2. ws_lead = 1
9 |
10 | We want to preserve the original optimization parameters of SoftPN. (For both loss
11 | function and learning rate) Therefore this term is fixed to 1.
12 |
13 | 3. ws_pro
14 |
15 | This term enhances the importance of our optimization target, product score. If this
16 | term is 1, then the product score will be equally important as the SoftPN loss
17 | function, and the optimization of product score will fail since the loss function will
18 | tend to optimize the SoftPN loss function.
19 |
20 | For example, if the leading descriptor is real-valued 128 dim descriptor, and it is way
21 | stronger than the completing descriptor which is 256 bits binary descriptor. Then we
22 | set ws_pro a big value (at least bigger than 1), for example ws_pro=5.
23 |
24 |
25 |
26 | ------------------------------------------------------------------------------------
27 | 1. Leading: real-valued, Completing: binary
28 |
29 | 128dim1, 256bits2: w_com=1.41, ws_lead=1, ws_pro=5
30 |
31 |
32 | 2. Leading: binary, Completing: binary
33 |
34 | 512bits1, 256bits2: w_com=1.41, ws_lead=1, ws_pro=2
35 |
36 | 128bits1, 64bits2: w_com=1.41, ws_lead=1, ws_pro=2
37 |
38 | 64bits1, 32bits2: w_com=1.41, ws_lead=1, ws_pro=2
39 |
40 |
41 |
--------------------------------------------------------------------------------
/main/models/model_DeepCD_2S_binary_binary.lua:
--------------------------------------------------------------------------------
1 | function createModel(pT)
2 | -- setup the CNN
3 | model = nn.Sequential()
4 | local CT = nn.ConcatTable()
5 |
6 | local lead_net = nn.Sequential()
7 | lead_net:add(cudnn.SpatialConvolution(1, 32, 7, 7))
8 | lead_net:add(cudnn.Tanh(true))
9 | lead_net:add(cudnn.SpatialMaxPooling(2,2,2,2))
10 | lead_net:add(cudnn.SpatialConvolution(32, 64, 6, 6))
11 | lead_net:add(cudnn.Tanh(true))
12 | lead_net:add(nn.View(64*8*8))
13 | lead_net:add(nn.Linear(64*8*8, math.max(128,pT.bits1/2)))
14 | lead_net:add(cudnn.Tanh(true))
15 | lead_net:add(nn.Linear(math.max(128,pT.bits1/2), pT.bits1))
16 | lead_net:add(nn.MulConstant(pT.scal_sigmoid, true))
17 | lead_net:add(cudnn.Sigmoid())
18 |
19 |
20 |
21 | local com_net = nn.Sequential()
22 | com_net:add(cudnn.SpatialConvolution(1, 32, 7, 7))
23 | com_net:add(cudnn.Tanh(true))
24 | com_net:add(cudnn.SpatialMaxPooling(2,2,2,2))
25 | com_net:add(cudnn.SpatialConvolution(32, 64, 6, 6))
26 | com_net:add(cudnn.Tanh(true))
27 | com_net:add(nn.View(64*8*8))
28 | com_net:add(nn.Linear(64*8*8, math.max(128,pT.bits2/2)))
29 | com_net:add(cudnn.Tanh(true))
30 | com_net:add(nn.Linear(math.max(128,pT.bits2/2), pT.bits2))
31 | com_net:add(nn.MulConstant(pT.scal_sigmoid, true))
32 | com_net:add(cudnn.Sigmoid())
33 | com_net:add(nn.MulConstant(pT.w_com,true))
34 |
35 | CT:add(lead_net)
36 | CT:add(com_net)
37 | model:add(CT)
38 | return model
39 |
40 | end
--------------------------------------------------------------------------------
/main/evalDeepCD.sh:
--------------------------------------------------------------------------------
1 | # Uncomment the command you want to operate.
2 | # Example:
3 | #----------------------------------------------------------------------------------------------------------
4 | # th ./eval/fun_evalDeepCD_2S.lua "rb" 128 256 "liberty" "notredame" true 1
5 | #----------------------------------------------------------------------------------------------------------
6 | # "rb": real-valued for leading and binary for complementary
7 | # 128 256: 128 dim for leading and 256 bits for complementary
8 | # "liberty" "notredame": the 1st one is the training subset, the 2nd one is the evaluation subset
9 | # true: DDM training is true. Otherwise its false
10 | # The last input is the epoch number of the model.
11 |
12 |
13 | #----------------------------------------------------------------------------------------------------------
14 | # Evaluate DeepCD 2-stream (lead: real, complementary: binary) with DDM
15 | th ./eval/fun_evalDeepCD_2S.lua "rb" 128 256 "liberty" "notredame" true 1
16 |
17 |
18 | #----------------------------------------------------------------------------------------------------------
19 | # Evaluate DeepCD 2-stream (lead: real, complementary: binary) without DDM
20 | #th ./eval/fun_evalDeepCD_2S.lua "rb" 128 256 "liberty" "notredame" false 1
21 |
22 |
23 | #----------------------------------------------------------------------------------------------------------
24 | # Training DeepCD 2-stream (lead: binary, complementary: binary) with DDM
25 | #th ./eval/fun_evalDeepCD_2S.lua "bb" 512 256 "liberty" "notredame" true 1
26 |
27 |
28 | #----------------------------------------------------------------------------------------------------------
29 | # Training DeepCD 2-stream (lead: binary, complementary: binary) without DDM
30 | #th ./eval/fun_evalDeepCD_2S.lua "bb" 512 256 "liberty" "notredame" false 1
31 |
--------------------------------------------------------------------------------
/main/train/allWeightedMSECriterion.lua:
--------------------------------------------------------------------------------
1 | local allWeightedMSECriterion, parent = torch.class('nn.allWeightedMSECriterion','nn.MSECriterion')
2 |
3 | function allWeightedMSECriterion:__init(w)
4 | parent.__init(self)
5 | self.weight = w:clone()
6 | end
7 |
8 | function allWeightedMSECriterion:updateOutput(input,target)
9 |
10 | self.buffer1 = self.buffer1 or input.new()
11 | self.buffer1:resizeAs(input):copy(input)
12 | if input:dim() - 1 == self.weight:dim() then
13 | for i=1,input:size(1) do
14 | self.buffer1[i]:cmul(self.weight)
15 | end
16 | else
17 | self.buffer1:cmul(self.weight)
18 | end
19 |
20 | self.buffer2 = self.buffer2 or input.new()
21 | self.buffer2:resizeAs(input):copy(target)
22 | if input:dim() - 1 == self.weight:dim() then
23 | for i=1,input:size(1) do
24 | self.buffer2[i]:cmul(self.weight)
25 | end
26 | else
27 | self.buffer2:cmul(self.weight)
28 | end
29 |
30 | self.output_tensor = self.output_tensor or input.new(1)
31 | input.THNN.MSECriterion_updateOutput(
32 | self.buffer1:cdata(),
33 | self.buffer2:cdata(),
34 | self.output_tensor:cdata(),
35 | self.sizeAverage
36 | )
37 | self.output = self.output_tensor[1]
38 | return self.output
39 | end
40 |
41 | function allWeightedMSECriterion:updateGradInput(input, target)
42 | self.buffer1 = self.buffer1 or input.new()
43 | self.buffer1:resizeAs(input):copy(input)
44 | if input:dim() - 1 == self.weight:dim() then
45 | for i=1,input:size(1) do
46 | self.buffer1[i]:cmul(self.weight)
47 | end
48 | else
49 | self.buffer1:cmul(self.weight)
50 | end
51 |
52 | self.buffer2 = self.buffer2 or input.new()
53 | self.buffer2:resizeAs(input):copy(target)
54 | if input:dim() - 1 == self.weight:dim() then
55 | for i=1,input:size(1) do
56 | self.buffer2[i]:cmul(self.weight)
57 | end
58 | else
59 | self.buffer2:cmul(self.weight)
60 | end
61 |
62 | input.THNN.MSECriterion_updateGradInput(
63 | self.buffer1:cdata(),
64 | self.buffer2:cdata(),
65 | self.gradInput:cdata(),
66 | self.sizeAverage
67 | )
68 | return self.gradInput
69 | end
--------------------------------------------------------------------------------
/main/train/DistanceRatioCriterion_allW.lua:
--------------------------------------------------------------------------------
1 | -- Taken from Elad Hoffer's TripletNet https://github.com/eladhoffer
2 | -- Hinge loss ranking could also be used, see below
3 | -- https://github.com/torch/nn/blob/master/doc/criterion.md#nn.MarginRankingCriterion
4 |
5 | local DistanceRatioCriterion_allW, parent = torch.class('nn.DistanceRatioCriterion_allW', 'nn.Criterion')
6 |
7 | function DistanceRatioCriterion_allW:__init(w)
8 | parent.__init(self)
9 | self.SoftMax = nn.SoftMax()
10 | self.wMSE = nn.allWeightedMSECriterion(w)
11 | -- wMSE:forward(a,b) equals to "sum((a*w-b*w).^2)/dim"
12 | self.Target = torch.Tensor()
13 | end
14 |
15 | function DistanceRatioCriterion_allW:createTarget(input, target)
16 | local target = target or 1
17 | self.Target:resizeAs(input):typeAs(input):zero()
18 | self.Target[{{},target}]:add(1)
19 | --self.Target[{{},target+1}]:add(1)
20 |
21 | self.Target[{{},3}]:add(1)
22 | self.Target[{{},4}]:add(1)
23 | --self.Target[{{},5}]:add(1)
24 | --self.Target[{{},7}]:add(1)
25 | --print(input:size())
26 | --print(self.Target:size())
27 | --os.exit()
28 | --[[
29 | print(input)
30 | ...
31 | 0.2617 0.5276
32 | 0.1764 0.3031
33 | 0.4771 0.3169
34 | 0.3398 0.0905
35 | 0.3689 0.1940
36 | --[torch.CudaTensor of size 128x2]
37 |
38 | print(self.Target)
39 | ...
40 | 1 0
41 | 1 0
42 | 1 0
43 | 1 0
44 | 1 0
45 | [torch.CudaTensor of size 128x2]
46 |
47 | --The first column contains negtive distance
48 | --while the second one the is positive distance
49 | os.exit()
50 | --]]
51 | end
52 |
53 | function DistanceRatioCriterion_allW:updateOutput(input, target)
54 | if not self.Target:isSameSizeAs(input) then
55 | self:createTarget(input, target)
56 | end
57 | self.output = self.wMSE:updateOutput(self.SoftMax:updateOutput(input),self.Target)
58 | return self.output
59 | end
60 |
61 | function DistanceRatioCriterion_allW:updateGradInput(input, target)
62 | if not self.Target:isSameSizeAs(input) then
63 | self:createTarget(input, target)
64 | end
65 |
66 | self.gradInput = self.SoftMax:updateGradInput(input, self.wMSE:updateGradInput(self.SoftMax.output,self.Target))
67 | return self.gradInput
68 | end
69 |
70 | function DistanceRatioCriterion_allW:type(t)
71 | parent.type(self, t)
72 | self.SoftMax:type(t)
73 | self.wMSE:type(t)
74 | self.Target = self.Target:type(t)
75 | return self
76 | end
77 |
--------------------------------------------------------------------------------
/main/utils.lua:
--------------------------------------------------------------------------------
1 | -- From pnnet
2 | -- read dataset in the wanted format
3 | -- example here: http://vbalnt.io/notredame-torch.tar.gz
4 | function read_brown_data(name)
5 | local d = torch.load(name..'.t7') -- this one is modified
6 | d.patches32 = d.patches32:float()
7 | -- labels in data are zero-indexed
8 | d.labels:add(1)
9 | return d
10 | end
11 |
12 | -- get the stats
13 | function get_stats(d)
14 | local mi = d.patches32:mean()
15 | local sigma = d.patches32:std()
16 | local stats = {}
17 | stats.mi = mi
18 | stats.sigma = sigma
19 | return stats
20 | end
21 |
22 | -- norm data based on stats
23 | function norm_data(d,stats)
24 | d.patches32:add(-stats.mi):div(stats.sigma)
25 | end
26 |
27 | -- following functions taken from Elad Hoffer's
28 | -- TripletNet https://github.com/eladhoffer
29 | function ArrangeByLabel(traind)
30 | local numClasses = traind.labels:max()
31 | local Ordered = {}
32 | for i=1,traind.labels:size(1) do
33 | -- print(i)
34 | if Ordered[traind.labels[i]] == nil then
35 | Ordered[traind.labels[i]] = {}
36 | end
37 | table.insert(Ordered[traind.labels[i]], i)
38 | end
39 | return Ordered
40 | end
41 |
42 |
43 | function generate_pairs(traind, num_pairs)
44 | local list = torch.IntTensor(num_pairs,3)
45 | local pairs = torch.IntTensor(num_pairs,3)
46 |
47 | local Ordered = ArrangeByLabel(traind)
48 | local nClasses = #Ordered
49 | for i=1, num_pairs do
50 | -- print(i)
51 | local c1 = math.random(nClasses)
52 | local c2 = math.random(nClasses)
53 | while c2 == c1 do
54 | c2 = math.random(nClasses)
55 | end
56 | local n1 = math.random(#Ordered[c1])
57 | local n2 = math.random(#Ordered[c2])
58 | local n3 = math.random(#Ordered[c1])
59 | while n3 == n1 do
60 | n3 = math.random(#Ordered[c1])
61 | end
62 |
63 | list[i][1] = Ordered[c1][n1]
64 | list[i][2] = Ordered[c2][n2]
65 | list[i][3] = Ordered[c1][n3]
66 |
67 | lbl = math.random(0,2)
68 | if ((lbl==0) or (lbl==1)) then
69 | pairs[i][1] = list[i][1]
70 | pairs[i][2] = list[i][2]
71 | pairs[i][3] = -1
72 | else
73 | pairs[i][1] = list[i][1]
74 | pairs[i][2] = list[i][3]
75 | pairs[i][3] = 1
76 | end
77 | end
78 |
79 | return pairs
80 |
81 | end
82 |
83 |
84 |
85 | function generate_triplets(traind, num_pairs)
86 | local list = torch.IntTensor(num_pairs,3)
87 |
88 | local Ordered = ArrangeByLabel(traind)
89 | local nClasses = #Ordered
90 | for i=1, num_pairs do
91 | -- print(i)
92 | local c1 = math.random(nClasses)
93 | local c2 = math.random(nClasses)
94 | while c2 == c1 do
95 | c2 = math.random(nClasses)
96 | end
97 | local n1 = math.random(#Ordered[c1])
98 | local n2 = math.random(#Ordered[c2])
99 | local n3 = math.random(#Ordered[c1])
100 | while n3 == n1 do
101 | n3 = math.random(#Ordered[c1])
102 | end
103 | list[i][1] = Ordered[c1][n1]
104 | list[i][2] = Ordered[c2][n2]
105 | list[i][3] = Ordered[c1][n3]
106 | end
107 |
108 | return list
109 | end
110 |
111 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepCD
2 | Code Author: Tsun-Yi Yang
3 |
4 | Last update: 2017/08/17 (Training and testing codes are both uploaded.)
5 |
6 | Platform: Ubuntu 14.04, Torch7
7 |
8 | Paper
9 | -
10 | **[ICCV17] DeepCD: Learning Deep Complementary Descriptors for Patch Representations**
11 |
12 | **Authors: [Tsun-Yi Yang](http://shamangary.logdown.com/), Jo-Han Hsu, [Yen-Yu Lin](https://www.citi.sinica.edu.tw/pages/yylin/index_zh.html), and [Yung-Yu Chuang](https://www.csie.ntu.edu.tw/~cyy/)**
13 |
14 | **PDF:**
15 | + Link1: http://www.csie.ntu.edu.tw/~cyy/publications/papers/Yang2017DLD.pdf
16 | + Link2: https://github.com/shamangary/DeepCD/blob/master/1557.pdf
17 |
18 | Code abstract
19 | -
20 | This is the source code of DeepCD. The training is done on Brown dataset.
21 |
22 | Two distinct descriptors are learned for the same network.
23 |
24 | Product late fusion in distance domain is performed before the final ranking.
25 |
26 | DeepCD project is heavily inspired by pnnet https://github.com/vbalnt/pnnet
27 |
28 | ***This respository:*** (author: Tsun-Yi Yang)
29 | + Brown dataset (Training and testing) https://github.com/shamangary/DeepCD/tree/master/main
30 |
31 | ***Related respositories:*** (author: Jo-Han Hsu)
32 | + MVS dataset (testing) https://github.com/Rohan8288/DeepCD_MVS
33 | + Oxford dataset (testing) https://github.com/Rohan8288/DeepCD_Oxford
34 |
35 | Model
36 | -
37 |
38 |
39 | Training with Data-Dependent Modulation (DDM) layer
40 | -
41 | + DDM layer dynamically adapt the learning rate of the complementary stream.
42 |
43 | + It consider information of the whole batch by considering both leading and complementary distances.
44 |
45 | The backward gradient value is scaled by a factor η (1e-3~1e-4). This step not only let us to slow down the learning of fully connected layer inside DDM layer, but also let us to approximately ignore the effect of DDM layer on the forward propagation of the complementary stream and make it an identity operation. The update equation is basically the the backward equation derived from multipling a parameter w from the previous layer.
46 |
47 | 
48 |
49 | ```
50 | a_DDM = nn.Identity()
51 | output_layer_DDM = nn.Linear(pT.batch_size*2,pT.batch_size)
52 | output_layer_DDM.weight:fill(0)
53 | output_layer_DDM.bias:fill(1)
54 | b_DDM = nn.Sequential():add(nn.Reshape(pT.batch_size*2,false)):add(output_layer_DDM):add(nn.Sigmoid())
55 | DDM_ct1 = nn.ConcatTable():add(a_DDM:clone()):add(b_DDM:clone())
56 | DDM_layer = nn.Sequential():add(DDM_ct1):add(nn.DataDependentModule(pT.DDM_LR))
57 | ```
58 | Testing stage
59 | -
60 | + A **hard threshold** will be appied on the complementary descriptor before the Hamming distance calculation.
61 |
62 | + **DDM layer is not involved in the testing stage** since we only need the trained model from the triplet structure.
63 |
64 | + **Product late fusion at distance domain** is computed before the final ranking.
65 |
66 | Brown dataset results
67 | -
68 |
69 |
70 |
--------------------------------------------------------------------------------
/main/eval/fun_evalDeepCD_2S.lua:
--------------------------------------------------------------------------------
1 | require 'cutorch'
2 | require 'xlua'
3 | require 'trepl'
4 | require 'cunn'
5 | require 'cudnn'
6 | require 'image'
7 | require 'nn'
8 | require 'torch'
9 | require 'gnuplot'
10 | require '../utils.lua'
11 |
12 | des_type = arg[1]
13 | if des_type == 'rb' then
14 | dim1 = tonumber(arg[2])
15 | bits2 = tonumber(arg[3])
16 | elseif des_type == 'bb' then
17 | bits1 = tonumber(arg[2])
18 | bits2 = tonumber(arg[3])
19 | end
20 | network = arg[4]
21 | eval_data = arg[5]
22 | if arg[6] == 'true' then
23 | isDDM = true
24 | elseif arg[6] == 'false' then
25 | isDDM = false
26 | end
27 | checkEpoch = arg[7]
28 |
29 | -- load default 128 out tanh-maxpooling network trained on liberty dataset
30 | -- for more details http://phototour.cs.washington.edu/patches/default.htm
31 |
32 | if isDDM then
33 | addName = '_DDM'
34 | else
35 | addName = ''
36 | end
37 |
38 | if des_type == 'rb' then
39 | net = torch.load('./train_epoch/DeepCD_2S'..addName..'_'..dim1..'dim1_'..bits2..'bits2_'..network..'/NET_DeepCD_2S'..addName..'_'..network..'_epoch'..checkEpoch..'.t7')
40 | elseif des_type == 'bb' then
41 | net = torch.load('./train_epoch/DeepCD_2S'..addName..'_'..bits1..'bits1_'..bits2..'bits2_'..network..'/NET_DeepCD_2S'..addName..'_'..network..'_epoch'..checkEpoch..'.t7')
42 | end
43 | net:evaluate()
44 |
45 | -- test on the testing gt (100k pairs from Brown's dataset)
46 | ntest = 100000
47 | R = torch.ones(2,ntest)
48 | print(net)
49 | net:get(1):get(2):remove(12)
50 |
51 | trained = torch.load('./UBCdataset/'..network..'.t7')
52 | dataset = torch.load('./UBCdataset/'..eval_data..'.t7')
53 | stats = get_stats(trained)
54 | print(stats)
55 | norm_data(dataset,stats)
56 | npatches = (dataset.patches32:size(1))
57 | print(npatches)
58 |
59 | -- normalize data
60 | patches32 = dataset.patches32
61 |
62 | -- split the patches in batches to avoid memory problems
63 | BatchSize = 128
64 |
65 |
66 |
67 | for iter=1, 2 do
68 |
69 |
70 | if iter == 1 then
71 | if des_type == 'rb' then
72 | dim = dim1
73 | isBinary = false
74 | elseif des_type == 'bb' then
75 | dim = bits1
76 | isBinary = true
77 | end
78 | net:add(nn.SelectTable(1))
79 |
80 | elseif iter == 2 then
81 | dim = bits2
82 | isBinary = true
83 | net:remove(2)
84 | net:add(nn.SelectTable(2))
85 |
86 | end
87 |
88 |
89 |
90 |
91 |
92 | local Descrs = torch.CudaTensor(npatches,dim)
93 | local DescrsSplit = Descrs:split(BatchSize)
94 | for i,v in ipairs(patches32:split(BatchSize)) do
95 | temp = v:clone():cuda()
96 | DescrsSplit[i]:copy(net:forward(temp))
97 | --print(net:get(1):get(2):get(1):get(1):get(2):get(11).output)
98 | --os.exit()
99 | end
100 |
101 |
102 |
103 |
104 |
105 | for j=1,ntest do
106 | l = dataset.gt100k[j]
107 | lbl = l[2]==l[5] and 1 or 0
108 | id1 = l[1]+1
109 | id2 = l[4]+1
110 | dl = Descrs[{ {id1},{} }]
111 | dr = Descrs[{ {id2},{} }]
112 |
113 | if isBinary then
114 | dl = dl:gt(0.5):float()
115 | dr = dr:gt(0.5):float()
116 | end
117 | d = torch.dist(dl,dr)
118 |
119 |
120 | if iter == 1 then
121 |
122 |
123 | R[{{1},{j}}] = lbl
124 | R[{{2},{j}}] = R[{{2},{j}}]*d
125 | elseif iter ==2 then
126 |
127 | R[{{2},{j}}] = R[{{2},{j}}]*d
128 |
129 |
130 | end
131 |
132 | --io.write(string.format("%d %.4f \n", lbl,d))
133 | end
134 | end
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 | --FPR95(FPR at TPR or recall 95%) and ROC curve
144 | val_sorted, temp_id_sorted = torch.sort(R[{{2},{}}])
145 | id_sorted = temp_id_sorted[{1,{}}]
146 | pn_sorted = R[{{1},{}}]:index(2,id_sorted)
147 |
148 |
149 | pos_all = torch.sum(R[{{1},{}}])
150 | print('pos_all:'..pos_all)
151 | neg_all = ntest-pos_all
152 | pos_95 = torch.floor(pos_all*0.95)
153 |
154 | pos_acc = pn_sorted:clone() --Don't forget "clone()"
155 |
156 | for j=2,ntest do
157 | pos_acc[{1,j}] = pos_acc[{1,j-1}] + pos_acc[{1,j}]
158 | end
159 |
160 |
161 | print('pos_95:'..pos_95)
162 | id_95 = pos_acc[{1,{}}]:eq(pos_95):nonzero():min()
163 | print(id_95)
164 |
165 |
166 |
167 | tpr = torch.Tensor(1,ntest)
168 | fpr = torch.Tensor(1,ntest)
169 | for k=1,ntest do
170 | tpr[{1,k}] = pos_acc[{1,k}]/pos_all
171 | fpr[{1,k}] = torch.abs(pos_acc[{1,k}]-k)/neg_all
172 | end
173 |
174 |
175 |
176 | Result = torch.cat(fpr,tpr,1)
177 | gnuplot.plot(Result:t())
178 | gnuplot.title('ROC curve')
179 | gnuplot.xlabel('FPR')
180 | gnuplot.ylabel('TPR')
181 | FPR95 = fpr[{1,id_95}]
182 | print('FPR95%:'.. FPR95*100)
183 |
184 |
185 |
186 |
187 |
--------------------------------------------------------------------------------
/main/train/fun_DeepCD_2S.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require '../utils.lua'
3 | require 'image'
4 | require 'optim'
5 | require './DistanceRatioCriterion_allW.lua'
6 | require 'cudnn'
7 | require 'cutorch'
8 | require 'cunn'
9 | require './allWeightedMSECriterion.lua'
10 | require './DataDependentModule.lua'
11 |
12 | local matio = require 'matio'
13 |
14 | des_type = arg[1]
15 | if des_type == 'rb' then
16 | dim1 = tonumber(arg[2])
17 | bits2 = tonumber(arg[3])
18 | elseif des_type == 'bb' then
19 | bits1 = tonumber(arg[2])
20 | bits2 = tonumber(arg[3])
21 | end
22 | w_com = tonumber(arg[4])
23 | ws_lead = tonumber(arg[5])
24 | ws_pro = tonumber(arg[6])
25 | name = arg[7]
26 | if arg[8] == 'true' then
27 | isDDM = true
28 | elseif arg[8] == 'false' then
29 | isDDM = false
30 | end
31 | DDM_LR = tonumber(arg[9])
32 |
33 | pT = {
34 |
35 | --dim1 = 128,
36 | --bits2 = 256,
37 | --w_com= 1.41,
38 | --ws_lead = 1,
39 | --ws_pro = 5,
40 | scal_sigmoid = 100,
41 | isNorm = true,
42 | --isSTN = false,
43 | --name = 'liberty',
44 | batch_size = 128,
45 | num_triplets = 1280000,
46 | max_epoch = 100
47 |
48 | }
49 |
50 |
51 | if des_type == 'rb' then
52 | pT.dim1 = dim1
53 | pT.bits2 = bits2
54 | elseif des_type == 'bb' then
55 | pT.bits1 = bits1
56 | pT.bits2 = bits2
57 | end
58 | pT.w_com = w_com
59 | pT.ws_lead = ws_lead
60 | pT.ws_pro = ws_pro
61 | pT.name = name
62 | pT.isDDM = isDDM
63 | pT.DDM_LR = DDM_LR
64 |
65 | -- optim parameters
66 | optimState = {
67 | learningRate = 0.1,
68 | weightDecay = 1e-4,
69 | momentum = 0.9,
70 | learningRateDecay = 1e-6
71 | }
72 | pT.optimState = optimState
73 |
74 | if pT.isDDM then
75 | addName = 'DDM_'
76 | else
77 | addName = ''
78 | end
79 |
80 | if des_type == 'rb' then
81 | folder_name = './train_epoch/DeepCD_2S_'..addName..pT.dim1..'dim1_'..pT.bits2..'bits2_'..pT.name
82 | elseif des_type == 'bb' then
83 | folder_name = './train_epoch/DeepCD_2S_'..addName..pT.bits1..'bits1_'..pT.bits2..'bits2_'..pT.name
84 | end
85 |
86 | print(pT)
87 | os.execute("mkdir -p " .. folder_name)
88 | if pT.isDDM then
89 | os.execute("mkdir -p " .. folder_name.."/DDM_vec/")
90 | end
91 | torch.save(folder_name..'/ParaTable_DeepCD_2S_'..addName..pT.name..'.t7',pT)
92 |
93 |
94 | ------------------------------------------------------------------------------------------------
95 |
96 |
97 | -- number of threads
98 | torch.setnumthreads(13)
99 |
100 | -- read training data, save mu and sigma & normalize
101 |
102 | traind = read_brown_data('./UBCdataset/'..pT.name)
103 | stats = get_stats(traind)
104 | print(stats)
105 | if pT.isNorm then
106 | norm_data(traind,stats)
107 | end
108 | print("==> read the dataset")
109 |
110 | -- generate random triplets for training data
111 |
112 | training_triplets = generate_triplets(traind, pT.num_triplets)
113 | print("==> created the tests")
114 | ------------------------------------------------------------------------------------------------
115 | if des_type == 'rb' then
116 | paths.dofile('../models/model_DeepCD_2stream.lua')
117 | elseif des_type == 'bb' then
118 | paths.dofile('../models/model_DeepCD_2S_binary_binary.lua')
119 | end
120 | model1 = createModel(pT)
121 | model1:training()
122 |
123 | -- clone the other two networks in the triplet
124 | model2 = model1:clone('weight', 'bias','gradWeight','gradBias')
125 | model3 = model1:clone('weight', 'bias','gradWeight','gradBias')
126 |
127 | -- add them to a parallel table
128 | prl = nn.ParallelTable()
129 | prl:add(model1)
130 | prl:add(model2)
131 | prl:add(model3)
132 | prl:cuda()
133 | ------------------------------------------------------------------------------------------------
134 | mlp= nn.Sequential()
135 | mlp:add(prl)
136 |
137 | -- get feature distances
138 | cc = nn.ConcatTable()
139 |
140 |
141 | -- feats 1 with 2 leading
142 | cnn_left1_lead = nn.Sequential()
143 |
144 | cnnneg1_dist_lead = nn.ConcatTable()
145 | a_neg1_lead = nn.Sequential()
146 | a_neg1_lead:add(nn.SelectTable(1))
147 | a_neg1_lead:add(nn.SelectTable(1))
148 | b_neg1_lead = nn.Sequential()
149 | b_neg1_lead:add(nn.SelectTable(2))
150 | b_neg1_lead:add(nn.SelectTable(1))
151 | cnnneg1_dist_lead:add(a_neg1_lead)
152 | cnnneg1_dist_lead:add(b_neg1_lead)
153 |
154 | cnn_left1_lead:add(cnnneg1_dist_lead)
155 | cnn_left1_lead:add(nn.PairwiseDistance(2))
156 | cnn_left1_lead:add(nn.View(pT.batch_size ,1))
157 | cnn_left1_lead:cuda()
158 | cc:add(cnn_left1_lead)
159 |
160 | -- feats 1 with 2 completing
161 | cnn_left1_com = nn.Sequential()
162 |
163 | cnnneg1_dist_com = nn.ConcatTable()
164 | a_neg1_com = nn.Sequential()
165 | a_neg1_com:add(nn.SelectTable(1))
166 | a_neg1_com:add(nn.SelectTable(2))
167 | b_neg1_com = nn.Sequential()
168 | b_neg1_com:add(nn.SelectTable(2))
169 | b_neg1_com:add(nn.SelectTable(2))
170 | cnnneg1_dist_com:add(a_neg1_com)
171 | cnnneg1_dist_com:add(b_neg1_com)
172 |
173 | cnn_left1_com:add(cnnneg1_dist_com)
174 | cnn_left1_com:add(nn.PairwiseDistance(2))
175 | cnn_left1_com:add(nn.View(pT.batch_size ,1))
176 | cnn_left1_com:cuda()
177 | cc:add(cnn_left1_com)
178 |
179 | -- feats 2 with 3 leading
180 | cnn_left2_lead = nn.Sequential()
181 | cnnneg2_dist_lead = nn.ConcatTable()
182 | a_neg2_lead = nn.Sequential()
183 | a_neg2_lead:add(nn.SelectTable(2))
184 | a_neg2_lead:add(nn.SelectTable(1))
185 | b_neg2_lead = nn.Sequential()
186 | b_neg2_lead:add(nn.SelectTable(3))
187 | b_neg2_lead:add(nn.SelectTable(1))
188 | cnnneg2_dist_lead:add(a_neg2_lead)
189 | cnnneg2_dist_lead:add(b_neg2_lead)
190 | cnn_left2_lead:add(cnnneg2_dist_lead)
191 | cnn_left2_lead:add(nn.PairwiseDistance(2))
192 | cnn_left2_lead:add(nn.View(pT.batch_size,1))
193 | cnn_left2_lead:cuda()
194 | cc:add(cnn_left2_lead)
195 |
196 | -- feats 2 with 3 completing
197 | cnn_left2_com = nn.Sequential()
198 | cnnneg2_dist_com = nn.ConcatTable()
199 | a_neg2_com = nn.Sequential()
200 | a_neg2_com:add(nn.SelectTable(2))
201 | a_neg2_com:add(nn.SelectTable(2))
202 | b_neg2_com = nn.Sequential()
203 | b_neg2_com:add(nn.SelectTable(3))
204 | b_neg2_com:add(nn.SelectTable(2))
205 | cnnneg2_dist_com:add(a_neg2_com)
206 | cnnneg2_dist_com:add(b_neg2_com)
207 | cnn_left2_com:add(cnnneg2_dist_com)
208 | cnn_left2_com:add(nn.PairwiseDistance(2))
209 | cnn_left2_com:add(nn.View(pT.batch_size ,1))
210 | cnn_left2_com:cuda()
211 | cc:add(cnn_left2_com)
212 |
213 | -- feats 1 with 3 leading
214 | cnn_right_lead = nn.Sequential()
215 | cnnpos_dist_lead = nn.ConcatTable()
216 | a_pos_lead = nn.Sequential()
217 | a_pos_lead:add(nn.SelectTable(1))
218 | a_pos_lead:add(nn.SelectTable(1))
219 | b_pos_lead = nn.Sequential()
220 | b_pos_lead:add(nn.SelectTable(3))
221 | b_pos_lead:add(nn.SelectTable(1))
222 | cnnpos_dist_lead:add(a_pos_lead)
223 | cnnpos_dist_lead:add(b_pos_lead)
224 | cnn_right_lead:add(cnnpos_dist_lead)
225 | cnn_right_lead:add(nn.PairwiseDistance(2))
226 | cnn_right_lead:add(nn.View(pT.batch_size ,1))
227 | cnn_right_lead:cuda()
228 | cc:add(cnn_right_lead)
229 |
230 | -- feats 1 with 3 completing
231 | cnn_right_com = nn.Sequential()
232 | cnnpos_dist_com = nn.ConcatTable()
233 | a_pos_com = nn.Sequential()
234 | a_pos_com:add(nn.SelectTable(1))
235 | a_pos_com:add(nn.SelectTable(2))
236 | b_pos_com = nn.Sequential()
237 | b_pos_com:add(nn.SelectTable(3))
238 | b_pos_com:add(nn.SelectTable(2))
239 | cnnpos_dist_com:add(a_pos_com)
240 | cnnpos_dist_com:add(b_pos_com)
241 | cnn_right_com:add(cnnpos_dist_com)
242 | cnn_right_com:add(nn.PairwiseDistance(2))
243 | cnn_right_com:add(nn.View(pT.batch_size ,1))
244 | cnn_right_com:cuda()
245 | cc:add(cnn_right_com)
246 |
247 |
248 | cc:cuda()
249 |
250 | mlp:add(cc)
251 | ------------------------------------------------------------------------------------------------
252 | last_layer = nn.ConcatTable()
253 |
254 |
255 | -- select leading min negative distance inside the triplet
256 | mined_neg = nn.Sequential()
257 | mining_layer = nn.ConcatTable()
258 | mining_layer:add(nn.SelectTable(1))
259 | mining_layer:add(nn.SelectTable(3))
260 | mined_neg:add(mining_layer)
261 | mined_neg:add(nn.JoinTable(2))
262 | mined_neg:add(nn.Min(2))
263 | mined_neg:add(nn.View(pT.batch_size ,1))
264 | last_layer:add(mined_neg)
265 |
266 | -- add leading positive distance
267 | pos_layer = nn.Sequential()
268 | pos_layer:add(nn.SelectTable(5))
269 | pos_layer:add(nn.View(pT.batch_size ,1))
270 | last_layer:add(pos_layer)
271 |
272 |
273 | ------------------------------------------------------------------------------------------------
274 | a_DDM = nn.Identity()
275 | output_layer_DDM = nn.Linear(pT.batch_size*2,pT.batch_size)
276 | output_layer_DDM.weight:fill(0)
277 | output_layer_DDM.bias:fill(1)
278 | b_DDM = nn.Sequential():add(nn.Reshape(pT.batch_size*2,false)):add(output_layer_DDM):add(nn.Sigmoid())
279 | DDM_ct = nn.ConcatTable():add(a_DDM:clone()):add(b_DDM:clone())
280 | DDM_layer = nn.Sequential():add(DDM_ct):add(nn.DataDependentModule(pT.DDM_LR))
281 | ------------------------------------------------------------------------------------------------
282 |
283 |
284 | --add neg1 (real,binary) distance
285 | neg1_RB_layer = nn.Sequential()
286 | temp_neg1_RB = nn.ConcatTable()
287 |
288 | S1_neg1 = nn.Sequential()
289 | S1_neg1:add(nn.SelectTable(1))
290 |
291 | S2_neg1 = nn.Sequential()
292 | S2_neg1:add(nn.SelectTable(2))
293 |
294 | temp_neg1_RB:add(S1_neg1)
295 | temp_neg1_RB:add(S2_neg1)
296 | neg1_RB_layer:add(temp_neg1_RB)
297 | if pT.isDDM then
298 | neg1_RB_layer:add(nn.JoinTable(2))
299 | neg1_RB_layer:add(DDM_layer)
300 | neg1_RB_layer:add(nn.SplitTable(2))
301 | end
302 | neg1_RB_layer:add(nn.CMulTable())
303 | neg1_RB_layer:add(nn.Sqrt())
304 | neg1_RB_layer:add(nn.View(pT.batch_size ,1))
305 | neg1_RB_layer:cuda()
306 | last_layer:add(neg1_RB_layer)
307 |
308 | --add neg2 (real,binary) distance
309 | neg2_RB_layer = nn.Sequential()
310 | temp_neg2_RB = nn.ConcatTable()
311 |
312 | S1_neg2 = nn.Sequential()
313 | S1_neg2:add(nn.SelectTable(3))
314 |
315 | S2_neg2 = nn.Sequential()
316 | S2_neg2:add(nn.SelectTable(4))
317 |
318 | temp_neg2_RB:add(S1_neg2)
319 | temp_neg2_RB:add(S2_neg2 )
320 | neg2_RB_layer:add(temp_neg2_RB)
321 | if pT.isDDM then
322 | neg2_RB_layer:add(nn.JoinTable(2))
323 | neg2_RB_layer:add(DDM_layer)
324 | neg2_RB_layer:add(nn.SplitTable(2))
325 | end
326 | neg2_RB_layer:add(nn.CMulTable())
327 | neg2_RB_layer:add(nn.Sqrt())
328 | neg2_RB_layer:add(nn.View(pT.batch_size ,1))
329 | neg2_RB_layer:cuda()
330 | last_layer:add(neg2_RB_layer)
331 |
332 |
333 | --add pos (real,binary) distance
334 | pos_RB_layer = nn.Sequential()
335 | temp_pos_RB = nn.ConcatTable()
336 |
337 | S1_pos = nn.Sequential()
338 | S1_pos:add(nn.SelectTable(5))
339 | S2_pos = nn.Sequential()
340 | S2_pos:add(nn.SelectTable(6))
341 |
342 | temp_pos_RB:add(S1_pos)
343 | temp_pos_RB:add(S2_pos)
344 | pos_RB_layer:add(temp_pos_RB)
345 | pos_RB_layer:add(nn.CMulTable())
346 | pos_RB_layer:add(nn.Sqrt())
347 |
348 | pos_RB_layer:add(nn.View(pT.batch_size ,1))
349 | pos_RB_layer:cuda()
350 | last_layer:add(pos_RB_layer)
351 | ------------------------------------------------------------------------------------------------
352 | mlp:add(last_layer)
353 |
354 | mlp:add(nn.JoinTable(2))
355 | mlp:cuda()
356 | ------------------------------------------------------------------------------------------------
357 | -- setup the criterion: ratio of min-negative to positive
358 | epoch = 1
359 |
360 |
361 |
362 | x=torch.zeros(pT.batch_size,1,32,32):cuda()
363 | y=torch.zeros(pT.batch_size,1,32,32):cuda()
364 | z=torch.zeros(pT.batch_size,1,32,32):cuda()
365 |
366 |
367 |
368 | parameters, gradParameters = mlp:getParameters()
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 | -- main training loop
377 |
378 | Loss = torch.zeros(1,pT.max_epoch)
379 |
380 |
381 |
382 | w = torch.zeros(128,5)
383 | w[{{},{1,2}}]:fill(pT.ws_lead)
384 | w[{{},{3,5}}]:fill(pT.ws_pro)
385 | crit=nn.DistanceRatioCriterion_allW(w):cuda()
386 |
387 | for epoch=epoch, pT.max_epoch do
388 |
389 |
390 |
391 | Gerr = 0
392 | shuffle = torch.randperm(pT.num_triplets)
393 | nbatches = pT.num_triplets/pT.batch_size
394 |
395 | for k=1,nbatches-1 do
396 | xlua.progress(k+1, nbatches)
397 |
398 | s = shuffle[{ {k*pT.batch_size,k*pT.batch_size+pT.batch_size} }]
399 | for i=1,pT.batch_size do
400 | x[i] = traind.patches32[training_triplets[s[i]][1]]
401 | y[i] = traind.patches32[training_triplets[s[i]][2]]
402 | z[i] = traind.patches32[training_triplets[s[i]][3]]
403 | end
404 |
405 | local feval = function(f)
406 | if f ~= parameters then parameters:copy(f) end
407 | gradParameters:zero()
408 | inputs = {x,y,z}
409 | local outputs = mlp:forward(inputs)
410 |
411 | local f = crit:forward(outputs, 1)
412 | Gerr = Gerr+f
413 | local df_do = crit:backward(outputs)
414 | mlp:backward(inputs, df_do)
415 | return f,gradParameters
416 | end
417 | optim.sgd(feval, parameters, optimState)
418 |
419 | end
420 | loss = Gerr/nbatches
421 | Loss[{{1},{epoch}}] = loss
422 | if pT.isDDM then
423 | print(DDM_vec)
424 | matio.save(folder_name..'/DDM_vec/DDM_vec_epoch'..epoch..'.mat',{DDM_vec=DDM_vec:double()})
425 | end
426 |
427 | print('==> epoch '..epoch)
428 | print(loss)
429 | print('')
430 |
431 | --remain = math.fmod(epoch,3)
432 | --if epoch == 1 or remain ==0 then
433 | net_save = mlp:get(1):get(1):clone()
434 | torch.save(folder_name..'/NET_DeepCD_2S_'..addName..pT.name..'_epoch'..epoch..'.t7',net_save:clearState())
435 | --end
436 |
437 | end
438 |
439 | torch.save(folder_name..'/Loss_DeepCD_2S_'..addName..pT.name..'.t7',Loss)
440 |
441 |
442 |
443 |
--------------------------------------------------------------------------------