├── README.md
├── addpyramidnet.lua
└── mulpyramidnet.lua
/README.md:
--------------------------------------------------------------------------------
1 | # PyramidNet
2 | This repository contains the code for the paper:
3 |
4 | Dongyoon Han*, Jiwhan Kim*, and Junmo Kim, "Deep Pyramidal Residual Networks", CVPR 2017 (* equal contribution).
5 |
6 | Arxiv: https://arxiv.org/abs/1610.02915.
7 |
8 | The code is based on Facebook's implementation of ResNet (https://github.com/facebook/fb.resnet.torch).
9 |
10 | ### Caffe implementation of PyramidNet: [site](https://github.com/jhkim89/PyramidNet-caffe)
11 | ### PyTorch implementation of PyramidNet: [site](https://github.com/dyhan0920/PyramidNet-PyTorch)
12 |
13 | ## Abstract
14 | Deep convolutional neural networks (DCNNs) have shown remarkable performance in image classification tasks in recent years. Generally, deep neural network architectures are stacks consisting of a large number of convolution layers, and they perform downsampling along the spatial dimension via pooling to reduce memory usage. At the same time, the feature map dimension (i.e., the number of channels) is sharply increased at downsampling locations, which is essential to ensure effective performance because it increases the capability of high-level attributes. Moreover, this also applies to residual networks and is very closely related to their performance. In this research, instead of using downsampling to achieve a sharp increase at each residual unit, we gradually increase the feature map dimension at all the units to involve as many locations as possible. This is discussed in depth together with our new insights as it has proven to be an effective design to improve the generalization ability. Furthermore, we propose a novel residual unit capable of further improving the classification accuracy with our new network architecture. Experiments on benchmark CIFAR datasets have shown that our network architecture has a superior generalization ability compared to the original residual networks.
15 |
16 |

