├── .gitignore ├── .travis.yml ├── CMakeLists.txt ├── COPYRIGHT.txt ├── JustElement.lua ├── JustTable.lua ├── ModuleFromCriterion.lua ├── README.md ├── doc ├── annotation_bg.png ├── annotation_fg.png ├── mlp.png ├── mlp2.png ├── mlp3_backward.png ├── mlp3_forward.png ├── mlp4_backward.png ├── mlp4_forward.png └── my_bad_linear_net.png ├── gmodule.lua ├── graphinspecting.lua ├── init.lua ├── nest.lua ├── nesting.lua ├── nngraph-scm-1.rockspec ├── node.lua ├── simple_print.lua ├── test ├── speed.lua ├── test_JustElement.lua ├── test_JustTable.lua ├── test_ModuleFromCriterion.lua ├── test_connectivity.lua ├── test_debug.lua ├── test_nest.lua ├── test_nngraph.lua └── test_old.lua └── utils.lua /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | compiler: 3 | - gcc 4 | - clang 5 | cache: 6 | directories: 7 | - $HOME/OpenBlasInstall 8 | - $HOME/GraphViz 9 | sudo: false 10 | env: 11 | - TORCH_LUA_VERSION=LUAJIT21 12 | - TORCH_LUA_VERSION=LUA51 13 | - TORCH_LUA_VERSION=LUA52 14 | addons: 15 | apt: 16 | packages: 17 | - cmake 18 | - gfortran 19 | - gcc-multilib 20 | - gfortran-multilib 21 | - liblapack-dev 22 | - build-essential 23 | - gcc 24 | - g++ 25 | - curl 26 | - cmake 27 | - libreadline-dev 28 | - git-core 29 | - libqt4-core 30 | - libqt4-gui 31 | - libqt4-dev 32 | - libjpeg-dev 33 | - libpng-dev 34 | - ncurses-dev 35 | - imagemagick 36 | - libzmq3-dev 37 | - gfortran 38 | - unzip 39 | - gnuplot 40 | - gnuplot-x11 41 | before_script: 42 | - export ROOT_TRAVIS_DIR=$(pwd) 43 | - export INSTALL_PREFIX=~/torch/install 44 | - ls $HOME/OpenBlasInstall/lib || (cd /tmp/ && git clone https://github.com/xianyi/OpenBLAS.git -b master && cd OpenBLAS && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make PREFIX=$HOME/OpenBlasInstall install) 45 | - ls $HOME/GraphViz/lib || (cd /tmp/ && wget -c http://www.graphviz.org/pub/graphviz/stable/SOURCES/graphviz-2.38.0.tar.gz && tar -xvf graphviz-2.38.0.tar.gz && cd graphviz-2.38.0 && (./configure prefix=$HOME/GraphViz/ 2>/dev/null >/dev/null) && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make install) 46 | - export LD_LIBRARY_PATH=$HOME/GraphViz/lib:$LD_LIBRARY_PATH 47 | - git clone https://github.com/torch/distro.git ~/torch --recursive 48 | - cd ~/torch && git submodule update --init --recursive 49 | - mkdir build && cd build 50 | - export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH 51 | - cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON 52 | - make && make install 53 | - ${INSTALL_PREFIX}/bin/luarocks install totem 54 | - if [[ $TORCH_LUA_VERSION != 'LUAJIT21' && $TORCH_LUA_VERSION != 'LUAJIT20' ]]; then ${INSTALL_PREFIX}/bin/luarocks install luaffi; fi 55 | - cd $ROOT_TRAVIS_DIR 56 | - export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH 57 | script: 58 | - ${INSTALL_PREFIX}/bin/luarocks make 59 | - export PATH=${INSTALL_PREFIX}/bin:$PATH 60 | - export LD_LIBRARY_PATH=$HOME/GraphViz/lib:$LD_LIBRARY_PATH 61 | - export TESTLUA=$(which luajit lua | head -n 1) 62 | - ${TESTLUA} -lnngraph -e "print('nngraph loaded succesfully')" 63 | - cd test 64 | - ${TESTLUA} test_ModuleFromCriterion.lua 65 | - ${TESTLUA} test_nest.lua 66 | - ${TESTLUA} test_nngraph.lua 67 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 3 | CMAKE_POLICY(VERSION 2.6) 4 | FIND_PACKAGE(Torch REQUIRED) 5 | 6 | FILE(GLOB luasrc *.lua) 7 | 8 | ADD_TORCH_PACKAGE(nngraph "" "${luasrc}" "Neural Net Graph Package") 9 | -------------------------------------------------------------------------------- /COPYRIGHT.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 2 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 3 | Copyright (c) 2011-2013 NYU (Clement Farabet) 4 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 5 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 6 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 7 | 8 | All rights reserved. 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | 13 | 1. Redistributions of source code must retain the above copyright 14 | notice, this list of conditions and the following disclaimer. 15 | 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 20 | 3. Neither the names of NEC Laboratories American and IDIAP Research 21 | Institute nor the names of its contributors may be used to endorse or 22 | promote products derived from this software without specific prior 23 | written permission. 24 | 25 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 26 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 27 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 28 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 29 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 30 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 31 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 32 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 33 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 34 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 35 | POSSIBILITY OF SUCH DAMAGE. 36 | -------------------------------------------------------------------------------- /JustElement.lua: -------------------------------------------------------------------------------- 1 | 2 | local JustElement, parent = torch.class('nngraph.JustElement', 'nn.Module') 3 | function JustElement:__init() 4 | self.gradInput = {} 5 | end 6 | 7 | -- The input is a table with one element. 8 | -- The output the element from the table. 9 | function JustElement:updateOutput(input) 10 | assert(#input == 1, "expecting one element") 11 | self.output = input[1] 12 | return self.output 13 | end 14 | 15 | function JustElement:updateGradInput(input, gradOutput) 16 | self.gradInput[1] = gradOutput 17 | return self.gradInput 18 | end 19 | -------------------------------------------------------------------------------- /JustTable.lua: -------------------------------------------------------------------------------- 1 | 2 | local JustTable, parent = torch.class('nngraph.JustTable', 'nn.Module') 3 | function JustTable:__init() 4 | self.output = {} 5 | end 6 | 7 | -- The input is one element. 8 | -- The output is a table with one element: {element} 9 | function JustTable:updateOutput(input) 10 | self.output[1] = input 11 | return self.output 12 | end 13 | 14 | function JustTable:updateGradInput(input, gradOutput) 15 | self.gradInput = gradOutput[1] 16 | return self.gradInput 17 | end 18 | -------------------------------------------------------------------------------- /ModuleFromCriterion.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ A wrapper to turn a criterion into a module. 3 | 4 | The gradient with respect to the target will be zero. 5 | --]] 6 | local ModuleFromCriterion, parent = torch.class('nn.ModuleFromCriterion','nn.Module') 7 | function ModuleFromCriterion:__init(criterion) 8 | self.criterion = criterion 9 | self.output = torch.Tensor(1) 10 | self.gradInput = {torch.Tensor(), torch.Tensor()} 11 | end 12 | 13 | local unpack = unpack or table.unpack -- lua52 compat 14 | 15 | --[[ The input is a {prediction, target} pair. 16 | The output is a tensor with one number: the criterion output. 17 | --]] 18 | function ModuleFromCriterion:updateOutput(input) 19 | local prediction, target = unpack(input) 20 | self.output[1] = self.criterion:updateOutput(prediction, target) 21 | return self.output 22 | end 23 | 24 | function ModuleFromCriterion:updateGradInput(input, gradOutput) 25 | local prediction, target = unpack(input) 26 | local gradPrediction = self.criterion:updateGradInput(prediction, target) 27 | if type(gradPrediction) == 'table' then 28 | if type(self.gradInput[1]) ~= 'table' then 29 | self.gradInput[1] = {} -- initializing to table first time if it is tensor (which it is: line 10) 30 | for i=1, #gradPrediction do 31 | self.gradInput[1][i] = gradPrediction[i].new() -- and putting tensors of right size inside. 32 | end 33 | end 34 | for i=1, #gradPrediction do 35 | self.gradInput[1][i]:resizeAs(gradPrediction[i]):copy(gradPrediction[i]):mul(gradOutput[1]) 36 | end 37 | else 38 | self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1]) 39 | end 40 | self.gradInput[2]:resizeAs(target):zero() 41 | return self.gradInput 42 | end 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Network Graph Package 2 | 3 | [![Build Status](https://travis-ci.org/torch/nngraph.svg)](https://travis-ci.org/torch/nngraph) 4 | 5 | This package provides graphical computation for `nn` library in [Torch](https://github.com/torch/torch7/blob/master/README.md). 6 | 7 | ## Requirements 8 | 9 | You do *not* need `graphviz` to be able to use this library, but if you have it you will be able to display the graphs that you have created. For installing the package run the appropriate command below: 10 | 11 | ```bash 12 | # Mac users 13 | brew install graphviz 14 | # Debian/Ubuntu users 15 | sudo apt-get install graphviz -y 16 | ``` 17 | 18 | ## Usage 19 | 20 | [Plug: A more explanatory nngraph tutorial by Nando De Freitas of Oxford](https://www.cs.ox.ac.uk/people/nando.defreitas/machinelearning/practicals/practical5.pdf) 21 | 22 | The aim of this library is to provide users of `nn` package with tools to easily create complicated architectures. 23 | Any given `nn` `module` is going to be bundled into a *graph node*. 24 | The `__call__` operator of an instance of `nn.Module` is used to create architectures as if one is writing function calls. 25 | 26 | ### Two hidden layers MLP 27 | 28 | ```lua 29 | h1 = nn.Linear(20, 10)() 30 | h2 = nn.Linear(10, 1)(nn.Tanh()(nn.Linear(10, 10)(nn.Tanh()(h1)))) 31 | mlp = nn.gModule({h1}, {h2}) 32 | 33 | x = torch.rand(20) 34 | dx = torch.rand(1) 35 | mlp:updateOutput(x) 36 | mlp:updateGradInput(x, dx) 37 | mlp:accGradParameters(x, dx) 38 | 39 | -- draw graph (the forward graph, '.fg') 40 | graph.dot(mlp.fg, 'MLP') 41 | ``` 42 | 43 | 44 | 45 | Read this diagram from top to bottom, with the first and last nodes being *dummy nodes* that regroup all inputs and outputs of the graph. 46 | The `module` entry describes the function of the node, as applies to `input`, and producing a result of the shape `gradOutput`; `mapindex` contains pointers to the parent nodes. 47 | 48 | To save the *graph* on file, specify the file name, and both a `dot` and `svg` files will be saved. For example, you can type: 49 | 50 | ```lua 51 | graph.dot(mlp.fg, 'MLP', 'myMLP') 52 | ``` 53 | 54 | You can also use the `__unm__` and `__sub__` operators to replace all `__call__`: 55 | ```lua 56 | h1 = - nn.Linear(20,10) 57 | h2 = h1 58 | - nn.Tanh() 59 | - nn.Linear(10,10) 60 | - nn.Tanh() 61 | - nn.Linear(10, 1) 62 | mlp = nn.gModule({h1}, {h2}) 63 | ``` 64 | 65 | 66 | ### A network with 2 inputs and 2 outputs 67 | 68 | ```lua 69 | h1 = nn.Linear(20, 20)() 70 | h2 = nn.Linear(10, 10)() 71 | hh1 = nn.Linear(20, 1)(nn.Tanh()(h1)) 72 | hh2 = nn.Linear(10, 1)(nn.Tanh()(h2)) 73 | madd = nn.CAddTable()({hh1, hh2}) 74 | oA = nn.Sigmoid()(madd) 75 | oB = nn.Tanh()(madd) 76 | gmod = nn.gModule({h1, h2}, {oA, oB}) 77 | 78 | x1 = torch.rand(20) 79 | x2 = torch.rand(10) 80 | 81 | gmod:updateOutput({x1, x2}) 82 | gmod:updateGradInput({x1, x2}, {torch.rand(1), torch.rand(1)}) 83 | graph.dot(gmod.fg, 'Big MLP') 84 | ``` 85 | 86 | Alternatively, you can use `-` to make your code looks like the data flow: 87 | 88 | ```lua 89 | h1 = - nn.Linear(20,20) 90 | h2 = - nn.Linear(10,10) 91 | hh1 = h1 - nn.Tanh() - nn.Linear(20,1) 92 | hh2 = h2 - nn.Tanh() - nn.Linear(10,1) 93 | madd = {hh1,hh2} - nn.CAddTable() 94 | oA = madd - nn.Sigmoid() 95 | oB = madd - nn.Tanh() 96 | gmod = nn.gModule( {h1,h2}, {oA,oB} ) 97 | ``` 98 | 99 | 100 | 101 | 102 | ### A network with containers 103 | 104 | Another net that uses container modules (like `ParallelTable`) that output a table of outputs. 105 | 106 | ```lua 107 | m = nn.Sequential() 108 | m:add(nn.SplitTable(1)) 109 | m:add(nn.ParallelTable():add(nn.Linear(10, 20)):add(nn.Linear(10, 30))) 110 | input = nn.Identity()() 111 | input1, input2 = m(input):split(2) 112 | m3 = nn.JoinTable(1)({input1, input2}) 113 | 114 | g = nn.gModule({input}, {m3}) 115 | 116 | indata = torch.rand(2, 10) 117 | gdata = torch.rand(50) 118 | g:forward(indata) 119 | g:backward(indata, gdata) 120 | 121 | graph.dot(g.fg, 'Forward Graph') 122 | graph.dot(g.bg, 'Backward Graph') 123 | ``` 124 | 125 | 126 | 127 | 128 | 129 | ### More fun with graphs 130 | 131 | A multi-layer network where each layer takes output of previous two layers as input. 132 | 133 | ```lua 134 | input = nn.Identity()() 135 | L1 = nn.Tanh()(nn.Linear(10, 20)(input)) 136 | L2 = nn.Tanh()(nn.Linear(30, 60)(nn.JoinTable(1)({input, L1}))) 137 | L3 = nn.Tanh()(nn.Linear(80, 160)(nn.JoinTable(1)({L1, L2}))) 138 | 139 | g = nn.gModule({input}, {L3}) 140 | 141 | indata = torch.rand(10) 142 | gdata = torch.rand(160) 143 | g:forward(indata) 144 | g:backward(indata, gdata) 145 | 146 | graph.dot(g.fg, 'Forward Graph') 147 | graph.dot(g.bg, 'Backward Graph') 148 | ``` 149 | 150 | As your graph getting bigger and more complicated, the nested parentheses may become confusing. In this case, using `-` to chain the modules is a clearer and easier way: 151 | ```lua 152 | input = - nn.Identity() 153 | L1 = input 154 | - nn.Linear(10, 20) 155 | - nn.Tanh() 156 | L2 = { input, L1 } 157 | - nn.JoinTable(1) 158 | - nn.Linear(30,60) 159 | - nn.Tanh() 160 | L3 = { L1,L2 } 161 | - nn.JoinTable(1) 162 | - nn.Linear(80,160) 163 | - nn.Tanh() 164 | g = nn.gModule({input},{L3}) 165 | ``` 166 | 167 | 168 | 169 | 170 | 171 | ## Annotations 172 | 173 | It is possible to add annotations to your network, such as labeling nodes with names or attributes which will show up when you graph the network. 174 | This can be helpful in large graphs. 175 | 176 | For the full list of graph attributes see the 177 | [graphviz documentation](http://www.graphviz.org/doc/info/attrs.html). 178 | 179 | ```lua 180 | input = nn.Identity()() 181 | L1 = nn.Tanh()(nn.Linear(10, 20)(input)):annotate{ 182 | name = 'L1', description = 'Level 1 Node', 183 | graphAttributes = {color = 'red'} 184 | } 185 | L2 = nn.Tanh()(nn.Linear(30, 60)(nn.JoinTable(1)({input, L1}))):annotate{ 186 | name = 'L2', description = 'Level 2 Node', 187 | graphAttributes = {color = 'blue', fontcolor = 'green'} 188 | } 189 | L3 = nn.Tanh()(nn.Linear(80, 160)(nn.JoinTable(1)({L1, L2}))):annotate{ 190 | name = 'L3', description = 'Level 3 Node', 191 | graphAttributes = {color = 'green', 192 | style = 'filled', fillcolor = 'yellow'} 193 | } 194 | 195 | g = nn.gModule({input},{L3}) 196 | 197 | indata = torch.rand(10) 198 | gdata = torch.rand(160) 199 | g:forward(indata) 200 | g:backward(indata, gdata) 201 | 202 | graph.dot(g.fg, 'Forward Graph', '/tmp/fg') 203 | graph.dot(g.bg, 'Backward Graph', '/tmp/bg') 204 | ``` 205 | 206 | In this case, the graphs are saved in the following 4 files: `/tmp/{fg,bg}.{dot,svg}`. 207 | 208 | 209 | 210 | 211 | ## Debugging 212 | 213 | With nngraph, one can create very complicated networks. In these cases, finding errors can be hard. For that purpose, nngraph provides several useful utilities. The following code snippet shows how to use local variable names for annotating the nodes in a graph and how to enable debugging mode that automatically creates an svg file with error node marked in case of a runtime error. 214 | 215 | ```lua 216 | 217 | require 'nngraph' 218 | 219 | -- generate SVG of the graph with the problem node highlighted 220 | -- and hover over the nodes in svg to see the filename:line_number info 221 | -- nodes will be annotated with local variable names even if debug mode is not enabled. 222 | nngraph.setDebug(true) 223 | 224 | local function get_net(from, to) 225 | local from = from or 10 226 | local to = to or 10 227 | local input_x = nn.Identity()() 228 | local linear_module = nn.Linear(from, to)(input_x) 229 | 230 | -- Annotate nodes with local variable names 231 | nngraph.annotateNodes() 232 | return nn.gModule({input_x},{linear_module}) 233 | end 234 | 235 | local net = get_net(10,10) 236 | 237 | -- if you give a name to the net, it will use that name to produce the 238 | -- svg in case of error, if not, it will come up with a name 239 | -- that is derived from number of inputs and outputs to the graph 240 | net.name = 'my_bad_linear_net' 241 | 242 | -- prepare an input that is of the wrong size to force an error 243 | local input = torch.rand(11) 244 | pcall(function() net:updateOutput(input) end) 245 | -- it should have produced an error and spit out a graph 246 | -- just run Safari to display the svg 247 | os.execute('open -a Safari my_bad_linear_net.svg') 248 | ``` 249 | 250 | 251 | -------------------------------------------------------------------------------- /doc/annotation_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/annotation_bg.png -------------------------------------------------------------------------------- /doc/annotation_fg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/annotation_fg.png -------------------------------------------------------------------------------- /doc/mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp.png -------------------------------------------------------------------------------- /doc/mlp2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp2.png -------------------------------------------------------------------------------- /doc/mlp3_backward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp3_backward.png -------------------------------------------------------------------------------- /doc/mlp3_forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp3_forward.png -------------------------------------------------------------------------------- /doc/mlp4_backward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp4_backward.png -------------------------------------------------------------------------------- /doc/mlp4_forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp4_forward.png -------------------------------------------------------------------------------- /doc/my_bad_linear_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/my_bad_linear_net.png -------------------------------------------------------------------------------- /gmodule.lua: -------------------------------------------------------------------------------- 1 | local nesting = require('nngraph.nesting') 2 | local utils = require('nngraph.utils') 3 | local istensor = torch.isTensor 4 | local istable = utils.istable 5 | local istorchclass = utils.istorchclass 6 | 7 | local function getTotalGradOutput(node) 8 | local gradOutput = node.data.gradOutput 9 | assert(istable(gradOutput), "expecting gradients to sum") 10 | if #gradOutput > 1 then 11 | -- Check if we can bypass the allocation, for the special case where all 12 | -- gradOutputs but one are zero tensors with an underlying one-element 13 | -- storage. Note that for the case that we 14 | -- cannot bypass it, this check will only be performed once 15 | if not node.data.gradOutputBuffer then 16 | local count = 0 17 | local idx = 1 18 | -- Count how many gradOutput are tensors of 1 element filled with zero 19 | for i=1,#gradOutput do 20 | local zero = torch.isTensor(gradOutput[i]) and 21 | gradOutput[i]:storage() ~= nil and 22 | gradOutput[i]:storage():size() == 1 and 23 | gradOutput[i]:storage()[1] == 0 24 | if not zero then 25 | idx = i 26 | count = count + 1 27 | end 28 | end 29 | if count < 2 then 30 | -- Return the only non-zero one, or the first one 31 | -- if they are all zero 32 | return gradOutput[idx] 33 | end 34 | end 35 | node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) 36 | local gobuff = node.data.gradOutputBuffer 37 | nesting.resizeNestedAs(gobuff, gradOutput[1]) 38 | nesting.copyNested(gobuff, gradOutput[1]) 39 | for i=2,#gradOutput do 40 | nesting.addNestedTo(gobuff, gradOutput[i]) 41 | end 42 | gradOutput = gobuff 43 | else 44 | gradOutput = gradOutput[1] 45 | end 46 | return gradOutput 47 | end 48 | 49 | -- The gModule allows to have a general non-cyclic graph of of modules. 50 | -- 51 | -- Each node of the graph can have multiple inputs. 52 | -- The order of inputs is remembered in node.data.mapindex. 53 | -- 54 | -- Each node have only one output. 55 | -- The output can be also a table. 56 | -- To route parts of the outputted table to different modules, 57 | -- use the node:split(nOutputs) function. 58 | -- The split will create subnodes with narrowed output. 59 | -- 60 | -- Implementation details: 61 | -- The node.data.input holds a list of inputs. 62 | -- If a module expects only one input, the node.data.input[1] is used. 63 | -- 64 | -- The node.data.gradOutput holds the to-be-summed gradOutputs. 65 | -- Each node has only one output. So we need only one gradOutput. 66 | local gModule, parent = torch.class('nn.gModule','nn.Container') 67 | 68 | function gModule:__init(inputs,outputs) 69 | parent.__init(self) 70 | -- the graph is defined backwards, we have the output modules as input here 71 | -- we will define a dummy output node that connects all output modules 72 | -- into itself. This will be the output for the forward graph and 73 | -- input point for the backward graph 74 | local node 75 | local outnode = nngraph.Node({input={}}) 76 | for i = 1, utils.tableMaxN(outputs) do 77 | node = outputs[i] 78 | if torch.typename(node) ~= 'nngraph.Node' then 79 | error(utils.expectingNodeErrorMessage(node, 'outputs', i)) 80 | end 81 | outnode:add(node, true) 82 | end 83 | for i = 1, utils.tableMaxN(inputs) do 84 | node = inputs[i] 85 | if torch.typename(node) ~= 'nngraph.Node' then 86 | error(utils.expectingNodeErrorMessage(node, 'inputs', i)) 87 | end 88 | end 89 | -- We add also a dummy input node. 90 | -- The input node will be split to feed the passed input nodes. 91 | local innode = nngraph.Node({input={}}) 92 | assert(#inputs > 0, "no inputs are not supported") 93 | if #inputs == 1 then 94 | inputs[1]:add(innode,true) 95 | else 96 | local splits = {innode:split(#inputs)} 97 | for i = 1, #inputs do 98 | assert(#inputs[i].children == 0, "an input should have no inputs") 99 | end 100 | for i = 1, #inputs do 101 | inputs[i]:add(splits[i],true) 102 | end 103 | end 104 | 105 | -- the backward graph (bg) is for gradients 106 | -- the forward graph (fg) is for function evaluation 107 | self.bg = outnode:graph() 108 | self.fg = self.bg:reverse() 109 | 110 | -- the complete graph is constructed 111 | -- now regenerate the graphs with the additional nodes 112 | 113 | local roots = self.fg:roots() 114 | -- if there are more than one root in the forward graph, then make sure that 115 | -- extra roots are parameter nodes 116 | if #roots > 1 then 117 | local innodeRoot = nil 118 | -- first find our innode 119 | for _, root in ipairs(roots) do 120 | if root.data == innode.data then 121 | assert(innodeRoot == nil, 'more than one matching input node found in leaves') 122 | innodeRoot = root 123 | else 124 | assert(root.data.module, 'Expected nnop.Parameters node, module not found in node') 125 | assert(torch.typename(root.data.module) == 'nnop.Parameters', 126 | 'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module)) 127 | end 128 | end 129 | assert(innodeRoot ~= nil, 'input node not found among roots') 130 | self.innode = innodeRoot 131 | else 132 | assert(#self.fg:roots() == 1, "expecting only one start") 133 | self.innode = self.fg:roots()[1] 134 | end 135 | 136 | assert(self.innode.data == innode.data, "expecting the forward innode") 137 | self.outnode = outnode 138 | self.verbose = false 139 | self.nInputs = #inputs 140 | 141 | -- computation on the graph is done through topsort of forward and backward graphs 142 | self.forwardnodes = self.fg:topsort() 143 | self.backwardnodes = self.bg:topsort() 144 | 145 | -- iteratare over all nodes: check, tag and add to container 146 | for i,node in ipairs(self.forwardnodes) do 147 | -- check for unused inputs or unused split() outputs 148 | if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then 149 | local nUnused = node.data.nSplitOutputs - #node.children 150 | local debugLabel = node.data.annotations._debugLabel 151 | local errStr = 152 | "%s of split(%s) outputs from the node declared at %s are unused" 153 | error(string.format(errStr, nUnused, node.data.nSplitOutputs, 154 | debugLabel)) 155 | end 156 | 157 | -- Check whether any nodes were defined as taking this node as an input, 158 | -- but then left dangling and don't connect to the output. If this is 159 | -- the case, then they won't be present in forwardnodes, so error out. 160 | for successor, _ in pairs(node.data.reverseMap) do 161 | local successorIsInGraph = false 162 | 163 | -- Only need to the part of forwardnodes from i onwards, topological 164 | -- sort guarantees it cannot be in the first part. 165 | for j = i+1, #self.forwardnodes do 166 | -- Compare equality of data tables, as new Node objects have been 167 | -- created by processes such as topoological sort, but the 168 | -- underlying .data table is shared. 169 | if self.forwardnodes[j].data == successor.data then 170 | successorIsInGraph = true 171 | break 172 | end 173 | end 174 | local errStr = 175 | "node declared on %s does not connect to gmodule output" 176 | assert(successorIsInGraph, 177 | string.format(errStr, successor.data.annotations._debugLabel)) 178 | end 179 | 180 | -- set data.forwardNodeId for node:label() output 181 | node.data.forwardNodeId = node.id 182 | 183 | -- add module to container 184 | if node.data.module then 185 | self:add(node.data.module) 186 | end 187 | end 188 | 189 | self.output = nil 190 | self.gradInput = nil 191 | if #self.outnode.children > 1 then 192 | self.output = self.outnode.data.input 193 | end 194 | end 195 | 196 | function gModule:replace(callback) 197 | local out = callback(self) 198 | local revmodules = {} 199 | for i,m in ipairs(self.modules) do 200 | revmodules[m] = i 201 | end 202 | for i,node in ipairs(self.forwardnodes) do 203 | if node.data.module then 204 | local m = node.data.module 205 | node.data.module = m:replace(callback) 206 | self.modules[revmodules[m]] = node.data.module 207 | end 208 | end 209 | return out 210 | end 211 | 212 | function gModule:map(gm, func) 213 | for i,node in ipairs(self.forwardnodes) do 214 | local gmnode = gm.forwardnodes[i] 215 | assert(gmnode, 'trying to map another gModule with a different structure') 216 | if node.data.module then 217 | assert(gmnode.data.module, 'trying to map another gModule with a different structure') 218 | func(node.data.module, gmnode.data.module) 219 | end 220 | end 221 | end 222 | 223 | --[[ Recursively applies type(type_str) to any tensors in the argument. If the 224 | argument is a tensor, type(type_str) is applied; if the argument is an array, 225 | this function recurses into it. ]] 226 | local function recursiveType(param, type_str) 227 | if torch.type(param) == 'table' then 228 | for i = 1, #param do 229 | param[i] = recursiveType(param[i], type_str) 230 | end 231 | elseif torch.typename(param) and 232 | torch.typename(param):find('torch%..+Tensor') then 233 | param = param:type(type_str) 234 | end 235 | return param 236 | end 237 | 238 | function gModule:type(type, tensorCache) 239 | if not type then 240 | return self._type 241 | end 242 | 243 | tensorCache = tensorCache or {} 244 | 245 | local function applyTypeToTable(table) 246 | for key, value in pairs(table) do 247 | table[key] = recursiveType(table[key], type) 248 | end 249 | end 250 | 251 | -- Convert any stored data in self, and in the in and out nodes 252 | applyTypeToTable(self) 253 | if self.innode then applyTypeToTable(self.innode.data) end 254 | if self.outnode then applyTypeToTable(self.outnode.data) end 255 | 256 | -- Loop through modules and convert data 257 | for _, m in ipairs(self.modules) do 258 | m:type(type, tensorCache) 259 | end 260 | 261 | for i,node in ipairs(self.backwardnodes) do 262 | if node.data.gradOutputBuffer ~= nil then 263 | node.data.gradOutputBuffer = 264 | recursiveType(node.data.gradOutputBuffer, type) 265 | end 266 | for k, child in ipairs(node.children) do 267 | applyTypeToTable(child.data) 268 | end 269 | end 270 | 271 | for i,node in ipairs(self.forwardnodes) do 272 | if node.data.input ~= nil then 273 | node.data.input = recursiveType(node.data.input, type) 274 | end 275 | for k, child in ipairs(node.children) do 276 | applyTypeToTable(child.data) 277 | end 278 | end 279 | 280 | self._type = type 281 | return self 282 | end 283 | 284 | function gModule:updateOutput(input) 285 | return self:runForwardFunction('updateOutput',input) 286 | end 287 | 288 | function gModule:clearState() 289 | local ret = parent.clearState(self) 290 | for _,node in ipairs(self.backwardnodes) do 291 | node.data.gradOutput = nil 292 | node.data.gradOutputBuffer = nil 293 | end 294 | for _,node in ipairs(self.forwardnodes) do 295 | node.data.input = nil 296 | end 297 | return ret 298 | end 299 | 300 | function gModule:runForwardFunction(func,input) 301 | if type(func) == "string" then 302 | local func_name = func 303 | func = function(module,input) return module[func_name](module,input) end 304 | end 305 | -- For backward compatibility, we allow self.nInputs to be missing. 306 | local nInputs = self.nInputs or #self.innode.children 307 | -- We see the input as a list of inputs. 308 | if nInputs <= 1 then 309 | input={input} 310 | elseif type(input) ~= "table" then 311 | error(string.format("expecting table of %s inputs", nInputs)) 312 | end 313 | local function neteval(node) 314 | local function propagate(node,x) 315 | for i,child in ipairs(node.children) do 316 | child.data.input = child.data.input or {} 317 | local mapindex = child.data.mapindex[node.data] 318 | assert(not child.data.input[mapindex], "each input should have one source") 319 | child.data.input[mapindex] = x 320 | end 321 | end 322 | if node.data.selectindex then 323 | assert(not node.data.module, "the selectindex-handling nodes should have no module") 324 | local input = node.data.input 325 | assert(#input == 1, "only the splitted node should be the input") 326 | assert(istable(input[1]), "the input for a split should be a table") 327 | input = input[1][node.data.selectindex] 328 | propagate(node,input) 329 | else 330 | local input = node.data.input 331 | 332 | -- a parameter node is captured 333 | if input == nil and node.data.module ~= nil then 334 | input = {} 335 | end 336 | if #input == 1 then 337 | input = input[1] 338 | end 339 | -- forward through this node 340 | -- If no module is present, the node behaves like nn.Identity. 341 | local output 342 | if not node.data.module then 343 | output = input 344 | else 345 | output = func(node.data.module,input) 346 | end 347 | if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then 348 | error(string.format("split(%s) cannot split %s outputs", 349 | node.data.nSplitOutputs, 350 | #output)) 351 | end 352 | -- propagate the output to children 353 | propagate(node,output) 354 | end 355 | if self.verbose then 356 | print(' V : ' .. node:label()) 357 | end 358 | end 359 | 360 | local innode = self.innode 361 | if #input ~= nInputs then 362 | error(string.format('Got %s inputs instead of %s', #input, nInputs)) 363 | end 364 | -- first clear the input states 365 | for _,node in ipairs(self.forwardnodes) do 366 | local input = node.data.input 367 | while input and #input>0 do 368 | table.remove(input) 369 | end 370 | end 371 | -- Set the starting input. 372 | -- We do copy instead of modifying the passed input. 373 | innode.data.input = innode.data.input or {} 374 | for i, item in ipairs(input) do 375 | innode.data.input[i] = item 376 | end 377 | 378 | -- the run forward 379 | for i,node in ipairs(self.forwardnodes) do 380 | neteval(node) 381 | end 382 | 383 | self.output = self.outnode.data.input 384 | if #self.outnode.children == 1 then 385 | self.output = self.output[1] 386 | end 387 | return self.output 388 | end 389 | 390 | function gModule:updateGradInput(input,gradOutput) 391 | local function neteval(node) 392 | if node.data.selectindex then 393 | assert(not node.data.module, "the selectindex-handling nodes should have no module") 394 | assert(#node.children == 1, "only the splitted node should be the input") 395 | local child = node.children[1] 396 | local go = getTotalGradOutput(node) 397 | child.data.gradOutput = child.data.gradOutput or {} 398 | assert(#child.data.gradOutput <= 1, "the splitted node should be used only once") 399 | -- The data.gradOutput holds the to-be-summed gradients. 400 | child.data.gradOutput[1] = child.data.gradOutput[1] or {} 401 | assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet") 402 | child.data.gradOutput[1][node.data.selectindex] = go 403 | else 404 | local gradOutput = getTotalGradOutput(node) 405 | -- updateGradInput through this node 406 | -- If no module is present, the node behaves like nn.Identity. 407 | local gradInput 408 | if not node.data.module then 409 | gradInput = gradOutput 410 | else 411 | local input = node.data.input 412 | -- a parameter node is captured 413 | if input == nil and node.data.module ~= nil then 414 | input = {} 415 | end 416 | if #input == 1 then 417 | input = input[1] 418 | end 419 | local module = node.data.module 420 | gradInput = module:updateGradInput(input,gradOutput) 421 | end 422 | -- propagate the output to children 423 | for i,child in ipairs(node.children) do 424 | child.data.gradOutput = child.data.gradOutput or {} 425 | local mapindex = node.data.mapindex[child.data] 426 | local gi 427 | if #node.children == 1 then 428 | gi = gradInput 429 | else 430 | gi = gradInput[mapindex] 431 | end 432 | table.insert(child.data.gradOutput,gi) 433 | end 434 | end 435 | if self.verbose then 436 | print(' V : ' .. node:label()) 437 | end 438 | end 439 | local outnode = self.outnode 440 | if #outnode.children > 1 and #gradOutput ~= #outnode.children then 441 | error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) 442 | end 443 | for _,node in ipairs(self.backwardnodes) do 444 | local gradOutput = node.data.gradOutput 445 | while gradOutput and #gradOutput >0 do 446 | table.remove(gradOutput) 447 | end 448 | end 449 | -- Set the starting gradOutput. 450 | outnode.data.gradOutput = outnode.data.gradOutput or {} 451 | outnode.data.gradOutput[1] = gradOutput 452 | 453 | for i,node in ipairs(self.backwardnodes) do 454 | neteval(node) 455 | end 456 | 457 | assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once") 458 | self.gradInput = self.innode.data.gradOutput[1] 459 | return self.gradInput 460 | end 461 | 462 | function gModule:accGradParameters(input,gradOutput,lr) 463 | local function neteval(node) 464 | if node.data.module then 465 | local module = node.data.module 466 | local gradOutput = node.data.gradOutput[1] 467 | if #node.data.gradOutput > 1 then 468 | gradOutput = node.data.gradOutputBuffer 469 | end 470 | local input = node.data.input 471 | -- a parameter node is captured 472 | if input == nil and node.data.module ~= nil then 473 | input = {} 474 | end 475 | if #input == 1 then 476 | input = input[1] 477 | end 478 | -- accGradParameters through this node 479 | module:accGradParameters(input,gradOutput,lr) 480 | end 481 | if self.verbose then 482 | print(' V : ' .. node:label()) 483 | end 484 | end 485 | local outnode = self.outnode 486 | if #outnode.children > 1 and #gradOutput ~= #outnode.children then 487 | error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) 488 | end 489 | for i,node in ipairs(self.backwardnodes) do 490 | neteval(node) 491 | end 492 | end 493 | 494 | function gModule:read(file) 495 | local data = file:readObject() 496 | for k, v in pairs(data) do 497 | self[k] = v 498 | end 499 | 500 | -- Initialize the modules table if necessary. 501 | if not self.modules then 502 | self.modules = {} 503 | for _, node in ipairs(self.forwardnodes) do 504 | if node.data.module then 505 | table.insert(self.modules, node.data.module) 506 | end 507 | end 508 | end 509 | end 510 | 511 | 512 | function gModule:__tostring__() 513 | return self.name or torch.type(self) 514 | end 515 | -------------------------------------------------------------------------------- /graphinspecting.lua: -------------------------------------------------------------------------------- 1 | 2 | -- The findCurrentNode() depends on the names of the 3 | -- local variables in the nngraph.gModule source code. 4 | local function findCurrentNode() 5 | for level = 2, math.huge do 6 | local info = debug.getinfo(level, "n") 7 | if info == nil then 8 | return nil 9 | end 10 | 11 | local funcName = info.name 12 | if funcName == "neteval" then 13 | local varName, node = debug.getlocal(level, 1) 14 | if varName == "node" then 15 | return node 16 | end 17 | end 18 | end 19 | end 20 | 21 | -- Runs the func and calls onError(failedNode, ...) on an error. 22 | -- The stack trace is inspected to find the failedNode. 23 | local function runChecked(func, onError, ...) 24 | -- The current node needs to be searched-for, before unrolling the stack. 25 | local failedNode 26 | local function errorHandler(message) 27 | -- The stack traceback is added only if not already present. 28 | if not string.find(message, 'stack traceback:\n', 1, true) then 29 | message = debug.traceback(message, 2) 30 | end 31 | failedNode = findCurrentNode() 32 | return message 33 | end 34 | 35 | local ok, result = xpcall(func, errorHandler) 36 | if ok then 37 | return result 38 | end 39 | 40 | onError(failedNode, ...) 41 | -- Passing the level 0 avoids adding an additional error position info 42 | -- to the message. 43 | error(result, 0) 44 | end 45 | 46 | local function customToDot(graph, title, failedNode) 47 | local str = graph:todot(title) 48 | if not failedNode then 49 | return str 50 | end 51 | 52 | local failedNodeId = nil 53 | for i, node in ipairs(graph.nodes) do 54 | if node.data == failedNode.data then 55 | failedNodeId = node.id 56 | break 57 | end 58 | end 59 | 60 | if failedNodeId ~= nil then 61 | -- The closing '}' is removed. 62 | -- And red fillcolor is specified for the failedNode. 63 | str = string.gsub(str, '}%s*$', '') 64 | str = str .. string.format('n%s[style=filled, fillcolor=red];\n}', 65 | failedNodeId) 66 | end 67 | return str 68 | end 69 | 70 | local function saveSvg(svgPathPrefix, dotStr) 71 | io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix)) 72 | local dotPath = svgPathPrefix .. '.dot' 73 | local dotFile = io.open(dotPath, 'w') 74 | dotFile:write(dotStr) 75 | dotFile:close() 76 | 77 | local svgPath = svgPathPrefix .. '.svg' 78 | local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath) 79 | os.execute(cmd) 80 | end 81 | 82 | local function onError(failedNode, gmodule) 83 | local nInputs = gmodule.nInputs or #gmodule.innode.children 84 | local svgPathPrefix = gmodule.name or string.format( 85 | 'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children) 86 | if paths.filep(svgPathPrefix .. '.svg') then 87 | svgPathPrefix = svgPathPrefix .. '_' .. paths.basename(os.tmpname()) 88 | end 89 | local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode) 90 | saveSvg(svgPathPrefix, dotStr) 91 | end 92 | 93 | local origFuncs = { 94 | runForwardFunction = nn.gModule.runForwardFunction, 95 | updateGradInput = nn.gModule.updateGradInput, 96 | accGradParameters = nn.gModule.accGradParameters, 97 | } 98 | 99 | -- When debug is enabled, 100 | -- a gmodule.name .. '.svg' will be saved 101 | -- if an exception occurs in a graph execution. 102 | -- The problematic node will be marked by red color. 103 | function nngraph.setDebug(enable) 104 | if not enable then 105 | -- When debug is disabled, 106 | -- the origFuncs are restored on nn.gModule. 107 | for funcName, origFunc in pairs(origFuncs) do 108 | nn.gModule[funcName] = origFunc 109 | end 110 | return 111 | end 112 | 113 | for funcName, origFunc in pairs(origFuncs) do 114 | nn.gModule[funcName] = function(...) 115 | local args = {...} 116 | local gmodule = args[1] 117 | local unpack = unpack or table.unpack 118 | return runChecked(function() 119 | return origFunc(unpack(args)) 120 | end, onError, gmodule) 121 | end 122 | end 123 | end 124 | 125 | -- Sets node.data.annotations.name for the found nodes. 126 | -- The local variables at the given stack level are inspected. 127 | -- The default stack level is 1 (the function that called annotateNodes()). 128 | function nngraph.annotateNodes(stackLevel) 129 | stackLevel = stackLevel or 1 130 | for index = 1, math.huge do 131 | local varName, varValue = debug.getlocal(stackLevel + 1, index) 132 | if not varName then 133 | break 134 | end 135 | if torch.typename(varValue) == "nngraph.Node" then 136 | -- An explicit name is preserved. 137 | if not varValue.data.annotations.name then 138 | varValue:annotate({name = varName}) 139 | end 140 | end 141 | end 142 | end 143 | 144 | --[[ 145 | SVG visualization for gmodule 146 | TODO: add custom coloring with node types 147 | ]] 148 | function nngraph.display(gmodule) 149 | local ffi = require 'ffi' 150 | local cmd 151 | if ffi.os == 'Linux' then 152 | cmd = 'xdg-open' 153 | elseif ffi.os == 'OSX' then 154 | cmd = 'open -a Safari' 155 | end 156 | local fname = os.tmpname() 157 | graph.dot(gmodule.fg, fname, fname) 158 | os.execute(cmd .. ' ' .. fname .. '.svg') 159 | end 160 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'graph' 3 | 4 | nngraph = {} 5 | 6 | require('nngraph.nest') 7 | require('nngraph.node') 8 | require('nngraph.gmodule') 9 | require('nngraph.graphinspecting') 10 | require('nngraph.JustElement') 11 | require('nngraph.JustTable') 12 | require('nngraph.ModuleFromCriterion') 13 | 14 | -- handy functions 15 | local utils = require('nngraph.utils') 16 | local istensor = torch.isTensor 17 | local istable = utils.istable 18 | local istorchclass = utils.istorchclass 19 | 20 | -- simpler todot functions 21 | nngraph.simple_print = require('nngraph.simple_print') 22 | 23 | -- Modify the __call function to hack into nn.Module 24 | local Module = torch.getmetatable('nn.Module') 25 | function Module:__call__(...) 26 | local nArgs = select("#", ...) 27 | assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.') 28 | 29 | local input = ... 30 | if nArgs == 1 and input == nil then 31 | error(utils.expectingNodeErrorMessage(input, 'inputs', 1)) 32 | end 33 | -- Disallow passing empty table, in case someone passes a table with some 34 | -- typo'd variable name in. 35 | if type(input) == 'table' and next(input) == nil then 36 | error('cannot pass an empty table of inputs. To indicate no incoming ' .. 37 | 'connections, leave the second set of parens blank.') 38 | end 39 | if not istable(input) then 40 | input = {input} 41 | end 42 | local mnode = nngraph.Node({module=self}) 43 | 44 | local dnode 45 | for i = 1, utils.tableMaxN(input) do 46 | dnode = input[i] 47 | if torch.typename(dnode) ~= 'nngraph.Node' then 48 | error(utils.expectingNodeErrorMessage(dnode, 'inputs', i)) 49 | end 50 | mnode:add(dnode,true) 51 | end 52 | 53 | return mnode 54 | end 55 | 56 | local Criterion = torch.getmetatable('nn.Criterion') 57 | function Criterion:__call__(...) 58 | return nn.ModuleFromCriterion(self)(...) 59 | end 60 | 61 | 62 | 63 | 64 | Module.__unm__ = function( obj ) 65 | return obj() 66 | end 67 | 68 | Module.__sub__ = function( prev, next ) 69 | return next(prev) 70 | end 71 | 72 | 73 | do 74 | local Node = torch.getmetatable('nngraph.Node') 75 | Node.__sub__ = function( prev, next ) 76 | return next(prev) 77 | end 78 | end 79 | 80 | return nngraph 81 | -------------------------------------------------------------------------------- /nest.lua: -------------------------------------------------------------------------------- 1 | 2 | local function isNode(input) 3 | local typename = torch.typename(input) 4 | return typename and typename == 'nngraph.Node' 5 | end 6 | 7 | local function isNonEmptyList(input) 8 | return type(input) == "table" and #input > 0 9 | end 10 | 11 | local function _nest(input) 12 | if not isNode(input) and not isNonEmptyList(input) then 13 | error('what is this in the nest input? ' .. tostring(input)) 14 | end 15 | 16 | if isNode(input) then 17 | return input 18 | end 19 | 20 | if #input == 1 then 21 | return nngraph.JustTable()(input) 22 | end 23 | 24 | local wrappedChildren = {} 25 | for i, child in ipairs(input) do 26 | wrappedChildren[i] = _nest(child) 27 | end 28 | return nn.Identity()(wrappedChildren) 29 | end 30 | 31 | -- Returns a nngraph node to represent a nested structure. 32 | -- Usage example: 33 | -- local in1 = nn.Identity()() 34 | -- local in2 = nn.Identity()() 35 | -- local in3 = nn.Identity()() 36 | -- local ok = nn.CAddTable()(nngraph.nest({in1})) 37 | -- local in1Again = nngraph.nest(in1) 38 | -- local state = nngraph.nest({in1, {in2}, in3}) 39 | function nngraph.nest(...) 40 | local nArgs = select("#", ...) 41 | assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.') 42 | 43 | local input = ... 44 | assert(nArgs > 0 and input ~= nil, 'Pass an input.') 45 | return _nest(input) 46 | end 47 | -------------------------------------------------------------------------------- /nesting.lua: -------------------------------------------------------------------------------- 1 | 2 | local nesting = {} 3 | 4 | local utils = require('nngraph.utils') 5 | 6 | -- Creates a clone of a tensor or of a table with tensors. 7 | function nesting.cloneNested(obj) 8 | if torch.isTensor(obj) then 9 | return obj:clone() 10 | end 11 | 12 | local result = {} 13 | for key, child in pairs(obj) do 14 | result[key] = nesting.cloneNested(child) 15 | end 16 | return result 17 | end 18 | 19 | -- Fills the obj with the given value. 20 | -- The obj can be a tensor or a table with tensors. 21 | function nesting.fillNested(obj, value) 22 | if torch.isTensor(obj) then 23 | obj:fill(value) 24 | else 25 | for key, child in pairs(obj) do 26 | nesting.fillNested(child, value) 27 | end 28 | end 29 | end 30 | 31 | -- Resizes all tensors in the output. 32 | function nesting.resizeNestedAs(output, input) 33 | if torch.isTensor(output) then 34 | output:resizeAs(input) 35 | else 36 | for key, child in pairs(input) do 37 | -- A new element is added to the output, if needed. 38 | if not output[key] then 39 | output[key] = nesting.cloneNested(child) 40 | else 41 | nesting.resizeNestedAs(output[key], child) 42 | end 43 | end 44 | -- Extra elements are removed from the output. 45 | for key, child in pairs(output) do 46 | if not input[key] then 47 | output[key] = nil 48 | end 49 | end 50 | end 51 | end 52 | 53 | -- Copies all tensors in the output. 54 | function nesting.copyNested(output, input) 55 | if torch.isTensor(output) then 56 | output:copy(input) 57 | else 58 | for key, child in pairs(input) do 59 | nesting.copyNested(output[key], child) 60 | end 61 | -- Extra elements in the output table cause an error. 62 | for key, child in pairs(output) do 63 | if not input[key] then 64 | error('key ' .. tostring(key) .. 65 | ' present in output but not in input') 66 | end 67 | end 68 | end 69 | end 70 | 71 | -- Adds the input to the output. 72 | -- The input can contain nested tables. 73 | -- The output will contain the same nesting of tables. 74 | function nesting.addNestedTo(output, input) 75 | if torch.isTensor(output) then 76 | output:add(input) 77 | else 78 | for key, child in pairs(input) do 79 | assert(output[key] ~= nil, "missing key") 80 | nesting.addNestedTo(output[key], child) 81 | end 82 | end 83 | end 84 | 85 | return nesting 86 | -------------------------------------------------------------------------------- /nngraph-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "nngraph" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/torch/nngraph", 6 | tag = "master" 7 | } 8 | 9 | description = { 10 | summary = "This package provides graphical computation for nn library in Torch7.", 11 | homepage = "https://github.com/torch/nngraph", 12 | license = "UNKNOWN" 13 | } 14 | 15 | dependencies = { 16 | "torch >= 7.0", 17 | "graph", 18 | "nn" 19 | } 20 | 21 | build = { 22 | type = "command", 23 | build_command = [[ 24 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) 25 | ]], 26 | install_command = "cd build && $(MAKE) install" 27 | } 28 | -------------------------------------------------------------------------------- /node.lua: -------------------------------------------------------------------------------- 1 | 2 | local utils = require('nngraph.utils') 3 | local istensor = torch.isTensor 4 | local istable = utils.istable 5 | local istorchclass = utils.istorchclass 6 | require 'debug' 7 | 8 | local nnNode,parent = torch.class('nngraph.Node','graph.Node') 9 | 10 | function nnNode:__init(data) 11 | parent.__init(self,data) 12 | self.data.annotations = self.data.annotations or {} 13 | self.data.mapindex = self.data.mapindex or {} 14 | self.data.reverseMap = self.data.reverseMap or {} 15 | if not self.data.annotations._debugLabel then 16 | self:_makeDebugLabel(debug.getinfo(6, 'Sl')) 17 | end 18 | end 19 | 20 | --[[ Build a string label which will be used a tooltip when 21 | making a graph.]] 22 | function nnNode:_makeDebugLabel(dinfo) 23 | if dinfo then 24 | self.data.annotations._debugLabel = string.format('[%s]:%d_%s', 25 | dinfo.short_src, 26 | dinfo.currentline, 27 | dinfo.name or '') 28 | end 29 | end 30 | 31 | -- domap ensures that this node will keep track of the order its children are added. 32 | -- mapindex is a forward/backward list 33 | -- index = self.data.mapindex[child.data] 34 | -- child.data = self.data.mapindex[index] 35 | function nnNode:add(child,domap) 36 | parent.add(self,child) 37 | if domap then 38 | local mapindex = self.data.mapindex 39 | local data = child.data 40 | assert(not mapindex[data], "Don't pass the same input twice.") 41 | table.insert(mapindex,data) 42 | mapindex[data] = #mapindex 43 | 44 | -- The "child" that is added here actually represents the input node, 45 | -- so we write into that node to indicate that we are downstream of it. 46 | -- This enables dangling pointer detection. 47 | local revMap = child.data.reverseMap 48 | assert(not revMap[self], 'this connection has already been made!') 49 | revMap[self] = true 50 | end 51 | end 52 | 53 | -- this function returns noutput number of new nodes 54 | -- that each take a single component of the output of this 55 | -- node in the order they are returned. 56 | function nnNode:split(noutput) 57 | if noutput == 1 then 58 | return nngraph.JustElement()(self) 59 | end 60 | local debugLabel = self.data.annotations._debugLabel 61 | -- Specify the source location where :split is called. 62 | local dinfo = debug.getinfo(2, 'Sl') 63 | local splitLoc = string.format(' split at [%s]:%d', 64 | dinfo.short_src, 65 | dinfo.currentline) 66 | local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. splitLoc .. '-mnode'}}) 67 | mnode:add(self,true) 68 | 69 | local selectnodes = {} 70 | for i=1,noutput do 71 | local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}}) 72 | node:add(mnode,true) 73 | table.insert(selectnodes,node) 74 | end 75 | 76 | local unpack = unpack or table.unpack -- Lua52 compat 77 | return unpack(selectnodes) 78 | end 79 | 80 | function nnNode:annotate(annotations) 81 | for k, v in pairs(annotations) do 82 | self.data.annotations[k] = v 83 | end 84 | 85 | return self 86 | end 87 | 88 | function nnNode:graphNodeName() 89 | if self.data.annotations.name then 90 | return self.data.annotations.name .. ' (' .. self.id .. ')' 91 | else 92 | return 'Node' .. self.id 93 | end 94 | end 95 | 96 | function nnNode:graphNodeAttributes() 97 | self.data.annotations.graphAttributes = 98 | self.data.annotations.graphAttributes or {} 99 | if not self.data.annotations.graphAttributes.tooltip then 100 | self.data.annotations.graphAttributes.tooltip = 101 | self.data.annotations._debugLabel 102 | end 103 | 104 | return self.data.annotations.graphAttributes 105 | end 106 | 107 | local function getNanFlag(data) 108 | if data:nElement() == 0 then 109 | return '' 110 | end 111 | local isNan = (data:ne(data):sum() > 0) 112 | if isNan then 113 | return 'NaN' 114 | end 115 | if data:max() == math.huge then 116 | return 'inf' 117 | end 118 | if data:min() == -math.huge then 119 | return '-inf' 120 | end 121 | return '' 122 | end 123 | 124 | function nnNode:label() 125 | 126 | local lbl = {} 127 | 128 | local function getstr(data) 129 | if not data then return '' end 130 | if istensor(data) then 131 | local nanFlag = getNanFlag(data) 132 | local tensorType = 'Tensor' 133 | if data:type() ~= torch.Tensor():type() then 134 | tensorType = data:type() 135 | end 136 | return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag 137 | elseif istable(data) then 138 | local tstr = {} 139 | for i,v in ipairs(data) do 140 | table.insert(tstr, getstr(v)) 141 | end 142 | return '{' .. table.concat(tstr,',') .. '}' 143 | else 144 | return tostring(data):gsub('\n','\\l') 145 | end 146 | end 147 | local function getmapindexstr(mapindex) 148 | local tstr = {} 149 | for i,data in ipairs(mapindex) do 150 | local inputId = 'Node' .. (data.forwardNodeId or '') 151 | table.insert(tstr, inputId) 152 | end 153 | return '{' .. table.concat(tstr,',') .. '}' 154 | end 155 | 156 | for k,v in pairs(self.data) do 157 | local vstr = '' 158 | if k== 'mapindex' then 159 | if #v > 1 then 160 | vstr = getmapindexstr(v) 161 | table.insert(lbl, k .. ' = ' .. vstr) 162 | end 163 | elseif k== 'forwardNodeId' or k== 'annotations' then 164 | -- the forwardNodeId is not displayed in the label. 165 | else 166 | vstr = getstr(v) 167 | table.insert(lbl, k .. ' = ' .. vstr) 168 | end 169 | end 170 | 171 | local desc 172 | if self.data.annotations.description then 173 | desc = 'desc = ' .. self.data.annotations.description .. '\\n' 174 | else 175 | desc = '' 176 | end 177 | return desc .. table.concat(lbl,"\\l") 178 | end 179 | -------------------------------------------------------------------------------- /simple_print.lua: -------------------------------------------------------------------------------- 1 | local function removeNodeFromEdges(node_id, edges) 2 | local from_nodes = {} 3 | local to_nodes = {} 4 | -- remove edges 5 | local idx = 1 6 | while idx <= #edges do 7 | local edge = edges[idx] 8 | if edge.source == node_id then 9 | local to_node = edges[idx].target 10 | table.insert(to_nodes, to_node) 11 | table.remove(edges, idx) 12 | elseif edge.target == node_id then 13 | local from_node = edges[idx].source 14 | table.insert(from_nodes, from_node) 15 | table.remove(edges, idx) 16 | else 17 | idx = idx + 1 18 | end 19 | end 20 | 21 | -- add new edges 22 | for _, f in pairs(from_nodes) do 23 | for _, t in pairs(to_nodes) do 24 | local edge = {source = f, target= t} 25 | table.insert(edges, edge) 26 | end 27 | end 28 | 29 | return edges 30 | end 31 | 32 | local function isNodeGood(node) 33 | return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity') 34 | end 35 | 36 | local function reIndexNodes(nodes, edges) 37 | -- make reverse map 38 | local rev_map = {} 39 | for idx = 1, #nodes do 40 | rev_map[nodes[idx].id] = idx 41 | nodes[idx].id = idx 42 | end 43 | for idx = 1, #edges do 44 | local edge = edges[idx] 45 | edge.source = rev_map[edge.source] 46 | edge.target = rev_map[edge.target] 47 | end 48 | return nodes, edges 49 | end 50 | 51 | local function cleanGraph(nodes, edges) 52 | local idx = 1 53 | while idx <= #nodes do 54 | local node = nodes[idx] 55 | if isNodeGood(node.orig_node) then 56 | idx = idx + 1 57 | else 58 | local id = node.id 59 | table.remove(nodes, idx) 60 | edges = removeNodeFromEdges(id, edges) 61 | end 62 | end 63 | return reIndexNodes(nodes, edges) 64 | end 65 | 66 | local function loadGraph(graph) 67 | local nodes = {} 68 | local edges = {} 69 | for _, node in ipairs(graph.nodes) do 70 | local idx = node.id 71 | table.insert(nodes, {id=idx, orig_node = node} ) 72 | for ich = 1, #node.children do 73 | table.insert( edges, {source = idx, target = node.children[ich].id}) 74 | end 75 | end 76 | nodes, edges = cleanGraph(nodes, edges) 77 | return nodes , edges 78 | end 79 | 80 | local M = {} 81 | 82 | function M.todot( graph, title ) 83 | local nodes, edges = loadGraph(graph) 84 | local str = {} 85 | table.insert(str,'digraph G {\n') 86 | if title then 87 | table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') 88 | end 89 | table.insert(str,'node [shape = oval]; ') 90 | local nodelabels = {} 91 | for i,node in ipairs(nodes) do 92 | local true_node = node.orig_node 93 | local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"' 94 | nodelabels[i] = 'n' .. true_node.id 95 | table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];') 96 | end 97 | table.insert(str,'\n') 98 | for i,edge in ipairs(edges) do 99 | table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n') 100 | end 101 | table.insert(str,'}') 102 | return table.concat(str,'') 103 | end 104 | 105 | function M.dot(g,title,fname) 106 | local gv = M.todot(g, title) 107 | local fngv = (fname or os.tmpname()) .. '.dot' 108 | local fgv = io.open(fngv,'w') 109 | fgv:write(gv) 110 | fgv:close() 111 | local fnsvg = (fname or os.tmpname()) .. '.svg' 112 | os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv) 113 | if not fname then 114 | require 'qtsvg' 115 | local qs = qt.QSvgWidget(fnsvg) 116 | qs:show() 117 | os.remove(fngv) 118 | os.remove(fnsvg) 119 | -- print(fngv,fnpng) 120 | return qs 121 | end 122 | end 123 | 124 | return M 125 | -------------------------------------------------------------------------------- /test/speed.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'nngraph' 3 | 4 | function time_benchmark(model, input, n) 5 | local forward_timer = torch.Timer():stop():reset() 6 | local backward_timer = torch.Timer():stop():reset() 7 | local total_timer = torch.Timer():stop():reset() 8 | local gradOut 9 | total_timer:resume() 10 | for i = 1, n do 11 | forward_timer:resume() 12 | local out = model:forward(input) 13 | forward_timer:stop() 14 | if not gradOut then 15 | gradOut = torch.rand(out:size()) 16 | end 17 | backward_timer:resume() 18 | model:backward(input, gradOut) 19 | backward_timer:stop() 20 | end 21 | total_timer:stop() 22 | 23 | return {forward = forward_timer:time().real, 24 | backward = backward_timer:time().real, 25 | total = total_timer:time().real} 26 | end 27 | 28 | function report_benchmark(result, title) 29 | local nspace = (80-string.len(title))/2 30 | report = {string.rep('#', 80), 31 | string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))), 32 | string.format('Total Time Spent = %.2f s', result.total), 33 | string.format(' Forward Time = %.2f s', result.forward), 34 | string.format(' Backward Time = %.2f s', result.backward), 35 | string.rep('#', 80) 36 | } 37 | print(table.concat(report, '\n')) 38 | return result 39 | end 40 | 41 | function compare_benchmarks(result, base, title) 42 | local nspace = (80-string.len(title))/2 43 | report = {string.rep('#', 80), 44 | string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))), 45 | string.format('Total Time Spent = %.2f %%', result.total/base.total*100), 46 | string.format(' Forward Time = %.2f %%', result.forward/base.forward*100), 47 | string.format(' Backward Time = %.2f %%', result.backward/base.backward*100), 48 | string.rep('#', 80) 49 | } 50 | print(table.concat(report, '\n')) 51 | return result 52 | end 53 | 54 | function get_models(nhidden_layers, ninput, noutput, nhidden) 55 | 56 | local function get_concat_layer(nfrom, nto) 57 | local concat_module = nn.Sequential() 58 | local concat_layer = nn.ConcatTable() 59 | concat_layer:add(nn.Linear(nfrom, nto)) 60 | concat_layer:add(nn.Linear(nfrom, nto)) 61 | concat_module:add(concat_layer) 62 | concat_module:add(nn.CAddTable()) 63 | concat_module:add(nn.ReLU()) 64 | return concat_module 65 | end 66 | 67 | -- NN 68 | local nn_model = nn.Sequential() 69 | nn_model:add(get_concat_layer(ninput, nhidden))--nn.Linear(ninput, nhidden)):add(nn.ReLU()) 70 | for i = 1, nhidden_layers do 71 | nn_model:add(get_concat_layer(nhidden, nhidden))--nn.Linear(nhidden, nhidden)):add(nn.ReLU()) 72 | end 73 | nn_model:add(get_concat_layer(nhidden, noutput))--nn.Linear(nhidden, noutput)) 74 | 75 | -- NN graph 76 | local input = nn.Identity()() 77 | local nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(ninput, nhidden)(input), 78 | nn.Linear(ninput, nhidden)(input)})) 79 | for i = 1, nhidden_layers do 80 | nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, nhidden)(nng_model), 81 | nn.Linear(nhidden, nhidden)(nng_model)})) 82 | end 83 | nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, noutput)(nng_model), 84 | nn.Linear(nhidden, noutput)(nng_model)})) 85 | 86 | nng_model = nn.gModule({input},{nng_model}) 87 | 88 | return {nn_model = nn_model, nng_model = nng_model} 89 | end 90 | 91 | function get_options(arg) 92 | local cmd = torch.CmdLine() 93 | cmd:text('nngraph benchmarking') 94 | cmd:option('-niter', 10, 'number of iterations of forward/backward for each model') 95 | cmd:option('-nhidden_layers', 10, 'number of hidden layers') 96 | cmd:option('-input_size', 512, 'size of input') 97 | cmd:option('-batch_size', 16, 'size of batch') 98 | cmd:option('-hidden_size', 1024, 'size of hidden layer') 99 | cmd:option('-output_size', 128, 'size of output layer') 100 | local opt = cmd:parse(arg) 101 | return opt 102 | end 103 | 104 | local opt = get_options(arg) 105 | models = get_models(opt.nhidden_layers, opt.input_size, opt.output_size, opt.hidden_size) 106 | print(opt) 107 | 108 | local nn_bench = report_benchmark(time_benchmark(models.nn_model, torch.rand(opt.batch_size,opt.input_size), opt.niter), 'NN') 109 | local nng_bench = report_benchmark(time_benchmark(models.nng_model, torch.rand(opt.batch_size,opt.input_size), opt.niter), 'NNGRAPH') 110 | compare_benchmarks(nng_bench, nn_bench, 'NNGRAPH / NN (%)') 111 | 112 | -------------------------------------------------------------------------------- /test/test_JustElement.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'totem' 3 | require 'nngraph' 4 | local test = {} 5 | local tester = totem.Tester() 6 | 7 | function test.test_output() 8 | local input = {torch.randn(7, 11)} 9 | local module = nngraph.JustElement() 10 | tester:eq(module:forward(input), input[1], "output") 11 | end 12 | 13 | function test.test_grad() 14 | local input = {torch.randn(7, 11)} 15 | local module = nngraph.JustElement() 16 | totem.nn.checkGradients(tester, module, input) 17 | end 18 | 19 | function test.test_split() 20 | local in1 = nn.Identity()() 21 | local output = in1:split(1) 22 | local net = nn.gModule({in1}, {output}) 23 | 24 | local input = {torch.randn(7, 11)} 25 | tester:eq(net:forward(input), input[1], "output of split(1)") 26 | end 27 | 28 | tester:add(test):run() 29 | -------------------------------------------------------------------------------- /test/test_JustTable.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'totem' 3 | require 'nngraph' 4 | local test = {} 5 | local tester = totem.Tester() 6 | 7 | function test.test_output() 8 | local input = torch.randn(7, 11) 9 | local module = nngraph.JustTable() 10 | tester:eq(module:forward(input), {input}, "output") 11 | end 12 | 13 | function test.test_grad() 14 | local input = torch.randn(7, 11) 15 | local module = nngraph.JustTable() 16 | totem.nn.checkGradients(tester, module, input) 17 | end 18 | 19 | tester:add(test):run() 20 | -------------------------------------------------------------------------------- /test/test_ModuleFromCriterion.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'totem' 3 | require 'nngraph' 4 | local test = {} 5 | local tester = totem.Tester() 6 | 7 | function test.test_call() 8 | local prediction = nn.Identity()() 9 | local target = nn.Identity()() 10 | local mse = nn.MSECriterion()({prediction, target}) 11 | local costBits = nn.MulConstant(1/math.log(2))(mse) 12 | local net = nn.gModule({prediction, target}, {costBits}) 13 | 14 | local input = {torch.randn(3, 5), torch.rand(3, 5)} 15 | local criterion = nn.MSECriterion() 16 | local output = net:forward(input) 17 | criterion:forward(input[1], input[2]) 18 | tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14) 19 | 20 | local gradOutput = torch.randn(1) 21 | local gradInput = net:backward(input, gradOutput) 22 | criterion:backward(input[1], input[2]) 23 | tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14) 24 | tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget") 25 | end 26 | 27 | function test.test_grad() 28 | local prediction = nn.Identity()() 29 | local zero = nn.MulConstant(0)(prediction) 30 | -- The target is created inside of the nngraph 31 | -- to ignore the zero gradTarget. 32 | local target = nn.AddConstant(1.23)(zero) 33 | local mse = nn.MSECriterion()({prediction, target}) 34 | local net = nn.gModule({prediction}, {mse}) 35 | 36 | local input = torch.randn(4, 7) 37 | totem.nn.checkGradients(tester, net, input) 38 | end 39 | 40 | local function module() 41 | local module = nn.ModuleFromCriterion(nn.MSECriterion()) 42 | local input = {torch.randn(3, 5), torch.randn(3, 5)} 43 | return module, input 44 | end 45 | 46 | function test.test_serializable() 47 | local module, input = module() 48 | totem.nn.checkSerializable(tester, module, input) 49 | end 50 | 51 | function test.test_typeCastable() 52 | local module, input = module() 53 | totem.nn.checkTypeCastable(tester, module, input) 54 | end 55 | 56 | 57 | tester:add(test):run() 58 | -------------------------------------------------------------------------------- /test/test_connectivity.lua: -------------------------------------------------------------------------------- 1 | local totem = require 'totem' 2 | require 'nngraph' 3 | local tests = totem.TestSuite() 4 | local tester = totem.Tester() 5 | 6 | function tests.connectivity() 7 | -- Store debug info here, need to call debug.getinfo on same line as the 8 | -- dangling pointer is declared. 9 | local dInfo 10 | local input = nn.Identity()() 11 | local lin = nn.Linear(20, 10)(input) 12 | -- The Sigmoid does not connect to the output, so should cause an error 13 | -- when we call gModule. 14 | local dangling = nn.Sigmoid()(lin); dInfo = debug.getinfo(1, 'Sl') 15 | local actualOutput = nn.Tanh()(lin) 16 | local errStr = string.format( 17 | 'node declared on %%[%s%%]:%d_ does not connect to gmodule output', 18 | dInfo.short_src, dInfo.currentline) 19 | tester:assertErrorPattern( 20 | function() 21 | return nn.gModule({input}, {actualOutput}) 22 | end, 23 | errStr) 24 | end 25 | 26 | return tester:add(tests):run() 27 | -------------------------------------------------------------------------------- /test/test_debug.lua: -------------------------------------------------------------------------------- 1 | local totem = require 'totem' 2 | require 'nngraph' 3 | local tests = totem.TestSuite() 4 | local tester = totem.Tester() 5 | 6 | function tests.whatIsThisInTheInput() 7 | tester:assertErrorPattern( 8 | function() 9 | local inp1, inp2 = nn.Identity()(), nn.Identity() -- missing 2nd parens 10 | local lin = nn.Linear(20, 10)(nn.CMulTable(){inp1, inp2}) 11 | end, 12 | 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' .. 13 | 'only valid thing to pass is an instance of nngraph%.Node') 14 | 15 | tester:assertErrorPattern( 16 | function() 17 | -- pass-through module, again with same mistake 18 | local graphNode, nnModule = nn.Identity()(), nn.Identity() 19 | return nn.gModule({graphNode, nnModule}, {graphNode}) 20 | end, 21 | 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' .. 22 | 'only valid thing to pass is an instance of nngraph%.Node') 23 | 24 | tester:assertErrorPattern( 25 | function() 26 | local input = nn.Identity()() 27 | local out1 = nn.Linear(20, 10)(input) 28 | local out2 = nn.Sigmoid()(input) 29 | local unconnectedOut = nn.Linear(2, 3) 30 | return nn.gModule({input}, {out1, out2, unconnectedOut}) 31 | end, 32 | 'outputs%[3%] is an nn%.Module, specifically a nn%.Linear, but the ' .. 33 | 'only valid thing to pass is an instance of nngraph%.Node') 34 | 35 | -- Check for detecting a nil in the middle of a table. 36 | tester:assertErrorPattern( 37 | function() 38 | local input = nn.Identity()() 39 | local out1 = nn.Tanh()(input) 40 | local out2 = nn.Sigmoid()(input) 41 | -- nil here is simulating a mis-spelt variable name 42 | return nn.gModule({input}, {out1, nil, out2}) 43 | end, 44 | 'outputs%[2%] is nil %(typo / bad index%?%)') 45 | 46 | tester:assertErrorPattern( 47 | function() 48 | -- Typo variable name returns nil, meaning an empty table 49 | local input = nn.Identity()({aNonExistentVariable}) 50 | end, 51 | 'cannot pass an empty table of inputs%. To indicate no incoming ' .. 52 | 'connections, leave the second set of parens blank%.') 53 | end 54 | 55 | function tests.splitUnused() 56 | -- Need to do debuginfo on the same lines as the other code here to match 57 | -- what debug.getinfo inside those calls will return 58 | local dInfoDeclare, dInfoSplit 59 | local input = nn.Identity()(); dInfoDeclare = debug.getinfo(1, 'Sl') 60 | local output, unused = input:split(2); dInfoSplit = debug.getinfo(1, 'Sl') 61 | 62 | local function willCrash() 63 | return nn.gModule({input}, {output}) 64 | end 65 | 66 | -- Work out what strings will be in the error message 67 | local declareLoc = string.format('%%[%s%%]:%d_', 68 | dInfoDeclare.short_src, 69 | dInfoDeclare.currentline) 70 | local splitLoc = string.format('%%[%s%%]:%d', 71 | dInfoSplit.short_src, 72 | dInfoSplit.currentline) 73 | 74 | tester:assertErrorPattern( 75 | willCrash, 76 | '1 of split%(2%) outputs from the node declared at ' .. 77 | declareLoc .. ' split at ' .. splitLoc .. '%-mnode are unused') 78 | end 79 | 80 | tester:add(tests):run() 81 | -------------------------------------------------------------------------------- /test/test_nest.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'totem' 3 | require 'nngraph' 4 | 5 | local test = {} 6 | local tester = totem.Tester() 7 | 8 | function test.test_output() 9 | local in1 = nn.Identity()() 10 | local in2 = nn.Identity()() 11 | local in3 = nn.Identity()() 12 | local ok = nn.CAddTable()(nngraph.nest({in1})) 13 | local in1Again = nngraph.nest(in1) 14 | local state = nngraph.nest({in1, {in2}, in3}) 15 | 16 | local net = nn.gModule( 17 | {in1, in2, in3}, 18 | {ok, in1Again, state, nngraph.nest({in3}), nngraph.nest({in1, in2})}) 19 | 20 | local val1 = torch.randn(7, 3) 21 | local val2 = torch.randn(2) 22 | local val3 = torch.randn(3) 23 | local expectedOutput = { 24 | val1, val1, {val1, {val2}, val3}, {val3}, {val1, val2}, 25 | } 26 | local output = net:forward({val1, val2, val3}) 27 | tester:eq(output, expectedOutput, "output") 28 | end 29 | 30 | 31 | return tester:add(test):run() 32 | 33 | 34 | -------------------------------------------------------------------------------- /test/test_nngraph.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'totem' 3 | require 'nngraph' 4 | local test = {} 5 | local tester = totem.Tester() 6 | 7 | local function checkGradients(...) 8 | totem.nn.checkGradients(tester, ...) 9 | end 10 | 11 | function test.test_oneOutput() 12 | local in1 = nn.Identity()() 13 | local out1 = nn.Identity()(in1) 14 | local module = nn.gModule({in1}, {out1}) 15 | 16 | local input = torch.Tensor({1}) 17 | module:forward(input) 18 | tester:eq(module.output, torch.Tensor{1}, "output") 19 | local gradInput = module:backward(input, torch.Tensor({-123})) 20 | tester:eq(gradInput, torch.Tensor{-123}, "gradInput") 21 | 22 | local input2 = torch.Tensor({2}) 23 | module:forward(input2) 24 | tester:eq(module.output, torch.Tensor{2}, "output for input2") 25 | gradInput = module:backward(input2, torch.Tensor({-2})) 26 | tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput") 27 | end 28 | 29 | 30 | function test.test_twoOutputs() 31 | local in1 = nn.Identity()() 32 | local out1 = nn.Identity()(in1) 33 | local out2 = nn.Identity()(in1) 34 | local module = nn.gModule({in1}, {out1, out2}) 35 | 36 | local input = torch.Tensor({1}) 37 | module:forward(input) 38 | local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})}) 39 | tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork") 40 | checkGradients(module, input) 41 | end 42 | 43 | function test.test_twoGradOutputs() 44 | local in1 = nn.Sigmoid()() 45 | local splitTable = nn.SplitTable(1)({in1}) 46 | local out1, out2 = splitTable:split(2) 47 | local module = nn.gModule({in1}, {out1, out2}) 48 | 49 | local input = torch.randn(2, 3) 50 | local output = module:forward(input) 51 | assert(#output == 2, "wrong number of outputs") 52 | module:backward(input, {torch.randn(3), torch.randn(3)}) 53 | checkGradients(module, input) 54 | end 55 | 56 | function test.test_twoInputs() 57 | local in1 = nn.Identity()() 58 | local in2 = nn.Identity()() 59 | local prevH, prevCell = in2:split(2) 60 | 61 | local out1 = nn.CMulTable()({in1, prevH, prevCell}) 62 | local module = nn.gModule({in1, in2}, {out1}) 63 | 64 | local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}} 65 | module:forward(input) 66 | local gradInput = module:backward(input, torch.randn(3)) 67 | assert(#gradInput == 2, "wrong number of gradInputs") 68 | assert(type(gradInput[2]) == "table", "wrong gradInput[2] type") 69 | checkGradients(module, input) 70 | end 71 | 72 | function test.test_twoInputs2() 73 | local in1 = nn.Sigmoid()() 74 | local in2 = nn.Sigmoid()() 75 | local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)}) 76 | 77 | local input = {torch.randn(3), torch.randn(3)} 78 | module:forward(input) 79 | local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) 80 | checkGradients(module, input) 81 | end 82 | 83 | function test.test_splitDebugLabels() 84 | local node = nn.Identity()() 85 | node.data.annotations._debugLabel = "node" 86 | local node1, node2 = node:split(2) 87 | assert(node1.data.annotations._debugLabel == "node-1") 88 | assert(node2.data.annotations._debugLabel == "node-2") 89 | end 90 | 91 | function test.test_identity() 92 | local in1 = nn.Identity()() 93 | local in2 = nn.Identity()() 94 | local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)}) 95 | 96 | local input = {torch.randn(3), torch.randn(3)} 97 | module:forward(input) 98 | module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) 99 | checkGradients(module, input) 100 | end 101 | 102 | function test.test_gradInputType() 103 | local xInput = torch.randn(3) 104 | local h = torch.randn(3) 105 | 106 | local x = nn.Identity()() 107 | local prevRnnState = nn.Identity()() 108 | local prevH1, prevCell = prevRnnState:split(2) 109 | local prevH = prevH1 110 | 111 | local cellOut = nn.CAddTable()({ 112 | nn.CMulTable()({x, prevH}), 113 | nn.CMulTable()({prevH, prevCell})}) 114 | local module = nn.gModule({x, prevRnnState}, {cellOut}) 115 | 116 | local c = torch.randn(h:size()) 117 | local prevRnnState = {h, c} 118 | local input = {xInput, prevRnnState} 119 | local output = module:forward(input) 120 | 121 | local gradOutput = torch.randn(h:size()) 122 | local gradInput = module:backward(input, gradOutput) 123 | 124 | local unpack = unpack or table.unpack 125 | local gradX, gradPrevState = unpack(gradInput) 126 | local gradPrevH, gradPrevCell = unpack(gradPrevState) 127 | assert(type(gradPrevH) == type(h), "wrong gradPrevH type") 128 | 129 | tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type") 130 | tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size") 131 | checkGradients(module, input) 132 | end 133 | 134 | function test.test_tabularInput() 135 | local in1 = nn.SplitTable(1)() 136 | local out1 = nn.CAddTable()(in1) 137 | local module = nn.gModule({in1}, {out1}) 138 | 139 | local input = torch.randn(2, 3) 140 | checkGradients(module, input) 141 | end 142 | 143 | function test.test_extraTable() 144 | local in1 = nn.Identity()() 145 | local out1 = nn.Identity()(in1) 146 | local module = nn.gModule({in1}, {out1}) 147 | 148 | local input = torch.Tensor({123}) 149 | tester:eq(module:forward(input), input, "simple output") 150 | tester:eq(module:forward({input}), {input}, "tabular output") 151 | end 152 | 153 | function test.test_accGradParameters() 154 | local input = torch.randn(10) 155 | 156 | local in1 = nn.CMul(input:nElement())() 157 | local out1 = nn.Identity()(in1) 158 | local out2 = nn.Identity()(in1) 159 | local module = nn.gModule({in1}, {out1, out2}) 160 | checkGradients(module, input) 161 | end 162 | 163 | function test.test_example1() 164 | local x1 = nn.Linear(20,10)() 165 | local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1)))) 166 | local mlp = nn.gModule({x1},{mout}) 167 | 168 | local x = torch.rand(20) 169 | checkGradients(mlp, x) 170 | end 171 | 172 | function test.test_example2() 173 | local x1=nn.Linear(20,20)() 174 | local x2=nn.Linear(10,10)() 175 | local m0=nn.Linear(20,1)(nn.Tanh()(x1)) 176 | local m1=nn.Linear(10,1)(nn.Tanh()(x2)) 177 | local madd=nn.CAddTable()({m0,m1}) 178 | local m2=nn.Sigmoid()(madd) 179 | local m3=nn.Tanh()(madd) 180 | local gmod = nn.gModule({x1,x2},{m2,m3}) 181 | 182 | local x = torch.rand(20) 183 | local y = torch.rand(10) 184 | checkGradients(gmod, {x, y}) 185 | end 186 | 187 | function test.test_example3() 188 | local m = nn.Sequential() 189 | m:add(nn.SplitTable(1)) 190 | m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) 191 | local input = nn.Identity()() 192 | local input1,input2 = m(input):split(2) 193 | local m3 = nn.JoinTable(1)({input1,input2}) 194 | local g = nn.gModule({input},{m3}) 195 | 196 | local indata = torch.rand(2,10) 197 | checkGradients(g, indata) 198 | end 199 | 200 | function test.test_example4() 201 | local input = nn.Identity()() 202 | local L1 = nn.Tanh()(nn.Linear(1,2)(input)) 203 | local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1}))) 204 | local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2}))) 205 | local g = nn.gModule({input},{L3}) 206 | 207 | local indata = torch.rand(1) 208 | checkGradients(g, indata) 209 | end 210 | 211 | function test.test_type() 212 | local in1 = nn.Linear(20,10)() 213 | local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1)))) 214 | local module = nn.gModule({in1}, {out1}) 215 | local input = torch.rand(20) 216 | local output = module:forward(input) 217 | local gradOutput = output:clone():normal() 218 | local gradInput = module:backward(input, gradOutput) 219 | 220 | module:backward(input, output) 221 | tester:eq(torch.typename(output), "torch.DoubleTensor") 222 | tester:eq(torch.typename(module.output), "torch.DoubleTensor") 223 | tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor") 224 | tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor") 225 | tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") 226 | tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") 227 | tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor") 228 | tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.DoubleTensor") 229 | 230 | module:float() 231 | tester:eq(torch.typename(module.output), "torch.FloatTensor") 232 | tester:eq(torch.typename(module.gradInput), "torch.FloatTensor") 233 | tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") 234 | tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") 235 | tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") 236 | tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor") 237 | tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.FloatTensor") 238 | local output = module:forward(input:float()) 239 | tester:eq(torch.typename(output), "torch.FloatTensor") 240 | local gradInput = module:backward(input:float(), gradOutput:float()) 241 | tester:eq(torch.typename(gradInput), "torch.FloatTensor") 242 | 243 | end 244 | 245 | function test.test_nestedGradInput() 246 | local x = nn.Identity()() 247 | local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh()) 248 | local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity()) 249 | local out = nn.CAddTable()({h1(x), h2(x)}) 250 | 251 | local model = nn.gModule({x}, {out}) 252 | 253 | local input = {} 254 | input[1] = torch.randn(3, 3) 255 | input[2] = torch.randn(3, 3) 256 | input[3] = torch.randn(3, 3) 257 | 258 | checkGradients(model, input) 259 | 260 | local input = {} 261 | input[1] = torch.randn(2, 3) 262 | input[2] = torch.randn(2, 3) 263 | input[3] = torch.randn(2, 3) 264 | 265 | checkGradients(model, input) 266 | end 267 | 268 | function test.test_unusedInput() 269 | local x = nn.Identity()() 270 | local h = nn.Identity()() 271 | local h2 = nn.Identity()() 272 | 273 | local ok, result = pcall(nn.gModule, {x, h}, {x}) 274 | assert(not ok, "the unused input should be detected") 275 | end 276 | 277 | function test.test_unusedChild() 278 | local prevState = nn.Identity()() 279 | local h, cell = prevState:split(2) 280 | 281 | local ok, result = pcall(nn.gModule, {prevState}, {h}) 282 | assert(not ok, "the unused cell should be detected") 283 | end 284 | 285 | function test.test_nilInput() 286 | local ok, result = pcall(function() nn.Sigmoid()(nil) end) 287 | assert(not ok, "the nil input should be detected") 288 | end 289 | 290 | function test.test_unusedNode() 291 | local in1 = nn.Identity()() 292 | local in2 = nn.Identity()() 293 | local middleResult = nn.Sigmoid()(in2) 294 | local out1 = nn.Sigmoid()(in1) 295 | 296 | local ok, result = pcall(nn.gModule, {in1, in2}, {out1}) 297 | assert(not ok, "the unused middleResult should be detected") 298 | end 299 | 300 | function test.test_usageAfterSplit() 301 | local prevState = nn.Identity()() 302 | local h, cell = prevState:split(2) 303 | local nextState = nn.Identity()(prevState) 304 | local transformed = nn.Sigmoid()(cell) 305 | 306 | local model = nn.gModule({prevState}, {h, nextState, transformed}) 307 | local nHidden = 10 308 | local input = {torch.randn(nHidden), torch.randn(nHidden)} 309 | checkGradients(model, input) 310 | end 311 | 312 | function test.test_resizeNestedAs() 313 | local in1 = nn.Identity()() 314 | local out1 = nn.Identity()(in1) 315 | local out2 = nn.Identity()(in1) 316 | 317 | local net = nn.gModule({in1}, {out1, out2}) 318 | local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} 319 | net:forward(input) 320 | net:backward(input, net.output) 321 | checkGradients(net, input) 322 | 323 | input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}} 324 | net:forward(input) 325 | net:backward(input, net.output) 326 | checkGradients(net, input) 327 | 328 | input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} 329 | net:forward(input) 330 | local gradInput = net:backward(input, net.output) 331 | tester:eq(#(gradInput[2]), 2, "gradInput[2] size") 332 | checkGradients(net, input) 333 | end 334 | 335 | 336 | function test.test_annotateGraph() 337 | local input = nn.Identity()():annotate( 338 | {name = 'Input', description = 'DescA', 339 | graphAttributes = {color = 'red'}}) 340 | 341 | local hidden_a = nn.Linear(10, 10)(input):annotate( 342 | {name = 'Hidden A', description = 'DescB', 343 | graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}}) 344 | local hidden_b = nn.Sigmoid()(hidden_a) 345 | local output = nn.Linear(10, 10)(hidden_b) 346 | local net = nn.gModule({input}, {output}) 347 | 348 | tester:assert(hidden_a:label():match('DescB') ~= nil) 349 | local fg_tmpfile = os.tmpname() 350 | local bg_tmpfile = os.tmpname() 351 | if not pcall(function() graph.dot(net.fg, 'Test', fg_tmpfile) end) then 352 | return -- prevent graphviz not found error 353 | end 354 | graph.dot(net.fg, 'Test BG', bg_tmpfile) 355 | 356 | local function checkDotFile(tmpfile) 357 | local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all") 358 | tester:assert( 359 | dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]') ~= nil) 360 | tester:assert( 361 | dotcontent:match( 362 | '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]') ~= nil) 363 | tester:assert( 364 | dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]') ~= nil) 365 | tester:assert( 366 | dotcontent:match( 367 | '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]') ~= nil) 368 | end 369 | 370 | checkDotFile(fg_tmpfile) 371 | checkDotFile(bg_tmpfile) 372 | end 373 | 374 | function test.test_splitMore() 375 | local nSplits = 2 376 | local in1 = nn.Identity()() 377 | local out1, out2 = nn.SplitTable(2)(in1):split(nSplits) 378 | 379 | local model = nn.gModule({in1}, {out1, out2}) 380 | local input = torch.randn(10, nSplits + 1) 381 | local ok, result = pcall(model.forward, model, input) 382 | assert(not ok, "the extra input to split should be detected") 383 | end 384 | 385 | function test.test_splitLess() 386 | local nSplits = 3 387 | local in1 = nn.Identity()() 388 | local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits) 389 | 390 | local model = nn.gModule({in1}, {out1, out2, out3}) 391 | local input = torch.randn(10, nSplits - 1) 392 | local ok, result = pcall(model.forward, model, input) 393 | assert(not ok, "the missing input to split should be detected") 394 | end 395 | 396 | function test.test_gradOutputZeroOptim() 397 | local unpack = function(...) 398 | if _G[unpack] then return _G[unpack](...) 399 | else return table.unpack(...) end 400 | end 401 | -- Make module that produces an expanded zero gradInput tensor 402 | local dummyModule = nn.Module() 403 | dummyModule.updateOutput = function(self, input) 404 | self.output = torch.Tensor(1, 2, 3):uniform() 405 | return self.output 406 | end 407 | dummyModule.updateGradInput = function(self, input, gradOutput) 408 | local zeroTensor = torch.Tensor{0} 409 | :view(unpack(torch.ones(input:dim()):totable())) 410 | :expandAs(input) 411 | self.gradInput = zeroTensor 412 | return self.gradInput 413 | end 414 | 415 | -- First input and final gradOutput 416 | local input = torch.Tensor(1, 2, 3):uniform() 417 | local gradOutput = torch.Tensor(1, 2, 3):uniform() 418 | 419 | -- First case: one intermediary gradOutput is going to be zero 420 | local x = nn.Identity()() 421 | local h1 = dummyModule:clone()(x) 422 | local h2 = nn.Identity()(x) 423 | local y = nn.CAddTable()({h1, h2}) 424 | local mod = nn.gModule({x}, {y}) 425 | 426 | local ok, result = pcall(nn.Module.forward, mod, input) 427 | assert(ok, "forward should succeed") 428 | 429 | nn.Module.backward( mod, input, gradOutput) 430 | ok, result = pcall(nn.Module.backward, mod, input, gradOutput) 431 | assert(ok, "backward should succeed") 432 | 433 | -- Second case: all intermediary gradOutputs are going to be zero 434 | local x = nn.Identity()() 435 | local h1 = dummyModule:clone()(x) 436 | local h2 = dummyModule:clone()(x) 437 | local y = nn.CAddTable()({h1, h2}) 438 | local mod = nn.gModule({x}, {y}) 439 | 440 | local ok, result = pcall(nn.Module.forward, mod, input) 441 | assert(ok, "forward should succeed") 442 | 443 | ok, result = pcall(nn.Module.backward, mod, input, gradOutput) 444 | assert(ok, "backward should succeed") 445 | end 446 | 447 | function test.test_replace() 448 | local i = nn.Identity()() 449 | local l1 = nn.Linear(5, 2)(i) 450 | local sig = nn.Sigmoid()(l1) 451 | local l2 = nn.Linear(2, 5)(sig) 452 | local model = nn.gModule({i}, {l2}) 453 | 454 | local input = torch.randn(4, 5) 455 | local gradOutput = torch.randn(4, 5) 456 | tester:eq(model:forward(input):size(), input:size(), "inconsistent output size") 457 | tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size") 458 | 459 | model:replace(function(m) 460 | if torch.type(m) == 'nn.Linear' then 461 | if m.weight:size(1) == 5 then 462 | return nn.Linear(2, 10) 463 | elseif m.weight:size(1) == 2 then 464 | return nn.Linear(10, 2) 465 | end 466 | elseif torch.type(m) == 'nn.Sigmoid' then 467 | return nn.Tanh() 468 | end 469 | return m 470 | end) 471 | 472 | local input = torch.randn(4, 10) 473 | local gradOutput = torch.randn(4, 10) 474 | tester:eq(model:forward(input):size(), input:size(), "inconsistent output size") 475 | tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size") 476 | 477 | tester:ne(model.modules[2], l1, "gModule.modules wasn't updated") 478 | tester:ne(model.modules[3], sig, "gModule.modules wasn't updated") 479 | tester:eq(torch.type(model.modules[3]), 'nn.Tanh', "replace didn't update gModule.modules") 480 | tester:ne(model.modules[4], l2, "gModule.modules wasn't updated") 481 | end 482 | 483 | tester:add(test):run() 484 | -------------------------------------------------------------------------------- /test/test_old.lua: -------------------------------------------------------------------------------- 1 | require 'nngraph' 2 | function t1() 3 | local x1 = nn.Linear(20,20)() 4 | local x2 = nn.Linear(10,10)() 5 | local m0=nn.Linear(20,1)(nn.Tanh()(x1)) 6 | local m1=nn.Linear(10,1)(nn.Tanh()(x2)) 7 | local madd=nn.CAddTable()({m0,m1}) 8 | local m2=nn.Sigmoid()(madd) 9 | local m3=nn.Tanh()(madd) 10 | local x = torch.rand(20) 11 | local y = torch.rand(10) 12 | gmod = nn.gModule({x1,x2},{m2,m3}) 13 | gmod.verbose = true 14 | print('forward') 15 | gmod:updateOutput({x,y}) 16 | print('updateGradInput') 17 | gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)}) 18 | graph.dot(gmod.fg,'forward') 19 | graph.dot(gmod.bg,'backward') 20 | end 21 | 22 | function t2() 23 | print('compare') 24 | local m0 = nn.Linear(5,10)() 25 | local m1 = nn.Linear(10,20)() 26 | local m2 = nn.Linear(30,50)(nn.JoinTable(1){m0,m1}) 27 | gmod = nn.gModule({m0,m1},{m2}) 28 | 29 | local nn0 = nn.Linear(5,10) 30 | local nn1 = nn.Linear(10,20) 31 | local nn2 = nn.Linear(30,50) 32 | local nnmod = nn.Sequential():add(nn.ParallelTable():add(nn0):add(nn1)):add(nn.JoinTable(1)):add(nn2) 33 | 34 | nn0.weight:copy(m0.data.module.weight) 35 | nn0.bias:copy(m0.data.module.bias) 36 | nn1.weight:copy(m1.data.module.weight) 37 | nn1.bias:copy(m1.data.module.bias) 38 | nn2.weight:copy(m2.data.module.weight) 39 | nn2.bias:copy(m2.data.module.bias) 40 | 41 | 42 | for i=1,5 do 43 | local x,y = torch.rand(5),torch.rand(10) 44 | local xx,yy = x:clone(),y:clone() 45 | 46 | gmod:updateOutput({x,y}) 47 | nnmod:updateOutput({xx,yy}) 48 | print('fdiff = ', torch.dist(gmod.output,nnmod.output)) 49 | 50 | local odx = torch.rand(50) 51 | local odxx = odx:clone() 52 | 53 | gmod:updateGradInput({x,y},odx) 54 | nnmod:updateGradInput({xx,yy},odxx) 55 | graph.dot(gmod.fg,tostring(i)) 56 | for i,v in ipairs(gmod.gradInput) do 57 | print('bdiff [' ..i.. '] = ', torch.dist(gmod.gradInput[i],nnmod.gradInput[i])) 58 | end 59 | end 60 | 61 | local gms = {m0,m1,m2} 62 | local nms = {nn0,nn1,nn2} 63 | 64 | for i=1,5 do 65 | local x,y = torch.rand(5),torch.rand(10) 66 | local xx,yy = x:clone(),y:clone() 67 | 68 | gmod:updateOutput({x,y}) 69 | nnmod:updateOutput({xx,yy}) 70 | print('fdiff = ', torch.dist(gmod.output,nnmod.output)) 71 | 72 | local odx = torch.rand(50) 73 | local odxx = odx:clone() 74 | 75 | gmod:zeroGradParameters() 76 | nnmod:zeroGradParameters() 77 | 78 | gmod:updateGradInput({x,y},odx) 79 | nnmod:updateGradInput({xx,yy},odxx) 80 | 81 | gmod:accGradParameters({x,y},odx) 82 | nnmod:accGradParameters({xx,yy},odxx) 83 | graph.dot(gmod.fg) 84 | for i,v in ipairs(gms) do 85 | print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradWeight,nms[i].gradWeight)) 86 | print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradBias,nms[i].gradBias)) 87 | end 88 | end 89 | end 90 | 91 | function t3() 92 | mlp=nn.Sequential(); --Create a network that takes a Tensor as input 93 | mlp:add(nn.SplitTable(2)) 94 | c=nn.ParallelTable() --The two Tensors go through two different Linear 95 | c:add(nn.Linear(10,3)) --Layers in Parallel 96 | c:add(nn.Linear(10,7)) 97 | mlp:add(c) --Outputing a table with 2 elements 98 | p=nn.ParallelTable() --These tables go through two more linear layers 99 | p:add(nn.Linear(3,2)) -- separately. 100 | p:add(nn.Linear(7,1)) 101 | mlp:add(p) 102 | mlp:add(nn.JoinTable(1)) --Finally, the tables are joined together and output. 103 | 104 | pred=mlp:forward(torch.randn(10,2)) 105 | print(pred) 106 | 107 | for i=1,25 do -- A few steps of training such a network.. 108 | x=torch.ones(10,2); 109 | y=torch.Tensor(3); y:copy(x:select(2,1,1):narrow(1,1,3)) 110 | pred=mlp:forward(x) 111 | 112 | criterion= nn.MSECriterion() 113 | local err=criterion:forward(pred,y) 114 | local gradCriterion = criterion:backward(pred,y); 115 | print(x,y) 116 | mlp:zeroGradParameters(); 117 | mlp:backward(x, gradCriterion); 118 | mlp:updateParameters(0.05); 119 | 120 | print(err) 121 | end 122 | end 123 | 124 | function t4() 125 | local getInput1 = nn.Identity()() 126 | local getInput2 = nn.Identity()() 127 | local mlp = nn.Tanh()(getInput1) 128 | net = nn.gModule({getInput1, getInput2}, {mlp, getInput2}) 129 | 130 | 131 | local input1 = torch.randn(2) 132 | local input2 = torch.randn(5) 133 | 134 | net:forward({input1, input2}) 135 | local gradInput = net:backward({input1, input2}, 136 | {torch.randn(input1:size()), torch.randn(input2:size())}) 137 | print("gradInput[1]:", gradInput[1]) 138 | print("gradInput[2]:", gradInput[2]) 139 | graph.dot(net.fg) 140 | assert(gradInput[1]:nElement() == input1:nElement(), "size mismatch") 141 | 142 | end 143 | 144 | function t5() 145 | local m = nn.Sequential() 146 | m:add(nn.SplitTable(1)) 147 | m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) 148 | local input = nn.Identity()() 149 | local input1,input2 = m(input):split(2) 150 | local m3 = nn.JoinTable(1)({input1,input2}) 151 | 152 | g = nn.gModule({input},{m3}) 153 | graph.dot(g.fg,'init forward') 154 | 155 | local indata = torch.rand(2,10) 156 | local gdata = torch.rand(50) 157 | g:forward(indata) 158 | g:backward(indata,gdata) 159 | 160 | graph.dot(g.fg,'forward') 161 | graph.dot(g.bg,'backward') 162 | end 163 | 164 | function topsort(a) 165 | -- first clone the graph 166 | -- local g = self:clone() 167 | -- local nodes = g.nodes 168 | -- local edges = g.edges 169 | -- for i,node in ipairs(nodes) do 170 | -- node.children = {} 171 | -- end 172 | 173 | -- reverse the graph 174 | rg,map = a:reverse() 175 | local rmap = {} 176 | for k,v in pairs(map) do 177 | rmap[v] = k 178 | end 179 | 180 | -- work on the sorted graph 181 | sortednodes = {} 182 | rootnodes = rg:roots() 183 | 184 | if #rootnodes == 0 then 185 | print('Graph has cycles') 186 | end 187 | 188 | -- run 189 | for i,root in ipairs(rootnodes) do 190 | root:dfs(function(node) 191 | print(node.id,rmap[node].id) 192 | -- print(rmap[node]) 193 | table.insert(sortednodes,rmap[node]) end) 194 | end 195 | 196 | if #sortednodes ~= #a.nodes then 197 | print('Graph has cycles') 198 | end 199 | return sortednodes,rg,rootnodes 200 | end 201 | 202 | local my={eq = 203 | function(a,b,s) 204 | if a:dist(b) == 0 then 205 | print('ok') 206 | else 207 | print('error : ' .. s) 208 | print('a : ');print(a) 209 | print('b : ');print(b) 210 | end 211 | end} 212 | 213 | function t8() 214 | local in1 = nn.Identity()() 215 | local m = nn.Linear(10,10)(in1) 216 | local out1 = nn.Tanh()(m) 217 | local out2 = nn.Tanh()(m) 218 | local out = nn.CAddTable(){out1, out2} 219 | local mod = nn.gModule({in1}, {out}) 220 | 221 | local dot = nngraph.simple_print.todot(mod.fg, 'bogus') 222 | print (dot) 223 | nngraph.simple_print.dot(mod.fg, 'bogus', 'new') 224 | graph.dot(mod.fg, 'bogus', 'old') 225 | end 226 | -- t2() 227 | t8() 228 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | local utils = {} 2 | 3 | function utils.istorchclass(x) 4 | return type(x) == 'table' and torch.typename(x) 5 | end 6 | 7 | function utils.istable(x) 8 | return type(x) == 'table' and not torch.typename(x) 9 | end 10 | 11 | --[[ Returns a useful error message when a nngraph.Node is expected. ]] 12 | function utils.expectingNodeErrorMessage(badVal, array, idx) 13 | if badVal == nil then 14 | return string.format('%s[%d] is nil (typo / bad index?)', array, idx) 15 | elseif torch.isTypeOf(badVal, 'nn.Module') then 16 | local errStr = '%s[%d] is an nn.Module, specifically a %s, but the ' .. 17 | 'only valid thing to pass is an instance of ' .. 18 | 'nngraph.Node. Did you forget a second set of parens, ' .. 19 | 'which convert a nn.Module to a nngraph.Node?' 20 | return string.format(errStr, array, idx, torch.typename(badVal)) 21 | else 22 | local errStr = '%s[%d] should be an nngraph.Node but is of type %s' 23 | return string.format(errStr, array, idx, 24 | torch.typename(badVal) or type(badVal)) 25 | end 26 | end 27 | 28 | --[[ Lua 5.2+ removed table.maxn, provide fallback implementation. ]] 29 | if table.maxn then 30 | utils.tableMaxN = table.maxn 31 | else 32 | function utils.tableMaxN(tbl) 33 | local max = 0 34 | for k, v in pairs(tbl) do 35 | if type(k) == 'number' and k > max then 36 | max = k 37 | end 38 | end 39 | return max 40 | end 41 | end 42 | return utils 43 | --------------------------------------------------------------------------------