├── LICENSE ├── README.md ├── checkpoints.lua ├── dataloader.lua ├── datasets ├── README.md ├── cifar10-gen.lua ├── cifar10.lua ├── cifar100-gen.lua ├── cifar100.lua ├── imagenet-gen.lua ├── imagenet.lua ├── init.lua └── transforms.lua ├── main.lua ├── models ├── DenseConnectLayer.lua ├── README.md ├── densenet.lua └── init.lua ├── opts.lua └── train.lua /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Zhuang Liu. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name DenseNet nor the names of its contributors may be used to 15 | endorse or promote products derived from this software without specific 16 | prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Densely Connected Convolutional Networks (DenseNets) 2 | 3 | This repository contains the code for DenseNet introduced in the following paper 4 | 5 | [Densely Connected Convolutional Networks](http://arxiv.org/abs/1608.06993) (CVPR 2017, Best Paper Award) 6 | 7 | [Gao Huang](http://www.cs.cornell.edu/~gaohuang/)\*, [Zhuang Liu](https://liuzhuang13.github.io/)\*, [Laurens van der Maaten](https://lvdmaaten.github.io/) and [Kilian Weinberger](https://www.cs.cornell.edu/~kilian/) (\* Authors contributed equally). 8 | 9 | 10 | **Now with much more memory efficient implementation!** Please check the [technical report](https://arxiv.org/pdf/1707.06990.pdf) and [code](https://github.com/liuzhuang13/DenseNet/tree/master/models) for more infomation. 11 | 12 | The code is built on [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch). 13 | 14 | ### Citation 15 | If you find DenseNet useful in your research, please consider citing: 16 | 17 | @inproceedings{DenseNet2017, 18 | title={Densely connected convolutional networks}, 19 | author={Huang, Gao and Liu, Zhuang and van der Maaten, Laurens and Weinberger, Kilian Q }, 20 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 21 | year={2017} 22 | } 23 | 24 | 25 | ## Other Implementations 26 | Our [[Caffe]](https://github.com/liuzhuang13/DenseNetCaffe), Our memory-efficient [[Caffe]](https://github.com/Tongcheng/DN_CaffeScript), Our memory-efficient [[PyTorch]](https://github.com/gpleiss/efficient_densenet_pytorch), 27 | [[PyTorch]](https://github.com/andreasveit/densenet-pytorch) by Andreas Veit, [[PyTorch]](https://github.com/bamos/densenet.pytorch) by Brandon Amos, [[PyTorch]](https://github.com/baldassarreFe/pytorch-densenet-tiramisu) by Federico Baldassarre, 28 | [[MXNet]](https://github.com/Nicatio/Densenet/tree/master/mxnet) by Nicatio, 29 | [[MXNet]](https://github.com/bruinxiong/densenet.mxnet) by Xiong Lin, 30 | [[MXNet]](https://github.com/miraclewkf/DenseNet) by miraclewkf, 31 | [[Tensorflow]](https://github.com/YixuanLi/densenet-tensorflow) by Yixuan Li, 32 | [[Tensorflow]](https://github.com/LaurentMazare/deep-models/tree/master/densenet) by Laurent Mazare, 33 | [[Tensorflow]](https://github.com/ikhlestov/vision_networks) by Illarion Khlestov, 34 | [[Lasagne]](https://github.com/Lasagne/Recipes/tree/master/papers/densenet) by Jan Schlüter, 35 | [[Keras]](https://github.com/tdeboissiere/DeepLearningImplementations/tree/master/DenseNet) by tdeboissiere, 36 | [[Keras]](https://github.com/robertomest/convnet-study) by Roberto de Moura Estevão Filho, 37 | [[Keras]](https://github.com/titu1994/DenseNet) by Somshubra Majumdar, 38 | [[Chainer]](https://github.com/t-hanya/chainer-DenseNet) by Toshinori Hanya, 39 | [[Chainer]](https://github.com/yasunorikudo/chainer-DenseNet) by Yasunori Kudo, 40 | [[Torch 3D-DenseNet]](https://github.com/barrykui/3ddensenet.torch) by Barry Kui, 41 | [[Keras]](https://github.com/cmasch/densenet) by Christopher Masch, 42 | [[Tensorflow2]](https://github.com/okason97/DenseNet-Tensorflow2) by Gaston Rios and Ulises Jeremias Cornejo Fandos. 43 | 44 | 45 | 46 | Note that we only listed some early implementations here. If you would like to add yours, please submit a pull request. 47 | 48 | ## Some Following up Projects 49 | 0. [Multi-Scale Dense Convolutional Networks for Efficient Prediction](https://github.com/gaohuang/MSDNet) 50 | 0. [DSOD: Learning Deeply Supervised Object Detectors from Scratch](https://github.com/szq0214/DSOD) 51 | 0. [CondenseNet: An Efficient DenseNet using Learned Group Convolutions](https://github.com/ShichenLiu/CondenseNet) 52 | 0. [Fully Convolutional DenseNets for Semantic Segmentation](https://github.com/SimJeg/FC-DenseNet) 53 | 0. [Pelee: A Real-Time Object Detection System on Mobile Devices](https://github.com/Robert-JunWang/Pelee) 54 | 55 | 56 | 57 | 58 | ## Contents 59 | 1. [Introduction](#introduction) 60 | 2. [Usage](#usage) 61 | 3. [Results on CIFAR](#results-on-cifar) 62 | 4. [Results on ImageNet and Pretrained Models](#results-on-imagenet-and-pretrained-models) 63 | 5. [Updates](#updates) 64 | 65 | 66 | ## Introduction 67 | DenseNet is a network architecture where each layer is directly connected to every other layer in a feed-forward fashion (within each *dense block*). For each layer, the feature maps of all preceding layers are treated as separate inputs whereas its own feature maps are passed on as inputs to all subsequent layers. This connectivity pattern yields state-of-the-art accuracies on CIFAR10/100 (with or without data augmentation) and SVHN. On the large scale ILSVRC 2012 (ImageNet) dataset, DenseNet achieves a similar accuracy as ResNet, but using less than half the amount of parameters and roughly half the number of FLOPs. 68 | 69 | 70 | 71 | Figure 1: A dense block with 5 layers and growth rate 4. 72 | 73 | 74 | ![densenet](https://cloud.githubusercontent.com/assets/8370623/17981496/fa648b32-6ad1-11e6-9625-02fdd72fdcd3.jpg) 75 | Figure 2: A deep DenseNet with three dense blocks. 76 | 77 | 78 | ## Usage 79 | 0. Install Torch and required dependencies like cuDNN. See the instructions [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md) for a step-by-step guide. 80 | 1. Clone this repo: ```git clone https://github.com/liuzhuang13/DenseNet.git``` 81 | 82 | As an example, the following command trains a DenseNet-BC with depth L=100 and growth rate k=12 on CIFAR-10: 83 | ``` 84 | th main.lua -netType densenet -dataset cifar10 -batchSize 64 -nEpochs 300 -depth 100 -growthRate 12 85 | ``` 86 | As another example, the following command trains a DenseNet-BC with depth L=121 and growth rate k=32 on ImageNet: 87 | ``` 88 | th main.lua -netType densenet -dataset imagenet -data [dataFolder] -batchSize 256 -nEpochs 90 -depth 121 -growthRate 32 -nGPU 4 -nThreads 16 -optMemory 3 89 | ``` 90 | Please refer to [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) for data preparation. 91 | 92 | ### DenseNet and DenseNet-BC 93 | By default, the code runs with the DenseNet-BC architecture, which has 1x1 convolutional *bottleneck* layers, and *compresses* the number of channels at each transition layer by 0.5. To run with the original DenseNet, simply use the options *-bottleneck false* and *-reduction 1* 94 | 95 | ### Memory efficient implementation (newly added feature on June 6, 2017) 96 | There is an option *-optMemory* which is very useful for reducing GPU memory footprint when training a DenseNet. By default, the value is set to 2, which activates the *shareGradInput* function (with small modifications from [here](https://github.com/facebook/fb.resnet.torch/blob/master/models/init.lua#L102)). There are two extreme memory efficient modes (*-optMemory 3* or *-optMemory 4*) which use a customized densely connected layer. With *-optMemory 4*, the largest 190-layer DenseNet-BC on CIFAR can be trained on a single NVIDIA TitanX GPU (uses 8.3G of 12G) instead of fully using four GPUs with the standard (recursive concatenation) implementation . 97 | 98 | More details about the memory efficient implementation are discussed [here](https://github.com/liuzhuang13/DenseNet/tree/master/models). 99 | 100 | 101 | ## Results on CIFAR 102 | The table below shows the results of DenseNets on CIFAR datasets. The "+" mark at the end denotes for standard data augmentation (random crop after zero-padding, and horizontal flip). For a DenseNet model, L denotes its depth and k denotes its growth rate. On CIFAR-10 and CIFAR-100 without data augmentation, a Dropout layer with drop rate 0.2 is introduced after each convolutional layer except the very first one. 103 | 104 | Model | Parameters| CIFAR-10 | CIFAR-10+ | CIFAR-100 | CIFAR-100+ 105 | -------|:-------:|:--------:|:--------:|:--------:|:--------:| 106 | DenseNet (L=40, k=12) |1.0M |7.00 |5.24 | 27.55|24.42 107 | DenseNet (L=100, k=12)|7.0M |5.77 |4.10 | 23.79|20.20 108 | DenseNet (L=100, k=24)|27.2M |5.83 |3.74 | 23.42|19.25 109 | DenseNet-BC (L=100, k=12)|0.8M |5.92 |4.51 | 24.15|22.27 110 | DenseNet-BC (L=250, k=24)|15.3M |**5.19** |3.62 | **19.64**|17.60 111 | DenseNet-BC (L=190, k=40)|25.6M |- |**3.46** | -|**17.18** 112 | 113 | 114 | ## Results on ImageNet and Pretrained Models 115 | ### Torch 116 | 117 | **Note: the pre-trained models in Torch are deprecated and no longer maintained. Please use PyTorch's pre-trained [DenseNet models](https://pytorch.org/vision/stable/models.html) instead.** 118 | 119 | #### Models in the original paper 120 | The Torch models are trained under the same setting as in [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch). The error rates shown are 224x224 1-crop test errors. 121 | 122 | | Network | Top-1 error | Torch Model | 123 | | ------------- | ----------- | ----------- | 124 | | DenseNet-121 (k=32) | 25.0 | [Download (64.5MB)] | 125 | | DenseNet-169 (k=32) | 23.6 | [Download (114.4MB)] | 126 | | DenseNet-201 (k=32) | 22.5 | [Download (161.8MB)] | 127 | | DenseNet-161 (k=48) | 22.2 | [Download (230.8MB)] 128 | 129 | #### Models in the tech report 130 | More accurate models trained with the memory efficient implementation in the [technical report](https://arxiv.org/pdf/1707.06990.pdf). 131 | 132 | 133 | | Network | Top-1 error | Torch Model | 134 | | ------------- | ----------- | ------------ | 135 | | DenseNet-264 (k=32) | 22.1 | [Download (256MB)] 136 | | DenseNet-232 (k=48) | 21.2 | [Download (426MB)] 137 | | DenseNet-cosine-264 (k=32) | 21.6 | [Download (256MB)] 138 | | DenseNet-cosine-264 (k=48) | 20.4 | [Download (557MB)] 139 | 140 | 141 | ### Caffe 142 | https://github.com/shicai/DenseNet-Caffe. 143 | 144 | ### PyTorch 145 | [PyTorch documentation on models](http://pytorch.org/docs/torchvision/models.html?highlight=densenet). We would like to thank @gpleiss for this nice work in PyTorch. 146 | 147 | ### Keras, Tensorflow and Theano 148 | https://github.com/flyyufelix/DenseNet-Keras. 149 | 150 | ### MXNet 151 | https://github.com/miraclewkf/DenseNet. 152 | 153 | 154 | ## Wide-DenseNet for better Time/Accuracy and Memory/Accuracy Tradeoff 155 | 156 | If you use DenseNet as a model in your learning task, to reduce the memory and time consumption, we recommend use a wide and shallow DenseNet, following the strategy of [wide residual networks](https://github.com/szagoruyko/wide-residual-networks). To obtain a wide DenseNet we set the depth to be smaller (e.g., L=40) and the growthRate to be larger (e.g., k=48). 157 | 158 | We test a set of Wide-DenseNet-BCs and compared the memory and time with the DenseNet-BC (L=100, k=12) shown above. We obtained the statistics using a single TITAN X card, with batch size 64, and without any memory optimization. 159 | 160 | 161 | Model | Parameters| CIFAR-10+ | CIFAR-100+ | Time per Iteration | Memory 162 | -------|:-------:|:--------:|:--------:|:--------:|:--------:| 163 | DenseNet-BC (L=100, k=12)|0.8M |4.51 |22.27 | 0.156s | 5452MB 164 | Wide-DenseNet-BC (L=40, k=36)|1.5M |4.58 |22.30 | 0.130s|4008MB 165 | Wide-DenseNet-BC (L=40, k=48)|2.7M |3.99 |20.29 | 0.165s|5245MB 166 | Wide-DenseNet-BC (L=40, k=60)|4.3M |4.01 |19.99 | 0.223s|6508MB 167 | 168 | Obersevations: 169 | 170 | 1. Wide-DenseNet-BC (L=40, k=36) uses less memory/time while achieves about the same accuracy as DenseNet-BC (L=100, k=12). 171 | 2. Wide-DenseNet-BC (L=40, k=48) uses about the same memory/time as DenseNet-BC (L=100, k=12), while is much more accurate. 172 | 173 | Thus, for practical use, we suggest picking one model from those Wide-DenseNet-BCs. 174 | 175 | 176 | 177 | ## Updates 178 | **08/23/2017:** 179 | 180 | 1. Add supporting code, so one can simply *git clone* and run. 181 | 182 | **06/06/2017:** 183 | 184 | 1. Support **ultra memory efficient** training of DenseNet with *customized densely connected layer*. 185 | 186 | 2. Support **memory efficient** training of DenseNet with *standard densely connected layer* (recursive concatenation) by fixing the *shareGradInput* function. 187 | 188 | 05/17/2017: 189 | 190 | 1. Add Wide-DenseNet. 191 | 2. Add keras, tf, theano link for pretrained models. 192 | 193 | 04/20/2017: 194 | 195 | 1. Add usage of models in PyTorch. 196 | 197 | 03/29/2017: 198 | 199 | 1. Add the code for imagenet training. 200 | 201 | 12/03/2016: 202 | 203 | 1. Add Imagenet results and pretrained models. 204 | 2. Add DenseNet-BC structures. 205 | 206 | 207 | 208 | ## Contact 209 | liuzhuangthu at gmail.com 210 | Any discussions, suggestions and questions are welcome! 211 | 212 | -------------------------------------------------------------------------------- /checkpoints.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | local checkpoint = {} 10 | 11 | local function deepCopy(tbl) 12 | -- creates a copy of a network with new modules and the same tensors 13 | local copy = {} 14 | for k, v in pairs(tbl) do 15 | if type(v) == 'table' then 16 | copy[k] = deepCopy(v) 17 | else 18 | copy[k] = v 19 | end 20 | end 21 | if torch.typename(tbl) then 22 | torch.setmetatable(copy, torch.typename(tbl)) 23 | end 24 | return copy 25 | end 26 | 27 | function checkpoint.latest(opt) 28 | if opt.resume == 'none' then 29 | return nil 30 | end 31 | 32 | local latestPath = paths.concat(opt.resume, 'latest.t7') 33 | if not paths.filep(latestPath) then 34 | return nil 35 | end 36 | 37 | print('=> Loading checkpoint ' .. latestPath) 38 | local latest = torch.load(latestPath) 39 | local optimState = torch.load(paths.concat(opt.resume, latest.optimFile)) 40 | 41 | return latest, optimState 42 | end 43 | 44 | function checkpoint.save(epoch, model, optimState, isBestModel, opt) 45 | -- don't save the DataParallelTable for easier loading on other machines 46 | if torch.type(model) == 'nn.DataParallelTable' then 47 | model = model:get(1) 48 | end 49 | 50 | -- create a clean copy on the CPU without modifying the original network 51 | model = deepCopy(model):float():clearState() 52 | 53 | local modelFile = 'model_' .. epoch .. '.t7' 54 | local optimFile = 'optimState_' .. epoch .. '.t7' 55 | 56 | torch.save(paths.concat(opt.save, modelFile), model) 57 | torch.save(paths.concat(opt.save, optimFile), optimState) 58 | torch.save(paths.concat(opt.save, 'latest.t7'), { 59 | epoch = epoch, 60 | modelFile = modelFile, 61 | optimFile = optimFile, 62 | }) 63 | 64 | if isBestModel then 65 | torch.save(paths.concat(opt.save, 'model_best.t7'), model) 66 | end 67 | end 68 | 69 | return checkpoint 70 | -------------------------------------------------------------------------------- /dataloader.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Multi-threaded data loader 10 | -- 11 | 12 | local datasets = require 'datasets/init' 13 | local Threads = require 'threads' 14 | Threads.serialization('threads.sharedserialize') 15 | 16 | local M = {} 17 | local DataLoader = torch.class('resnet.DataLoader', M) 18 | 19 | function DataLoader.create(opt) 20 | -- The train and val loader 21 | local loaders = {} 22 | 23 | for i, split in ipairs{'train', 'val'} do 24 | local dataset = datasets.create(opt, split) 25 | loaders[i] = M.DataLoader(dataset, opt, split) 26 | end 27 | 28 | return table.unpack(loaders) 29 | end 30 | 31 | function DataLoader:__init(dataset, opt, split) 32 | local manualSeed = opt.manualSeed 33 | local function init() 34 | require('datasets/' .. opt.dataset) 35 | end 36 | local function main(idx) 37 | if manualSeed ~= 0 then 38 | torch.manualSeed(manualSeed + idx) 39 | end 40 | torch.setnumthreads(1) 41 | _G.dataset = dataset 42 | _G.preprocess = dataset:preprocess() 43 | return dataset:size() 44 | end 45 | 46 | local threads, sizes = Threads(opt.nThreads, init, main) 47 | self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1 48 | self.threads = threads 49 | self.__size = sizes[1][1] 50 | self.batchSize = math.floor(opt.batchSize / self.nCrops) 51 | local function getCPUType(tensorType) 52 | if tensorType == 'torch.CudaHalfTensor' then 53 | return 'HalfTensor' 54 | elseif tensorType == 'torch.CudaDoubleTensor' then 55 | return 'DoubleTensor' 56 | else 57 | return 'FloatTensor' 58 | end 59 | end 60 | self.cpuType = getCPUType(opt.tensorType) 61 | end 62 | 63 | function DataLoader:size() 64 | return math.ceil(self.__size / self.batchSize) 65 | end 66 | 67 | function DataLoader:run() 68 | local threads = self.threads 69 | local size, batchSize = self.__size, self.batchSize 70 | local perm = torch.randperm(size) 71 | 72 | local idx, sample = 1, nil 73 | local function enqueue() 74 | while idx <= size and threads:acceptsjob() do 75 | local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) 76 | threads:addjob( 77 | function(indices, nCrops, cpuType) 78 | local sz = indices:size(1) 79 | local batch, imageSize 80 | local target = torch.IntTensor(sz) 81 | for i, idx in ipairs(indices:totable()) do 82 | local sample = _G.dataset:get(idx) 83 | local input = _G.preprocess(sample.input) 84 | if not batch then 85 | imageSize = input:size():totable() 86 | if nCrops > 1 then table.remove(imageSize, 1) end 87 | batch = torch[cpuType](sz, nCrops, table.unpack(imageSize)) 88 | end 89 | batch[i]:copy(input) 90 | target[i] = sample.target 91 | end 92 | collectgarbage() 93 | return { 94 | input = batch:view(sz * nCrops, table.unpack(imageSize)), 95 | target = target, 96 | } 97 | end, 98 | function(_sample_) 99 | sample = _sample_ 100 | end, 101 | indices, 102 | self.nCrops, 103 | self.cpuType 104 | ) 105 | idx = idx + batchSize 106 | end 107 | end 108 | 109 | local n = 0 110 | local function loop() 111 | enqueue() 112 | if not threads:hasjob() then 113 | return nil 114 | end 115 | threads:dojob() 116 | if threads:haserror() then 117 | threads:synchronize() 118 | end 119 | enqueue() 120 | n = n + 1 121 | return n, sample 122 | end 123 | 124 | return loop 125 | end 126 | 127 | return M.DataLoader 128 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | ## Datasets 2 | 3 | Each dataset consist of two files: `dataset-gen.lua` and `dataset.lua`. The `dataset-gen.lua` is responsible for one-time setup, while 4 | the `dataset.lua` handles the actual data loading. 5 | 6 | If you want to be able to use the new dataset from main.lua, you should also modify `opts.lua` to handle the new dataset name. 7 | 8 | ### `dataset-gen.lua` 9 | 10 | The `dataset-gen.lua` performs any necessary one-time setup. For example, the [`cifar10-gen.lua`](cifar10-gen.lua) file downloads the CIFAR-10 dataset, and the [`imagenet-gen.lua`](imagenet-gen.lua) file indexes all the training and validation data. 11 | 12 | The module should have a single function `exec(opt, cacheFile)`. 13 | - `opt`: the command line options 14 | - `cacheFile`: path to output 15 | 16 | ```lua 17 | local M = {} 18 | function M.exec(opt, cacheFile) 19 | local imageInfo = {} 20 | -- preprocess dataset, store results in imageInfo, save to cacheFile 21 | torch.save(cacheFile, imageInfo) 22 | end 23 | return M 24 | ``` 25 | 26 | ### `dataset.lua` 27 | 28 | The `dataset.lua` should return a class that implements three functions: 29 | - `get(i)`: returns a table containing two entries, `input` and `target` 30 | - `input`: the training or validation image as a Torch tensor 31 | - `target`: the image category as a number 1-N 32 | - `size()`: returns the number of entries in the dataset 33 | - `preprocess()`: returns a function that transforms the `input` for data augmentation or input normalization 34 | 35 | ```lua 36 | local M = {} 37 | local FakeDataset = torch.class('resnet.FakeDataset', M) 38 | 39 | function FakeDataset:__init(imageInfo, opt, split) 40 | -- imageInfo: result from dataset-gen.lua 41 | -- opt: command-line arguments 42 | -- split: "train" or "val" 43 | end 44 | 45 | function FakeDataset:get(i) 46 | return { 47 | input = torch.Tensor(3, 800, 600):uniform(), 48 | target = 42, 49 | } 50 | end 51 | 52 | function FakeDataset:size() 53 | -- size of dataset 54 | return 2000 55 | end 56 | 57 | function FakeDataset:preprocess() 58 | -- Scale smaller side to 256 and take 224x224 center-crop 59 | return t.Compose{ 60 | t.Scale(256), 61 | t.CenterCrop(224), 62 | } 63 | end 64 | 65 | return M.FakeDataset 66 | ``` 67 | -------------------------------------------------------------------------------- /datasets/cifar10-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Script to compute list of ImageNet filenames and classes 10 | -- 11 | -- This automatically downloads the CIFAR-10 dataset from 12 | -- http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz 13 | -- 14 | 15 | local URL = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz' 16 | 17 | local M = {} 18 | 19 | local function convertToTensor(files) 20 | local data, labels 21 | 22 | for _, file in ipairs(files) do 23 | local m = torch.load(file, 'ascii') 24 | if not data then 25 | data = m.data:t() 26 | labels = m.labels:squeeze() 27 | else 28 | data = torch.cat(data, m.data:t(), 1) 29 | labels = torch.cat(labels, m.labels:squeeze()) 30 | end 31 | end 32 | 33 | -- This is *very* important. The downloaded files have labels 0-9, which do 34 | -- not work with CrossEntropyCriterion 35 | labels:add(1) 36 | 37 | return { 38 | data = data:contiguous():view(-1, 3, 32, 32), 39 | labels = labels, 40 | } 41 | end 42 | 43 | function M.exec(opt, cacheFile) 44 | print("=> Downloading CIFAR-10 dataset from " .. URL) 45 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') 46 | assert(ok == true or ok == 0, 'error downloading CIFAR-10') 47 | 48 | print(" | combining dataset into a single file") 49 | local trainData = convertToTensor({ 50 | 'gen/cifar-10-batches-t7/data_batch_1.t7', 51 | 'gen/cifar-10-batches-t7/data_batch_2.t7', 52 | 'gen/cifar-10-batches-t7/data_batch_3.t7', 53 | 'gen/cifar-10-batches-t7/data_batch_4.t7', 54 | 'gen/cifar-10-batches-t7/data_batch_5.t7', 55 | }) 56 | local testData = convertToTensor({ 57 | 'gen/cifar-10-batches-t7/test_batch.t7', 58 | }) 59 | 60 | print(" | saving CIFAR-10 dataset to " .. cacheFile) 61 | torch.save(cacheFile, { 62 | train = trainData, 63 | val = testData, 64 | }) 65 | end 66 | 67 | return M 68 | -------------------------------------------------------------------------------- /datasets/cifar10.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- CIFAR-10 dataset loader 10 | -- 11 | 12 | local t = require 'datasets/transforms' 13 | 14 | local M = {} 15 | local CifarDataset = torch.class('resnet.CifarDataset', M) 16 | 17 | function CifarDataset:__init(imageInfo, opt, split) 18 | assert(imageInfo[split], split) 19 | self.imageInfo = imageInfo[split] 20 | self.split = split 21 | end 22 | 23 | function CifarDataset:get(i) 24 | local image = self.imageInfo.data[i]:float() 25 | local label = self.imageInfo.labels[i] 26 | 27 | return { 28 | input = image, 29 | target = label, 30 | } 31 | end 32 | 33 | function CifarDataset:size() 34 | return self.imageInfo.data:size(1) 35 | end 36 | 37 | -- Computed from entire CIFAR-10 training set 38 | local meanstd = { 39 | mean = {125.3, 123.0, 113.9}, 40 | std = {63.0, 62.1, 66.7}, 41 | } 42 | 43 | function CifarDataset:preprocess() 44 | if self.split == 'train' then 45 | return t.Compose{ 46 | t.ColorNormalize(meanstd), 47 | t.HorizontalFlip(0.5), 48 | t.RandomCrop(32, 4), 49 | } 50 | elseif self.split == 'val' then 51 | return t.ColorNormalize(meanstd) 52 | else 53 | error('invalid split: ' .. self.split) 54 | end 55 | end 56 | 57 | return M.CifarDataset 58 | -------------------------------------------------------------------------------- /datasets/cifar100-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | ------------ 11 | -- This file automatically downloads the CIFAR-100 dataset from 12 | -- http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz 13 | -- It is based on cifar10-gen.lua 14 | -- Ludovic Trottier 15 | ------------ 16 | 17 | local URL = 'http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz' 18 | 19 | local M = {} 20 | 21 | local function convertCifar100BinToTorchTensor(inputFname) 22 | local m=torch.DiskFile(inputFname, 'r'):binary() 23 | m:seekEnd() 24 | local length = m:position() - 1 25 | local nSamples = length / 3074 -- 1 coarse-label byte, 1 fine-label byte, 3072 pixel bytes 26 | 27 | assert(nSamples == math.floor(nSamples), 'expecting numSamples to be an exact integer') 28 | m:seek(1) 29 | 30 | local coarse = torch.ByteTensor(nSamples) 31 | local fine = torch.ByteTensor(nSamples) 32 | local data = torch.ByteTensor(nSamples, 3, 32, 32) 33 | for i=1,nSamples do 34 | coarse[i] = m:readByte() 35 | fine[i] = m:readByte() 36 | local store = m:readByte(3072) 37 | data[i]:copy(torch.ByteTensor(store)) 38 | end 39 | 40 | local out = {} 41 | out.data = data 42 | -- This is *very* important. The downloaded files have labels 0-9, which do 43 | -- not work with CrossEntropyCriterion 44 | out.labels = fine + 1 45 | 46 | return out 47 | end 48 | 49 | function M.exec(opt, cacheFile) 50 | print("=> Downloading CIFAR-100 dataset from " .. URL) 51 | 52 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') 53 | assert(ok == true or ok == 0, 'error downloading CIFAR-100') 54 | 55 | print(" | combining dataset into a single file") 56 | 57 | local trainData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/train.bin') 58 | local testData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/test.bin') 59 | 60 | print(" | saving CIFAR-100 dataset to " .. cacheFile) 61 | torch.save(cacheFile, { 62 | train = trainData, 63 | val = testData, 64 | }) 65 | end 66 | 67 | return M 68 | -------------------------------------------------------------------------------- /datasets/cifar100.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | ------------ 11 | -- This file is downloading and transforming CIFAR-100. 12 | -- It is based on cifar10.lua 13 | -- Ludovic Trottier 14 | ------------ 15 | 16 | local t = require 'datasets/transforms' 17 | 18 | local M = {} 19 | local CifarDataset = torch.class('resnet.CifarDataset', M) 20 | 21 | function CifarDataset:__init(imageInfo, opt, split) 22 | assert(imageInfo[split], split) 23 | self.imageInfo = imageInfo[split] 24 | self.split = split 25 | end 26 | 27 | function CifarDataset:get(i) 28 | local image = self.imageInfo.data[i]:float() 29 | local label = self.imageInfo.labels[i] 30 | 31 | return { 32 | input = image, 33 | target = label, 34 | } 35 | end 36 | 37 | function CifarDataset:size() 38 | return self.imageInfo.data:size(1) 39 | end 40 | 41 | 42 | -- Computed from entire CIFAR-100 training set with this code: 43 | -- dataset = torch.load('cifar100.t7') 44 | -- tt = dataset.train.data:double(); 45 | -- tt = tt:transpose(2,4); 46 | -- tt = tt:reshape(50000*32*32, 3); 47 | -- tt:mean(1) 48 | -- tt:std(1) 49 | local meanstd = { 50 | mean = {129.3, 124.1, 112.4}, 51 | std = {68.2, 65.4, 70.4}, 52 | } 53 | 54 | function CifarDataset:preprocess() 55 | if self.split == 'train' then 56 | return t.Compose{ 57 | t.ColorNormalize(meanstd), 58 | t.HorizontalFlip(0.5), 59 | t.RandomCrop(32, 4), 60 | } 61 | elseif self.split == 'val' then 62 | return t.ColorNormalize(meanstd) 63 | else 64 | error('invalid split: ' .. self.split) 65 | end 66 | end 67 | 68 | return M.CifarDataset 69 | -------------------------------------------------------------------------------- /datasets/imagenet-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Script to compute list of ImageNet filenames and classes 10 | -- 11 | -- This generates a file gen/imagenet.t7 which contains the list of all 12 | -- ImageNet training and validation images and their classes. This script also 13 | -- works for other datasets arragned with the same layout. 14 | -- 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | 19 | local M = {} 20 | 21 | local function findClasses(dir) 22 | local dirs = paths.dir(dir) 23 | table.sort(dirs) 24 | 25 | local classList = {} 26 | local classToIdx = {} 27 | for _ ,class in ipairs(dirs) do 28 | if not classToIdx[class] and class ~= '.' and class ~= '..' and class ~= '.DS_Store' then 29 | table.insert(classList, class) 30 | classToIdx[class] = #classList 31 | end 32 | end 33 | 34 | -- assert(#classList == 1000, 'expected 1000 ImageNet classes') 35 | return classList, classToIdx 36 | end 37 | 38 | local function findImages(dir, classToIdx) 39 | local imagePath = torch.CharTensor() 40 | local imageClass = torch.LongTensor() 41 | 42 | ---------------------------------------------------------------------- 43 | -- Options for the GNU and BSD find command 44 | local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 45 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 46 | for i=2,#extensionList do 47 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 48 | end 49 | 50 | -- Find all the images using the find command 51 | local f = io.popen('find -L ' .. dir .. findOptions) 52 | 53 | local maxLength = -1 54 | local imagePaths = {} 55 | local imageClasses = {} 56 | 57 | -- Generate a list of all the images and their class 58 | while true do 59 | local line = f:read('*line') 60 | if not line then break end 61 | 62 | local className = paths.basename(paths.dirname(line)) 63 | local filename = paths.basename(line) 64 | local path = className .. '/' .. filename 65 | 66 | local classId = classToIdx[className] 67 | assert(classId, 'class not found: ' .. className) 68 | 69 | table.insert(imagePaths, path) 70 | table.insert(imageClasses, classId) 71 | 72 | maxLength = math.max(maxLength, #path + 1) 73 | end 74 | 75 | f:close() 76 | 77 | -- Convert the generated list to a tensor for faster loading 78 | local nImages = #imagePaths 79 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 80 | for i, path in ipairs(imagePaths) do 81 | ffi.copy(imagePath[i]:data(), path) 82 | end 83 | 84 | local imageClass = torch.LongTensor(imageClasses) 85 | return imagePath, imageClass 86 | end 87 | 88 | function M.exec(opt, cacheFile) 89 | -- find the image path names 90 | local imagePath = torch.CharTensor() -- path to each image in dataset 91 | local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 92 | 93 | local trainDir = paths.concat(opt.data, 'train') 94 | local valDir = paths.concat(opt.data, 'val') 95 | assert(paths.dirp(trainDir), 'train directory not found: ' .. trainDir) 96 | assert(paths.dirp(valDir), 'val directory not found: ' .. valDir) 97 | 98 | print("=> Generating list of images") 99 | local classList, classToIdx = findClasses(trainDir) 100 | 101 | print(" | finding all validation images") 102 | local valImagePath, valImageClass = findImages(valDir, classToIdx) 103 | 104 | print(" | finding all training images") 105 | local trainImagePath, trainImageClass = findImages(trainDir, classToIdx) 106 | 107 | local info = { 108 | basedir = opt.data, 109 | classList = classList, 110 | train = { 111 | imagePath = trainImagePath, 112 | imageClass = trainImageClass, 113 | }, 114 | val = { 115 | imagePath = valImagePath, 116 | imageClass = valImageClass, 117 | }, 118 | } 119 | 120 | print(" | saving list of images to " .. cacheFile) 121 | torch.save(cacheFile, info) 122 | return info 123 | end 124 | 125 | return M 126 | -------------------------------------------------------------------------------- /datasets/imagenet.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- ImageNet dataset loader 10 | -- 11 | 12 | local image = require 'image' 13 | local paths = require 'paths' 14 | local t = require 'datasets/transforms' 15 | local ffi = require 'ffi' 16 | 17 | local M = {} 18 | local ImagenetDataset = torch.class('resnet.ImagenetDataset', M) 19 | 20 | function ImagenetDataset:__init(imageInfo, opt, split) 21 | self.imageInfo = imageInfo[split] 22 | self.opt = opt 23 | self.split = split 24 | self.dir = paths.concat(opt.data, split) 25 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) 26 | end 27 | 28 | function ImagenetDataset:get(i) 29 | local path = ffi.string(self.imageInfo.imagePath[i]:data()) 30 | 31 | local image = self:_loadImage(paths.concat(self.dir, path)) 32 | local class = self.imageInfo.imageClass[i] 33 | 34 | return { 35 | input = image, 36 | target = class, 37 | } 38 | end 39 | 40 | function ImagenetDataset:_loadImage(path) 41 | local ok, input = pcall(function() 42 | return image.load(path, 3, 'float') 43 | end) 44 | 45 | -- Sometimes image.load fails because the file extension does not match the 46 | -- image format. In that case, use image.decompress on a ByteTensor. 47 | if not ok then 48 | local f = io.open(path, 'r') 49 | assert(f, 'Error reading: ' .. tostring(path)) 50 | local data = f:read('*a') 51 | f:close() 52 | 53 | local b = torch.ByteTensor(string.len(data)) 54 | ffi.copy(b:data(), data, b:size(1)) 55 | 56 | input = image.decompress(b, 3, 'float') 57 | end 58 | 59 | return input 60 | end 61 | 62 | function ImagenetDataset:size() 63 | return self.imageInfo.imageClass:size(1) 64 | end 65 | 66 | -- Computed from random subset of ImageNet training images 67 | local meanstd = { 68 | mean = { 0.485, 0.456, 0.406 }, 69 | std = { 0.229, 0.224, 0.225 }, 70 | } 71 | local pca = { 72 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 73 | eigvec = torch.Tensor{ 74 | { -0.5675, 0.7192, 0.4009 }, 75 | { -0.5808, -0.0045, -0.8140 }, 76 | { -0.5836, -0.6948, 0.4203 }, 77 | }, 78 | } 79 | 80 | function ImagenetDataset:preprocess() 81 | if self.split == 'train' then 82 | return t.Compose{ 83 | t.RandomSizedCrop(224), 84 | t.ColorJitter({ 85 | brightness = 0.4, 86 | contrast = 0.4, 87 | saturation = 0.4, 88 | }), 89 | t.Lighting(0.1, pca.eigval, pca.eigvec), 90 | t.ColorNormalize(meanstd), 91 | t.HorizontalFlip(0.5), 92 | } 93 | elseif self.split == 'val' then 94 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop 95 | return t.Compose{ 96 | t.Scale(256), 97 | t.ColorNormalize(meanstd), 98 | Crop(224), 99 | } 100 | else 101 | error('invalid split: ' .. self.split) 102 | end 103 | end 104 | 105 | return M.ImagenetDataset 106 | -------------------------------------------------------------------------------- /datasets/init.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- ImageNet and CIFAR-10 datasets 10 | -- 11 | 12 | local M = {} 13 | 14 | local function isvalid(opt, cachePath) 15 | local imageInfo = torch.load(cachePath) 16 | if imageInfo.basedir and imageInfo.basedir ~= opt.data then 17 | return false 18 | end 19 | return true 20 | end 21 | 22 | function M.create(opt, split) 23 | local cachePath = paths.concat(opt.gen, opt.dataset .. '.t7') 24 | if not paths.filep(cachePath) or not isvalid(opt, cachePath) then 25 | paths.mkdir('gen') 26 | 27 | local script = paths.dofile(opt.dataset .. '-gen.lua') 28 | script.exec(opt, cachePath) 29 | end 30 | local imageInfo = torch.load(cachePath) 31 | 32 | local Dataset = require('datasets/' .. opt.dataset) 33 | return Dataset(imageInfo, opt, split) 34 | end 35 | 36 | return M 37 | -------------------------------------------------------------------------------- /datasets/transforms.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Image transforms for data augmentation and input normalization 10 | -- 11 | 12 | require 'image' 13 | 14 | local M = {} 15 | 16 | function M.Compose(transforms) 17 | return function(input) 18 | for _, transform in ipairs(transforms) do 19 | input = transform(input) 20 | end 21 | return input 22 | end 23 | end 24 | 25 | function M.ColorNormalize(meanstd) 26 | return function(img) 27 | img = img:clone() 28 | for i=1,3 do 29 | img[i]:add(-meanstd.mean[i]) 30 | img[i]:div(meanstd.std[i]) 31 | end 32 | return img 33 | end 34 | end 35 | 36 | -- Scales the smaller edge to size 37 | function M.Scale(size, interpolation) 38 | interpolation = interpolation or 'bicubic' 39 | return function(input) 40 | local w, h = input:size(3), input:size(2) 41 | if (w <= h and w == size) or (h <= w and h == size) then 42 | return input 43 | end 44 | if w < h then 45 | return image.scale(input, size, h/w * size, interpolation) 46 | else 47 | return image.scale(input, w/h * size, size, interpolation) 48 | end 49 | end 50 | end 51 | 52 | -- Crop to centered rectangle 53 | function M.CenterCrop(size) 54 | return function(input) 55 | local w1 = math.ceil((input:size(3) - size)/2) 56 | local h1 = math.ceil((input:size(2) - size)/2) 57 | return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch 58 | end 59 | end 60 | 61 | -- Random crop form larger image with optional zero padding 62 | function M.RandomCrop(size, padding) 63 | padding = padding or 0 64 | 65 | return function(input) 66 | if padding > 0 then 67 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) 68 | temp:zero() 69 | :narrow(2, padding+1, input:size(2)) 70 | :narrow(3, padding+1, input:size(3)) 71 | :copy(input) 72 | input = temp 73 | end 74 | 75 | local w, h = input:size(3), input:size(2) 76 | if w == size and h == size then 77 | return input 78 | end 79 | 80 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 81 | local out = image.crop(input, x1, y1, x1 + size, y1 + size) 82 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 83 | return out 84 | end 85 | end 86 | 87 | -- Four corner patches and center crop from image and its horizontal reflection 88 | function M.TenCrop(size) 89 | local centerCrop = M.CenterCrop(size) 90 | 91 | return function(input) 92 | local w, h = input:size(3), input:size(2) 93 | 94 | local output = {} 95 | for _, img in ipairs{input, image.hflip(input)} do 96 | table.insert(output, centerCrop(img)) 97 | table.insert(output, image.crop(img, 0, 0, size, size)) 98 | table.insert(output, image.crop(img, w-size, 0, w, size)) 99 | table.insert(output, image.crop(img, 0, h-size, size, h)) 100 | table.insert(output, image.crop(img, w-size, h-size, w, h)) 101 | end 102 | 103 | -- View as mini-batch 104 | for i, img in ipairs(output) do 105 | output[i] = img:view(1, img:size(1), img:size(2), img:size(3)) 106 | end 107 | 108 | return input.cat(output, 1) 109 | end 110 | end 111 | 112 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style) 113 | function M.RandomScale(minSize, maxSize) 114 | return function(input) 115 | local w, h = input:size(3), input:size(2) 116 | 117 | local targetSz = torch.random(minSize, maxSize) 118 | local targetW, targetH = targetSz, targetSz 119 | if w < h then 120 | targetH = torch.round(h / w * targetW) 121 | else 122 | targetW = torch.round(w / h * targetH) 123 | end 124 | 125 | return image.scale(input, targetW, targetH, 'bicubic') 126 | end 127 | end 128 | 129 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style) 130 | function M.RandomSizedCrop(size) 131 | local scale = M.Scale(size) 132 | local crop = M.CenterCrop(size) 133 | 134 | return function(input) 135 | local attempt = 0 136 | repeat 137 | local area = input:size(2) * input:size(3) 138 | local targetArea = torch.uniform(0.08, 1.0) * area 139 | 140 | local aspectRatio = torch.uniform(3/4, 4/3) 141 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 142 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 143 | 144 | if torch.uniform() < 0.5 then 145 | w, h = h, w 146 | end 147 | 148 | if h <= input:size(2) and w <= input:size(3) then 149 | local y1 = torch.random(0, input:size(2) - h) 150 | local x1 = torch.random(0, input:size(3) - w) 151 | 152 | local out = image.crop(input, x1, y1, x1 + w, y1 + h) 153 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 154 | 155 | return image.scale(out, size, size, 'bicubic') 156 | end 157 | attempt = attempt + 1 158 | until attempt >= 10 159 | 160 | -- fallback 161 | return crop(scale(input)) 162 | end 163 | end 164 | 165 | function M.HorizontalFlip(prob) 166 | return function(input) 167 | if torch.uniform() < prob then 168 | input = image.hflip(input) 169 | end 170 | return input 171 | end 172 | end 173 | 174 | function M.Rotation(deg) 175 | return function(input) 176 | if deg ~= 0 then 177 | input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear') 178 | end 179 | return input 180 | end 181 | end 182 | 183 | -- Lighting noise (AlexNet-style PCA-based noise) 184 | function M.Lighting(alphastd, eigval, eigvec) 185 | return function(input) 186 | if alphastd == 0 then 187 | return input 188 | end 189 | 190 | local alpha = torch.Tensor(3):normal(0, alphastd) 191 | local rgb = eigvec:clone() 192 | :cmul(alpha:view(1, 3):expand(3, 3)) 193 | :cmul(eigval:view(1, 3):expand(3, 3)) 194 | :sum(2) 195 | :squeeze() 196 | 197 | input = input:clone() 198 | for i=1,3 do 199 | input[i]:add(rgb[i]) 200 | end 201 | return input 202 | end 203 | end 204 | 205 | local function blend(img1, img2, alpha) 206 | return img1:mul(alpha):add(1 - alpha, img2) 207 | end 208 | 209 | local function grayscale(dst, img) 210 | dst:resizeAs(img) 211 | dst[1]:zero() 212 | dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3]) 213 | dst[2]:copy(dst[1]) 214 | dst[3]:copy(dst[1]) 215 | return dst 216 | end 217 | 218 | function M.Saturation(var) 219 | local gs 220 | 221 | return function(input) 222 | gs = gs or input.new() 223 | grayscale(gs, input) 224 | 225 | local alpha = 1.0 + torch.uniform(-var, var) 226 | blend(input, gs, alpha) 227 | return input 228 | end 229 | end 230 | 231 | function M.Brightness(var) 232 | local gs 233 | 234 | return function(input) 235 | gs = gs or input.new() 236 | gs:resizeAs(input):zero() 237 | 238 | local alpha = 1.0 + torch.uniform(-var, var) 239 | blend(input, gs, alpha) 240 | return input 241 | end 242 | end 243 | 244 | function M.Contrast(var) 245 | local gs 246 | 247 | return function(input) 248 | gs = gs or input.new() 249 | grayscale(gs, input) 250 | gs:fill(gs[1]:mean()) 251 | 252 | local alpha = 1.0 + torch.uniform(-var, var) 253 | blend(input, gs, alpha) 254 | return input 255 | end 256 | end 257 | 258 | function M.RandomOrder(ts) 259 | return function(input) 260 | local img = input.img or input 261 | local order = torch.randperm(#ts) 262 | for i=1,#ts do 263 | img = ts[order[i]](img) 264 | end 265 | return img 266 | end 267 | end 268 | 269 | function M.ColorJitter(opt) 270 | local brightness = opt.brightness or 0 271 | local contrast = opt.contrast or 0 272 | local saturation = opt.saturation or 0 273 | 274 | local ts = {} 275 | if brightness ~= 0 then 276 | table.insert(ts, M.Brightness(brightness)) 277 | end 278 | if contrast ~= 0 then 279 | table.insert(ts, M.Contrast(contrast)) 280 | end 281 | if saturation ~= 0 then 282 | table.insert(ts, M.Saturation(saturation)) 283 | end 284 | 285 | if #ts == 0 then 286 | return function(input) return input end 287 | end 288 | 289 | return M.RandomOrder(ts) 290 | end 291 | 292 | return M 293 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | require 'torch' 10 | require 'paths' 11 | require 'optim' 12 | require 'nn' 13 | local DataLoader = require 'dataloader' 14 | local models = require 'models/init' 15 | local Trainer = require 'train' 16 | local opts = require 'opts' 17 | local checkpoints = require 'checkpoints' 18 | 19 | -- we don't change this to the 'correct' type (e.g. HalfTensor), because math 20 | -- isn't supported on that type. Type conversion later will handle having 21 | -- the correct type. 22 | torch.setdefaulttensortype('torch.FloatTensor') 23 | torch.setnumthreads(1) 24 | 25 | local opt = opts.parse(arg) 26 | torch.manualSeed(opt.manualSeed) 27 | cutorch.manualSeedAll(opt.manualSeed) 28 | 29 | -- Load previous checkpoint, if it exists 30 | local checkpoint, optimState = checkpoints.latest(opt) 31 | 32 | -- Create model 33 | local model, criterion = models.setup(opt, checkpoint) 34 | 35 | -- Data loading 36 | local trainLoader, valLoader = DataLoader.create(opt) 37 | 38 | -- The trainer handles the training loop and evaluation on validation set 39 | local trainer = Trainer(model, criterion, opt, optimState) 40 | 41 | if opt.testOnly then 42 | local top1Err, top5Err = trainer:test(0, valLoader) 43 | print(string.format(' * Results top1: %6.3f top5: %6.3f', top1Err, top5Err)) 44 | return 45 | end 46 | 47 | local startEpoch = checkpoint and checkpoint.epoch + 1 or opt.epochNumber 48 | local bestTop1 = math.huge 49 | local bestTop5 = math.huge 50 | for epoch = startEpoch, opt.nEpochs do 51 | -- Train for a single epoch 52 | local trainTop1, trainTop5, trainLoss = trainer:train(epoch, trainLoader) 53 | 54 | -- Run model on validation set 55 | local testTop1, testTop5 = trainer:test(epoch, valLoader) 56 | 57 | local bestModel = false 58 | if testTop1 < bestTop1 then 59 | bestModel = true 60 | bestTop1 = testTop1 61 | bestTop5 = testTop5 62 | print(' * Best model ', testTop1, testTop5) 63 | end 64 | 65 | checkpoints.save(epoch, model, trainer.optimState, bestModel, opt) 66 | end 67 | 68 | print(string.format(' * Finished top1: %6.3f top5: %6.3f', bestTop1, bestTop5)) -------------------------------------------------------------------------------- /models/DenseConnectLayer.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cudnn' 3 | require 'cunn' 4 | 5 | 6 | local function ShareGradInput(module, key) 7 | assert(key) 8 | module.__shareGradInputKey = key 9 | return module 10 | end 11 | 12 | -------------------------------------------------------------------------------- 13 | -- Standard densely connected layer (memory inefficient) 14 | -------------------------------------------------------------------------------- 15 | function DenseConnectLayerStandard(nChannels, opt) 16 | local net = nn.Sequential() 17 | 18 | net:add(ShareGradInput(cudnn.SpatialBatchNormalization(nChannels), 'first')) 19 | net:add(cudnn.ReLU(true)) 20 | if opt.bottleneck then 21 | net:add(cudnn.SpatialConvolution(nChannels, 4 * opt.growthRate, 1, 1, 1, 1, 0, 0)) 22 | nChannels = 4 * opt.growthRate 23 | if opt.dropRate > 0 then net:add(nn.Dropout(opt.dropRate)) end 24 | net:add(cudnn.SpatialBatchNormalization(nChannels)) 25 | net:add(cudnn.ReLU(true)) 26 | end 27 | net:add(cudnn.SpatialConvolution(nChannels, opt.growthRate, 3, 3, 1, 1, 1, 1)) 28 | if opt.dropRate > 0 then net:add(nn.Dropout(opt.dropRate)) end 29 | 30 | return nn.Sequential() 31 | :add(nn.Concat(2) 32 | :add(nn.Identity()) 33 | :add(net)) 34 | end 35 | 36 | -------------------------------------------------------------------------------- 37 | -- Customized densely connected layer (memory efficient) 38 | -------------------------------------------------------------------------------- 39 | local DenseConnectLayerCustom, parent = torch.class('nn.DenseConnectLayerCustom', 'nn.Container') 40 | 41 | function DenseConnectLayerCustom:__init(nChannels, opt) 42 | parent.__init(self) 43 | self.train = true 44 | self.opt = opt 45 | 46 | self.net1 = nn.Sequential() 47 | self.net1:add(ShareGradInput(cudnn.SpatialBatchNormalization(nChannels), 'first')) 48 | self.net1:add(cudnn.ReLU(true)) 49 | 50 | self.net2 = nn.Sequential() 51 | if opt.bottleneck then 52 | self.net2:add(cudnn.SpatialConvolution(nChannels, 4*opt.growthRate, 1, 1, 1, 1, 0, 0)) 53 | nChannels = 4 * opt.growthRate 54 | self.net2:add(cudnn.SpatialBatchNormalization(nChannels)) 55 | self.net2:add(cudnn.ReLU(true)) 56 | end 57 | self.net2:add(cudnn.SpatialConvolution(nChannels, opt.growthRate, 3, 3, 1, 1, 1, 1)) 58 | 59 | -- contiguous outputs of previous layers 60 | self.input_c = torch.Tensor():type(opt.tensorType) 61 | -- save a copy of BatchNorm statistics before forwarding it for the second time when optMemory=4 62 | self.saved_bn_running_mean = torch.Tensor():type(opt.tensorType) 63 | self.saved_bn_running_var = torch.Tensor():type(opt.tensorType) 64 | 65 | self.gradInput = {} 66 | self.output = {} 67 | 68 | self.modules = {self.net1, self.net2} 69 | end 70 | 71 | function DenseConnectLayerCustom:updateOutput(input) 72 | 73 | if type(input) ~= 'table' then 74 | self.output[1] = input 75 | self.output[2] = self.net2:forward(self.net1:forward(input)) 76 | else 77 | for i = 1, #input do 78 | self.output[i] = input[i] 79 | end 80 | torch.cat(self.input_c, input, 2) 81 | self.net1:forward(self.input_c) 82 | self.output[#input+1] = self.net2:forward(self.net1.output) 83 | end 84 | 85 | if self.opt.optMemory == 4 then 86 | local running_mean, running_var = self.net1:get(1).running_mean, self.net1:get(1).running_var 87 | self.saved_bn_running_mean:resizeAs(running_mean):copy(running_mean) 88 | self.saved_bn_running_var:resizeAs(running_var):copy(running_var) 89 | end 90 | 91 | return self.output 92 | end 93 | 94 | function DenseConnectLayerCustom:updateGradInput(input, gradOutput) 95 | 96 | if type(input) ~= 'table' then 97 | self.gradInput = gradOutput[1] 98 | if self.opt.optMemory == 4 then self.net1:forward(input) end 99 | self.net2:updateGradInput(self.net1.output, gradOutput[2]) 100 | self.gradInput:add(self.net1:updateGradInput(input, self.net2.gradInput)) 101 | else 102 | torch.cat(self.input_c, input, 2) 103 | if self.opt.optMemory == 4 then self.net1:forward(self.input_c) end 104 | self.net2:updateGradInput(self.net1.output, gradOutput[#gradOutput]) 105 | self.net1:updateGradInput(self.input_c, self.net2.gradInput) 106 | local nC = 1 107 | for i = 1, #input do 108 | self.gradInput[i] = gradOutput[i] 109 | self.gradInput[i]:add(self.net1.gradInput:narrow(2,nC,input[i]:size(2))) 110 | nC = nC + input[i]:size(2) 111 | end 112 | end 113 | 114 | if self.opt.optMemory == 4 then 115 | self.net1:get(1).running_mean:copy(self.saved_bn_running_mean) 116 | self.net1:get(1).running_var:copy(self.saved_bn_running_var) 117 | end 118 | 119 | return self.gradInput 120 | end 121 | 122 | function DenseConnectLayerCustom:accGradParameters(input, gradOutput, scale) 123 | scale = scale or 1 124 | self.net2:accGradParameters(self.net1.output, gradOutput[#gradOutput], scale) 125 | if type(input) ~= 'table' then 126 | self.net1:accGradParameters(input, self.net2.gradInput, scale) 127 | else 128 | self.net1:accGradParameters(self.input_c, self.net2.gradInput, scale) 129 | end 130 | end 131 | 132 | function DenseConnectLayerCustom:__tostring__() 133 | local tab = ' ' 134 | local line = '\n' 135 | local next = ' |`-> ' 136 | local lastNext = ' `-> ' 137 | local ext = ' | ' 138 | local extlast = ' ' 139 | local last = ' ... -> ' 140 | local str = 'DenseConnectLayerCustom' 141 | str = str .. ' {' .. line .. tab .. '{input}' 142 | for i=1,#self.modules do 143 | if i == #self.modules then 144 | str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) 145 | else 146 | str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) 147 | end 148 | end 149 | str = str .. line .. tab .. last .. '{output}' 150 | str = str .. line .. '}' 151 | return str 152 | end 153 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # Memory Efficient Implementation of DenseNets 2 | 3 | The standard (orginal) implementation of DenseNet with recursive concatenation is very memory inefficient. This can be an obstacle when we need to train DenseNets on high resolution images (such as for object detection and localization tasks) or on devices with limited memory. 4 | 5 | In theory, DenseNet should use memory more efficiently than other networks, because one of its key features is that it encourages *feature reusing* in the network. The fact that DenseNet is "memory hungry" in practice is simply an artifact of implementation. In particular, the culprit is the recursive concatenation which *re-allocates memory* for all previous outputs at each layer. Consider a dense block with N layers, the first layer's output has N copies in the memory, the second layer's output has (N-1) copies, ..., leading to a quadratic increase (1+2+...+N) in memory consumption as the network depth grows. 6 | 7 | Using *optnet* (*-optMemory 1*) or *shareGradInput* (*-optMemory 2*), we can significantly reduce the run-time memory footprint of the standard implementaion (with recursive concatenation). However, the memory consumption is still a quadratic function in depth. 8 | 9 | We implement a customized densely connected layer (largely motivated by the [Caffe implementation of memory efficient DenseNet](https://github.com/Tongcheng/DN_CaffeScript) by [Tongcheng](https://github.com/Tongcheng)), which uses shared buffers to store the concatenated outputs and gradients, thus dramatically reducing the memory footprint of DenseNet during training. The mode *-optMemory 3* activates shareGradInput and shared output buffers, while the mode *-optMemory 4* further shares the memory to store the output of the Batch-Normalization layer before each 1x1 convolution layer. The latter makes the memory consumption *linear* in network depth, but introduces a training time overhead due to the need to re-forward these Batch-Normalization layers in the backward pass. 10 | 11 | In practice, we suggest using the default *-optMemory 2*, as it does not require customized layers, while the memory consumption is moderate. When GPU memory is really the bottleneck, we can adopt the customized implementation by setting *-optMemory* to 3 or 4, e.g., 12 | ``` 13 | th main.lua -netType densenet -dataset cifar10 -batchSize 64 -nEpochs 300 -depth 100 -growthRate 12 -optMemory 4 14 | ``` 15 | 16 | The following time and memory footprint are benchmarked on a DenseNet-BC (l=100, k=12) on CIFAR-10, and on an NVIDIA TitanX GPU: 17 | 18 | optMemory | Memory | Time (s/mini-batch) | Description | 19 | :-------:|:-------:|:--------:|:-------| 20 | | 0 | 5453M | 0.153 | Original implementation 21 | | 1 | 3746M | 0.153 | Original implementation with optnet 22 | | 2 | 2969M | 0.152 | Original implementation with shareGradInput 23 | | 3 | 2188M | 0.155 | Customized implementation with shareGradInput and sharePrevOutput 24 | | 4 | 1655M | 0.175 | Customized implementation with shareGradInput, sharePrevOutput and shareBNOutput -------------------------------------------------------------------------------- /models/densenet.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cunn' 3 | require 'cudnn' 4 | require 'models/DenseConnectLayer' 5 | 6 | local function createModel(opt) 7 | 8 | --growth rate 9 | local growthRate = opt.growthRate 10 | 11 | --dropout rate, set it to 0 to disable dropout, non-zero number to enable dropout and set drop rate 12 | local dropRate = opt.dropRate 13 | 14 | --# channels before entering the first Dense-Block 15 | local nChannels = 2 * growthRate 16 | 17 | --compression rate at transition layers 18 | local reduction = opt.reduction 19 | 20 | --whether to use bottleneck structures 21 | local bottleneck = opt.bottleneck 22 | 23 | --N: # dense connected layers in each denseblock 24 | local N = (opt.depth - 4)/3 25 | if bottleneck then N = N/2 end 26 | 27 | 28 | function addLayer(model, nChannels, opt) 29 | if opt.optMemory >= 3 then 30 | model:add(nn.DenseConnectLayerCustom(nChannels, opt)) 31 | else 32 | model:add(DenseConnectLayerStandard(nChannels, opt)) 33 | end 34 | end 35 | 36 | 37 | function addTransition(model, nChannels, nOutChannels, opt, last, pool_size) 38 | if opt.optMemory >= 3 then 39 | model:add(nn.JoinTable(2)) 40 | end 41 | 42 | model:add(cudnn.SpatialBatchNormalization(nChannels)) 43 | model:add(cudnn.ReLU(true)) 44 | if last then 45 | model:add(cudnn.SpatialAveragePooling(pool_size, pool_size)) 46 | model:add(nn.Reshape(nChannels)) 47 | else 48 | model:add(cudnn.SpatialConvolution(nChannels, nOutChannels, 1, 1, 1, 1, 0, 0)) 49 | if opt.dropRate > 0 then model:add(nn.Dropout(opt.dropRate)) end 50 | model:add(cudnn.SpatialAveragePooling(2, 2)) 51 | end 52 | end 53 | 54 | 55 | local function addDenseBlock(model, nChannels, opt, N) 56 | for i = 1, N do 57 | addLayer(model, nChannels, opt) 58 | nChannels = nChannels + opt.growthRate 59 | end 60 | return nChannels 61 | end 62 | 63 | 64 | -- Build DenseNet 65 | local model = nn.Sequential() 66 | 67 | if opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then 68 | 69 | --Initial convolution layer 70 | model:add(cudnn.SpatialConvolution(3, nChannels, 3,3, 1,1, 1,1)) 71 | 72 | --Dense-Block 1 and transition 73 | nChannels = addDenseBlock(model, nChannels, opt, N) 74 | addTransition(model, nChannels, math.floor(nChannels*reduction), opt) 75 | nChannels = math.floor(nChannels*reduction) 76 | 77 | --Dense-Block 2 and transition 78 | nChannels = addDenseBlock(model, nChannels, opt, N) 79 | addTransition(model, nChannels, math.floor(nChannels*reduction), opt) 80 | nChannels = math.floor(nChannels*reduction) 81 | 82 | --Dense-Block 3 and transition 83 | nChannels = addDenseBlock(model, nChannels, opt, N) 84 | addTransition(model, nChannels, nChannels, opt, true, 8) 85 | 86 | elseif opt.dataset == 'imagenet' then 87 | 88 | --number of layers in each block 89 | if opt.depth == 121 then 90 | stages = {6, 12, 24, 16} 91 | elseif opt.depth == 169 then 92 | stages = {6, 12, 32, 32} 93 | elseif opt.depth == 201 then 94 | stages = {6, 12, 48, 32} 95 | elseif opt.depth == 161 then 96 | stages = {6, 12, 36, 24} 97 | else 98 | stages = {opt.d1, opt.d2, opt.d3, opt.d4} 99 | end 100 | 101 | --Initial transforms follow ResNet(224x224) 102 | model:add(cudnn.SpatialConvolution(3, nChannels, 7,7, 2,2, 3,3)) 103 | model:add(cudnn.SpatialBatchNormalization(nChannels)) 104 | model:add(cudnn.ReLU(true)) 105 | model:add(nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1)) 106 | 107 | --Dense-Block 1 and transition (56x56) 108 | nChannels = addDenseBlock(model, nChannels, opt, stages[1]) 109 | addTransition(model, nChannels, math.floor(nChannels*reduction), opt) 110 | nChannels = math.floor(nChannels*reduction) 111 | 112 | --Dense-Block 2 and transition (28x28) 113 | nChannels = addDenseBlock(model, nChannels, opt, stages[2]) 114 | addTransition(model, nChannels, math.floor(nChannels*reduction), opt) 115 | nChannels = math.floor(nChannels*reduction) 116 | 117 | --Dense-Block 3 and transition (14x14) 118 | nChannels = addDenseBlock(model, nChannels, opt, stages[3]) 119 | addTransition(model, nChannels, math.floor(nChannels*reduction), opt) 120 | nChannels = math.floor(nChannels*reduction) 121 | 122 | --Dense-Block 4 and transition (7x7) 123 | nChannels = addDenseBlock(model, nChannels, opt, stages[4]) 124 | addTransition(model, nChannels, nChannels, opt, true, 7) 125 | 126 | end 127 | 128 | 129 | if opt.dataset == 'cifar10' then 130 | model:add(nn.Linear(nChannels, 10)) 131 | elseif opt.dataset == 'cifar100' then 132 | model:add(nn.Linear(nChannels, 100)) 133 | elseif opt.dataset == 'imagenet' then 134 | model:add(nn.Linear(nChannels, 1000)) 135 | end 136 | 137 | --Initialization following ResNet 138 | local function ConvInit(name) 139 | for k,v in pairs(model:findModules(name)) do 140 | local n = v.kW*v.kH*v.nOutputPlane 141 | v.weight:normal(0,math.sqrt(2/n)) 142 | if cudnn.version >= 4000 then 143 | v.bias = nil 144 | v.gradBias = nil 145 | else 146 | v.bias:zero() 147 | end 148 | end 149 | end 150 | 151 | local function BNInit(name) 152 | for k,v in pairs(model:findModules(name)) do 153 | v.weight:fill(1) 154 | v.bias:zero() 155 | end 156 | end 157 | 158 | ConvInit('cudnn.SpatialConvolution') 159 | BNInit('cudnn.SpatialBatchNormalization') 160 | for k,v in pairs(model:findModules('nn.Linear')) do 161 | v.bias:zero() 162 | end 163 | 164 | model:type(opt.tensorType) 165 | 166 | if opt.cudnn == 'deterministic' then 167 | model:apply(function(m) 168 | if m.setMode then m:setMode(1,1,1) end 169 | end) 170 | end 171 | 172 | model:get(1).gradInput = nil 173 | 174 | print(model) 175 | return model 176 | end 177 | 178 | return createModel -------------------------------------------------------------------------------- /models/init.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Generic model creating code. For the specific ResNet model see 10 | -- models/resnet.lua 11 | -- 12 | -- Code modified for DenseNet (https://arxiv.org/abs/1608.06993) by Gao Huang. 13 | -- 14 | -- More details about the memory efficient implementation can be found in the 15 | -- technique report "Memory-Efficient Implementation of DenseNets" 16 | -- (https://arxiv.org/pdf/1707.06990.pdf) 17 | 18 | require 'nn' 19 | require 'cunn' 20 | require 'cudnn' 21 | require 'models/DenseConnectLayer' 22 | 23 | local M = {} 24 | 25 | function M.setup(opt, checkpoint) 26 | 27 | print('=> Creating model from file: models/' .. opt.netType .. '.lua') 28 | local model = require('models/' .. opt.netType)(opt) 29 | if checkpoint then 30 | local modelPath = paths.concat(opt.resume, checkpoint.modelFile) 31 | assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath) 32 | print('=> Resuming model from ' .. modelPath) 33 | local model0 = torch.load(modelPath):type(opt.tensorType) 34 | M.copyModel(model, model0) 35 | elseif opt.retrain ~= 'none' then 36 | assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain) 37 | print('Loading model from file: ' .. opt.retrain) 38 | local model0 = torch.load(opt.retrain):type(opt.tensorType) 39 | M.copyModel(model, model0) 40 | end 41 | 42 | -- First remove any DataParallelTable 43 | if torch.type(model) == 'nn.DataParallelTable' then 44 | model = model:get(1) 45 | end 46 | 47 | -- optnet is an general library for reducing memory usage in neural networks 48 | if opt.optnet or opt.optMemory == 1 then 49 | local optnet = require 'optnet' 50 | local imsize = opt.dataset == 'imagenet' and 224 or 32 51 | local sampleInput = torch.zeros(4,3,imsize,imsize):type(opt.tensorType) 52 | optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'}) 53 | end 54 | 55 | -- This is useful for fitting ResNet-50 on 4 GPUs, but requires that all 56 | -- containers override backwards to call backwards recursively on submodules 57 | if opt.shareGradInput or opt.optMemory >= 2 then 58 | M.shareGradInput(model, opt) 59 | M.sharePrevOutput(model, opt) 60 | end 61 | 62 | -- Share the contiguous (concatenated) outputs of previous layers in DenseNet. 63 | if opt.optMemory == 3 then 64 | M.sharePrevOutput(model, opt) 65 | end 66 | 67 | -- Share the output of BatchNorm in bottleneck layers of DenseNet. This requires 68 | -- forwarding the BN layer twice at each mini-batch, but makes the memory consumption 69 | -- linear (instead of quadratic) in depth 70 | if opt.optMemory == 4 then 71 | M.shareBNOutput(model, opt) 72 | end 73 | 74 | -- For resetting the classifier when fine-tuning on a different Dataset 75 | if opt.resetClassifier and not checkpoint then 76 | print(' => Replacing classifier with ' .. opt.nClasses .. '-way classifier') 77 | 78 | local orig = model:get(#model.modules) 79 | assert(torch.type(orig) == 'nn.Linear', 80 | 'expected last layer to be fully connected') 81 | 82 | local linear = nn.Linear(orig.weight:size(2), opt.nClasses) 83 | linear.bias:zero() 84 | 85 | model:remove(#model.modules) 86 | model:add(linear:type(opt.tensorType)) 87 | end 88 | 89 | -- Set the CUDNN flags 90 | if opt.cudnn == 'fastest' then 91 | cudnn.fastest = true 92 | cudnn.benchmark = true 93 | elseif opt.cudnn == 'deterministic' then 94 | -- Use a deterministic convolution implementation 95 | model:apply(function(m) 96 | if m.setMode then m:setMode(1, 1, 1) end 97 | end) 98 | end 99 | 100 | -- Wrap the model with DataParallelTable, if using more than one GPU 101 | if opt.nGPU > 1 then 102 | local gpus = torch.range(1, opt.nGPU):totable() 103 | local fastest, benchmark = cudnn.fastest, cudnn.benchmark 104 | 105 | local dpt = nn.DataParallelTable(1, true, true) 106 | :add(model, gpus) 107 | :threads(function() 108 | local cudnn = require 'cudnn' 109 | require 'models/DenseConnectLayer' 110 | cudnn.fastest, cudnn.benchmark = fastest, benchmark 111 | end) 112 | dpt.gradInput = nil 113 | 114 | model = dpt:type(opt.tensorType) 115 | end 116 | 117 | local criterion = nn.CrossEntropyCriterion():type(opt.tensorType) 118 | return model, criterion 119 | end 120 | 121 | function M.shareGradInput(model, opt) 122 | local function sharingKey(m) 123 | local key = torch.type(m) 124 | if m.__shareGradInputKey then 125 | key = key .. ':' .. m.__shareGradInputKey 126 | end 127 | return key 128 | end 129 | 130 | -- Share gradInput for memory efficient backprop 131 | local cache = {} 132 | model:apply(function(m) 133 | local moduleType = torch.type(m) 134 | if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' and moduleType ~= 'nn.Concat' then 135 | local key = sharingKey(m) 136 | if cache[key] == nil then 137 | cache[key] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1) 138 | end 139 | m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[key], 1, 0) 140 | end 141 | end) 142 | for i, m in ipairs(model:findModules('nn.ConcatTable')) do 143 | if cache[i % 2] == nil then 144 | cache[i % 2] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1) 145 | end 146 | m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[i % 2], 1, 0) 147 | end 148 | for i, m in ipairs(model:findModules('nn.Concat')) do 149 | if cache[i % 2] == nil then 150 | cache[i % 2] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1) 151 | end 152 | m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[i % 2], 1, 0) 153 | end 154 | end 155 | 156 | function M.sharePrevOutput(model, opt) 157 | -- Share contiguous output for memory efficient densenet 158 | local buffer = nil 159 | model:apply(function(m) 160 | local moduleType = torch.type(m) 161 | if moduleType == 'nn.DenseConnectLayerCustom' then 162 | if buffer == nil then 163 | buffer = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1) 164 | end 165 | m.input_c = torch[opt.tensorType:match('torch.(%a+)')](buffer, 1, 0) 166 | end 167 | end) 168 | end 169 | 170 | function M.shareBNOutput(model, opt) 171 | -- Share BN.output for memory efficient densenet 172 | local buffer = nil 173 | model:apply(function(m) 174 | local moduleType = torch.type(m) 175 | if moduleType == 'nn.DenseConnectLayerCustom' then 176 | if buffer == nil then 177 | buffer = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1) 178 | end 179 | m.net1:get(1).output = torch[opt.tensorType:match('torch.(%a+)')](buffer, 1, 0) 180 | end 181 | end) 182 | end 183 | 184 | function M.copyModel(t, s) 185 | local wt, ws = t:parameters(), s:parameters() 186 | assert(#wt==#ws, 'Model configurations does not match the resumed model!') 187 | for l = 1, #wt do 188 | wt[l]:copy(ws[l]) 189 | end 190 | local bn_t, bn_s = {}, {} 191 | for i, m in ipairs(s:findModules('cudnn.SpatialBatchNormalization')) do 192 | bn_s[i] = m 193 | end 194 | for i, m in ipairs(t:findModules('cudnn.SpatialBatchNormalization')) do 195 | bn_t[i] = m 196 | end 197 | assert(#bn_t==#bn_s, 'Model configurations does not match the resumed model!') 198 | for i = 1, #bn_s do 199 | bn_t[i].running_mean:copy(bn_s[i].running_mean) 200 | bn_t[i].running_var:copy(bn_s[i].running_var) 201 | end 202 | end 203 | 204 | return M 205 | -------------------------------------------------------------------------------- /opts.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Code modified for DenseNet (https://arxiv.org/abs/1608.06993) by Gao Huang. 10 | -- 11 | local M = { } 12 | 13 | function M.parse(arg) 14 | local cmd = torch.CmdLine() 15 | cmd:text() 16 | cmd:text('Torch-7 ResNet Training script') 17 | cmd:text('See https://github.com/facebook/fb.resnet.torch/blob/master/TRAINING.md for examples') 18 | cmd:text() 19 | cmd:text('Options:') 20 | ------------ General options -------------------- 21 | cmd:option('-data', '', 'Path to dataset') 22 | cmd:option('-dataset', 'cifar10', 'Options: imagenet | cifar10 | cifar100') 23 | cmd:option('-manualSeed', 0, 'Manually set RNG seed') 24 | cmd:option('-nGPU', 1, 'Number of GPUs to use by default') 25 | cmd:option('-backend', 'cudnn', 'Options: cudnn | cunn') 26 | cmd:option('-cudnn', 'fastest', 'Options: fastest | default | deterministic') 27 | cmd:option('-gen', 'gen', 'Path to save generated files') 28 | cmd:option('-precision', 'single', 'Options: single | double | half') 29 | ------------- Data options ------------------------ 30 | cmd:option('-nThreads', 2, 'number of data loading threads') 31 | ------------- Training options -------------------- 32 | cmd:option('-nEpochs', 0, 'Number of total epochs to run') 33 | cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)') 34 | cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)') 35 | cmd:option('-testOnly', 'false', 'Run on validation set only') 36 | cmd:option('-tenCrop', 'false', 'Ten-crop testing') 37 | ------------- Checkpointing options --------------- 38 | cmd:option('-save', 'checkpoints', 'Directory in which to save checkpoints') 39 | cmd:option('-resume', 'none', 'Resume from the latest checkpoint in this directory') 40 | ---------- Optimization options ---------------------- 41 | cmd:option('-LR', 0.1, 'initial learning rate') 42 | cmd:option('-momentum', 0.9, 'momentum') 43 | cmd:option('-weightDecay', 1e-4, 'weight decay') 44 | cmd:option('-lrShape', 'multistep', 'Learning rate: multistep|cosine') 45 | ---------- Model options ---------------------------------- 46 | cmd:option('-netType', 'resnet', 'Options: resnet | preresnet') 47 | cmd:option('-depth', 20, 'ResNet depth: 18 | 34 | 50 | 101 | ...', 'number') 48 | cmd:option('-shortcutType', '', 'Options: A | B | C') 49 | cmd:option('-retrain', 'none', 'Path to model to retrain with') 50 | cmd:option('-optimState', 'none', 'Path to an optimState to reload from') 51 | ---------- Model options ---------------------------------- 52 | cmd:option('-shareGradInput', 'false', 'Share gradInput tensors to reduce memory usage') 53 | cmd:option('-optnet', 'false', 'Use optnet to reduce memory usage') 54 | cmd:option('-resetClassifier', 'false', 'Reset the fully connected layer for fine-tuning') 55 | cmd:option('-nClasses', 0, 'Number of classes in the dataset') 56 | ---------- Model options for DenseNet ---------------------------------- 57 | cmd:option('-growthRate', 12, 'Number of output channels at each convolutional layer') 58 | cmd:option('-bottleneck', 'true', 'Use 1x1 convolution to reduce dimension (DenseNet-B)') 59 | cmd:option('-reduction', 0.5, 'Channel compress ratio at transition layer (DenseNet-C)') 60 | cmd:option('-dropRate', 0, 'Dropout probability') 61 | cmd:option('-optMemory', 2, 'Optimize memory for DenseNet: 0 | 1 | 2 | 3 | 4 | 5', 'number') 62 | -- The following hyperparameters are activated when depth is not from {121, 161, 169, 201} (for ImageNet only) 63 | cmd:option('-d1', 0, 'Number of layers in block 1') 64 | cmd:option('-d2', 0, 'Number of layers in block 2') 65 | cmd:option('-d3', 0, 'Number of layers in block 3') 66 | cmd:option('-d4', 0, 'Number of layers in block 4') 67 | 68 | cmd:text() 69 | 70 | local opt = cmd:parse(arg or {}) 71 | 72 | opt.testOnly = opt.testOnly ~= 'false' 73 | opt.tenCrop = opt.tenCrop ~= 'false' 74 | opt.shareGradInput = opt.shareGradInput ~= 'false' 75 | opt.optnet = opt.optnet ~= 'false' 76 | opt.resetClassifier = opt.resetClassifier ~= 'false' 77 | opt.bottleneck = opt.bottleneck ~= 'false' 78 | 79 | if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then 80 | cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n') 81 | end 82 | 83 | if opt.dataset == 'imagenet' then 84 | -- Handle the most common case of missing -data flag 85 | local trainDir = paths.concat(opt.data, 'train') 86 | if not paths.dirp(opt.data) then 87 | cmd:error('error: missing ImageNet data directory') 88 | elseif not paths.dirp(trainDir) then 89 | cmd:error('error: ImageNet missing `train` directory: ' .. trainDir) 90 | end 91 | -- Default shortcutType=B and nEpochs=90 92 | opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType 93 | opt.nEpochs = opt.nEpochs == 0 and 90 or opt.nEpochs 94 | elseif opt.dataset == 'cifar10' then 95 | -- Default shortcutType=A and nEpochs=164 96 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType 97 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs 98 | elseif opt.dataset == 'cifar100' then 99 | -- Default shortcutType=A and nEpochs=164 100 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType 101 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs 102 | else 103 | cmd:error('unknown dataset: ' .. opt.dataset) 104 | end 105 | 106 | if opt.precision == nil or opt.precision == 'single' then 107 | opt.tensorType = 'torch.CudaTensor' 108 | elseif opt.precision == 'double' then 109 | opt.tensorType = 'torch.CudaDoubleTensor' 110 | elseif opt.precision == 'half' then 111 | opt.tensorType = 'torch.CudaHalfTensor' 112 | else 113 | cmd:error('unknown precision: ' .. opt.precision) 114 | end 115 | 116 | if opt.resetClassifier then 117 | if opt.nClasses == 0 then 118 | cmd:error('-nClasses required when resetClassifier is set') 119 | end 120 | end 121 | if opt.shareGradInput and opt.optnet then 122 | cmd:error('error: cannot use both -shareGradInput and -optnet') 123 | end 124 | 125 | return opt 126 | end 127 | 128 | return M 129 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- The training loop and learning rate schedule 10 | -- 11 | -- Code modified for DenseNet (https://arxiv.org/abs/1608.06993) by Gao Huang. 12 | -- 13 | 14 | local optim = require 'optim' 15 | 16 | local M = {} 17 | local Trainer = torch.class('resnet.Trainer', M) 18 | 19 | function Trainer:__init(model, criterion, opt, optimState) 20 | self.model = model 21 | self.criterion = criterion 22 | self.optimState = optimState or { 23 | learningRate = opt.LR, 24 | learningRateDecay = 0.0, 25 | momentum = opt.momentum, 26 | nesterov = true, 27 | dampening = 0.0, 28 | weightDecay = opt.weightDecay, 29 | } 30 | self.opt = opt 31 | self.params, self.gradParams = model:getParameters() 32 | end 33 | 34 | function Trainer:train(epoch, dataloader) 35 | -- Trains the model for a single epoch 36 | 37 | ------for LR------ 38 | if self.opt.lrShape == 'multistep' then 39 | self.optimState.learningRate = self:learningRate(epoch) 40 | end 41 | ------for LR------ 42 | 43 | local timer = torch.Timer() 44 | local dataTimer = torch.Timer() 45 | 46 | local function feval() 47 | return self.criterion.output, self.gradParams 48 | end 49 | 50 | local trainSize = dataloader:size() 51 | local top1Sum, top5Sum, lossSum = 0.0, 0.0, 0.0 52 | local N = 0 53 | 54 | print('=> Training epoch # ' .. epoch) 55 | -- set the batch norm to training mode 56 | self.model:training() 57 | for n, sample in dataloader:run() do 58 | 59 | ------for LR------ 60 | if self.opt.lrShape == 'cosine' then 61 | self.optimState.learningRate = self:learningRateCosine(epoch, n, trainSize) 62 | end 63 | ------for LR------ 64 | 65 | local dataTime = dataTimer:time().real 66 | 67 | -- Copy input and target to the GPU 68 | self:copyInputs(sample) 69 | 70 | local output = self.model:forward(self.input):float() 71 | local batchSize = output:size(1) 72 | local loss = self.criterion:forward(self.model.output, self.target) 73 | 74 | self.model:zeroGradParameters() 75 | self.criterion:backward(self.model.output, self.target) 76 | self.model:backward(self.input, self.criterion.gradInput) 77 | 78 | optim.sgd(feval, self.params, self.optimState) 79 | 80 | local top1, top5 = self:computeScore(output, sample.target, 1) 81 | top1Sum = top1Sum + top1*batchSize 82 | top5Sum = top5Sum + top5*batchSize 83 | lossSum = lossSum + loss*batchSize 84 | N = N + batchSize 85 | 86 | print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.3f top1 %7.2f top5 %7.2f lr %.4f'):format( 87 | epoch, n, trainSize, timer:time().real, dataTime, loss, top1, top5, self.optimState.learningRate)) 88 | 89 | -- check that the storage didn't get changed due to an unfortunate getParameters call 90 | assert(self.params:storage() == self.model:parameters()[1]:storage()) 91 | 92 | timer:reset() 93 | dataTimer:reset() 94 | end 95 | 96 | return top1Sum / N, top5Sum / N, lossSum / N 97 | end 98 | 99 | function Trainer:test(epoch, dataloader) 100 | -- Computes the top-1 and top-5 err on the validation set 101 | 102 | local timer = torch.Timer() 103 | local dataTimer = torch.Timer() 104 | local size = dataloader:size() 105 | 106 | local nCrops = self.opt.tenCrop and 10 or 1 107 | local top1Sum, top5Sum = 0.0, 0.0 108 | local N = 0 109 | 110 | self.model:evaluate() 111 | for n, sample in dataloader:run() do 112 | local dataTime = dataTimer:time().real 113 | 114 | -- Copy input and target to the GPU 115 | self:copyInputs(sample) 116 | 117 | local output = self.model:forward(self.input):float() 118 | local batchSize = output:size(1) / nCrops 119 | local loss = self.criterion:forward(self.model.output, self.target) 120 | 121 | local top1, top5 = self:computeScore(output, sample.target, nCrops) 122 | top1Sum = top1Sum + top1*batchSize 123 | top5Sum = top5Sum + top5*batchSize 124 | N = N + batchSize 125 | 126 | print((' | Test: [%d][%d/%d] Time %.3f Data %.3f top1 %7.3f (%7.3f) top5 %7.3f (%7.3f)'):format( 127 | epoch, n, size, timer:time().real, dataTime, top1, top1Sum / N, top5, top5Sum / N)) 128 | 129 | timer:reset() 130 | dataTimer:reset() 131 | end 132 | self.model:training() 133 | 134 | print((' * Finished epoch # %d top1: %7.3f top5: %7.3f\n'):format( 135 | epoch, top1Sum / N, top5Sum / N)) 136 | 137 | return top1Sum / N, top5Sum / N 138 | end 139 | 140 | function Trainer:computeScore(output, target, nCrops) 141 | if nCrops > 1 then 142 | -- Sum over crops 143 | output = output:view(output:size(1) / nCrops, nCrops, output:size(2)) 144 | --:exp() 145 | :sum(2):squeeze(2) 146 | end 147 | 148 | -- Coputes the top1 and top5 error rate 149 | local batchSize = output:size(1) 150 | 151 | local _ , predictions = output:float():topk(5, 2, true, true) -- descending 152 | 153 | -- Find which predictions match the target 154 | local correct = predictions:eq( 155 | target:long():view(batchSize, 1):expandAs(predictions)) 156 | 157 | -- Top-1 score 158 | local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize) 159 | 160 | -- Top-5 score, if there are at least 5 classes 161 | local len = math.min(5, correct:size(2)) 162 | local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize) 163 | 164 | return top1 * 100, top5 * 100 165 | end 166 | 167 | local function getCudaTensorType(tensorType) 168 | if tensorType == 'torch.CudaHalfTensor' then 169 | return cutorch.createCudaHostHalfTensor() 170 | elseif tensorType == 'torch.CudaDoubleTensor' then 171 | return cutorch.createCudaHostDoubleTensor() 172 | else 173 | return cutorch.createCudaHostTensor() 174 | end 175 | end 176 | 177 | function Trainer:copyInputs(sample) 178 | -- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory, 179 | -- if using DataParallelTable. The target is always copied to a CUDA tensor 180 | self.input = self.input or (self.opt.nGPU == 1 181 | and torch[self.opt.tensorType:match('torch.(%a+)')]() 182 | or getCudaTensorType(self.opt.tensorType)) 183 | self.target = self.target or (torch.CudaLongTensor and torch.CudaLongTensor()) 184 | self.input:resize(sample.input:size()):copy(sample.input) 185 | self.target:resize(sample.target:size()):copy(sample.target) 186 | end 187 | 188 | function Trainer:learningRate(epoch) 189 | -- Training schedule 190 | local decay = 0 191 | if self.opt.dataset == 'imagenet' then 192 | decay = math.floor((epoch - 1) / 30) 193 | elseif self.opt.dataset == 'cifar10' then 194 | decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0 195 | elseif self.opt.dataset == 'cifar100' then 196 | decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0 197 | end 198 | return self.opt.LR * math.pow(0.1, decay) 199 | end 200 | 201 | ------for LR------ 202 | function Trainer:learningRateCosine(epoch, iter, nBatches) 203 | local nEpochs = self.opt.nEpochs 204 | local T_total = nEpochs * nBatches 205 | local T_cur = ((epoch-1) % nEpochs) * nBatches + iter 206 | return 0.5 * self.opt.LR * (1 + torch.cos(math.pi * T_cur / T_total)) 207 | end 208 | ------for LR------ 209 | 210 | return M.Trainer 211 | --------------------------------------------------------------------------------