The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── LICENSE
├── README.md
├── checkpoints.lua
├── dataloader.lua
├── datasets
    ├── README.md
    ├── cifar10-gen.lua
    ├── cifar10.lua
    ├── cifar100-gen.lua
    ├── cifar100.lua
    ├── imagenet-gen.lua
    ├── imagenet.lua
    ├── init.lua
    └── transforms.lua
├── main.lua
├── models
    ├── DenseConnectLayer.lua
    ├── README.md
    ├── densenet.lua
    └── init.lua
├── opts.lua
└── train.lua


/LICENSE:
--------------------------------------------------------------------------------
 1 | Copyright (c) 2016, Zhuang Liu. 
 2 | All rights reserved.
 3 | 
 4 | Redistribution and use in source and binary forms, with or without modification,
 5 | are permitted provided that the following conditions are met:
 6 | 
 7 |  * Redistributions of source code must retain the above copyright notice, this
 8 |    list of conditions and the following disclaimer.
 9 | 
10 |  * Redistributions in binary form must reproduce the above copyright notice,
11 |    this list of conditions and the following disclaimer in the documentation
12 |    and/or other materials provided with the distribution.
13 | 
14 |  * Neither the name DenseNet nor the names of its contributors may be used to
15 |    endorse or promote products derived from this software without specific
16 |    prior written permission.
17 | 
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # Densely Connected Convolutional Networks (DenseNets)
  2 | 
  3 | This repository contains the code for DenseNet introduced in the following paper 
  4 | 
  5 | [Densely Connected Convolutional Networks](http://arxiv.org/abs/1608.06993) (CVPR 2017, Best Paper Award) 
  6 | 
  7 | [Gao Huang](http://www.cs.cornell.edu/~gaohuang/)\*, [Zhuang Liu](https://liuzhuang13.github.io/)\*, [Laurens van der Maaten](https://lvdmaaten.github.io/) and [Kilian Weinberger](https://www.cs.cornell.edu/~kilian/) (\* Authors contributed equally).
  8 | 
  9 | 
 10 | **Now with much more memory efficient implementation!** Please check the [technical report](https://arxiv.org/pdf/1707.06990.pdf) and [code](https://github.com/liuzhuang13/DenseNet/tree/master/models) for more infomation.
 11 |  
 12 | The code is built on [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch).
 13 | 
 14 | ### Citation
 15 | If you find DenseNet useful in your research, please consider citing:
 16 | 
 17 | 	@inproceedings{DenseNet2017,
 18 | 	  title={Densely connected convolutional networks},
 19 | 	  author={Huang, Gao and Liu, Zhuang and van der Maaten, Laurens and Weinberger, Kilian Q },
 20 | 	  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
 21 | 	  year={2017}
 22 | 	}
 23 | 
 24 | 
 25 | ## Other Implementations
 26 | Our [[Caffe]](https://github.com/liuzhuang13/DenseNetCaffe), Our memory-efficient [[Caffe]](https://github.com/Tongcheng/DN_CaffeScript), Our memory-efficient [[PyTorch]](https://github.com/gpleiss/efficient_densenet_pytorch), 
 27 | [[PyTorch]](https://github.com/andreasveit/densenet-pytorch) by Andreas Veit, [[PyTorch]](https://github.com/bamos/densenet.pytorch) by Brandon Amos, [[PyTorch]](https://github.com/baldassarreFe/pytorch-densenet-tiramisu) by Federico Baldassarre,
 28 | [[MXNet]](https://github.com/Nicatio/Densenet/tree/master/mxnet) by Nicatio, 
 29 | [[MXNet]](https://github.com/bruinxiong/densenet.mxnet) by Xiong Lin, 
 30 | [[MXNet]](https://github.com/miraclewkf/DenseNet) by miraclewkf,
 31 | [[Tensorflow]](https://github.com/YixuanLi/densenet-tensorflow) by Yixuan Li, 
 32 | [[Tensorflow]](https://github.com/LaurentMazare/deep-models/tree/master/densenet) by Laurent Mazare,
 33 | [[Tensorflow]](https://github.com/ikhlestov/vision_networks) by Illarion Khlestov, 
 34 | [[Lasagne]](https://github.com/Lasagne/Recipes/tree/master/papers/densenet) by Jan Schlüter, 
 35 | [[Keras]](https://github.com/tdeboissiere/DeepLearningImplementations/tree/master/DenseNet) by tdeboissiere,  
 36 | [[Keras]](https://github.com/robertomest/convnet-study) by Roberto de Moura Estevão Filho, 
 37 | [[Keras]](https://github.com/titu1994/DenseNet) by Somshubra Majumdar, 
 38 | [[Chainer]](https://github.com/t-hanya/chainer-DenseNet) by Toshinori Hanya, 
 39 | [[Chainer]](https://github.com/yasunorikudo/chainer-DenseNet) by Yasunori Kudo, 
 40 | [[Torch 3D-DenseNet]](https://github.com/barrykui/3ddensenet.torch) by Barry Kui,
 41 | [[Keras]](https://github.com/cmasch/densenet) by Christopher Masch,
 42 | [[Tensorflow2]](https://github.com/okason97/DenseNet-Tensorflow2) by Gaston Rios and Ulises Jeremias Cornejo Fandos.
 43 | 
 44 | 
 45 | 
 46 | Note that we only listed some early implementations here. If you would like to add yours, please submit a pull request.
 47 | 
 48 | ## Some Following up Projects
 49 | 0. [Multi-Scale Dense Convolutional Networks for Efficient Prediction](https://github.com/gaohuang/MSDNet)
 50 | 0. [DSOD: Learning Deeply Supervised Object Detectors from Scratch](https://github.com/szq0214/DSOD)
 51 | 0. [CondenseNet: An Efficient DenseNet using Learned Group Convolutions](https://github.com/ShichenLiu/CondenseNet)
 52 | 0. [Fully Convolutional DenseNets for Semantic Segmentation](https://github.com/SimJeg/FC-DenseNet)
 53 | 0. [Pelee: A Real-Time Object Detection System on Mobile Devices](https://github.com/Robert-JunWang/Pelee)
 54 | 
 55 | 
 56 | 
 57 | 
 58 | ## Contents
 59 | 1. [Introduction](#introduction)
 60 | 2. [Usage](#usage)
 61 | 3. [Results on CIFAR](#results-on-cifar)
 62 | 4. [Results on ImageNet and Pretrained Models](#results-on-imagenet-and-pretrained-models)
 63 | 5. [Updates](#updates)
 64 | 
 65 | 
 66 | ## Introduction
 67 | DenseNet is a network architecture where each layer is directly connected to every other layer in a feed-forward fashion (within each *dense block*). For each layer, the feature maps of all preceding layers are treated as separate inputs whereas its own feature maps are passed on as inputs to all subsequent layers. This connectivity pattern yields state-of-the-art accuracies on CIFAR10/100 (with or without data augmentation) and SVHN. On the large scale ILSVRC 2012 (ImageNet) dataset, DenseNet achieves a similar accuracy as ResNet, but using less than half the amount of parameters and roughly half the number of FLOPs.
 68 | 
 69 | <img src="https://cloud.githubusercontent.com/assets/8370623/17981494/f838717a-6ad1-11e6-9391-f0906c80bc1d.jpg" width="480">
 70 | 
 71 | Figure 1: A dense block with 5 layers and growth rate 4. 
 72 | 
 73 | 
 74 | ![densenet](https://cloud.githubusercontent.com/assets/8370623/17981496/fa648b32-6ad1-11e6-9625-02fdd72fdcd3.jpg)
 75 | Figure 2: A deep DenseNet with three dense blocks. 
 76 | 
 77 | 
 78 | ## Usage 
 79 | 0. Install Torch and required dependencies like cuDNN. See the instructions [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md) for a step-by-step guide.
 80 | 1. Clone this repo: ```git clone https://github.com/liuzhuang13/DenseNet.git```
 81 | 
 82 | As an example, the following command trains a DenseNet-BC with depth L=100 and growth rate k=12 on CIFAR-10:
 83 | ```
 84 | th main.lua -netType densenet -dataset cifar10 -batchSize 64 -nEpochs 300 -depth 100 -growthRate 12
 85 | ``` 
 86 | As another example, the following command trains a DenseNet-BC with depth L=121 and growth rate k=32 on ImageNet:
 87 | ```
 88 | th main.lua -netType densenet -dataset imagenet -data [dataFolder] -batchSize 256 -nEpochs 90 -depth 121 -growthRate 32 -nGPU 4 -nThreads 16 -optMemory 3
 89 | ``` 
 90 | Please refer to [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) for data preparation.
 91 | 
 92 | ### DenseNet and DenseNet-BC
 93 | By default, the code runs with the DenseNet-BC architecture, which has 1x1 convolutional *bottleneck* layers, and *compresses* the number of channels at each transition layer by 0.5. To run with the original DenseNet, simply use the options *-bottleneck false* and *-reduction 1*
 94 | 
 95 | ### Memory efficient implementation (newly added feature on June 6, 2017)
 96 | There is an option *-optMemory* which is very useful for reducing GPU memory footprint when training a DenseNet. By default, the value is set to 2, which activates the *shareGradInput* function (with small modifications from [here](https://github.com/facebook/fb.resnet.torch/blob/master/models/init.lua#L102)). There are two extreme memory efficient modes (*-optMemory 3* or *-optMemory 4*) which use a customized densely connected layer. With *-optMemory 4*, the largest 190-layer DenseNet-BC on CIFAR can be trained on a single NVIDIA TitanX GPU (uses 8.3G of 12G) instead of fully using four GPUs with the standard (recursive concatenation) implementation . 
 97 | 
 98 | More details about the memory efficient implementation are discussed [here](https://github.com/liuzhuang13/DenseNet/tree/master/models).
 99 | 
100 | 
101 | ## Results on CIFAR
102 | The table below shows the results of DenseNets on CIFAR datasets. The "+" mark at the end denotes for standard data augmentation (random crop after zero-padding, and horizontal flip). For a DenseNet model, L denotes its depth and k denotes its growth rate. On CIFAR-10 and CIFAR-100 without data augmentation, a Dropout layer with drop rate 0.2 is introduced after each convolutional layer except the very first one.
103 | 
104 | Model | Parameters| CIFAR-10 | CIFAR-10+ | CIFAR-100 | CIFAR-100+ 
105 | -------|:-------:|:--------:|:--------:|:--------:|:--------:|
106 | DenseNet (L=40, k=12) |1.0M |7.00 |5.24 | 27.55|24.42
107 | DenseNet (L=100, k=12)|7.0M |5.77 |4.10 | 23.79|20.20
108 | DenseNet (L=100, k=24)|27.2M |5.83 |3.74 | 23.42|19.25
109 | DenseNet-BC (L=100, k=12)|0.8M |5.92 |4.51 | 24.15|22.27
110 | DenseNet-BC (L=250, k=24)|15.3M |**5.19** |3.62 | **19.64**|17.60
111 | DenseNet-BC (L=190, k=40)|25.6M |- |**3.46** | -|**17.18**
112 | 
113 | 
114 | ## Results on ImageNet and Pretrained Models
115 | ### Torch
116 | 
117 | **Note: the pre-trained models in Torch are deprecated and no longer maintained. Please use PyTorch's pre-trained [DenseNet models](https://pytorch.org/vision/stable/models.html) instead.**
118 | 
119 | #### Models in the original paper
120 | The Torch models are trained under the same setting as in [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch). The error rates shown are 224x224 1-crop test errors.
121 | 
122 | | Network       |  Top-1 error | Torch Model |
123 | | ------------- | ----------- | ----------- |
124 | | DenseNet-121 (k=32)  |  25.0    | [Download (64.5MB)]       |
125 | | DenseNet-169 (k=32)  | 23.6     | [Download (114.4MB)]       |
126 | | DenseNet-201 (k=32)  | 22.5     | [Download (161.8MB)]       |
127 | | DenseNet-161 (k=48)  | 22.2     | [Download (230.8MB)]
128 | 
129 | #### Models in the tech report
130 | More accurate models trained with the memory efficient implementation in the [technical report](https://arxiv.org/pdf/1707.06990.pdf). 
131 | 
132 | 
133 | | Network       |  Top-1 error | Torch Model |
134 | | ------------- | ----------- | ------------ |
135 | | DenseNet-264 (k=32)  |  22.1 |    [Download (256MB)]
136 | | DenseNet-232 (k=48)  | 21.2 | [Download (426MB)]
137 | | DenseNet-cosine-264 (k=32) | 21.6  |  [Download (256MB)]
138 | | DenseNet-cosine-264 (k=48) | 20.4   | [Download (557MB)]
139 | 
140 | 
141 | ### Caffe
142 | https://github.com/shicai/DenseNet-Caffe.
143 | 
144 | ### PyTorch
145 | [PyTorch documentation on models](http://pytorch.org/docs/torchvision/models.html?highlight=densenet). We would like to thank @gpleiss for this nice work in PyTorch.
146 | 
147 | ### Keras, Tensorflow and Theano
148 | https://github.com/flyyufelix/DenseNet-Keras.
149 | 
150 | ### MXNet
151 | https://github.com/miraclewkf/DenseNet.
152 | 
153 | 
154 | ## Wide-DenseNet for better Time/Accuracy and Memory/Accuracy Tradeoff
155 | 
156 | If you use DenseNet as a model in your learning task, to reduce the memory and time consumption, we recommend use a wide and shallow DenseNet, following the strategy of [wide residual networks](https://github.com/szagoruyko/wide-residual-networks). To obtain a wide DenseNet we set the depth to be smaller (e.g., L=40) and the growthRate to be larger (e.g., k=48).
157 | 
158 | We test a set of Wide-DenseNet-BCs and compared the memory and time with the DenseNet-BC (L=100, k=12) shown above. We obtained the statistics using a single TITAN X card, with batch size 64, and without any memory optimization.
159 | 
160 | 
161 | Model | Parameters| CIFAR-10+ | CIFAR-100+ | Time per Iteration | Memory 
162 | -------|:-------:|:--------:|:--------:|:--------:|:--------:|
163 | DenseNet-BC (L=100, k=12)|0.8M |4.51 |22.27 | 0.156s | 5452MB
164 | Wide-DenseNet-BC (L=40, k=36)|1.5M |4.58 |22.30 | 0.130s|4008MB
165 | Wide-DenseNet-BC (L=40, k=48)|2.7M |3.99 |20.29 | 0.165s|5245MB
166 | Wide-DenseNet-BC (L=40, k=60)|4.3M |4.01 |19.99 | 0.223s|6508MB
167 | 
168 | Obersevations:
169 | 
170 | 1. Wide-DenseNet-BC (L=40, k=36) uses less memory/time while achieves about the same accuracy as DenseNet-BC (L=100, k=12). 
171 | 2. Wide-DenseNet-BC (L=40, k=48) uses about the same memory/time as DenseNet-BC (L=100, k=12), while is much more accurate.
172 | 
173 | Thus, for practical use, we suggest picking one model from those Wide-DenseNet-BCs.
174 | 
175 | 
176 | 
177 | ## Updates
178 | **08/23/2017:**
179 | 
180 | 1. Add supporting code, so one can simply *git clone* and run.
181 | 
182 | **06/06/2017:**
183 | 
184 | 1. Support **ultra memory efficient** training of DenseNet with *customized densely connected layer*.
185 | 
186 | 2. Support **memory efficient** training of DenseNet with *standard densely connected layer* (recursive concatenation) by fixing the *shareGradInput* function.
187 | 
188 | 05/17/2017:
189 | 
190 | 1. Add Wide-DenseNet.
191 | 2. Add keras, tf, theano link for pretrained models.
192 | 
193 | 04/20/2017:
194 | 
195 | 1. Add usage of models in PyTorch.
196 | 
197 | 03/29/2017:
198 | 
199 | 1. Add the code for imagenet training.
200 | 
201 | 12/03/2016:
202 | 
203 | 1. Add Imagenet results and pretrained models.
204 | 2. Add DenseNet-BC structures.
205 | 
206 | 
207 | 
208 | ## Contact
209 | liuzhuangthu at gmail.com  
210 | Any discussions, suggestions and questions are welcome!
211 | 
212 | 


--------------------------------------------------------------------------------
/checkpoints.lua:
--------------------------------------------------------------------------------
 1 | --
 2 | --  Copyright (c) 2016, Facebook, Inc.
 3 | --  All rights reserved.
 4 | --
 5 | --  This source code is licensed under the BSD-style license found in the
 6 | --  LICENSE file in the root directory of this source tree. An additional grant
 7 | --  of patent rights can be found in the PATENTS file in the same directory.
 8 | --
 9 | local checkpoint = {}
10 | 
11 | local function deepCopy(tbl)
12 |    -- creates a copy of a network with new modules and the same tensors
13 |    local copy = {}
14 |    for k, v in pairs(tbl) do
15 |       if type(v) == 'table' then
16 |          copy[k] = deepCopy(v)
17 |       else
18 |          copy[k] = v
19 |       end
20 |    end
21 |    if torch.typename(tbl) then
22 |       torch.setmetatable(copy, torch.typename(tbl))
23 |    end
24 |    return copy
25 | end
26 | 
27 | function checkpoint.latest(opt)
28 |    if opt.resume == 'none' then
29 |       return nil
30 |    end
31 | 
32 |    local latestPath = paths.concat(opt.resume, 'latest.t7')
33 |    if not paths.filep(latestPath) then
34 |       return nil
35 |    end
36 | 
37 |    print('=> Loading checkpoint ' .. latestPath)
38 |    local latest = torch.load(latestPath)
39 |    local optimState = torch.load(paths.concat(opt.resume, latest.optimFile))
40 | 
41 |    return latest, optimState
42 | end
43 | 
44 | function checkpoint.save(epoch, model, optimState, isBestModel, opt)
45 |    -- don't save the DataParallelTable for easier loading on other machines
46 |    if torch.type(model) == 'nn.DataParallelTable' then
47 |       model = model:get(1)
48 |    end
49 | 
50 |    -- create a clean copy on the CPU without modifying the original network
51 |    model = deepCopy(model):float():clearState()
52 | 
53 |    local modelFile = 'model_' .. epoch .. '.t7'
54 |    local optimFile = 'optimState_' .. epoch .. '.t7'
55 | 
56 |    torch.save(paths.concat(opt.save, modelFile), model)
57 |    torch.save(paths.concat(opt.save, optimFile), optimState)
58 |    torch.save(paths.concat(opt.save, 'latest.t7'), {
59 |       epoch = epoch,
60 |       modelFile = modelFile,
61 |       optimFile = optimFile,
62 |    })
63 | 
64 |    if isBestModel then
65 |       torch.save(paths.concat(opt.save, 'model_best.t7'), model)
66 |    end
67 | end
68 | 
69 | return checkpoint
70 | 


--------------------------------------------------------------------------------
/dataloader.lua:
--------------------------------------------------------------------------------
  1 | --
  2 | --  Copyright (c) 2016, Facebook, Inc.
  3 | --  All rights reserved.
  4 | --
  5 | --  This source code is licensed under the BSD-style license found in the
  6 | --  LICENSE file in the root directory of this source tree. An additional grant
  7 | --  of patent rights can be found in the PATENTS file in the same directory.
  8 | --
  9 | --  Multi-threaded data loader
 10 | --
 11 | 
 12 | local datasets = require 'datasets/init'
 13 | local Threads = require 'threads'
 14 | Threads.serialization('threads.sharedserialize')
 15 | 
 16 | local M = {}
 17 | local DataLoader = torch.class('resnet.DataLoader', M)
 18 | 
 19 | function DataLoader.create(opt)
 20 |    -- The train and val loader
 21 |    local loaders = {}
 22 | 
 23 |    for i, split in ipairs{'train', 'val'} do
 24 |       local dataset = datasets.create(opt, split)
 25 |       loaders[i] = M.DataLoader(dataset, opt, split)
 26 |    end
 27 | 
 28 |    return table.unpack(loaders)
 29 | end
 30 | 
 31 | function DataLoader:__init(dataset, opt, split)
 32 |    local manualSeed = opt.manualSeed
 33 |    local function init()
 34 |       require('datasets/' .. opt.dataset)
 35 |    end
 36 |    local function main(idx)
 37 |       if manualSeed ~= 0 then
 38 |          torch.manualSeed(manualSeed + idx)
 39 |       end
 40 |       torch.setnumthreads(1)
 41 |       _G.dataset = dataset
 42 |       _G.preprocess = dataset:preprocess()
 43 |       return dataset:size()
 44 |    end
 45 | 
 46 |    local threads, sizes = Threads(opt.nThreads, init, main)
 47 |    self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1
 48 |    self.threads = threads
 49 |    self.__size = sizes[1][1]
 50 |    self.batchSize = math.floor(opt.batchSize / self.nCrops)
 51 |    local function getCPUType(tensorType)
 52 |       if tensorType == 'torch.CudaHalfTensor' then
 53 |          return 'HalfTensor'
 54 |       elseif tensorType == 'torch.CudaDoubleTensor' then
 55 |          return 'DoubleTensor'
 56 |       else
 57 |          return 'FloatTensor'
 58 |       end
 59 |    end
 60 |    self.cpuType = getCPUType(opt.tensorType)
 61 | end
 62 | 
 63 | function DataLoader:size()
 64 |    return math.ceil(self.__size / self.batchSize)
 65 | end
 66 | 
 67 | function DataLoader:run()
 68 |    local threads = self.threads
 69 |    local size, batchSize = self.__size, self.batchSize
 70 |    local perm = torch.randperm(size)
 71 | 
 72 |    local idx, sample = 1, nil
 73 |    local function enqueue()
 74 |       while idx <= size and threads:acceptsjob() do
 75 |          local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
 76 |          threads:addjob(
 77 |             function(indices, nCrops, cpuType)
 78 |                local sz = indices:size(1)
 79 |                local batch, imageSize
 80 |                local target = torch.IntTensor(sz)
 81 |                for i, idx in ipairs(indices:totable()) do
 82 |                   local sample = _G.dataset:get(idx)
 83 |                   local input = _G.preprocess(sample.input)
 84 |                   if not batch then
 85 |                      imageSize = input:size():totable()
 86 |                      if nCrops > 1 then table.remove(imageSize, 1) end
 87 |                      batch = torch[cpuType](sz, nCrops, table.unpack(imageSize))
 88 |                   end
 89 |                   batch[i]:copy(input)
 90 |                   target[i] = sample.target
 91 |                end
 92 |                collectgarbage()
 93 |                return {
 94 |                   input = batch:view(sz * nCrops, table.unpack(imageSize)),
 95 |                   target = target,
 96 |                }
 97 |             end,
 98 |             function(_sample_)
 99 |                sample = _sample_
100 |             end,
101 |             indices,
102 |             self.nCrops,
103 |             self.cpuType
104 |          )
105 |          idx = idx + batchSize
106 |       end
107 |    end
108 | 
109 |    local n = 0
110 |    local function loop()
111 |       enqueue()
112 |       if not threads:hasjob() then
113 |          return nil
114 |       end
115 |       threads:dojob()
116 |       if threads:haserror() then
117 |          threads:synchronize()
118 |       end
119 |       enqueue()
120 |       n = n + 1
121 |       return n, sample
122 |    end
123 | 
124 |    return loop
125 | end
126 | 
127 | return M.DataLoader
128 | 


--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
 1 | ## Datasets
 2 | 
 3 | Each dataset consist of two files: `dataset-gen.lua` and `dataset.lua`. The `dataset-gen.lua` is responsible for one-time setup, while
 4 | the `dataset.lua` handles the actual data loading.
 5 | 
 6 | If you want to be able to use the new dataset from main.lua, you should also modify `opts.lua` to handle the new dataset name.
 7 | 
 8 | ### `dataset-gen.lua`
 9 | 
10 | The `dataset-gen.lua` performs any necessary one-time setup. For example, the [`cifar10-gen.lua`](cifar10-gen.lua) file downloads the CIFAR-10 dataset, and the [`imagenet-gen.lua`](imagenet-gen.lua) file indexes all the training and validation data.
11 | 
12 | The module should have a single function `exec(opt, cacheFile)`.
13 | - `opt`: the command line options
14 | - `cacheFile`: path to output 
15 | 
16 | ```lua
17 | local M = {}
18 | function M.exec(opt, cacheFile)
19 |   local imageInfo = {}
20 |   -- preprocess dataset, store results in imageInfo, save to cacheFile
21 |   torch.save(cacheFile, imageInfo)
22 | end
23 | return M
24 | ```
25 | 
26 | ### `dataset.lua`
27 | 
28 | The `dataset.lua` should return a class that implements three functions:
29 | - `get(i)`: returns a table containing two entries, `input` and `target`
30 |   - `input`: the training or validation image as a Torch tensor
31 |   - `target`: the image category as a number 1-N
32 | - `size()`: returns the number of entries in the dataset
33 | - `preprocess()`: returns a function that transforms the `input` for data augmentation or input normalization
34 | 
35 | ```lua
36 | local M = {}
37 | local FakeDataset = torch.class('resnet.FakeDataset', M)
38 | 
39 | function FakeDataset:__init(imageInfo, opt, split)
40 |   -- imageInfo: result from dataset-gen.lua
41 |   -- opt: command-line arguments
42 |   -- split: "train" or "val"
43 | end
44 | 
45 | function FakeDataset:get(i)
46 |   return {
47 |     input = torch.Tensor(3, 800, 600):uniform(),
48 |     target = 42,
49 |   }
50 | end
51 | 
52 | function FakeDataset:size()
53 |   -- size of dataset
54 |   return 2000 
55 | end
56 | 
57 | function FakeDataset:preprocess()
58 |   -- Scale smaller side to 256 and take 224x224 center-crop
59 |   return t.Compose{
60 |     t.Scale(256),
61 |     t.CenterCrop(224),
62 |   }
63 | end
64 | 
65 | return M.FakeDataset
66 | ```
67 | 


--------------------------------------------------------------------------------
/datasets/cifar10-gen.lua:
--------------------------------------------------------------------------------
 1 | --
 2 | --  Copyright (c) 2016, Facebook, Inc.
 3 | --  All rights reserved.
 4 | --
 5 | --  This source code is licensed under the BSD-style license found in the
 6 | --  LICENSE file in the root directory of this source tree. An additional grant
 7 | --  of patent rights can be found in the PATENTS file in the same directory.
 8 | --
 9 | --  Script to compute list of ImageNet filenames and classes
10 | --
11 | --  This automatically downloads the CIFAR-10 dataset from
12 | --  http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz
13 | --
14 | 
15 | local URL = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz'
16 | 
17 | local M = {}
18 | 
19 | local function convertToTensor(files)
20 |    local data, labels
21 | 
22 |    for _, file in ipairs(files) do
23 |       local m = torch.load(file, 'ascii')
24 |       if not data then
25 |          data = m.data:t()
26 |          labels = m.labels:squeeze()
27 |       else
28 |          data = torch.cat(data, m.data:t(), 1)
29 |          labels = torch.cat(labels, m.labels:squeeze())
30 |       end
31 |    end
32 | 
33 |    -- This is *very* important. The downloaded files have labels 0-9, which do
34 |    -- not work with CrossEntropyCriterion
35 |    labels:add(1)
36 | 
37 |    return {
38 |       data = data:contiguous():view(-1, 3, 32, 32),
39 |       labels = labels,
40 |    }
41 | end
42 | 
43 | function M.exec(opt, cacheFile)
44 |    print("=> Downloading CIFAR-10 dataset from " .. URL)
45 |    local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/')
46 |    assert(ok == true or ok == 0, 'error downloading CIFAR-10')
47 | 
48 |    print(" | combining dataset into a single file")
49 |    local trainData = convertToTensor({
50 |       'gen/cifar-10-batches-t7/data_batch_1.t7',
51 |       'gen/cifar-10-batches-t7/data_batch_2.t7',
52 |       'gen/cifar-10-batches-t7/data_batch_3.t7',
53 |       'gen/cifar-10-batches-t7/data_batch_4.t7',
54 |       'gen/cifar-10-batches-t7/data_batch_5.t7',
55 |    })
56 |    local testData = convertToTensor({
57 |       'gen/cifar-10-batches-t7/test_batch.t7',
58 |    })
59 | 
60 |    print(" | saving CIFAR-10 dataset to " .. cacheFile)
61 |    torch.save(cacheFile, {
62 |       train = trainData,
63 |       val = testData,
64 |    })
65 | end
66 | 
67 | return M
68 | 


--------------------------------------------------------------------------------
/datasets/cifar10.lua:
--------------------------------------------------------------------------------
 1 | --
 2 | --  Copyright (c) 2016, Facebook, Inc.
 3 | --  All rights reserved.
 4 | --
 5 | --  This source code is licensed under the BSD-style license found in the
 6 | --  LICENSE file in the root directory of this source tree. An additional grant
 7 | --  of patent rights can be found in the PATENTS file in the same directory.
 8 | --
 9 | --  CIFAR-10 dataset loader
10 | --
11 | 
12 | local t = require 'datasets/transforms'
13 | 
14 | local M = {}
15 | local CifarDataset = torch.class('resnet.CifarDataset', M)
16 | 
17 | function CifarDataset:__init(imageInfo, opt, split)
18 |    assert(imageInfo[split], split)
19 |    self.imageInfo = imageInfo[split]
20 |    self.split = split
21 | end
22 | 
23 | function CifarDataset:get(i)
24 |    local image = self.imageInfo.data[i]:float()
25 |    local label = self.imageInfo.labels[i]
26 | 
27 |    return {
28 |       input = image,
29 |       target = label,
30 |    }
31 | end
32 | 
33 | function CifarDataset:size()
34 |    return self.imageInfo.data:size(1)
35 | end
36 | 
37 | -- Computed from entire CIFAR-10 training set
38 | local meanstd = {
39 |    mean = {125.3, 123.0, 113.9},
40 |    std  = {63.0,  62.1,  66.7},
41 | }
42 | 
43 | function CifarDataset:preprocess()
44 |    if self.split == 'train' then
45 |       return t.Compose{
46 |          t.ColorNormalize(meanstd),
47 |          t.HorizontalFlip(0.5),
48 |          t.RandomCrop(32, 4),
49 |       }
50 |    elseif self.split == 'val' then
51 |       return t.ColorNormalize(meanstd)
52 |    else
53 |       error('invalid split: ' .. self.split)
54 |    end
55 | end
56 | 
57 | return M.CifarDataset
58 | 


--------------------------------------------------------------------------------
/datasets/cifar100-gen.lua:
--------------------------------------------------------------------------------
 1 | --
 2 | --  Copyright (c) 2016, Facebook, Inc.
 3 | --  All rights reserved.
 4 | --
 5 | --  This source code is licensed under the BSD-style license found in the
 6 | --  LICENSE file in the root directory of this source tree. An additional grant
 7 | --  of patent rights can be found in the PATENTS file in the same directory.
 8 | --
 9 | 
10 | ------------
11 | -- This file automatically downloads the CIFAR-100 dataset from
12 | --    http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz
13 | -- It is based on cifar10-gen.lua
14 | -- Ludovic Trottier
15 | ------------
16 | 
17 | local URL = 'http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz'
18 | 
19 | local M = {}
20 | 
21 | local function convertCifar100BinToTorchTensor(inputFname)
22 |    local m=torch.DiskFile(inputFname, 'r'):binary()
23 |    m:seekEnd()
24 |    local length = m:position() - 1
25 |    local nSamples = length / 3074 -- 1 coarse-label byte, 1 fine-label byte, 3072 pixel bytes
26 | 
27 |    assert(nSamples == math.floor(nSamples), 'expecting numSamples to be an exact integer')
28 |    m:seek(1)
29 | 
30 |    local coarse = torch.ByteTensor(nSamples)
31 |    local fine = torch.ByteTensor(nSamples)
32 |    local data = torch.ByteTensor(nSamples, 3, 32, 32)
33 |    for i=1,nSamples do
34 |       coarse[i] = m:readByte()
35 |       fine[i]   = m:readByte()
36 |       local store = m:readByte(3072)
37 |       data[i]:copy(torch.ByteTensor(store))
38 |    end
39 | 
40 |    local out = {}
41 |    out.data = data
42 |    -- This is *very* important. The downloaded files have labels 0-9, which do
43 |    -- not work with CrossEntropyCriterion
44 |    out.labels = fine + 1
45 | 
46 |    return out
47 | end
48 | 
49 | function M.exec(opt, cacheFile)
50 |    print("=> Downloading CIFAR-100 dataset from " .. URL)
51 |    
52 |    local ok = os.execute('curl ' .. URL .. ' | tar xz -C  gen/')
53 |    assert(ok == true or ok == 0, 'error downloading CIFAR-100')
54 | 
55 |    print(" | combining dataset into a single file")
56 |    
57 |    local trainData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/train.bin')
58 |    local testData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/test.bin')
59 | 
60 |    print(" | saving CIFAR-100 dataset to " .. cacheFile)
61 |    torch.save(cacheFile, {
62 |       train = trainData,
63 |       val = testData,
64 |    })
65 | end
66 | 
67 | return M
68 | 


--------------------------------------------------------------------------------
/datasets/cifar100.lua:
--------------------------------------------------------------------------------
 1 | --
 2 | --  Copyright (c) 2016, Facebook, Inc.
 3 | --  All rights reserved.
 4 | --
 5 | --  This source code is licensed under the BSD-style license found in the
 6 | --  LICENSE file in the root directory of this source tree. An additional grant
 7 | --  of patent rights can be found in the PATENTS file in the same directory.
 8 | --
 9 | 
10 | ------------
11 | -- This file is downloading and transforming CIFAR-100.
12 | -- It is based on cifar10.lua
13 | -- Ludovic Trottier
14 | ------------
15 | 
16 | local t = require 'datasets/transforms'
17 | 
18 | local M = {}
19 | local CifarDataset = torch.class('resnet.CifarDataset', M)
20 | 
21 | function CifarDataset:__init(imageInfo, opt, split)
22 |    assert(imageInfo[split], split)
23 |    self.imageInfo = imageInfo[split]
24 |    self.split = split
25 | end
26 | 
27 | function CifarDataset:get(i)
28 |    local image = self.imageInfo.data[i]:float()
29 |    local label = self.imageInfo.labels[i]
30 | 
31 |    return {
32 |       input = image,
33 |       target = label,
34 |    }
35 | end
36 | 
37 | function CifarDataset:size()
38 |    return self.imageInfo.data:size(1)
39 | end
40 | 
41 | 
42 | -- Computed from entire CIFAR-100 training set with this code:
43 | --      dataset = torch.load('cifar100.t7')
44 | --      tt = dataset.train.data:double();
45 | --      tt = tt:transpose(2,4);
46 | --      tt = tt:reshape(50000*32*32, 3);
47 | --      tt:mean(1)
48 | --      tt:std(1)
49 | local meanstd = {
50 |    mean = {129.3, 124.1, 112.4},
51 |    std  = {68.2,  65.4,  70.4},
52 | }
53 | 
54 | function CifarDataset:preprocess()
55 |    if self.split == 'train' then
56 |       return t.Compose{
57 |          t.ColorNormalize(meanstd),
58 |          t.HorizontalFlip(0.5),
59 |          t.RandomCrop(32, 4),
60 |       }
61 |    elseif self.split == 'val' then
62 |       return t.ColorNormalize(meanstd)
63 |    else
64 |       error('invalid split: ' .. self.split)
65 |    end
66 | end
67 | 
68 | return M.CifarDataset
69 | 


--------------------------------------------------------------------------------
/datasets/imagenet-gen.lua:
--------------------------------------------------------------------------------
  1 | --
  2 | --  Copyright (c) 2016, Facebook, Inc.
  3 | --  All rights reserved.
  4 | --
  5 | --  This source code is licensed under the BSD-style license found in the
  6 | --  LICENSE file in the root directory of this source tree. An additional grant
  7 | --  of patent rights can be found in the PATENTS file in the same directory.
  8 | --
  9 | --  Script to compute list of ImageNet filenames and classes
 10 | --
 11 | --  This generates a file gen/imagenet.t7 which contains the list of all
 12 | --  ImageNet training and validation images and their classes. This script also
 13 | --  works for other datasets arragned with the same layout.
 14 | --
 15 | 
 16 | local sys = require 'sys'
 17 | local ffi = require 'ffi'
 18 | 
 19 | local M = {}
 20 | 
 21 | local function findClasses(dir)
 22 |    local dirs = paths.dir(dir)
 23 |    table.sort(dirs)
 24 | 
 25 |    local classList = {}
 26 |    local classToIdx = {}
 27 |    for _ ,class in ipairs(dirs) do
 28 |       if not classToIdx[class] and class ~= '.' and class ~= '..' and class ~= '.DS_Store' then
 29 |          table.insert(classList, class)
 30 |          classToIdx[class] = #classList
 31 |       end
 32 |    end
 33 | 
 34 |    -- assert(#classList == 1000, 'expected 1000 ImageNet classes')
 35 |    return classList, classToIdx
 36 | end
 37 | 
 38 | local function findImages(dir, classToIdx)
 39 |    local imagePath = torch.CharTensor()
 40 |    local imageClass = torch.LongTensor()
 41 | 
 42 |    ----------------------------------------------------------------------
 43 |    -- Options for the GNU and BSD find command
 44 |    local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
 45 |    local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
 46 |    for i=2,#extensionList do
 47 |       findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
 48 |    end
 49 | 
 50 |    -- Find all the images using the find command
 51 |    local f = io.popen('find -L ' .. dir .. findOptions)
 52 | 
 53 |    local maxLength = -1
 54 |    local imagePaths = {}
 55 |    local imageClasses = {}
 56 | 
 57 |    -- Generate a list of all the images and their class
 58 |    while true do
 59 |       local line = f:read('*line')
 60 |       if not line then break end
 61 | 
 62 |       local className = paths.basename(paths.dirname(line))
 63 |       local filename = paths.basename(line)
 64 |       local path = className .. '/' .. filename
 65 | 
 66 |       local classId = classToIdx[className]
 67 |       assert(classId, 'class not found: ' .. className)
 68 | 
 69 |       table.insert(imagePaths, path)
 70 |       table.insert(imageClasses, classId)
 71 | 
 72 |       maxLength = math.max(maxLength, #path + 1)
 73 |    end
 74 | 
 75 |    f:close()
 76 | 
 77 |    -- Convert the generated list to a tensor for faster loading
 78 |    local nImages = #imagePaths
 79 |    local imagePath = torch.CharTensor(nImages, maxLength):zero()
 80 |    for i, path in ipairs(imagePaths) do
 81 |       ffi.copy(imagePath[i]:data(), path)
 82 |    end
 83 | 
 84 |    local imageClass = torch.LongTensor(imageClasses)
 85 |    return imagePath, imageClass
 86 | end
 87 | 
 88 | function M.exec(opt, cacheFile)
 89 |    -- find the image path names
 90 |    local imagePath = torch.CharTensor()  -- path to each image in dataset
 91 |    local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
 92 | 
 93 |    local trainDir = paths.concat(opt.data, 'train')
 94 |    local valDir = paths.concat(opt.data, 'val')
 95 |    assert(paths.dirp(trainDir), 'train directory not found: ' .. trainDir)
 96 |    assert(paths.dirp(valDir), 'val directory not found: ' .. valDir)
 97 | 
 98 |    print("=> Generating list of images")
 99 |    local classList, classToIdx = findClasses(trainDir)
100 | 
101 |    print(" | finding all validation images")
102 |    local valImagePath, valImageClass = findImages(valDir, classToIdx)
103 | 
104 |    print(" | finding all training images")
105 |    local trainImagePath, trainImageClass = findImages(trainDir, classToIdx)
106 | 
107 |    local info = {
108 |       basedir = opt.data,
109 |       classList = classList,
110 |       train = {
111 |          imagePath = trainImagePath,
112 |          imageClass = trainImageClass,
113 |       },
114 |       val = {
115 |          imagePath = valImagePath,
116 |          imageClass = valImageClass,
117 |       },
118 |    }
119 | 
120 |    print(" | saving list of images to " .. cacheFile)
121 |    torch.save(cacheFile, info)
122 |    return info
123 | end
124 | 
125 | return M
126 | 


--------------------------------------------------------------------------------
/datasets/imagenet.lua:
--------------------------------------------------------------------------------
  1 | --
  2 | --  Copyright (c) 2016, Facebook, Inc.
  3 | --  All rights reserved.
  4 | --
  5 | --  This source code is licensed under the BSD-style license found in the
  6 | --  LICENSE file in the root directory of this source tree. An additional grant
  7 | --  of patent rights can be found in the PATENTS file in the same directory.
  8 | --
  9 | --  ImageNet dataset loader
 10 | --
 11 | 
 12 | local image = require 'image'
 13 | local paths = require 'paths'
 14 | local t = require 'datasets/transforms'
 15 | local ffi = require 'ffi'
 16 | 
 17 | local M = {}
 18 | local ImagenetDataset = torch.class('resnet.ImagenetDataset', M)
 19 | 
 20 | function ImagenetDataset:__init(imageInfo, opt, split)
 21 |    self.imageInfo = imageInfo[split]
 22 |    self.opt = opt
 23 |    self.split = split
 24 |    self.dir = paths.concat(opt.data, split)
 25 |    assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir)
 26 | end
 27 | 
 28 | function ImagenetDataset:get(i)
 29 |    local path = ffi.string(self.imageInfo.imagePath[i]:data())
 30 | 
 31 |    local image = self:_loadImage(paths.concat(self.dir, path))
 32 |    local class = self.imageInfo.imageClass[i]
 33 | 
 34 |    return {
 35 |       input = image,
 36 |       target = class,
 37 |    }
 38 | end
 39 | 
 40 | function ImagenetDataset:_loadImage(path)
 41 |    local ok, input = pcall(function()
 42 |       return image.load(path, 3, 'float')
 43 |    end)
 44 | 
 45 |    -- Sometimes image.load fails because the file extension does not match the
 46 |    -- image format. In that case, use image.decompress on a ByteTensor.
 47 |    if not ok then
 48 |       local f = io.open(path, 'r')
 49 |       assert(f, 'Error reading: ' .. tostring(path))
 50 |       local data = f:read('*a')
 51 |       f:close()
 52 | 
 53 |       local b = torch.ByteTensor(string.len(data))
 54 |       ffi.copy(b:data(), data, b:size(1))
 55 | 
 56 |       input = image.decompress(b, 3, 'float')
 57 |    end
 58 | 
 59 |    return input
 60 | end
 61 | 
 62 | function ImagenetDataset:size()
 63 |    return self.imageInfo.imageClass:size(1)
 64 | end
 65 | 
 66 | -- Computed from random subset of ImageNet training images
 67 | local meanstd = {
 68 |    mean = { 0.485, 0.456, 0.406 },
 69 |    std = { 0.229, 0.224, 0.225 },
 70 | }
 71 | local pca = {
 72 |    eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 },
 73 |    eigvec = torch.Tensor{
 74 |       { -0.5675,  0.7192,  0.4009 },
 75 |       { -0.5808, -0.0045, -0.8140 },
 76 |       { -0.5836, -0.6948,  0.4203 },
 77 |    },
 78 | }
 79 | 
 80 | function ImagenetDataset:preprocess()
 81 |    if self.split == 'train' then
 82 |       return t.Compose{
 83 |          t.RandomSizedCrop(224),
 84 |          t.ColorJitter({
 85 |             brightness = 0.4,
 86 |             contrast = 0.4,
 87 |             saturation = 0.4,
 88 |          }),
 89 |          t.Lighting(0.1, pca.eigval, pca.eigvec),
 90 |          t.ColorNormalize(meanstd),
 91 |          t.HorizontalFlip(0.5),
 92 |       }
 93 |    elseif self.split == 'val' then
 94 |       local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop
 95 |       return t.Compose{
 96 |          t.Scale(256),
 97 |          t.ColorNormalize(meanstd),
 98 |          Crop(224),
 99 |       }
100 |    else
101 |       error('invalid split: ' .. self.split)
102 |    end
103 | end
104 | 
105 | return M.ImagenetDataset
106 | 


--------------------------------------------------------------------------------
/datasets/init.lua:
--------------------------------------------------------------------------------
 1 | --
 2 | --  Copyright (c) 2016, Facebook, Inc.
 3 | --  All rights reserved.
 4 | --
 5 | --  This source code is licensed under the BSD-style license found in the
 6 | --  LICENSE file in the root directory of this source tree. An additional grant
 7 | --  of patent rights can be found in the PATENTS file in the same directory.
 8 | --
 9 | --  ImageNet and CIFAR-10 datasets
10 | --
11 | 
12 | local M = {}
13 | 
14 | local function isvalid(opt, cachePath)
15 |    local imageInfo = torch.load(cachePath)
16 |    if imageInfo.basedir and imageInfo.basedir ~= opt.data then
17 |       return false
18 |    end
19 |    return true
20 | end
21 | 
22 | function M.create(opt, split)
23 |    local cachePath = paths.concat(opt.gen, opt.dataset .. '.t7')
24 |    if not paths.filep(cachePath) or not isvalid(opt, cachePath) then
25 |       paths.mkdir('gen')
26 | 
27 |       local script = paths.dofile(opt.dataset .. '-gen.lua')
28 |       script.exec(opt, cachePath)
29 |    end
30 |    local imageInfo = torch.load(cachePath)
31 | 
32 |    local Dataset = require('datasets/' .. opt.dataset)
33 |    return Dataset(imageInfo, opt, split)
34 | end
35 | 
36 | return M
37 | 


--------------------------------------------------------------------------------
/datasets/transforms.lua:
--------------------------------------------------------------------------------
  1 | --
  2 | --  Copyright (c) 2016, Facebook, Inc.
  3 | --  All rights reserved.
  4 | --
  5 | --  This source code is licensed under the BSD-style license found in the
  6 | --  LICENSE file in the root directory of this source tree. An additional grant
  7 | --  of patent rights can be found in the PATENTS file in the same directory.
  8 | --
  9 | --  Image transforms for data augmentation and input normalization
 10 | --
 11 | 
 12 | require 'image'
 13 | 
 14 | local M = {}
 15 | 
 16 | function M.Compose(transforms)
 17 |    return function(input)
 18 |       for _, transform in ipairs(transforms) do
 19 |          input = transform(input)
 20 |       end
 21 |       return input
 22 |    end
 23 | end
 24 | 
 25 | function M.ColorNormalize(meanstd)
 26 |    return function(img)
 27 |       img = img:clone()
 28 |       for i=1,3 do
 29 |          img[i]:add(-meanstd.mean[i])
 30 |          img[i]:div(meanstd.std[i])
 31 |       end
 32 |       return img
 33 |    end
 34 | end
 35 | 
 36 | -- Scales the smaller edge to size
 37 | function M.Scale(size, interpolation)
 38 |    interpolation = interpolation or 'bicubic'
 39 |    return function(input)
 40 |       local w, h = input:size(3), input:size(2)
 41 |       if (w <= h and w == size) or (h <= w and h == size) then
 42 |          return input
 43 |       end
 44 |       if w < h then
 45 |          return image.scale(input, size, h/w * size, interpolation)
 46 |       else
 47 |          return image.scale(input, w/h * size, size, interpolation)
 48 |       end
 49 |    end
 50 | end
 51 | 
 52 | -- Crop to centered rectangle
 53 | function M.CenterCrop(size)
 54 |    return function(input)
 55 |       local w1 = math.ceil((input:size(3) - size)/2)
 56 |       local h1 = math.ceil((input:size(2) - size)/2)
 57 |       return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch
 58 |    end
 59 | end
 60 | 
 61 | -- Random crop form larger image with optional zero padding
 62 | function M.RandomCrop(size, padding)
 63 |    padding = padding or 0
 64 | 
 65 |    return function(input)
 66 |       if padding > 0 then
 67 |          local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding)
 68 |          temp:zero()
 69 |             :narrow(2, padding+1, input:size(2))
 70 |             :narrow(3, padding+1, input:size(3))
 71 |             :copy(input)
 72 |          input = temp
 73 |       end
 74 | 
 75 |       local w, h = input:size(3), input:size(2)
 76 |       if w == size and h == size then
 77 |          return input
 78 |       end
 79 | 
 80 |       local x1, y1 = torch.random(0, w - size), torch.random(0, h - size)
 81 |       local out = image.crop(input, x1, y1, x1 + size, y1 + size)
 82 |       assert(out:size(2) == size and out:size(3) == size, 'wrong crop size')
 83 |       return out
 84 |    end
 85 | end
 86 | 
 87 | -- Four corner patches and center crop from image and its horizontal reflection
 88 | function M.TenCrop(size)
 89 |    local centerCrop = M.CenterCrop(size)
 90 | 
 91 |    return function(input)
 92 |       local w, h = input:size(3), input:size(2)
 93 | 
 94 |       local output = {}
 95 |       for _, img in ipairs{input, image.hflip(input)} do
 96 |          table.insert(output, centerCrop(img))
 97 |          table.insert(output, image.crop(img, 0, 0, size, size))
 98 |          table.insert(output, image.crop(img, w-size, 0, w, size))
 99 |          table.insert(output, image.crop(img, 0, h-size, size, h))
100 |          table.insert(output, image.crop(img, w-size, h-size, w, h))
101 |       end
102 | 
103 |       -- View as mini-batch
104 |       for i, img in ipairs(output) do
105 |          output[i] = img:view(1, img:size(1), img:size(2), img:size(3))
106 |       end
107 | 
108 |       return input.cat(output, 1)
109 |    end
110 | end
111 | 
112 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style)
113 | function M.RandomScale(minSize, maxSize)
114 |    return function(input)
115 |       local w, h = input:size(3), input:size(2)
116 | 
117 |       local targetSz = torch.random(minSize, maxSize)
118 |       local targetW, targetH = targetSz, targetSz
119 |       if w < h then
120 |          targetH = torch.round(h / w * targetW)
121 |       else
122 |          targetW = torch.round(w / h * targetH)
123 |       end
124 | 
125 |       return image.scale(input, targetW, targetH, 'bicubic')
126 |    end
127 | end
128 | 
129 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style)
130 | function M.RandomSizedCrop(size)
131 |    local scale = M.Scale(size)
132 |    local crop = M.CenterCrop(size)
133 | 
134 |    return function(input)
135 |       local attempt = 0
136 |       repeat
137 |          local area = input:size(2) * input:size(3)
138 |          local targetArea = torch.uniform(0.08, 1.0) * area
139 | 
140 |          local aspectRatio = torch.uniform(3/4, 4/3)
141 |          local w = torch.round(math.sqrt(targetArea * aspectRatio))
142 |          local h = torch.round(math.sqrt(targetArea / aspectRatio))
143 | 
144 |          if torch.uniform() < 0.5 then
145 |             w, h = h, w
146 |          end
147 | 
148 |          if h <= input:size(2) and w <= input:size(3) then
149 |             local y1 = torch.random(0, input:size(2) - h)
150 |             local x1 = torch.random(0, input:size(3) - w)
151 | 
152 |             local out = image.crop(input, x1, y1, x1 + w, y1 + h)
153 |             assert(out:size(2) == h and out:size(3) == w, 'wrong crop size')
154 | 
155 |             return image.scale(out, size, size, 'bicubic')
156 |          end
157 |          attempt = attempt + 1
158 |       until attempt >= 10
159 | 
160 |       -- fallback
161 |       return crop(scale(input))
162 |    end
163 | end
164 | 
165 | function M.HorizontalFlip(prob)
166 |    return function(input)
167 |       if torch.uniform() < prob then
168 |          input = image.hflip(input)
169 |       end
170 |       return input
171 |    end
172 | end
173 | 
174 | function M.Rotation(deg)
175 |    return function(input)
176 |       if deg ~= 0 then
177 |          input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear')
178 |       end
179 |       return input
180 |    end
181 | end
182 | 
183 | -- Lighting noise (AlexNet-style PCA-based noise)
184 | function M.Lighting(alphastd, eigval, eigvec)
185 |    return function(input)
186 |       if alphastd == 0 then
187 |          return input
188 |       end
189 | 
190 |       local alpha = torch.Tensor(3):normal(0, alphastd)
191 |       local rgb = eigvec:clone()
192 |          :cmul(alpha:view(1, 3):expand(3, 3))
193 |          :cmul(eigval:view(1, 3):expand(3, 3))
194 |          :sum(2)
195 |          :squeeze()
196 | 
197 |       input = input:clone()
198 |       for i=1,3 do
199 |          input[i]:add(rgb[i])
200 |       end
201 |       return input
202 |    end
203 | end
204 | 
205 | local function blend(img1, img2, alpha)
206 |    return img1:mul(alpha):add(1 - alpha, img2)
207 | end
208 | 
209 | local function grayscale(dst, img)
210 |    dst:resizeAs(img)
211 |    dst[1]:zero()
212 |    dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3])
213 |    dst[2]:copy(dst[1])
214 |    dst[3]:copy(dst[1])
215 |    return dst
216 | end
217 | 
218 | function M.Saturation(var)
219 |    local gs
220 | 
221 |    return function(input)
222 |       gs = gs or input.new()
223 |       grayscale(gs, input)
224 | 
225 |       local alpha = 1.0 + torch.uniform(-var, var)
226 |       blend(input, gs, alpha)
227 |       return input
228 |    end
229 | end
230 | 
231 | function M.Brightness(var)
232 |    local gs
233 | 
234 |    return function(input)
235 |       gs = gs or input.new()
236 |       gs:resizeAs(input):zero()
237 | 
238 |       local alpha = 1.0 + torch.uniform(-var, var)
239 |       blend(input, gs, alpha)
240 |       return input
241 |    end
242 | end
243 | 
244 | function M.Contrast(var)
245 |    local gs
246 | 
247 |    return function(input)
248 |       gs = gs or input.new()
249 |       grayscale(gs, input)
250 |       gs:fill(gs[1]:mean())
251 | 
252 |       local alpha = 1.0 + torch.uniform(-var, var)
253 |       blend(input, gs, alpha)
254 |       return input
255 |    end
256 | end
257 | 
258 | function M.RandomOrder(ts)
259 |    return function(input)
260 |       local img = input.img or input
261 |       local order = torch.randperm(#ts)
262 |       for i=1,#ts do
263 |          img = ts[order[i]](img)
264 |       end
265 |       return img
266 |    end
267 | end
268 | 
269 | function M.ColorJitter(opt)
270 |    local brightness = opt.brightness or 0
271 |    local contrast = opt.contrast or 0
272 |    local saturation = opt.saturation or 0
273 | 
274 |    local ts = {}
275 |    if brightness ~= 0 then
276 |       table.insert(ts, M.Brightness(brightness))
277 |    end
278 |    if contrast ~= 0 then
279 |       table.insert(ts, M.Contrast(contrast))
280 |    end
281 |    if saturation ~= 0 then
282 |       table.insert(ts, M.Saturation(saturation))
283 |    end
284 | 
285 |    if #ts == 0 then
286 |       return function(input) return input end
287 |    end
288 | 
289 |    return M.RandomOrder(ts)
290 | end
291 | 
292 | return M
293 | 


--------------------------------------------------------------------------------
/main.lua:
--------------------------------------------------------------------------------
 1 | --
 2 | --  Copyright (c) 2016, Facebook, Inc.
 3 | --  All rights reserved.
 4 | --
 5 | --  This source code is licensed under the BSD-style license found in the
 6 | --  LICENSE file in the root directory of this source tree. An additional grant
 7 | --  of patent rights can be found in the PATENTS file in the same directory.
 8 | --
 9 | require 'torch'
10 | require 'paths'
11 | require 'optim'
12 | require 'nn'
13 | local DataLoader = require 'dataloader'
14 | local models = require 'models/init'
15 | local Trainer = require 'train'
16 | local opts = require 'opts'
17 | local checkpoints = require 'checkpoints'
18 | 
19 | -- we don't  change this to the 'correct' type (e.g. HalfTensor), because math
20 | -- isn't supported on that type.  Type conversion later will handle having
21 | -- the correct type.
22 | torch.setdefaulttensortype('torch.FloatTensor')
23 | torch.setnumthreads(1)
24 | 
25 | local opt = opts.parse(arg)
26 | torch.manualSeed(opt.manualSeed)
27 | cutorch.manualSeedAll(opt.manualSeed)
28 | 
29 | -- Load previous checkpoint, if it exists
30 | local checkpoint, optimState = checkpoints.latest(opt)
31 | 
32 | -- Create model
33 | local model, criterion = models.setup(opt, checkpoint)
34 | 
35 | -- Data loading
36 | local trainLoader, valLoader = DataLoader.create(opt)
37 | 
38 | -- The trainer handles the training loop and evaluation on validation set
39 | local trainer = Trainer(model, criterion, opt, optimState)
40 | 
41 | if opt.testOnly then
42 |    local top1Err, top5Err = trainer:test(0, valLoader)
43 |    print(string.format(' * Results top1: %6.3f  top5: %6.3f', top1Err, top5Err))
44 |    return
45 | end
46 | 
47 | local startEpoch = checkpoint and checkpoint.epoch + 1 or opt.epochNumber
48 | local bestTop1 = math.huge
49 | local bestTop5 = math.huge
50 | for epoch = startEpoch, opt.nEpochs do
51 |    -- Train for a single epoch
52 |    local trainTop1, trainTop5, trainLoss = trainer:train(epoch, trainLoader)
53 | 
54 |    -- Run model on validation set
55 |    local testTop1, testTop5 = trainer:test(epoch, valLoader)
56 | 
57 |    local bestModel = false
58 |    if testTop1 < bestTop1 then
59 |       bestModel = true
60 |       bestTop1 = testTop1
61 |       bestTop5 = testTop5
62 |       print(' * Best model ', testTop1, testTop5)
63 |    end
64 | 
65 |    checkpoints.save(epoch, model, trainer.optimState, bestModel, opt)
66 | end
67 | 
68 | print(string.format(' * Finished top1: %6.3f  top5: %6.3f', bestTop1, bestTop5))


--------------------------------------------------------------------------------
/models/DenseConnectLayer.lua:
--------------------------------------------------------------------------------
  1 | require 'nn'
  2 | require 'cudnn'
  3 | require 'cunn'
  4 | 
  5 | 
  6 | local function ShareGradInput(module, key)
  7 |    assert(key)
  8 |    module.__shareGradInputKey = key
  9 |    return module
 10 | end
 11 | 
 12 | --------------------------------------------------------------------------------
 13 | -- Standard densely connected layer (memory inefficient)
 14 | --------------------------------------------------------------------------------
 15 | function DenseConnectLayerStandard(nChannels, opt)
 16 |    local net = nn.Sequential()
 17 | 
 18 |    net:add(ShareGradInput(cudnn.SpatialBatchNormalization(nChannels), 'first'))
 19 |    net:add(cudnn.ReLU(true))   
 20 |    if opt.bottleneck then
 21 |       net:add(cudnn.SpatialConvolution(nChannels, 4 * opt.growthRate, 1, 1, 1, 1, 0, 0))
 22 |       nChannels = 4 * opt.growthRate
 23 |       if opt.dropRate > 0 then net:add(nn.Dropout(opt.dropRate)) end
 24 |       net:add(cudnn.SpatialBatchNormalization(nChannels))
 25 |       net:add(cudnn.ReLU(true))      
 26 |    end
 27 |    net:add(cudnn.SpatialConvolution(nChannels, opt.growthRate, 3, 3, 1, 1, 1, 1))
 28 |    if opt.dropRate > 0 then net:add(nn.Dropout(opt.dropRate)) end
 29 | 
 30 |    return nn.Sequential()
 31 |       :add(nn.Concat(2)
 32 |          :add(nn.Identity())
 33 |          :add(net))  
 34 | end
 35 | 
 36 | --------------------------------------------------------------------------------
 37 | -- Customized densely connected layer (memory efficient)
 38 | --------------------------------------------------------------------------------
 39 | local DenseConnectLayerCustom, parent = torch.class('nn.DenseConnectLayerCustom', 'nn.Container')
 40 | 
 41 | function DenseConnectLayerCustom:__init(nChannels, opt)
 42 |    parent.__init(self)
 43 |    self.train = true
 44 |    self.opt = opt
 45 | 
 46 |    self.net1 = nn.Sequential()
 47 |    self.net1:add(ShareGradInput(cudnn.SpatialBatchNormalization(nChannels), 'first'))
 48 |    self.net1:add(cudnn.ReLU(true))  
 49 | 
 50 |    self.net2 = nn.Sequential()
 51 |    if opt.bottleneck then
 52 |       self.net2:add(cudnn.SpatialConvolution(nChannels, 4*opt.growthRate, 1, 1, 1, 1, 0, 0))
 53 |       nChannels = 4 * opt.growthRate
 54 |       self.net2:add(cudnn.SpatialBatchNormalization(nChannels))
 55 |       self.net2:add(cudnn.ReLU(true))
 56 |    end
 57 |    self.net2:add(cudnn.SpatialConvolution(nChannels, opt.growthRate, 3, 3, 1, 1, 1, 1))
 58 | 
 59 |    -- contiguous outputs of previous layers
 60 |    self.input_c = torch.Tensor():type(opt.tensorType) 
 61 |    -- save a copy of BatchNorm statistics before forwarding it for the second time when optMemory=4
 62 |    self.saved_bn_running_mean = torch.Tensor():type(opt.tensorType)
 63 |    self.saved_bn_running_var = torch.Tensor():type(opt.tensorType)
 64 | 
 65 |    self.gradInput = {}
 66 |    self.output = {}
 67 | 
 68 |    self.modules = {self.net1, self.net2}
 69 | end
 70 | 
 71 | function DenseConnectLayerCustom:updateOutput(input)
 72 | 
 73 |    if type(input) ~= 'table' then
 74 |       self.output[1] = input
 75 |       self.output[2] = self.net2:forward(self.net1:forward(input))
 76 |    else
 77 |       for i = 1, #input do
 78 |          self.output[i] = input[i]
 79 |       end
 80 |       torch.cat(self.input_c, input, 2)
 81 |       self.net1:forward(self.input_c)
 82 |       self.output[#input+1] = self.net2:forward(self.net1.output)
 83 |    end
 84 | 
 85 |    if self.opt.optMemory == 4 then
 86 |       local running_mean, running_var = self.net1:get(1).running_mean, self.net1:get(1).running_var
 87 |       self.saved_bn_running_mean:resizeAs(running_mean):copy(running_mean)
 88 |       self.saved_bn_running_var:resizeAs(running_var):copy(running_var)
 89 |    end
 90 | 
 91 |    return self.output
 92 | end
 93 | 
 94 | function DenseConnectLayerCustom:updateGradInput(input, gradOutput)
 95 | 
 96 |    if type(input) ~= 'table' then
 97 |       self.gradInput = gradOutput[1]
 98 |       if self.opt.optMemory == 4 then self.net1:forward(input) end
 99 |       self.net2:updateGradInput(self.net1.output, gradOutput[2])
100 |       self.gradInput:add(self.net1:updateGradInput(input, self.net2.gradInput))
101 |    else
102 |       torch.cat(self.input_c, input, 2)
103 |       if self.opt.optMemory == 4 then self.net1:forward(self.input_c) end
104 |       self.net2:updateGradInput(self.net1.output, gradOutput[#gradOutput])
105 |       self.net1:updateGradInput(self.input_c, self.net2.gradInput)
106 |       local nC = 1
107 |       for i = 1, #input do
108 |          self.gradInput[i] = gradOutput[i]
109 |          self.gradInput[i]:add(self.net1.gradInput:narrow(2,nC,input[i]:size(2)))
110 |          nC = nC + input[i]:size(2)
111 |       end
112 |    end
113 | 
114 |    if self.opt.optMemory == 4 then
115 |       self.net1:get(1).running_mean:copy(self.saved_bn_running_mean)
116 |       self.net1:get(1).running_var:copy(self.saved_bn_running_var)
117 |    end
118 | 
119 |    return self.gradInput
120 | end
121 | 
122 | function DenseConnectLayerCustom:accGradParameters(input, gradOutput, scale)
123 |    scale = scale or 1
124 |    self.net2:accGradParameters(self.net1.output, gradOutput[#gradOutput], scale)
125 |    if type(input) ~= 'table' then
126 |       self.net1:accGradParameters(input, self.net2.gradInput, scale)
127 |    else
128 |       self.net1:accGradParameters(self.input_c, self.net2.gradInput, scale)
129 |    end
130 | end
131 | 
132 | function DenseConnectLayerCustom:__tostring__()
133 |    local tab = '  '
134 |    local line = '\n'
135 |    local next = '  |`-> '
136 |    local lastNext = '   `-> '
137 |    local ext = '  |    '
138 |    local extlast = '       '
139 |    local last = '   ... -> '
140 |    local str = 'DenseConnectLayerCustom'
141 |    str = str .. ' {' .. line .. tab .. '{input}'
142 |    for i=1,#self.modules do
143 |       if i == #self.modules then
144 |          str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
145 |       else
146 |          str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
147 |       end
148 |    end
149 |    str = str .. line .. tab .. last .. '{output}'
150 |    str = str .. line .. '}'
151 |    return str
152 | end
153 | 


--------------------------------------------------------------------------------
/models/README.md:
--------------------------------------------------------------------------------
 1 | # Memory Efficient Implementation of DenseNets
 2 | 
 3 | The standard (orginal) implementation of DenseNet with recursive concatenation is very memory inefficient. This can be an obstacle when we need to train DenseNets on high resolution images (such as for object detection and localization tasks) or on devices with limited memory.
 4 | 
 5 | In theory, DenseNet should use memory more efficiently than other networks, because one of its key features is that it encourages *feature reusing* in the network. The fact that DenseNet is "memory hungry" in practice is simply an artifact of implementation. In particular, the culprit is the recursive concatenation which *re-allocates memory* for all previous outputs at each layer. Consider a dense block with N layers, the first layer's output has N copies in the memory, the second layer's output has (N-1) copies, ..., leading to a quadratic increase (1+2+...+N) in memory consumption as the network depth grows.
 6 | 
 7 | Using *optnet* (*-optMemory 1*) or *shareGradInput* (*-optMemory 2*), we can significantly reduce the run-time memory footprint of the standard implementaion (with recursive concatenation). However, the memory consumption is still a quadratic function in depth. 
 8 | 
 9 | We implement a customized densely connected layer (largely motivated by the [Caffe implementation of memory efficient DenseNet](https://github.com/Tongcheng/DN_CaffeScript) by [Tongcheng](https://github.com/Tongcheng)), which uses shared buffers to store the concatenated outputs and gradients, thus dramatically reducing the memory footprint of DenseNet during training. The mode *-optMemory 3* activates shareGradInput and shared output buffers, while the mode *-optMemory 4* further shares the memory to store the output of the Batch-Normalization layer before each 1x1 convolution layer. The latter makes the memory consumption *linear* in network depth, but introduces a training time overhead due to the need to re-forward these Batch-Normalization layers in the backward pass.
10 | 
11 | In practice, we suggest using the default *-optMemory 2*, as it does not require customized layers, while the memory consumption is moderate. When GPU memory is really the bottleneck, we can adopt the customized implementation by setting *-optMemory* to 3 or 4, e.g.,
12 | ```
13 | th main.lua -netType densenet -dataset cifar10 -batchSize 64 -nEpochs 300 -depth 100 -growthRate 12 -optMemory 4
14 | ``` 
15 | 
16 | The following time and memory footprint are benchmarked on a DenseNet-BC (l=100, k=12) on CIFAR-10, and on an NVIDIA TitanX GPU:
17 | 
18 | optMemory | Memory | Time (s/mini-batch) |  Description | 
19 | :-------:|:-------:|:--------:|:-------|
20 |  | 0 | 5453M | 0.153 | Original implementation
21 |  | 1 | 3746M | 0.153 | Original implementation with optnet
22 |  | 2 | 2969M | 0.152 | Original implementation with shareGradInput 
23 |  | 3 | 2188M | 0.155 | Customized implementation with shareGradInput and sharePrevOutput
24 |  | 4 | 1655M | 0.175 | Customized implementation with shareGradInput, sharePrevOutput and shareBNOutput 


--------------------------------------------------------------------------------
/models/densenet.lua:
--------------------------------------------------------------------------------
  1 | require 'nn'
  2 | require 'cunn'
  3 | require 'cudnn'
  4 | require 'models/DenseConnectLayer'
  5 | 
  6 | local function createModel(opt)
  7 | 
  8 |    --growth rate
  9 |    local growthRate = opt.growthRate
 10 | 
 11 |    --dropout rate, set it to 0 to disable dropout, non-zero number to enable dropout and set drop rate
 12 |    local dropRate = opt.dropRate
 13 | 
 14 |    --# channels before entering the first Dense-Block
 15 |    local nChannels = 2 * growthRate
 16 | 
 17 |    --compression rate at transition layers
 18 |    local reduction = opt.reduction
 19 | 
 20 |    --whether to use bottleneck structures
 21 |    local bottleneck = opt.bottleneck
 22 | 
 23 |    --N: # dense connected layers in each denseblock
 24 |    local N = (opt.depth - 4)/3
 25 |    if bottleneck then N = N/2 end
 26 | 
 27 | 
 28 |    function addLayer(model, nChannels, opt)
 29 |       if opt.optMemory >= 3 then
 30 |          model:add(nn.DenseConnectLayerCustom(nChannels, opt))
 31 |       else
 32 |          model:add(DenseConnectLayerStandard(nChannels, opt))     
 33 |       end
 34 |    end
 35 | 
 36 | 
 37 |    function addTransition(model, nChannels, nOutChannels, opt, last, pool_size)
 38 |       if opt.optMemory >= 3 then     
 39 |          model:add(nn.JoinTable(2))
 40 |       end
 41 | 
 42 |       model:add(cudnn.SpatialBatchNormalization(nChannels))
 43 |       model:add(cudnn.ReLU(true))      
 44 |       if last then
 45 |          model:add(cudnn.SpatialAveragePooling(pool_size, pool_size))
 46 |          model:add(nn.Reshape(nChannels))      
 47 |       else
 48 |          model:add(cudnn.SpatialConvolution(nChannels, nOutChannels, 1, 1, 1, 1, 0, 0))
 49 |          if opt.dropRate > 0 then model:add(nn.Dropout(opt.dropRate)) end
 50 |          model:add(cudnn.SpatialAveragePooling(2, 2))
 51 |       end      
 52 |    end
 53 | 
 54 | 
 55 |    local function addDenseBlock(model, nChannels, opt, N)
 56 |       for i = 1, N do 
 57 |          addLayer(model, nChannels, opt)
 58 |          nChannels = nChannels + opt.growthRate
 59 |       end
 60 |       return nChannels
 61 |    end
 62 | 
 63 | 
 64 |    -- Build DenseNet
 65 |    local model = nn.Sequential()
 66 | 
 67 |    if opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then
 68 | 
 69 |       --Initial convolution layer
 70 |       model:add(cudnn.SpatialConvolution(3, nChannels, 3,3, 1,1, 1,1))      
 71 | 
 72 |       --Dense-Block 1 and transition
 73 |       nChannels = addDenseBlock(model, nChannels, opt, N)
 74 |       addTransition(model, nChannels, math.floor(nChannels*reduction), opt)
 75 |       nChannels = math.floor(nChannels*reduction)
 76 | 
 77 |       --Dense-Block 2 and transition
 78 |       nChannels = addDenseBlock(model, nChannels, opt, N)
 79 |       addTransition(model, nChannels, math.floor(nChannels*reduction), opt)
 80 |       nChannels = math.floor(nChannels*reduction)
 81 | 
 82 |       --Dense-Block 3 and transition
 83 |       nChannels = addDenseBlock(model, nChannels, opt, N)
 84 |       addTransition(model, nChannels, nChannels, opt, true, 8)
 85 | 
 86 |    elseif opt.dataset == 'imagenet' then
 87 | 
 88 |       --number of layers in each block
 89 |       if opt.depth == 121 then
 90 |          stages = {6, 12, 24, 16}
 91 |       elseif opt.depth == 169 then
 92 |          stages = {6, 12, 32, 32}
 93 |       elseif opt.depth == 201 then
 94 |          stages = {6, 12, 48, 32}
 95 |       elseif opt.depth == 161 then
 96 |          stages = {6, 12, 36, 24}
 97 |       else
 98 |          stages = {opt.d1, opt.d2, opt.d3, opt.d4}
 99 |       end
100 | 
101 |       --Initial transforms follow ResNet(224x224)
102 |       model:add(cudnn.SpatialConvolution(3, nChannels, 7,7, 2,2, 3,3))
103 |       model:add(cudnn.SpatialBatchNormalization(nChannels))
104 |       model:add(cudnn.ReLU(true))
105 |       model:add(nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1))
106 | 
107 |       --Dense-Block 1 and transition (56x56)
108 |       nChannels = addDenseBlock(model, nChannels, opt, stages[1])
109 |       addTransition(model, nChannels, math.floor(nChannels*reduction), opt)
110 |       nChannels = math.floor(nChannels*reduction)
111 | 
112 |       --Dense-Block 2 and transition (28x28)
113 |       nChannels = addDenseBlock(model, nChannels, opt, stages[2])
114 |       addTransition(model, nChannels, math.floor(nChannels*reduction), opt)
115 |       nChannels = math.floor(nChannels*reduction)
116 | 
117 |       --Dense-Block 3 and transition (14x14)
118 |       nChannels = addDenseBlock(model, nChannels, opt, stages[3])
119 |       addTransition(model, nChannels, math.floor(nChannels*reduction), opt)
120 |       nChannels = math.floor(nChannels*reduction)
121 | 
122 |       --Dense-Block 4 and transition (7x7)
123 |       nChannels = addDenseBlock(model, nChannels, opt, stages[4])
124 |       addTransition(model, nChannels, nChannels, opt, true, 7)
125 | 
126 |    end
127 | 
128 | 
129 |    if opt.dataset == 'cifar10' then
130 |       model:add(nn.Linear(nChannels, 10))
131 |    elseif opt.dataset == 'cifar100' then
132 |       model:add(nn.Linear(nChannels, 100))
133 |    elseif opt.dataset == 'imagenet' then
134 |       model:add(nn.Linear(nChannels, 1000))
135 |    end
136 | 
137 |    --Initialization following ResNet
138 |    local function ConvInit(name)
139 |       for k,v in pairs(model:findModules(name)) do
140 |          local n = v.kW*v.kH*v.nOutputPlane
141 |          v.weight:normal(0,math.sqrt(2/n))
142 |          if cudnn.version >= 4000 then
143 |             v.bias = nil
144 |             v.gradBias = nil
145 |          else
146 |             v.bias:zero()
147 |          end
148 |       end
149 |    end
150 | 
151 |    local function BNInit(name)
152 |       for k,v in pairs(model:findModules(name)) do
153 |          v.weight:fill(1)
154 |          v.bias:zero()
155 |       end
156 |    end
157 | 
158 |    ConvInit('cudnn.SpatialConvolution')
159 |    BNInit('cudnn.SpatialBatchNormalization')
160 |    for k,v in pairs(model:findModules('nn.Linear')) do
161 |       v.bias:zero()
162 |    end
163 | 
164 |    model:type(opt.tensorType)
165 | 
166 |    if opt.cudnn == 'deterministic' then
167 |       model:apply(function(m)
168 |          if m.setMode then m:setMode(1,1,1) end
169 |       end)
170 |    end
171 | 
172 |    model:get(1).gradInput = nil
173 | 
174 |    print(model)
175 |    return model
176 | end
177 | 
178 | return createModel


--------------------------------------------------------------------------------
/models/init.lua:
--------------------------------------------------------------------------------
  1 | --
  2 | --  Copyright (c) 2016, Facebook, Inc.
  3 | --  All rights reserved.
  4 | --
  5 | --  This source code is licensed under the BSD-style license found in the
  6 | --  LICENSE file in the root directory of this source tree. An additional grant
  7 | --  of patent rights can be found in the PATENTS file in the same directory.
  8 | --
  9 | --  Generic model creating code. For the specific ResNet model see
 10 | --  models/resnet.lua
 11 | --
 12 | -- Code modified for DenseNet (https://arxiv.org/abs/1608.06993) by Gao Huang.
 13 | -- 
 14 | -- More details about the memory efficient implementation can be found in the 
 15 | -- technique report "Memory-Efficient Implementation of DenseNets" 
 16 | -- (https://arxiv.org/pdf/1707.06990.pdf)
 17 | 
 18 | require 'nn'
 19 | require 'cunn'
 20 | require 'cudnn'
 21 | require 'models/DenseConnectLayer'
 22 | 
 23 | local M = {}
 24 | 
 25 | function M.setup(opt, checkpoint)
 26 | 
 27 |    print('=> Creating model from file: models/' .. opt.netType .. '.lua')
 28 |    local model = require('models/' .. opt.netType)(opt)
 29 |    if checkpoint then
 30 |       local modelPath = paths.concat(opt.resume, checkpoint.modelFile)
 31 |       assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath)
 32 |       print('=> Resuming model from ' .. modelPath)
 33 |       local model0 = torch.load(modelPath):type(opt.tensorType)
 34 |       M.copyModel(model, model0)
 35 |    elseif opt.retrain ~= 'none' then
 36 |       assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
 37 |       print('Loading model from file: ' .. opt.retrain)
 38 |       local model0 = torch.load(opt.retrain):type(opt.tensorType)
 39 |       M.copyModel(model, model0)
 40 |    end
 41 | 
 42 |    -- First remove any DataParallelTable
 43 |    if torch.type(model) == 'nn.DataParallelTable' then
 44 |       model = model:get(1)
 45 |    end
 46 | 
 47 |    -- optnet is an general library for reducing memory usage in neural networks
 48 |    if opt.optnet or opt.optMemory == 1 then
 49 |       local optnet = require 'optnet'
 50 |       local imsize = opt.dataset == 'imagenet' and 224 or 32
 51 |       local sampleInput = torch.zeros(4,3,imsize,imsize):type(opt.tensorType)
 52 |       optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'})
 53 |    end
 54 | 
 55 |    -- This is useful for fitting ResNet-50 on 4 GPUs, but requires that all
 56 |    -- containers override backwards to call backwards recursively on submodules
 57 |    if opt.shareGradInput or opt.optMemory >= 2 then
 58 |       M.shareGradInput(model, opt)
 59 |       M.sharePrevOutput(model, opt)
 60 |    end
 61 | 
 62 |    -- Share the contiguous (concatenated) outputs of previous layers in DenseNet.
 63 |    if opt.optMemory == 3 then
 64 |       M.sharePrevOutput(model, opt)
 65 |    end
 66 | 
 67 |    -- Share the output of BatchNorm in bottleneck layers of DenseNet. This requires
 68 |    -- forwarding the BN layer twice at each mini-batch, but makes the memory consumption  
 69 |    -- linear (instead of quadratic) in depth
 70 |    if opt.optMemory == 4 then
 71 |       M.shareBNOutput(model, opt)
 72 |    end
 73 | 
 74 |    -- For resetting the classifier when fine-tuning on a different Dataset
 75 |    if opt.resetClassifier and not checkpoint then
 76 |       print(' => Replacing classifier with ' .. opt.nClasses .. '-way classifier')
 77 | 
 78 |       local orig = model:get(#model.modules)
 79 |       assert(torch.type(orig) == 'nn.Linear',
 80 |          'expected last layer to be fully connected')
 81 | 
 82 |       local linear = nn.Linear(orig.weight:size(2), opt.nClasses)
 83 |       linear.bias:zero()
 84 | 
 85 |       model:remove(#model.modules)
 86 |       model:add(linear:type(opt.tensorType))
 87 |    end
 88 | 
 89 |    -- Set the CUDNN flags
 90 |    if opt.cudnn == 'fastest' then
 91 |       cudnn.fastest = true
 92 |       cudnn.benchmark = true
 93 |    elseif opt.cudnn == 'deterministic' then
 94 |       -- Use a deterministic convolution implementation
 95 |       model:apply(function(m)
 96 |          if m.setMode then m:setMode(1, 1, 1) end
 97 |       end)
 98 |    end
 99 | 
100 |    -- Wrap the model with DataParallelTable, if using more than one GPU
101 |    if opt.nGPU > 1 then
102 |       local gpus = torch.range(1, opt.nGPU):totable()
103 |       local fastest, benchmark = cudnn.fastest, cudnn.benchmark
104 | 
105 |       local dpt = nn.DataParallelTable(1, true, true)
106 |          :add(model, gpus)
107 |          :threads(function()
108 |             local cudnn = require 'cudnn'
109 |             require 'models/DenseConnectLayer'
110 |             cudnn.fastest, cudnn.benchmark = fastest, benchmark
111 |          end)
112 |       dpt.gradInput = nil
113 | 
114 |       model = dpt:type(opt.tensorType)
115 |    end
116 | 
117 |    local criterion = nn.CrossEntropyCriterion():type(opt.tensorType)
118 |    return model, criterion
119 | end
120 | 
121 | function M.shareGradInput(model, opt)
122 |    local function sharingKey(m)
123 |       local key = torch.type(m)
124 |       if m.__shareGradInputKey then
125 |          key = key .. ':' .. m.__shareGradInputKey
126 |       end
127 |       return key
128 |    end
129 | 
130 |    -- Share gradInput for memory efficient backprop
131 |    local cache = {}
132 |    model:apply(function(m)
133 |       local moduleType = torch.type(m)
134 |       if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' and moduleType ~= 'nn.Concat' then
135 |          local key = sharingKey(m)
136 |          if cache[key] == nil then
137 |             cache[key] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
138 |          end
139 |          m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[key], 1, 0)
140 |       end
141 |    end)
142 |    for i, m in ipairs(model:findModules('nn.ConcatTable')) do
143 |       if cache[i % 2] == nil then
144 |          cache[i % 2] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
145 |       end
146 |       m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[i % 2], 1, 0)
147 |    end
148 |    for i, m in ipairs(model:findModules('nn.Concat')) do
149 |       if cache[i % 2] == nil then
150 |          cache[i % 2] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
151 |       end
152 |       m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[i % 2], 1, 0)
153 |    end
154 | end
155 | 
156 | function M.sharePrevOutput(model, opt)
157 |    -- Share contiguous output for memory efficient densenet
158 |    local buffer = nil
159 |    model:apply(function(m)
160 |       local moduleType = torch.type(m)
161 |       if moduleType == 'nn.DenseConnectLayerCustom' then
162 |          if buffer == nil then
163 |             buffer = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
164 |          end
165 |          m.input_c = torch[opt.tensorType:match('torch.(%a+)')](buffer, 1, 0)
166 |       end
167 |    end)
168 | end
169 | 
170 | function M.shareBNOutput(model, opt)
171 |    -- Share BN.output for memory efficient densenet
172 |    local buffer = nil
173 |    model:apply(function(m)
174 |       local moduleType = torch.type(m)
175 |       if moduleType == 'nn.DenseConnectLayerCustom' then
176 |          if buffer == nil then
177 |             buffer = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
178 |          end
179 |          m.net1:get(1).output = torch[opt.tensorType:match('torch.(%a+)')](buffer, 1, 0)
180 |       end
181 |    end)
182 | end
183 | 
184 | function M.copyModel(t, s)
185 |    local wt, ws = t:parameters(), s:parameters()
186 |    assert(#wt==#ws, 'Model configurations does not match the resumed model!')
187 |    for l = 1, #wt do
188 |       wt[l]:copy(ws[l])
189 |    end
190 |    local bn_t, bn_s = {}, {}
191 |    for i, m in ipairs(s:findModules('cudnn.SpatialBatchNormalization')) do
192 |       bn_s[i] = m
193 |    end
194 |    for i, m in ipairs(t:findModules('cudnn.SpatialBatchNormalization')) do
195 |       bn_t[i] = m
196 |    end
197 |    assert(#bn_t==#bn_s, 'Model configurations does not match the resumed model!')
198 |    for i = 1, #bn_s do
199 |       bn_t[i].running_mean:copy(bn_s[i].running_mean)
200 |       bn_t[i].running_var:copy(bn_s[i].running_var) 
201 |    end
202 | end
203 | 
204 | return M
205 | 


--------------------------------------------------------------------------------
/opts.lua:
--------------------------------------------------------------------------------
  1 | --
  2 | --  Copyright (c) 2016, Facebook, Inc.
  3 | --  All rights reserved.
  4 | --
  5 | --  This source code is licensed under the BSD-style license found in the
  6 | --  LICENSE file in the root directory of this source tree. An additional grant
  7 | --  of patent rights can be found in the PATENTS file in the same directory.
  8 | --
  9 | -- Code modified for DenseNet (https://arxiv.org/abs/1608.06993) by Gao Huang.
 10 | -- 
 11 | local M = { }
 12 | 
 13 | function M.parse(arg)
 14 |    local cmd = torch.CmdLine()
 15 |    cmd:text()
 16 |    cmd:text('Torch-7 ResNet Training script')
 17 |    cmd:text('See https://github.com/facebook/fb.resnet.torch/blob/master/TRAINING.md for examples')
 18 |    cmd:text()
 19 |    cmd:text('Options:')
 20 |     ------------ General options --------------------
 21 |    cmd:option('-data',       '',         'Path to dataset')
 22 |    cmd:option('-dataset',    'cifar10', 'Options: imagenet | cifar10 | cifar100')
 23 |    cmd:option('-manualSeed', 0,          'Manually set RNG seed')
 24 |    cmd:option('-nGPU',       1,          'Number of GPUs to use by default')
 25 |    cmd:option('-backend',    'cudnn',    'Options: cudnn | cunn')
 26 |    cmd:option('-cudnn',      'fastest',  'Options: fastest | default | deterministic')
 27 |    cmd:option('-gen',        'gen',      'Path to save generated files')
 28 |    cmd:option('-precision', 'single',    'Options: single | double | half')
 29 |    ------------- Data options ------------------------
 30 |    cmd:option('-nThreads',        2, 'number of data loading threads')
 31 |    ------------- Training options --------------------
 32 |    cmd:option('-nEpochs',         0,       'Number of total epochs to run')
 33 |    cmd:option('-epochNumber',     1,       'Manual epoch number (useful on restarts)')
 34 |    cmd:option('-batchSize',       32,      'mini-batch size (1 = pure stochastic)')
 35 |    cmd:option('-testOnly',        'false', 'Run on validation set only')
 36 |    cmd:option('-tenCrop',         'false', 'Ten-crop testing')
 37 |    ------------- Checkpointing options ---------------
 38 |    cmd:option('-save',            'checkpoints', 'Directory in which to save checkpoints')
 39 |    cmd:option('-resume',          'none',        'Resume from the latest checkpoint in this directory')
 40 |    ---------- Optimization options ----------------------
 41 |    cmd:option('-LR',              0.1,   'initial learning rate')
 42 |    cmd:option('-momentum',        0.9,   'momentum')
 43 |    cmd:option('-weightDecay',     1e-4,  'weight decay')
 44 |    cmd:option('-lrShape',         'multistep',    'Learning rate: multistep|cosine')
 45 |    ---------- Model options ----------------------------------
 46 |    cmd:option('-netType',      'resnet', 'Options: resnet | preresnet')
 47 |    cmd:option('-depth',        20,       'ResNet depth: 18 | 34 | 50 | 101 | ...', 'number')
 48 |    cmd:option('-shortcutType', '',       'Options: A | B | C')
 49 |    cmd:option('-retrain',      'none',   'Path to model to retrain with')
 50 |    cmd:option('-optimState',   'none',   'Path to an optimState to reload from')
 51 |    ---------- Model options ----------------------------------
 52 |    cmd:option('-shareGradInput',  'false', 'Share gradInput tensors to reduce memory usage')
 53 |    cmd:option('-optnet',          'false', 'Use optnet to reduce memory usage')
 54 |    cmd:option('-resetClassifier', 'false', 'Reset the fully connected layer for fine-tuning')
 55 |    cmd:option('-nClasses',         0,      'Number of classes in the dataset')
 56 |    ---------- Model options for DenseNet ----------------------------------
 57 |    cmd:option('-growthRate',        12,    'Number of output channels at each convolutional layer')
 58 |    cmd:option('-bottleneck',       'true', 'Use 1x1 convolution to reduce dimension (DenseNet-B)')
 59 |    cmd:option('-reduction',         0.5,   'Channel compress ratio at transition layer (DenseNet-C)')
 60 |    cmd:option('-dropRate',           0,    'Dropout probability')
 61 |    cmd:option('-optMemory',          2,    'Optimize memory for DenseNet: 0 | 1 | 2 | 3 | 4 | 5', 'number')
 62 |    -- The following hyperparameters are activated when depth is not from {121, 161, 169, 201} (for ImageNet only)
 63 |    cmd:option('-d1',                 0,    'Number of layers in block 1')
 64 |    cmd:option('-d2',                 0,    'Number of layers in block 2')
 65 |    cmd:option('-d3',                 0,    'Number of layers in block 3')
 66 |    cmd:option('-d4',                 0,    'Number of layers in block 4')
 67 | 
 68 |    cmd:text()
 69 | 
 70 |    local opt = cmd:parse(arg or {})
 71 | 
 72 |    opt.testOnly = opt.testOnly ~= 'false'
 73 |    opt.tenCrop = opt.tenCrop ~= 'false'
 74 |    opt.shareGradInput = opt.shareGradInput ~= 'false'
 75 |    opt.optnet = opt.optnet ~= 'false'
 76 |    opt.resetClassifier = opt.resetClassifier ~= 'false'
 77 |    opt.bottleneck = opt.bottleneck ~= 'false'
 78 | 
 79 |    if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then
 80 |       cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n')
 81 |    end
 82 | 
 83 |    if opt.dataset == 'imagenet' then
 84 |       -- Handle the most common case of missing -data flag
 85 |       local trainDir = paths.concat(opt.data, 'train')
 86 |       if not paths.dirp(opt.data) then
 87 |          cmd:error('error: missing ImageNet data directory')
 88 |       elseif not paths.dirp(trainDir) then
 89 |          cmd:error('error: ImageNet missing `train` directory: ' .. trainDir)
 90 |       end
 91 |       -- Default shortcutType=B and nEpochs=90
 92 |       opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType
 93 |       opt.nEpochs = opt.nEpochs == 0 and 90 or opt.nEpochs
 94 |    elseif opt.dataset == 'cifar10' then
 95 |       -- Default shortcutType=A and nEpochs=164
 96 |       opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType
 97 |       opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs
 98 |    elseif opt.dataset == 'cifar100' then
 99 |        -- Default shortcutType=A and nEpochs=164
100 |        opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType
101 |        opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs
102 |    else
103 |       cmd:error('unknown dataset: ' .. opt.dataset)
104 |    end
105 | 
106 |    if opt.precision == nil or opt.precision == 'single' then
107 |       opt.tensorType = 'torch.CudaTensor'
108 |    elseif opt.precision == 'double' then
109 |       opt.tensorType = 'torch.CudaDoubleTensor'
110 |    elseif opt.precision == 'half' then
111 |       opt.tensorType = 'torch.CudaHalfTensor'
112 |    else
113 |       cmd:error('unknown precision: ' .. opt.precision)
114 |    end
115 | 
116 |    if opt.resetClassifier then
117 |       if opt.nClasses == 0 then
118 |          cmd:error('-nClasses required when resetClassifier is set')
119 |       end
120 |    end
121 |    if opt.shareGradInput and opt.optnet then
122 |       cmd:error('error: cannot use both -shareGradInput and -optnet')
123 |    end
124 | 
125 |    return opt
126 | end
127 | 
128 | return M
129 | 


--------------------------------------------------------------------------------
/train.lua:
--------------------------------------------------------------------------------
  1 | --
  2 | --  Copyright (c) 2016, Facebook, Inc.
  3 | --  All rights reserved.
  4 | --
  5 | --  This source code is licensed under the BSD-style license found in the
  6 | --  LICENSE file in the root directory of this source tree. An additional grant
  7 | --  of patent rights can be found in the PATENTS file in the same directory.
  8 | --
  9 | --  The training loop and learning rate schedule
 10 | --
 11 | -- Code modified for DenseNet (https://arxiv.org/abs/1608.06993) by Gao Huang.
 12 | -- 
 13 | 
 14 | local optim = require 'optim'
 15 | 
 16 | local M = {}
 17 | local Trainer = torch.class('resnet.Trainer', M)
 18 | 
 19 | function Trainer:__init(model, criterion, opt, optimState)
 20 |    self.model = model
 21 |    self.criterion = criterion
 22 |    self.optimState = optimState or {
 23 |       learningRate = opt.LR,
 24 |       learningRateDecay = 0.0,
 25 |       momentum = opt.momentum,
 26 |       nesterov = true,
 27 |       dampening = 0.0,
 28 |       weightDecay = opt.weightDecay,
 29 |    }
 30 |    self.opt = opt
 31 |    self.params, self.gradParams = model:getParameters()
 32 | end
 33 | 
 34 | function Trainer:train(epoch, dataloader)
 35 |    -- Trains the model for a single epoch
 36 | 
 37 |    ------for LR------
 38 |    if self.opt.lrShape == 'multistep' then
 39 |       self.optimState.learningRate = self:learningRate(epoch)
 40 |    end
 41 |    ------for LR------
 42 | 
 43 |    local timer = torch.Timer()
 44 |    local dataTimer = torch.Timer()
 45 | 
 46 |    local function feval()
 47 |       return self.criterion.output, self.gradParams
 48 |    end
 49 | 
 50 |    local trainSize = dataloader:size()
 51 |    local top1Sum, top5Sum, lossSum = 0.0, 0.0, 0.0
 52 |    local N = 0
 53 | 
 54 |    print('=> Training epoch # ' .. epoch)
 55 |    -- set the batch norm to training mode
 56 |    self.model:training()
 57 |    for n, sample in dataloader:run() do
 58 | 
 59 |       ------for LR------
 60 |       if self.opt.lrShape == 'cosine' then
 61 |          self.optimState.learningRate = self:learningRateCosine(epoch, n, trainSize)
 62 |       end
 63 |       ------for LR------
 64 | 
 65 |       local dataTime = dataTimer:time().real
 66 | 
 67 |       -- Copy input and target to the GPU
 68 |       self:copyInputs(sample)
 69 | 
 70 |       local output = self.model:forward(self.input):float()
 71 |       local batchSize = output:size(1)
 72 |       local loss = self.criterion:forward(self.model.output, self.target)
 73 | 
 74 |       self.model:zeroGradParameters()
 75 |       self.criterion:backward(self.model.output, self.target)
 76 |       self.model:backward(self.input, self.criterion.gradInput)
 77 | 
 78 |       optim.sgd(feval, self.params, self.optimState)
 79 | 
 80 |       local top1, top5 = self:computeScore(output, sample.target, 1)
 81 |       top1Sum = top1Sum + top1*batchSize
 82 |       top5Sum = top5Sum + top5*batchSize
 83 |       lossSum = lossSum + loss*batchSize
 84 |       N = N + batchSize
 85 | 
 86 |       print((' | Epoch: [%d][%d/%d]   Time %.3f  Data %.3f  Err %1.3f  top1 %7.2f  top5 %7.2f  lr %.4f'):format(
 87 |          epoch, n, trainSize, timer:time().real, dataTime, loss, top1, top5, self.optimState.learningRate))
 88 | 
 89 |       -- check that the storage didn't get changed due to an unfortunate getParameters call
 90 |       assert(self.params:storage() == self.model:parameters()[1]:storage())
 91 | 
 92 |       timer:reset()
 93 |       dataTimer:reset()
 94 |    end
 95 | 
 96 |    return top1Sum / N, top5Sum / N, lossSum / N
 97 | end
 98 | 
 99 | function Trainer:test(epoch, dataloader)
100 |    -- Computes the top-1 and top-5 err on the validation set
101 | 
102 |    local timer = torch.Timer()
103 |    local dataTimer = torch.Timer()
104 |    local size = dataloader:size()
105 | 
106 |    local nCrops = self.opt.tenCrop and 10 or 1
107 |    local top1Sum, top5Sum = 0.0, 0.0
108 |    local N = 0
109 | 
110 |    self.model:evaluate()
111 |    for n, sample in dataloader:run() do
112 |       local dataTime = dataTimer:time().real
113 | 
114 |       -- Copy input and target to the GPU
115 |       self:copyInputs(sample)
116 | 
117 |       local output = self.model:forward(self.input):float()
118 |       local batchSize = output:size(1) / nCrops
119 |       local loss = self.criterion:forward(self.model.output, self.target)
120 | 
121 |       local top1, top5 = self:computeScore(output, sample.target, nCrops)
122 |       top1Sum = top1Sum + top1*batchSize
123 |       top5Sum = top5Sum + top5*batchSize
124 |       N = N + batchSize
125 | 
126 |       print((' | Test: [%d][%d/%d]    Time %.3f  Data %.3f  top1 %7.3f (%7.3f)  top5 %7.3f (%7.3f)'):format(
127 |          epoch, n, size, timer:time().real, dataTime, top1, top1Sum / N, top5, top5Sum / N))
128 | 
129 |       timer:reset()
130 |       dataTimer:reset()
131 |    end
132 |    self.model:training()
133 | 
134 |    print((' * Finished epoch # %d     top1: %7.3f  top5: %7.3f\n'):format(
135 |       epoch, top1Sum / N, top5Sum / N))
136 | 
137 |    return top1Sum / N, top5Sum / N
138 | end
139 | 
140 | function Trainer:computeScore(output, target, nCrops)
141 |    if nCrops > 1 then
142 |       -- Sum over crops
143 |       output = output:view(output:size(1) / nCrops, nCrops, output:size(2))
144 |          --:exp()
145 |          :sum(2):squeeze(2)
146 |    end
147 | 
148 |    -- Coputes the top1 and top5 error rate
149 |    local batchSize = output:size(1)
150 | 
151 |    local _ , predictions = output:float():topk(5, 2, true, true) -- descending
152 | 
153 |    -- Find which predictions match the target
154 |    local correct = predictions:eq(
155 |       target:long():view(batchSize, 1):expandAs(predictions))
156 | 
157 |    -- Top-1 score
158 |    local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize)
159 | 
160 |    -- Top-5 score, if there are at least 5 classes
161 |    local len = math.min(5, correct:size(2))
162 |    local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize)
163 | 
164 |    return top1 * 100, top5 * 100
165 | end
166 | 
167 | local function getCudaTensorType(tensorType)
168 |   if tensorType == 'torch.CudaHalfTensor' then
169 |      return cutorch.createCudaHostHalfTensor()
170 |   elseif tensorType == 'torch.CudaDoubleTensor' then
171 |     return cutorch.createCudaHostDoubleTensor()
172 |   else
173 |      return cutorch.createCudaHostTensor()
174 |   end
175 | end
176 | 
177 | function Trainer:copyInputs(sample)
178 |    -- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory,
179 |    -- if using DataParallelTable. The target is always copied to a CUDA tensor
180 |    self.input = self.input or (self.opt.nGPU == 1
181 |       and torch[self.opt.tensorType:match('torch.(%a+)')]()
182 |       or getCudaTensorType(self.opt.tensorType))
183 |    self.target = self.target or (torch.CudaLongTensor and torch.CudaLongTensor())
184 |    self.input:resize(sample.input:size()):copy(sample.input)
185 |    self.target:resize(sample.target:size()):copy(sample.target)
186 | end
187 | 
188 | function Trainer:learningRate(epoch)
189 |    -- Training schedule
190 |    local decay = 0
191 |    if self.opt.dataset == 'imagenet' then
192 |       decay = math.floor((epoch - 1) / 30)
193 |    elseif self.opt.dataset == 'cifar10' then
194 |       decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0
195 |    elseif self.opt.dataset == 'cifar100' then
196 |       decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0
197 |    end
198 |    return self.opt.LR * math.pow(0.1, decay)
199 | end
200 | 
201 | ------for LR------
202 | function Trainer:learningRateCosine(epoch, iter, nBatches)
203 |    local nEpochs = self.opt.nEpochs
204 |    local T_total = nEpochs * nBatches
205 |    local T_cur = ((epoch-1) % nEpochs) * nBatches + iter
206 |    return 0.5 * self.opt.LR * (1 + torch.cos(math.pi * T_cur / T_total))
207 | end
208 | ------for LR------
209 | 
210 | return M.Trainer
211 | 


--------------------------------------------------------------------------------