├── License
├── README.md
├── checkpoints.lua
├── dataloader.lua
├── datasets
├── README.md
├── cifar10-gen.lua
├── cifar10.lua
├── cifar100-gen.lua
├── cifar100.lua
├── imagenet-gen.lua
├── imagenet.lua
├── init.lua
└── transforms.lua
├── main.lua
├── models
├── JointTrainContainer.lua
├── MSDNet.lua
├── MSDNet_Layer.lua
├── init.lua
└── msdnet.lua
├── opts.lua
├── saveTXT.lua
└── train_joint.lua
/License:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 Gao Huang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MSDNet
2 |
3 | This repository provides the code for the paper [Multi-Scale Dense Networks for Resource Efficient Image Classification](http://arxiv.org/abs/1703.09844).
4 |
5 | ### Update on April 3, 2019 -- PyTorch implementation released!
6 | [A PyTorch implementation of MSDNet can be found from here.](https://github.com/kalviny/MSDNet-PyTorch)
7 |
8 |
9 | ## Introduction
10 |
11 | This paper studies convolutional networks that require limited computational resources at test time. We develop a new network architecture that performs on par with state-of-the-art convolutional networks, whilst facilitating prediction in two settings: (1) an anytime-prediction setting in which the network's prediction for one example is progressively updated, facilitating the output of a prediction at any time; and (2) a batch computational budget setting in which a fixed amount of computation is available to classify a set of examples that can be spent unevenly across 'easier' and 'harder' examples.
12 |
13 | **Figure 1**: MSDNet layout (2D).
14 |
15 |
16 |
17 | **Figure 2**: MSDNet layout (3D).
18 |
19 |
20 |
21 |
22 | ## Results
23 | ### (a) anytime-prediction setting
24 |
25 | **Figure 3**: Anytime prediction on ImageNet.
26 |
27 |
28 |
29 |
30 | ### (b) batch computational budget setting
31 |
32 | **Figure 4**: Prediction under batch computational budget on ImageNet.
33 |
34 |
35 |
36 | **Figure 5**: Random example images from the ImageNet classes `Red wine` and `Volcano`. Top row: images exited from the first classification layer of an MSDNet with correct prediction; Bottom row: images failed to be correctly classified at the first classifier but were correctly predicted and exited at the last layer.
37 |
38 |
39 |
40 |
41 | ## Usage
42 |
43 | Our code is written under the framework of Torch ResNet (https://github.com/facebook/fb.resnet.torch). The training scripts come with several options, which can be listed with the `--help` flag.
44 |
45 | th main.lua --help
46 |
47 | #### Configuration
48 |
49 | In all the experiments, we use a **validation set** for model selection. We hold out `5000` training images on CIFAR, and
50 | `50000`
51 | images on ImageNet as the validation set.
52 |
53 |
54 | #### Training recipe
55 |
56 | Train an MSDNet with 10 classifiers attached to every other layer for anytime prediction:
57 | ```bash
58 | th main.lua -netType msdnet -dataset cifar10 -batchSize 64 -nEpochs 300 -nBlocks 10 -stepmode even -step 2 -base 4
59 | ```
60 |
61 | Train an MSDNet with 7 classifiers with the span linearly increases for efficient batch computation:
62 | ```bash
63 | th main.lua -netType msdnet -dataset cifar10 -batchSize 64 -nEpochs 300 -nBlocks 7 -stepmode lin_grow -step 1 -base 1
64 | ```
65 |
66 | #### Pre-trained ImageNet Models
67 |
68 | 1. [Download](https://www.dropbox.com/sh/elnyjl4xdi4zyas/AACxCdjV-RWYrHYfbz61FFDma?dl=0) model checkpoints and the validation set indeces.
69 |
70 | 2. Testing script: `th main.lua -dataset imagenet -testOnly true -resume -data -gen `
71 |
72 | ## FAQ
73 |
74 | 1. How to calculate the FLOPs (or mul-add op) of a model?
75 |
76 | We strongly recommend doing it automatically. Please refer to the [op-counter](https://github.com/apaszke/torch-opCounter) project (LuaTorch), or the [script](https://github.com/ShichenLiu/CondenseNet/blob/master/utils.py#L58-L162) in ConDenseNet (PyTorch). The basic idea of these op counters is to add a hook before the forward pass of a model.
77 |
78 |
79 |
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/checkpoints.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | local checkpoint = {}
10 |
11 | local function deepCopy(tbl)
12 | -- creates a copy of a network with new modules and the same tensors
13 | local copy = {}
14 | for k, v in pairs(tbl) do
15 | if type(v) == 'table' then
16 | copy[k] = deepCopy(v)
17 | else
18 | copy[k] = v
19 | end
20 | end
21 | if torch.typename(tbl) then
22 | torch.setmetatable(copy, torch.typename(tbl))
23 | end
24 | return copy
25 | end
26 |
27 | function checkpoint.latest(opt)
28 | if opt.resume == 'none' then
29 | return nil
30 | end
31 |
32 | local latestPath = paths.concat(opt.resume, 'latest.t7')
33 | if not paths.filep(latestPath) then
34 | return nil
35 | end
36 |
37 | print('=> Loading checkpoint ' .. latestPath)
38 | local latest = torch.load(latestPath)
39 | local optimState = torch.load(paths.concat(opt.resume, latest.optimFile))
40 |
41 | return latest, optimState
42 | end
43 |
44 | function checkpoint.save(epoch, model, optimState, isBestModel, opt)
45 | -- don't save the DataParallelTable for easier loading on other machines
46 | if torch.type(model) == 'nn.DataParallelTable' then
47 | model = model:get(1)
48 | end
49 |
50 | print('save model!')
51 |
52 | -- create a clean copy on the CPU without modifying the original network
53 | model = deepCopy(model):float():clearState()
54 |
55 | local modelFile = 'model' .. '.t7'
56 | local optimFile = 'optimState' .. '.t7'
57 |
58 | torch.save(paths.concat(opt.save, modelFile), model)
59 | torch.save(paths.concat(opt.save, optimFile), optimState)
60 | torch.save(paths.concat(opt.save, 'latest.t7'), {
61 | epoch = epoch,
62 | modelFile = modelFile,
63 | optimFile = optimFile,
64 | opt = opt,
65 | })
66 |
67 | if isBestModel then
68 | torch.save(paths.concat(opt.save, 'model_best.t7'), model)
69 | end
70 | end
71 |
72 | return checkpoint
73 |
--------------------------------------------------------------------------------
/dataloader.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | -- Multi-threaded data loader
10 | --
11 |
12 | local datasets = require 'datasets/init'
13 | local Threads = require 'threads'
14 | Threads.serialization('threads.sharedserialize')
15 |
16 | local M = {}
17 | local DataLoader = torch.class('resnet.DataLoader', M)
18 |
19 | function DataLoader.create(opt)
20 | -- The train and val loader
21 | local loaders = {}
22 |
23 | local data_pair
24 | if opt.validset == true then
25 | data_pair = {'train', 'val', 'test'}
26 | else
27 | data_pair = {'train', 'val'}
28 | end
29 |
30 | for i, split in ipairs(data_pair) do
31 | local dataset = datasets.create(opt, split)
32 | loaders[i] = M.DataLoader(dataset, opt, split)
33 | end
34 |
35 | return table.unpack(loaders)
36 | end
37 |
38 | function DataLoader:__init(dataset, opt, split)
39 | self.opt = opt
40 | self.split = split
41 | local manualSeed = opt.manualSeed
42 | local function init()
43 | require('datasets/' .. opt.dataset)
44 | end
45 | local function main(idx)
46 | if manualSeed ~= 0 then
47 | torch.manualSeed(manualSeed + idx)
48 | end
49 | torch.setnumthreads(1)
50 | _G.dataset = dataset
51 | _G.preprocess = dataset:preprocess()
52 | return dataset:size()
53 | end
54 |
55 | local threads, sizes = Threads(opt.nThreads, init, main)
56 | self.nCrops = ((split == 'val' or split == 'test') and opt.tenCrop) and 10 or 1
57 | self.threads = threads
58 | self.__size = sizes[1][1]
59 | self.batchSize = math.floor(opt.batchSize / self.nCrops)
60 | end
61 |
62 | function DataLoader:size()
63 | return math.ceil(self.__size / self.batchSize)
64 | end
65 |
66 | function DataLoader:run()
67 | local threads = self.threads
68 | local size, batchSize = self.__size, self.batchSize
69 | local perm = (self.split == 'train') and torch.randperm(size) or torch.range(1,size) -- modifield to not shuffle data for testing
70 |
71 | local idx, sample = 1, nil
72 | local function enqueue()
73 | while idx <= size and threads:acceptsjob() do
74 | local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
75 | threads:addjob(
76 | function(indices, nCrops)
77 | local sz = indices:size(1)
78 | local batch, imageSize
79 | local target, path = torch.IntTensor(sz), {}
80 | for i, idx in ipairs(indices:totable()) do
81 | local sample = _G.dataset:get(idx)
82 | local input = _G.preprocess(sample.input)
83 | if not batch then
84 | imageSize = input:size():totable()
85 | if nCrops > 1 then table.remove(imageSize, 1) end
86 | batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize))
87 | end
88 | batch[i]:copy(input)
89 | target[i] = sample.target
90 | path[i] = sample.path
91 | end
92 | collectgarbage()
93 | return {
94 | input = batch:view(sz * nCrops, table.unpack(imageSize)),
95 | target = target,
96 | path = path,
97 | }
98 | end,
99 | function(_sample_)
100 | sample = _sample_
101 | end,
102 | indices,
103 | self.nCrops
104 | )
105 | idx = idx + batchSize
106 | end
107 | end
108 |
109 | local n = 0
110 | local function loop()
111 | enqueue()
112 | if not threads:hasjob() then
113 | return nil
114 | end
115 | threads:dojob()
116 | if threads:haserror() then
117 | threads:synchronize()
118 | end
119 | enqueue()
120 | n = n + 1
121 | return n, sample
122 | end
123 |
124 | return loop
125 | end
126 |
127 | return M.DataLoader
128 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | ## Datasets
2 |
3 | Each dataset consist of two files: `dataset-gen.lua` and `dataset.lua`. The `dataset-gen.lua` is responsible for one-time setup, while
4 | the `dataset.lua` handles the actual data loading.
5 |
6 | If you want to be able to use the new dataset from main.lua, you should also modify `opts.lua` to handle the new dataset name.
7 |
8 | ### `dataset-gen.lua`
9 |
10 | The `dataset-gen.lua` performs any necessary one-time setup. For example, the [`cifar10-gen.lua`](cifar10-gen.lua) file downloads the CIFAR-10 dataset, and the [`imagenet-gen.lua`](imagenet-gen.lua) file indexes all the training and validation data.
11 |
12 | The module should have a single function `exec(opt, cacheFile)`.
13 | - `opt`: the command line options
14 | - `cacheFile`: path to output
15 |
16 | ```lua
17 | local M = {}
18 | function M.exec(opt, cacheFile)
19 | local imageInfo = {}
20 | -- preprocess dataset, store results in imageInfo, save to cacheFile
21 | torch.save(cacheFile, imageInfo)
22 | end
23 | return M
24 | ```
25 |
26 | ### `dataset.lua`
27 |
28 | The `dataset.lua` should return a class that implements three functions:
29 | - `get(i)`: returns a table containing two entries, `input` and `target`
30 | - `input`: the training or validation image as a Torch tensor
31 | - `target`: the image category as a number 1-N
32 | - `size()`: returns the number of entries in the dataset
33 | - `preprocess()`: returns a function that transforms the `input` for data augmentation or input normalization
34 |
35 | ```lua
36 | local M = {}
37 | local FakeDataset = torch.class('resnet.FakeDataset', M)
38 |
39 | function FakeDataset:__init(imageInfo, opt, split)
40 | -- imageInfo: result from dataset-gen.lua
41 | -- opt: command-line arguments
42 | -- split: "train" or "val"
43 | end
44 |
45 | function FakeDataset:get(i)
46 | return {
47 | input = torch.Tensor(3, 800, 600):uniform(),
48 | target = 42,
49 | }
50 | end
51 |
52 | function FakeDataset:size()
53 | -- size of dataset
54 | return 2000
55 | end
56 |
57 | function FakeDataset:preprocess()
58 | -- Scale smaller side to 256 and take 224x224 center-crop
59 | return t.Compose{
60 | t.Scale(256),
61 | t.CenterCrop(224),
62 | }
63 | end
64 |
65 | return M.FakeDataset
66 | ```
67 |
--------------------------------------------------------------------------------
/datasets/cifar10-gen.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | -- Script to compute list of ImageNet filenames and classes
10 | --
11 | -- This automatically downloads the CIFAR-10 dataset from
12 | -- http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz
13 | --
14 |
15 | local URL = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz'
16 |
17 | local M = {}
18 |
19 | local function convertToTensor(files)
20 | local data, labels
21 |
22 | for _, file in ipairs(files) do
23 | local m = torch.load(file, 'ascii')
24 | if not data then
25 | data = m.data:t()
26 | labels = m.labels:squeeze()
27 | else
28 | data = torch.cat(data, m.data:t(), 1)
29 | labels = torch.cat(labels, m.labels:squeeze())
30 | end
31 | end
32 |
33 | -- This is *very* important. The downloaded files have labels 0-9, which do
34 | -- not work with CrossEntropyCriterion
35 | labels:add(1)
36 |
37 | return {
38 | data = data:contiguous():view(-1, 3, 32, 32),
39 | labels = labels,
40 | }
41 | end
42 |
43 | function M.exec(opt, cacheFile)
44 | local rawpath = 'gen/cifar-10-batches-t7/data_batch_1.t7'
45 | if not paths.filep(rawpath) then
46 | print("=> Downloading CIFAR-10 dataset from " .. URL)
47 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/')
48 | assert(ok == true or ok == 0, 'error downloading CIFAR-10')
49 | end
50 |
51 | print(" | combining dataset into a single file")
52 | local trainData = convertToTensor({
53 | 'gen/cifar-10-batches-t7/data_batch_1.t7',
54 | 'gen/cifar-10-batches-t7/data_batch_2.t7',
55 | 'gen/cifar-10-batches-t7/data_batch_3.t7',
56 | 'gen/cifar-10-batches-t7/data_batch_4.t7',
57 | 'gen/cifar-10-batches-t7/data_batch_5.t7',
58 | })
59 | local testData = convertToTensor({
60 | 'gen/cifar-10-batches-t7/test_batch.t7',
61 | })
62 | print(" | saving CIFAR-10 dataset to " .. cacheFile)
63 |
64 |
65 | if opt.validset == true then
66 | torch.manualSeed(1)
67 | local shuffle = torch.randperm(50000)
68 | trainData.data[{ {1, 50000} }] = trainData.data:index(1, shuffle:long())
69 | trainData.labels[{ {1, 50000} }] = trainData.labels:index(1, shuffle:long())
70 |
71 | local valData = {data = torch.Tensor(), labels = torch.Tensor()}
72 | valData.data = trainData.data[{ {45001,50000} }]
73 | valData.labels = trainData.labels[{ {45001,50000} }]
74 | trainData.data = trainData.data[{ {1,45000} }]
75 | trainData.labels = trainData.labels[{ {1,45000} }]
76 |
77 | torch.save(cacheFile, {
78 | train = trainData,
79 | val = valData,
80 | test = testData,})
81 | else
82 | torch.save(cacheFile, {
83 | train = trainData,
84 | val = testData,})
85 | end
86 |
87 | end
88 |
89 | return M
90 |
--------------------------------------------------------------------------------
/datasets/cifar10.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | -- CIFAR-10 dataset loader
10 | --
11 |
12 | local t = require 'datasets/transforms'
13 |
14 | local M = {}
15 | local CifarDataset = torch.class('resnet.CifarDataset', M)
16 |
17 | function CifarDataset:__init(imageInfo, opt, split)
18 | assert(imageInfo[split], split)
19 | self.imageInfo = imageInfo[split]
20 | self.split = split
21 | self.opt = opt
22 | end
23 |
24 | function CifarDataset:get(i)
25 | local image = self.imageInfo.data[i]:float()
26 | local label = self.imageInfo.labels[i]
27 |
28 | return {
29 | input = image,
30 | target = label,
31 | }
32 | end
33 |
34 | function CifarDataset:size()
35 | return self.imageInfo.data:size(1)
36 | end
37 |
38 | -- Computed from entire CIFAR-10 training set
39 | local meanstd = {
40 | mean = {125.3, 123.0, 113.9},
41 | std = {63.0, 62.1, 66.7},
42 | }
43 |
44 | function CifarDataset:preprocess()
45 | if self.split == 'train' then
46 | if self.opt.DataAug == true then
47 | return t.Compose{
48 | t.ColorNormalize(meanstd),
49 | t.HorizontalFlip(0.5),
50 | t.RandomCrop(32, 4),
51 | }
52 | else
53 | return t.Compose{
54 | t.ColorNormalize(meanstd),
55 | }
56 | end
57 | elseif self.split == 'val' or self.split == 'test' then
58 | return t.ColorNormalize(meanstd)
59 | else
60 | error('invalid split: ' .. self.split)
61 | end
62 | end
63 |
64 | return M.CifarDataset
65 |
--------------------------------------------------------------------------------
/datasets/cifar100-gen.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 |
10 | ------------
11 | -- This file automatically downloads the CIFAR-100 dataset from
12 | -- http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz
13 | -- It is based on cifar10-gen.lua
14 | -- Ludovic Trottier
15 | ------------
16 |
17 | local URL = 'http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz'
18 |
19 | local M = {}
20 |
21 | local function convertCifar100BinToTorchTensor(inputFname)
22 | local m=torch.DiskFile(inputFname, 'r'):binary()
23 | m:seekEnd()
24 | local length = m:position() - 1
25 | local nSamples = length / 3074 -- 1 coarse-label byte, 1 fine-label byte, 3072 pixel bytes
26 |
27 | assert(nSamples == math.floor(nSamples), 'expecting numSamples to be an exact integer')
28 | m:seek(1)
29 |
30 | local coarse = torch.ByteTensor(nSamples)
31 | local fine = torch.ByteTensor(nSamples)
32 | local data = torch.ByteTensor(nSamples, 3, 32, 32)
33 | for i=1,nSamples do
34 | coarse[i] = m:readByte()
35 | fine[i] = m:readByte()
36 | local store = m:readByte(3072)
37 | data[i]:copy(torch.ByteTensor(store))
38 | end
39 |
40 | local out = {}
41 | out.data = data
42 | -- This is *very* important. The downloaded files have labels 0-9, which do
43 | -- not work with CrossEntropyCriterion
44 | out.labels = fine + 1
45 |
46 | return out
47 | end
48 |
49 | function M.exec(opt, cacheFile)
50 | print("=> Downloading CIFAR-100 dataset from " .. URL)
51 |
52 | local ok = os.execute('curl ' .. URL .. ' | tar xz -C gen/')
53 | assert(ok == true or ok == 0, 'error downloading CIFAR-100')
54 |
55 | print(" | combining dataset into a single file")
56 |
57 | local trainData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/train.bin')
58 | local testData = convertCifar100BinToTorchTensor('gen/cifar-100-binary/test.bin')
59 |
60 | print(" | saving CIFAR-100 dataset to " .. cacheFile)
61 |
62 | if opt.validset == true then
63 | torch.manualSeed(1)
64 | local shuffle = torch.randperm(50000)
65 | trainData.data[{ {1, 50000} }] = trainData.data:index(1, shuffle:long())
66 | trainData.labels[{ {1, 50000} }] = trainData.labels:index(1, shuffle:long())
67 |
68 | local valData = {data = torch.Tensor(), labels = torch.Tensor()}
69 | valData.data = trainData.data[{ {45001,50000} }]
70 | valData.labels = trainData.labels[{ {45001,50000} }]
71 | trainData.data = trainData.data[{ {1,45000} }]
72 | trainData.labels = trainData.labels[{ {1,45000} }]
73 |
74 | torch.save(cacheFile, {
75 | train = trainData,
76 | val = valData,
77 | test = testData,})
78 | else
79 | torch.save(cacheFile, {
80 | train = trainData,
81 | val = testData,})
82 | end
83 | end
84 |
85 | return M
--------------------------------------------------------------------------------
/datasets/cifar100.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 |
10 | ------------
11 | -- This file is downloading and transforming CIFAR-100.
12 | -- It is based on cifar10.lua
13 | -- Ludovic Trottier
14 | ------------
15 |
16 | local t = require 'datasets/transforms'
17 |
18 | local M = {}
19 | local CifarDataset = torch.class('resnet.CifarDataset', M)
20 |
21 | function CifarDataset:__init(imageInfo, opt, split)
22 | assert(imageInfo[split], split)
23 | self.imageInfo = imageInfo[split]
24 | self.split = split
25 | self.opt = opt
26 | end
27 |
28 | function CifarDataset:get(i)
29 | local image = self.imageInfo.data[i]:float()
30 | local label = self.imageInfo.labels[i]
31 |
32 | return {
33 | input = image,
34 | target = label,
35 | }
36 | end
37 |
38 | function CifarDataset:size()
39 | return self.imageInfo.data:size(1)
40 | end
41 |
42 |
43 | -- Computed from entire CIFAR-100 training set with this code:
44 | -- dataset = torch.load('cifar100.t7')
45 | -- tt = dataset.train.data:double();
46 | -- tt = tt:transpose(2,4);
47 | -- tt = tt:reshape(50000*32*32, 3);
48 | -- tt:mean(1)
49 | -- tt:std(1)
50 | local meanstd = {
51 | mean = {129.3, 124.1, 112.4},
52 | std = {68.2, 65.4, 70.4},
53 | }
54 |
55 | function CifarDataset:preprocess()
56 | if self.split == 'train' then
57 | if self.opt.DataAug == true then
58 | return t.Compose{
59 | t.ColorNormalize(meanstd),
60 | t.HorizontalFlip(0.5),
61 | t.RandomCrop(32, 4),
62 | }
63 | else
64 | return t.Compose{
65 | t.ColorNormalize(meanstd),
66 | }
67 | end
68 | elseif self.split == 'val' or self.split == 'test' then
69 | return t.ColorNormalize(meanstd)
70 | else
71 | error('invalid split: ' .. self.split)
72 | end
73 | end
74 |
75 | return M.CifarDataset
76 |
--------------------------------------------------------------------------------
/datasets/imagenet-gen.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | -- Script to compute list of ImageNet filenames and classes
10 | --
11 | -- This generates a file gen/imagenet.t7 which contains the list of all
12 | -- ImageNet training and validation images and their classes. This script also
13 | -- works for other datasets arragned with the same layout.
14 | --
15 |
16 | local sys = require 'sys'
17 | local ffi = require 'ffi'
18 |
19 | local M = {}
20 |
21 | local function findClasses(dir)
22 | local dirs = paths.dir(dir)
23 | table.sort(dirs)
24 |
25 | local classList = {}
26 | local classToIdx = {}
27 | for _ ,class in ipairs(dirs) do
28 | if not classToIdx[class] and class ~= '.' and class ~= '..'
29 | and class ~= '.DS_Store' then
30 | table.insert(classList, class)
31 | classToIdx[class] = #classList
32 | end
33 | end
34 |
35 | -- assert(#classList == 1000, 'expected 1000 ImageNet classes')
36 | return classList, classToIdx
37 | end
38 |
39 | local function findImages(dir, classToIdx)
40 |
41 | ----------------------------------------------------------------------
42 | -- Options for the GNU and BSD find command
43 | local extensionList =
44 | {'jpg', 'png', 'jpeg', 'JPG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
45 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
46 | for i=2,#extensionList do
47 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
48 | end
49 |
50 | -- Find all the images using the find command
51 | local f = io.popen('find -L ' .. dir .. findOptions)
52 |
53 | local maxLength = -1
54 | local imagePaths = {}
55 | local imageClasses = {}
56 |
57 | -- Generate a list of all the images and their class
58 | while true do
59 | local line = f:read('*line')
60 | if not line then break end
61 |
62 | local className = paths.basename(paths.dirname(line))
63 | local filename = paths.basename(line)
64 | local path = className .. '/' .. filename
65 |
66 | local classId = classToIdx[className]
67 | assert(classId, 'class not found: ' .. className)
68 |
69 | table.insert(imagePaths, path)
70 | table.insert(imageClasses, classId)
71 |
72 | maxLength = math.max(maxLength, #path + 1)
73 | end
74 |
75 | f:close()
76 |
77 | -- Convert the generated list to a tensor for faster loading
78 | local nImages = #imagePaths
79 | local imagePath = torch.CharTensor(nImages, maxLength):zero()
80 | for i, path in ipairs(imagePaths) do
81 | ffi.copy(imagePath[i]:data(), path)
82 | end
83 |
84 | local imageClass = torch.LongTensor(imageClasses)
85 | return imagePath, imageClass
86 | end
87 |
88 | function M.exec(opt, cacheFile)
89 |
90 | local trainDir = paths.concat(opt.data, 'train')
91 | local valDir = paths.concat(opt.data, 'val')
92 | assert(paths.dirp(trainDir), 'train directory not found: ' .. trainDir)
93 | assert(paths.dirp(valDir), 'val directory not found: ' .. valDir)
94 |
95 | print("=> Generating list of images")
96 | local classList, classToIdx = findClasses(trainDir)
97 |
98 | print(" | finding all validation images")
99 | local valImagePath, valImageClass = findImages(valDir, classToIdx)
100 |
101 | print(" | finding all training images")
102 | local trainImagePath, trainImageClass = findImages(trainDir, classToIdx)
103 |
104 | local info
105 | if opt.validset == true then
106 | torch.manualSeed(1)
107 |
108 | local nAll = trainImageClass:size(1)
109 | local shuffle = torch.randperm(nAll)
110 | local val0ImagePath = trainImagePath:index(1,shuffle:sub(1,50000):long())
111 | local val0ImageClass = trainImageClass:index(1,shuffle:sub(1,50000):long())
112 | trainImagePath = trainImagePath:index(1,shuffle:sub(50001,nAll):long())
113 | trainImageClass = trainImageClass:index(1,shuffle:sub(50001,nAll):long())
114 |
115 | info = {
116 | basedir = opt.data,
117 | classList = classList,
118 | train = {
119 | imagePath = trainImagePath,
120 | imageClass = trainImageClass,
121 | },
122 | val = {
123 | imagePath = val0ImagePath,
124 | imageClass = val0ImageClass,
125 | },
126 | test = {
127 | imagePath = valImagePath,
128 | imageClass = valImageClass,
129 | },
130 | }
131 | else
132 | info = {
133 | basedir = opt.data,
134 | classList = classList,
135 | train = {
136 | imagePath = trainImagePath,
137 | imageClass = trainImageClass,
138 | },
139 | val = {
140 | imagePath = valImagePath,
141 | imageClass = valImageClass,
142 | },
143 | }
144 | end
145 |
146 |
147 | print(" | saving list of images to " .. cacheFile)
148 | torch.save(cacheFile, info)
149 | return info
150 | end
151 |
152 | return M
153 |
--------------------------------------------------------------------------------
/datasets/imagenet.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | -- ImageNet dataset loader
10 | --
11 |
12 | local image = require 'image'
13 | local paths = require 'paths'
14 | local t = require 'datasets/transforms'
15 | local ffi = require 'ffi'
16 |
17 | local M = {}
18 | local ImagenetDataset = torch.class('resnet.ImagenetDataset', M)
19 |
20 | function ImagenetDataset:__init(imageInfo, opt, split)
21 | self.imageInfo = imageInfo[split]
22 | self.opt = opt
23 | self.split = split
24 | --------------------------------------------
25 | if opt.validset then
26 | if split == 'train' or split == 'val' then
27 | self.dir = paths.concat(opt.data, 'train')
28 | elseif split == 'test' then
29 | self.dir = paths.concat(opt.data, 'val')
30 | end
31 | else
32 | self.dir = paths.concat(opt.data, split)
33 | end
34 | assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir)
35 | -------------------------------------------
36 | end
37 |
38 | function ImagenetDataset:get(i)
39 | local path = ffi.string(self.imageInfo.imagePath[i]:data())
40 |
41 | local image = self:_loadImage(paths.concat(self.dir, path))
42 | local class = self.imageInfo.imageClass[i]
43 |
44 | return {
45 | input = image,
46 | target = class,
47 | }
48 | end
49 |
50 | function ImagenetDataset:_loadImage(path)
51 | local ok, input = pcall(function()
52 | return image.load(path, 3, 'float')
53 | end)
54 |
55 | -- Sometimes image.load fails because the file extension does not match the
56 | -- image format. In that case, use image.decompress on a ByteTensor.
57 | if not ok then
58 | local f = io.open(path, 'r')
59 | assert(f, 'Error reading: ' .. tostring(path))
60 | local data = f:read('*a')
61 | f:close()
62 |
63 | local b = torch.ByteTensor(string.len(data))
64 | ffi.copy(b:data(), data, b:size(1))
65 |
66 | ok, input = pcall(image.decompress, b, 3, 'float')
67 | if not ok then
68 | input = torch.FloatTensor(3, 224, 224):fill(0)
69 | end
70 | end
71 |
72 | return input
73 | end
74 |
75 | function ImagenetDataset:size()
76 | return self.imageInfo.imageClass:size(1)
77 | end
78 |
79 | -- Computed from random subset of ImageNet training images
80 | local meanstd = {
81 | mean = { 0.485, 0.456, 0.406 },
82 | std = { 0.229, 0.224, 0.225 },
83 | }
84 | local pca = {
85 | eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 },
86 | eigvec = torch.Tensor{
87 | { -0.5675, 0.7192, 0.4009 },
88 | { -0.5808, -0.0045, -0.8140 },
89 | { -0.5836, -0.6948, 0.4203 },
90 | },
91 | }
92 |
93 | function ImagenetDataset:preprocess()
94 | if self.split == 'train' then
95 | return t.Compose{
96 | t.RandomSizedCrop(224),
97 | t.ColorJitter({
98 | brightness = 0.4,
99 | contrast = 0.4,
100 | saturation = 0.4,
101 | }),
102 | t.Lighting(0.1, pca.eigval, pca.eigvec),
103 | t.ColorNormalize(meanstd),
104 | t.HorizontalFlip(0.5),
105 | }
106 | elseif self.split == 'val' or self.split == 'test' then
107 | local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop
108 | return t.Compose{
109 | t.Scale(256),
110 | t.ColorNormalize(meanstd),
111 | Crop(224),
112 | }
113 | else
114 | error('invalid split: ' .. self.split)
115 | end
116 | end
117 |
118 | return M.ImagenetDataset
119 |
--------------------------------------------------------------------------------
/datasets/init.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | -- ImageNet and CIFAR-10 datasets
10 | --
11 |
12 | local M = {}
13 |
14 | local function isvalid(opt, cachePath)
15 | if opt.testOnly == true then return true end -- don't check the basedir in test mode
16 | local imageInfo = torch.load(cachePath)
17 | if imageInfo.basedir and imageInfo.basedir ~= opt.data then
18 | return false
19 | end
20 | return true
21 | end
22 |
23 | function M.create(opt, split)
24 | local cachePath = opt.validset
25 | and paths.concat(opt.gen, opt.dataset .. '_withvalid.t7')
26 | or paths.concat(opt.gen, opt.dataset .. '.t7')
27 | if not paths.filep(cachePath) or not isvalid(opt, cachePath) then
28 | paths.mkdir('gen')
29 | local script = paths.dofile(opt.dataset .. '-gen.lua')
30 | script.exec(opt, cachePath)
31 | end
32 | local imageInfo = torch.load(cachePath)
33 |
34 | local Dataset = require('datasets/' .. opt.dataset)
35 | return Dataset(imageInfo, opt, split)
36 | end
37 |
38 | return M
39 |
--------------------------------------------------------------------------------
/datasets/transforms.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | -- Image transforms for data augmentation and input normalization
10 | --
11 |
12 | require 'image'
13 |
14 | local M = {}
15 |
16 | function M.Compose(transforms)
17 | return function(input)
18 | for _, transform in ipairs(transforms) do
19 | input = transform(input)
20 | end
21 | return input
22 | end
23 | end
24 |
25 | function M.ColorNormalize(meanstd)
26 | return function(img)
27 | img = img:clone()
28 | for i=1,3 do
29 | img[i]:add(-meanstd.mean[i])
30 | img[i]:div(meanstd.std[i])
31 | end
32 | return img
33 | end
34 | end
35 |
36 | -- Scales the smaller edge to size
37 | function M.Scale(size, interpolation)
38 | interpolation = interpolation or 'bicubic'
39 | return function(input)
40 | local w, h = input:size(3), input:size(2)
41 | if (w <= h and w == size) or (h <= w and h == size) then
42 | return input
43 | end
44 | if w < h then
45 | return image.scale(input, size, h/w * size, interpolation)
46 | else
47 | return image.scale(input, w/h * size, size, interpolation)
48 | end
49 | end
50 | end
51 |
52 | -- Crop to centered rectangle
53 | function M.CenterCrop(size)
54 | return function(input)
55 | local w1 = math.ceil((input:size(3) - size)/2)
56 | local h1 = math.ceil((input:size(2) - size)/2)
57 | return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch
58 | end
59 | end
60 |
61 | -- Random crop form larger image with optional zero padding
62 | function M.RandomCrop(size, padding)
63 | padding = padding or 0
64 |
65 | return function(input)
66 | if padding > 0 then
67 | local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding)
68 | temp:zero()
69 | :narrow(2, padding+1, input:size(2))
70 | :narrow(3, padding+1, input:size(3))
71 | :copy(input)
72 | input = temp
73 | end
74 |
75 | local w, h = input:size(3), input:size(2)
76 | if w == size and h == size then
77 | return input
78 | end
79 |
80 | local x1, y1 = torch.random(0, w - size), torch.random(0, h - size)
81 | local out = image.crop(input, x1, y1, x1 + size, y1 + size)
82 | assert(out:size(2) == size and out:size(3) == size, 'wrong crop size')
83 | return out
84 | end
85 | end
86 |
87 | -- Four corner patches and center crop from image and its horizontal reflection
88 | function M.TenCrop(size)
89 | local centerCrop = M.CenterCrop(size)
90 |
91 | return function(input)
92 | local w, h = input:size(3), input:size(2)
93 |
94 | local output = {}
95 | for _, img in ipairs{input, image.hflip(input)} do
96 | table.insert(output, centerCrop(img))
97 | table.insert(output, image.crop(img, 0, 0, size, size))
98 | table.insert(output, image.crop(img, w-size, 0, w, size))
99 | table.insert(output, image.crop(img, 0, h-size, size, h))
100 | table.insert(output, image.crop(img, w-size, h-size, w, h))
101 | end
102 |
103 | -- View as mini-batch
104 | for i, img in ipairs(output) do
105 | output[i] = img:view(1, img:size(1), img:size(2), img:size(3))
106 | end
107 |
108 | return input.cat(output, 1)
109 | end
110 | end
111 |
112 | -- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style)
113 | function M.RandomScale(minSize, maxSize)
114 | return function(input)
115 | local w, h = input:size(3), input:size(2)
116 |
117 | local targetSz = torch.random(minSize, maxSize)
118 | local targetW, targetH = targetSz, targetSz
119 | if w < h then
120 | targetH = torch.round(h / w * targetW)
121 | else
122 | targetW = torch.round(w / h * targetH)
123 | end
124 |
125 | return image.scale(input, targetW, targetH, 'bicubic')
126 | end
127 | end
128 |
129 | -- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style)
130 | function M.RandomSizedCrop(size)
131 | local scale = M.Scale(size)
132 | local crop = M.CenterCrop(size)
133 |
134 | return function(input)
135 | local attempt = 0
136 | repeat
137 | local area = input:size(2) * input:size(3)
138 | local targetArea = torch.uniform(0.08, 1.0) * area
139 |
140 | local aspectRatio = torch.uniform(3/4, 4/3)
141 | local w = torch.round(math.sqrt(targetArea * aspectRatio))
142 | local h = torch.round(math.sqrt(targetArea / aspectRatio))
143 |
144 | if torch.uniform() < 0.5 then
145 | w, h = h, w
146 | end
147 |
148 | if h <= input:size(2) and w <= input:size(3) then
149 | local y1 = torch.random(0, input:size(2) - h)
150 | local x1 = torch.random(0, input:size(3) - w)
151 |
152 | local out = image.crop(input, x1, y1, x1 + w, y1 + h)
153 | assert(out:size(2) == h and out:size(3) == w, 'wrong crop size')
154 |
155 | return image.scale(out, size, size, 'bicubic')
156 | end
157 | attempt = attempt + 1
158 | until attempt >= 10
159 |
160 | -- fallback
161 | return crop(scale(input))
162 | end
163 | end
164 |
165 | function M.HorizontalFlip(prob)
166 | return function(input)
167 | if torch.uniform() < prob then
168 | input = image.hflip(input)
169 | end
170 | return input
171 | end
172 | end
173 |
174 | function M.Rotation(deg)
175 | return function(input)
176 | if deg ~= 0 then
177 | input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear')
178 | end
179 | return input
180 | end
181 | end
182 |
183 | -- Lighting noise (AlexNet-style PCA-based noise)
184 | function M.Lighting(alphastd, eigval, eigvec)
185 | return function(input)
186 | if alphastd == 0 then
187 | return input
188 | end
189 |
190 | local alpha = torch.Tensor(3):normal(0, alphastd)
191 | local rgb = eigvec:clone()
192 | :cmul(alpha:view(1, 3):expand(3, 3))
193 | :cmul(eigval:view(1, 3):expand(3, 3))
194 | :sum(2)
195 | :squeeze()
196 |
197 | input = input:clone()
198 | for i=1,3 do
199 | input[i]:add(rgb[i])
200 | end
201 | return input
202 | end
203 | end
204 |
205 | local function blend(img1, img2, alpha)
206 | return img1:mul(alpha):add(1 - alpha, img2)
207 | end
208 |
209 | local function grayscale(dst, img)
210 | dst:resizeAs(img)
211 | dst[1]:zero()
212 | dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3])
213 | dst[2]:copy(dst[1])
214 | dst[3]:copy(dst[1])
215 | return dst
216 | end
217 |
218 | function M.Saturation(var)
219 | local gs
220 |
221 | return function(input)
222 | gs = gs or input.new()
223 | grayscale(gs, input)
224 |
225 | local alpha = 1.0 + torch.uniform(-var, var)
226 | blend(input, gs, alpha)
227 | return input
228 | end
229 | end
230 |
231 | function M.Brightness(var)
232 | local gs
233 |
234 | return function(input)
235 | gs = gs or input.new()
236 | gs:resizeAs(input):zero()
237 |
238 | local alpha = 1.0 + torch.uniform(-var, var)
239 | blend(input, gs, alpha)
240 | return input
241 | end
242 | end
243 |
244 | function M.Contrast(var)
245 | local gs
246 |
247 | return function(input)
248 | gs = gs or input.new()
249 | grayscale(gs, input)
250 | gs:fill(gs[1]:mean())
251 |
252 | local alpha = 1.0 + torch.uniform(-var, var)
253 | blend(input, gs, alpha)
254 | return input
255 | end
256 | end
257 |
258 | function M.RandomOrder(ts)
259 | return function(input)
260 | local img = input.img or input
261 | local order = torch.randperm(#ts)
262 | for i=1,#ts do
263 | img = ts[order[i]](img)
264 | end
265 | return img
266 | end
267 | end
268 |
269 | function M.ColorJitter(opt)
270 | local brightness = opt.brightness or 0
271 | local contrast = opt.contrast or 0
272 | local saturation = opt.saturation or 0
273 |
274 | local ts = {}
275 | if brightness ~= 0 then
276 | table.insert(ts, M.Brightness(brightness))
277 | end
278 | if contrast ~= 0 then
279 | table.insert(ts, M.Contrast(contrast))
280 | end
281 | if saturation ~= 0 then
282 | table.insert(ts, M.Saturation(saturation))
283 | end
284 |
285 | if #ts == 0 then
286 | return function(input) return input end
287 | end
288 |
289 | return M.RandomOrder(ts)
290 | end
291 |
292 | return M
293 |
--------------------------------------------------------------------------------
/main.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | require 'torch'
10 | require 'paths'
11 | require 'optim'
12 | require 'nn'
13 | require 'sys'
14 | require 'paths'
15 |
16 | local save2txt = require 'saveTXT'
17 | local DataLoader = require 'dataloader'
18 | local models = require 'models/init'
19 | local Trainer = require 'train_joint'
20 | local opts = require 'opts'
21 | local checkpoints = require 'checkpoints'
22 |
23 | torch.setdefaulttensortype('torch.FloatTensor')
24 | torch.setnumthreads(1)
25 |
26 | local opt = opts.parse(arg)
27 | torch.manualSeed(opt.manualSeed)
28 | cutorch.manualSeedAll(opt.manualSeed)
29 |
30 | -- Load previous checkpoint, if it exists
31 | local checkpoint, optimState = checkpoints.latest(opt)
32 |
33 | -- Create model
34 | local model = models.setup(opt, checkpoint)
35 |
36 | -- Use parallel criterion for multiple exits
37 | local criterion = nn.ParallelCriterion()
38 | for i = 1, opt.nBlocks do
39 | criterion:add(nn.CrossEntropyCriterion():type(opt.tensorType), 1)
40 | end
41 |
42 | -- Data loading
43 | print('Creating dataloader ...')
44 | local trainLoader, valLoader, testLoader = DataLoader.create(opt)
45 |
46 | -- The trainer handles the training loop and evaluation on validation set
47 | local trainer = Trainer(model, criterion, opt, optimState)
48 |
49 | -- Test only, need to specify -save and -retrain
50 | if opt.testOnly then
51 | -- local flops = torch.load(paths.concat(opt.save, 'flops.t7'))
52 | -- opt.nBlocks = #flops
53 | local top1ErrValid, top5Err, top1ErrEnsembleValid, top5ErrEnsemble = trainer:test(0, valLoader)
54 | local top1Err, top5Err, top1ErrEnsemble, top5ErrEnsemble = trainer:test(0, testLoader)
55 | print('results from: ' .. opt.save)
56 | print(
57 | -- 'flops: \n', flops,
58 | '\nval single: \n', top1ErrValid,
59 | '\n val ensemble: \n', top1ErrEnsembleValid,
60 | '\n test single:\n', top1Err,
61 | '\n test ensemble:\n', top1ErrEnsemble)
62 | torch.save(paths.concat(opt.save, 'anytime_result.t7'), {top1ErrValid, top1ErrEnsembleValid,
63 | top1Err, top1ErrEnsemble})
64 | return
65 | end
66 |
67 | -- Initialize some parameters
68 | local paramsize = trainer.params:size(1)
69 | print('Parameters:', paramsize)
70 | checkpoints.save(0, model, trainer.optimState, false, opt)
71 | local valSingle, valEnsemble, testSingle, testEnsemble = {}, {}, {}, {}
72 | local startEpoch = checkpoint and checkpoint.epoch + 1 or opt.epochNumber
73 | local bestTop1 = math.huge
74 | local bestTop5 = math.huge
75 | local timer = torch.Timer()
76 |
77 | -- Training epochs
78 | for epoch = startEpoch, opt.nEpochs do
79 |
80 | -- Train for a single epoch
81 | timer:reset()
82 | trainer:train(epoch, trainLoader)
83 |
84 | -- Run model on validation set
85 | local valTop1All, _, valTop1Evolve = trainer:test(epoch, valLoader)
86 | valSingle[epoch] = valTop1All
87 | valEnsemble[epoch] = valTop1Evolve
88 |
89 | -- Run model on test set
90 | if opt.validset == true then
91 | local testTop1All, _, testTop1Evolve = trainer:test(epoch, testLoader)
92 | testSingle[epoch] = testTop1All
93 | testEnsemble[epoch] = testTop1Evolve
94 | end
95 |
96 | -- Log results to text file
97 | local filename = paths.concat(opt.save, 'result_')
98 | save2txt(filename..'valSingle', valSingle)
99 | save2txt(filename..'valEnsemble', valEnsemble)
100 | if opt.validset == true then
101 | save2txt(filename..'testSingle', testSingle)
102 | save2txt(filename..'testEnsemble', testEnsemble)
103 | end
104 |
105 | -- Checkpoint best model
106 | local bestModel = false
107 | if valEnsemble[epoch][opt.nBlocks] < bestTop1 then
108 | bestModel = true
109 | bestTop1 = valEnsemble[epoch][opt.nBlocks]
110 | print(' * Best model ', valEnsemble[epoch][opt.nBlocks])
111 | torch.save(paths.concat(opt.save, 'best_result_ensemble.t7'), {valEnsemble[epoch], testEnsemble[epoch]})
112 | end
113 | if bestModel or epoch == opt.nEpochs or opt.dataset == 'imagenet' then
114 | checkpoints.save(epoch, model, trainer.optimState, bestModel, opt)
115 | end
116 | end
117 |
118 | -- Done
119 | print(string.format(' * Finished top1: %6.3f top5: %6.3f', bestTop1, bestTop5))
120 |
--------------------------------------------------------------------------------
/models/JointTrainContainer.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'cudnn'
3 | require 'cunn'
4 | require 'models/MSDNet_Layer'
5 |
6 |
7 | local function build_transition(nIn, nOut, outScales, offset, opt)
8 | local function conv1x1(nIn, nOut)
9 | local s = nn.Sequential()
10 | s:add(cudnn.SpatialConvolution(nIn, nOut, 1,1, 1,1, 0,0))
11 | s:add(cudnn.SpatialBatchNormalization(nOut))
12 | s:add(cudnn.ReLU(true))
13 | return s
14 | end
15 | local net = nn.ParallelTable()
16 | for i = 1, outScales do
17 | net:add(conv1x1(nIn * opt.grFactor[offset + i], nOut * opt.grFactor[offset + i]))
18 | end
19 | return net
20 | end
21 |
22 |
23 | local function build_block(nChannels, opt, step, layer_all, layer_tillnow)
24 | local block = nn.Sequential()
25 | if layer_tillnow == 0 then -- layer_curr == 0 means we are at the first block
26 | block:add(nn.MSDNet_Layer_first(3, nChannels, opt))
27 | end
28 |
29 | local nIn = nChannels
30 | for i = 1, step do
31 | local inScales, outScales = opt.nScales, opt.nScales
32 | layer_tillnow = layer_tillnow + 1
33 | if opt.prune == 'min' then
34 | inScales = math.min(opt.nScales, layer_all - layer_tillnow + 2)
35 | outScales = math.min(opt.nScales, layer_all - layer_tillnow + 1)
36 | elseif opt.prune == 'max' then
37 | local interval = torch.ceil(layer_all/opt.nScales)
38 | inScales = opt.nScales - torch.floor((math.max(0, layer_tillnow -2))/interval)
39 | outScales = opt.nScales - torch.floor((layer_tillnow -1)/interval)
40 | else
41 | error('Unknown prune option!')
42 | end
43 | print('|', 'inScales ', inScales, 'outScales ', outScales , '|')
44 | block:add(nn.MSDNet_Layer(nIn, opt.growthRate, opt, inScales, outScales))
45 | nIn = nIn + opt.growthRate
46 | if opt.prune == 'max' and inScales > outScales and opt.reduction > 0 then
47 | local offset = opt.nScales - outScales
48 | block:add(build_transition(nIn, math.floor(opt.reduction*nIn), outScales, offset, opt))
49 | nIn = math.floor(opt.reduction*nIn)
50 | print('|', 'Transition layer inserted!', '\t\t|')
51 | elseif opt.prune == 'min' and opt.reduction > 0 and (layer_tillnow == torch.floor(layer_all/3) or layer_tillnow == torch.floor(2*layer_all/3)) then
52 | local offset = opt.nScales - outScales
53 | block:add(build_transition(nIn, math.floor(opt.reduction*nIn), outScales, offset, opt))
54 | nIn = math.floor(opt.reduction*nIn)
55 | print('|', 'Transition layer inserted!', '\t\t|')
56 | end
57 | end
58 | return block, nIn
59 | end
60 |
61 | local function build_classifier_cifar(nChannels, nClass)
62 | local interChannels1, interChannels2 = 128, 128
63 | local c = nn.Sequential()
64 | c:add(cudnn.SpatialConvolution(nChannels, interChannels1,3,3,2,2,1,1))
65 | c:add(cudnn.SpatialBatchNormalization(interChannels1))
66 | c:add(cudnn.ReLU(true))
67 | c:add(cudnn.SpatialConvolution(interChannels1, interChannels2,3,3,2,2,1,1))
68 | c:add(cudnn.SpatialBatchNormalization(interChannels2))
69 | c:add(cudnn.ReLU(true))
70 | c:add(cudnn.SpatialAveragePooling(2,2))
71 | c:add(nn.Reshape(interChannels2))
72 | c:add(nn.Linear(interChannels2, nClass))
73 | return c
74 | end
75 |
76 | local function build_classifier_imagenet(nChannels, nClass)
77 | local c = nn.Sequential()
78 | c:add(cudnn.SpatialConvolution(nChannels, nChannels,3,3,2,2,1,1))
79 | c:add(cudnn.SpatialBatchNormalization(nChannels))
80 | c:add(cudnn.ReLU(true))
81 | c:add(cudnn.SpatialConvolution(nChannels, nChannels,3,3,2,2,1,1))
82 | c:add(cudnn.SpatialBatchNormalization(nChannels))
83 | c:add(cudnn.ReLU(true))
84 | c:add(cudnn.SpatialAveragePooling(2,2))
85 | c:add(nn.Reshape(nChannels))
86 | c:add(nn.Linear(nChannels, 1000))
87 | return c
88 | end
89 |
90 | local JointTrainModule, parent = torch.class('nn.JointTrainModule', 'nn.Container')
91 |
92 | function JointTrainModule:__init(nChannels, opt)
93 | parent.__init(self)
94 |
95 | self.train = true
96 | self.nChannels = nChannels
97 | self.nBlocks = opt.nBlocks
98 | self.opt = opt
99 |
100 | local nIn = nChannels
101 | self.modules = {}
102 |
103 | -- calculate the step size of each blocks
104 | local layer_curr, layer_all = 0, opt.base
105 | local steps = {}
106 | steps[1] = opt.base
107 | for i = 2, self.nBlocks do
108 | steps[i] = opt.stepmode=='even' and opt.step or opt.stepmode=='lin_grow' and opt.step*(i-1)+1
109 | layer_all = layer_all + steps[i]
110 | end
111 | print("building network of steps: ")
112 | print(steps)
113 | torch.save(paths.concat(opt.save, 'layer_specific.t7'), steps)
114 |
115 | for i = 1, self.nBlocks do
116 | print(' ----------------------- Block ' .. i .. ' -----------------------')
117 | self.modules[i], nIn = build_block(nIn, opt, steps[i], layer_all, layer_curr)
118 | layer_curr = layer_curr + steps[i]
119 | if opt.dataset == 'cifar10' then
120 | self.modules[i+self.nBlocks] = build_classifier_cifar(nIn*opt.grFactor[opt.nScales], 10)
121 | elseif opt.dataset == 'cifar100' then
122 | self.modules[i+self.nBlocks] = build_classifier_cifar(nIn*opt.grFactor[opt.nScales], 100)
123 | elseif opt.dataset == 'imagenet' then
124 | self.modules[i+self.nBlocks] = build_classifier_imagenet(nIn*opt.grFactor[opt.nScales], 1000)
125 | else
126 | error('Unknown dataset!')
127 | end
128 | end
129 | self.gradInput = {}
130 | self.output = {}
131 |
132 | end
133 |
134 | function JointTrainModule:updateOutput(input)
135 | for i = 1, self.nBlocks do
136 | local tmp_input = (i==1) and input or self.modules[i-1].output
137 | local tmp_output1 = self:rethrowErrors(self.modules[i], i, 'updateOutput', tmp_input)
138 | local tmp_output2 = self:rethrowErrors(self.modules[i+self.nBlocks], i+self.nBlocks, 'updateOutput', tmp_output1[#tmp_output1])
139 | self.output[i] = tmp_output2
140 | end
141 | return self.output
142 | end
143 |
144 | function JointTrainModule:updateGradInput(input, gradOutput)
145 |
146 | local nScales = self.opt.nScales
147 | for i = self.nBlocks, 1, -1 do
148 | local features = self.modules[i].output[#self.modules[i].output]
149 | self.modules[i+self.nBlocks]:updateGradInput(features, gradOutput[i])
150 | local gOut = {}
151 | if i == self.nBlocks then
152 | for s = 1, nScales do
153 | local out = self.modules[i].output[s]
154 | if out then
155 | gOut[s] = out.new():resizeAs(out):zero()
156 | end
157 | end
158 | else
159 | gOut = self.modules[i+1].gradInput
160 | end
161 | gOut[#gOut]:add(self.modules[i+self.nBlocks].gradInput)
162 | local tmp_input = (i==1) and input or self.modules[i-1].output
163 | self.modules[i]:updateGradInput(tmp_input, gOut)
164 | end
165 | self.gradInput = self.modules[1].gradInput
166 |
167 | return self.gradInput
168 | end
169 |
170 | function JointTrainModule:accGradParameters(input, gradOutput, scale)
171 | scale = scale or 1
172 | for i = self.nBlocks, 1, -1 do
173 | local features = self.modules[i].output[#self.modules[i].output]
174 | self.modules[i+self.nBlocks]:accGradParameters(features, gradOutput[i], scale)
175 | local tmp_input = (i==1) and input or self.modules[i-1].output
176 | local gOut
177 | if i == self.nBlocks then
178 | gOut = self.modules[i].output
179 | else
180 | gOut = self.modules[i+1].gradInput
181 | end
182 | self.modules[i]:accGradParameters(tmp_input, gOut, scale)
183 | end
184 | end
185 |
186 | function JointTrainModule:__tostring__()
187 | local tab = ' '
188 | local line = '\n'
189 | local next = ' |`-> '
190 | local lastNext = ' `-> '
191 | local ext = ' | '
192 | local extlast = ' '
193 | local last = ' ... -> '
194 | local str = 'JointTrainModule'
195 | str = str .. ' {' .. line .. tab .. '{input}'
196 | for i=1,#self.modules do
197 | if i == #self.modules then
198 | str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
199 | else
200 | str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
201 | end
202 | end
203 | str = str .. line .. tab .. last .. '{output}'
204 | str = str .. line .. '}'
205 | return str
206 | end
207 |
208 |
--------------------------------------------------------------------------------
/models/MSDNet.lua:
--------------------------------------------------------------------------------
1 | msdnet.lua
--------------------------------------------------------------------------------
/models/MSDNet_Layer.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'cudnn'
3 | require 'cunn'
4 |
5 | local function build(nChannels, nOutChannels, type, bottleneck, bnWidth)
6 | local net = nn.Sequential()
7 | local innerChannels = nChannels
8 | bnWidth = bnWidth or 4
9 | if bottleneck == true then
10 | innerChannels = math.min(innerChannels, bnWidth * nOutChannels)
11 | net:add(cudnn.SpatialConvolution(nChannels, innerChannels, 1,1, 1,1, 0,0))
12 | net:add(cudnn.SpatialBatchNormalization(innerChannels))
13 | net:add(cudnn.ReLU(true))
14 | end
15 | if type == 'normal' then
16 | net:add(cudnn.SpatialConvolution(innerChannels, nOutChannels, 3,3, 1,1, 1,1))
17 | elseif type == 'down' then
18 | net:add(cudnn.SpatialConvolution(innerChannels, nOutChannels, 3,3, 2,2, 1,1))
19 | elseif type == 'up' then
20 | net:add(cudnn.SpatialFullConvolution(innerChannels, nOutChannels, 3,3, 2,2, 1,1, 1,1))
21 | else
22 | error("Please implement me: " .. type)
23 | end
24 | net:add(cudnn.SpatialBatchNormalization(nOutChannels))
25 | net:add(cudnn.ReLU(true))
26 | return net
27 | end
28 |
29 | local function build_net_normal(nChannels, nOutChannels, bottleneck, bnWidth)
30 | local net_warp = nn.Sequential()
31 | local net = nn.ParallelTable()
32 | net:add(nn.Identity())
33 | net:add(build(nChannels, nOutChannels, 'normal', bottleneck, bnWidth))
34 | net_warp:add(net):add(nn.JoinTable(2))
35 | return net_warp
36 | end
37 |
38 | local function build_net_down_normal(nChannels1, nChannels2, nOutChannels, bottleneck, bnWidth1, bnWidth2)
39 | local net_warpper = nn.Sequential()
40 | local net = nn.ParallelTable()
41 | assert(nOutChannels % 2 == 0, 'Growth rate invalid!')
42 | net:add(nn.Identity())
43 | net:add(build(nChannels1, nOutChannels/2, 'down', bottleneck, bnWidth1))
44 | net:add(build(nChannels2, nOutChannels/2, 'normal', bottleneck, bnWidth2))
45 | net_warpper:add(net):add(nn.JoinTable(2))
46 | return net_warpper
47 | end
48 |
49 |
50 | ----------------
51 | --- MSDNet_Layer_first:
52 | --- the input layer of MSDNet
53 | --- input: a tensor (orginal image)
54 | --- output: a table of nScale tensors
55 | ----------------
56 |
57 | local MSDNet_Layer_first, parent = torch.class('nn.MSDNet_Layer_first', 'nn.Container')
58 |
59 | function MSDNet_Layer_first:__init(nChannels, nOutChannels, opt)
60 | parent.__init(self)
61 |
62 | self.train = true
63 | self.nChannels = nChannels
64 | self.nOutChannels = nOutChannels
65 |
66 | self.opt = opt
67 |
68 | self.modules = {}
69 | -- transform raw input to first layer
70 | self.modules[1] = nn.Sequential()
71 | if opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then
72 | self.modules[1]:add(cudnn.SpatialConvolution(nChannels, nOutChannels*opt.grFactor[1], 3,3, 1,1, 1,1))
73 | self.modules[1]:add(cudnn.SpatialBatchNormalization(nOutChannels*opt.grFactor[1]))
74 | self.modules[1]:add(cudnn.ReLU(true))
75 | elseif opt.dataset == 'imagenet' then
76 | self.modules[1]:add(cudnn.SpatialConvolution(nChannels,nOutChannels*opt.grFactor[1], 7,7, 2,2, 3,3))
77 | self.modules[1]:add(nn.SpatialBatchNormalization(nOutChannels*opt.grFactor[1]))
78 | self.modules[1]:add(cudnn.ReLU(true))
79 | self.modules[1]:add(nn.SpatialMaxPooling(3,3,2,2,1,1))
80 | end
81 |
82 | local nIn = nOutChannels * opt.grFactor[1]
83 | for i = 2, opt.nScales do
84 | self.modules[i] = nn.Sequential()
85 | self.modules[i]:add(cudnn.SpatialConvolution(nIn, nOutChannels*opt.grFactor[i], 3,3, 2,2, 1,1))
86 | self.modules[i]:add(cudnn.SpatialBatchNormalization(nOutChannels*opt.grFactor[i]))
87 | self.modules[i]:add(cudnn.ReLU(true))
88 | nIn = nOutChannels*opt.grFactor[i]
89 | end
90 |
91 | self.gradInput = torch.CudaTensor()
92 | self.output = {}
93 | end
94 |
95 | function MSDNet_Layer_first:updateOutput(input)
96 | for i = 1, self.opt.nScales do
97 | self.output[i] = self.output[i] or input.new()
98 | local tmp_input = (i==1) and input or self.output[i-1]
99 | local tmp_output = self:rethrowErrors(self.modules[i], i, 'updateOutput', tmp_input)
100 | self.output[i]:resizeAs(tmp_output):copy(tmp_output)
101 | end
102 | return self.output
103 | end
104 |
105 | function MSDNet_Layer_first:updateGradInput(input, gradOutput)
106 | self.gradInput = self.gradInput or input.new()
107 | self.gradInput:resizeAs(input):zero()
108 | for i = self.opt.nScales-1, 1, -1 do
109 | gradOutput[i]:add(self.modules[i+1]:updateGradInput(self.output[i], gradOutput[i+1]))
110 | end
111 | self.gradInput:resizeAs(input):copy(self.modules[1]:updateGradInput(input, gradOutput[1]))
112 | return self.gradInput
113 | end
114 |
115 | function MSDNet_Layer_first:accGradParameters(input, gradOutput, scale)
116 | scale = scale or 1
117 | for i = self.opt.nScales, 2, -1 do
118 | self.modules[i]:accGradParameters(self.output[i-1], gradOutput[i], scale)
119 | end
120 | self.modules[1]:accGradParameters(input, gradOutput[1], scale)
121 | end
122 |
123 | function MSDNet_Layer_first:__tostring__()
124 | local tab = ' '
125 | local line = '\n'
126 | local next = ' |`-> '
127 | local lastNext = ' `-> '
128 | local ext = ' | '
129 | local extlast = ' '
130 | local last = ' ... -> '
131 | local str = 'MSDNet_Layer_first'
132 | str = str .. ' {' .. line .. tab .. '{input}'
133 | for i=1,#self.modules do
134 | if i == #self.modules then
135 | str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
136 | else
137 | str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
138 | end
139 | end
140 | str = str .. line .. tab .. last .. '{output}'
141 | str = str .. line .. '}'
142 | return str
143 | end
144 |
145 | ----------------
146 | --- MSDNet_Layer (without upsampling):
147 | --- subsequent layers of MSDNet
148 | --- input: a table of `nScales` tensors
149 | --- output: a table of `nScales` tensors
150 | ----------------
151 |
152 |
153 | local MSDNet_Layer, parent = torch.class('nn.MSDNet_Layer', 'nn.Container')
154 |
155 |
156 | function MSDNet_Layer:__init(nChannels, nOutChannels, opt, inScales, outScales)
157 | parent.__init(self)
158 |
159 | self.train = true
160 | self.nChannels = nChannels
161 | self.nOutChannels = nOutChannels
162 | self.opt = opt
163 | self.inScales = inScales or opt.nScales
164 | self.outScales = outScales or opt.nScales
165 | self.nScales = opt.nScales
166 | self.discard = self.inScales - self.outScales
167 | assert(self.discard<=1, 'Double check inScales'..self.inScales..'and outScales: '..self.outScales)
168 |
169 | local offset = self.nScales - self.outScales
170 | self.modules = {}
171 | local isTrans = self.outScales 1 then
251 | self.gradInput[self.inScales]:resizeAs(self.gIn[self.outScales][1])
252 | :copy(self.gIn[self.outScales][1]):add(self.gIn[self.outScales][3])
253 | end
254 |
255 | return self.gradInput
256 |
257 | end
258 |
259 |
260 | function MSDNet_Layer:accGradParameters(input, gradOutput, scale)
261 |
262 | scale = scale or 1
263 | for i = 1, self.outScales do
264 | self.modules[i]:accGradParameters(self.real_input[i], gradOutput[i], scale)
265 | end
266 | end
267 |
268 |
269 | function MSDNet_Layer:clearState()
270 | -- don't call set because it might reset referenced tensors
271 | local function clear(f)
272 | if self[f] then
273 | if torch.isTensor(self[f]) then
274 | self[f] = self[f].new()
275 | elseif type(self[f]) == 'table' then
276 | self[f] = {}
277 | else
278 | self[f] = nil
279 | end
280 | end
281 | end
282 | clear('output')
283 | clear('gradInput')
284 | clear('real_input')
285 | clear('gIn')
286 | if self.modules then
287 | for i,module in pairs(self.modules) do
288 | module:clearState()
289 | end
290 | end
291 | return self
292 | end
293 |
294 | function MSDNet_Layer:__tostring__()
295 | local tab = ' '
296 | local line = '\n'
297 | local next = ' |`-> '
298 | local lastNext = ' `-> '
299 | local ext = ' | '
300 | local extlast = ' '
301 | local last = ' ... -> '
302 | local str = 'MSDNet_Layer'
303 | str = str .. ' {' .. line .. tab .. '{input}'
304 | for i = 1,#self.modules do
305 | if i == #self.modules then
306 | str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
307 | else
308 | str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
309 | end
310 | end
311 | str = str .. line .. tab .. last .. '{output}'
312 | str = str .. line .. '}'
313 | return str
314 | end
--------------------------------------------------------------------------------
/models/init.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | -- Generic model creating code. For the specific MSDNet model see
10 | -- models/MSDNet.lua
11 | --
12 |
13 | require 'nn'
14 | require 'cunn'
15 | require 'cudnn'
16 | require 'models/MSDNet_Layer'
17 | require 'models/JointTrainContainer'
18 |
19 | local M = {}
20 |
21 | function M.setup(opt, checkpoint)
22 | local model
23 | if checkpoint then
24 | local modelPath = paths.concat(opt.resume, checkpoint.modelFile)
25 | assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath)
26 | print('=> Resuming model from ' .. modelPath)
27 | model = torch.load(modelPath):type(opt.tensorType)
28 | elseif opt.retrain ~= 'none' then
29 | assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
30 | print('Loading model from file: ' .. opt.retrain)
31 | model = torch.load(opt.retrain):type(opt.tensorType)
32 | print(model)
33 | model.__memoryOptimized = nil
34 | else
35 | print('=> Creating model from file: models/' .. opt.netType .. '.lua')
36 | model = require('models/' .. opt.netType)(opt)
37 | end
38 |
39 | -- First remove any DataParallelTable
40 | if torch.type(model) == 'nn.DataParallelTable' then
41 | model = model:get(1)
42 | end
43 |
44 | -- optnet is an general library for reducing memory usage in neural networks
45 | if opt.optnet then
46 | local optnet = require 'optnet'
47 | local imsize = opt.dataset == 'imagenet' and 224 or 32
48 | local sampleInput = torch.zeros(4,3,imsize,imsize):type(opt.tensorType)
49 | optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'})
50 | end
51 |
52 | -- This is useful for fitting ResNet-50 on 4 GPUs, but requires that all
53 | -- containers override backwards to call backwards recursively on submodules
54 | if opt.shareGradInput then
55 | M.shareGradInput(model, opt)
56 | end
57 |
58 | -- Set the CUDNN flags
59 | if opt.cudnn == 'fastest' then
60 | cudnn.fastest = true
61 | cudnn.benchmark = true
62 | elseif opt.cudnn == 'deterministic' then
63 | -- Use a deterministic convolution implementation
64 | model:apply(function(m)
65 | if m.setMode then m:setMode(1, 1, 1) end
66 | end)
67 | end
68 |
69 | -- Wrap the model with DataParallelTable, if using more than one GPU
70 | if opt.nGPU > 1 then
71 | local gpus = torch.range(1, opt.nGPU):totable()
72 | local fastest, benchmark = cudnn.fastest, cudnn.benchmark
73 |
74 | local dpt = nn.DataParallelTable(1, true, true)
75 | :add(model, gpus)
76 | :threads(function()
77 | local cudnn = require 'cudnn'
78 | require 'models/MSDNet_Layer'
79 | require 'models/JointTrainContainer'
80 | cudnn.fastest, cudnn.benchmark = fastest, benchmark
81 | end)
82 | dpt.gradInput = nil
83 |
84 | model = dpt:type(opt.tensorType)
85 | end
86 |
87 | local criterion = nn.CrossEntropyCriterion():type(opt.tensorType)
88 | return model, criterion
89 | end
90 |
91 | function M.shareGradInput(model, opt)
92 | local function sharingKey(m)
93 | local key = torch.type(m)
94 | if m.__shareGradInputKey then
95 | key = key .. ':' .. m.__shareGradInputKey
96 | end
97 | return key
98 | end
99 |
100 | -- Share gradInput for memory efficient backprop
101 | local cache = {}
102 | model:apply(function(m)
103 | local moduleType = torch.type(m)
104 | if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' then
105 | local key = sharingKey(m)
106 | if cache[key] == nil then
107 | cache[key] = torch[self.opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')]()(1)
108 | end
109 | m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[key], 1, 0)
110 | end
111 | end)
112 | for i, m in ipairs(model:findModules('nn.ConcatTable')) do
113 | if cache[i % 2] == nil then
114 | cache[i % 2] = torch.CudaStorage(1)
115 | end
116 | m.gradInput = torch[self.opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](cache[i % 2], 1, 0)
117 | end
118 | end
119 |
120 | return M
121 |
--------------------------------------------------------------------------------
/models/msdnet.lua:
--------------------------------------------------------------------------------
1 | local nn = require 'nn'
2 | require 'cunn'
3 | require 'models/JointTrainContainer'
4 |
5 | local function createModel(opt)
6 | -- (1) configure
7 | if opt.stepmode == 'even' then
8 | assert(opt.base - opt.step >= 0, 'Base should not be smaller than step!')
9 | end
10 |
11 | local nChannels
12 | if opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then
13 | nChannels = opt.initChannels>0 and opt.initChannels or 16
14 | elseif opt.dataset == 'imagenet' then
15 | nChannels = opt.initChannels>0 and opt.initChannels or 64
16 | else
17 | error('invalid dataset: ' .. opt.dataset)
18 | end
19 |
20 | -- (2) build model
21 | print(' | MSDNet-Block' .. opt.nBlocks.. '-'..opt.step .. ' ' .. opt.dataset)
22 | local model = nn.Sequential()
23 | model:add(nn.JointTrainModule(nChannels, opt))
24 |
25 | -- (3) init model
26 | local function ConvInit(name)
27 | for k,v in pairs(model:findModules(name)) do
28 | local n = v.kW*v.kH*v.nOutputPlane
29 | v.weight:normal(0,math.sqrt(2/n))
30 | if cudnn.version >= 4000 then
31 | v.bias = nil
32 | v.gradBias = nil
33 | else
34 | v.bias:zero()
35 | end
36 | end
37 | end
38 | local function BNInit(name)
39 | for k,v in pairs(model:findModules(name)) do
40 | v.weight:fill(1)
41 | v.bias:zero()
42 | end
43 | end
44 | ConvInit('cudnn.SpatialConvolution')
45 | ConvInit('nn.SpatialConvolution')
46 | BNInit('fbnn.SpatialBatchNormalization')
47 | BNInit('cudnn.SpatialBatchNormalization')
48 | BNInit('nn.SpatialBatchNormalization')
49 | for k,v in pairs(model:findModules('nn.Linear')) do
50 | v.bias:zero()
51 | end
52 | model:type(opt.tensorType)
53 |
54 | if opt.cudnn == 'deterministic' then
55 | model:apply(function(m)
56 | if m.setMode then m:setMode(1,1,1) end
57 | end)
58 | end
59 |
60 | model:get(1).gradInput = nil
61 |
62 | -- (4) save the network definition
63 | local file = io.open(paths.concat(opt.save, 'model_definition.txt'), "w")
64 | for k, v in pairs(opt) do
65 | local s
66 | if v == true then
67 | s = 'true'
68 | elseif v == false then
69 | s = 'false'
70 | else
71 | s = v
72 | end
73 | file:write(tostring(k) .. ': '..tostring(s)..'\n')
74 | end
75 |
76 | file:write('\n model definition \n\n')
77 | file:write(model:__tostring__())
78 | file:close()
79 |
80 | print(model)
81 |
82 | return model
83 | end
84 |
85 | return createModel
86 |
--------------------------------------------------------------------------------
/opts.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- Copyright (c) 2016, Facebook, Inc.
3 | -- All rights reserved.
4 | --
5 | -- This source code is licensed under the BSD-style license found in the
6 | -- LICENSE file in the root directory of this source tree. An additional grant
7 | -- of patent rights can be found in the PATENTS file in the same directory.
8 | --
9 | local M = { }
10 |
11 | function M.parse(arg)
12 | local cmd = torch.CmdLine()
13 | cmd:text()
14 | cmd:text('Torch-7 ResNet Training script')
15 | cmd:text('See https://github.com/facebook/fb.resnet.torch/blob/master/TRAINING.md for examples')
16 | cmd:text()
17 | cmd:text('Options:')
18 | ------------ General options --------------------
19 | cmd:option('-data', '', 'Path to dataset')
20 | cmd:option('-dataset', 'imagenet', 'Options: imagenet | cifar10 | cifar100')
21 | cmd:option('-manualSeed', 0, 'Manually set RNG seed')
22 | cmd:option('-nGPU', 1, 'Number of GPUs to use by default')
23 | cmd:option('-backend', 'cudnn', 'Options: cudnn | cunn')
24 | cmd:option('-cudnn', 'fastest', 'Options: fastest | default | deterministic')
25 | cmd:option('-gen', 'gen', 'Path to save generated files')
26 | cmd:option('-precision', 'single', 'Options: single | double | half')
27 | ------------- Data options ------------------------
28 | cmd:option('-nThreads', 2, 'number of data loading threads')
29 | cmd:option('-DataAug', 'true', 'use data augmentation or not')
30 | cmd:option('-validset', 'true', 'use validation set or not')
31 | ------------- Training options --------------------
32 | cmd:option('-nEpochs', 0, 'Number of total epochs to run')
33 | cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)')
34 | cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)')
35 | cmd:option('-testOnly', 'false', 'Run on validation set only')
36 | cmd:option('-tenCrop', 'false', 'Ten-crop testing')
37 | ------------- Checkpointing options ---------------
38 | cmd:option('-save', 'checkpoints', 'Directory in which to save checkpoints')
39 | cmd:option('-resume', 'none', 'Resume from the latest checkpoint in this directory')
40 |
41 | ---------- Optimization options ----------------------
42 | cmd:option('-LR', 0.1, 'initial learning rate')
43 | cmd:option('-momentum', 0.9, 'momentum')
44 | cmd:option('-weightDecay', 1e-4, 'weight decay')
45 | ---------- Model options ----------------------------------
46 | cmd:option('-netType', 'MSDNet', '')
47 | cmd:option('-retrain', 'none', 'Path to model to retrain with')
48 | cmd:option('-optimState', 'none', 'Path to an optimState to reload from')
49 | ---------- Model options ----------------------------------
50 | cmd:option('-shareGradInput', 'false', 'Share gradInput tensors to reduce memory usage')
51 | cmd:option('-optnet', 'false', 'Use optnet to reduce memory usage')
52 | ---------- MSDNet MOdel options ----------------------------------
53 | cmd:option('-base', 4, 'the layer to attach the first classifier')
54 | cmd:option('-nBlocks', 1, 'number of blocks/classifiers')
55 | cmd:option('-stepmode', 'even', 'patten of span between two adjacent classifers |even|lin_grow|')
56 | cmd:option('-step', 1, 'span between two adjacent classifers')
57 | cmd:option('-bottleneck', 'true', 'use 1x1 conv layer or not')
58 | cmd:option('-reduction', 0.5, 'dimension reduction ratio at transition layers')
59 | cmd:option('-growthRate', 6, 'number of output channels for each layer (the first scale)')
60 | cmd:option('-grFactor', '1-2-4-4', 'growth rate factor of each sacle')
61 | cmd:option('-prune', 'max', 'specify how to prune the network, min | max')
62 | cmd:option('-joinType', 'concat', 'add or concat for features from different paths')
63 | cmd:option('-bnFactor', '1-2-4-4', 'bottleneck factor of each sacle, 4-4-4-4 | 1-2-4-4')
64 | cmd:option('-initChannels', 0, 'number of features produced by the initial conv layer')
65 |
66 | ---------- joint training options ----------------------------------
67 | cmd:option('-clearstate', 'true', 'save a model with clearsate or not')
68 | cmd:option('-joint_weight', 'uniform', 'weight of differnt classifiers: uniform | lin_grow | triangle | chi | gauss | exp')
69 |
70 |
71 | ---------- early exit testing options ----------------------------------
72 | cmd:option('-EEpath', 'none', 'the path to a saved model for early exit calculation')
73 | cmd:option('-EEensemble', 'true', 'use ensemble or not in early exit')
74 |
75 | cmd:text()
76 |
77 | local opt = cmd:parse(arg or {})
78 |
79 | opt.testOnly = opt.testOnly ~= 'false'
80 | opt.tenCrop = opt.tenCrop ~= 'false'
81 | opt.shareGradInput = opt.shareGradInput ~= 'false'
82 | opt.optnet = opt.optnet ~= 'false'
83 | opt.resetClassifier = opt.resetClassifier ~= 'false'
84 | opt.bottleneck = opt.bottleneck ~= 'false'
85 | opt.DataAug = opt.DataAug ~= 'false'
86 | opt.validset = opt.validset ~= 'false'
87 | opt.override = opt.override ~= 'false'
88 | opt.clearstate = opt.clearstate ~= 'false'
89 | opt.EEensemble = opt.EEensemble ~= 'false'
90 |
91 | -- for logging
92 | opt._grFactor = opt.grFactor
93 | opt._bnFactor = opt.bnFactor
94 |
95 | local bnFactor = {}
96 | for i, s in pairs(opt.bnFactor:split('-')) do
97 | bnFactor[i] = tonumber(s)
98 | end
99 | opt.bnFactor = bnFactor
100 |
101 | local grFactor = {}
102 | for i, s in pairs(opt.grFactor:split('-')) do
103 | grFactor[i] = tonumber(s)
104 | end
105 | opt.grFactor = grFactor
106 |
107 | if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then
108 | cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n')
109 | end
110 |
111 | if opt.dataset == 'imagenet' then
112 | -- Handle the most common case of missing -data flag
113 | local trainDir = paths.concat(opt.data, 'train')
114 | if not paths.dirp(opt.data) then
115 | cmd:error('error: missing ImageNet data directory')
116 | elseif not paths.dirp(trainDir) then
117 | cmd:error('error: ImageNet missing `train` directory: ' .. trainDir)
118 | end
119 | -- Default shortcutType=B and nEpochs=90
120 | opt.shortcutType = opt.shortcutType == '' and 'B' or opt.shortcutType
121 | opt.nEpochs = opt.nEpochs == 0 and 90 or opt.nEpochs
122 | opt.nScales = 4
123 | elseif opt.dataset == 'cifar10' then
124 | -- Default shortcutType=A and nEpochs=164
125 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType
126 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs
127 | opt.nScales = 3
128 | elseif opt.dataset == 'cifar100' then
129 | -- Default shortcutType=A and nEpochs=164
130 | opt.shortcutType = opt.shortcutType == '' and 'A' or opt.shortcutType
131 | opt.nEpochs = opt.nEpochs == 0 and 164 or opt.nEpochs
132 | opt.nScales = 3
133 | else
134 | cmd:error('unknown dataset: ' .. opt.dataset)
135 | end
136 |
137 | if opt.precision == nil or opt.precision == 'single' then
138 | opt.tensorType = 'torch.CudaTensor'
139 | elseif opt.precision == 'double' then
140 | opt.tensorType = 'torch.CudaDoubleTensor'
141 | elseif opt.precision == 'half' then
142 | opt.tensorType = 'torch.CudaHalfTensor'
143 | else
144 | cmd:error('unknown precision: ' .. opt.precision)
145 | end
146 |
147 | if opt.shareGradInput and opt.optnet then
148 | cmd:error('error: cannot use both -shareGradInput and -optnet')
149 | end
150 |
151 | return opt
152 | end
153 |
154 | return M
155 |
--------------------------------------------------------------------------------
/saveTXT.lua:
--------------------------------------------------------------------------------
1 | local save2txt = function(filename, data, precision)
2 |
3 | local precision = precision or 4
4 |
5 | local sz
6 | if torch.type(data) == 'table' and #data > 0 then
7 | local n = #data
8 | local d = #data[1]
9 | if torch.type(d) ~= 'number' then
10 | d = d[1]
11 | end
12 | sz = {n,d}
13 | elseif torch.isTensor(data) then
14 | sz = data:size()
15 | else
16 | print('Input cannot be saved to TXT file')
17 | end
18 |
19 | if #sz == 1 then
20 | local file = io.open(filename..'.txt', "w")
21 | for i = 1, sz[1] do
22 | file:write(string.format('%0.'..precision..'f\t', data[i]))
23 | file:write('\n')
24 | end
25 | file:close()
26 | elseif #sz ==2 then
27 | local file = io.open(filename..'.txt', "w")
28 | for i = 1, sz[1] do
29 | for j = 1, sz[2] do
30 | file:write(string.format('%0.'..precision..'f\t', data[i][j]))
31 | end
32 | file:write('\n')
33 | end
34 | file:close()
35 | else
36 | print('Input cannot be saved to TXT file')
37 | end
38 | end
39 |
40 | return save2txt
41 |
--------------------------------------------------------------------------------
/train_joint.lua:
--------------------------------------------------------------------------------
1 |
2 | local optim = require 'optim'
3 |
4 | local M = {}
5 | local Trainer = torch.class('MSDNet.Trainer', M)
6 |
7 | function Trainer:__init(model, criterion, opt, optimState)
8 | self.model = model
9 | self.criterion = criterion
10 | self.optimState = optimState or {
11 | learningRate = opt.LR,
12 | learningRateDecay = 0.0,
13 | momentum = opt.momentum,
14 | nesterov = true,
15 | dampening = 0.0,
16 | weightDecay = opt.weightDecay,
17 | }
18 | self.opt = opt
19 | self.params, self.gradParams = model:getParameters()
20 | end
21 |
22 | function Trainer:train(epoch, dataloader)
23 | -- Trains the model for a single epoch
24 | self.optimState.learningRate = self:learningRate(epoch)
25 |
26 | local timer = torch.Timer()
27 | local dataTimer = torch.Timer()
28 |
29 | local function feval()
30 | return self.criterion.output, self.gradParams
31 | end
32 |
33 | local trainSize = dataloader:size()
34 | local lossSum = 0.0
35 | local N = 0
36 |
37 | local top1All, top5All = torch.zeros(self.opt.nBlocks), torch.zeros(self.opt.nBlocks)
38 | local top1Evolve, top5Evolve = torch.zeros(self.opt.nBlocks), torch.zeros(self.opt.nBlocks)
39 |
40 | print('=> Training epoch # ' .. epoch)
41 | -- set the batch norm to training mode
42 | self.model:training()
43 | for n, sample in dataloader:run() do
44 | local dataTime = dataTimer:time().real
45 |
46 | -- Copy input and target to the GPU
47 | self:copyInputs(sample)
48 |
49 | local output = self.model:forward(self.input)
50 | local batchSize = output[1]:size(1)
51 |
52 | -- create a table which contains `nBlocks' same targets
53 | local multi_targets = {}
54 | for i = 1, self.opt.nBlocks do
55 | multi_targets[i] = self.target
56 | end
57 | local loss = self.criterion:forward(self.model.output, multi_targets)
58 | lossSum = lossSum + loss
59 |
60 | self.model:zeroGradParameters()
61 | self.criterion:backward(self.model.output, multi_targets)
62 | self.model:backward(self.input, self.criterion.gradInput)
63 |
64 | optim.sgd(feval, self.params, self.optimState)
65 |
66 | -- sotre the cumulative softmax() of exits to obtain the ensemble performance
67 | local ensemble = torch.Tensor():resizeAs(output[1]:float()):zero()
68 | local top1, top5 = 0, 0
69 | for i = 1, self.opt.nBlocks do
70 | -- single exit
71 | top1, top5 = self:computeScore(output[i]:float(), sample.target, 1)
72 | top1All[i] = top1All[i] + top1*batchSize
73 | top5All[i] = top5All[i] + top5*batchSize
74 | -- ensemble
75 | ensemble:add(nn.SoftMax():forward(output[i]:float()))
76 | top1, top5 = self:computeScore(ensemble, sample.target, 1)
77 | top1Evolve[i] = top1Evolve[i] + top1*batchSize
78 | top5Evolve[i] = top5Evolve[i] + top5*batchSize
79 | end
80 |
81 | N = N + batchSize
82 |
83 | print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.4f top1 %7.3f top5 %7.3f'):format(
84 | epoch, n, trainSize, timer:time().real, dataTime, loss, top1All[self.opt.nBlocks]/N, top5All[self.opt.nBlocks]/N))
85 |
86 | -- check that the storage didn't get changed due to an unfortunate getParameters call
87 | assert(self.params:storage() == self.model:parameters()[1]:storage())
88 |
89 | timer:reset()
90 | dataTimer:reset()
91 | end
92 |
93 | for i = 1, self.opt.nBlocks do
94 | top1All[i] = top1All[i] / N
95 | top5All[i] = top5All[i] / N
96 | top1Evolve[i] = top1Evolve[i] / N
97 | top5Evolve[i] = top5Evolve[i] / N
98 | print((' * Train %d exit single top1: %7.3f top5: %7.3f, \t Ensemble %d exit(s) top1: %7.3f top5: %7.3f')
99 | :format(i, top1All[i], top5All[i], i, top1Evolve[i], top5Evolve[i]))
100 | end
101 |
102 | return top1All, top5All, top1Evolve, top5Evolve, lossSum / N
103 | end
104 |
105 | function Trainer:test(epoch, dataloader, prefix)
106 | -- Computes the top-1 and top-5 err on the validation/test set
107 | prefix = prefix or 'Test'
108 | local timer = torch.Timer()
109 | local dataTimer = torch.Timer()
110 | local size = dataloader:size()
111 |
112 | local nCrops = self.opt.tenCrop and 10 or 1
113 | local N = 0
114 |
115 | local top1All, top5All = torch.zeros(self.opt.nBlocks), torch.zeros(self.opt.nBlocks)
116 | local top1Evolve, top5Evolve = torch.zeros(self.opt.nBlocks), torch.zeros(self.opt.nBlocks)
117 |
118 | self.model:evaluate()
119 | for n, sample in dataloader:run() do
120 | local dataTime = dataTimer:time().real
121 |
122 | -- Copy input and target to the GPU
123 | self:copyInputs(sample)
124 |
125 | local output = self.model:forward(self.input)
126 | local batchSize = output[1]:size(1) / nCrops
127 | -- create a table which contains `nBlocks' same targets
128 | local multi_targets = {}
129 | for i = 1, self.opt.nBlocks do
130 | multi_targets[i] = self.target
131 | end
132 | local loss = self.criterion:forward(self.model.output, multi_targets)
133 |
134 | -- sotre the cumulative softmax() of exits to obtain the ensemble performance
135 | local ensemble = torch.Tensor():resizeAs(output[1]:float()):zero()
136 | local top1, top5 = 0, 0
137 | for i = 1, self.opt.nBlocks do
138 | -- single exit
139 | ensemble = ensemble or torch.Tensor():resizeAs(output[1]:float()):zero()
140 | top1, top5 = self:computeScore(output[i]:float(), sample.target, 1)
141 | top1All[i] = top1All[i] + top1*batchSize
142 | top5All[i] = top5All[i] + top5*batchSize
143 | -- ensemble
144 | ensemble:add(nn.SoftMax():forward(output[i]:float()))
145 | top1, top5 = self:computeScore(ensemble, sample.target, 1)
146 | top1Evolve[i] = top1Evolve[i] + top1*batchSize
147 | top5Evolve[i] = top5Evolve[i] + top5*batchSize
148 | end
149 | ----------------------------------------------------
150 | N = N + batchSize
151 |
152 | print((' | %s: [%d][%d/%d] Time %.3f Data %.3f top1 %7.3f (cumul: %7.3f) top5 %7.3f (cumul: %7.3f)'):format(
153 | prefix, epoch, n, size, timer:time().real, dataTime, top1, top1All[self.opt.nBlocks]/N,
154 | top5, top5All[self.opt.nBlocks]/N))
155 |
156 | timer:reset()
157 | dataTimer:reset()
158 | end
159 | self.model:training()
160 |
161 | for i = 1, self.opt.nBlocks do
162 | top1All[i] = top1All[i] / N
163 | top5All[i] = top5All[i] / N
164 | top1Evolve[i] = top1Evolve[i] / N
165 | top5Evolve[i] = top5Evolve[i] / N
166 | print((' * %s %d exit top1: %7.3f top5: %7.3f, \t Ensemble %d exit(s) top1: %7.3f top5: %7.3f'):format(
167 | prefix, i, top1All[i], top5All[i], i, top1Evolve[i], top5Evolve[i]))
168 | end
169 |
170 | return top1All, top5All, top1Evolve, top5Evolve
171 | end
172 |
173 | function Trainer:computeScore(output, target, nCrops)
174 | if nCrops > 1 then
175 | -- Sum over crops
176 | output = output:view(output:size(1) / nCrops, nCrops, output:size(2))
177 | --:exp()
178 | :sum(2):squeeze(2)
179 | end
180 |
181 | -- Coputes the top1 and top5 error rate
182 | local batchSize = output:size(1)
183 |
184 | local _, predictions = output:float():topk(5, 2, true, true) -- descending sort
185 |
186 | -- Find which predictions match the target
187 | local correct = predictions:eq(
188 | target:long():view(batchSize, 1):expandAs(predictions))
189 |
190 | -- Top-1 score
191 | local top1 = 1.0 - (correct:narrow(2, 1, 1):sum() / batchSize)
192 |
193 | -- Top-5 score, if there are at least 5 classes
194 | local len = math.min(5, correct:size(2))
195 | local top5 = 1.0 - (correct:narrow(2, 1, len):sum() / batchSize)
196 |
197 | return top1 * 100, top5 * 100
198 | end
199 |
200 | local function getCudaTensorType(tensorType)
201 | if tensorType == 'torch.CudaHalfTensor' then
202 | return cutorch.createCudaHostHalfTensor()
203 | elseif tensorType == 'torch.CudaDoubleTensor' then
204 | return cutorch.createCudaHostDoubleTensor()
205 | else
206 | return cutorch.createCudaHostTensor()
207 | end
208 | end
209 |
210 | function Trainer:copyInputs(sample)
211 | -- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory,
212 | -- if using DataParallelTable. The target is always copied to a CUDA tensor
213 | self.input = self.input or (self.opt.nGPU == 1
214 | and torch[self.opt.tensorType:match('torch.(%a+)')]()
215 | or getCudaTensorType(self.opt.tensorType))
216 | self.target = self.target or (torch.CudaLongTensor and torch.CudaLongTensor())
217 | self.input:resize(sample.input:size()):copy(sample.input)
218 | self.target:resize(sample.target:size()):copy(sample.target)
219 | end
220 |
221 | function Trainer:learningRate(epoch)
222 | -- Training schedule
223 | local decay = 0
224 | if self.opt.dataset == 'imagenet' then
225 | decay = math.floor((epoch - 1) / 30)
226 | elseif self.opt.dataset == 'cifar10' then
227 | decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0
228 | elseif self.opt.dataset == 'cifar100' then
229 | decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0
230 | end
231 | return self.opt.LR * math.pow(0.1, decay)
232 | end
233 |
234 |
235 | return M.Trainer
--------------------------------------------------------------------------------