17 |
18 | Figure 1: Schematic illustration of (a) basic residual units, (b) bottleneck, (c) wide residual units, and (d) our pyramidal residual units.
19 |
20 | 
21 |
22 | Figure 2: Visual illustrations of (a) additive PyramidNet, (b) multiplicative PyramidNet, and (c) comparison of (a) and (b).
23 |
24 | ## Usage
25 |
26 | 1. Install Torch (http://torch.ch) and ResNet (https://github.com/facebook/fb.resnet.torch).
27 | 2. Add the files addpyramidnet.lua and mulpyramidnet.lua to the folder "models".
28 | 3. Manually set the parameter "alpha" in the files addpyramidnet.lua and mulpyramidnet.lua (Line 28).
29 | 4. Change the learning rate schedule in the file train.lua: "decay = epoch >= 122 and 2 or epoch >= 81 and 1 or 0" to "decay = epoch >= 225 and 2 or epoch >= 150 and 1 or 0".
30 | 5. Train our PyramidNet, by running main.lua as below:
31 |
32 | To train additive PyramidNet-164 (alpha=48) on CIFAR-10 dataset:
33 | ```bash
34 | th main.lua -dataset cifar10 -depth 164 -nEpochs 300 -LR 0.1 -netType addpyramidnet -batchSize 128 -shareGradInput true
35 | ```
36 | To train additive PyramidNet-164 (alpha=48) with 4 GPUs on CIFAR-100 dataset:
37 | ```bash
38 | th main.lua -dataset cifar100 -depth 164 -nEpochs 300 -LR 0.5 -nGPU 4 -nThreads 8 -netType addpyramidNet -batchSize 128 -shareGradInput true
39 | ```
40 |
41 | ## Results
42 |
43 | #### CIFAR
44 |
45 | Top-1 error rates on CIFAR-10 and CIFAR-100 datasets. "alpha" denotes the widening factor; "add" and "mul" denote the results obtained with additive and multiplicative pyramidal networks, respectively.
46 |
47 | | Network | # of parameters | Output feat. dimension | CIFAR-10 | CIFAR-100 |
48 | | --------------------------------- | --------------- | ---------------------- | ----------- | ----------- |
49 | | PyramidNet-110 (mul), alpha=4.75 | 1.7M | 76 | 4.62 | 23.16 |
50 | | PyramidNet-110 (add), alpha=48 | 1.7M | **64** | 4.62 | 23.31 |
51 | | PyramidNet-110 (mul), alpha=8 | 3.8M | 128 | 4.50 | 20.94 |
52 | | PyramidNet-110 (add), alpha=84 | 3.8M | **100** | 4.27 | 20.21 |
53 | | PyramidNet-110 (mul), alpha=27 | 28.3M | 432 | 4.06 | 18.79 |
54 | | PyramidNet-110 (add), alpha=270 | 28.3M | **286** | **3.73** | **18.25** |
55 |
56 | Top-1 error rates of our model with the **bottleneck architecture** on CIFAR-10 and CIFAR-100 datasets. We use the additive pyramidal networks.
57 |
58 | | Network | # of parameters | Output feat. dimension | CIFAR-10 | CIFAR-100 |
59 | | --------------------------------- | --------------- | ---------------------- | ----------- | ----------- |
60 | | PyramidNet-164 (add), alpha=48 | 1.7M | 256 | 4.21 | 19.52 |
61 | | PyramidNet-164 (add), alpha=84 | 3.8M | 400 | 3.96 | 18.32 |
62 | | PyramidNet-164 (add), alpha=270 | 27.0M | 1144 | **3.48** | **17.01** |
63 | | PyramidNet-200 (add), alpha=240 | 26.6M | 1024 | **3.44** | **16.51** |
64 | | PyramidNet-236 (add), alpha=220 | 26.8M | 944 | **3.40** | **16.37** |
65 | | PyramidNet-272 (add), alpha=200 | 26.0M | 864 | **3.31** | **16.35** |
66 |
67 | 
68 |
69 | Figure 3: Performance distribution according to number of parameters on CIFAR-10 (left) and CIFAR-100 (right).
70 |
71 | #### ImageNet
72 |
73 | Top-1 and Top-5 error rates of single-model, single-crop (224*224) on ImageNet dataset. We use the additive PyramidNet for our results.
74 |
75 | | Network | # of parameters | Output feat. dimension | Top-1 error | Top-5 error |
76 | | ----------------------------------------- | --------------- | ---------------------- | ----------- | ----------- |
77 | | PreResNet-200 | 64.5M | 2048 | 21.66 | 5.79 |
78 | | PyramidNet-200, alpha=300 | 62.1M | 1456 | 20.47 | 5.29 |
79 | | PyramidNet-200, alpha=450, Dropout (0.5) | 116.4M | 2056 | 20.11 | 5.43 |
80 |
81 | Model files download: [link](https://1drv.ms/f/s!AmNvwgeB0n4GsiDFDNJWZkEbajJf)
82 |
83 |
84 | ## Notes
85 |
86 | 1. The parameter "alpha" can only be changed in the files addpyramidnet.lua and mulpyramidnet.lua (Line 28).
87 | 2. We recommend to use multi-GPU when training additive PyramidNet with alpha=270 or multiplicative PyramidNet with alpha=27. Otherwise you may get "out of memory" error.
88 | 3. We are currently testing our code in the ImageNet dataset. We will upload the result when the training is completed.
89 |
90 | ## Updates
91 |
92 | 07/17/2017:
93 |
94 | 1. Caffe implementation of PyramidNet is released.
95 |
96 | 02/12/2017:
97 |
98 | 1. Results of the bottleneck architecture on CIFAR datasets are updated.
99 |
100 | 01/23/2017:
101 |
102 | 1. Added Imagenet pretrained models.
103 |
104 | ## Contact
105 | Jiwhan Kim (jhkim89@kaist.ac.kr),
106 | Dongyoon Han (dyhan@kaist.ac.kr),
107 | Junmo Kim (junmo.kim@kaist.ac.kr)
108 |
--------------------------------------------------------------------------------
/addpyramidnet.lua:
--------------------------------------------------------------------------------
1 | -- Implementation of "Deep Pyramidal Residual Networks"
2 |
3 | -- ************************************************************************
4 | -- This code incorporates material from:
5 |
6 | -- fb.resnet.torch (https://github.com/facebook/fb.resnet.torch)
7 | -- Copyright (c) 2016, Facebook, Inc.
8 | -- All rights reserved.
9 | --
10 | -- This source code is licensed under the BSD-style license found in the
11 | -- LICENSE file in the root directory of this source tree. An additional grant
12 | -- of patent rights can be found in the PATENTS file in the same directory.
13 | --
14 | -- ************************************************************************
15 |
16 | local nn = require 'nn'
17 | require 'cunn'
18 |
19 | local Convolution = cudnn.SpatialConvolution
20 | local Avg = cudnn.SpatialAveragePooling
21 | local ReLU = cudnn.ReLU
22 | local Max = nn.SpatialMaxPooling
23 | local SBatchNorm = nn.SpatialBatchNormalization
24 |
25 | local function createModel(opt)
26 | local depth = opt.depth
27 | local iChannels
28 | local alpha = 48
29 | -- local alpha = 300
30 | local function round(x)
31 | return math.floor(x+0.5)
32 | end
33 |
34 | local function shortcut(nInputPlane, nOutputPlane, stride)
35 | -- Strided, zero-padded identity shortcut
36 | local short = nn.Sequential()
37 | if stride == 2 then
38 | short:add(nn.SpatialAveragePooling(2, 2, 2, 2))
39 | end
40 | if nInputPlane ~= nOutputPlane then
41 | short:add(nn.Padding(1, (nOutputPlane - nInputPlane), 3))
42 | else
43 | short:add(nn.Identity())
44 | end
45 | return short
46 | end
47 |
48 | local function basicblock(n, stride)
49 | local nInputPlane = iChannels
50 | iChannels = n
51 |
52 | local s = nn.Sequential()
53 | s:add(SBatchNorm(nInputPlane))
54 | s:add(Convolution(nInputPlane,n,3,3,stride,stride,1,1))
55 | s:add(SBatchNorm(n))
56 | s:add(ReLU(true))
57 | s:add(Convolution(n,n,3,3,1,1,1,1))
58 | s:add(SBatchNorm(n))
59 | return nn.Sequential()
60 | :add(nn.ConcatTable()
61 | :add(s)
62 | :add(shortcut(nInputPlane, n, stride)))
63 | :add(nn.CAddTable(true))
64 | end
65 |
66 | local function bottleneck(n, stride, type)
67 | local nInputPlane = iChannels
68 | iChannels = n * 4
69 |
70 | local s = nn.Sequential()
71 | s:add(SBatchNorm(nInputPlane))
72 | s:add(Convolution(nInputPlane,n,1,1,1,1,0,0))
73 | s:add(SBatchNorm(n))
74 | s:add(ReLU(true))
75 | s:add(Convolution(n,n,3,3,stride,stride,1,1))
76 | s:add(SBatchNorm(n))
77 | s:add(ReLU(true))
78 | s:add(Convolution(n,n*4,1,1,1,1,0,0))
79 | s:add(SBatchNorm(n*4))
80 |
81 | return nn.Sequential()
82 | :add(nn.ConcatTable()
83 | :add(s)
84 | :add(shortcut(nInputPlane, n * 4, stride)))
85 | :add(nn.CAddTable(true))
86 | end
87 |
88 | -- Creates count residual blocks with specified number of features
89 | local function layer(block, features, count, stride)
90 | local s = nn.Sequential()
91 | if count < 1 then
92 | return s
93 | end
94 | for i=1,count do
95 | s:add(block(features, stride))
96 | end
97 | return s
98 | end
99 |
100 | local model = nn.Sequential()
101 | if opt.dataset == 'imagenet' then
102 | -- Configurations for ResNet:
103 | -- num. residual blocks, num features, residual block function
104 | local cfg = {
105 | [18] = {{2, 2, 2, 2}, 512, basicblock},
106 | [34] = {{3, 4, 6, 3}, 512, basicblock},
107 | [50] = {{3, 4, 6, 3}, 2048, bottleneck},
108 | [101] = {{3, 4, 23, 3}, 2048, bottleneck},
109 | [152] = {{3, 8, 36, 3}, 2048, bottleneck},
110 | [200] = {{3, 24, 36, 3}, 2048, bottleneck},
111 | }
112 |
113 | assert(cfg[depth], 'Invalid depth: ' .. tostring(depth))
114 | local def, nFeatures, block = table.unpack(cfg[depth])
115 | iChannels = 64
116 | Channeltemp = 64
117 | local addrate = alpha/(def[1]+def[2]+def[3]+def[4])
118 | print(' | PyramidNet-' .. depth .. ' ImageNet')
119 |
120 | model:add(Convolution(3,64,7,7,2,2,3,3))
121 | model:add(SBatchNorm(64))
122 | model:add(ReLU(true))
123 | model:add(Max(3,3,2,2,1,1))
124 | Channeltemp = Channeltemp + addrate
125 | model:add(bottleneck(round(Channeltemp), 1, 1, 'first'))
126 | for i=2,def[1] do
127 | Channeltemp = Channeltemp + addrate
128 | model:add(bottleneck(round(Channeltemp), 1, 1))
129 | end
130 | Channeltemp = Channeltemp + addrate
131 | model:add(bottleneck(round(Channeltemp), 2, 1))
132 | for i=2,def[2] do
133 | Channeltemp = Channeltemp + addrate
134 | model:add(bottleneck(round(Channeltemp), 1, 1))
135 | end
136 | Channeltemp = Channeltemp + addrate
137 | model:add(bottleneck(round(Channeltemp), 2, 1))
138 | for i=2,def[3] do
139 | Channeltemp = Channeltemp + addrate
140 | model:add(bottleneck(round(Channeltemp), 1, 1))
141 | end
142 | Channeltemp = Channeltemp + addrate
143 | model:add(bottleneck(round(Channeltemp), 2, 1))
144 | for i=2,def[4] do
145 | Channeltemp = Channeltemp + addrate
146 | model:add(bottleneck(round(Channeltemp), 1, 1))
147 | end
148 | model:add(nn.Copy(nil, nil, true))
149 | model:add(SBatchNorm(iChannels))
150 | model:add(ReLU(true))
151 | model:add(Avg(7, 7, 1, 1))
152 | model:add(nn.View(iChannels):setNumInputDims(3))
153 | model:add(nn.Linear(iChannels, 1000))
154 |
155 | elseif opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then
156 | -- local n = (depth - 2) / 6 -- basicblock
157 | local n = (depth - 2) / 9 -- bottleneck
158 | iChannels = 16
159 | local startChannel = 16
160 | local Channeltemp = 16
161 | addChannel = alpha/(3*n)
162 | print(' | PyramidNet-' .. depth .. ' CIFAR')
163 |
164 | model:add(Convolution(3,16,3,3,1,1,1,1))
165 | model:add(SBatchNorm(iChannels))
166 |
167 | Channeltemp = startChannel
168 | startChannel = startChannel + addChannel
169 | model:add(layer(bottleneck, round(startChannel), 1, 1, 1))
170 | for i=2,n do
171 | Channeltemp = startChannel
172 | startChannel = startChannel + addChannel
173 | model:add(layer(bottleneck, round(startChannel), 1, 1, 1))
174 | end
175 |
176 | Channeltemp = startChannel
177 | startChannel = startChannel + addChannel
178 | model:add(layer(bottleneck, round(startChannel), 1, 2, 1))
179 | for i=2,n do
180 | Channeltemp = startChannel
181 | startChannel = startChannel + addChannel
182 | model:add(layer(bottleneck, round(startChannel), 1, 1, 1))
183 | end
184 | Channeltemp = startChannel
185 | startChannel = startChannel + addChannel
186 | model:add(layer(bottleneck, round(startChannel), 1, 2, 1))
187 | for i=2,n do
188 | Channeltemp = startChannel
189 | startChannel = startChannel + addChannel
190 | model:add(layer(bottleneck, round(startChannel), 1, 1, 1))
191 | end
192 | model:add(nn.Copy(nil, nil, true))
193 | model:add(SBatchNorm(iChannels))
194 | model:add(ReLU(true))
195 | model:add(Avg(8, 8, 1, 1))
196 | model:add(nn.View(iChannels):setNumInputDims(3))
197 | if opt.dataset == 'cifar10' then
198 | model:add(nn.Linear(iChannels, 10))
199 | elseif opt.dataset == 'cifar100' then
200 | model:add(nn.Linear(iChannels, 100))
201 | end
202 | else
203 | error('invalid dataset: ' .. opt.dataset)
204 | end
205 |
206 | local function ConvInit(name)
207 | for k,v in pairs(model:findModules(name)) do
208 | local n = v.kW*v.kH*v.nOutputPlane
209 | v.weight:normal(0,math.sqrt(2/n))
210 | if cudnn.version >= 4000 then
211 | v.bias = nil
212 | v.gradBias = nil
213 | else
214 | v.bias:zero()
215 | end
216 | end
217 | end
218 | local function BNInit(name)
219 | for k,v in pairs(model:findModules(name)) do
220 | v.weight:fill(1)
221 | v.bias:zero()
222 | end
223 | end
224 |
225 | ConvInit('cudnn.SpatialConvolution')
226 | ConvInit('nn.SpatialConvolution')
227 | BNInit('fbnn.SpatialBatchNormalization')
228 | BNInit('cudnn.SpatialBatchNormalization')
229 | BNInit('nn.SpatialBatchNormalization')
230 | for k,v in pairs(model:findModules('nn.Linear')) do
231 | v.bias:zero()
232 | end
233 | model:cuda()
234 |
235 | if opt.cudnn == 'deterministic' then
236 | model:apply(function(m)
237 | if m.setMode then m:setMode(1,1,1) end
238 | end)
239 | end
240 |
241 | model:get(1).gradInput = nil
242 |
243 | return model
244 | end
245 |
246 | return createModel
247 |
--------------------------------------------------------------------------------
/mulpyramidnet.lua:
--------------------------------------------------------------------------------
1 | -- Implementation of "Deep Pyramidal Residual Networks"
2 |
3 | -- ************************************************************************
4 | -- This code incorporates material from:
5 |
6 | -- fb.resnet.torch (https://github.com/facebook/fb.resnet.torch)
7 | -- Copyright (c) 2016, Facebook, Inc.
8 | -- All rights reserved.
9 | --
10 | -- This source code is licensed under the BSD-style license found in the
11 | -- LICENSE file in the root directory of this source tree. An additional grant
12 | -- of patent rights can be found in the PATENTS file in the same directory.
13 | --
14 | -- ************************************************************************
15 |
16 | local nn = require 'nn'
17 | require 'cunn'
18 |
19 | local Convolution = cudnn.SpatialConvolution
20 | local Avg = cudnn.SpatialAveragePooling
21 | local ReLU = cudnn.ReLU
22 | local Max = nn.SpatialMaxPooling
23 | local SBatchNorm = nn.SpatialBatchNormalization
24 |
25 | local function createModel(opt)
26 | local depth = opt.depth
27 | local iChannels
28 | local alpha = 4.75
29 | local function round(x)
30 | return math.floor(x+0.5)
31 | end
32 |
33 | local function shortcut(nInputPlane, nOutputPlane, stride)
34 | -- Strided, zero-padded identity shortcut
35 | local short = nn.Sequential()
36 | if stride == 2 then
37 | short:add(nn.SpatialAveragePooling(2, 2, 2, 2))
38 | end
39 | if nInputPlane ~= nOutputPlane then
40 | short:add(nn.Padding(1, (nOutputPlane - nInputPlane), 3))
41 | else
42 | short:add(nn.Identity())
43 | end
44 | return short
45 | end
46 |
47 | local function basicblock(n, stride)
48 | local nInputPlane = iChannels
49 | iChannels = n
50 |
51 | local s = nn.Sequential()
52 | s:add(SBatchNorm(nInputPlane))
53 | s:add(Convolution(nInputPlane,n,3,3,stride,stride,1,1))
54 | s:add(SBatchNorm(n))
55 | s:add(ReLU(true))
56 | s:add(Convolution(n,n,3,3,1,1,1,1))
57 | s:add(SBatchNorm(n))
58 | return nn.Sequential()
59 | :add(nn.ConcatTable()
60 | :add(s)
61 | :add(shortcut(nInputPlane, n, stride)))
62 | :add(nn.CAddTable(true))
63 | end
64 |
65 | local function bottleneck(n, stride, type)
66 | local nInputPlane = iChannels
67 | iChannels = n * 4
68 |
69 | local s = nn.Sequential()
70 | s:add(SBatchNorm(nInputPlane))
71 | s:add(Convolution(nInputPlane,n,1,1,1,1,0,0))
72 | s:add(SBatchNorm(n))
73 | s:add(ReLU(true))
74 | s:add(Convolution(n,n,3,3,stride,stride,1,1))
75 | s:add(SBatchNorm(n))
76 | s:add(ReLU(true))
77 | s:add(Convolution(n,n*4,1,1,1,1,0,0))
78 | s:add(SBatchNorm(n*4))
79 |
80 | return nn.Sequential()
81 | :add(nn.ConcatTable()
82 | :add(s)
83 | :add(shortcut(nInputPlane, n * 4, stride)))
84 | :add(nn.CAddTable(true))
85 | end
86 |
87 |
88 | -- Creates count residual blocks with specified number of features
89 | local function layer(block, features, count, stride)
90 | local s = nn.Sequential()
91 | if count < 1 then
92 | return s
93 | end
94 | for i=1,count do
95 | s:add(block(features, stride))
96 | end
97 | return s
98 | end
99 |
100 | local model = nn.Sequential()
101 | if opt.dataset == 'imagenet' then
102 | -- Configurations for ResNet:
103 | -- num. residual blocks, num features, residual block function
104 | local cfg = {
105 | [18] = {{2, 2, 2, 2}, 512, basicblock},
106 | [34] = {{3, 4, 6, 3}, 512, basicblock},
107 | [50] = {{3, 4, 6, 3}, 2048, bottleneck},
108 | [101] = {{3, 4, 23, 3}, 2048, bottleneck},
109 | [152] = {{3, 8, 36, 3}, 2048, bottleneck},
110 | [200] = {{3, 24, 36, 3}, 2048, bottleneck},
111 | }
112 |
113 | assert(cfg[depth], 'Invalid depth: ' .. tostring(depth))
114 | local def, nFeatures, block = table.unpack(cfg[depth])
115 | iChannels = 64
116 | Channeltemp = 64
117 | local addrate = alpha^(1/(def[1]+def[2]+def[3]+def[4]))
118 | print(' | ResNet-' .. depth .. ' ImageNet')
119 |
120 | model:add(Convolution(3,64,7,7,2,2,3,3))
121 | model:add(SBatchNorm(64))
122 | model:add(ReLU(true))
123 | model:add(Max(3,3,2,2,1,1))
124 | Channeltemp = Channeltemp * addrate
125 | model:add(bottleneck(round(Channeltemp), 1, 1, 'first'))
126 | for i=2,def[1] do
127 | Channeltemp = Channeltemp * addrate
128 | model:add(bottleneck(round(Channeltemp), 1, 1))
129 | end
130 | Channeltemp = Channeltemp * addrate
131 | model:add(bottleneck(round(Channeltemp), 2, 1))
132 | for i=2,def[2] do
133 | Channeltemp = Channeltemp * addrate
134 | model:add(bottleneck(round(Channeltemp), 1, 1))
135 | end
136 | Channeltemp = Channeltemp * addrate
137 | model:add(bottleneck(round(Channeltemp), 2, 1))
138 | for i=2,def[3] do
139 | Channeltemp = Channeltemp * addrate
140 | model:add(bottleneck(round(Channeltemp), 1, 1))
141 | end
142 | Channeltemp = Channeltemp * addrate
143 | model:add(bottleneck(round(Channeltemp), 2, 1))
144 | for i=2,def[4] do
145 | Channeltemp = Channeltemp * addrate
146 | model:add(bottleneck(round(Channeltemp), 1, 1))
147 | end
148 | model:add(nn.Copy(nil, nil, true))
149 | model:add(SBatchNorm(iChannels))
150 | model:add(ReLU(true))
151 | model:add(Avg(7, 7, 1, 1))
152 | model:add(nn.View(iChannels):setNumInputDims(3))
153 | model:add(nn.Linear(iChannels, 1000))
154 |
155 | elseif opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then
156 | local n = (depth - 2) / 6
157 | iChannels = 16
158 | local startChannel = 16
159 | local Channeltemp = 16
160 | addChannel = alpha^(1/(3*n))
161 | print(' | PyramidNet-' .. depth .. ' CIFAR-10')
162 |
163 | model:add(Convolution(3,16,3,3,1,1,1,1))
164 | model:add(SBatchNorm(iChannels))
165 |
166 | Channeltemp = startChannel
167 | startChannel = startChannel * addChannel
168 | model:add(layer(basicblock, round(startChannel), 1, 1, 1))
169 | for i=2,n do
170 | Channeltemp = startChannel
171 | startChannel = startChannel * addChannel
172 | model:add(layer(basicblock, round(startChannel), 1, 1, 1))
173 | end
174 |
175 | Channeltemp = startChannel
176 | startChannel = startChannel * addChannel
177 | model:add(layer(basicblock, round(startChannel), 1, 2, 1))
178 | for i=2,n do
179 | Channeltemp = startChannel
180 | startChannel = startChannel * addChannel
181 | model:add(layer(basicblock, round(startChannel), 1, 1, 1))
182 | end
183 | Channeltemp = startChannel
184 | startChannel = startChannel * addChannel
185 | model:add(layer(basicblock, round(startChannel), 1, 2, 1))
186 | for i=2,n do
187 | Channeltemp = startChannel
188 | startChannel = startChannel * addChannel
189 | model:add(layer(basicblock, round(startChannel), 1, 1, 1))
190 | end
191 | model:add(nn.Copy(nil, nil, true))
192 | model:add(SBatchNorm(iChannels))
193 | model:add(ReLU(true))
194 | model:add(Avg(8, 8, 1, 1))
195 | model:add(nn.View(iChannels):setNumInputDims(3))
196 | if opt.dataset == 'cifar10' then
197 | model:add(nn.Linear(iChannels, 10))
198 | elseif opt.dataset == 'cifar100' then
199 | model:add(nn.Linear(iChannels, 100))
200 | end
201 | else
202 | error('invalid dataset: ' .. opt.dataset)
203 | end
204 |
205 | local function ConvInit(name)
206 | for k,v in pairs(model:findModules(name)) do
207 | local n = v.kW*v.kH*v.nOutputPlane
208 | v.weight:normal(0,math.sqrt(2/n))
209 | if cudnn.version >= 4000 then
210 | v.bias = nil
211 | v.gradBias = nil
212 | else
213 | v.bias:zero()
214 | end
215 | end
216 | end
217 | local function BNInit(name)
218 | for k,v in pairs(model:findModules(name)) do
219 | v.weight:fill(1)
220 | v.bias:zero()
221 | end
222 | end
223 |
224 | ConvInit('cudnn.SpatialConvolution')
225 | ConvInit('nn.SpatialConvolution')
226 | BNInit('fbnn.SpatialBatchNormalization')
227 | BNInit('cudnn.SpatialBatchNormalization')
228 | BNInit('nn.SpatialBatchNormalization')
229 | for k,v in pairs(model:findModules('nn.Linear')) do
230 | v.bias:zero()
231 | end
232 | model:cuda()
233 |
234 | if opt.cudnn == 'deterministic' then
235 | model:apply(function(m)
236 | if m.setMode then m:setMode(1,1,1) end
237 | end)
238 | end
239 |
240 | model:get(1).gradInput = nil
241 |
242 | return model
243 | end
244 |
245 | return createModel
246 |
--------------------------------------------------------------------------------