├── CMakeLists.txt ├── LICENCE ├── README.md ├── countUsedMemory.lua ├── doc ├── googlenet.gif └── googlenet_optimized.gif ├── env.lua ├── example.lua ├── graphgen.lua ├── init.lua ├── models.lua ├── rocks └── optnet-scm-1.rockspec ├── tests.lua └── utils.lua /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 2 | CMAKE_POLICY(VERSION 2.6) 3 | FIND_PACKAGE(Torch REQUIRED) 4 | 5 | FILE(GLOB luasrc *.lua) 6 | 7 | ADD_TORCH_PACKAGE(optnet "" "${luasrc}" "Memory optimizations for nn") 8 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Francisco Massa 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OptNet - reducing memory usage in torch neural networks 2 | 3 | Memory optimizations for torch neural networks. 4 | 5 | Heavily inspired from the `Optimizer` from https://github.com/facebook/fb-caffe-exts 6 | 7 | ## Installing 8 | Simply do 9 | ``` 10 | luarocks install optnet 11 | ``` 12 | 13 | ## How does it work ? 14 | 15 | It goes over the network and verify which buffers can be reused. 16 | It supports both inference (evaluation) mode and training mode. 17 | 18 | ### Inference mode 19 | 20 | Here is a list of currently tested modules. Numbers are for CPU version, with batch size of 1, for **double** type, in the format 21 | **(total memory used, memory used for the outputs, memory used for the internal buffers, memory used for the parameters and grad parameters)**: 22 | 23 | | Network | before optimization | after optimization | Relative save | 24 | | ------- | :--------: | :-------: | :------: | 25 | |alexnet | (973MB, 6MB, 43MB, 924MB) | (472MB, 1.5MB, 9MB, 462MB) | (51%, 75%, 80%, 50%) | 26 | |vgg-A | (2311MB, 69MB, 215MB, 2027MB) | (1106MB, 31MB, 61MB, 1014MB) | (52%, 55%, 72%, 50%) | 27 | |googlenet | (505MB, 69MB, 145MB, 292MB) | (193MB, 31MB, 16MB, 146MB) | (62%, 54%, 89%, 50%) | 28 | |resnet 110 (cifar)| (113MB, 16MB, 71MB, 26MB) | (15MB, 0.5MB, 1.3MB, 13MB) | (87%, 97%, 98%, 50%) | 29 | 30 | Note that for most of the models, for a batch size of 1 most of the memory is spent with the `weights` and `gradWeights` of the network (and the latter can be safely freed during inference). 31 | More interestingly, the the output size is *linearly* dependent on the batch size, which means that the total savings are much more significant for bigger batch sizes. 32 | 33 | In a more realistic setup where we use `cudnn` and batch size of 128, the gains are 34 | way more significant, specially for very deep networks like resnet. The memory usage is shown in the following table (for **float** type), following **(total memory used, memory used for the outputs, memory used for the parameters and grad parameters)** as `cudnn` almost don't use internal buffers: 35 | 36 | | Network | before optimization | after optimization | Relative save | 37 | | ------- | :--------: | :-------: | :------: | 38 | |alexnet | (859MB, 397MB, 462MB) | (328MB, 97MB, 231MB) | (62%, 75%, 50%) | 39 | |vgg-A | (5340MB, 4386MB, 1014MB) | (2467MB, 1960MB, 507MB) | (54%, 55%, 50%) | 40 | |googlenet | (4536MB, 4390MB, 146MB) | (2066MB, 1993MB, 73MB) | (54%, 55%, 50%) | 41 | |resnet 110 (cifar)| (1049MB, 1036MB, 13MB) | (39MB, 32MB, 7MB) | (96%, 97%, 50%) | 42 | 43 | ### Training mode 44 | 45 | We currently support a basic algorithm for training mode. 46 | Using `cudnn` with batch size of 64, we currently obtain the following savings, in the format **(total memory used, memory used for the outputs, memory used for the gradInputs, memory used for the parameters and gradParameters)**: 47 | 48 | | Network | before optimization | after optimization | Relative save | 49 | | ------- | :--------: | :-------: | :------: | 50 | |alexnet | (963MB, 195MB, 303MB, 462MB) | (816MB, 195MB, 156MB, 462MB) | (15%, 0%, 48%, 0%) | 51 | |vgg-A | (5433MB, 2191MB, 2228MB, 1014MB) | (4228MB, 2191MB, 1023MB, 1014MB) | (22%, 0%, 54%, 0%) | 52 | |googlenet | (6092MB, 2195MB, 3346MB, 146MB) | (4844MB, 2195MB, 2098MB, 146MB) | (20%, 0%, 37%, 0%) | 53 | |resnet 110 (cifar)| (664MB, 259MB, 392MB, 13MB) | (428MB, 259MB, 156MB, 13MB) | (36%, 0%, 60%, 0%) | 54 | 55 | Note that the relative save of the `gradInput` stays constant for different batch sizes, meaning that the total relative savings will be more important for bigger batch sizes (as the parameters doesn't depend on the batch size). 56 | 57 | We can setup the optimizations for training mode by using `mode='training'` as follows 58 | 59 | ```lua 60 | models = require 'optnet.models' 61 | modelname = 'googlenet' 62 | net, input = models[modelname]() 63 | 64 | opts = {inplace=true, mode='training'} 65 | 66 | optnet = require 'optnet' 67 | 68 | optnet.optimizeMemory(net, input, opts) 69 | ``` 70 | 71 | ### Optional parameters 72 | 73 | Here is a list of options that are currently supported, and should be passed in the `opts` table as a third argument: 74 | * `inplace`: uses in place modules when available (boolean) 75 | * `mode`: selects between `training` and `inference` optimization algorithm (string) 76 | * `reuseBuffers`: shares internal buffers between same modules (like unfolded images for convolution) (boolean) 77 | * `removeGradParams`: remove `gradWeight` and `gradBias` in the networks, saving their sharings so that they can be exactly reconstructed. Only applies for `inference` mode (boolean) 78 | 79 | ## Visualizing the memory reuse 80 | 81 | We can analyse the sharing of the internal buffers by looking at the computation 82 | graph of the network before and after the sharing. 83 | 84 | For that, we have the `graphgen(net, input, opts)` function, which creates the 85 | graph corresponding to the network `net`. The generated graph contains the storage 86 | id of each `output`, and same colors means same storage. 87 | 88 | Note that `net` is a `nn` model, and **not** a `nngraph` network. This allows us 89 | to use `optnet.graphgen` to generate graph visualizations of `nn` networks without 90 | having to use `nngraph`. 91 | 92 | Let's have a look: 93 | 94 | ```lua 95 | -- some handy models are defined in optnet.models 96 | -- like alexnet, googlenet, vgg and resnet 97 | models = require 'optnet.models' 98 | modelname = 'googlenet' 99 | net, input = models[modelname]() 100 | 101 | generateGraph = require 'optnet.graphgen' 102 | 103 | -- visual properties of the generated graph 104 | -- follows graphviz attributes 105 | graphOpts = { 106 | displayProps = {shape='ellipse',fontsize=14, style='solid'}, 107 | nodeData = function(oldData, tensor) 108 | return oldData .. '\n' .. 'Size: '.. tensor:numel() 109 | end 110 | } 111 | 112 | g = generateGraph(net, input, graphOpts) 113 | 114 | graph.dot(g,modelname,modelname) 115 | ``` 116 | 117 | This generates the following graph: 118 | 119 | ![GoogleNet without memory optimization](doc/googlenet.gif) 120 | 121 | Now what happens after we optimize the network ? Check the colors and the storage 122 | ids. 123 | 124 | ```lua 125 | models = require 'optnet.models' 126 | modelname = 'googlenet' 127 | net, input = models[modelname]() 128 | 129 | opts = {inplace=true, reuseBuffers=true} 130 | 131 | generateGraph = require 'optnet.graphgen' 132 | 133 | optnet = require 'optnet' 134 | 135 | optnet.optimizeMemory(net, input, opts) 136 | 137 | graphOpts = { 138 | displayProps = {shape='ellipse',fontsize=14, style='solid'}, 139 | nodeData = function(oldData, tensor) 140 | return oldData .. '\n' .. 'Size: '.. tensor:numel() 141 | end 142 | } 143 | 144 | g = generateGraph(net, input, graphOpts) 145 | 146 | graph.dot(g,modelname..'_optimized',modelname..'_optimized') 147 | ``` 148 | ![GoogleNet with memory optimization](doc/googlenet_optimized.gif) 149 | 150 | ## Counting the amount of saved memory 151 | 152 | We can also provide a function to compute the amount of memory used by the network 153 | in bytes, which allows us to check the amount of saved memory. 154 | It decomposes the total amount of memory in four fields: 155 | * total memory used by the outputs of each module 156 | * total memory used by the gradInputs of each module 157 | * total memory used by the internal buffers of each module 158 | * total memory used by the weights and gradWeights of each module. 159 | 160 | Here is an example 161 | 162 | ```lua 163 | optnet = require 'optnet' 164 | 165 | models = require 'optnet.models' 166 | modelname = 'googlenet' 167 | net, input = models[modelname]() 168 | 169 | -- countUsedMemory needs the network to 170 | -- be initialized with all its buffers 171 | -- to output correct results 172 | net:forward(input) 173 | mem1 = optnet.countUsedMemory(net) 174 | 175 | optnet.optimizeMemory(net, input) 176 | 177 | net:forward(input) 178 | mem2 = optnet.countUsedMemory(net) 179 | 180 | optnet.removeOptimization(net) 181 | 182 | net:forward(input) 183 | mem3 = optnet.countUsedMemory(net) 184 | 185 | print('Before optimization : '.. mem1.total_size/1024/1024 .. ' MBytes') 186 | print('After optimization : '.. mem2.total_size/1024/1024 .. ' MBytes') 187 | print('After removing optimization: '.. mem3.total_size/1024/1024 .. ' MBytes') 188 | 189 | ``` 190 | -------------------------------------------------------------------------------- /countUsedMemory.lua: -------------------------------------------------------------------------------- 1 | local optnet = require 'optnet.env' 2 | local utils = require 'optnet.utils' 3 | local keepTrack = utils.keepTrack 4 | 5 | function optnet.countUsedMemory(net) 6 | local tensors = {outputs={},buffers={},params={},gradInputs={}} 7 | local function entry_fun(t) 8 | return t 9 | end 10 | local function count_func(self) 11 | keepTrack(self.output, tensors.outputs, entry_fun) 12 | keepTrack(self.gradInput, tensors.gradInputs, entry_fun) 13 | for k, v in pairs(self) do 14 | if torch.isTensor(v) and 15 | k ~= 'weight' and k ~= 'bias' and 16 | k ~= 'gradWeight' and k ~= 'gradBias' and 17 | k ~= 'output' and k ~= 'gradInput' then 18 | keepTrack(v, tensors.buffers, entry_fun) 19 | end 20 | end 21 | for _, k in ipairs({'weight', 'bias', 'gradWeight','gradBias'}) do 22 | if self[k] then 23 | keepTrack(self[k], tensors.params, entry_fun) 24 | end 25 | end 26 | end 27 | net:apply(count_func) 28 | local total_size = 0 29 | local sizes = {} 30 | for typeTensor, subTensors in pairs(tensors) do 31 | sizes[typeTensor] = 0 32 | for k,v in pairs(subTensors) do 33 | local size = v:storage():size()*v:elementSize() 34 | total_size = total_size + size 35 | sizes[typeTensor] = sizes[typeTensor] + size 36 | end 37 | end 38 | sizes.total_size = total_size 39 | return sizes 40 | end 41 | -------------------------------------------------------------------------------- /doc/googlenet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmassa/optimize-net/d380df067126f14ea3e1a0d35f21c322e957e804/doc/googlenet.gif -------------------------------------------------------------------------------- /doc/googlenet_optimized.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmassa/optimize-net/d380df067126f14ea3e1a0d35f21c322e957e804/doc/googlenet_optimized.gif -------------------------------------------------------------------------------- /env.lua: -------------------------------------------------------------------------------- 1 | local optnet = {} 2 | return optnet 3 | -------------------------------------------------------------------------------- /example.lua: -------------------------------------------------------------------------------- 1 | optnet = require 'optnet' 2 | generateGraph = require 'optnet.graphgen' 3 | models = require 'optnet.models' 4 | 5 | modelname = 'googlenet' 6 | net, input = models[modelname]() 7 | 8 | graphOpts = { 9 | displayProps = {shape='box',fontsize=10, style='solid'}, 10 | nodeData = function(oldData, tensor) 11 | return oldData .. '\n' .. 'Size: '.. tensor:numel() 12 | end 13 | } 14 | 15 | g = generateGraph(net, input, graphOpts) 16 | graph.dot(g, modelname, modelname) 17 | 18 | optnet.optimizeMemory(net, input) 19 | 20 | g = generateGraph(net, input, graphOpts) 21 | graph.dot(g, modelname..'_optimized', modelname..'_optimized') 22 | -------------------------------------------------------------------------------- /graphgen.lua: -------------------------------------------------------------------------------- 1 | require 'graph' 2 | local utils = require 'optnet.utils' 3 | 4 | -- taken from http://www.graphviz.org/doc/info/colors.html 5 | local colorNames = { 6 | "aliceblue","antiquewhite","antiquewhite1","antiquewhite2","antiquewhite3", 7 | "antiquewhite4","aquamarine","aquamarine1","aquamarine2","aquamarine3", 8 | "aquamarine4","azure","azure1","azure2","azure3", 9 | "azure4","beige","bisque","bisque1","bisque2", 10 | "bisque3","bisque4","black","blanchedalmond","blue", 11 | "blue1","blue2","blue3","blue4","blueviolet", 12 | "brown","brown1","brown2","brown3","brown4", 13 | "burlywood","burlywood1","burlywood2","burlywood3","burlywood4", 14 | "cadetblue","cadetblue1","cadetblue2","cadetblue3","cadetblue4", 15 | "chartreuse","chartreuse1","chartreuse2","chartreuse3","chartreuse4", 16 | "chocolate","chocolate1","chocolate2","chocolate3","chocolate4", 17 | "coral","coral1","coral2","coral3","coral4", 18 | "cornflowerblue","cornsilk","cornsilk1","cornsilk2","cornsilk3", 19 | "cornsilk4","crimson","cyan","cyan1","cyan2", 20 | "cyan3","cyan4","darkgoldenrod","darkgoldenrod1","darkgoldenrod2", 21 | "darkgoldenrod3","darkgoldenrod4","darkgreen","darkkhaki","darkolivegreen", 22 | "darkolivegreen1","darkolivegreen2","darkolivegreen3","darkolivegreen4","darkorange", 23 | "darkorange1","darkorange2","darkorange3","darkorange4","darkorchid", 24 | "darkorchid1","darkorchid2","darkorchid3","darkorchid4","darksalmon", 25 | "darkseagreen","darkseagreen1","darkseagreen2","darkseagreen3","darkseagreen4", 26 | "darkslateblue","darkslategray","darkslategray1","darkslategray2","darkslategray3", 27 | "darkslategray4","darkslategrey","darkturquoise","darkviolet","deeppink", 28 | "deeppink1","deeppink2","deeppink3","deeppink4","deepskyblue", 29 | "deepskyblue1","deepskyblue2","deepskyblue3","deepskyblue4","dimgray", 30 | "dimgrey","dodgerblue","dodgerblue1","dodgerblue2","dodgerblue3", 31 | "dodgerblue4","firebrick","firebrick1","firebrick2","firebrick3", 32 | "firebrick4","floralwhite","forestgreen","gainsboro","ghostwhite", 33 | "gold","gold1","gold2","gold3","gold4", 34 | "goldenrod","goldenrod1","goldenrod2","goldenrod3","goldenrod4" 35 | } 36 | 37 | -- some modules exist only for constructing 38 | -- the flow of information, and should not 39 | -- have their place in the computation graph 40 | -- as separate entities 41 | local function isSingleOperationModule(m) 42 | if m.modules then 43 | return false 44 | end 45 | local constructorModules = { 46 | 'nn.Identity', 47 | 'nn.SelectTable', 48 | 'nn.NarrowTable', 49 | 'nn.FlattenTable' 50 | } 51 | local mType = torch.typename(m) 52 | for _, v in ipairs(constructorModules) do 53 | if mType == v then 54 | return false 55 | end 56 | end 57 | return true 58 | end 59 | 60 | local function isOperativeContainer(m) 61 | local mType = torch.typename(m) 62 | 63 | local opContainers = { 64 | 'nn.Concat', 65 | 'nn.Parallel', 66 | 'nn.DepthConcat' 67 | } 68 | for _, v in ipairs(opContainers) do 69 | if mType == v then 70 | return true 71 | end 72 | end 73 | 74 | -- those modules heritate from an 75 | -- operative container like nn.Concat 76 | local fakeContainers = { 77 | 'inn.SpatialPyramidPooling', 78 | } 79 | for _, v in ipairs(fakeContainers) do 80 | if mType == v then 81 | return true 82 | end 83 | end 84 | 85 | return false 86 | end 87 | 88 | -- generates a graph from a nn network 89 | -- Arguments: 90 | -- net: nn network 91 | -- input: input to the network 92 | -- opts: table with options for the graph generation. Options are 93 | -- nodeData: function that takes the string with storage id plus 94 | -- the tensor output from the module and outputs a 95 | -- string which will be displayed in the graph 96 | -- displayProps: display options from graphviz, like color, fontsize, 97 | -- style, etc 98 | -- addOutputNode: insert a dummy output node in the generated graph 99 | -- returns a graph representing the network 100 | local function generateGraph(net, input, opts) 101 | opts = opts or {} 102 | 103 | local storageHash = {} 104 | local nodes = {} 105 | 106 | local g = graph.Graph() 107 | 108 | -- basic function for creating an annotated nn.Node to suit our purposes 109 | -- gives the same color for the same storage. 110 | -- note that two colors being the same does not imply the same storage 111 | -- as we have a limited number of colors 112 | local function createNode(name, tensor) 113 | local data = torch.pointer(tensor:storage()) 114 | local storageId 115 | if not storageHash[data] then 116 | storageHash[data] = torch.random(1,#colorNames) 117 | table.insert(storageHash, data) 118 | end 119 | for k, v in ipairs(storageHash) do 120 | if v == data then 121 | storageId = k 122 | end 123 | end 124 | local nodeData = 'Storage id: '..storageId 125 | if opts.nodeData then 126 | nodeData = opts.nodeData(nodeData, tensor) 127 | end 128 | local node = graph.Node(nodeData) 129 | function node:graphNodeName() 130 | return name 131 | end 132 | function node:graphNodeAttributes() 133 | local prop = { 134 | color=colorNames[storageHash[data]], 135 | style = 'filled', 136 | shape = 'box', 137 | fontsize = 10, 138 | } 139 | if opts.displayProps then 140 | for k, v in pairs(opts.displayProps) do 141 | prop[k] = v 142 | end 143 | end 144 | return prop 145 | end 146 | return node 147 | end 148 | 149 | -- generate input/output nodes 150 | local function createBoundaryNode(input, name) 151 | if torch.isTensor(input) then 152 | local ptr = torch.pointer(input) 153 | nodes[ptr] = createNode(name,input) 154 | else 155 | for k,v in ipairs(input) do 156 | createBoundaryNode(v, name..' '..k) 157 | end 158 | end 159 | end 160 | 161 | -- create edge "from" -> "to", creating "to" on the way with "name" 162 | -- the edges can be seen as linking modules, but in fact it links the output 163 | -- tensor of each module 164 | local function addEdge(from, to, name) 165 | if torch.isTensor(to) and torch.isTensor(from) then 166 | local fromPtr = torch.pointer(from) 167 | local toPtr = torch.pointer(to) 168 | 169 | nodes[toPtr] = nodes[toPtr] or createNode(name,to) 170 | 171 | assert(nodes[fromPtr], 'Parent node inexistant for module '.. name) 172 | 173 | -- insert edge 174 | g:add(graph.Edge(nodes[fromPtr],nodes[toPtr])) 175 | elseif torch.isTensor(from) then 176 | for k,v in ipairs(to) do 177 | addEdge(from, v, name) 178 | end 179 | else 180 | for k,v in ipairs(from) do 181 | addEdge(v, to, name) 182 | end 183 | end 184 | end 185 | 186 | -- go over the network keeping track of the input/output for each module 187 | -- we overwrite the updateOutput for that. 188 | local function apply_func(m) 189 | local basefunc = m.updateOutput 190 | m.updateOutput = function(self, input) 191 | if isSingleOperationModule(m) then 192 | local name = tostring(m) 193 | if m.inplace then -- handle it differently ? 194 | addEdge(input,self.output,name) 195 | else 196 | addEdge(input,self.output,name) 197 | end 198 | elseif isOperativeContainer(m) then 199 | -- those containers effectively do some computation, so they have their 200 | -- place in the graph 201 | for i,branch in ipairs(m.modules) do 202 | local last_module 203 | if branch.modules then 204 | last_module = branch:get(branch:size()) 205 | else 206 | last_module = branch 207 | end 208 | local out = last_module.output 209 | local ptr = torch.pointer(out) 210 | 211 | local name = torch.typename(last_module) 212 | nodes[ptr] = nodes[ptr] or createNode(name,out) 213 | addEdge(out, self.output, torch.typename(m)) 214 | end 215 | end 216 | return basefunc(self, input) 217 | end 218 | end 219 | 220 | createBoundaryNode(input, 'Input') 221 | 222 | if torch.typename(net) == 'nn.DataParallelTable' then 223 | net = net.modules[1] 224 | end 225 | 226 | -- fill the states from each tensor 227 | net:forward(input) 228 | 229 | -- overwriting the standard functions to generate our graph 230 | net:apply(apply_func) 231 | -- generate the graph 232 | net:forward(input) 233 | 234 | if opts.addOutputNode then 235 | -- add dummy output node and link the last module to it 236 | local output = utils.recursiveClone(net.output) 237 | createBoundaryNode(output, 'Output') 238 | local function addOutputEdge(lastModule, output) 239 | if torch.isTensor(lastModule) then 240 | local fromPtr = torch.pointer(lastModule) 241 | local toPtr = torch.pointer(output) 242 | -- insert edge 243 | g:add(graph.Edge(nodes[fromPtr],nodes[toPtr])) 244 | 245 | else 246 | for k,v in ipairs(lastModule) do 247 | addOutputEdge(v, output[k]) 248 | end 249 | end 250 | end 251 | addOutputEdge(net.output, output) 252 | end 253 | 254 | -- clean up the modified function 255 | net:apply(function(x) 256 | x.updateOutput = nil 257 | end) 258 | 259 | return g 260 | end 261 | 262 | return generateGraph 263 | 264 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local optnet = require 'optnet.env' 4 | require 'optnet.countUsedMemory' 5 | require 'optnet.tests' 6 | 7 | local utils = require 'optnet.utils' 8 | 9 | local kNotUsed = 10000---1 10 | local kNotDefined = 0 11 | local kMinimumForSharing = 2 12 | local kAlwaysLive = 10000 13 | 14 | local function analyse(net, input, opts) 15 | opts = opts or {} 16 | local mode = opts.mode or 'inference' 17 | 18 | local track = {} 19 | local analysis = {} 20 | 21 | local function entry_fun(t, args) 22 | local ptr = torch.pointer(t:storage()) 23 | local info = {used=kNotUsed, defined=kNotDefined, 24 | name=args.name, ptr=ptr, tensor=t} 25 | table.insert(analysis, info) 26 | return info 27 | end 28 | 29 | local function fun(t, track, args) 30 | local ptr = torch.pointer(t:storage()) 31 | local val = track[ptr][args.var] 32 | if val == args.notUsed then 33 | track[ptr][args.var] = args.c 34 | else 35 | track[ptr][args.var] = args.f(args.c,val) 36 | end 37 | end 38 | 39 | local c = 1 40 | local function apply_func(m) 41 | local func = 'updateOutput' 42 | local basefunc = m[func] 43 | m[func] = function(self, input) 44 | local opts = { 45 | analysis=analysis, c=c, name=tostring(m), 46 | kNotUsed=kNotUsed, kNotDefined=kNotDefined 47 | } 48 | if mode == 'inference' then 49 | -- always keep track of the input 50 | opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed 51 | utils.keepTrack(input, track, entry_fun, fun, opts) 52 | 53 | if not m.modules then 54 | -- always keep track of the outputs of non-containers 55 | opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined 56 | utils.keepTrack(self.output, track, entry_fun, fun, opts) 57 | elseif torch.typename(m) == 'nn.Concat' or 58 | torch.typename(m) == 'nn.Parallel' or 59 | torch.typename(m) == 'nn.DepthConcat' then 60 | 61 | -- for containers that do some operations on the input, need to keep 62 | -- track of each output of its branches uppon entry on the module, 63 | -- as well as to keep track of it's own output (as it's a non-trivial 64 | -- operation on the childs output, contrary to nn.Sequential for 65 | -- example) 66 | opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined 67 | utils.keepTrack(self.output, track, entry_fun, fun, opts) 68 | 69 | for i,branch in ipairs(m.modules) do 70 | local last_module 71 | -- if brach is a container, get its last element, if not, take it 72 | if branch.modules then 73 | last_module = branch:get(branch:size()) 74 | else 75 | last_module = branch 76 | end 77 | local out = last_module.output 78 | opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined 79 | utils.keepTrack(out, track, entry_fun, fun, opts) 80 | end 81 | end 82 | end 83 | c = c + 1 84 | return basefunc(self,input) 85 | end 86 | 87 | for _, func in ipairs({'updateGradInput','accGradParameters','backward'}) do 88 | local basefunc = m[func] 89 | m[func] = function(self, input, gradOutput, scale) 90 | local opts = { 91 | analysis=analysis, c=c, name=tostring(m), 92 | kNotUsed=kNotUsed, kNotDefined=kNotDefined 93 | } 94 | 95 | -- always keep track of the input 96 | --opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed 97 | --utils.keepTrack(input, track, entry_fun, fun, opts) 98 | if not torch.isTypeOf(m, 'nn.Sequential') then 99 | -- always keep track of the gradOutput 100 | opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed 101 | utils.keepTrack(gradOutput, track, entry_fun, fun, opts) 102 | 103 | opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined 104 | utils.keepTrack(self.gradInput, track, entry_fun, fun, opts) 105 | end 106 | 107 | --[[ 108 | if not m.modules then 109 | -- always keep track of the gradInput of non-containers 110 | opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined 111 | utils.keepTrack(self.gradInput, track, entry_fun, fun, opts) 112 | elseif torch.typename(m) == 'nn.Concat' or 113 | torch.typename(m) == 'nn.Parallel' or 114 | torch.typename(m) == 'nn.DepthConcat' then 115 | 116 | -- for containers that do some operations on the gradOutput, need to keep 117 | -- track of each gradInput of its branches uppon entry on the module, 118 | -- as well as to keep track of it's own gradInput (as it's a non-trivial 119 | -- operation on the childs output, contrary to nn.Sequential for 120 | -- example) 121 | opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined 122 | utils.keepTrack(self.gradInput, track, entry_fun, fun, opts) 123 | 124 | for i,branch in ipairs(m.modules) do 125 | local first_module = branch:get(1) 126 | local out = first_module.gradInput 127 | opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined 128 | utils.keepTrack(out, track, entry_fun, fun, opts) 129 | end 130 | end 131 | --]] 132 | c = c + 1 133 | return basefunc(self,input,gradOutput,scale) 134 | end 135 | 136 | end 137 | 138 | end 139 | net:apply(apply_func) 140 | local out = net['forward'](net, input) 141 | local grad 142 | if mode == 'training' then 143 | grad = utils.recursiveClone(out) 144 | net['backward'](net, input, grad) 145 | end 146 | local function trackInputs(t, name) 147 | if torch.isTensor(t) then 148 | local f = function(a,b) return a end 149 | utils.keepTrack(t, track, entry_fun, fun, 150 | {var='used', c=kAlwaysLive, 151 | f=f, notUsed=0, name=name}) 152 | utils.keepTrack(t, track, entry_fun, fun, 153 | {var='defined', c=-kAlwaysLive, 154 | f=f, notUsed=0, name=name}) 155 | else 156 | for k,v in ipairs(t) do 157 | trackInputs(v) 158 | end 159 | end 160 | end 161 | trackInputs(input,'input') 162 | if grad then 163 | trackInputs(grad,'grad') 164 | end 165 | -- clean up the modified function 166 | net:apply(function(x) 167 | x['updateOutput'] = nil 168 | x['updateGradInput'] = nil 169 | x['accGradParameters'] = nil 170 | x['backward'] = nil 171 | end) 172 | 173 | -- disable backward pass if in evaluation mode 174 | if mode == 'inference' then 175 | net:apply(function(m) 176 | m.updateGradInput = function(self, input, gradInput) 177 | error([[Backward pass disabled! 178 | You are using inference optimization. 179 | Call optnet.removeOptimization(net) to enable backward again]]) 180 | end 181 | end) 182 | end 183 | return analysis 184 | end 185 | 186 | local function isCompatible(candidate, assignment) 187 | if candidate.used == kNotUsed then 188 | return false 189 | end 190 | if candidate.tensor:numel() < kMinimumForSharing then 191 | return false 192 | end 193 | local a_used = assignment[#assignment].used 194 | return candidate.defined > a_used 195 | end 196 | 197 | local function assign(net, analysis) 198 | table.sort(analysis, function(a,b) 199 | local x = a.used 200 | local y = b.used 201 | return x < y 202 | end) 203 | local assignments = {} 204 | for _,candidate in ipairs(analysis) do 205 | local assigned = false 206 | local bestAssignment = 0 207 | local minDist = math.huge 208 | local candidateSize = candidate.tensor:numel() 209 | for idx, assignment in ipairs(assignments) do 210 | if isCompatible(candidate, assignment) then 211 | assigned = true 212 | local dist = math.abs(assignment.maxSize-candidateSize) 213 | if dist < minDist then 214 | minDist = dist 215 | bestAssignment = idx 216 | end 217 | end 218 | end 219 | if assigned then 220 | local assignment = assignments[bestAssignment] 221 | table.insert(assignment, candidate) 222 | assignment.maxSize = math.max(assignment.maxSize, candidateSize) 223 | else 224 | table.insert(assignments, {candidate, maxSize=candidateSize}) 225 | end 226 | end 227 | return assignments 228 | end 229 | 230 | local function applyAssignments(net, assignments) 231 | for _, assignment in ipairs(assignments) do 232 | local storage 233 | for k, v in ipairs(assignment) do 234 | if v.used == kAlwaysLive and v.defined == -kAlwaysLive then 235 | break 236 | end 237 | storage = storage or v.tensor.new(1):storage() 238 | v.tensor:set(storage) 239 | end 240 | end 241 | end 242 | 243 | local function defaultValue(var, val) 244 | if var == nil then 245 | var = val 246 | end 247 | return var 248 | end 249 | 250 | -- set to inplace modules that allows it 251 | local function setInplace(net, opts) 252 | local inplace = defaultValue(opts.inplace, true) 253 | 254 | if inplace then 255 | net:apply(function(m) 256 | if m.inplace ~= nil then 257 | -- inplace is not always supported for threshold, 258 | -- depending on the values. Disabling it for the moment 259 | if torch.typename(m) ~= 'nn.Threshold' then 260 | m.inplace = true 261 | end 262 | end 263 | end) 264 | end 265 | end 266 | 267 | local reusableBuffers = { 268 | ['nn.SpatialConvolution'] = {{'finput','fgradInput'},{'fgradInput'}}, 269 | ['nn.SpatialConvolutionMM'] = {{'finput','fgradInput'},{'fgradInput'}}, 270 | ['nn.Normalize'] = {{'norm','buffer','normp','_indices'},{}}, 271 | ['nn.SpatialCrossMapLRN'] = {{'scale'},{}}, 272 | ['nn.SpatialMaxPooling'] = {{'indices'},{}}, 273 | } 274 | -- basic reusing scheme: keeps a list of all possible buffers 275 | -- that can be reused in evaluation mode and also in training 276 | -- mode. 277 | local function reuseStateBuffers(net, opts) 278 | local reuseBuffers = defaultValue(opts.reuseBuffers, true) 279 | local mode = defaultValue(opts.mode, 'inference') 280 | local mode_idx = 1 281 | if mode == 'training' then 282 | mode_idx = 2 283 | end 284 | -- workaround SpatialMaxUnpooling corner case 285 | -- https://github.com/fmassa/optimize-net/issues/14 286 | local reusableBuffers = utils.copyTable(reusableBuffers) 287 | if #net:findModules('nn.SpatialMaxUnpooling') > 0 then 288 | reusableBuffers['nn.SpatialMaxPooling'] = {{},{}} 289 | end 290 | if reuseBuffers then 291 | local reusedBuffers = {} 292 | net:apply(function(m) 293 | local name = torch.typename(m) 294 | if reusableBuffers[name] then 295 | local rb = reusableBuffers[name][mode_idx] 296 | for k, v in ipairs(rb) do 297 | if m[v] then 298 | reusedBuffers[name..','..v] = reusedBuffers[name..','..v] or m[v]:storage() 299 | if reusedBuffers[name..','..v] then 300 | m[v]:set(reusedBuffers[name..','..v]) 301 | end 302 | end 303 | end 304 | end 305 | end) 306 | end 307 | end 308 | 309 | -- needed for cudnn 310 | local function resetInputDescriptors(net) 311 | net:apply(function(m) 312 | if torch.typename(m):find('cudnn') and 313 | torch.typename(m.iSize) == 'torch.LongStorage' then 314 | m.iSize:fill(0) 315 | end 316 | end) 317 | end 318 | 319 | -- need to keep a list of shared gradParams 320 | -- to avoid problems when removing the optimization 321 | local function removeGradParams(net, opts) 322 | local removeGradParams = defaultValue(opts.removeGradParams, true) 323 | local mode = defaultValue(opts.mode, 'inference') 324 | if not removeGradParams then return end 325 | if mode == 'training' then return end 326 | local storages = {} 327 | net:apply(function(m) 328 | for _, k in ipairs({'gradWeight','gradBias'}) do 329 | if m[k] and m[k]:storage() then 330 | local strPtr = torch.pointer(m[k]:storage()) 331 | local strOffset = m[k]:storageOffset() 332 | local strSize = m[k]:storage():size() 333 | local tPtr = torch.pointer(m[k]) 334 | storages[tPtr] = {storage=strPtr, offSet=strOffset, 335 | size=strSize, stride=m[k]:stride()} 336 | m[k]:set() 337 | end 338 | end 339 | -- disabling getParameters 340 | m.getParameters = function(self) 341 | error('getParameters was disabled by optnet '.. 342 | '(by option removeGradParams=true). '.. 343 | 'Call optnet.removeOptimization(net) to enable it back.') 344 | end 345 | end) 346 | net.__gradParamsInfo = storages 347 | end 348 | 349 | local function addGradParams(net) 350 | local storages = net.__gradParamsInfo 351 | if not storages then return end 352 | local createdStorages = {} 353 | net:apply(function(m) 354 | for k, v in pairs({weight='gradWeight',bias='gradBias'}) do 355 | if m[v] then 356 | local tPtr = torch.pointer(m[v]) 357 | local info = storages[tPtr] 358 | if not createdStorages[info.storage] then 359 | local strSize = info.size 360 | createdStorages[info.storage] = m[v].new(strSize):storage() 361 | end 362 | local storage = createdStorages[info.storage] 363 | local tSize = m[k]:size() 364 | local tStride = info.stride 365 | local tOffset = info.offSet 366 | m[v]:set(storage, tOffset, tSize, tStride) 367 | end 368 | end 369 | -- add back original getParameters 370 | m.getParameters = nil 371 | end) 372 | net.__gradParamsInfo = nil 373 | end 374 | 375 | 376 | function optnet.optimizeMemory(net, input, opts) 377 | opts = opts or {} 378 | 379 | if net.__memoryOptimized then 380 | print('Skipping memory optimization. '.. 381 | 'Network is already optimized for '..net.__memoryOptimized..' mode.') 382 | return 383 | end 384 | 385 | local mode = defaultValue(opts.mode,'inference') 386 | 387 | local out = net['forward'](net, input) 388 | local grad 389 | if mode == 'training' then 390 | grad = utils.recursiveClone(out) 391 | net['backward'](net, input, grad) 392 | end 393 | 394 | setInplace(net, opts) 395 | reuseStateBuffers(net, opts) 396 | removeGradParams(net, opts) 397 | 398 | -- share outputs 399 | local analysis = analyse(net, input, opts) 400 | --print(analysis) 401 | local assignments = assign(net,analysis) 402 | --print(assignments) 403 | applyAssignments(net, assignments) 404 | resetInputDescriptors(net) 405 | 406 | -- add flag to mention that it was optimized 407 | net.__memoryOptimized = mode 408 | end 409 | 410 | function optnet.removeOptimization(net) 411 | 412 | if not net.__memoryOptimized then 413 | print('Skipping memory optimization removal, as the network was not optimized.') 414 | return 415 | end 416 | 417 | local function rem(m) 418 | if torch.isTensor(m) then 419 | m:set() 420 | end 421 | if torch.type(m) == 'table' then 422 | for k, v in ipairs(m) do 423 | rem(v) 424 | end 425 | end 426 | end 427 | 428 | net:apply(function(m) 429 | rem(m.output) 430 | rem(m.gradInput) 431 | local name = torch.typename(m) 432 | if reusableBuffers[name] then 433 | local rb = reusableBuffers[name][1] 434 | for k, v in ipairs(rb) do 435 | if m[v] then 436 | m[v]:set() 437 | end 438 | end 439 | end 440 | -- remove backward blocking 441 | m.updateGradInput = nil 442 | end) 443 | resetInputDescriptors(net) 444 | addGradParams(net) 445 | 446 | net.__memoryOptimized = nil 447 | end 448 | 449 | return optnet 450 | 451 | -------------------------------------------------------------------------------- /models.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local models = {} 4 | models.basic_parallel = function() 5 | local m = nn.Sequential() 6 | local prl = nn.ParallelTable() 7 | prl:add(nn.Linear(2,2)) 8 | prl:add(nn.Sequential():add(nn.Linear(2,1)):add(nn.Sigmoid()):add(nn.Linear(1,1))) 9 | m:add(prl) 10 | m:add(nn.JoinTable(2)) 11 | m:add(nn.Linear(3,2)) 12 | m:add(nn.ReLU(true)) 13 | 14 | local input = {torch.rand(2,2), torch.rand(2,2)} 15 | return m, input 16 | end 17 | models.basic_conv = function() 18 | local m = nn.Sequential() 19 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 20 | -- m:add(nn.ReLU(true)) 21 | -- m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 22 | -- m:add(nn.ReLU(true)) 23 | m:add(nn.View(32*32):setNumInputDims(3)) 24 | m:add(nn.Linear(32*32,100)) 25 | -- m:add(nn.ReLU(true)) 26 | -- m:add(nn.Linear(100,10)) 27 | local input = torch.rand(1,1,32,32) 28 | return m, input 29 | end 30 | 31 | models.basic_deep_conv = function() 32 | local inplace = true 33 | local m = nn.Sequential() 34 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 35 | m:add(nn.ReLU(inplace)) 36 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 37 | m:add(nn.ReLU(inplace)) 38 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 39 | m:add(nn.ReLU(inplace)) 40 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 41 | m:add(nn.ReLU(inplace)) 42 | m:add(nn.View(32*32):setNumInputDims(3)) 43 | m:add(nn.Linear(32*32,100)) 44 | m:add(nn.ReLU(inplace)) 45 | m:add(nn.Linear(100,10)) 46 | local input = torch.rand(1,1,32,32) 47 | return m, input 48 | end 49 | 50 | models.basic_unpooling = function() 51 | local inplace = true 52 | local m = nn.Sequential() 53 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 54 | m:add(nn.ReLU(inplace)) 55 | local mp1 = nn.SpatialMaxPooling(2,2,2,2) 56 | m:add(mp1) 57 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 58 | m:add(nn.ReLU(inplace)) 59 | local mp2 = nn.SpatialMaxPooling(2,2,2,2) 60 | m:add(mp2) 61 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 62 | m:add(nn.ReLU(inplace)) 63 | m:add(nn.SpatialMaxUnpooling(mp2)) 64 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 65 | m:add(nn.ReLU(inplace)) 66 | m:add(nn.SpatialMaxUnpooling(mp1)) 67 | m:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 68 | local input = torch.rand(1,1,32,32) 69 | return m, input 70 | end 71 | 72 | models.siamese = function() 73 | local inplace = false 74 | local b1 = nn.Sequential() 75 | b1:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 76 | b1:add(nn.ReLU(inplace)) 77 | b1:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 78 | b1:add(nn.ReLU(inplace)) 79 | b1:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 80 | b1:add(nn.ReLU(inplace)) 81 | b1:add(nn.SpatialConvolution(1,1,3,3,1,1,1,1)) 82 | b1:add(nn.ReLU(inplace)) 83 | b1:add(nn.View(-1):setNumInputDims(3)) 84 | 85 | local b2 = b1:clone('weight','bias','gradWeight','gradBias') 86 | local prl = nn.ParallelTable() 87 | prl:add(b1) 88 | prl:add(b2) 89 | 90 | m = nn.Sequential() 91 | m:add(prl) 92 | m:add(nn.PairwiseDistance(2)) 93 | local input = {torch.rand(1,1,32,32), torch.rand(1,1,32,32)} 94 | return m, input 95 | end 96 | 97 | models.basic_concat = function() 98 | local m = nn.Sequential() 99 | local cat = nn.ConcatTable() 100 | local b1 = nn.Sequential():add(nn.Linear(2,1)):add(nn.ReLU(true)):add(nn.Linear(1,1)) 101 | local b2 = nn.Sequential():add(nn.Linear(2,2)):add(nn.ReLU()) 102 | local b3 = nn.Sequential():add(nn.Linear(2,3)):add(nn.ReLU(true)):add(nn.Linear(3,2)) 103 | cat:add(b1):add(b2):add(b3) 104 | local cat2 = nn.ConcatTable() 105 | local bb1 = nn.SelectTable(1) 106 | local bb2 = nn.NarrowTable(2,2) 107 | cat2:add(bb1):add(bb2) 108 | local prl = nn.ParallelTable() 109 | local bbb1 = nn.Sequential():add(nn.Linear(1,2)) 110 | local bbb2 = nn.CAddTable() 111 | prl:add(bbb1):add(bbb2) 112 | local final = nn.CMulTable() 113 | m:add(cat):add(cat2):add(prl):add(final) 114 | 115 | local input = torch.rand(2,2) 116 | return m, input 117 | 118 | end 119 | 120 | models.basic_multiOutput = function() 121 | local m = nn.Sequential() 122 | m:add(nn.Linear(2,2)) 123 | m:add(nn.ReLU()) 124 | m:add(nn.Linear(2,2)) 125 | m:add(nn.ReLU()) 126 | m:add(nn.Linear(2,2)) 127 | m:add(nn.ReLU()) 128 | local p = nn.ConcatTable() 129 | p:add(nn.Linear(2,2)) 130 | p:add(nn.Linear(2,2)) 131 | p:add(nn.Linear(2,2)) 132 | 133 | m:add(p) 134 | 135 | local input = torch.rand(2,2) 136 | return m, input 137 | end 138 | 139 | models.alexnet = function() 140 | -- taken from soumith's imagenet-multiGPU 141 | -- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/alexnet.lua 142 | local features = nn.Concat(2) 143 | local fb1 = nn.Sequential() -- branch 1 144 | fb1:add(nn.SpatialConvolution(3,48,11,11,4,4,2,2)) -- 224 -> 55 145 | fb1:add(nn.ReLU(true)) 146 | fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 55 -> 27 147 | fb1:add(nn.SpatialConvolution(48,128,5,5,1,1,2,2)) -- 27 -> 27 148 | fb1:add(nn.ReLU(true)) 149 | fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 27 -> 13 150 | fb1:add(nn.SpatialConvolution(128,192,3,3,1,1,1,1)) -- 13 -> 13 151 | fb1:add(nn.ReLU(true)) 152 | fb1:add(nn.SpatialConvolution(192,192,3,3,1,1,1,1)) -- 13 -> 13 153 | fb1:add(nn.ReLU(true)) 154 | fb1:add(nn.SpatialConvolution(192,128,3,3,1,1,1,1)) -- 13 -> 13 155 | fb1:add(nn.ReLU(true)) 156 | fb1:add(nn.SpatialMaxPooling(3,3,2,2)) -- 13 -> 6 157 | 158 | local fb2 = fb1:clone() -- branch 2 159 | for k,v in ipairs(fb2:findModules('nn.SpatialConvolution')) do 160 | v:reset() -- reset branch 2's weights 161 | end 162 | 163 | features:add(fb1) 164 | features:add(fb2) 165 | 166 | -- 1.3. Create Classifier (fully connected layers) 167 | local classifier = nn.Sequential() 168 | classifier:add(nn.View(256*6*6)) 169 | --classifier:add(nn.Dropout(0.5)) 170 | classifier:add(nn.Linear(256*6*6, 4096)) 171 | classifier:add(nn.Threshold(0, 1e-6)) 172 | --classifier:add(nn.Dropout(0.5)) 173 | classifier:add(nn.Linear(4096, 4096)) 174 | classifier:add(nn.Threshold(0, 1e-6)) 175 | classifier:add(nn.Linear(4096, 1000)) 176 | classifier:add(nn.LogSoftMax()) 177 | 178 | -- 1.4. Combine 1.1 and 1.3 to produce final model 179 | local model = nn.Sequential():add(features):add(classifier) 180 | model.imageSize = 256 181 | model.imageCrop = 224 182 | 183 | local input = torch.rand(1,3,model.imageCrop,model.imageCrop) 184 | 185 | return model, input 186 | end 187 | 188 | models.googlenet = function() 189 | -- taken from soumith's imagenet-multiGPU 190 | -- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/googlenet.lua 191 | local function inception(input_size, config) 192 | local concat = nn.Concat(2) 193 | if config[1][1] ~= 0 then 194 | local conv1 = nn.Sequential() 195 | conv1:add(nn.SpatialConvolution(input_size, config[1][1],1,1,1,1)):add(nn.ReLU(true)) 196 | concat:add(conv1) 197 | end 198 | 199 | local conv3 = nn.Sequential() 200 | conv3:add(nn.SpatialConvolution( input_size, config[2][1],1,1,1,1)):add(nn.ReLU(true)) 201 | conv3:add(nn.SpatialConvolution(config[2][1], config[2][2],3,3,1,1,1,1)):add(nn.ReLU(true)) 202 | concat:add(conv3) 203 | 204 | local conv3xx = nn.Sequential() 205 | conv3xx:add(nn.SpatialConvolution( input_size, config[3][1],1,1,1,1)):add(nn.ReLU(true)) 206 | conv3xx:add(nn.SpatialConvolution(config[3][1], config[3][2],3,3,1,1,1,1)):add(nn.ReLU(true)) 207 | conv3xx:add(nn.SpatialConvolution(config[3][2], config[3][2],3,3,1,1,1,1)):add(nn.ReLU(true)) 208 | concat:add(conv3xx) 209 | 210 | local pool = nn.Sequential() 211 | pool:add(nn.SpatialZeroPadding(1,1,1,1)) -- remove after getting nn R2 into fbcode 212 | if config[4][1] == 'max' then 213 | pool:add(nn.SpatialMaxPooling(3,3,1,1):ceil()) 214 | elseif config[4][1] == 'avg' then 215 | pool:add(nn.SpatialAveragePooling(3,3,1,1):ceil()) 216 | else 217 | error('Unknown pooling') 218 | end 219 | if config[4][2] ~= 0 then 220 | pool:add(nn.SpatialConvolution(input_size, config[4][2],1,1,1,1)):add(nn.ReLU(true)) 221 | end 222 | concat:add(pool) 223 | 224 | return concat 225 | end 226 | 227 | local nClasses = 1000 228 | 229 | local features = nn.Sequential() 230 | features:add(nn.SpatialConvolution(3,64,7,7,2,2,3,3)):add(nn.ReLU(true)) 231 | features:add(nn.SpatialMaxPooling(3,3,2,2):ceil()) 232 | features:add(nn.SpatialConvolution(64,64,1,1)):add(nn.ReLU(true)) 233 | features:add(nn.SpatialConvolution(64,192,3,3,1,1,1,1)):add(nn.ReLU(true)) 234 | features:add(nn.SpatialMaxPooling(3,3,2,2):ceil()) 235 | features:add(inception( 192, {{ 64},{ 64, 64},{ 64, 96},{'avg', 32}})) -- 3(a) 236 | features:add(inception( 256, {{ 64},{ 64, 96},{ 64, 96},{'avg', 64}})) -- 3(b) 237 | features:add(inception( 320, {{ 0},{128,160},{ 64, 96},{'max', 0}})) -- 3(c) 238 | features:add(nn.SpatialConvolution(576,576,2,2,2,2)) 239 | features:add(inception( 576, {{224},{ 64, 96},{ 96,128},{'avg',128}})) -- 4(a) 240 | features:add(inception( 576, {{192},{ 96,128},{ 96,128},{'avg',128}})) -- 4(b) 241 | features:add(inception( 576, {{160},{128,160},{128,160},{'avg', 96}})) -- 4(c) 242 | features:add(inception( 576, {{ 96},{128,192},{160,192},{'avg', 96}})) -- 4(d) 243 | 244 | local main_branch = nn.Sequential() 245 | main_branch:add(inception( 576, {{ 0},{128,192},{192,256},{'max', 0}})) -- 4(e) 246 | main_branch:add(nn.SpatialConvolution(1024,1024,2,2,2,2)) 247 | main_branch:add(inception(1024, {{352},{192,320},{160,224},{'avg',128}})) -- 5(a) 248 | main_branch:add(inception(1024, {{352},{192,320},{192,224},{'max',128}})) -- 5(b) 249 | main_branch:add(nn.SpatialAveragePooling(7,7,1,1)) 250 | main_branch:add(nn.View(1024):setNumInputDims(3)) 251 | main_branch:add(nn.Linear(1024,nClasses)) 252 | main_branch:add(nn.LogSoftMax()) 253 | 254 | -- add auxillary classifier here (thanks to Christian Szegedy for the details) 255 | local aux_classifier = nn.Sequential() 256 | aux_classifier:add(nn.SpatialAveragePooling(5,5,3,3):ceil()) 257 | aux_classifier:add(nn.SpatialConvolution(576,128,1,1,1,1)) 258 | aux_classifier:add(nn.View(128*4*4):setNumInputDims(3)) 259 | aux_classifier:add(nn.Linear(128*4*4,768)) 260 | aux_classifier:add(nn.ReLU()) 261 | aux_classifier:add(nn.Linear(768,nClasses)) 262 | aux_classifier:add(nn.LogSoftMax()) 263 | 264 | local splitter = nn.Concat(2) 265 | splitter:add(main_branch):add(aux_classifier) 266 | local model = nn.Sequential():add(features):add(splitter) 267 | 268 | model.imageSize = 256 269 | model.imageCrop = 224 270 | 271 | local input = torch.rand(1,3,model.imageCrop,model.imageCrop) 272 | 273 | return model, input 274 | 275 | 276 | end 277 | 278 | models.vgg = function(modelType) 279 | -- taken from soumith's imagenet-multiGPU 280 | -- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/vgg.lua 281 | 282 | local nClasses = 1000 283 | 284 | local modelType = modelType or 'A' -- on a titan black, B/D/E run out of memory even for batch-size 32 285 | 286 | -- Create tables describing VGG configurations A, B, D, E 287 | local cfg = {} 288 | if modelType == 'A' then 289 | cfg = {64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'} 290 | elseif modelType == 'B' then 291 | cfg = {64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'} 292 | elseif modelType == 'D' then 293 | cfg = {64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'} 294 | elseif modelType == 'E' then 295 | cfg = {64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'} 296 | else 297 | error('Unknown model type: ' .. modelType .. ' | Please specify a modelType A or B or D or E') 298 | end 299 | 300 | local features = nn.Sequential() 301 | do 302 | local iChannels = 3; 303 | for k,v in ipairs(cfg) do 304 | if v == 'M' then 305 | features:add(nn.SpatialMaxPooling(2,2,2,2)) 306 | else 307 | local oChannels = v; 308 | local conv3 = nn.SpatialConvolution(iChannels,oChannels,3,3,1,1,1,1); 309 | features:add(conv3) 310 | features:add(nn.ReLU(true)) 311 | iChannels = oChannels; 312 | end 313 | end 314 | end 315 | 316 | local classifier = nn.Sequential() 317 | classifier:add(nn.View(512*7*7)) 318 | classifier:add(nn.Linear(512*7*7, 4096)) 319 | classifier:add(nn.Threshold(0, 1e-6)) 320 | -- classifier:add(nn.Dropout(0.5)) 321 | classifier:add(nn.Linear(4096, 4096)) 322 | classifier:add(nn.Threshold(0, 1e-6)) 323 | -- classifier:add(nn.Dropout(0.5)) 324 | classifier:add(nn.Linear(4096, nClasses)) 325 | classifier:add(nn.LogSoftMax()) 326 | 327 | local model = nn.Sequential() 328 | model:add(features):add(classifier) 329 | model.imageSize = 256 330 | model.imageCrop = 224 331 | 332 | local input = torch.rand(1,3,model.imageCrop,model.imageCrop) 333 | 334 | return model, input 335 | end 336 | 337 | models.resnet = function(opt) 338 | -- taken from https://github.com/facebook/fb.resnet.torch 339 | local Convolution = nn.SpatialConvolution 340 | local Avg = nn.SpatialAveragePooling 341 | local ReLU = nn.ReLU 342 | local Max = nn.SpatialMaxPooling 343 | local SBatchNorm = function(n) 344 | local m = nn.MulConstant(1) 345 | m.inplace = nil 346 | return m 347 | end--nn.SpatialBatchNormalization 348 | 349 | local function createModel(opt) 350 | local depth = opt.depth 351 | local shortcutType = opt.shortcutType or 'B' 352 | local iChannels 353 | 354 | -- The shortcut layer is either identity or 1x1 convolution 355 | local function shortcut(nInputPlane, nOutputPlane, stride) 356 | local useConv = shortcutType == 'C' or 357 | (shortcutType == 'B' and nInputPlane ~= nOutputPlane) 358 | if useConv then 359 | -- 1x1 convolution 360 | return nn.Sequential() 361 | :add(Convolution(nInputPlane, nOutputPlane, 1, 1, stride, stride)) 362 | :add(SBatchNorm(nOutputPlane)) 363 | elseif nInputPlane ~= nOutputPlane then 364 | -- Strided, zero-padded identity shortcut 365 | return nn.Sequential() 366 | :add(nn.SpatialAveragePooling(1, 1, stride, stride)) 367 | :add(nn.Concat(2) 368 | :add(nn.Identity()) 369 | :add(nn.MulConstant(0))) 370 | else 371 | return nn.Identity() 372 | end 373 | end 374 | 375 | -- The basic residual layer block for 18 and 34 layer network, and the 376 | -- CIFAR networks 377 | local function basicblock(n, stride) 378 | local nInputPlane = iChannels 379 | iChannels = n 380 | 381 | local s = nn.Sequential() 382 | s:add(Convolution(nInputPlane,n,3,3,stride,stride,1,1)) 383 | s:add(SBatchNorm(n)) 384 | s:add(ReLU(true)) 385 | s:add(Convolution(n,n,3,3,1,1,1,1)) 386 | s:add(SBatchNorm(n)) 387 | 388 | return nn.Sequential() 389 | :add(nn.ConcatTable() 390 | :add(s) 391 | :add(shortcut(nInputPlane, n, stride))) 392 | :add(nn.CAddTable(true)) 393 | :add(ReLU(true)) 394 | end 395 | 396 | -- The bottleneck residual layer for 50, 101, and 152 layer networks 397 | local function bottleneck(n, stride) 398 | local nInputPlane = iChannels 399 | iChannels = n * 4 400 | 401 | local s = nn.Sequential() 402 | s:add(Convolution(nInputPlane,n,1,1,1,1,0,0)) 403 | s:add(SBatchNorm(n)) 404 | s:add(ReLU(true)) 405 | s:add(Convolution(n,n,3,3,stride,stride,1,1)) 406 | s:add(SBatchNorm(n)) 407 | s:add(ReLU(true)) 408 | s:add(Convolution(n,n*4,1,1,1,1,0,0)) 409 | s:add(SBatchNorm(n * 4)) 410 | 411 | return nn.Sequential() 412 | :add(nn.ConcatTable() 413 | :add(s) 414 | :add(shortcut(nInputPlane, n * 4, stride))) 415 | :add(nn.CAddTable(true)) 416 | :add(ReLU(true)) 417 | end 418 | 419 | -- Creates count residual blocks with specified number of features 420 | local function layer(block, features, count, stride) 421 | local s = nn.Sequential() 422 | for i=1,count do 423 | s:add(block(features, i == 1 and stride or 1)) 424 | end 425 | return s 426 | end 427 | 428 | local model = nn.Sequential() 429 | local input 430 | if opt.dataset == 'imagenet' then 431 | -- Configurations for ResNet: 432 | -- num. residual blocks, num features, residual block function 433 | local cfg = { 434 | [18] = {{2, 2, 2, 2}, 512, basicblock}, 435 | [34] = {{3, 4, 6, 3}, 512, basicblock}, 436 | [50] = {{3, 4, 6, 3}, 2048, bottleneck}, 437 | [101] = {{3, 4, 23, 3}, 2048, bottleneck}, 438 | [152] = {{3, 8, 36, 3}, 2048, bottleneck}, 439 | } 440 | 441 | assert(cfg[depth], 'Invalid depth: ' .. tostring(depth)) 442 | local def, nFeatures, block = table.unpack(cfg[depth]) 443 | iChannels = 64 444 | --print(' | ResNet-' .. depth .. ' ImageNet') 445 | 446 | -- The ResNet ImageNet model 447 | model:add(Convolution(3,64,7,7,2,2,3,3)) 448 | model:add(SBatchNorm(64)) 449 | model:add(ReLU(true)) 450 | model:add(Max(3,3,2,2,1,1)) 451 | model:add(layer(block, 64, def[1])) 452 | model:add(layer(block, 128, def[2], 2)) 453 | model:add(layer(block, 256, def[3], 2)) 454 | model:add(layer(block, 512, def[4], 2)) 455 | model:add(Avg(7, 7, 1, 1)) 456 | model:add(nn.View(nFeatures):setNumInputDims(3)) 457 | model:add(nn.Linear(nFeatures, 1000)) 458 | 459 | input = torch.rand(1,3,224,224) 460 | elseif opt.dataset == 'cifar10' then 461 | -- Model type specifies number of layers for CIFAR-10 model 462 | assert((depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110, 1202') 463 | local n = (depth - 2) / 6 464 | iChannels = 16 465 | --print(' | ResNet-' .. depth .. ' CIFAR-10') 466 | 467 | -- The ResNet CIFAR-10 model 468 | model:add(Convolution(3,16,3,3,1,1,1,1)) 469 | model:add(SBatchNorm(16)) 470 | model:add(ReLU(true)) 471 | model:add(layer(basicblock, 16, n)) 472 | model:add(layer(basicblock, 32, n, 2)) 473 | model:add(layer(basicblock, 64, n, 2)) 474 | model:add(Avg(8, 8, 1, 1)) 475 | model:add(nn.View(64):setNumInputDims(3)) 476 | model:add(nn.Linear(64, 10)) 477 | input = torch.rand(1,3,32,32) 478 | else 479 | error('invalid dataset: ' .. opt.dataset) 480 | end 481 | 482 | local function ConvInit(name) 483 | for k,v in pairs(model:findModules(name)) do 484 | local n = v.kW*v.kH*v.nOutputPlane 485 | v.weight:normal(0,math.sqrt(2/n)) 486 | if false and cudnn.version >= 4000 then 487 | v.bias = nil 488 | v.gradBias = nil 489 | else 490 | v.bias:zero() 491 | end 492 | end 493 | end 494 | local function BNInit(name) 495 | for k,v in pairs(model:findModules(name)) do 496 | v.weight:fill(1) 497 | v.bias:zero() 498 | end 499 | end 500 | 501 | ConvInit('cudnn.SpatialConvolution') 502 | ConvInit('nn.SpatialConvolution') 503 | BNInit('fbnn.SpatialBatchNormalization') 504 | BNInit('cudnn.SpatialBatchNormalization') 505 | BNInit('nn.SpatialBatchNormalization') 506 | for k,v in pairs(model:findModules('nn.Linear')) do 507 | v.bias:zero() 508 | end 509 | 510 | if opt.cudnn == 'deterministic' then 511 | model:apply(function(m) 512 | if m.setMode then m:setMode(1,1,1) end 513 | end) 514 | end 515 | 516 | return model, input 517 | end 518 | 519 | return createModel(opt) 520 | end 521 | 522 | models.preresnet = function(opt) 523 | 524 | local Convolution = nn.SpatialConvolution 525 | local Avg = nn.SpatialAveragePooling 526 | local ReLU = nn.ReLU 527 | local Max = nn.SpatialMaxPooling 528 | local SBatchNorm = function(n) 529 | local m = nn.MulConstant(1) 530 | m.inplace = nil 531 | return m 532 | end--nn.SpatialBatchNormalization 533 | 534 | local function createModel(opt) 535 | local depth = opt.depth 536 | local shortcutType = opt.shortcutType or 'B' 537 | local iChannels 538 | 539 | -- The shortcut layer is either identity or 1x1 convolution 540 | local function shortcut(nInputPlane, nOutputPlane, stride) 541 | local useConv = shortcutType == 'C' or 542 | (shortcutType == 'B' and nInputPlane ~= nOutputPlane) 543 | if useConv then 544 | -- 1x1 convolution 545 | return nn.Sequential() 546 | :add(Convolution(nInputPlane, nOutputPlane, 1, 1, stride, stride)) 547 | elseif nInputPlane ~= nOutputPlane then 548 | -- Strided, zero-padded identity shortcut 549 | return nn.Sequential() 550 | :add(nn.SpatialAveragePooling(1, 1, stride, stride)) 551 | :add(nn.Concat(2) 552 | :add(nn.Identity()) 553 | :add(nn.MulConstant(0))) 554 | else 555 | return nn.Identity() 556 | end 557 | end 558 | 559 | -- The basic residual layer block for 18 and 34 layer network, and the 560 | -- CIFAR networks 561 | local function basicblock(n, stride, type) 562 | local nInputPlane = iChannels 563 | iChannels = n 564 | 565 | local block = nn.Sequential() 566 | local s = nn.Sequential() 567 | if type == 'both_preact' then 568 | block:add(SBatchNorm(nInputPlane)) 569 | block:add(ReLU(true)) 570 | elseif type ~= 'no_preact' then 571 | s:add(SBatchNorm(nInputPlane)) 572 | s:add(ReLU(true)) 573 | end 574 | s:add(Convolution(nInputPlane,n,3,3,stride,stride,1,1)) 575 | s:add(SBatchNorm(n)) 576 | s:add(ReLU(true)) 577 | s:add(Convolution(n,n,3,3,1,1,1,1)) 578 | 579 | return block 580 | :add(nn.ConcatTable() 581 | :add(s) 582 | :add(shortcut(nInputPlane, n, stride))) 583 | :add(nn.CAddTable(true)) 584 | end 585 | 586 | -- The bottleneck residual layer for 50, 101, and 152 layer networks 587 | local function bottleneck(n, stride, type) 588 | local nInputPlane = iChannels 589 | iChannels = n * 4 590 | 591 | local block = nn.Sequential() 592 | local s = nn.Sequential() 593 | if type == 'both_preact' then 594 | block:add(SBatchNorm(nInputPlane)) 595 | block:add(ReLU(true)) 596 | elseif type ~= 'no_preact' then 597 | s:add(SBatchNorm(nInputPlane)) 598 | s:add(ReLU(true)) 599 | end 600 | s:add(Convolution(nInputPlane,n,1,1,1,1,0,0)) 601 | s:add(SBatchNorm(n)) 602 | s:add(ReLU(true)) 603 | s:add(Convolution(n,n,3,3,stride,stride,1,1)) 604 | s:add(SBatchNorm(n)) 605 | s:add(ReLU(true)) 606 | s:add(Convolution(n,n*4,1,1,1,1,0,0)) 607 | 608 | return block 609 | :add(nn.ConcatTable() 610 | :add(s) 611 | :add(shortcut(nInputPlane, n * 4, stride))) 612 | :add(nn.CAddTable(true)) 613 | end 614 | 615 | -- Creates count residual blocks with specified number of features 616 | local function layer(block, features, count, stride, type) 617 | local s = nn.Sequential() 618 | if count < 1 then 619 | return s 620 | end 621 | s:add(block(features, stride, 622 | type == 'first' and 'no_preact' or 'both_preact')) 623 | for i=2,count do 624 | s:add(block(features, 1)) 625 | end 626 | return s 627 | end 628 | 629 | local model = nn.Sequential() 630 | local input 631 | if opt.dataset == 'imagenet' then 632 | -- Configurations for ResNet: 633 | -- num. residual blocks, num features, residual block function 634 | local cfg = { 635 | [18] = {{2, 2, 2, 2}, 512, basicblock}, 636 | [34] = {{3, 4, 6, 3}, 512, basicblock}, 637 | [50] = {{3, 4, 6, 3}, 2048, bottleneck}, 638 | [101] = {{3, 4, 23, 3}, 2048, bottleneck}, 639 | [152] = {{3, 8, 36, 3}, 2048, bottleneck}, 640 | [200] = {{3, 24, 36, 3}, 2048, bottleneck}, 641 | } 642 | 643 | assert(cfg[depth], 'Invalid depth: ' .. tostring(depth)) 644 | local def, nFeatures, block = table.unpack(cfg[depth]) 645 | iChannels = 64 646 | --print(' | ResNet-' .. depth .. ' ImageNet') 647 | 648 | -- The ResNet ImageNet model 649 | model:add(Convolution(3,64,7,7,2,2,3,3)) 650 | model:add(SBatchNorm(64)) 651 | model:add(ReLU(true)) 652 | model:add(Max(3,3,2,2,1,1)) 653 | model:add(layer(block, 64, def[1], 1, 'first')) 654 | model:add(layer(block, 128, def[2], 2)) 655 | model:add(layer(block, 256, def[3], 2)) 656 | model:add(layer(block, 512, def[4], 2)) 657 | model:add(nn.Copy(nil, nil, true)) 658 | model:add(SBatchNorm(iChannels)) 659 | model:add(ReLU(true)) 660 | model:add(Avg(7, 7, 1, 1)) 661 | model:add(nn.View(nFeatures):setNumInputDims(3)) 662 | model:add(nn.Linear(nFeatures, 1000)) 663 | 664 | input = torch.rand(1,3,224,224) 665 | elseif opt.dataset == 'cifar10' then 666 | -- Model type specifies number of layers for CIFAR-10 model 667 | assert((depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110, 1202') 668 | local n = (depth - 2) / 6 669 | iChannels = 16 670 | --print(' | ResNet-' .. depth .. ' CIFAR-10') 671 | 672 | -- The ResNet CIFAR-10 model 673 | model:add(Convolution(3,16,3,3,1,1,1,1)) 674 | model:add(layer(basicblock, 16, n, 1)) 675 | model:add(layer(basicblock, 32, n, 2)) 676 | model:add(layer(basicblock, 64, n, 2)) 677 | model:add(nn.Copy(nil, nil, true)) 678 | model:add(SBatchNorm(iChannels)) 679 | model:add(ReLU(true)) 680 | model:add(Avg(8, 8, 1, 1)) 681 | model:add(nn.View(64):setNumInputDims(3)) 682 | model:add(nn.Linear(64, 10)) 683 | 684 | input = torch.rand(1,3,32,32) 685 | else 686 | error('invalid dataset: ' .. opt.dataset) 687 | end 688 | 689 | local function ConvInit(name) 690 | for k,v in pairs(model:findModules(name)) do 691 | local n = v.kW*v.kH*v.nOutputPlane 692 | v.weight:normal(0,math.sqrt(2/n)) 693 | if false and cudnn.version >= 4000 then 694 | v.bias = nil 695 | v.gradBias = nil 696 | else 697 | v.bias:zero() 698 | end 699 | end 700 | end 701 | local function BNInit(name) 702 | for k,v in pairs(model:findModules(name)) do 703 | v.weight:fill(1) 704 | v.bias:zero() 705 | end 706 | end 707 | 708 | --ConvInit('cudnn.SpatialConvolution') 709 | ConvInit('nn.SpatialConvolution') 710 | --BNInit('fbnn.SpatialBatchNormalization') 711 | --BNInit('cudnn.SpatialBatchNormalization') 712 | BNInit('nn.SpatialBatchNormalization') 713 | for k,v in pairs(model:findModules('nn.Linear')) do 714 | v.bias:zero() 715 | end 716 | --model:cuda() 717 | 718 | if opt.cudnn == 'deterministic' then 719 | model:apply(function(m) 720 | if m.setMode then m:setMode(1,1,1) end 721 | end) 722 | end 723 | 724 | --model:get(1).gradInput = nil 725 | 726 | return model, input 727 | end 728 | 729 | return createModel(opt) 730 | 731 | end 732 | 733 | 734 | return models 735 | -------------------------------------------------------------------------------- /rocks/optnet-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "optnet" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/fmassa/optimize-net", 6 | tag = "master" 7 | } 8 | 9 | description = { 10 | summary = "This package provides memory optimizations for nn library in Torch7.", 11 | homepage = "git://github.com/fmassa/optimize-net", 12 | license = "BSD" 13 | } 14 | 15 | dependencies = { 16 | "torch >= 7.0", 17 | "graph", 18 | "nn" 19 | } 20 | 21 | build = { 22 | type = "command", 23 | build_command = [[ 24 | cmake -E make_directory build; 25 | cd build; 26 | cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)"; 27 | $(MAKE) 28 | ]], 29 | install_command = "cd build && $(MAKE) install" 30 | } 31 | -------------------------------------------------------------------------------- /tests.lua: -------------------------------------------------------------------------------- 1 | local optnet = require 'optnet.env' 2 | local models = require 'optnet.models' 3 | local utils = require 'optnet.utils' 4 | local countUsedMemory = optnet.countUsedMemory 5 | 6 | local optest = torch.TestSuite() 7 | local tester = torch.Tester() 8 | 9 | local use_cudnn = false 10 | local backward_tol = 1e-6 11 | 12 | local function resizeAndConvert(input, type) 13 | local res 14 | local s = 64 15 | if torch.isTensor(input) then 16 | local iSize = torch.Tensor(input:size():totable())[{{2,-1}}] 17 | res = torch.rand(s,table.unpack(iSize:totable())):type(type) 18 | else 19 | res = {} 20 | for k, v in ipairs(input) do 21 | res[k] = resizeAndConvert(v,type) 22 | end 23 | end 24 | return res 25 | end 26 | 27 | local function cudnnSetDeterministic(net) 28 | net:apply(function(m) 29 | if m.setMode then m:setMode(1, 1, 1) end 30 | end) 31 | end 32 | 33 | local function genericTestForward(model,opts) 34 | local net, input = models[model](opts) 35 | net:evaluate() 36 | 37 | if use_cudnn then 38 | cudnn.convert(net,cudnn); 39 | net:cuda(); 40 | 41 | input = resizeAndConvert(input,'torch.CudaTensor') 42 | end 43 | 44 | local out_orig = utils.recursiveClone(net:forward(input)) 45 | 46 | local mems1 = optnet.countUsedMemory(net) 47 | 48 | optnet.optimizeMemory(net, input) 49 | 50 | local out = utils.recursiveClone(net:forward(input)) 51 | local mems2 = countUsedMemory(net) 52 | tester:eq(out_orig, out, 'Outputs differ after optimization of '..model) 53 | 54 | local mem1 = mems1.total_size 55 | local mem2 = mems2.total_size 56 | 57 | local omem1 = mems1.outputs 58 | local omem2 = mems2.outputs 59 | 60 | local bmem1 = mems1.buffers 61 | local bmem2 = mems2.buffers 62 | 63 | local pmem1 = mems1.params 64 | local pmem2 = mems2.params 65 | 66 | tester:assertle(mem2, mem1, 'Optimized model uses more memory! '.. 67 | 'Before: '.. mem1..' bytes, After: '..mem2..' bytes') 68 | print('Memory use') 69 | print('Total', mem1/1024/1024, mem2/1024/1024, 1-mem2/mem1) 70 | print('Outputs',omem1/1024/1024,omem2/1024/1024, 1-omem2/omem1) 71 | print('Buffers',bmem1/1024/1024,bmem2/1024/1024, 1-bmem2/bmem1) 72 | print('Params', pmem1/1024/1024,pmem2/1024/1024, 1-pmem2/pmem1) 73 | end 74 | 75 | ------------------------------------------------- 76 | -- Backward 77 | ------------------------------------------------- 78 | 79 | local function genericTestBackward(model,opts) 80 | local net, input = models[model](opts) 81 | net:training() 82 | 83 | if use_cudnn then 84 | cudnn.convert(net,cudnn); 85 | cudnnSetDeterministic(net) 86 | net:cuda(); 87 | 88 | input = resizeAndConvert(input,'torch.CudaTensor') 89 | end 90 | 91 | local out_orig = utils.recursiveClone(net:forward(input)) 92 | local grad_orig = utils.recursiveClone(out_orig) 93 | net:zeroGradParameters() 94 | local gradInput_orig = utils.recursiveClone(net:backward(input, grad_orig)) 95 | local _, gradParams_orig = net:getParameters() 96 | gradParams_orig = gradParams_orig:clone() 97 | 98 | local mems1 = optnet.countUsedMemory(net) 99 | 100 | optnet.optimizeMemory(net, input, {mode='training'}) 101 | 102 | local out = utils.recursiveClone(net:forward(input)) 103 | local grad = utils.recursiveClone(out) 104 | net:zeroGradParameters() 105 | local gradInput = utils.recursiveClone(net:backward(input, grad)) 106 | local _, gradParams = net:getParameters() 107 | gradParams = gradParams:clone() 108 | 109 | local mems2 = countUsedMemory(net) 110 | tester:eq(out_orig, out, 'Outputs differ after optimization of '..model) 111 | tester:eq(gradInput_orig, gradInput, backward_tol, 'GradInputs differ after optimization of '..model) 112 | tester:eq(gradParams_orig, gradParams, backward_tol, 'GradParams differ after optimization of '..model) 113 | 114 | local mem1 = mems1.total_size 115 | local mem2 = mems2.total_size 116 | 117 | local omem1 = mems1.outputs 118 | local omem2 = mems2.outputs 119 | 120 | local imem1 = mems1.gradInputs 121 | local imem2 = mems2.gradInputs 122 | 123 | local bmem1 = mems1.buffers 124 | local bmem2 = mems2.buffers 125 | 126 | local pmem1 = mems1.params 127 | local pmem2 = mems2.params 128 | 129 | tester:assertle(mem2, mem1, 'Optimized model uses more memory! '.. 130 | 'Before: '.. mem1..' bytes, After: '..mem2..' bytes') 131 | print('Memory use') 132 | print('Total', mem1/1024/1024, mem2/1024/1024, 1-mem2/mem1) 133 | print('Outputs',omem1/1024/1024,omem2/1024/1024, 1-omem2/omem1) 134 | print('gradInputs',imem1/1024/1024,imem2/1024/1024, 1-imem2/imem1) 135 | print('Buffers',bmem1/1024/1024,bmem2/1024/1024, 1-bmem2/bmem1) 136 | print('Params', pmem1/1024/1024,pmem2/1024/1024, 1-pmem2/pmem1) 137 | end 138 | 139 | ------------------------------------------------- 140 | -- removing optimization 141 | ------------------------------------------------- 142 | 143 | local function genericTestRemoveOptim(model,opts) 144 | local net, input = models[model](opts) 145 | net:training() 146 | 147 | if use_cudnn then 148 | cudnn.convert(net,cudnn); 149 | cudnnSetDeterministic(net) 150 | net:cuda(); 151 | 152 | input = resizeAndConvert(input,'torch.CudaTensor') 153 | end 154 | 155 | local out_orig = utils.recursiveClone(net:forward(input)) 156 | local grad_orig = utils.recursiveClone(out_orig) 157 | net:zeroGradParameters() 158 | local gradInput_orig = utils.recursiveClone(net:backward(input, grad_orig)) 159 | local _, gradParams_orig = net:getParameters() 160 | gradParams_orig = gradParams_orig:clone() 161 | 162 | optnet.optimizeMemory(net, input) 163 | optnet.removeOptimization(net) 164 | 165 | local out = utils.recursiveClone(net:forward(input)) 166 | local grad = utils.recursiveClone(out) 167 | net:zeroGradParameters() 168 | local gradInput = utils.recursiveClone(net:backward(input, grad)) 169 | local _, gradParams = net:getParameters() 170 | gradParams = gradParams:clone() 171 | 172 | tester:eq(out_orig, out, 'Outputs differ after optimization of '..model) 173 | tester:eq(gradInput_orig, gradInput, backward_tol, 'GradInputs differ after optimization of '..model) 174 | tester:eq(gradParams_orig, gradParams, backward_tol, 'GradParams differ after optimization of '..model) 175 | end 176 | 177 | for k, v in pairs(models) do 178 | if k ~= 'resnet' and k ~= 'preresnet' then 179 | optest[k] = function() 180 | genericTestForward(k) 181 | end 182 | optest[k..'_backward'] = function() 183 | genericTestBackward(k) 184 | end 185 | optest[k..'_remove'] = function() 186 | genericTestRemoveOptim(k) 187 | end 188 | end 189 | end 190 | 191 | for _, v in ipairs({20,32,56,110}) do 192 | for _, k in ipairs{'resnet', 'preresnet'} do 193 | local opts = {dataset='cifar10',depth=v} 194 | optest[k..v] = function() 195 | genericTestForward(k, opts) 196 | end 197 | optest[k..v..'_backward'] = function() 198 | genericTestBackward(k, opts) 199 | end 200 | optest[k..v..'_remove'] = function() 201 | genericTestRemoveOptim(k, opts) 202 | end 203 | end 204 | end 205 | 206 | tester:add(optest) 207 | 208 | function optnet.test(tests, opts) 209 | opts = opts or {} 210 | 211 | local tType = torch.getdefaulttensortype() 212 | torch.setdefaulttensortype('torch.FloatTensor') 213 | 214 | if opts.only_basic_tests then 215 | local disable = { 216 | 'alexnet','vgg','googlenet', 217 | 'resnet20','resnet32','resnet56','resnet110', 218 | 'preresnet20','preresnet32','preresnet56','preresnet110' 219 | } 220 | local toDisable = {} 221 | for _, v in ipairs(disable) do 222 | table.insert(toDisable,v) 223 | table.insert(toDisable,v..'_backward') 224 | table.insert(toDisable,v..'_remove') 225 | end 226 | tester:disable(toDisable) 227 | end 228 | if opts.use_cudnn then 229 | use_cudnn = true 230 | require 'cudnn' 231 | require 'cunn' 232 | end 233 | tester:run(tests) 234 | torch.setdefaulttensortype(tType) 235 | return tester 236 | end 237 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | local utils = {} 2 | 3 | local function keepTrack(t, track, entry_fun, fun, ...) 4 | if torch.isTensor(t) and t:storage() then 5 | local ptr = torch.pointer(t:storage()) 6 | if not track[ptr] then 7 | track[ptr] = entry_fun(t, ...) 8 | end 9 | if fun then 10 | fun(t,track,...) 11 | end 12 | return 13 | end 14 | if torch.type(t) == 'table' then 15 | for k, v in ipairs(t) do 16 | keepTrack(v, track, entry_fun, fun, ...) 17 | end 18 | end 19 | end 20 | utils.keepTrack = keepTrack 21 | 22 | local function recursiveClone(out) 23 | if torch.isTensor(out) then 24 | return out:clone() 25 | else 26 | local res = {} 27 | for k, v in ipairs(out) do 28 | res[k] = recursiveClone(v) 29 | end 30 | return res 31 | end 32 | end 33 | utils.recursiveClone = recursiveClone 34 | 35 | local function copyTable(t) 36 | if type(t) == 'table' then 37 | local r = {} 38 | for k, v in pairs(t) do 39 | r[k] = copyTable(v) 40 | end 41 | return r 42 | else 43 | return t 44 | end 45 | end 46 | utils.copyTable = copyTable 47 | 48 | return utils 49 | --------------------------------------------------------------------------------