├── License ├── README.md ├── checkpoints.lua ├── dataloader.lua ├── datasets ├── README.md ├── cifar10-gen.lua ├── cifar10.lua ├── cifar100-gen.lua ├── cifar100.lua ├── imagenet-gen.lua ├── imagenet.lua ├── init.lua └── transforms.lua ├── main.lua ├── models ├── JointTrainContainer.lua ├── MSDNet.lua ├── MSDNet_Layer.lua ├── init.lua └── msdnet.lua ├── opts.lua ├── saveTXT.lua └── train_joint.lua /License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Gao Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSDNet 2 | 3 | This repository provides the code for the paper [Multi-Scale Dense Networks for Resource Efficient Image Classification](http://arxiv.org/abs/1703.09844). 4 | 5 | ### Update on April 3, 2019 -- PyTorch implementation released! 6 | [A PyTorch implementation of MSDNet can be found from here.](https://github.com/kalviny/MSDNet-PyTorch) 7 | 8 | 9 | ## Introduction 10 | 11 | This paper studies convolutional networks that require limited computational resources at test time. We develop a new network architecture that performs on par with state-of-the-art convolutional networks, whilst facilitating prediction in two settings: (1) an anytime-prediction setting in which the network's prediction for one example is progressively updated, facilitating the output of a prediction at any time; and (2) a batch computational budget setting in which a fixed amount of computation is available to classify a set of examples that can be spent unevenly across 'easier' and 'harder' examples. 12 | 13 | **Figure 1**: MSDNet layout (2D). 14 | 15 | 16 | 17 | **Figure 2**: MSDNet layout (3D). 18 | 19 | 20 | 21 | 22 | ## Results 23 | ### (a) anytime-prediction setting 24 | 25 | **Figure 3**: Anytime prediction on ImageNet. 26 | 27 | 28 | 29 | 30 | ### (b) batch computational budget setting 31 | 32 | **Figure 4**: Prediction under batch computational budget on ImageNet. 33 | 34 | 35 | 36 | **Figure 5**: Random example images from the ImageNet classes `Red wine` and `Volcano`. Top row: images exited from the first classification layer of an MSDNet with correct prediction; Bottom row: images failed to be correctly classified at the first classifier but were correctly predicted and exited at the last layer. 37 | 38 | 39 | 40 | 41 | ## Usage 42 | 43 | Our code is written under the framework of Torch ResNet (https://github.com/facebook/fb.resnet.torch). The training scripts come with several options, which can be listed with the `--help` flag. 44 | 45 | th main.lua --help 46 | 47 | #### Configuration 48 | 49 | In all the experiments, we use a **validation set** for model selection. We hold out `5000` training images on CIFAR, and 50 | `50000` 51 | images on ImageNet as the validation set. 52 | 53 | 54 | #### Training recipe 55 | 56 | Train an MSDNet with 10 classifiers attached to every other layer for anytime prediction: 57 | ```bash 58 | th main.lua -netType msdnet -dataset cifar10 -batchSize 64 -nEpochs 300 -nBlocks 10 -stepmode even -step 2 -base 4 59 | ``` 60 | 61 | Train an MSDNet with 7 classifiers with the span linearly increases for efficient batch computation: 62 | ```bash 63 | th main.lua -netType msdnet -dataset cifar10 -batchSize 64 -nEpochs 300 -nBlocks 7 -stepmode lin_grow -step 1 -base 1 64 | ``` 65 | 66 | #### Pre-trained ImageNet Models 67 | 68 | 1. [Download](https://www.dropbox.com/sh/elnyjl4xdi4zyas/AACxCdjV-RWYrHYfbz61FFDma?dl=0) model checkpoints and the validation set indeces. 69 | 70 | 2. Testing script: `th main.lua -dataset imagenet -testOnly true -resume -data -gen ` 71 | 72 | ## FAQ 73 | 74 | 1. How to calculate the FLOPs (or mul-add op) of a model? 75 | 76 | We strongly recommend doing it automatically. Please refer to the [op-counter](https://github.com/apaszke/torch-opCounter) project (LuaTorch), or the [script](https://github.com/ShichenLiu/CondenseNet/blob/master/utils.py#L58-L162) in ConDenseNet (PyTorch). The basic idea of these op counters is to add a hook before the forward pass of a model. 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /checkpoints.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | local checkpoint = {} 10 | 11 | local function deepCopy(tbl) 12 | -- creates a copy of a network with new modules and the same tensors 13 | local copy = {} 14 | for k, v in pairs(tbl) do 15 | if type(v) == 'table' then 16 | copy[k] = deepCopy(v) 17 | else 18 | copy[k] = v 19 | end 20 | end 21 | if torch.typename(tbl) then 22 | torch.setmetatable(copy, torch.typename(tbl)) 23 | end 24 | return copy 25 | end 26 | 27 | function checkpoint.latest(opt) 28 | if opt.resume == 'none' then 29 | return nil 30 | end 31 | 32 | local latestPath = paths.concat(opt.resume, 'latest.t7') 33 | if not paths.filep(latestPath) then 34 | return nil 35 | end 36 | 37 | print('=> Loading checkpoint ' .. latestPath) 38 | local latest = torch.load(latestPath) 39 | local optimState = torch.load(paths.concat(opt.resume, latest.optimFile)) 40 | 41 | return latest, optimState 42 | end 43 | 44 | function checkpoint.save(epoch, model, optimState, isBestModel, opt) 45 | -- don't save the DataParallelTable for easier loading on other machines 46 | if torch.type(model) == 'nn.DataParallelTable' then 47 | model = model:get(1) 48 | end 49 | 50 | print('save model!') 51 | 52 | -- create a clean copy on the CPU without modifying the original network 53 | model = deepCopy(model):float():clearState() 54 | 55 | local modelFile = 'model' .. '.t7' 56 | local optimFile = 'optimState' .. '.t7' 57 | 58 | torch.save(paths.concat(opt.save, modelFile), model) 59 | torch.save(paths.concat(opt.save, optimFile), optimState) 60 | torch.save(paths.concat(opt.save, 'latest.t7'), { 61 | epoch = epoch, 62 | modelFile = modelFile, 63 | optimFile = optimFile, 64 | opt = opt, 65 | }) 66 | 67 | if isBestModel then 68 | torch.save(paths.concat(opt.save, 'model_best.t7'), model) 69 | end 70 | end 71 | 72 | return checkpoint 73 | -------------------------------------------------------------------------------- /dataloader.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Multi-threaded data loader 10 | -- 11 | 12 | local datasets = require 'datasets/init' 13 | local Threads = require 'threads' 14 | Threads.serialization('threads.sharedserialize') 15 | 16 | local M = {} 17 | local DataLoader = torch.class('resnet.DataLoader', M) 18 | 19 | function DataLoader.create(opt) 20 | -- The train and val loader 21 | local loaders = {} 22 | 23 | local data_pair 24 | if opt.validset == true then 25 | data_pair = {'train', 'val', 'test'} 26 | else 27 | data_pair = {'train', 'val'} 28 | end 29 | 30 | for i, split in ipairs(data_pair) do 31 | local dataset = datasets.create(opt, split) 32 | loaders[i] = M.DataLoader(dataset, opt, split) 33 | end 34 | 35 | return table.unpack(loaders) 36 | end 37 | 38 | function DataLoader:__init(dataset, opt, split) 39 | self.opt = opt 40 | self.split = split 41 | local manualSeed = opt.manualSeed 42 | local function init() 43 | require('datasets/' .. opt.dataset) 44 | end 45 | local function main(idx) 46 | if manualSeed ~= 0 then 47 | torch.manualSeed(manualSeed + idx) 48 | end 49 | torch.setnumthreads(1) 50 | _G.dataset = dataset 51 | _G.preprocess = dataset:preprocess() 52 | return dataset:size() 53 | end 54 | 55 | local threads, sizes = Threads(opt.nThreads, init, main) 56 | self.nCrops = ((split == 'val' or split == 'test') and opt.tenCrop) and 10 or 1 57 | self.threads = threads 58 | self.__size = sizes[1][1] 59 | self.batchSize = math.floor(opt.batchSize / self.nCrops) 60 | end 61 | 62 | function DataLoader:size() 63 | return math.ceil(self.__size / self.batchSize) 64 | end 65 | 66 | function DataLoader:run() 67 | local threads = self.threads 68 | local size, batchSize = self.__size, self.batchSize 69 | local perm = (self.split == 'train') and torch.randperm(size) or torch.range(1,size) -- modifield to not shuffle data for testing 70 | 71 | local idx, sample = 1, nil 72 | local function enqueue() 73 | while idx <= size and threads:acceptsjob() do 74 | local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) 75 | threads:addjob( 76 | function(indices, nCrops) 77 | local sz = indices:size(1) 78 | local batch, imageSize 79 | local target, path = torch.IntTensor(sz), {} 80 | for i, idx in ipairs(indices:totable()) do 81 | local sample = _G.dataset:get(idx) 82 | local input = _G.preprocess(sample.input) 83 | if not batch then 84 | imageSize = input:size():totable() 85 | if nCrops > 1 then table.remove(imageSize, 1) end 86 | batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize)) 87 | end 88 | batch[i]:copy(input) 89 | target[i] = sample.target 90 | path[i] = sample.path 91 | end 92 | collectgarbage() 93 | return { 94 | input = batch:view(sz * nCrops, table.unpack(imageSize)), 95 | target = target, 96 | path = path, 97 | } 98 | end, 99 | function(_sample_) 100 | sample = _sample_ 101 | end, 102 | indices, 103 | self.nCrops 104 | ) 105 | idx = idx + batchSize 106 | end 107 | end 108 | 109 | local n = 0 110 | local function loop() 111 | enqueue() 112 | if not threads:hasjob() then 113 | return nil 114 | end 115 | threads:dojob() 116 | if threads:haserror() then 117 | threads:synchronize() 118 | end 119 | enqueue() 120 | n = n + 1 121 | return n, sample 122 | end 123 | 124 | return loop 125 | end 126 | 127 | return M.DataLoader 128 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | ## Datasets 2 | 3 | Each dataset consist of two files: `dataset-gen.lua` and `dataset.lua`. The `dataset-gen.lua` is responsible for one-time setup, while 4 | the `dataset.lua` handles the actual data loading. 5 | 6 | If you want to be able to use the new dataset from main.lua, you should also modify `opts.lua` to handle the new dataset name. 7 | 8 | ### `dataset-gen.lua` 9 | 10 | The `dataset-gen.lua` performs any necessary one-time setup. For example, the [`cifar10-gen.lua`](cifar10-gen.lua) file downloads the CIFAR-10 dataset, and the [`imagenet-gen.lua`](imagenet-gen.lua) file indexes all the training and validation data. 11 | 12 | The module should have a single function `exec(opt, cacheFile)`. 13 | - `opt`: the command line options 14 | - `cacheFile`: path to output 15 | 16 | ```lua 17 | local M = {} 18 | function M.exec(opt, cacheFile) 19 | local imageInfo = {} 20 | -- preprocess dataset, store results in imageInfo, save to cacheFile 21 | torch.save(cacheFile, imageInfo) 22 | end 23 | return M 24 | ``` 25 | 26 | ### `dataset.lua` 27 | 28 | The `dataset.lua` should return a class that implements three functions: 29 | - `get(i)`: returns a table containing two entries, `input` and `target` 30 | - `input`: the training or validation image as a Torch tensor 31 | - `target`: the image category as a number 1-N 32 | - `size()`: returns the number of entries in the dataset 33 | - `preprocess()`: returns a function that transforms the `input` for data augmentation or input normalization 34 | 35 | ```lua 36 | local M = {} 37 | local FakeDataset = torch.class('resnet.FakeDataset', M) 38 | 39 | function FakeDataset:__init(imageInfo, opt, split) 40 | -- imageInfo: result from dataset-gen.lua 41 | -- opt: command-line arguments 42 | -- split: "train" or "val" 43 | end 44 | 45 | function FakeDataset:get(i) 46 | return { 47 | input = torch.Tensor(3, 800, 600):uniform(), 48 | target = 42, 49 | } 50 | end 51 | 52 | function FakeDataset:size() 53 | -- size of dataset 54 | return 2000 55 | end 56 | 57 | function FakeDataset:preprocess() 58 | -- Scale smaller side to 256 and take 224x224 center-crop 59 | return t.Compose{ 60 | t.Scale(256), 61 | t.CenterCrop(224), 62 | } 63 | end 64 | 65 | return M.FakeDataset 66 | ``` 67 | -------------------------------------------------------------------------------- /datasets/cifar10-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Script to compute list of ImageNet filenames and classes 10 | -- 11 | -- This automatically downloads the CIFAR-10 dataset from 12 | -- http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz 13 | -- 14 | 15 | local URL = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz' 16 | 17 | local M = {} 18 | 19 | local function convertToTensor(files) 20 | local data, labels 21 | 22 | for _, file in ipairs(files) do 23 | local m = torch.load(file, 'ascii') 24 | if not data then 25 | data = m.data:t() 26 | labels = m.labels:squeeze() 27 | else 28 | data = torch.cat(data, m.data:t(), 1) 29 | labels = torch.cat(labels, m.labels:squeeze()) 30 | end 31 | end 32 | 33 | -- This is *very* important. The downloaded files have labels 0-9, which do 34 | -- not work with CrossEntropyCriterion 35 | labels:add(1) 36 | 37 | return { 38 | data = data:contiguous():view(-1, 3, 32, 32), 39 | labels = labels, 40 | } 41 | end 42 | 43 | function M.exec(opt, cacheFile) 44 | local rawpath = 'gen/cifar-10-batches-t7/data_batch_1.t7' 45 | if not paths.filep(rawpath) then 46 | print("=> Downloading CIFAR-10 dataset from " .. URL) 47 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') 48 | assert(ok == true or ok == 0, 'error downloading CIFAR-10') 49 | end 50 | 51 | print(" | combining dataset into a single file") 52 | local trainData = convertToTensor({ 53 | 'gen/cifar-10-batches-t7/data_batch_1.t7', 54 | 'gen/cifar-10-batches-t7/data_batch_2.t7', 55 | 'gen/cifar-10-batches-t7/data_batch_3.t7', 56 | 'gen/cifar-10-batches-t7/data_batch_4.t7', 57 | 'gen/cifar-10-batches-t7/data_batch_5.t7', 58 | }) 59 | local testData = convertToTensor({ 60 | 'gen/cifar-10-batches-t7/test_batch.t7', 61 | }) 62 | print(" | saving CIFAR-10 dataset to " .. cacheFile) 63 | 64 | 65 | if opt.validset == true then 66 | torch.manualSeed(1) 67 | local shuffle = torch.randperm(50000) 68 | trainData.data[{ {1, 50000} }] = trainData.data:index(1, shuffle:long()) 69 | trainData.labels[{ {1, 50000} }] = trainData.labels:index(1, shuffle:long()) 70 | 71 | local valData = {data = torch.Tensor(), labels = torch.Tensor()} 72 | valData.data = trainData.data[{ {45001,50000} }] 73 | valData.labels = trainData.labels[{ {45001,50000} }] 74 | trainData.data = trainData.data[{ {1,45000} }] 75 | trainData.labels = trainData.labels[{ {1,45000} }] 76 | 77 | torch.save(cacheFile, { 78 | train = trainData, 79 | val = valData, 80 | test = testData,}) 81 | else 82 | torch.save(cacheFile, { 83 | train = trainData, 84 | val = testData,}) 85 | end 86 | 87 | end 88 | 89 | return M 90 | -------------------------------------------------------------------------------- /datasets/cifar10.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- CIFAR-10 dataset loader 10 | -- 11 | 12 | local t = require 'datasets/transforms' 13 | 14 | local M = {} 15 | local CifarDataset = torch.class('resnet.CifarDataset', M) 16 | 17 | function CifarDataset:__init(imageInfo, opt, split) 18 | assert(imageInfo[split], split) 19 | self.imageInfo = imageInfo[split] 20 | self.split = split 21 | self.opt = opt 22 | end 23 | 24 | function CifarDataset:get(i) 25 | local image = self.imageInfo.data[i]:float() 26 | local label = self.imageInfo.labels[i] 27 | 28 | return { 29 | input = image, 30 | target = label, 31 | } 32 | end 33 | 34 | function CifarDataset:size() 35 | return self.imageInfo.data:size(1) 36 | end 37 | 38 | -- Computed from entire CIFAR-10 training set 39 | local meanstd = { 40 | mean = {125.3, 123.0, 113.9}, 41 | std = {63.0, 62.1, 66.7}, 42 | } 43 | 44 | function CifarDataset:preprocess() 45 | if self.split == 'train' then 46 | if self.opt.DataAug == true then 47 | return t.Compose{ 48 | t.ColorNormalize(meanstd), 49 | t.HorizontalFlip(0.5), 50 | t.RandomCrop(32, 4), 51 | } 52 | else 53 | return t.Compose{ 54 | t.ColorNormalize(meanstd), 55 | } 56 | end 57 | elseif self.split == 'val' or self.split == 'test' then 58 | return t.ColorNormalize(meanstd) 59 | else 60 | error('invalid split: ' .. self.split) 61 | end 62 | end 63 | 64 | return M.CifarDataset 65 | -------------------------------------------------------------------------------- /datasets/cifar100-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | ------------ 11 | -- This file automatically downloads the CIFAR-100 dataset from 12 | -- http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz 13 | -- It is based on cifar10-gen.lua 14 | -- Ludovic Trottier 15 | ------------ 16 | 17 | local URL = 'http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz' 18 | 19 | local M = {} 20 | 21 | local function convertCifar100BinToTorchTensor(inputFname) 22 | local m=torch.DiskFile(inputFname, 'r'):binary() 23 | m:seekEnd() 24 | local length = m:position() - 1 25 | local nSamples = length / 3074 -- 1 coarse-label byte, 1 fine-label byte, 3072 pixel bytes 26 | 27 | assert(nSamples == math.floor(nSamples), 'expecting numSamples to be an exact integer') 28 | m:seek(1) 29 | 30 | local coarse = torch.ByteTensor(nSamples) 31 | local fine = torch.ByteTensor(nSamples) 32 | local data = torch.ByteTensor(nSamples, 3, 32, 32) 33 | for i=1,nSamples do 34 | coarse[i] = m:readByte() 35 | fine[i] = m:readByte() 36 | local store = m:readByte(3072) 37 | data[i]:copy(torch.ByteTensor(store)) 38 | end 39 | 40 | local out = {} 41 | out.data = data 42 | -- This is *very* important. The downloaded files have labels 0-9, which do 43 | -- not work with CrossEntropyCriterion 44 | out.labels = fine + 1 45 | 46 | return out 47 | end 48 | 49 | function M.exec(opt, cacheFile) 50 | print("=> Downloading CIFAR-100 dataset from " .. URL) 51 | 52 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/') 53 | assert(ok == true or ok == 0, 'error downloading CIFAR-100') 54 | 55 | print(" | combining dataset into a single file") 56 | 57 | local trainData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/train.bin') 58 | local testData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/test.bin') 59 | 60 | print(" | saving CIFAR-100 dataset to " .. cacheFile) 61 | 62 | if opt.validset == true then 63 | torch.manualSeed(1) 64 | local shuffle = torch.randperm(50000) 65 | trainData.data[{ {1, 50000} }] = trainData.data:index(1, shuffle:long()) 66 | trainData.labels[{ {1, 50000} }] = trainData.labels:index(1, shuffle:long()) 67 | 68 | local valData = {data = torch.Tensor(), labels = torch.Tensor()} 69 | valData.data = trainData.data[{ {45001,50000} }] 70 | valData.labels = trainData.labels[{ {45001,50000} }] 71 | trainData.data = trainData.data[{ {1,45000} }] 72 | trainData.labels = trainData.labels[{ {1,45000} }] 73 | 74 | torch.save(cacheFile, { 75 | train = trainData, 76 | val = valData, 77 | test = testData,}) 78 | else 79 | torch.save(cacheFile, { 80 | train = trainData, 81 | val = testData,}) 82 | end 83 | end 84 | 85 | return M -------------------------------------------------------------------------------- /datasets/cifar100.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | 10 | ------------ 11 | -- This file is downloading and transforming CIFAR-100. 12 | -- It is based on cifar10.lua 13 | -- Ludovic Trottier 14 | ------------ 15 | 16 | local t = require 'datasets/transforms' 17 | 18 | local M = {} 19 | local CifarDataset = torch.class('resnet.CifarDataset', M) 20 | 21 | function CifarDataset:__init(imageInfo, opt, split) 22 | assert(imageInfo[split], split) 23 | self.imageInfo = imageInfo[split] 24 | self.split = split 25 | self.opt = opt 26 | end 27 | 28 | function CifarDataset:get(i) 29 | local image = self.imageInfo.data[i]:float() 30 | local label = self.imageInfo.labels[i] 31 | 32 | return { 33 | input = image, 34 | target = label, 35 | } 36 | end 37 | 38 | function CifarDataset:size() 39 | return self.imageInfo.data:size(1) 40 | end 41 | 42 | 43 | -- Computed from entire CIFAR-100 training set with this code: 44 | -- dataset = torch.load('cifar100.t7') 45 | -- tt = dataset.train.data:double(); 46 | -- tt = tt:transpose(2,4); 47 | -- tt = tt:reshape(50000*32*32, 3); 48 | -- tt:mean(1) 49 | -- tt:std(1) 50 | local meanstd = { 51 | mean = {129.3, 124.1, 112.4}, 52 | std = {68.2, 65.4, 70.4}, 53 | } 54 | 55 | function CifarDataset:preprocess() 56 | if self.split == 'train' then 57 | if self.opt.DataAug == true then 58 | return t.Compose{ 59 | t.ColorNormalize(meanstd), 60 | t.HorizontalFlip(0.5), 61 | t.RandomCrop(32, 4), 62 | } 63 | else 64 | return t.Compose{ 65 | t.ColorNormalize(meanstd), 66 | } 67 | end 68 | elseif self.split == 'val' or self.split == 'test' then 69 | return t.ColorNormalize(meanstd) 70 | else 71 | error('invalid split: ' .. self.split) 72 | end 73 | end 74 | 75 | return M.CifarDataset 76 | -------------------------------------------------------------------------------- /datasets/imagenet-gen.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Script to compute list of ImageNet filenames and classes 10 | -- 11 | -- This generates a file gen/imagenet.t7 which contains the list of all 12 | -- ImageNet training and validation images and their classes. This script also 13 | -- works for other datasets arragned with the same layout. 14 | -- 15 | 16 | local sys = require 'sys' 17 | local ffi = require 'ffi' 18 | 19 | local M = {} 20 | 21 | local function findClasses(dir) 22 | local dirs = paths.dir(dir) 23 | table.sort(dirs) 24 | 25 | local classList = {} 26 | local classToIdx = {} 27 | for _ ,class in ipairs(dirs) do 28 | if not classToIdx[class] and class ~= '.' and class ~= '..' 29 | and class ~= '.DS_Store' then 30 | table.insert(classList, class) 31 | classToIdx[class] = #classList 32 | end 33 | end 34 | 35 | -- assert(#classList == 1000, 'expected 1000 ImageNet classes') 36 | return classList, classToIdx 37 | end 38 | 39 | local function findImages(dir, classToIdx) 40 | 41 | ---------------------------------------------------------------------- 42 | -- Options for the GNU and BSD find command 43 | local extensionList = 44 | {'jpg', 'png', 'jpeg', 'JPG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 45 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 46 | for i=2,#extensionList do 47 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 48 | end 49 | 50 | -- Find all the images using the find command 51 | local f = io.popen('find -L ' .. dir .. findOptions) 52 | 53 | local maxLength = -1 54 | local imagePaths = {} 55 | local imageClasses = {} 56 | 57 | -- Generate a list of all the images and their class 58 | while true do 59 | local line = f:read('*line') 60 | if not line then break end 61 | 62 | local className = paths.basename(paths.dirname(line)) 63 | local filename = paths.basename(line) 64 | local path = className .. '/' .. filename 65 | 66 | local classId = classToIdx[className] 67 | assert(classId, 'class not found: ' .. className) 68 | 69 | table.insert(imagePaths, path) 70 | table.insert(imageClasses, classId) 71 | 72 | maxLength = math.max(maxLength, #path + 1) 73 | end 74 | 75 | f:close() 76 | 77 | -- Convert the generated list to a tensor for faster loading 78 | local nImages = #imagePaths 79 | local imagePath = torch.CharTensor(nImages, maxLength):zero() 80 | for i, path in ipairs(imagePaths) do 81 | ffi.copy(imagePath[i]:data(), path) 82 | end 83 | 84 | local imageClass = torch.LongTensor(imageClasses) 85 | return imagePath, imageClass 86 | end 87 | 88 | function M.exec(opt, cacheFile) 89 | 90 | local trainDir = paths.concat(opt.data, 'train') 91 | local valDir = paths.concat(opt.data, 'val') 92 | assert(paths.dirp(trainDir), 'train directory not found: ' .. trainDir) 93 | assert(paths.dirp(valDir), 'val directory not found: ' .. valDir) 94 | 95 | print("=> Generating list of images") 96 | local classList, classToIdx = findClasses(trainDir) 97 | 98 | print(" | finding all validation images") 99 | local valImagePath, valImageClass = findImages(valDir, classToIdx) 100 | 101 | print(" | finding all training images") 102 | local trainImagePath, trainImageClass = findImages(trainDir, classToIdx) 103 | 104 | local info 105 | if opt.validset == true then 106 | torch.manualSeed(1) 107 | 108 | local nAll = trainImageClass:size(1) 109 | local shuffle = torch.randperm(nAll) 110 | local val0ImagePath = trainImagePath:index(1,shuffle:sub(1,50000):long()) 111 | local val0ImageClass = trainImageClass:index(1,shuffle:sub(1,50000):long()) 112 | trainImagePath = trainImagePath:index(1,shuffle:sub(50001,nAll):long()) 113 | trainImageClass = trainImageClass:index(1,shuffle:sub(50001,nAll):long()) 114 | 115 | info = { 116 | basedir = opt.data, 117 | classList = classList, 118 | train = { 119 | imagePath = trainImagePath, 120 | imageClass = trainImageClass, 121 | }, 122 | val = { 123 | imagePath = val0ImagePath, 124 | imageClass = val0ImageClass, 125 | }, 126 | test = { 127 | imagePath = valImagePath, 128 | imageClass = valImageClass, 129 | }, 130 | } 131 | else 132 | info = { 133 | basedir = opt.data, 134 | classList = classList, 135 | train = { 136 | imagePath = trainImagePath, 137 | imageClass = trainImageClass, 138 | }, 139 | val = { 140 | imagePath = valImagePath, 141 | imageClass = valImageClass, 142 | }, 143 | } 144 | end 145 | 146 | 147 | print(" | saving list of images to " .. cacheFile) 148 | torch.save(cacheFile, info) 149 | return info 150 | end 151 | 152 | return M 153 | -------------------------------------------------------------------------------- /datasets/imagenet.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- ImageNet dataset loader 10 | -- 11 | 12 | local image = require 'image' 13 | local paths = require 'paths' 14 | local t = require 'datasets/transforms' 15 | local ffi = require 'ffi' 16 | 17 | local M = {} 18 | local ImagenetDataset = torch.class('resnet.ImagenetDataset', M) 19 | 20 | function ImagenetDataset:__init(imageInfo, opt, split) 21 | self.imageInfo = imageInfo[split] 22 | self.opt = opt 23 | self.split = split 24 | -------------------------------------------- 25 | if opt.validset then 26 | if split == 'train' or split == 'val' then 27 | self.dir = paths.concat(opt.data, 'train') 28 | elseif split == 'test' then 29 | self.dir = paths.concat(opt.data, 'val') 30 | end 31 | else 32 | self.dir = paths.concat(opt.data, split) 33 | end 34 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) 35 | ------------------------------------------- 36 | end 37 | 38 | function ImagenetDataset:get(i) 39 | local path = ffi.string(self.imageInfo.imagePath[i]:data()) 40 | 41 | local image = self:_loadImage(paths.concat(self.dir, path)) 42 | local class = self.imageInfo.imageClass[i] 43 | 44 | return { 45 | input = image, 46 | target = class, 47 | } 48 | end 49 | 50 | function ImagenetDataset:_loadImage(path) 51 | local ok, input = pcall(function() 52 | return image.load(path, 3, 'float') 53 | end) 54 | 55 | -- Sometimes image.load fails because the file extension does not match the 56 | -- image format. In that case, use image.decompress on a ByteTensor. 57 | if not ok then 58 | local f = io.open(path, 'r') 59 | assert(f, 'Error reading: ' .. tostring(path)) 60 | local data = f:read('*a') 61 | f:close() 62 | 63 | local b = torch.ByteTensor(string.len(data)) 64 | ffi.copy(b:data(), data, b:size(1)) 65 | 66 | ok, input = pcall(image.decompress, b, 3, 'float') 67 | if not ok then 68 | input = torch.FloatTensor(3, 224, 224):fill(0) 69 | end 70 | end 71 | 72 | return input 73 | end 74 | 75 | function ImagenetDataset:size() 76 | return self.imageInfo.imageClass:size(1) 77 | end 78 | 79 | -- Computed from random subset of ImageNet training images 80 | local meanstd = { 81 | mean = { 0.485, 0.456, 0.406 }, 82 | std = { 0.229, 0.224, 0.225 }, 83 | } 84 | local pca = { 85 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 }, 86 | eigvec = torch.Tensor{ 87 | { -0.5675, 0.7192, 0.4009 }, 88 | { -0.5808, -0.0045, -0.8140 }, 89 | { -0.5836, -0.6948, 0.4203 }, 90 | }, 91 | } 92 | 93 | function ImagenetDataset:preprocess() 94 | if self.split == 'train' then 95 | return t.Compose{ 96 | t.RandomSizedCrop(224), 97 | t.ColorJitter({ 98 | brightness = 0.4, 99 | contrast = 0.4, 100 | saturation = 0.4, 101 | }), 102 | t.Lighting(0.1, pca.eigval, pca.eigvec), 103 | t.ColorNormalize(meanstd), 104 | t.HorizontalFlip(0.5), 105 | } 106 | elseif self.split == 'val' or self.split == 'test' then 107 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop 108 | return t.Compose{ 109 | t.Scale(256), 110 | t.ColorNormalize(meanstd), 111 | Crop(224), 112 | } 113 | else 114 | error('invalid split: ' .. self.split) 115 | end 116 | end 117 | 118 | return M.ImagenetDataset 119 | -------------------------------------------------------------------------------- /datasets/init.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- ImageNet and CIFAR-10 datasets 10 | -- 11 | 12 | local M = {} 13 | 14 | local function isvalid(opt, cachePath) 15 | if opt.testOnly == true then return true end -- don't check the basedir in test mode 16 | local imageInfo = torch.load(cachePath) 17 | if imageInfo.basedir and imageInfo.basedir ~= opt.data then 18 | return false 19 | end 20 | return true 21 | end 22 | 23 | function M.create(opt, split) 24 | local cachePath = opt.validset 25 | and paths.concat(opt.gen, opt.dataset .. '_withvalid.t7') 26 | or paths.concat(opt.gen, opt.dataset .. '.t7') 27 | if not paths.filep(cachePath) or not isvalid(opt, cachePath) then 28 | paths.mkdir('gen') 29 | local script = paths.dofile(opt.dataset .. '-gen.lua') 30 | script.exec(opt, cachePath) 31 | end 32 | local imageInfo = torch.load(cachePath) 33 | 34 | local Dataset = require('datasets/' .. opt.dataset) 35 | return Dataset(imageInfo, opt, split) 36 | end 37 | 38 | return M 39 | -------------------------------------------------------------------------------- /datasets/transforms.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Image transforms for data augmentation and input normalization 10 | -- 11 | 12 | require 'image' 13 | 14 | local M = {} 15 | 16 | function M.Compose(transforms) 17 | return function(input) 18 | for _, transform in ipairs(transforms) do 19 | input = transform(input) 20 | end 21 | return input 22 | end 23 | end 24 | 25 | function M.ColorNormalize(meanstd) 26 | return function(img) 27 | img = img:clone() 28 | for i=1,3 do 29 | img[i]:add(-meanstd.mean[i]) 30 | img[i]:div(meanstd.std[i]) 31 | end 32 | return img 33 | end 34 | end 35 | 36 | -- Scales the smaller edge to size 37 | function M.Scale(size, interpolation) 38 | interpolation = interpolation or 'bicubic' 39 | return function(input) 40 | local w, h = input:size(3), input:size(2) 41 | if (w <= h and w == size) or (h <= w and h == size) then 42 | return input 43 | end 44 | if w < h then 45 | return image.scale(input, size, h/w * size, interpolation) 46 | else 47 | return image.scale(input, w/h * size, size, interpolation) 48 | end 49 | end 50 | end 51 | 52 | -- Crop to centered rectangle 53 | function M.CenterCrop(size) 54 | return function(input) 55 | local w1 = math.ceil((input:size(3) - size)/2) 56 | local h1 = math.ceil((input:size(2) - size)/2) 57 | return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch 58 | end 59 | end 60 | 61 | -- Random crop form larger image with optional zero padding 62 | function M.RandomCrop(size, padding) 63 | padding = padding or 0 64 | 65 | return function(input) 66 | if padding > 0 then 67 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) 68 | temp:zero() 69 | :narrow(2, padding+1, input:size(2)) 70 | :narrow(3, padding+1, input:size(3)) 71 | :copy(input) 72 | input = temp 73 | end 74 | 75 | local w, h = input:size(3), input:size(2) 76 | if w == size and h == size then 77 | return input 78 | end 79 | 80 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) 81 | local out = image.crop(input, x1, y1, x1 + size, y1 + size) 82 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') 83 | return out 84 | end 85 | end 86 | 87 | -- Four corner patches and center crop from image and its horizontal reflection 88 | function M.TenCrop(size) 89 | local centerCrop = M.CenterCrop(size) 90 | 91 | return function(input) 92 | local w, h = input:size(3), input:size(2) 93 | 94 | local output = {} 95 | for _, img in ipairs{input, image.hflip(input)} do 96 | table.insert(output, centerCrop(img)) 97 | table.insert(output, image.crop(img, 0, 0, size, size)) 98 | table.insert(output, image.crop(img, w-size, 0, w, size)) 99 | table.insert(output, image.crop(img, 0, h-size, size, h)) 100 | table.insert(output, image.crop(img, w-size, h-size, w, h)) 101 | end 102 | 103 | -- View as mini-batch 104 | for i, img in ipairs(output) do 105 | output[i] = img:view(1, img:size(1), img:size(2), img:size(3)) 106 | end 107 | 108 | return input.cat(output, 1) 109 | end 110 | end 111 | 112 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style) 113 | function M.RandomScale(minSize, maxSize) 114 | return function(input) 115 | local w, h = input:size(3), input:size(2) 116 | 117 | local targetSz = torch.random(minSize, maxSize) 118 | local targetW, targetH = targetSz, targetSz 119 | if w < h then 120 | targetH = torch.round(h / w * targetW) 121 | else 122 | targetW = torch.round(w / h * targetH) 123 | end 124 | 125 | return image.scale(input, targetW, targetH, 'bicubic') 126 | end 127 | end 128 | 129 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style) 130 | function M.RandomSizedCrop(size) 131 | local scale = M.Scale(size) 132 | local crop = M.CenterCrop(size) 133 | 134 | return function(input) 135 | local attempt = 0 136 | repeat 137 | local area = input:size(2) * input:size(3) 138 | local targetArea = torch.uniform(0.08, 1.0) * area 139 | 140 | local aspectRatio = torch.uniform(3/4, 4/3) 141 | local w = torch.round(math.sqrt(targetArea * aspectRatio)) 142 | local h = torch.round(math.sqrt(targetArea / aspectRatio)) 143 | 144 | if torch.uniform() < 0.5 then 145 | w, h = h, w 146 | end 147 | 148 | if h <= input:size(2) and w <= input:size(3) then 149 | local y1 = torch.random(0, input:size(2) - h) 150 | local x1 = torch.random(0, input:size(3) - w) 151 | 152 | local out = image.crop(input, x1, y1, x1 + w, y1 + h) 153 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') 154 | 155 | return image.scale(out, size, size, 'bicubic') 156 | end 157 | attempt = attempt + 1 158 | until attempt >= 10 159 | 160 | -- fallback 161 | return crop(scale(input)) 162 | end 163 | end 164 | 165 | function M.HorizontalFlip(prob) 166 | return function(input) 167 | if torch.uniform() < prob then 168 | input = image.hflip(input) 169 | end 170 | return input 171 | end 172 | end 173 | 174 | function M.Rotation(deg) 175 | return function(input) 176 | if deg ~= 0 then 177 | input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear') 178 | end 179 | return input 180 | end 181 | end 182 | 183 | -- Lighting noise (AlexNet-style PCA-based noise) 184 | function M.Lighting(alphastd, eigval, eigvec) 185 | return function(input) 186 | if alphastd == 0 then 187 | return input 188 | end 189 | 190 | local alpha = torch.Tensor(3):normal(0, alphastd) 191 | local rgb = eigvec:clone() 192 | :cmul(alpha:view(1, 3):expand(3, 3)) 193 | :cmul(eigval:view(1, 3):expand(3, 3)) 194 | :sum(2) 195 | :squeeze() 196 | 197 | input = input:clone() 198 | for i=1,3 do 199 | input[i]:add(rgb[i]) 200 | end 201 | return input 202 | end 203 | end 204 | 205 | local function blend(img1, img2, alpha) 206 | return img1:mul(alpha):add(1 - alpha, img2) 207 | end 208 | 209 | local function grayscale(dst, img) 210 | dst:resizeAs(img) 211 | dst[1]:zero() 212 | dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3]) 213 | dst[2]:copy(dst[1]) 214 | dst[3]:copy(dst[1]) 215 | return dst 216 | end 217 | 218 | function M.Saturation(var) 219 | local gs 220 | 221 | return function(input) 222 | gs = gs or input.new() 223 | grayscale(gs, input) 224 | 225 | local alpha = 1.0 + torch.uniform(-var, var) 226 | blend(input, gs, alpha) 227 | return input 228 | end 229 | end 230 | 231 | function M.Brightness(var) 232 | local gs 233 | 234 | return function(input) 235 | gs = gs or input.new() 236 | gs:resizeAs(input):zero() 237 | 238 | local alpha = 1.0 + torch.uniform(-var, var) 239 | blend(input, gs, alpha) 240 | return input 241 | end 242 | end 243 | 244 | function M.Contrast(var) 245 | local gs 246 | 247 | return function(input) 248 | gs = gs or input.new() 249 | grayscale(gs, input) 250 | gs:fill(gs[1]:mean()) 251 | 252 | local alpha = 1.0 + torch.uniform(-var, var) 253 | blend(input, gs, alpha) 254 | return input 255 | end 256 | end 257 | 258 | function M.RandomOrder(ts) 259 | return function(input) 260 | local img = input.img or input 261 | local order = torch.randperm(#ts) 262 | for i=1,#ts do 263 | img = ts[order[i]](img) 264 | end 265 | return img 266 | end 267 | end 268 | 269 | function M.ColorJitter(opt) 270 | local brightness = opt.brightness or 0 271 | local contrast = opt.contrast or 0 272 | local saturation = opt.saturation or 0 273 | 274 | local ts = {} 275 | if brightness ~= 0 then 276 | table.insert(ts, M.Brightness(brightness)) 277 | end 278 | if contrast ~= 0 then 279 | table.insert(ts, M.Contrast(contrast)) 280 | end 281 | if saturation ~= 0 then 282 | table.insert(ts, M.Saturation(saturation)) 283 | end 284 | 285 | if #ts == 0 then 286 | return function(input) return input end 287 | end 288 | 289 | return M.RandomOrder(ts) 290 | end 291 | 292 | return M 293 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | require 'torch' 10 | require 'paths' 11 | require 'optim' 12 | require 'nn' 13 | require 'sys' 14 | require 'paths' 15 | 16 | local save2txt = require 'saveTXT' 17 | local DataLoader = require 'dataloader' 18 | local models = require 'models/init' 19 | local Trainer = require 'train_joint' 20 | local opts = require 'opts' 21 | local checkpoints = require 'checkpoints' 22 | 23 | torch.setdefaulttensortype('torch.FloatTensor') 24 | torch.setnumthreads(1) 25 | 26 | local opt = opts.parse(arg) 27 | torch.manualSeed(opt.manualSeed) 28 | cutorch.manualSeedAll(opt.manualSeed) 29 | 30 | -- Load previous checkpoint, if it exists 31 | local checkpoint, optimState = checkpoints.latest(opt) 32 | 33 | -- Create model 34 | local model = models.setup(opt, checkpoint) 35 | 36 | -- Use parallel criterion for multiple exits 37 | local criterion = nn.ParallelCriterion() 38 | for i = 1, opt.nBlocks do 39 | criterion:add(nn.CrossEntropyCriterion():type(opt.tensorType), 1) 40 | end 41 | 42 | -- Data loading 43 | print('Creating dataloader ...') 44 | local trainLoader, valLoader, testLoader = DataLoader.create(opt) 45 | 46 | -- The trainer handles the training loop and evaluation on validation set 47 | local trainer = Trainer(model, criterion, opt, optimState) 48 | 49 | -- Test only, need to specify -save and -retrain 50 | if opt.testOnly then 51 | -- local flops = torch.load(paths.concat(opt.save, 'flops.t7')) 52 | -- opt.nBlocks = #flops 53 | local top1ErrValid, top5Err, top1ErrEnsembleValid, top5ErrEnsemble = trainer:test(0, valLoader) 54 | local top1Err, top5Err, top1ErrEnsemble, top5ErrEnsemble = trainer:test(0, testLoader) 55 | print('results from: ' .. opt.save) 56 | print( 57 | -- 'flops: \n', flops, 58 | '\nval single: \n', top1ErrValid, 59 | '\n val ensemble: \n', top1ErrEnsembleValid, 60 | '\n test single:\n', top1Err, 61 | '\n test ensemble:\n', top1ErrEnsemble) 62 | torch.save(paths.concat(opt.save, 'anytime_result.t7'), {top1ErrValid, top1ErrEnsembleValid, 63 | top1Err, top1ErrEnsemble}) 64 | return 65 | end 66 | 67 | -- Initialize some parameters 68 | local paramsize = trainer.params:size(1) 69 | print('Parameters:', paramsize) 70 | checkpoints.save(0, model, trainer.optimState, false, opt) 71 | local valSingle, valEnsemble, testSingle, testEnsemble = {}, {}, {}, {} 72 | local startEpoch = checkpoint and checkpoint.epoch + 1 or opt.epochNumber 73 | local bestTop1 = math.huge 74 | local bestTop5 = math.huge 75 | local timer = torch.Timer() 76 | 77 | -- Training epochs 78 | for epoch = startEpoch, opt.nEpochs do 79 | 80 | -- Train for a single epoch 81 | timer:reset() 82 | trainer:train(epoch, trainLoader) 83 | 84 | -- Run model on validation set 85 | local valTop1All, _, valTop1Evolve = trainer:test(epoch, valLoader) 86 | valSingle[epoch] = valTop1All 87 | valEnsemble[epoch] = valTop1Evolve 88 | 89 | -- Run model on test set 90 | if opt.validset == true then 91 | local testTop1All, _, testTop1Evolve = trainer:test(epoch, testLoader) 92 | testSingle[epoch] = testTop1All 93 | testEnsemble[epoch] = testTop1Evolve 94 | end 95 | 96 | -- Log results to text file 97 | local filename = paths.concat(opt.save, 'result_') 98 | save2txt(filename..'valSingle', valSingle) 99 | save2txt(filename..'valEnsemble', valEnsemble) 100 | if opt.validset == true then 101 | save2txt(filename..'testSingle', testSingle) 102 | save2txt(filename..'testEnsemble', testEnsemble) 103 | end 104 | 105 | -- Checkpoint best model 106 | local bestModel = false 107 | if valEnsemble[epoch][opt.nBlocks] < bestTop1 then 108 | bestModel = true 109 | bestTop1 = valEnsemble[epoch][opt.nBlocks] 110 | print(' * Best model ', valEnsemble[epoch][opt.nBlocks]) 111 | torch.save(paths.concat(opt.save, 'best_result_ensemble.t7'), {valEnsemble[epoch], testEnsemble[epoch]}) 112 | end 113 | if bestModel or epoch == opt.nEpochs or opt.dataset == 'imagenet' then 114 | checkpoints.save(epoch, model, trainer.optimState, bestModel, opt) 115 | end 116 | end 117 | 118 | -- Done 119 | print(string.format(' * Finished top1: %6.3f top5: %6.3f', bestTop1, bestTop5)) 120 | -------------------------------------------------------------------------------- /models/JointTrainContainer.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cudnn' 3 | require 'cunn' 4 | require 'models/MSDNet_Layer' 5 | 6 | 7 | local function build_transition(nIn, nOut, outScales, offset, opt) 8 | local function conv1x1(nIn, nOut) 9 | local s = nn.Sequential() 10 | s:add(cudnn.SpatialConvolution(nIn, nOut, 1,1, 1,1, 0,0)) 11 | s:add(cudnn.SpatialBatchNormalization(nOut)) 12 | s:add(cudnn.ReLU(true)) 13 | return s 14 | end 15 | local net = nn.ParallelTable() 16 | for i = 1, outScales do 17 | net:add(conv1x1(nIn * opt.grFactor[offset + i], nOut * opt.grFactor[offset + i])) 18 | end 19 | return net 20 | end 21 | 22 | 23 | local function build_block(nChannels, opt, step, layer_all, layer_tillnow) 24 | local block = nn.Sequential() 25 | if layer_tillnow == 0 then -- layer_curr == 0 means we are at the first block 26 | block:add(nn.MSDNet_Layer_first(3, nChannels, opt)) 27 | end 28 | 29 | local nIn = nChannels 30 | for i = 1, step do 31 | local inScales, outScales = opt.nScales, opt.nScales 32 | layer_tillnow = layer_tillnow + 1 33 | if opt.prune == 'min' then 34 | inScales = math.min(opt.nScales, layer_all - layer_tillnow + 2) 35 | outScales = math.min(opt.nScales, layer_all - layer_tillnow + 1) 36 | elseif opt.prune == 'max' then 37 | local interval = torch.ceil(layer_all/opt.nScales) 38 | inScales = opt.nScales - torch.floor((math.max(0, layer_tillnow -2))/interval) 39 | outScales = opt.nScales - torch.floor((layer_tillnow -1)/interval) 40 | else 41 | error('Unknown prune option!') 42 | end 43 | print('|', 'inScales ', inScales, 'outScales ', outScales , '|') 44 | block:add(nn.MSDNet_Layer(nIn, opt.growthRate, opt, inScales, outScales)) 45 | nIn = nIn + opt.growthRate 46 | if opt.prune == 'max' and inScales > outScales and opt.reduction > 0 then 47 | local offset = opt.nScales - outScales 48 | block:add(build_transition(nIn, math.floor(opt.reduction*nIn), outScales, offset, opt)) 49 | nIn = math.floor(opt.reduction*nIn) 50 | print('|', 'Transition layer inserted!', '\t\t|') 51 | elseif opt.prune == 'min' and opt.reduction > 0 and (layer_tillnow == torch.floor(layer_all/3) or layer_tillnow == torch.floor(2*layer_all/3)) then 52 | local offset = opt.nScales - outScales 53 | block:add(build_transition(nIn, math.floor(opt.reduction*nIn), outScales, offset, opt)) 54 | nIn = math.floor(opt.reduction*nIn) 55 | print('|', 'Transition layer inserted!', '\t\t|') 56 | end 57 | end 58 | return block, nIn 59 | end 60 | 61 | local function build_classifier_cifar(nChannels, nClass) 62 | local interChannels1, interChannels2 = 128, 128 63 | local c = nn.Sequential() 64 | c:add(cudnn.SpatialConvolution(nChannels, interChannels1,3,3,2,2,1,1)) 65 | c:add(cudnn.SpatialBatchNormalization(interChannels1)) 66 | c:add(cudnn.ReLU(true)) 67 | c:add(cudnn.SpatialConvolution(interChannels1, interChannels2,3,3,2,2,1,1)) 68 | c:add(cudnn.SpatialBatchNormalization(interChannels2)) 69 | c:add(cudnn.ReLU(true)) 70 | c:add(cudnn.SpatialAveragePooling(2,2)) 71 | c:add(nn.Reshape(interChannels2)) 72 | c:add(nn.Linear(interChannels2, nClass)) 73 | return c 74 | end 75 | 76 | local function build_classifier_imagenet(nChannels, nClass) 77 | local c = nn.Sequential() 78 | c:add(cudnn.SpatialConvolution(nChannels, nChannels,3,3,2,2,1,1)) 79 | c:add(cudnn.SpatialBatchNormalization(nChannels)) 80 | c:add(cudnn.ReLU(true)) 81 | c:add(cudnn.SpatialConvolution(nChannels, nChannels,3,3,2,2,1,1)) 82 | c:add(cudnn.SpatialBatchNormalization(nChannels)) 83 | c:add(cudnn.ReLU(true)) 84 | c:add(cudnn.SpatialAveragePooling(2,2)) 85 | c:add(nn.Reshape(nChannels)) 86 | c:add(nn.Linear(nChannels, 1000)) 87 | return c 88 | end 89 | 90 | local JointTrainModule, parent = torch.class('nn.JointTrainModule', 'nn.Container') 91 | 92 | function JointTrainModule:__init(nChannels, opt) 93 | parent.__init(self) 94 | 95 | self.train = true 96 | self.nChannels = nChannels 97 | self.nBlocks = opt.nBlocks 98 | self.opt = opt 99 | 100 | local nIn = nChannels 101 | self.modules = {} 102 | 103 | -- calculate the step size of each blocks 104 | local layer_curr, layer_all = 0, opt.base 105 | local steps = {} 106 | steps[1] = opt.base 107 | for i = 2, self.nBlocks do 108 | steps[i] = opt.stepmode=='even' and opt.step or opt.stepmode=='lin_grow' and opt.step*(i-1)+1 109 | layer_all = layer_all + steps[i] 110 | end 111 | print("building network of steps: ") 112 | print(steps) 113 | torch.save(paths.concat(opt.save, 'layer_specific.t7'), steps) 114 | 115 | for i = 1, self.nBlocks do 116 | print(' ----------------------- Block ' .. i .. ' -----------------------') 117 | self.modules[i], nIn = build_block(nIn, opt, steps[i], layer_all, layer_curr) 118 | layer_curr = layer_curr + steps[i] 119 | if opt.dataset == 'cifar10' then 120 | self.modules[i+self.nBlocks] = build_classifier_cifar(nIn*opt.grFactor[opt.nScales], 10) 121 | elseif opt.dataset == 'cifar100' then 122 | self.modules[i+self.nBlocks] = build_classifier_cifar(nIn*opt.grFactor[opt.nScales], 100) 123 | elseif opt.dataset == 'imagenet' then 124 | self.modules[i+self.nBlocks] = build_classifier_imagenet(nIn*opt.grFactor[opt.nScales], 1000) 125 | else 126 | error('Unknown dataset!') 127 | end 128 | end 129 | self.gradInput = {} 130 | self.output = {} 131 | 132 | end 133 | 134 | function JointTrainModule:updateOutput(input) 135 | for i = 1, self.nBlocks do 136 | local tmp_input = (i==1) and input or self.modules[i-1].output 137 | local tmp_output1 = self:rethrowErrors(self.modules[i], i, 'updateOutput', tmp_input) 138 | local tmp_output2 = self:rethrowErrors(self.modules[i+self.nBlocks], i+self.nBlocks, 'updateOutput', tmp_output1[#tmp_output1]) 139 | self.output[i] = tmp_output2 140 | end 141 | return self.output 142 | end 143 | 144 | function JointTrainModule:updateGradInput(input, gradOutput) 145 | 146 | local nScales = self.opt.nScales 147 | for i = self.nBlocks, 1, -1 do 148 | local features = self.modules[i].output[#self.modules[i].output] 149 | self.modules[i+self.nBlocks]:updateGradInput(features, gradOutput[i]) 150 | local gOut = {} 151 | if i == self.nBlocks then 152 | for s = 1, nScales do 153 | local out = self.modules[i].output[s] 154 | if out then 155 | gOut[s] = out.new():resizeAs(out):zero() 156 | end 157 | end 158 | else 159 | gOut = self.modules[i+1].gradInput 160 | end 161 | gOut[#gOut]:add(self.modules[i+self.nBlocks].gradInput) 162 | local tmp_input = (i==1) and input or self.modules[i-1].output 163 | self.modules[i]:updateGradInput(tmp_input, gOut) 164 | end 165 | self.gradInput = self.modules[1].gradInput 166 | 167 | return self.gradInput 168 | end 169 | 170 | function JointTrainModule:accGradParameters(input, gradOutput, scale) 171 | scale = scale or 1 172 | for i = self.nBlocks, 1, -1 do 173 | local features = self.modules[i].output[#self.modules[i].output] 174 | self.modules[i+self.nBlocks]:accGradParameters(features, gradOutput[i], scale) 175 | local tmp_input = (i==1) and input or self.modules[i-1].output 176 | local gOut 177 | if i == self.nBlocks then 178 | gOut = self.modules[i].output 179 | else 180 | gOut = self.modules[i+1].gradInput 181 | end 182 | self.modules[i]:accGradParameters(tmp_input, gOut, scale) 183 | end 184 | end 185 | 186 | function JointTrainModule:__tostring__() 187 | local tab = ' ' 188 | local line = '\n' 189 | local next = ' |`-> ' 190 | local lastNext = ' `-> ' 191 | local ext = ' | ' 192 | local extlast = ' ' 193 | local last = ' ... -> ' 194 | local str = 'JointTrainModule' 195 | str = str .. ' {' .. line .. tab .. '{input}' 196 | for i=1,#self.modules do 197 | if i == #self.modules then 198 | str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) 199 | else 200 | str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) 201 | end 202 | end 203 | str = str .. line .. tab .. last .. '{output}' 204 | str = str .. line .. '}' 205 | return str 206 | end 207 | 208 | -------------------------------------------------------------------------------- /models/MSDNet.lua: -------------------------------------------------------------------------------- 1 | msdnet.lua -------------------------------------------------------------------------------- /models/MSDNet_Layer.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cudnn' 3 | require 'cunn' 4 | 5 | local function build(nChannels, nOutChannels, type, bottleneck, bnWidth) 6 | local net = nn.Sequential() 7 | local innerChannels = nChannels 8 | bnWidth = bnWidth or 4 9 | if bottleneck == true then 10 | innerChannels = math.min(innerChannels, bnWidth * nOutChannels) 11 | net:add(cudnn.SpatialConvolution(nChannels, innerChannels, 1,1, 1,1, 0,0)) 12 | net:add(cudnn.SpatialBatchNormalization(innerChannels)) 13 | net:add(cudnn.ReLU(true)) 14 | end 15 | if type == 'normal' then 16 | net:add(cudnn.SpatialConvolution(innerChannels, nOutChannels, 3,3, 1,1, 1,1)) 17 | elseif type == 'down' then 18 | net:add(cudnn.SpatialConvolution(innerChannels, nOutChannels, 3,3, 2,2, 1,1)) 19 | elseif type == 'up' then 20 | net:add(cudnn.SpatialFullConvolution(innerChannels, nOutChannels, 3,3, 2,2, 1,1, 1,1)) 21 | else 22 | error("Please implement me: " .. type) 23 | end 24 | net:add(cudnn.SpatialBatchNormalization(nOutChannels)) 25 | net:add(cudnn.ReLU(true)) 26 | return net 27 | end 28 | 29 | local function build_net_normal(nChannels, nOutChannels, bottleneck, bnWidth) 30 | local net_warp = nn.Sequential() 31 | local net = nn.ParallelTable() 32 | net:add(nn.Identity()) 33 | net:add(build(nChannels, nOutChannels, 'normal', bottleneck, bnWidth)) 34 | net_warp:add(net):add(nn.JoinTable(2)) 35 | return net_warp 36 | end 37 | 38 | local function build_net_down_normal(nChannels1, nChannels2, nOutChannels, bottleneck, bnWidth1, bnWidth2) 39 | local net_warpper = nn.Sequential() 40 | local net = nn.ParallelTable() 41 | assert(nOutChannels % 2 == 0, 'Growth rate invalid!') 42 | net:add(nn.Identity()) 43 | net:add(build(nChannels1, nOutChannels/2, 'down', bottleneck, bnWidth1)) 44 | net:add(build(nChannels2, nOutChannels/2, 'normal', bottleneck, bnWidth2)) 45 | net_warpper:add(net):add(nn.JoinTable(2)) 46 | return net_warpper 47 | end 48 | 49 | 50 | ---------------- 51 | --- MSDNet_Layer_first: 52 | --- the input layer of MSDNet 53 | --- input: a tensor (orginal image) 54 | --- output: a table of nScale tensors 55 | ---------------- 56 | 57 | local MSDNet_Layer_first, parent = torch.class('nn.MSDNet_Layer_first', 'nn.Container') 58 | 59 | function MSDNet_Layer_first:__init(nChannels, nOutChannels, opt) 60 | parent.__init(self) 61 | 62 | self.train = true 63 | self.nChannels = nChannels 64 | self.nOutChannels = nOutChannels 65 | 66 | self.opt = opt 67 | 68 | self.modules = {} 69 | -- transform raw input to first layer 70 | self.modules[1] = nn.Sequential() 71 | if opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then 72 | self.modules[1]:add(cudnn.SpatialConvolution(nChannels, nOutChannels*opt.grFactor[1], 3,3, 1,1, 1,1)) 73 | self.modules[1]:add(cudnn.SpatialBatchNormalization(nOutChannels*opt.grFactor[1])) 74 | self.modules[1]:add(cudnn.ReLU(true)) 75 | elseif opt.dataset == 'imagenet' then 76 | self.modules[1]:add(cudnn.SpatialConvolution(nChannels,nOutChannels*opt.grFactor[1], 7,7, 2,2, 3,3)) 77 | self.modules[1]:add(nn.SpatialBatchNormalization(nOutChannels*opt.grFactor[1])) 78 | self.modules[1]:add(cudnn.ReLU(true)) 79 | self.modules[1]:add(nn.SpatialMaxPooling(3,3,2,2,1,1)) 80 | end 81 | 82 | local nIn = nOutChannels * opt.grFactor[1] 83 | for i = 2, opt.nScales do 84 | self.modules[i] = nn.Sequential() 85 | self.modules[i]:add(cudnn.SpatialConvolution(nIn, nOutChannels*opt.grFactor[i], 3,3, 2,2, 1,1)) 86 | self.modules[i]:add(cudnn.SpatialBatchNormalization(nOutChannels*opt.grFactor[i])) 87 | self.modules[i]:add(cudnn.ReLU(true)) 88 | nIn = nOutChannels*opt.grFactor[i] 89 | end 90 | 91 | self.gradInput = torch.CudaTensor() 92 | self.output = {} 93 | end 94 | 95 | function MSDNet_Layer_first:updateOutput(input) 96 | for i = 1, self.opt.nScales do 97 | self.output[i] = self.output[i] or input.new() 98 | local tmp_input = (i==1) and input or self.output[i-1] 99 | local tmp_output = self:rethrowErrors(self.modules[i], i, 'updateOutput', tmp_input) 100 | self.output[i]:resizeAs(tmp_output):copy(tmp_output) 101 | end 102 | return self.output 103 | end 104 | 105 | function MSDNet_Layer_first:updateGradInput(input, gradOutput) 106 | self.gradInput = self.gradInput or input.new() 107 | self.gradInput:resizeAs(input):zero() 108 | for i = self.opt.nScales-1, 1, -1 do 109 | gradOutput[i]:add(self.modules[i+1]:updateGradInput(self.output[i], gradOutput[i+1])) 110 | end 111 | self.gradInput:resizeAs(input):copy(self.modules[1]:updateGradInput(input, gradOutput[1])) 112 | return self.gradInput 113 | end 114 | 115 | function MSDNet_Layer_first:accGradParameters(input, gradOutput, scale) 116 | scale = scale or 1 117 | for i = self.opt.nScales, 2, -1 do 118 | self.modules[i]:accGradParameters(self.output[i-1], gradOutput[i], scale) 119 | end 120 | self.modules[1]:accGradParameters(input, gradOutput[1], scale) 121 | end 122 | 123 | function MSDNet_Layer_first:__tostring__() 124 | local tab = ' ' 125 | local line = '\n' 126 | local next = ' |`-> ' 127 | local lastNext = ' `-> ' 128 | local ext = ' | ' 129 | local extlast = ' ' 130 | local last = ' ... -> ' 131 | local str = 'MSDNet_Layer_first' 132 | str = str .. ' {' .. line .. tab .. '{input}' 133 | for i=1,#self.modules do 134 | if i == #self.modules then 135 | str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) 136 | else 137 | str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) 138 | end 139 | end 140 | str = str .. line .. tab .. last .. '{output}' 141 | str = str .. line .. '}' 142 | return str 143 | end 144 | 145 | ---------------- 146 | --- MSDNet_Layer (without upsampling): 147 | --- subsequent layers of MSDNet 148 | --- input: a table of `nScales` tensors 149 | --- output: a table of `nScales` tensors 150 | ---------------- 151 | 152 | 153 | local MSDNet_Layer, parent = torch.class('nn.MSDNet_Layer', 'nn.Container') 154 | 155 | 156 | function MSDNet_Layer:__init(nChannels, nOutChannels, opt, inScales, outScales) 157 | parent.__init(self) 158 | 159 | self.train = true 160 | self.nChannels = nChannels 161 | self.nOutChannels = nOutChannels 162 | self.opt = opt 163 | self.inScales = inScales or opt.nScales 164 | self.outScales = outScales or opt.nScales 165 | self.nScales = opt.nScales 166 | self.discard = self.inScales - self.outScales 167 | assert(self.discard<=1, 'Double check inScales'..self.inScales..'and outScales: '..self.outScales) 168 | 169 | local offset = self.nScales - self.outScales 170 | self.modules = {} 171 | local isTrans = self.outScales 1 then 251 | self.gradInput[self.inScales]:resizeAs(self.gIn[self.outScales][1]) 252 | :copy(self.gIn[self.outScales][1]):add(self.gIn[self.outScales][3]) 253 | end 254 | 255 | return self.gradInput 256 | 257 | end 258 | 259 | 260 | function MSDNet_Layer:accGradParameters(input, gradOutput, scale) 261 | 262 | scale = scale or 1 263 | for i = 1, self.outScales do 264 | self.modules[i]:accGradParameters(self.real_input[i], gradOutput[i], scale) 265 | end 266 | end 267 | 268 | 269 | function MSDNet_Layer:clearState() 270 | -- don't call set because it might reset referenced tensors 271 | local function clear(f) 272 | if self[f] then 273 | if torch.isTensor(self[f]) then 274 | self[f] = self[f].new() 275 | elseif type(self[f]) == 'table' then 276 | self[f] = {} 277 | else 278 | self[f] = nil 279 | end 280 | end 281 | end 282 | clear('output') 283 | clear('gradInput') 284 | clear('real_input') 285 | clear('gIn') 286 | if self.modules then 287 | for i,module in pairs(self.modules) do 288 | module:clearState() 289 | end 290 | end 291 | return self 292 | end 293 | 294 | function MSDNet_Layer:__tostring__() 295 | local tab = ' ' 296 | local line = '\n' 297 | local next = ' |`-> ' 298 | local lastNext = ' `-> ' 299 | local ext = ' | ' 300 | local extlast = ' ' 301 | local last = ' ... -> ' 302 | local str = 'MSDNet_Layer' 303 | str = str .. ' {' .. line .. tab .. '{input}' 304 | for i = 1,#self.modules do 305 | if i == #self.modules then 306 | str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) 307 | else 308 | str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) 309 | end 310 | end 311 | str = str .. line .. tab .. last .. '{output}' 312 | str = str .. line .. '}' 313 | return str 314 | end -------------------------------------------------------------------------------- /models/init.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Generic model creating code. For the specific MSDNet model see 10 | -- models/MSDNet.lua 11 | -- 12 | 13 | require 'nn' 14 | require 'cunn' 15 | require 'cudnn' 16 | require 'models/MSDNet_Layer' 17 | require 'models/JointTrainContainer' 18 | 19 | local M = {} 20 | 21 | function M.setup(opt, checkpoint) 22 | local model 23 | if checkpoint then 24 | local modelPath = paths.concat(opt.resume, checkpoint.modelFile) 25 | assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath) 26 | print('=> Resuming model from ' .. modelPath) 27 | model = torch.load(modelPath):type(opt.tensorType) 28 | elseif opt.retrain ~= 'none' then 29 | assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain) 30 | print('Loading model from file: ' .. opt.retrain) 31 | model = torch.load(opt.retrain):type(opt.tensorType) 32 | print(model) 33 | model.__memoryOptimized = nil 34 | else 35 | print('=> Creating model from file: models/' .. opt.netType .. '.lua') 36 | model = require('models/' .. opt.netType)(opt) 37 | end 38 | 39 | -- First remove any DataParallelTable 40 | if torch.type(model) == 'nn.DataParallelTable' then 41 | model = model:get(1) 42 | end 43 | 44 | -- optnet is an general library for reducing memory usage in neural networks 45 | if opt.optnet then 46 | local optnet = require 'optnet' 47 | local imsize = opt.dataset == 'imagenet' and 224 or 32 48 | local sampleInput = torch.zeros(4,3,imsize,imsize):type(opt.tensorType) 49 | optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'}) 50 | end 51 | 52 | -- This is useful for fitting ResNet-50 on 4 GPUs, but requires that all 53 | -- containers override backwards to call backwards recursively on submodules 54 | if opt.shareGradInput then 55 | M.shareGradInput(model, opt) 56 | end 57 | 58 | -- Set the CUDNN flags 59 | if opt.cudnn == 'fastest' then 60 | cudnn.fastest = true 61 | cudnn.benchmark = true 62 | elseif opt.cudnn == 'deterministic' then 63 | -- Use a deterministic convolution implementation 64 | model:apply(function(m) 65 | if m.setMode then m:setMode(1, 1, 1) end 66 | end) 67 | end 68 | 69 | -- Wrap the model with DataParallelTable, if using more than one GPU 70 | if opt.nGPU > 1 then 71 | local gpus = torch.range(1, opt.nGPU):totable() 72 | local fastest, benchmark = cudnn.fastest, cudnn.benchmark 73 | 74 | local dpt = nn.DataParallelTable(1, true, true) 75 | :add(model, gpus) 76 | :threads(function() 77 | local cudnn = require 'cudnn' 78 | require 'models/MSDNet_Layer' 79 | require 'models/JointTrainContainer' 80 | cudnn.fastest, cudnn.benchmark = fastest, benchmark 81 | end) 82 | dpt.gradInput = nil 83 | 84 | model = dpt:type(opt.tensorType) 85 | end 86 | 87 | local criterion = nn.CrossEntropyCriterion():type(opt.tensorType) 88 | return model, criterion 89 | end 90 | 91 | function M.shareGradInput(model, opt) 92 | local function sharingKey(m) 93 | local key = torch.type(m) 94 | if m.__shareGradInputKey then 95 | key = key .. ':' .. m.__shareGradInputKey 96 | end 97 | return key 98 | end 99 | 100 | -- Share gradInput for memory efficient backprop 101 | local cache = {} 102 | model:apply(function(m) 103 | local moduleType = torch.type(m) 104 | if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' then 105 | local key = sharingKey(m) 106 | if cache[key] == nil then 107 | cache[key] = torch[self.opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')]()(1) 108 | end 109 | m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[key], 1, 0) 110 | end 111 | end) 112 | for i, m in ipairs(model:findModules('nn.ConcatTable')) do 113 | if cache[i % 2] == nil then 114 | cache[i % 2] = torch.CudaStorage(1) 115 | end 116 | m.gradInput = torch[self.opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](cache[i % 2], 1, 0) 117 | end 118 | end 119 | 120 | return M 121 | -------------------------------------------------------------------------------- /models/msdnet.lua: -------------------------------------------------------------------------------- 1 | local nn = require 'nn' 2 | require 'cunn' 3 | require 'models/JointTrainContainer' 4 | 5 | local function createModel(opt) 6 | -- (1) configure 7 | if opt.stepmode == 'even' then 8 | assert(opt.base - opt.step >= 0, 'Base should not be smaller than step!') 9 | end 10 | 11 | local nChannels 12 | if opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then 13 | nChannels = opt.initChannels>0 and opt.initChannels or 16 14 | elseif opt.dataset == 'imagenet' then 15 | nChannels = opt.initChannels>0 and opt.initChannels or 64 16 | else 17 | error('invalid dataset: ' .. opt.dataset) 18 | end 19 | 20 | -- (2) build model 21 | print(' | MSDNet-Block' .. opt.nBlocks.. '-'..opt.step .. ' ' .. opt.dataset) 22 | local model = nn.Sequential() 23 | model:add(nn.JointTrainModule(nChannels, opt)) 24 | 25 | -- (3) init model 26 | local function ConvInit(name) 27 | for k,v in pairs(model:findModules(name)) do 28 | local n = v.kW*v.kH*v.nOutputPlane 29 | v.weight:normal(0,math.sqrt(2/n)) 30 | if cudnn.version >= 4000 then 31 | v.bias = nil 32 | v.gradBias = nil 33 | else 34 | v.bias:zero() 35 | end 36 | end 37 | end 38 | local function BNInit(name) 39 | for k,v in pairs(model:findModules(name)) do 40 | v.weight:fill(1) 41 | v.bias:zero() 42 | end 43 | end 44 | ConvInit('cudnn.SpatialConvolution') 45 | ConvInit('nn.SpatialConvolution') 46 | BNInit('fbnn.SpatialBatchNormalization') 47 | BNInit('cudnn.SpatialBatchNormalization') 48 | BNInit('nn.SpatialBatchNormalization') 49 | for k,v in pairs(model:findModules('nn.Linear')) do 50 | v.bias:zero() 51 | end 52 | model:type(opt.tensorType) 53 | 54 | if opt.cudnn == 'deterministic' then 55 | model:apply(function(m) 56 | if m.setMode then m:setMode(1,1,1) end 57 | end) 58 | end 59 | 60 | model:get(1).gradInput = nil 61 | 62 | -- (4) save the network definition 63 | local file = io.open(paths.concat(opt.save, 'model_definition.txt'), "w") 64 | for k, v in pairs(opt) do 65 | local s 66 | if v == true then 67 | s = 'true' 68 | elseif v == false then 69 | s = 'false' 70 | else 71 | s = v 72 | end 73 | file:write(tostring(k) .. ': '..tostring(s)..'\n') 74 | end 75 | 76 | file:write('\n model definition \n\n') 77 | file:write(model:__tostring__()) 78 | file:close() 79 | 80 | print(model) 81 | 82 | return model 83 | end 84 | 85 | return createModel 86 | -------------------------------------------------------------------------------- /opts.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | local M = { } 10 | 11 | function M.parse(arg) 12 | local cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('Torch-7 ResNet Training script') 15 | cmd:text('See https://github.com/facebook/fb.resnet.torch/blob/master/TRAINING.md for examples') 16 | cmd:text() 17 | cmd:text('Options:') 18 | ------------ General options -------------------- 19 | cmd:option('-data', '', 'Path to dataset') 20 | cmd:option('-dataset', 'imagenet', 'Options: imagenet | cifar10 | cifar100') 21 | cmd:option('-manualSeed', 0, 'Manually set RNG seed') 22 | cmd:option('-nGPU', 1, 'Number of GPUs to use by default') 23 | cmd:option('-backend', 'cudnn', 'Options: cudnn | cunn') 24 | cmd:option('-cudnn', 'fastest', 'Options: fastest | default | deterministic') 25 | cmd:option('-gen', 'gen', 'Path to save generated files') 26 | cmd:option('-precision', 'single', 'Options: single | double | half') 27 | ------------- Data options ------------------------ 28 | cmd:option('-nThreads', 2, 'number of data loading threads') 29 | cmd:option('-DataAug', 'true', 'use data augmentation or not') 30 | cmd:option('-validset', 'true', 'use validation set or not') 31 | ------------- Training options -------------------- 32 | cmd:option('-nEpochs', 0, 'Number of total epochs to run') 33 | cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)') 34 | cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)') 35 | cmd:option('-testOnly', 'false', 'Run on validation set only') 36 | cmd:option('-tenCrop', 'false', 'Ten-crop testing') 37 | ------------- Checkpointing options --------------- 38 | cmd:option('-save', 'checkpoints', 'Directory in which to save checkpoints') 39 | cmd:option('-resume', 'none', 'Resume from the latest checkpoint in this directory') 40 | 41 | ---------- Optimization options ---------------------- 42 | cmd:option('-LR', 0.1, 'initial learning rate') 43 | cmd:option('-momentum', 0.9, 'momentum') 44 | cmd:option('-weightDecay', 1e-4, 'weight decay') 45 | ---------- Model options ---------------------------------- 46 | cmd:option('-netType', 'MSDNet', '') 47 | cmd:option('-retrain', 'none', 'Path to model to retrain with') 48 | cmd:option('-optimState', 'none', 'Path to an optimState to reload from') 49 | ---------- Model options ---------------------------------- 50 | cmd:option('-shareGradInput', 'false', 'Share gradInput tensors to reduce memory usage') 51 | cmd:option('-optnet', 'false', 'Use optnet to reduce memory usage') 52 | ---------- MSDNet MOdel options ---------------------------------- 53 | cmd:option('-base', 4, 'the layer to attach the first classifier') 54 | cmd:option('-nBlocks', 1, 'number of blocks/classifiers') 55 | cmd:option('-stepmode', 'even', 'patten of span between two adjacent classifers |even|lin_grow|') 56 | cmd:option('-step', 1, 'span between two adjacent classifers') 57 | cmd:option('-bottleneck', 'true', 'use 1x1 conv layer or not') 58 | cmd:option('-reduction', 0.5, 'dimension reduction ratio at transition layers') 59 | cmd:option('-growthRate', 6, 'number of output channels for each layer (the first scale)') 60 | cmd:option('-grFactor', '1-2-4-4', 'growth rate factor of each sacle') 61 | cmd:option('-prune', 'max', 'specify how to prune the network, min | max') 62 | cmd:option('-joinType', 'concat', 'add or concat for features from different paths') 63 | cmd:option('-bnFactor', '1-2-4-4', 'bottleneck factor of each sacle, 4-4-4-4 | 1-2-4-4') 64 | cmd:option('-initChannels', 0, 'number of features produced by the initial conv layer') 65 | 66 | ---------- joint training options ---------------------------------- 67 | cmd:option('-clearstate', 'true', 'save a model with clearsate or not') 68 | cmd:option('-joint_weight', 'uniform', 'weight of differnt classifiers: uniform | lin_grow | triangle | chi | gauss | exp') 69 | 70 | 71 | ---------- early exit testing options ---------------------------------- 72 | cmd:option('-EEpath', 'none', 'the path to a saved model for early exit calculation') 73 | cmd:option('-EEensemble', 'true', 'use ensemble or not in early exit') 74 | 75 | cmd:text() 76 | 77 | local opt = cmd:parse(arg or {}) 78 | 79 | opt.testOnly = opt.testOnly ~= 'false' 80 | opt.tenCrop = opt.tenCrop ~= 'false' 81 | opt.shareGradInput = opt.shareGradInput ~= 'false' 82 | opt.optnet = opt.optnet ~= 'false' 83 | opt.resetClassifier = opt.resetClassifier ~= 'false' 84 | opt.bottleneck = opt.bottleneck ~= 'false' 85 | opt.DataAug = opt.DataAug ~= 'false' 86 | opt.validset = opt.validset ~= 'false' 87 | opt.override = opt.override ~= 'false' 88 | opt.clearstate = opt.clearstate ~= 'false' 89 | opt.EEensemble = opt.EEensemble ~= 'false' 90 | 91 | -- for logging 92 | opt._grFactor = opt.grFactor 93 | opt._bnFactor = opt.bnFactor 94 | 95 | local bnFactor = {} 96 | for i, s in pairs(opt.bnFactor:split('-')) do 97 | bnFactor[i] = tonumber(s) 98 | end 99 | opt.bnFactor = bnFactor 100 | 101 | local grFactor = {} 102 | for i, s in pairs(opt.grFactor:split('-')) do 103 | grFactor[i] = tonumber(s) 104 | end 105 | opt.grFactor = grFactor 106 | 107 | if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then 108 | cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n') 109 | end 110 | 111 | if opt.dataset == 'imagenet' then 112 | -- Handle the most common case of missing -data flag 113 | local trainDir = paths.concat(opt.data, 'train') 114 | if not paths.dirp(opt.data) then 115 | cmd:error('error: missing ImageNet data directory') 116 | elseif not paths.dirp(trainDir) then 117 | cmd:error('error: ImageNet missing `train` directory: ' .. trainDir) 118 | end 119 | -- Default shortcutType=B and nEpochs=90 120 | opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType 121 | opt.nEpochs = opt.nEpochs == 0 and 90 or opt.nEpochs 122 | opt.nScales = 4 123 | elseif opt.dataset == 'cifar10' then 124 | -- Default shortcutType=A and nEpochs=164 125 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType 126 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs 127 | opt.nScales = 3 128 | elseif opt.dataset == 'cifar100' then 129 | -- Default shortcutType=A and nEpochs=164 130 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType 131 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs 132 | opt.nScales = 3 133 | else 134 | cmd:error('unknown dataset: ' .. opt.dataset) 135 | end 136 | 137 | if opt.precision == nil or opt.precision == 'single' then 138 | opt.tensorType = 'torch.CudaTensor' 139 | elseif opt.precision == 'double' then 140 | opt.tensorType = 'torch.CudaDoubleTensor' 141 | elseif opt.precision == 'half' then 142 | opt.tensorType = 'torch.CudaHalfTensor' 143 | else 144 | cmd:error('unknown precision: ' .. opt.precision) 145 | end 146 | 147 | if opt.shareGradInput and opt.optnet then 148 | cmd:error('error: cannot use both -shareGradInput and -optnet') 149 | end 150 | 151 | return opt 152 | end 153 | 154 | return M 155 | -------------------------------------------------------------------------------- /saveTXT.lua: -------------------------------------------------------------------------------- 1 | local save2txt = function(filename, data, precision) 2 | 3 | local precision = precision or 4 4 | 5 | local sz 6 | if torch.type(data) == 'table' and #data > 0 then 7 | local n = #data 8 | local d = #data[1] 9 | if torch.type(d) ~= 'number' then 10 | d = d[1] 11 | end 12 | sz = {n,d} 13 | elseif torch.isTensor(data) then 14 | sz = data:size() 15 | else 16 | print('Input cannot be saved to TXT file') 17 | end 18 | 19 | if #sz == 1 then 20 | local file = io.open(filename..'.txt', "w") 21 | for i = 1, sz[1] do 22 | file:write(string.format('%0.'..precision..'f\t', data[i])) 23 | file:write('\n') 24 | end 25 | file:close() 26 | elseif #sz ==2 then 27 | local file = io.open(filename..'.txt', "w") 28 | for i = 1, sz[1] do 29 | for j = 1, sz[2] do 30 | file:write(string.format('%0.'..precision..'f\t', data[i][j])) 31 | end 32 | file:write('\n') 33 | end 34 | file:close() 35 | else 36 | print('Input cannot be saved to TXT file') 37 | end 38 | end 39 | 40 | return save2txt 41 | -------------------------------------------------------------------------------- /train_joint.lua: -------------------------------------------------------------------------------- 1 | 2 | local optim = require 'optim' 3 | 4 | local M = {} 5 | local Trainer = torch.class('MSDNet.Trainer', M) 6 | 7 | function Trainer:__init(model, criterion, opt, optimState) 8 | self.model = model 9 | self.criterion = criterion 10 | self.optimState = optimState or { 11 | learningRate = opt.LR, 12 | learningRateDecay = 0.0, 13 | momentum = opt.momentum, 14 | nesterov = true, 15 | dampening = 0.0, 16 | weightDecay = opt.weightDecay, 17 | } 18 | self.opt = opt 19 | self.params, self.gradParams = model:getParameters() 20 | end 21 | 22 | function Trainer:train(epoch, dataloader) 23 | -- Trains the model for a single epoch 24 | self.optimState.learningRate = self:learningRate(epoch) 25 | 26 | local timer = torch.Timer() 27 | local dataTimer = torch.Timer() 28 | 29 | local function feval() 30 | return self.criterion.output, self.gradParams 31 | end 32 | 33 | local trainSize = dataloader:size() 34 | local lossSum = 0.0 35 | local N = 0 36 | 37 | local top1All, top5All = torch.zeros(self.opt.nBlocks), torch.zeros(self.opt.nBlocks) 38 | local top1Evolve, top5Evolve = torch.zeros(self.opt.nBlocks), torch.zeros(self.opt.nBlocks) 39 | 40 | print('=> Training epoch # ' .. epoch) 41 | -- set the batch norm to training mode 42 | self.model:training() 43 | for n, sample in dataloader:run() do 44 | local dataTime = dataTimer:time().real 45 | 46 | -- Copy input and target to the GPU 47 | self:copyInputs(sample) 48 | 49 | local output = self.model:forward(self.input) 50 | local batchSize = output[1]:size(1) 51 | 52 | -- create a table which contains `nBlocks' same targets 53 | local multi_targets = {} 54 | for i = 1, self.opt.nBlocks do 55 | multi_targets[i] = self.target 56 | end 57 | local loss = self.criterion:forward(self.model.output, multi_targets) 58 | lossSum = lossSum + loss 59 | 60 | self.model:zeroGradParameters() 61 | self.criterion:backward(self.model.output, multi_targets) 62 | self.model:backward(self.input, self.criterion.gradInput) 63 | 64 | optim.sgd(feval, self.params, self.optimState) 65 | 66 | -- sotre the cumulative softmax() of exits to obtain the ensemble performance 67 | local ensemble = torch.Tensor():resizeAs(output[1]:float()):zero() 68 | local top1, top5 = 0, 0 69 | for i = 1, self.opt.nBlocks do 70 | -- single exit 71 | top1, top5 = self:computeScore(output[i]:float(), sample.target, 1) 72 | top1All[i] = top1All[i] + top1*batchSize 73 | top5All[i] = top5All[i] + top5*batchSize 74 | -- ensemble 75 | ensemble:add(nn.SoftMax():forward(output[i]:float())) 76 | top1, top5 = self:computeScore(ensemble, sample.target, 1) 77 | top1Evolve[i] = top1Evolve[i] + top1*batchSize 78 | top5Evolve[i] = top5Evolve[i] + top5*batchSize 79 | end 80 | 81 | N = N + batchSize 82 | 83 | print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.4f top1 %7.3f top5 %7.3f'):format( 84 | epoch, n, trainSize, timer:time().real, dataTime, loss, top1All[self.opt.nBlocks]/N, top5All[self.opt.nBlocks]/N)) 85 | 86 | -- check that the storage didn't get changed due to an unfortunate getParameters call 87 | assert(self.params:storage() == self.model:parameters()[1]:storage()) 88 | 89 | timer:reset() 90 | dataTimer:reset() 91 | end 92 | 93 | for i = 1, self.opt.nBlocks do 94 | top1All[i] = top1All[i] / N 95 | top5All[i] = top5All[i] / N 96 | top1Evolve[i] = top1Evolve[i] / N 97 | top5Evolve[i] = top5Evolve[i] / N 98 | print((' * Train %d exit single top1: %7.3f top5: %7.3f, \t Ensemble %d exit(s) top1: %7.3f top5: %7.3f') 99 | :format(i, top1All[i], top5All[i], i, top1Evolve[i], top5Evolve[i])) 100 | end 101 | 102 | return top1All, top5All, top1Evolve, top5Evolve, lossSum / N 103 | end 104 | 105 | function Trainer:test(epoch, dataloader, prefix) 106 | -- Computes the top-1 and top-5 err on the validation/test set 107 | prefix = prefix or 'Test' 108 | local timer = torch.Timer() 109 | local dataTimer = torch.Timer() 110 | local size = dataloader:size() 111 | 112 | local nCrops = self.opt.tenCrop and 10 or 1 113 | local N = 0 114 | 115 | local top1All, top5All = torch.zeros(self.opt.nBlocks), torch.zeros(self.opt.nBlocks) 116 | local top1Evolve, top5Evolve = torch.zeros(self.opt.nBlocks), torch.zeros(self.opt.nBlocks) 117 | 118 | self.model:evaluate() 119 | for n, sample in dataloader:run() do 120 | local dataTime = dataTimer:time().real 121 | 122 | -- Copy input and target to the GPU 123 | self:copyInputs(sample) 124 | 125 | local output = self.model:forward(self.input) 126 | local batchSize = output[1]:size(1) / nCrops 127 | -- create a table which contains `nBlocks' same targets 128 | local multi_targets = {} 129 | for i = 1, self.opt.nBlocks do 130 | multi_targets[i] = self.target 131 | end 132 | local loss = self.criterion:forward(self.model.output, multi_targets) 133 | 134 | -- sotre the cumulative softmax() of exits to obtain the ensemble performance 135 | local ensemble = torch.Tensor():resizeAs(output[1]:float()):zero() 136 | local top1, top5 = 0, 0 137 | for i = 1, self.opt.nBlocks do 138 | -- single exit 139 | ensemble = ensemble or torch.Tensor():resizeAs(output[1]:float()):zero() 140 | top1, top5 = self:computeScore(output[i]:float(), sample.target, 1) 141 | top1All[i] = top1All[i] + top1*batchSize 142 | top5All[i] = top5All[i] + top5*batchSize 143 | -- ensemble 144 | ensemble:add(nn.SoftMax():forward(output[i]:float())) 145 | top1, top5 = self:computeScore(ensemble, sample.target, 1) 146 | top1Evolve[i] = top1Evolve[i] + top1*batchSize 147 | top5Evolve[i] = top5Evolve[i] + top5*batchSize 148 | end 149 | ---------------------------------------------------- 150 | N = N + batchSize 151 | 152 | print((' | %s: [%d][%d/%d] Time %.3f Data %.3f top1 %7.3f (cumul: %7.3f) top5 %7.3f (cumul: %7.3f)'):format( 153 | prefix, epoch, n, size, timer:time().real, dataTime, top1, top1All[self.opt.nBlocks]/N, 154 | top5, top5All[self.opt.nBlocks]/N)) 155 | 156 | timer:reset() 157 | dataTimer:reset() 158 | end 159 | self.model:training() 160 | 161 | for i = 1, self.opt.nBlocks do 162 | top1All[i] = top1All[i] / N 163 | top5All[i] = top5All[i] / N 164 | top1Evolve[i] = top1Evolve[i] / N 165 | top5Evolve[i] = top5Evolve[i] / N 166 | print((' * %s %d exit top1: %7.3f top5: %7.3f, \t Ensemble %d exit(s) top1: %7.3f top5: %7.3f'):format( 167 | prefix, i, top1All[i], top5All[i], i, top1Evolve[i], top5Evolve[i])) 168 | end 169 | 170 | return top1All, top5All, top1Evolve, top5Evolve 171 | end 172 | 173 | function Trainer:computeScore(output, target, nCrops) 174 | if nCrops > 1 then 175 | -- Sum over crops 176 | output = output:view(output:size(1) / nCrops, nCrops, output:size(2)) 177 | --:exp() 178 | :sum(2):squeeze(2) 179 | end 180 | 181 | -- Coputes the top1 and top5 error rate 182 | local batchSize = output:size(1) 183 | 184 | local _, predictions = output:float():topk(5, 2, true, true) -- descending sort 185 | 186 | -- Find which predictions match the target 187 | local correct = predictions:eq( 188 | target:long():view(batchSize, 1):expandAs(predictions)) 189 | 190 | -- Top-1 score 191 | local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize) 192 | 193 | -- Top-5 score, if there are at least 5 classes 194 | local len = math.min(5, correct:size(2)) 195 | local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize) 196 | 197 | return top1 * 100, top5 * 100 198 | end 199 | 200 | local function getCudaTensorType(tensorType) 201 | if tensorType == 'torch.CudaHalfTensor' then 202 | return cutorch.createCudaHostHalfTensor() 203 | elseif tensorType == 'torch.CudaDoubleTensor' then 204 | return cutorch.createCudaHostDoubleTensor() 205 | else 206 | return cutorch.createCudaHostTensor() 207 | end 208 | end 209 | 210 | function Trainer:copyInputs(sample) 211 | -- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory, 212 | -- if using DataParallelTable. The target is always copied to a CUDA tensor 213 | self.input = self.input or (self.opt.nGPU == 1 214 | and torch[self.opt.tensorType:match('torch.(%a+)')]() 215 | or getCudaTensorType(self.opt.tensorType)) 216 | self.target = self.target or (torch.CudaLongTensor and torch.CudaLongTensor()) 217 | self.input:resize(sample.input:size()):copy(sample.input) 218 | self.target:resize(sample.target:size()):copy(sample.target) 219 | end 220 | 221 | function Trainer:learningRate(epoch) 222 | -- Training schedule 223 | local decay = 0 224 | if self.opt.dataset == 'imagenet' then 225 | decay = math.floor((epoch - 1) / 30) 226 | elseif self.opt.dataset == 'cifar10' then 227 | decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0 228 | elseif self.opt.dataset == 'cifar100' then 229 | decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0 230 | end 231 | return self.opt.LR * math.pow(0.1, decay) 232 | end 233 | 234 | 235 | return M.Trainer --------------------------------------------------------------------------------