├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── checkpoints.lua ├── dataloader.lua ├── datasets ├── README.md ├── cifar10-gen.lua ├── cifar10.lua ├── cifar100-gen.lua ├── cifar100.lua ├── imagenet-gen.lua ├── imagenet.lua ├── init.lua └── transforms.lua ├── main.lua ├── models ├── init.lua └── resnext.lua ├── opts.lua └── train.lua /.gitignore: -------------------------------------------------------------------------------- 1 | gen/ 2 | libnccl.so 3 | model_best.t7 4 | checkpoints 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ResNeXt 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `master`. 10 | 2. If you haven't already, complete the Contributor License Agreement ("CLA"). 11 | 12 | ## Contributor License Agreement ("CLA") 13 | In order to accept your pull request, we need you to submit a CLA. You only need 14 | to do this once to work on any of Facebook's open source projects. 15 | 16 | Complete your CLA here: 17 | 18 | ## Issues 19 | We use GitHub issues to track public bugs. Please ensure your description is 20 | clear and has sufficient instructions to be able to reproduce the issue. 21 | 22 | ## Coding Style 23 | * 3 spaces for indentation rather than tabs 24 | * 80 character line length 25 | 26 | ## License 27 | By contributing to ResNeXt, you agree that your contributions will be licensed 28 | under its [BSD license](https://github.com/facebookresearch/ResNeXt/blob/master/LICENSE). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For ResNeXt software 4 | 5 | Copyright (c) 2017, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ResNeXt: Aggregated Residual Transformations for Deep Neural Networks 2 | 3 | By [Saining Xie](http://vcl.ucsd.edu/~sxie), [Ross Girshick](http://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Zhuowen Tu](http://pages.ucsd.edu/~ztu/), [Kaiming He](http://kaiminghe.com) 4 | 5 | UC San Diego, Facebook AI Research 6 | 7 | ### Table of Contents 8 | 0. [Introduction](#introduction) 9 | 0. [Citation](#citation) 10 | 0. [Requirements and Dependencies](#requirements-and-dependencies) 11 | 0. [Training](#training) 12 | 0. [ImageNet Pretrained Models](#imagenet-pretrained-models) 13 | 0. [Third-party re-implementations](#third-party-re-implementations) 14 | 15 | #### News 16 | * Congrats to the ILSVRC 2017 classification challenge winner [WMW](http://image-net.org/challenges/LSVRC/2017/results). 17 | ResNeXt is the foundation of their new SENet architecture (a **ResNeXt-152 (64 x 4d)** with the Squeeze-and-Excitation module)! 18 | * Check out Figure 6 in the new [Memory-Efficient Implementation of DenseNets](https://arxiv.org/pdf/1707.06990.pdf) paper for a comparision between ResNeXts and DenseNets. (*DenseNet cosine is DenseNet trained with cosine learning rate schedule.*) 19 |

20 | 21 |

22 | 23 | 24 | ### Introduction 25 | 26 | This repository contains a [Torch](http://torch.ch) implementation for the [ResNeXt](https://arxiv.org/abs/1611.05431) algorithm for image classification. The code is based on [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch). 27 | 28 | [ResNeXt](https://arxiv.org/abs/1611.05431) is a simple, highly modularized network architecture for image classification. Our network is constructed by repeating a building block that aggregates a set of transformations with the same topology. Our simple design results in a homogeneous, multi-branch architecture that has only a few hyper-parameters to set. This strategy exposes a new dimension, which we call “cardinality” (the size of the set of transformations), as an essential factor in addition to the dimensions of depth and width. 29 | 30 | 31 | ![teaser](http://vcl.ucsd.edu/resnext/teaser.png) 32 | ##### Figure: Training curves on ImageNet-1K. (Left): ResNet/ResNeXt-50 with the same complexity (~4.1 billion FLOPs, ~25 million parameters); (Right): ResNet/ResNeXt-101 with the same complexity (~7.8 billion FLOPs, ~44 million parameters). 33 | ----- 34 | 35 | ### Citation 36 | If you use ResNeXt in your research, please cite the paper: 37 | ``` 38 | @article{Xie2016, 39 | title={Aggregated Residual Transformations for Deep Neural Networks}, 40 | author={Saining Xie and Ross Girshick and Piotr Dollár and Zhuowen Tu and Kaiming He}, 41 | journal={arXiv preprint arXiv:1611.05431}, 42 | year={2016} 43 | } 44 | ``` 45 | 46 | ### Requirements and Dependencies 47 | See the fb.resnet.torch [installation instructions](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md) for a step-by-step guide. 48 | - Install [Torch](http://torch.ch/docs/getting-started.html) on a machine with CUDA GPU 49 | - Install [cuDNN v4 or v5](https://developer.nvidia.com/cudnn) and the Torch [cuDNN bindings](https://github.com/soumith/cudnn.torch/tree/R4) 50 | - Download the [ImageNet](http://image-net.org/download-images) dataset and [move validation images](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset) to labeled subfolders 51 | 52 | ### Training 53 | 54 | Please follow [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) for the general usage of the code, including [how](https://github.com/facebook/fb.resnet.torch/tree/master/pretrained) to use pretrained ResNeXt models for your own task. 55 | 56 | There are two new hyperparameters need to be specified to determine the bottleneck template: 57 | 58 | **-baseWidth** and **-cardinality** 59 | 60 | ### 1x Complexity Configurations Reference Table 61 | 62 | | baseWidth | cardinality | 63 | |---------- | ----------- | 64 | | 64 | 1 | 65 | | 40 | 2 | 66 | | 24 | 4 | 67 | | 14 | 8 | 68 | | 4 | 32 | 69 | 70 | 71 | To train ResNeXt-50 (32x4d) on 8 GPUs for ImageNet: 72 | ```bash 73 | th main.lua -dataset imagenet -bottleneckType resnext_C -depth 50 -baseWidth 4 -cardinality 32 -batchSize 256 -nGPU 8 -nThreads 8 -shareGradInput true -data [imagenet-folder] 74 | ``` 75 | 76 | To reproduce CIFAR results (e.g. ResNeXt 16x64d for cifar10) on 8 GPUs: 77 | ```bash 78 | th main.lua -dataset cifar10 -bottleneckType resnext_C -depth 29 -baseWidth 64 -cardinality 16 -weightDecay 5e-4 -batchSize 128 -nGPU 8 -nThreads 8 -shareGradInput true 79 | ``` 80 | To get comparable results using 2/4 GPUs, you should change the batch size and the corresponding learning rate: 81 | ```bash 82 | th main.lua -dataset cifar10 -bottleneckType resnext_C -depth 29 -baseWidth 64 -cardinality 16 -weightDecay 5e-4 -batchSize 64 -nGPU 4 -LR 0.05 -nThreads 8 -shareGradInput true 83 | th main.lua -dataset cifar10 -bottleneckType resnext_C -depth 29 -baseWidth 64 -cardinality 16 -weightDecay 5e-4 -batchSize 32 -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true 84 | ``` 85 | Note: CIFAR datasets will be automatically downloaded and processed for the first time. Note that in the arXiv paper CIFAR results are based on pre-activated bottleneck blocks and a batch size of 256. We found that better CIFAR test acurracy can be achieved using original bottleneck blocks and a batch size of 128. 86 | 87 | ### ImageNet Pretrained Models 88 | ImageNet pretrained models are licensed under CC BY-NC 4.0. 89 | 90 | [![CC BY-NC 4.0](https://i.creativecommons.org/l/by-nc/4.0/88x31.png)](https://creativecommons.org/licenses/by-nc/4.0/) 91 | 92 | #### Single-crop (224x224) validation error rate 93 | | Network | GFLOPS | Top-1 Error | Download | 94 | | ------------------- | ------ | ----------- | ------------| 95 | | ResNet-50 (1x64d) | ~4.1 | 23.9 | [Original ResNet-50](https://github.com/facebook/fb.resnet.torch/tree/master/pretrained) | 96 | | ResNeXt-50 (32x4d) | ~4.1 | 22.2 | [Download (191MB)](https://dl.fbaipublicfiles.com/resnext/imagenet_models/resnext_50_32x4d.t7) | 97 | | ResNet-101 (1x64d) | ~7.8 | 22.0 | [Original ResNet-101](https://github.com/facebook/fb.resnet.torch/tree/master/pretrained) | 98 | | ResNeXt-101 (32x4d) | ~7.8 | 21.2 | [Download (338MB)](https://dl.fbaipublicfiles.com/resnext/imagenet_models/resnext_101_32x4d.t7) | 99 | | ResNeXt-101 (64x4d) | ~15.6 | 20.4 | [Download (638MB)](https://dl.fbaipublicfiles.com/resnext/imagenet_models/resnext_101_64x4d.t7) | 100 | 101 | ### Third-party re-implementations 102 | 103 | Besides our torch implementation, we recommend to see also the following third-party re-implementations and extensions: 104 | 105 | 1. Training code in PyTorch [code](https://github.com/prlz77/ResNeXt.pytorch) 106 | 1. Converting ImageNet pretrained model to PyTorch model and source. [code](https://github.com/clcarwin/convert_torch_to_pytorch) 107 | 1. Training code in MXNet and pretrained ImageNet models [code](https://github.com/dmlc/mxnet/tree/master/example/image-classification#imagenet-1k) 108 | 1. Caffe prototxt, pretrained ImageNet models (with ResNeXt-152), curves [code](https://github.com/cypw/ResNeXt-1)[code](https://github.com/terrychenism/ResNeXt) 109 | -------------------------------------------------------------------------------- /checkpoints.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | local checkpoint = {} 10 | 11 | local function deepCopy(tbl) 12 | -- creates a copy of a network with new modules and the same tensors 13 | local copy = {} 14 | for k, v in pairs(tbl) do 15 | if type(v) == 'table' then 16 | copy[k] = deepCopy(v) 17 | else 18 | copy[k] = v 19 | end 20 | end 21 | if torch.typename(tbl) then 22 | torch.setmetatable(copy, torch.typename(tbl)) 23 | end 24 | return copy 25 | end 26 | 27 | function checkpoint.latest(opt) 28 | if opt.resume == 'none' then 29 | return nil 30 | end 31 | 32 | local latestPath = paths.concat(opt.resume, 'latest.t7') 33 | if not paths.filep(latestPath) then 34 | return nil 35 | end 36 | 37 | print('=> Loading checkpoint ' .. latestPath) 38 | local latest = torch.load(latestPath) 39 | local optimState = torch.load(paths.concat(opt.resume, latest.optimFile)) 40 | 41 | return latest, optimState 42 | end 43 | 44 | function checkpoint.save(epoch, model, optimState, isBestModel, opt) 45 | -- don't save the DataParallelTable for easier loading on other machines 46 | if torch.type(model) == 'nn.DataParallelTable' then 47 | model = model:get(1) 48 | end 49 | 50 | -- create a clean copy on the CPU without modifying the original network 51 | model = deepCopy(model):float():clearState() 52 | 53 | local modelFile = 'model_' .. epoch .. '.t7' 54 | local optimFile = 'optimState_' .. epoch .. '.t7' 55 | 56 | torch.save(paths.concat(opt.save, modelFile), model) 57 | torch.save(paths.concat(opt.save, optimFile), optimState) 58 | torch.save(paths.concat(opt.save, 'latest.t7'), { 59 | epoch = epoch, 60 | modelFile = modelFile, 61 | optimFile = optimFile, 62 | }) 63 | 64 | if isBestModel then 65 | torch.save(paths.concat(opt.save, 'model_best.t7'), model) 66 | end 67 | end 68 | 69 | return checkpoint 70 | -------------------------------------------------------------------------------- /dataloader.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Multi-threaded data loader 10 | -- 11 | 12 | local datasets = require 'datasets/init' 13 | local Threads = require 'threads' 14 | Threads.serialization('threads.sharedserialize') 15 | 16 | local M = {} 17 | local DataLoader = torch.class('resnet.DataLoader', M) 18 | 19 | function DataLoader.create(opt) 20 | -- The train and val loader 21 | local loaders = {} 22 | 23 | for i, split in ipairs{'train', 'val'} do 24 | local dataset = datasets.create(opt, split) 25 | loaders[i] = M.DataLoader(dataset, opt, split) 26 | end 27 | 28 | return table.unpack(loaders) 29 | end 30 | 31 | function DataLoader:__init(dataset, opt, split) 32 | local manualSeed = opt.manualSeed 33 | local function init() 34 | require('datasets/' .. opt.dataset) 35 | end 36 | local function main(idx) 37 | if manualSeed ~= 0 then 38 | torch.manualSeed(manualSeed + idx) 39 | end 40 | torch.setnumthreads(1) 41 | _G.dataset = dataset 42 | _G.preprocess = dataset:preprocess() 43 | return dataset:size() 44 | end 45 | 46 | local threads, sizes = Threads(opt.nThreads, init, main) 47 | self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1 48 | self.threads = threads 49 | self.__size = sizes[1][1] 50 | self.batchSize = math.floor(opt.batchSize / self.nCrops) 51 | local function getCPUType(tensorType) 52 | if tensorType == 'torch.CudaHalfTensor' then 53 | return 'HalfTensor' 54 | elseif tensorType == 'torch.CudaDoubleTensor' then 55 | return 'DoubleTensor' 56 | else 57 | return 'FloatTensor' 58 | end 59 | end 60 | self.cpuType = getCPUType(opt.tensorType) 61 | end 62 | 63 | function DataLoader:size() 64 | return math.ceil(self.__size / self.batchSize) 65 | end 66 | 67 | function DataLoader:run() 68 | local threads = self.threads 69 | local size, batchSize = self.__size, self.batchSize 70 | local perm = torch.randperm(size) 71 | 72 | local idx, sample = 1, nil 73 | local function enqueue() 74 | while idx <= size and threads:acceptsjob() do 75 | local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) 76 | threads:addjob( 77 | function(indices, nCrops, cpuType) 78 | local sz = indices:size(1) 79 | local batch, imageSize 80 | local target = torch.IntTensor(sz) 81 | for i, idx in ipairs(indices:totable()) do 82 | local sample = _G.dataset:get(idx) 83 | local input = _G.preprocess(sample.input) 84 | if not batch then 85 | imageSize = input:size():totable() 86 | if nCrops > 1 then table.remove(imageSize, 1) end 87 | batch = torch[cpuType](sz, nCrops, table.unpack(imageSize)) 88 | end 89 | batch[i]:copy(input) 90 | target[i] = sample.target 91 | end 92 | collectgarbage() 93 | return { 94 | input = batch:view(sz * nCrops, table.unpack(imageSize)), 95 | target = target, 96 | } 97 | end, 98 | function(_sample_) 99 | sample = _sample_ 100 | end, 101 | indices, 102 | self.nCrops, 103 | self.cpuType 104 | ) 105 | idx = idx + batchSize 106 | end 107 | end 108 | 109 | local n = 0 110 | local function loop() 111 | enqueue() 112 | if not threads:hasjob() then 113 | return nil 114 | end 115 | threads:dojob() 116 | if threads:haserror() then 117 | threads:synchronize() 118 | end 119 | enqueue() 120 | n = n + 1 121 | return n, sample 122 | end 123 | 124 | return loop 125 | end 126 | 127 | return M.DataLoader 128 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | ## Datasets 2 | 3 | Each dataset consist of two files: `dataset-gen.lua` and `dataset.lua`. The `dataset-gen.lua` is responsible for one-time setup, while 4 | the `dataset.lua` handles the actual data loading. 5 | 6 | If you want to be able to use the new dataset from main.lua, you should also modify `opts.lua` to handle the new dataset name. 7 | 8 | ### `dataset-gen.lua` 9 | 10 | The `dataset-gen.lua` performs any necessary one-time setup. For example, the [`cifar10-gen.lua`](cifar10-gen.lua) file downloads the CIFAR-10 dataset, and the [`imagenet-gen.lua`](imagenet-gen.lua) file indexes all the training and validation data. 11 | 12 | The module should have a single function `exec(opt, cacheFile)`. 13 | - `opt`: the command line options 14 | - `cacheFile`: path to output 15 | 16 | ```lua 17 | local M = {} 18 | function M.exec(opt, cacheFile) 19 | local imageInfo = {} 20 | -- preprocess dataset, store results in imageInfo, save to cacheFile 21 | torch.save(cacheFile, imageInfo) 22 | end 23 | return M 24 | ``` 25 | 26 | ### `dataset.lua` 27 | 28 | The `dataset.lua` should return a class that implements three functions: 29 | - `get(i)`: returns a table containing two entries, `input` and `target` 30 | - `input`: the training or validation image as a Torch tensor 31 | - `target`: the image category as a number 1-N 32 | - `size()`: returns the number of entries in the dataset 33 | - `preprocess()`: returns a function that transforms the `input` for data augmentation or input normalization 34 | 35 | ```lua 36 | local M = {} 37 | local FakeDataset = torch.class('resnet.FakeDataset', M) 38 | 39 | function FakeDataset:__init(imageInfo, opt, split) 40 | -- imageInfo: result from dataset-gen.lua 41 | -- opt: command-line arguments 42 | -- split: "train" or "val" 43 | end 44 | 45 | function FakeDataset:get(i) 46 | return { 47 | input = torch.Tensor(3, 800, 600):uniform(), 48 | target = 42, 49 | } 50 | end 51 | 52 | function FakeDataset:size() 53 | -- size of dataset 54 | return 2000 55 | end 56 | 57 | function FakeDataset:preprocess() 58 | -- Scale smaller side to 256 and take 224x224 center-crop 59 | return t.Compose{ 60 | t.Scale(256), 61 | t.CenterCrop(224), 62 | } 63 | end 64 | 65 | return M.FakeDataset 66 | ``` 67 | -------------------------------------------------------------------------------- /datasets/cifar10-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Script to compute list of ImageNet filenames and classes 10 | -- 11 | -- This automatically downloads the CIFAR-10 dataset from 12 | -- http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz 13 | -- 14 | 15 | local URL = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz' 16 | 17 | local M = {} 18 | 19 | local function convertToTensor(files) 20 | local data, labels 21 | 22 | for _, file in ipairs(files) do 23 | local m = torch.load(file, 'ascii') 24 | if not data then 25 | data = m.data:t() 26 | labels = m.labels:squeeze() 27 | else 28 | data = torch.cat(data, m.data:t(), 1) 29 | labels = torch.cat(labels, m.labels:squeeze()) 30 | end 31 | end 32 | 33 | -- This is *very* important. The downloaded files have labels 0-9, which do 34 | -- not work with CrossEntropyCriterion 35 | labels:add(1) 36 | 37 | return { 38 | data = data:contiguous():view(-1, 3, 32, 32), 39 | labels = labels, 40 | } 41 | end 42 | 43 | function M.exec(opt, cacheFile) 44 | print("=> Downloading CIFAR-10 dataset from " .. URL) 45 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') 46 | assert(ok == true or ok == 0, 'error downloading CIFAR-10') 47 | 48 | print(" | combining dataset into a single file") 49 | local trainData = convertToTensor({ 50 | 'gen/cifar-10-batches-t7/data_batch_1.t7', 51 | 'gen/cifar-10-batches-t7/data_batch_2.t7', 52 | 'gen/cifar-10-batches-t7/data_batch_3.t7', 53 | 'gen/cifar-10-batches-t7/data_batch_4.t7', 54 | 'gen/cifar-10-batches-t7/data_batch_5.t7', 55 | }) 56 | local testData = convertToTensor({ 57 | 'gen/cifar-10-batches-t7/test_batch.t7', 58 | }) 59 | 60 | print(" | saving CIFAR-10 dataset to " .. cacheFile) 61 | torch.save(cacheFile, { 62 | train = trainData, 63 | val = testData, 64 | }) 65 | end 66 | 67 | return M 68 | -------------------------------------------------------------------------------- /datasets/cifar10.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- CIFAR-10 dataset loader 10 | -- 11 | 12 | local t = require 'datasets/transforms' 13 | 14 | local M = {} 15 | local CifarDataset = torch.class('resnet.CifarDataset', M) 16 | 17 | function CifarDataset:__init(imageInfo, opt, split) 18 | assert(imageInfo[split], split) 19 | self.imageInfo = imageInfo[split] 20 | self.split = split 21 | end 22 | 23 | function CifarDataset:get(i) 24 | local image = self.imageInfo.data[i]:float() 25 | local label = self.imageInfo.labels[i] 26 | 27 | return { 28 | input = image, 29 | target = label, 30 | } 31 | end 32 | 33 | function CifarDataset:size() 34 | return self.imageInfo.data:size(1) 35 | end 36 | 37 | -- Computed from entire CIFAR-10 training set 38 | local meanstd = { 39 | mean = {125.3, 123.0, 113.9}, 40 | std = {63.0, 62.1, 66.7}, 41 | } 42 | 43 | function CifarDataset:preprocess() 44 | if self.split == 'train' then 45 | return t.Compose{ 46 | t.ColorNormalize(meanstd), 47 | t.HorizontalFlip(0.5), 48 | t.RandomCrop(32, 4), 49 | } 50 | elseif self.split == 'val' then 51 | return t.ColorNormalize(meanstd) 52 | else 53 | error('invalid split: ' .. self.split) 54 | end 55 | end 56 | 57 | return M.CifarDataset 58 | -------------------------------------------------------------------------------- /datasets/cifar100-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | ------------ 11 | -- This file automatically downloads the CIFAR-100 dataset from 12 | -- http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz 13 | -- It is based on cifar10-gen.lua 14 | -- Ludovic Trottier 15 | ------------ 16 | 17 | local URL = 'http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz' 18 | 19 | local M = {} 20 | 21 | local function convertCifar100BinToTorchTensor(inputFname) 22 | local m=torch.DiskFile(inputFname, 'r'):binary() 23 | m:seekEnd() 24 | local length = m:position() - 1 25 | local nSamples = length / 3074 -- 1 coarse-label byte, 1 fine-label byte, 3072 pixel bytes 26 | 27 | assert(nSamples == math.floor(nSamples), 'expecting numSamples to be an exact integer') 28 | m:seek(1) 29 | 30 | local coarse = torch.ByteTensor(nSamples) 31 | local fine = torch.ByteTensor(nSamples) 32 | local data = torch.ByteTensor(nSamples, 3, 32, 32) 33 | for i=1,nSamples do 34 | coarse[i] = m:readByte() 35 | fine[i] = m:readByte() 36 | local store = m:readByte(3072) 37 | data[i]:copy(torch.ByteTensor(store)) 38 | end 39 | 40 | local out = {} 41 | out.data = data 42 | -- This is *very* important. The downloaded files have labels 0-9, which do 43 | -- not work with CrossEntropyCriterion 44 | out.labels = fine + 1 45 | 46 | return out 47 | end 48 | 49 | function M.exec(opt, cacheFile) 50 | print("=> Downloading CIFAR-100 dataset from " .. URL) 51 | 52 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') 53 | assert(ok == true or ok == 0, 'error downloading CIFAR-100') 54 | 55 | print(" | combining dataset into a single file") 56 | 57 | local trainData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/train.bin') 58 | local testData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/test.bin') 59 | 60 | print(" | saving CIFAR-100 dataset to " .. cacheFile) 61 | torch.save(cacheFile, { 62 | train = trainData, 63 | val = testData, 64 | }) 65 | end 66 | 67 | return M 68 | -------------------------------------------------------------------------------- /datasets/cifar100.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | ------------ 11 | -- This file is downloading and transforming CIFAR-100. 12 | -- It is based on cifar10.lua 13 | -- Ludovic Trottier 14 | ------------ 15 | 16 | local t = require 'datasets/transforms' 17 | 18 | local M = {} 19 | local CifarDataset = torch.class('resnet.CifarDataset', M) 20 | 21 | function CifarDataset:__init(imageInfo, opt, split) 22 | assert(imageInfo[split], split) 23 | self.imageInfo = imageInfo[split] 24 | self.split = split 25 | end 26 | 27 | function CifarDataset:get(i) 28 | local image = self.imageInfo.data[i]:float() 29 | local label = self.imageInfo.labels[i] 30 | 31 | return { 32 | input = image, 33 | target = label, 34 | } 35 | end 36 | 37 | function CifarDataset:size() 38 | return self.imageInfo.data:size(1) 39 | end 40 | 41 | 42 | -- Computed from entire CIFAR-100 training set with this code: 43 | -- dataset = torch.load('cifar100.t7') 44 | -- tt = dataset.train.data:double(); 45 | -- tt = tt:transpose(2,4); 46 | -- tt = tt:reshape(50000*32*32, 3); 47 | -- tt:mean(1) 48 | -- tt:std(1) 49 | local meanstd = { 50 | mean = {129.3, 124.1, 112.4}, 51 | std = {68.2, 65.4, 70.4}, 52 | } 53 | 54 | function CifarDataset:preprocess() 55 | if self.split == 'train' then 56 | return t.Compose{ 57 | t.ColorNormalize(meanstd), 58 | t.HorizontalFlip(0.5), 59 | t.RandomCrop(32, 4), 60 | } 61 | elseif self.split == 'val' then 62 | return t.ColorNormalize(meanstd) 63 | else 64 | error('invalid split: ' .. self.split) 65 | end 66 | end 67 | 68 | return M.CifarDataset 69 | -------------------------------------------------------------------------------- /datasets/imagenet-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Script to compute list of ImageNet filenames and classes 10 | -- 11 | -- This generates a file gen/imagenet.t7 which contains the list of all 12 | -- ImageNet training and validation images and their classes. This script also 13 | -- works for other datasets arragned with the same layout. 14 | -- 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | 19 | local M = {} 20 | 21 | local function findClasses(dir) 22 | local dirs = paths.dir(dir) 23 | table.sort(dirs) 24 | 25 | local classList = {} 26 | local classToIdx = {} 27 | for _ ,class in ipairs(dirs) do 28 | if not classToIdx[class] and class ~= '.' and class ~= '..' and class ~= '.DS_Store' then 29 | table.insert(classList, class) 30 | classToIdx[class] = #classList 31 | end 32 | end 33 | 34 | -- assert(#classList == 1000, 'expected 1000 ImageNet classes') 35 | return classList, classToIdx 36 | end 37 | 38 | local function findImages(dir, classToIdx) 39 | local imagePath = torch.CharTensor() 40 | local imageClass = torch.LongTensor() 41 | 42 | ---------------------------------------------------------------------- 43 | -- Options for the GNU and BSD find command 44 | local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 45 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 46 | for i=2,#extensionList do 47 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 48 | end 49 | 50 | -- Find all the images using the find command 51 | local f = io.popen('find -L ' .. dir .. findOptions) 52 | 53 | local maxLength = -1 54 | local imagePaths = {} 55 | local imageClasses = {} 56 | 57 | -- Generate a list of all the images and their class 58 | while true do 59 | local line = f:read('*line') 60 | if not line then break end 61 | 62 | local className = paths.basename(paths.dirname(line)) 63 | local filename = paths.basename(line) 64 | local path = className .. '/' .. filename 65 | 66 | local classId = classToIdx[className] 67 | assert(classId, 'class not found: ' .. className) 68 | 69 | table.insert(imagePaths, path) 70 | table.insert(imageClasses, classId) 71 | 72 | maxLength = math.max(maxLength, #path + 1) 73 | end 74 | 75 | f:close() 76 | 77 | -- Convert the generated list to a tensor for faster loading 78 | local nImages = #imagePaths 79 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 80 | for i, path in ipairs(imagePaths) do 81 | ffi.copy(imagePath[i]:data(), path) 82 | end 83 | 84 | local imageClass = torch.LongTensor(imageClasses) 85 | return imagePath, imageClass 86 | end 87 | 88 | function M.exec(opt, cacheFile) 89 | -- find the image path names 90 | local imagePath = torch.CharTensor() -- path to each image in dataset 91 | local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 92 | 93 | local trainDir = paths.concat(opt.data, 'train') 94 | local valDir = paths.concat(opt.data, 'val') 95 | assert(paths.dirp(trainDir), 'train directory not found: ' .. trainDir) 96 | assert(paths.dirp(valDir), 'val directory not found: ' .. valDir) 97 | 98 | print("=> Generating list of images") 99 | local classList, classToIdx = findClasses(trainDir) 100 | 101 | print(" | finding all validation images") 102 | local valImagePath, valImageClass = findImages(valDir, classToIdx) 103 | 104 | print(" | finding all training images") 105 | local trainImagePath, trainImageClass = findImages(trainDir, classToIdx) 106 | 107 | local info = { 108 | basedir = opt.data, 109 | classList = classList, 110 | train = { 111 | imagePath = trainImagePath, 112 | imageClass = trainImageClass, 113 | }, 114 | val = { 115 | imagePath = valImagePath, 116 | imageClass = valImageClass, 117 | }, 118 | } 119 | 120 | print(" | saving list of images to " .. cacheFile) 121 | torch.save(cacheFile, info) 122 | return info 123 | end 124 | 125 | return M 126 | -------------------------------------------------------------------------------- /datasets/imagenet.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- ImageNet dataset loader 10 | -- 11 | 12 | local image = require 'image' 13 | local paths = require 'paths' 14 | local t = require 'datasets/transforms' 15 | local ffi = require 'ffi' 16 | 17 | local M = {} 18 | local ImagenetDataset = torch.class('resnet.ImagenetDataset', M) 19 | 20 | function ImagenetDataset:__init(imageInfo, opt, split) 21 | self.imageInfo = imageInfo[split] 22 | self.opt = opt 23 | self.split = split 24 | self.dir = paths.concat(opt.data, split) 25 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) 26 | end 27 | 28 | function ImagenetDataset:get(i) 29 | local path = ffi.string(self.imageInfo.imagePath[i]:data()) 30 | 31 | local image = self:_loadImage(paths.concat(self.dir, path)) 32 | local class = self.imageInfo.imageClass[i] 33 | 34 | return { 35 | input = image, 36 | target = class, 37 | } 38 | end 39 | 40 | function ImagenetDataset:_loadImage(path) 41 | local ok, input = pcall(function() 42 | return image.load(path, 3, 'float') 43 | end) 44 | 45 | -- Sometimes image.load fails because the file extension does not match the 46 | -- image format. In that case, use image.decompress on a ByteTensor. 47 | if not ok then 48 | local f = io.open(path, 'r') 49 | assert(f, 'Error reading: ' .. tostring(path)) 50 | local data = f:read('*a') 51 | f:close() 52 | 53 | local b = torch.ByteTensor(string.len(data)) 54 | ffi.copy(b:data(), data, b:size(1)) 55 | 56 | input = image.decompress(b, 3, 'float') 57 | end 58 | 59 | return input 60 | end 61 | 62 | function ImagenetDataset:size() 63 | return self.imageInfo.imageClass:size(1) 64 | end 65 | 66 | -- Computed from random subset of ImageNet training images 67 | local meanstd = { 68 | mean = { 0.485, 0.456, 0.406 }, 69 | std = { 0.229, 0.224, 0.225 }, 70 | } 71 | local pca = { 72 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 73 | eigvec = torch.Tensor{ 74 | { -0.5675, 0.7192, 0.4009 }, 75 | { -0.5808, -0.0045, -0.8140 }, 76 | { -0.5836, -0.6948, 0.4203 }, 77 | }, 78 | } 79 | 80 | function ImagenetDataset:preprocess() 81 | if self.split == 'train' then 82 | return t.Compose{ 83 | t.RandomSizedCrop(224), 84 | t.ColorJitter({ 85 | brightness = 0.4, 86 | contrast = 0.4, 87 | saturation = 0.4, 88 | }), 89 | t.Lighting(0.1, pca.eigval, pca.eigvec), 90 | t.ColorNormalize(meanstd), 91 | t.HorizontalFlip(0.5), 92 | } 93 | elseif self.split == 'val' then 94 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop 95 | return t.Compose{ 96 | t.Scale(256), 97 | t.ColorNormalize(meanstd), 98 | Crop(224), 99 | } 100 | else 101 | error('invalid split: ' .. self.split) 102 | end 103 | end 104 | 105 | return M.ImagenetDataset 106 | -------------------------------------------------------------------------------- /datasets/init.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- ImageNet and CIFAR-10 datasets 10 | -- 11 | 12 | local M = {} 13 | 14 | local function isvalid(opt, cachePath) 15 | local imageInfo = torch.load(cachePath) 16 | if imageInfo.basedir and imageInfo.basedir ~= opt.data then 17 | return false 18 | end 19 | return true 20 | end 21 | 22 | function M.create(opt, split) 23 | local cachePath = paths.concat(opt.gen, opt.dataset .. '.t7') 24 | if not paths.filep(cachePath) or not isvalid(opt, cachePath) then 25 | paths.mkdir('gen') 26 | 27 | local script = paths.dofile(opt.dataset .. '-gen.lua') 28 | script.exec(opt, cachePath) 29 | end 30 | local imageInfo = torch.load(cachePath) 31 | 32 | local Dataset = require('datasets/' .. opt.dataset) 33 | return Dataset(imageInfo, opt, split) 34 | end 35 | 36 | return M 37 | -------------------------------------------------------------------------------- /datasets/transforms.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Image transforms for data augmentation and input normalization 10 | -- 11 | 12 | require 'image' 13 | 14 | local M = {} 15 | 16 | function M.Compose(transforms) 17 | return function(input) 18 | for _, transform in ipairs(transforms) do 19 | input = transform(input) 20 | end 21 | return input 22 | end 23 | end 24 | 25 | function M.ColorNormalize(meanstd) 26 | return function(img) 27 | img = img:clone() 28 | for i=1,3 do 29 | img[i]:add(-meanstd.mean[i]) 30 | img[i]:div(meanstd.std[i]) 31 | end 32 | return img 33 | end 34 | end 35 | 36 | -- Scales the smaller edge to size 37 | function M.Scale(size, interpolation) 38 | interpolation = interpolation or 'bicubic' 39 | return function(input) 40 | local w, h = input:size(3), input:size(2) 41 | if (w <= h and w == size) or (h <= w and h == size) then 42 | return input 43 | end 44 | if w < h then 45 | return image.scale(input, size, h/w * size, interpolation) 46 | else 47 | return image.scale(input, w/h * size, size, interpolation) 48 | end 49 | end 50 | end 51 | 52 | -- Crop to centered rectangle 53 | function M.CenterCrop(size) 54 | return function(input) 55 | local w1 = math.ceil((input:size(3) - size)/2) 56 | local h1 = math.ceil((input:size(2) - size)/2) 57 | return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch 58 | end 59 | end 60 | 61 | -- Random crop form larger image with optional zero padding 62 | function M.RandomCrop(size, padding) 63 | padding = padding or 0 64 | 65 | return function(input) 66 | if padding > 0 then 67 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) 68 | temp:zero() 69 | :narrow(2, padding+1, input:size(2)) 70 | :narrow(3, padding+1, input:size(3)) 71 | :copy(input) 72 | input = temp 73 | end 74 | 75 | local w, h = input:size(3), input:size(2) 76 | if w == size and h == size then 77 | return input 78 | end 79 | 80 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 81 | local out = image.crop(input, x1, y1, x1 + size, y1 + size) 82 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 83 | return out 84 | end 85 | end 86 | 87 | -- Four corner patches and center crop from image and its horizontal reflection 88 | function M.TenCrop(size) 89 | local centerCrop = M.CenterCrop(size) 90 | 91 | return function(input) 92 | local w, h = input:size(3), input:size(2) 93 | 94 | local output = {} 95 | for _, img in ipairs{input, image.hflip(input)} do 96 | table.insert(output, centerCrop(img)) 97 | table.insert(output, image.crop(img, 0, 0, size, size)) 98 | table.insert(output, image.crop(img, w-size, 0, w, size)) 99 | table.insert(output, image.crop(img, 0, h-size, size, h)) 100 | table.insert(output, image.crop(img, w-size, h-size, w, h)) 101 | end 102 | 103 | -- View as mini-batch 104 | for i, img in ipairs(output) do 105 | output[i] = img:view(1, img:size(1), img:size(2), img:size(3)) 106 | end 107 | 108 | return input.cat(output, 1) 109 | end 110 | end 111 | 112 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style) 113 | function M.RandomScale(minSize, maxSize) 114 | return function(input) 115 | local w, h = input:size(3), input:size(2) 116 | 117 | local targetSz = torch.random(minSize, maxSize) 118 | local targetW, targetH = targetSz, targetSz 119 | if w < h then 120 | targetH = torch.round(h / w * targetW) 121 | else 122 | targetW = torch.round(w / h * targetH) 123 | end 124 | 125 | return image.scale(input, targetW, targetH, 'bicubic') 126 | end 127 | end 128 | 129 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style) 130 | function M.RandomSizedCrop(size) 131 | local scale = M.Scale(size) 132 | local crop = M.CenterCrop(size) 133 | 134 | return function(input) 135 | local attempt = 0 136 | repeat 137 | local area = input:size(2) * input:size(3) 138 | local targetArea = torch.uniform(0.08, 1.0) * area 139 | 140 | local aspectRatio = torch.uniform(3/4, 4/3) 141 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 142 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 143 | 144 | if torch.uniform() < 0.5 then 145 | w, h = h, w 146 | end 147 | 148 | if h <= input:size(2) and w <= input:size(3) then 149 | local y1 = torch.random(0, input:size(2) - h) 150 | local x1 = torch.random(0, input:size(3) - w) 151 | 152 | local out = image.crop(input, x1, y1, x1 + w, y1 + h) 153 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 154 | 155 | return image.scale(out, size, size, 'bicubic') 156 | end 157 | attempt = attempt + 1 158 | until attempt >= 10 159 | 160 | -- fallback 161 | return crop(scale(input)) 162 | end 163 | end 164 | 165 | function M.HorizontalFlip(prob) 166 | return function(input) 167 | if torch.uniform() < prob then 168 | input = image.hflip(input) 169 | end 170 | return input 171 | end 172 | end 173 | 174 | function M.Rotation(deg) 175 | return function(input) 176 | if deg ~= 0 then 177 | input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear') 178 | end 179 | return input 180 | end 181 | end 182 | 183 | -- Lighting noise (AlexNet-style PCA-based noise) 184 | function M.Lighting(alphastd, eigval, eigvec) 185 | return function(input) 186 | if alphastd == 0 then 187 | return input 188 | end 189 | 190 | local alpha = torch.Tensor(3):normal(0, alphastd) 191 | local rgb = eigvec:clone() 192 | :cmul(alpha:view(1, 3):expand(3, 3)) 193 | :cmul(eigval:view(1, 3):expand(3, 3)) 194 | :sum(2) 195 | :squeeze() 196 | 197 | input = input:clone() 198 | for i=1,3 do 199 | input[i]:add(rgb[i]) 200 | end 201 | return input 202 | end 203 | end 204 | 205 | local function blend(img1, img2, alpha) 206 | return img1:mul(alpha):add(1 - alpha, img2) 207 | end 208 | 209 | local function grayscale(dst, img) 210 | dst:resizeAs(img) 211 | dst[1]:zero() 212 | dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3]) 213 | dst[2]:copy(dst[1]) 214 | dst[3]:copy(dst[1]) 215 | return dst 216 | end 217 | 218 | function M.Saturation(var) 219 | local gs 220 | 221 | return function(input) 222 | gs = gs or input.new() 223 | grayscale(gs, input) 224 | 225 | local alpha = 1.0 + torch.uniform(-var, var) 226 | blend(input, gs, alpha) 227 | return input 228 | end 229 | end 230 | 231 | function M.Brightness(var) 232 | local gs 233 | 234 | return function(input) 235 | gs = gs or input.new() 236 | gs:resizeAs(input):zero() 237 | 238 | local alpha = 1.0 + torch.uniform(-var, var) 239 | blend(input, gs, alpha) 240 | return input 241 | end 242 | end 243 | 244 | function M.Contrast(var) 245 | local gs 246 | 247 | return function(input) 248 | gs = gs or input.new() 249 | grayscale(gs, input) 250 | gs:fill(gs[1]:mean()) 251 | 252 | local alpha = 1.0 + torch.uniform(-var, var) 253 | blend(input, gs, alpha) 254 | return input 255 | end 256 | end 257 | 258 | function M.RandomOrder(ts) 259 | return function(input) 260 | local img = input.img or input 261 | local order = torch.randperm(#ts) 262 | for i=1,#ts do 263 | img = ts[order[i]](img) 264 | end 265 | return img 266 | end 267 | end 268 | 269 | function M.ColorJitter(opt) 270 | local brightness = opt.brightness or 0 271 | local contrast = opt.contrast or 0 272 | local saturation = opt.saturation or 0 273 | 274 | local ts = {} 275 | if brightness ~= 0 then 276 | table.insert(ts, M.Brightness(brightness)) 277 | end 278 | if contrast ~= 0 then 279 | table.insert(ts, M.Contrast(contrast)) 280 | end 281 | if saturation ~= 0 then 282 | table.insert(ts, M.Saturation(saturation)) 283 | end 284 | 285 | if #ts == 0 then 286 | return function(input) return input end 287 | end 288 | 289 | return M.RandomOrder(ts) 290 | end 291 | 292 | return M 293 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | require 'torch' 10 | require 'paths' 11 | require 'optim' 12 | require 'nn' 13 | local DataLoader = require 'dataloader' 14 | local models = require 'models/init' 15 | local Trainer = require 'train' 16 | local opts = require 'opts' 17 | local checkpoints = require 'checkpoints' 18 | 19 | -- we don't change this to the 'correct' type (e.g. HalfTensor), because math 20 | -- isn't supported on that type. Type conversion later will handle having 21 | -- the correct type. 22 | torch.setdefaulttensortype('torch.FloatTensor') 23 | torch.setnumthreads(1) 24 | 25 | local opt = opts.parse(arg) 26 | torch.manualSeed(opt.manualSeed) 27 | cutorch.manualSeedAll(opt.manualSeed) 28 | 29 | -- Load previous checkpoint, if it exists 30 | local checkpoint, optimState = checkpoints.latest(opt) 31 | 32 | -- Create model 33 | local model, criterion = models.setup(opt, checkpoint) 34 | 35 | -- Data loading 36 | local trainLoader, valLoader = DataLoader.create(opt) 37 | 38 | -- The trainer handles the training loop and evaluation on validation set 39 | local trainer = Trainer(model, criterion, opt, optimState) 40 | 41 | if opt.testOnly then 42 | local top1Err, top5Err = trainer:test(0, valLoader) 43 | print(string.format(' * Results top1: %6.3f top5: %6.3f', top1Err, top5Err)) 44 | return 45 | end 46 | 47 | local startEpoch = checkpoint and checkpoint.epoch + 1 or opt.epochNumber 48 | local bestTop1 = math.huge 49 | local bestTop5 = math.huge 50 | for epoch = startEpoch, opt.nEpochs do 51 | -- Train for a single epoch 52 | local trainTop1, trainTop5, trainLoss = trainer:train(epoch, trainLoader) 53 | 54 | -- Run model on validation set 55 | local testTop1, testTop5 = trainer:test(epoch, valLoader) 56 | 57 | local bestModel = false 58 | if testTop1 < bestTop1 then 59 | bestModel = true 60 | bestTop1 = testTop1 61 | bestTop5 = testTop5 62 | print(' * Best model ', testTop1, testTop5) 63 | end 64 | 65 | checkpoints.save(epoch, model, trainer.optimState, bestModel, opt) 66 | end 67 | 68 | print(string.format(' * Finished top1: %6.3f top5: %6.3f', bestTop1, bestTop5)) 69 | -------------------------------------------------------------------------------- /models/init.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Generic model creating code. For the specific ResNeXt model see 10 | -- models/resnext.lua 11 | -- 12 | 13 | require 'nn' 14 | require 'cunn' 15 | require 'cudnn' 16 | 17 | local M = {} 18 | 19 | function M.setup(opt, checkpoint) 20 | local model 21 | if checkpoint then 22 | local modelPath = paths.concat(opt.resume, checkpoint.modelFile) 23 | assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath) 24 | print('=> Resuming model from ' .. modelPath) 25 | model = torch.load(modelPath):type(opt.tensorType) 26 | elseif opt.retrain ~= 'none' then 27 | assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain) 28 | print('Loading model from file: ' .. opt.retrain) 29 | model = torch.load(opt.retrain):type(opt.tensorType) 30 | model.__memoryOptimized = nil 31 | else 32 | print('=> Creating model from file: models/' .. opt.netType .. '.lua') 33 | model = require('models/' .. opt.netType)(opt) 34 | end 35 | 36 | -- First remove any DataParallelTable 37 | if torch.type(model) == 'nn.DataParallelTable' then 38 | model = model:get(1) 39 | end 40 | 41 | -- optnet is an general library for reducing memory usage in neural networks 42 | if opt.optnet then 43 | local optnet = require 'optnet' 44 | local imsize = opt.dataset == 'imagenet' and 224 or 32 45 | local sampleInput = torch.zeros(4,3,imsize,imsize):type(opt.tensorType) 46 | optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'}) 47 | end 48 | 49 | -- This is useful for fitting ResNet/ResNeXt-50 on 4 GPUs, but requires that all 50 | -- containers override backwards to call backwards recursively on submodules 51 | if opt.shareGradInput then 52 | M.shareGradInput(model, opt) 53 | end 54 | 55 | -- For resetting the classifier when fine-tuning on a different Dataset 56 | if opt.resetClassifier and not checkpoint then 57 | print(' => Replacing classifier with ' .. opt.nClasses .. '-way classifier') 58 | 59 | local orig = model:get(#model.modules) 60 | assert(torch.type(orig) == 'nn.Linear', 61 | 'expected last layer to be fully connected') 62 | 63 | local linear = nn.Linear(orig.weight:size(2), opt.nClasses) 64 | linear.bias:zero() 65 | 66 | model:remove(#model.modules) 67 | model:add(linear:type(opt.tensorType)) 68 | end 69 | 70 | -- Set the CUDNN flags 71 | if opt.cudnn == 'fastest' then 72 | cudnn.fastest = true 73 | cudnn.benchmark = true 74 | elseif opt.cudnn == 'deterministic' then 75 | -- Use a deterministic convolution implementation 76 | model:apply(function(m) 77 | if m.setMode then m:setMode(1, 1, 1) end 78 | end) 79 | end 80 | 81 | -- Wrap the model with DataParallelTable, if using more than one GPU 82 | if opt.nGPU > 1 then 83 | local gpus = torch.range(1, opt.nGPU):totable() 84 | local fastest, benchmark = cudnn.fastest, cudnn.benchmark 85 | 86 | local dpt = nn.DataParallelTable(1, true, true) 87 | :add(model, gpus) 88 | :threads(function() 89 | local cudnn = require 'cudnn' 90 | cudnn.fastest, cudnn.benchmark = fastest, benchmark 91 | end) 92 | dpt.gradInput = nil 93 | 94 | model = dpt:type(opt.tensorType) 95 | end 96 | 97 | local criterion = nn.CrossEntropyCriterion():type(opt.tensorType) 98 | return model, criterion 99 | end 100 | 101 | function M.shareGradInput(model) 102 | print('shareGradInput function') 103 | local function sharingKey(m) 104 | local key = torch.type(m) 105 | if m.__shareGradInputKey then 106 | key = key .. ':' .. m.__shareGradInputKey 107 | end 108 | return key 109 | end 110 | 111 | -- Share gradInput for memory efficient backprop 112 | local cache = {} 113 | local input_sign_cache = {} 114 | model:apply(function(m) 115 | local moduleType = torch.type(m) 116 | if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' and moduleType ~= 'nn.JoinTable' then 117 | local key = sharingKey(m) 118 | if cache[key] == nil then 119 | cache[key] = torch.CudaStorage(1) 120 | end 121 | m.gradInput = torch.CudaTensor(cache[key], 1, 0) 122 | end 123 | end) 124 | for i, m in ipairs(model:findModules('nn.ConcatTable')) do 125 | if cache[i % 3] == nil then 126 | cache[i % 3] = torch.CudaStorage(1) 127 | end 128 | m.gradInput = torch.CudaTensor(cache[i % 3], 1, 0) 129 | end 130 | end 131 | 132 | return M 133 | -------------------------------------------------------------------------------- /models/resnext.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- The ResNeXt model definition 10 | -- 11 | 12 | local nn = require 'nn' 13 | require 'cunn' 14 | 15 | local Convolution = cudnn.SpatialConvolution 16 | local Avg = cudnn.SpatialAveragePooling 17 | local ReLU = cudnn.ReLU 18 | local Max = nn.SpatialMaxPooling 19 | local SBatchNorm = nn.SpatialBatchNormalization 20 | 21 | local function createModel(opt) 22 | local depth = opt.depth 23 | local shortcutType = opt.shortcutType or 'B' 24 | local iChannels 25 | 26 | -- The shortcut layer is either identity or 1x1 convolution 27 | local function shortcut(nInputPlane, nOutputPlane, stride) 28 | local useConv = shortcutType == 'C' or 29 | (shortcutType == 'B' and nInputPlane ~= nOutputPlane) 30 | if useConv then 31 | -- 1x1 convolution 32 | return nn.Sequential() 33 | :add(Convolution(nInputPlane, nOutputPlane, 1, 1, stride, stride)) 34 | :add(SBatchNorm(nOutputPlane)) 35 | elseif nInputPlane ~= nOutputPlane then 36 | -- Strided, zero-padded identity shortcut 37 | return nn.Sequential() 38 | :add(nn.SpatialAveragePooling(1, 1, stride, stride)) 39 | :add(nn.Concat(2) 40 | :add(nn.Identity()) 41 | :add(nn.MulConstant(0))) 42 | else 43 | return nn.Identity() 44 | end 45 | end 46 | 47 | -- The original bottleneck residual layer 48 | local function resnet_bottleneck(n, stride) 49 | local nInputPlane = iChannels 50 | iChannels = n * 4 51 | 52 | local s = nn.Sequential() 53 | s:add(Convolution(nInputPlane,n,1,1,1,1,0,0)) 54 | s:add(SBatchNorm(n)) 55 | s:add(ReLU(true)) 56 | s:add(Convolution(n,n,3,3,stride,stride,1,1)) 57 | s:add(SBatchNorm(n)) 58 | s:add(ReLU(true)) 59 | s:add(Convolution(n,n*4,1,1,1,1,0,0)) 60 | s:add(SBatchNorm(n * 4)) 61 | 62 | return nn.Sequential() 63 | :add(nn.ConcatTable() 64 | :add(s) 65 | :add(shortcut(nInputPlane, n * 4, stride))) 66 | :add(nn.CAddTable(true)) 67 | :add(ReLU(true)) 68 | end 69 | 70 | -- The aggregated residual transformation bottleneck layer, Form (B) 71 | local function split(nInputPlane, d, c, stride) 72 | local cat = nn.ConcatTable() 73 | for i=1,c do 74 | local s = nn.Sequential() 75 | s:add(Convolution(nInputPlane,d,1,1,1,1,0,0)) 76 | s:add(SBatchNorm(d)) 77 | s:add(ReLU(true)) 78 | s:add(Convolution(d,d,3,3,stride,stride,1,1)) 79 | s:add(SBatchNorm(d)) 80 | s:add(ReLU(true)) 81 | cat:add(s) 82 | end 83 | return cat 84 | end 85 | 86 | local function resnext_bottleneck_B(n, stride) 87 | local nInputPlane = iChannels 88 | iChannels = n * 4 89 | 90 | local D = math.floor(n * (opt.baseWidth/64)) 91 | local C = opt.cardinality 92 | 93 | local s = nn.Sequential() 94 | s:add(split(nInputPlane, D, C, stride)) 95 | s:add(nn.JoinTable(2)) 96 | s:add(Convolution(D*C,n*4,1,1,1,1,0,0)) 97 | s:add(SBatchNorm(n*4)) 98 | 99 | return nn.Sequential() 100 | :add(nn.ConcatTable() 101 | :add(s) 102 | :add(shortcut(nInputPlane, n * 4, stride))) 103 | :add(nn.CAddTable(true)) 104 | :add(ReLU(true)) 105 | end 106 | 107 | -- The aggregated residual transformation bottleneck layer, Form (C) 108 | local function resnext_bottleneck_C(n, stride) 109 | local nInputPlane = iChannels 110 | iChannels = n * 4 111 | 112 | local D = math.floor(n * (opt.baseWidth/64)) 113 | local C = opt.cardinality 114 | 115 | local s = nn.Sequential() 116 | s:add(Convolution(nInputPlane,D*C,1,1,1,1,0,0)) 117 | s:add(SBatchNorm(D*C)) 118 | s:add(ReLU(true)) 119 | s:add(Convolution(D*C,D*C,3,3,stride,stride,1,1,C)) 120 | s:add(SBatchNorm(D*C)) 121 | s:add(ReLU(true)) 122 | s:add(Convolution(D*C,n*4,1,1,1,1,0,0)) 123 | s:add(SBatchNorm(n*4)) 124 | 125 | return nn.Sequential() 126 | :add(nn.ConcatTable() 127 | :add(s) 128 | :add(shortcut(nInputPlane, n * 4, stride))) 129 | :add(nn.CAddTable(true)) 130 | :add(ReLU(true)) 131 | end 132 | 133 | -- Creates count residual blocks with specified number of features 134 | local function layer(block, features, count, stride) 135 | local s = nn.Sequential() 136 | for i=1,count do 137 | s:add(block(features, i == 1 and stride or 1)) 138 | end 139 | return s 140 | end 141 | 142 | local model = nn.Sequential() 143 | local bottleneck 144 | if opt.bottleneckType == 'resnet' then 145 | bottleneck = resnet_bottleneck 146 | print('Deploying ResNet bottleneck block') 147 | elseif opt.bottleneckType == 'resnext_B' then 148 | bottleneck = resnext_bottleneck_B 149 | print('Deploying ResNeXt bottleneck block form B') 150 | elseif opt.bottleneckType == 'resnext_C' then 151 | bottleneck = resnext_bottleneck_C 152 | print('Deploying ResNeXt bottleneck block form C (group convolution)') 153 | else 154 | error('invalid bottleneck type: ' .. opt.bottleneckType) 155 | end 156 | 157 | if opt.dataset == 'imagenet' then 158 | -- Configurations for ResNet: 159 | -- num. residual blocks, num features, residual block function 160 | local cfg = { 161 | [50] = {{3, 4, 6, 3}, 2048, bottleneck}, 162 | [101] = {{3, 4, 23, 3}, 2048, bottleneck}, 163 | [152] = {{3, 8, 36, 3}, 2048, bottleneck}, 164 | } 165 | 166 | assert(cfg[depth], 'Invalid depth: ' .. tostring(depth)) 167 | local def, nFeatures, block = table.unpack(cfg[depth]) 168 | iChannels = 64 169 | print(' | ResNet-' .. depth .. ' ImageNet') 170 | 171 | -- The ResNet ImageNet model 172 | model:add(Convolution(3,64,7,7,2,2,3,3)) 173 | model:add(SBatchNorm(64)) 174 | model:add(ReLU(true)) 175 | model:add(Max(3,3,2,2,1,1)) 176 | model:add(layer(block, 64, def[1])) 177 | model:add(layer(block, 128, def[2], 2)) 178 | model:add(layer(block, 256, def[3], 2)) 179 | model:add(layer(block, 512, def[4], 2)) 180 | model:add(Avg(7, 7, 1, 1)) 181 | model:add(nn.View(nFeatures):setNumInputDims(3)) 182 | model:add(nn.Linear(nFeatures, 1000)) 183 | 184 | elseif opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then 185 | -- Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 186 | assert((depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101') 187 | local n = (depth - 2) / 9 188 | 189 | iChannels = 64 190 | print(' | ResNet-' .. depth .. ' ' .. opt.dataset) 191 | 192 | model:add(Convolution(3,64,3,3,1,1,1,1)) 193 | model:add(SBatchNorm(64)) 194 | model:add(ReLU(true)) 195 | model:add(layer(bottleneck, 64, n, 1)) 196 | model:add(layer(bottleneck, 128, n, 2)) 197 | model:add(layer(bottleneck, 256, n, 2)) 198 | model:add(Avg(8, 8, 1, 1)) 199 | model:add(nn.View(1024):setNumInputDims(3)) 200 | local nCategories = opt.dataset == 'cifar10' and 10 or 100 201 | model:add(nn.Linear(1024, nCategories)) 202 | else 203 | error('invalid dataset: ' .. opt.dataset) 204 | end 205 | 206 | local function ConvInit(name) 207 | for k,v in pairs(model:findModules(name)) do 208 | local n = v.kW*v.kH*v.nOutputPlane 209 | v.weight:normal(0,math.sqrt(2/n)) 210 | if cudnn.version >= 4000 then 211 | v.bias = nil 212 | v.gradBias = nil 213 | else 214 | v.bias:zero() 215 | end 216 | end 217 | end 218 | local function BNInit(name) 219 | for k,v in pairs(model:findModules(name)) do 220 | v.weight:fill(1) 221 | v.bias:zero() 222 | end 223 | end 224 | 225 | ConvInit('cudnn.SpatialConvolution') 226 | ConvInit('nn.SpatialConvolution') 227 | BNInit('fbnn.SpatialBatchNormalization') 228 | BNInit('cudnn.SpatialBatchNormalization') 229 | BNInit('nn.SpatialBatchNormalization') 230 | for k,v in pairs(model:findModules('nn.Linear')) do 231 | v.bias:zero() 232 | end 233 | model:type(opt.tensorType) 234 | 235 | if opt.cudnn == 'deterministic' then 236 | model:apply(function(m) 237 | if m.setMode then m:setMode(1,1,1) end 238 | end) 239 | end 240 | 241 | model:get(1).gradInput = nil 242 | 243 | return model 244 | end 245 | 246 | return createModel 247 | -------------------------------------------------------------------------------- /opts.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | local M = { } 10 | 11 | function M.parse(arg) 12 | local cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('Torch-7 ResNet Training script') 15 | cmd:text('See https://github.com/facebook/fb.resnet.torch/blob/master/TRAINING.md for examples') 16 | cmd:text() 17 | cmd:text('Options:') 18 | ------------ General options -------------------- 19 | cmd:option('-data', '', 'Path to dataset') 20 | cmd:option('-dataset', 'imagenet', 'Options: imagenet | cifar10 | cifar100') 21 | cmd:option('-manualSeed', 0, 'Manually set RNG seed') 22 | cmd:option('-nGPU', 1, 'Number of GPUs to use by default') 23 | cmd:option('-backend', 'cudnn', 'Options: cudnn | cunn') 24 | cmd:option('-cudnn', 'fastest', 'Options: fastest | default | deterministic') 25 | cmd:option('-gen', 'gen', 'Path to save generated files') 26 | cmd:option('-precision', 'single', 'Options: single | double | half') 27 | ------------- Data options ------------------------ 28 | cmd:option('-nThreads', 2, 'number of data loading threads') 29 | ------------- Training options -------------------- 30 | cmd:option('-nEpochs', 0, 'Number of total epochs to run') 31 | cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)') 32 | cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)') 33 | cmd:option('-testOnly', 'false', 'Run on validation set only') 34 | cmd:option('-tenCrop', 'false', 'Ten-crop testing') 35 | ------------- Checkpointing options --------------- 36 | cmd:option('-save', 'checkpoints', 'Directory in which to save checkpoints') 37 | cmd:option('-resume', 'none', 'Resume from the latest checkpoint in this directory') 38 | ---------- Optimization options ---------------------- 39 | cmd:option('-LR', 0.1, 'initial learning rate') 40 | cmd:option('-momentum', 0.9, 'momentum') 41 | cmd:option('-weightDecay', 1e-4, 'weight decay') 42 | ---------- Model options ---------------------------------- 43 | cmd:option('-netType', 'resnext', 'Options: resnext') 44 | cmd:option('-bottleneckType', 'resnext_C', 'Options: resnet | resnext_B | resnext_C') 45 | cmd:option('-depth', 34, 'ResNet depth: 18 | 34 | 50 | 101 | ...', 'number') 46 | cmd:option('-baseWidth', 64, 'ResNet base width', 'number') 47 | cmd:option('-cardinality', 1, 'ResNet cardinality', 'number') 48 | cmd:option('-shortcutType', '', 'Options: A | B | C') 49 | cmd:option('-retrain', 'none', 'Path to model to retrain with') 50 | cmd:option('-optimState', 'none', 'Path to an optimState to reload from') 51 | ---------- Model options ---------------------------------- 52 | cmd:option('-shareGradInput', 'false', 'Share gradInput tensors to reduce memory usage') 53 | cmd:option('-optnet', 'false', 'Use optnet to reduce memory usage') 54 | cmd:option('-resetClassifier', 'false', 'Reset the fully connected layer for fine-tuning') 55 | cmd:option('-nClasses', 0, 'Number of classes in the dataset') 56 | cmd:text() 57 | 58 | local opt = cmd:parse(arg or {}) 59 | 60 | opt.testOnly = opt.testOnly ~= 'false' 61 | opt.tenCrop = opt.tenCrop ~= 'false' 62 | opt.shareGradInput = opt.shareGradInput ~= 'false' 63 | opt.optnet = opt.optnet ~= 'false' 64 | opt.resetClassifier = opt.resetClassifier ~= 'false' 65 | 66 | if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then 67 | cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n') 68 | end 69 | 70 | if opt.dataset == 'imagenet' then 71 | -- Handle the most common case of missing -data flag 72 | local trainDir = paths.concat(opt.data, 'train') 73 | if not paths.dirp(opt.data) then 74 | cmd:error('error: missing ImageNet data directory') 75 | elseif not paths.dirp(trainDir) then 76 | cmd:error('error: ImageNet missing `train` directory: ' .. trainDir) 77 | end 78 | -- Default shortcutType=B and nEpochs=90 79 | opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType 80 | opt.nEpochs = opt.nEpochs == 0 and 100 or opt.nEpochs 81 | elseif opt.dataset == 'cifar10' then 82 | -- Default shortcutType=B and nEpochs=328 83 | opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType 84 | opt.nEpochs = opt.nEpochs == 0 and 300 or opt.nEpochs 85 | elseif opt.dataset == 'cifar100' then 86 | -- Default shortcutType=B and nEpochs=328 87 | opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType 88 | opt.nEpochs = opt.nEpochs == 0 and 300 or opt.nEpochs 89 | else 90 | cmd:error('unknown dataset: ' .. opt.dataset) 91 | end 92 | 93 | if opt.precision == nil or opt.precision == 'single' then 94 | opt.tensorType = 'torch.CudaTensor' 95 | elseif opt.precision == 'double' then 96 | opt.tensorType = 'torch.CudaDoubleTensor' 97 | elseif opt.precision == 'half' then 98 | opt.tensorType = 'torch.CudaHalfTensor' 99 | else 100 | cmd:error('unknown precision: ' .. opt.precision) 101 | end 102 | 103 | if opt.resetClassifier then 104 | if opt.nClasses == 0 then 105 | cmd:error('-nClasses required when resetClassifier is set') 106 | end 107 | end 108 | if opt.shareGradInput and opt.optnet then 109 | cmd:error('error: cannot use both -shareGradInput and -optnet') 110 | end 111 | 112 | return opt 113 | end 114 | 115 | return M 116 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2017, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- The training loop and learning rate schedule 10 | -- 11 | 12 | local optim = require 'optim' 13 | 14 | local M = {} 15 | local Trainer = torch.class('resnet.Trainer', M) 16 | 17 | function Trainer:__init(model, criterion, opt, optimState) 18 | self.model = model 19 | self.criterion = criterion 20 | self.optimState = optimState or { 21 | learningRate = opt.LR, 22 | learningRateDecay = 0.0, 23 | momentum = opt.momentum, 24 | nesterov = true, 25 | dampening = 0.0, 26 | weightDecay = opt.weightDecay, 27 | } 28 | self.opt = opt 29 | self.params, self.gradParams = model:getParameters() 30 | end 31 | 32 | function Trainer:train(epoch, dataloader) 33 | -- Trains the model for a single epoch 34 | self.optimState.learningRate = self:learningRate(epoch) 35 | 36 | local timer = torch.Timer() 37 | local dataTimer = torch.Timer() 38 | 39 | local function feval() 40 | return self.criterion.output, self.gradParams 41 | end 42 | 43 | local trainSize = dataloader:size() 44 | local top1Sum, top5Sum, lossSum = 0.0, 0.0, 0.0 45 | local N = 0 46 | 47 | print('=> Training epoch # ' .. epoch) 48 | -- set the batch norm to training mode 49 | self.model:training() 50 | for n, sample in dataloader:run() do 51 | local dataTime = dataTimer:time().real 52 | 53 | -- Copy input and target to the GPU 54 | self:copyInputs(sample) 55 | 56 | local output = self.model:forward(self.input):float() 57 | local batchSize = output:size(1) 58 | local loss = self.criterion:forward(self.model.output, self.target) 59 | 60 | self.model:zeroGradParameters() 61 | self.criterion:backward(self.model.output, self.target) 62 | self.model:backward(self.input, self.criterion.gradInput) 63 | 64 | optim.sgd(feval, self.params, self.optimState) 65 | 66 | local top1, top5 = self:computeScore(output, sample.target, 1) 67 | top1Sum = top1Sum + top1*batchSize 68 | top5Sum = top5Sum + top5*batchSize 69 | lossSum = lossSum + loss*batchSize 70 | N = N + batchSize 71 | 72 | print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.4f top1 %7.3f top5 %7.3f'):format( 73 | epoch, n, trainSize, timer:time().real, dataTime, loss, top1, top5)) 74 | 75 | -- check that the storage didn't get changed due to an unfortunate getParameters call 76 | assert(self.params:storage() == self.model:parameters()[1]:storage()) 77 | 78 | timer:reset() 79 | dataTimer:reset() 80 | end 81 | 82 | return top1Sum / N, top5Sum / N, lossSum / N 83 | end 84 | 85 | function Trainer:test(epoch, dataloader) 86 | -- Computes the top-1 and top-5 err on the validation set 87 | 88 | local timer = torch.Timer() 89 | local dataTimer = torch.Timer() 90 | local size = dataloader:size() 91 | 92 | local nCrops = self.opt.tenCrop and 10 or 1 93 | local top1Sum, top5Sum = 0.0, 0.0 94 | local N = 0 95 | 96 | self.model:evaluate() 97 | for n, sample in dataloader:run() do 98 | local dataTime = dataTimer:time().real 99 | 100 | -- Copy input and target to the GPU 101 | self:copyInputs(sample) 102 | 103 | local output = self.model:forward(self.input):float() 104 | local batchSize = output:size(1) / nCrops 105 | local loss = self.criterion:forward(self.model.output, self.target) 106 | 107 | local top1, top5 = self:computeScore(output, sample.target, nCrops) 108 | top1Sum = top1Sum + top1*batchSize 109 | top5Sum = top5Sum + top5*batchSize 110 | N = N + batchSize 111 | 112 | print((' | Test: [%d][%d/%d] Time %.3f Data %.3f top1 %7.3f (%7.3f) top5 %7.3f (%7.3f)'):format( 113 | epoch, n, size, timer:time().real, dataTime, top1, top1Sum / N, top5, top5Sum / N)) 114 | 115 | timer:reset() 116 | dataTimer:reset() 117 | end 118 | self.model:training() 119 | 120 | print((' * Finished epoch # %d top1: %7.3f top5: %7.3f\n'):format( 121 | epoch, top1Sum / N, top5Sum / N)) 122 | 123 | return top1Sum / N, top5Sum / N 124 | end 125 | 126 | function Trainer:computeScore(output, target, nCrops) 127 | if nCrops > 1 then 128 | -- Sum over crops 129 | output = output:view(output:size(1) / nCrops, nCrops, output:size(2)) 130 | --:exp() 131 | :sum(2):squeeze(2) 132 | end 133 | 134 | -- Coputes the top1 and top5 error rate 135 | local batchSize = output:size(1) 136 | 137 | local _ , predictions = output:float():topk(5, 2, true, true) -- descending 138 | 139 | -- Find which predictions match the target 140 | local correct = predictions:eq( 141 | target:long():view(batchSize, 1):expandAs(predictions)) 142 | 143 | -- Top-1 score 144 | local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize) 145 | 146 | -- Top-5 score, if there are at least 5 classes 147 | local len = math.min(5, correct:size(2)) 148 | local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize) 149 | 150 | return top1 * 100, top5 * 100 151 | end 152 | 153 | local function getCudaTensorType(tensorType) 154 | if tensorType == 'torch.CudaHalfTensor' then 155 | return cutorch.createCudaHostHalfTensor() 156 | elseif tensorType == 'torch.CudaDoubleTensor' then 157 | return cutorch.createCudaHostDoubleTensor() 158 | else 159 | return cutorch.createCudaHostTensor() 160 | end 161 | end 162 | 163 | function Trainer:copyInputs(sample) 164 | -- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory, 165 | -- if using DataParallelTable. The target is always copied to a CUDA tensor 166 | self.input = self.input or (self.opt.nGPU == 1 167 | and torch[self.opt.tensorType:match('torch.(%a+)')]() 168 | or getCudaTensorType(self.opt.tensorType)) 169 | self.target = self.target or (torch.CudaLongTensor and torch.CudaLongTensor()) 170 | self.input:resize(sample.input:size()):copy(sample.input) 171 | self.target:resize(sample.target:size()):copy(sample.target) 172 | end 173 | 174 | function Trainer:learningRate(epoch) 175 | -- Training schedule 176 | local decay = 0 177 | if self.opt.dataset == 'imagenet' then 178 | decay = math.floor((epoch - 1) / 30) 179 | elseif self.opt.dataset == 'cifar10' or self.opt.dataset == 'cifar100' then 180 | decay = (epoch/self.opt.nEpochs) > 0.74 and 2 or (epoch/self.opt.nEpochs) > 0.49 and 1 or 0 181 | end 182 | return self.opt.LR * math.pow(0.1, decay) 183 | end 184 | 185 | return M.Trainer 186 | --------------------------------------------------------------------------------