├── .gitignore
├── .travis.yml
├── CMakeLists.txt
├── COPYRIGHT.txt
├── JustElement.lua
├── JustTable.lua
├── ModuleFromCriterion.lua
├── README.md
├── doc
├── annotation_bg.png
├── annotation_fg.png
├── mlp.png
├── mlp2.png
├── mlp3_backward.png
├── mlp3_forward.png
├── mlp4_backward.png
├── mlp4_forward.png
└── my_bad_linear_net.png
├── gmodule.lua
├── graphinspecting.lua
├── init.lua
├── nest.lua
├── nesting.lua
├── nngraph-scm-1.rockspec
├── node.lua
├── simple_print.lua
├── test
├── speed.lua
├── test_JustElement.lua
├── test_JustTable.lua
├── test_ModuleFromCriterion.lua
├── test_connectivity.lua
├── test_debug.lua
├── test_nest.lua
├── test_nngraph.lua
└── test_old.lua
└── utils.lua
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: c
2 | compiler:
3 | - gcc
4 | - clang
5 | cache:
6 | directories:
7 | - $HOME/OpenBlasInstall
8 | - $HOME/GraphViz
9 | sudo: false
10 | env:
11 | - TORCH_LUA_VERSION=LUAJIT21
12 | - TORCH_LUA_VERSION=LUA51
13 | - TORCH_LUA_VERSION=LUA52
14 | addons:
15 | apt:
16 | packages:
17 | - cmake
18 | - gfortran
19 | - gcc-multilib
20 | - gfortran-multilib
21 | - liblapack-dev
22 | - build-essential
23 | - gcc
24 | - g++
25 | - curl
26 | - cmake
27 | - libreadline-dev
28 | - git-core
29 | - libqt4-core
30 | - libqt4-gui
31 | - libqt4-dev
32 | - libjpeg-dev
33 | - libpng-dev
34 | - ncurses-dev
35 | - imagemagick
36 | - libzmq3-dev
37 | - gfortran
38 | - unzip
39 | - gnuplot
40 | - gnuplot-x11
41 | before_script:
42 | - export ROOT_TRAVIS_DIR=$(pwd)
43 | - export INSTALL_PREFIX=~/torch/install
44 | - ls $HOME/OpenBlasInstall/lib || (cd /tmp/ && git clone https://github.com/xianyi/OpenBLAS.git -b master && cd OpenBLAS && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make PREFIX=$HOME/OpenBlasInstall install)
45 | - ls $HOME/GraphViz/lib || (cd /tmp/ && wget -c http://www.graphviz.org/pub/graphviz/stable/SOURCES/graphviz-2.38.0.tar.gz && tar -xvf graphviz-2.38.0.tar.gz && cd graphviz-2.38.0 && (./configure prefix=$HOME/GraphViz/ 2>/dev/null >/dev/null) && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make install)
46 | - export LD_LIBRARY_PATH=$HOME/GraphViz/lib:$LD_LIBRARY_PATH
47 | - git clone https://github.com/torch/distro.git ~/torch --recursive
48 | - cd ~/torch && git submodule update --init --recursive
49 | - mkdir build && cd build
50 | - export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH
51 | - cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON
52 | - make && make install
53 | - ${INSTALL_PREFIX}/bin/luarocks install totem
54 | - if [[ $TORCH_LUA_VERSION != 'LUAJIT21' && $TORCH_LUA_VERSION != 'LUAJIT20' ]]; then ${INSTALL_PREFIX}/bin/luarocks install luaffi; fi
55 | - cd $ROOT_TRAVIS_DIR
56 | - export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH
57 | script:
58 | - ${INSTALL_PREFIX}/bin/luarocks make
59 | - export PATH=${INSTALL_PREFIX}/bin:$PATH
60 | - export LD_LIBRARY_PATH=$HOME/GraphViz/lib:$LD_LIBRARY_PATH
61 | - export TESTLUA=$(which luajit lua | head -n 1)
62 | - ${TESTLUA} -lnngraph -e "print('nngraph loaded succesfully')"
63 | - cd test
64 | - ${TESTLUA} test_ModuleFromCriterion.lua
65 | - ${TESTLUA} test_nest.lua
66 | - ${TESTLUA} test_nngraph.lua
67 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 |
2 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
3 | CMAKE_POLICY(VERSION 2.6)
4 | FIND_PACKAGE(Torch REQUIRED)
5 |
6 | FILE(GLOB luasrc *.lua)
7 |
8 | ADD_TORCH_PACKAGE(nngraph "" "${luasrc}" "Neural Net Graph Package")
9 |
--------------------------------------------------------------------------------
/COPYRIGHT.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
2 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
3 | Copyright (c) 2011-2013 NYU (Clement Farabet)
4 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
5 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
6 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
7 |
8 | All rights reserved.
9 |
10 | Redistribution and use in source and binary forms, with or without
11 | modification, are permitted provided that the following conditions are met:
12 |
13 | 1. Redistributions of source code must retain the above copyright
14 | notice, this list of conditions and the following disclaimer.
15 |
16 | 2. Redistributions in binary form must reproduce the above copyright
17 | notice, this list of conditions and the following disclaimer in the
18 | documentation and/or other materials provided with the distribution.
19 |
20 | 3. Neither the names of NEC Laboratories American and IDIAP Research
21 | Institute nor the names of its contributors may be used to endorse or
22 | promote products derived from this software without specific prior
23 | written permission.
24 |
25 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35 | POSSIBILITY OF SUCH DAMAGE.
36 |
--------------------------------------------------------------------------------
/JustElement.lua:
--------------------------------------------------------------------------------
1 |
2 | local JustElement, parent = torch.class('nngraph.JustElement', 'nn.Module')
3 | function JustElement:__init()
4 | self.gradInput = {}
5 | end
6 |
7 | -- The input is a table with one element.
8 | -- The output the element from the table.
9 | function JustElement:updateOutput(input)
10 | assert(#input == 1, "expecting one element")
11 | self.output = input[1]
12 | return self.output
13 | end
14 |
15 | function JustElement:updateGradInput(input, gradOutput)
16 | self.gradInput[1] = gradOutput
17 | return self.gradInput
18 | end
19 |
--------------------------------------------------------------------------------
/JustTable.lua:
--------------------------------------------------------------------------------
1 |
2 | local JustTable, parent = torch.class('nngraph.JustTable', 'nn.Module')
3 | function JustTable:__init()
4 | self.output = {}
5 | end
6 |
7 | -- The input is one element.
8 | -- The output is a table with one element: {element}
9 | function JustTable:updateOutput(input)
10 | self.output[1] = input
11 | return self.output
12 | end
13 |
14 | function JustTable:updateGradInput(input, gradOutput)
15 | self.gradInput = gradOutput[1]
16 | return self.gradInput
17 | end
18 |
--------------------------------------------------------------------------------
/ModuleFromCriterion.lua:
--------------------------------------------------------------------------------
1 |
2 | --[[ A wrapper to turn a criterion into a module.
3 |
4 | The gradient with respect to the target will be zero.
5 | --]]
6 | local ModuleFromCriterion, parent = torch.class('nn.ModuleFromCriterion','nn.Module')
7 | function ModuleFromCriterion:__init(criterion)
8 | self.criterion = criterion
9 | self.output = torch.Tensor(1)
10 | self.gradInput = {torch.Tensor(), torch.Tensor()}
11 | end
12 |
13 | local unpack = unpack or table.unpack -- lua52 compat
14 |
15 | --[[ The input is a {prediction, target} pair.
16 | The output is a tensor with one number: the criterion output.
17 | --]]
18 | function ModuleFromCriterion:updateOutput(input)
19 | local prediction, target = unpack(input)
20 | self.output[1] = self.criterion:updateOutput(prediction, target)
21 | return self.output
22 | end
23 |
24 | function ModuleFromCriterion:updateGradInput(input, gradOutput)
25 | local prediction, target = unpack(input)
26 | local gradPrediction = self.criterion:updateGradInput(prediction, target)
27 | if type(gradPrediction) == 'table' then
28 | if type(self.gradInput[1]) ~= 'table' then
29 | self.gradInput[1] = {} -- initializing to table first time if it is tensor (which it is: line 10)
30 | for i=1, #gradPrediction do
31 | self.gradInput[1][i] = gradPrediction[i].new() -- and putting tensors of right size inside.
32 | end
33 | end
34 | for i=1, #gradPrediction do
35 | self.gradInput[1][i]:resizeAs(gradPrediction[i]):copy(gradPrediction[i]):mul(gradOutput[1])
36 | end
37 | else
38 | self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1])
39 | end
40 | self.gradInput[2]:resizeAs(target):zero()
41 | return self.gradInput
42 | end
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Network Graph Package
2 |
3 | [](https://travis-ci.org/torch/nngraph)
4 |
5 | This package provides graphical computation for `nn` library in [Torch](https://github.com/torch/torch7/blob/master/README.md).
6 |
7 | ## Requirements
8 |
9 | You do *not* need `graphviz` to be able to use this library, but if you have it you will be able to display the graphs that you have created. For installing the package run the appropriate command below:
10 |
11 | ```bash
12 | # Mac users
13 | brew install graphviz
14 | # Debian/Ubuntu users
15 | sudo apt-get install graphviz -y
16 | ```
17 |
18 | ## Usage
19 |
20 | [Plug: A more explanatory nngraph tutorial by Nando De Freitas of Oxford](https://www.cs.ox.ac.uk/people/nando.defreitas/machinelearning/practicals/practical5.pdf)
21 |
22 | The aim of this library is to provide users of `nn` package with tools to easily create complicated architectures.
23 | Any given `nn` `module` is going to be bundled into a *graph node*.
24 | The `__call__` operator of an instance of `nn.Module` is used to create architectures as if one is writing function calls.
25 |
26 | ### Two hidden layers MLP
27 |
28 | ```lua
29 | h1 = nn.Linear(20, 10)()
30 | h2 = nn.Linear(10, 1)(nn.Tanh()(nn.Linear(10, 10)(nn.Tanh()(h1))))
31 | mlp = nn.gModule({h1}, {h2})
32 |
33 | x = torch.rand(20)
34 | dx = torch.rand(1)
35 | mlp:updateOutput(x)
36 | mlp:updateGradInput(x, dx)
37 | mlp:accGradParameters(x, dx)
38 |
39 | -- draw graph (the forward graph, '.fg')
40 | graph.dot(mlp.fg, 'MLP')
41 | ```
42 |
43 |
44 |
45 | Read this diagram from top to bottom, with the first and last nodes being *dummy nodes* that regroup all inputs and outputs of the graph.
46 | The `module` entry describes the function of the node, as applies to `input`, and producing a result of the shape `gradOutput`; `mapindex` contains pointers to the parent nodes.
47 |
48 | To save the *graph* on file, specify the file name, and both a `dot` and `svg` files will be saved. For example, you can type:
49 |
50 | ```lua
51 | graph.dot(mlp.fg, 'MLP', 'myMLP')
52 | ```
53 |
54 | You can also use the `__unm__` and `__sub__` operators to replace all `__call__`:
55 | ```lua
56 | h1 = - nn.Linear(20,10)
57 | h2 = h1
58 | - nn.Tanh()
59 | - nn.Linear(10,10)
60 | - nn.Tanh()
61 | - nn.Linear(10, 1)
62 | mlp = nn.gModule({h1}, {h2})
63 | ```
64 |
65 |
66 | ### A network with 2 inputs and 2 outputs
67 |
68 | ```lua
69 | h1 = nn.Linear(20, 20)()
70 | h2 = nn.Linear(10, 10)()
71 | hh1 = nn.Linear(20, 1)(nn.Tanh()(h1))
72 | hh2 = nn.Linear(10, 1)(nn.Tanh()(h2))
73 | madd = nn.CAddTable()({hh1, hh2})
74 | oA = nn.Sigmoid()(madd)
75 | oB = nn.Tanh()(madd)
76 | gmod = nn.gModule({h1, h2}, {oA, oB})
77 |
78 | x1 = torch.rand(20)
79 | x2 = torch.rand(10)
80 |
81 | gmod:updateOutput({x1, x2})
82 | gmod:updateGradInput({x1, x2}, {torch.rand(1), torch.rand(1)})
83 | graph.dot(gmod.fg, 'Big MLP')
84 | ```
85 |
86 | Alternatively, you can use `-` to make your code looks like the data flow:
87 |
88 | ```lua
89 | h1 = - nn.Linear(20,20)
90 | h2 = - nn.Linear(10,10)
91 | hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
92 | hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
93 | madd = {hh1,hh2} - nn.CAddTable()
94 | oA = madd - nn.Sigmoid()
95 | oB = madd - nn.Tanh()
96 | gmod = nn.gModule( {h1,h2}, {oA,oB} )
97 | ```
98 |
99 |
100 |
101 |
102 | ### A network with containers
103 |
104 | Another net that uses container modules (like `ParallelTable`) that output a table of outputs.
105 |
106 | ```lua
107 | m = nn.Sequential()
108 | m:add(nn.SplitTable(1))
109 | m:add(nn.ParallelTable():add(nn.Linear(10, 20)):add(nn.Linear(10, 30)))
110 | input = nn.Identity()()
111 | input1, input2 = m(input):split(2)
112 | m3 = nn.JoinTable(1)({input1, input2})
113 |
114 | g = nn.gModule({input}, {m3})
115 |
116 | indata = torch.rand(2, 10)
117 | gdata = torch.rand(50)
118 | g:forward(indata)
119 | g:backward(indata, gdata)
120 |
121 | graph.dot(g.fg, 'Forward Graph')
122 | graph.dot(g.bg, 'Backward Graph')
123 | ```
124 |
125 |
126 |
127 |
128 |
129 | ### More fun with graphs
130 |
131 | A multi-layer network where each layer takes output of previous two layers as input.
132 |
133 | ```lua
134 | input = nn.Identity()()
135 | L1 = nn.Tanh()(nn.Linear(10, 20)(input))
136 | L2 = nn.Tanh()(nn.Linear(30, 60)(nn.JoinTable(1)({input, L1})))
137 | L3 = nn.Tanh()(nn.Linear(80, 160)(nn.JoinTable(1)({L1, L2})))
138 |
139 | g = nn.gModule({input}, {L3})
140 |
141 | indata = torch.rand(10)
142 | gdata = torch.rand(160)
143 | g:forward(indata)
144 | g:backward(indata, gdata)
145 |
146 | graph.dot(g.fg, 'Forward Graph')
147 | graph.dot(g.bg, 'Backward Graph')
148 | ```
149 |
150 | As your graph getting bigger and more complicated, the nested parentheses may become confusing. In this case, using `-` to chain the modules is a clearer and easier way:
151 | ```lua
152 | input = - nn.Identity()
153 | L1 = input
154 | - nn.Linear(10, 20)
155 | - nn.Tanh()
156 | L2 = { input, L1 }
157 | - nn.JoinTable(1)
158 | - nn.Linear(30,60)
159 | - nn.Tanh()
160 | L3 = { L1,L2 }
161 | - nn.JoinTable(1)
162 | - nn.Linear(80,160)
163 | - nn.Tanh()
164 | g = nn.gModule({input},{L3})
165 | ```
166 |
167 |
168 |
169 |
170 |
171 | ## Annotations
172 |
173 | It is possible to add annotations to your network, such as labeling nodes with names or attributes which will show up when you graph the network.
174 | This can be helpful in large graphs.
175 |
176 | For the full list of graph attributes see the
177 | [graphviz documentation](http://www.graphviz.org/doc/info/attrs.html).
178 |
179 | ```lua
180 | input = nn.Identity()()
181 | L1 = nn.Tanh()(nn.Linear(10, 20)(input)):annotate{
182 | name = 'L1', description = 'Level 1 Node',
183 | graphAttributes = {color = 'red'}
184 | }
185 | L2 = nn.Tanh()(nn.Linear(30, 60)(nn.JoinTable(1)({input, L1}))):annotate{
186 | name = 'L2', description = 'Level 2 Node',
187 | graphAttributes = {color = 'blue', fontcolor = 'green'}
188 | }
189 | L3 = nn.Tanh()(nn.Linear(80, 160)(nn.JoinTable(1)({L1, L2}))):annotate{
190 | name = 'L3', description = 'Level 3 Node',
191 | graphAttributes = {color = 'green',
192 | style = 'filled', fillcolor = 'yellow'}
193 | }
194 |
195 | g = nn.gModule({input},{L3})
196 |
197 | indata = torch.rand(10)
198 | gdata = torch.rand(160)
199 | g:forward(indata)
200 | g:backward(indata, gdata)
201 |
202 | graph.dot(g.fg, 'Forward Graph', '/tmp/fg')
203 | graph.dot(g.bg, 'Backward Graph', '/tmp/bg')
204 | ```
205 |
206 | In this case, the graphs are saved in the following 4 files: `/tmp/{fg,bg}.{dot,svg}`.
207 |
208 |
209 |
210 |
211 | ## Debugging
212 |
213 | With nngraph, one can create very complicated networks. In these cases, finding errors can be hard. For that purpose, nngraph provides several useful utilities. The following code snippet shows how to use local variable names for annotating the nodes in a graph and how to enable debugging mode that automatically creates an svg file with error node marked in case of a runtime error.
214 |
215 | ```lua
216 |
217 | require 'nngraph'
218 |
219 | -- generate SVG of the graph with the problem node highlighted
220 | -- and hover over the nodes in svg to see the filename:line_number info
221 | -- nodes will be annotated with local variable names even if debug mode is not enabled.
222 | nngraph.setDebug(true)
223 |
224 | local function get_net(from, to)
225 | local from = from or 10
226 | local to = to or 10
227 | local input_x = nn.Identity()()
228 | local linear_module = nn.Linear(from, to)(input_x)
229 |
230 | -- Annotate nodes with local variable names
231 | nngraph.annotateNodes()
232 | return nn.gModule({input_x},{linear_module})
233 | end
234 |
235 | local net = get_net(10,10)
236 |
237 | -- if you give a name to the net, it will use that name to produce the
238 | -- svg in case of error, if not, it will come up with a name
239 | -- that is derived from number of inputs and outputs to the graph
240 | net.name = 'my_bad_linear_net'
241 |
242 | -- prepare an input that is of the wrong size to force an error
243 | local input = torch.rand(11)
244 | pcall(function() net:updateOutput(input) end)
245 | -- it should have produced an error and spit out a graph
246 | -- just run Safari to display the svg
247 | os.execute('open -a Safari my_bad_linear_net.svg')
248 | ```
249 |
250 |
251 |
--------------------------------------------------------------------------------
/doc/annotation_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/annotation_bg.png
--------------------------------------------------------------------------------
/doc/annotation_fg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/annotation_fg.png
--------------------------------------------------------------------------------
/doc/mlp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp.png
--------------------------------------------------------------------------------
/doc/mlp2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp2.png
--------------------------------------------------------------------------------
/doc/mlp3_backward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp3_backward.png
--------------------------------------------------------------------------------
/doc/mlp3_forward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp3_forward.png
--------------------------------------------------------------------------------
/doc/mlp4_backward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp4_backward.png
--------------------------------------------------------------------------------
/doc/mlp4_forward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/mlp4_forward.png
--------------------------------------------------------------------------------
/doc/my_bad_linear_net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/torch/nngraph/3ed3b9ba9d1adf72c1fe15291a1e50b843cf04f9/doc/my_bad_linear_net.png
--------------------------------------------------------------------------------
/gmodule.lua:
--------------------------------------------------------------------------------
1 | local nesting = require('nngraph.nesting')
2 | local utils = require('nngraph.utils')
3 | local istensor = torch.isTensor
4 | local istable = utils.istable
5 | local istorchclass = utils.istorchclass
6 |
7 | local function getTotalGradOutput(node)
8 | local gradOutput = node.data.gradOutput
9 | assert(istable(gradOutput), "expecting gradients to sum")
10 | if #gradOutput > 1 then
11 | -- Check if we can bypass the allocation, for the special case where all
12 | -- gradOutputs but one are zero tensors with an underlying one-element
13 | -- storage. Note that for the case that we
14 | -- cannot bypass it, this check will only be performed once
15 | if not node.data.gradOutputBuffer then
16 | local count = 0
17 | local idx = 1
18 | -- Count how many gradOutput are tensors of 1 element filled with zero
19 | for i=1,#gradOutput do
20 | local zero = torch.isTensor(gradOutput[i]) and
21 | gradOutput[i]:storage() ~= nil and
22 | gradOutput[i]:storage():size() == 1 and
23 | gradOutput[i]:storage()[1] == 0
24 | if not zero then
25 | idx = i
26 | count = count + 1
27 | end
28 | end
29 | if count < 2 then
30 | -- Return the only non-zero one, or the first one
31 | -- if they are all zero
32 | return gradOutput[idx]
33 | end
34 | end
35 | node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
36 | local gobuff = node.data.gradOutputBuffer
37 | nesting.resizeNestedAs(gobuff, gradOutput[1])
38 | nesting.copyNested(gobuff, gradOutput[1])
39 | for i=2,#gradOutput do
40 | nesting.addNestedTo(gobuff, gradOutput[i])
41 | end
42 | gradOutput = gobuff
43 | else
44 | gradOutput = gradOutput[1]
45 | end
46 | return gradOutput
47 | end
48 |
49 | -- The gModule allows to have a general non-cyclic graph of of modules.
50 | --
51 | -- Each node of the graph can have multiple inputs.
52 | -- The order of inputs is remembered in node.data.mapindex.
53 | --
54 | -- Each node have only one output.
55 | -- The output can be also a table.
56 | -- To route parts of the outputted table to different modules,
57 | -- use the node:split(nOutputs) function.
58 | -- The split will create subnodes with narrowed output.
59 | --
60 | -- Implementation details:
61 | -- The node.data.input holds a list of inputs.
62 | -- If a module expects only one input, the node.data.input[1] is used.
63 | --
64 | -- The node.data.gradOutput holds the to-be-summed gradOutputs.
65 | -- Each node has only one output. So we need only one gradOutput.
66 | local gModule, parent = torch.class('nn.gModule','nn.Container')
67 |
68 | function gModule:__init(inputs,outputs)
69 | parent.__init(self)
70 | -- the graph is defined backwards, we have the output modules as input here
71 | -- we will define a dummy output node that connects all output modules
72 | -- into itself. This will be the output for the forward graph and
73 | -- input point for the backward graph
74 | local node
75 | local outnode = nngraph.Node({input={}})
76 | for i = 1, utils.tableMaxN(outputs) do
77 | node = outputs[i]
78 | if torch.typename(node) ~= 'nngraph.Node' then
79 | error(utils.expectingNodeErrorMessage(node, 'outputs', i))
80 | end
81 | outnode:add(node, true)
82 | end
83 | for i = 1, utils.tableMaxN(inputs) do
84 | node = inputs[i]
85 | if torch.typename(node) ~= 'nngraph.Node' then
86 | error(utils.expectingNodeErrorMessage(node, 'inputs', i))
87 | end
88 | end
89 | -- We add also a dummy input node.
90 | -- The input node will be split to feed the passed input nodes.
91 | local innode = nngraph.Node({input={}})
92 | assert(#inputs > 0, "no inputs are not supported")
93 | if #inputs == 1 then
94 | inputs[1]:add(innode,true)
95 | else
96 | local splits = {innode:split(#inputs)}
97 | for i = 1, #inputs do
98 | assert(#inputs[i].children == 0, "an input should have no inputs")
99 | end
100 | for i = 1, #inputs do
101 | inputs[i]:add(splits[i],true)
102 | end
103 | end
104 |
105 | -- the backward graph (bg) is for gradients
106 | -- the forward graph (fg) is for function evaluation
107 | self.bg = outnode:graph()
108 | self.fg = self.bg:reverse()
109 |
110 | -- the complete graph is constructed
111 | -- now regenerate the graphs with the additional nodes
112 |
113 | local roots = self.fg:roots()
114 | -- if there are more than one root in the forward graph, then make sure that
115 | -- extra roots are parameter nodes
116 | if #roots > 1 then
117 | local innodeRoot = nil
118 | -- first find our innode
119 | for _, root in ipairs(roots) do
120 | if root.data == innode.data then
121 | assert(innodeRoot == nil, 'more than one matching input node found in leaves')
122 | innodeRoot = root
123 | else
124 | assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
125 | assert(torch.typename(root.data.module) == 'nnop.Parameters',
126 | 'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
127 | end
128 | end
129 | assert(innodeRoot ~= nil, 'input node not found among roots')
130 | self.innode = innodeRoot
131 | else
132 | assert(#self.fg:roots() == 1, "expecting only one start")
133 | self.innode = self.fg:roots()[1]
134 | end
135 |
136 | assert(self.innode.data == innode.data, "expecting the forward innode")
137 | self.outnode = outnode
138 | self.verbose = false
139 | self.nInputs = #inputs
140 |
141 | -- computation on the graph is done through topsort of forward and backward graphs
142 | self.forwardnodes = self.fg:topsort()
143 | self.backwardnodes = self.bg:topsort()
144 |
145 | -- iteratare over all nodes: check, tag and add to container
146 | for i,node in ipairs(self.forwardnodes) do
147 | -- check for unused inputs or unused split() outputs
148 | if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then
149 | local nUnused = node.data.nSplitOutputs - #node.children
150 | local debugLabel = node.data.annotations._debugLabel
151 | local errStr =
152 | "%s of split(%s) outputs from the node declared at %s are unused"
153 | error(string.format(errStr, nUnused, node.data.nSplitOutputs,
154 | debugLabel))
155 | end
156 |
157 | -- Check whether any nodes were defined as taking this node as an input,
158 | -- but then left dangling and don't connect to the output. If this is
159 | -- the case, then they won't be present in forwardnodes, so error out.
160 | for successor, _ in pairs(node.data.reverseMap) do
161 | local successorIsInGraph = false
162 |
163 | -- Only need to the part of forwardnodes from i onwards, topological
164 | -- sort guarantees it cannot be in the first part.
165 | for j = i+1, #self.forwardnodes do
166 | -- Compare equality of data tables, as new Node objects have been
167 | -- created by processes such as topoological sort, but the
168 | -- underlying .data table is shared.
169 | if self.forwardnodes[j].data == successor.data then
170 | successorIsInGraph = true
171 | break
172 | end
173 | end
174 | local errStr =
175 | "node declared on %s does not connect to gmodule output"
176 | assert(successorIsInGraph,
177 | string.format(errStr, successor.data.annotations._debugLabel))
178 | end
179 |
180 | -- set data.forwardNodeId for node:label() output
181 | node.data.forwardNodeId = node.id
182 |
183 | -- add module to container
184 | if node.data.module then
185 | self:add(node.data.module)
186 | end
187 | end
188 |
189 | self.output = nil
190 | self.gradInput = nil
191 | if #self.outnode.children > 1 then
192 | self.output = self.outnode.data.input
193 | end
194 | end
195 |
196 | function gModule:replace(callback)
197 | local out = callback(self)
198 | local revmodules = {}
199 | for i,m in ipairs(self.modules) do
200 | revmodules[m] = i
201 | end
202 | for i,node in ipairs(self.forwardnodes) do
203 | if node.data.module then
204 | local m = node.data.module
205 | node.data.module = m:replace(callback)
206 | self.modules[revmodules[m]] = node.data.module
207 | end
208 | end
209 | return out
210 | end
211 |
212 | function gModule:map(gm, func)
213 | for i,node in ipairs(self.forwardnodes) do
214 | local gmnode = gm.forwardnodes[i]
215 | assert(gmnode, 'trying to map another gModule with a different structure')
216 | if node.data.module then
217 | assert(gmnode.data.module, 'trying to map another gModule with a different structure')
218 | func(node.data.module, gmnode.data.module)
219 | end
220 | end
221 | end
222 |
223 | --[[ Recursively applies type(type_str) to any tensors in the argument. If the
224 | argument is a tensor, type(type_str) is applied; if the argument is an array,
225 | this function recurses into it. ]]
226 | local function recursiveType(param, type_str)
227 | if torch.type(param) == 'table' then
228 | for i = 1, #param do
229 | param[i] = recursiveType(param[i], type_str)
230 | end
231 | elseif torch.typename(param) and
232 | torch.typename(param):find('torch%..+Tensor') then
233 | param = param:type(type_str)
234 | end
235 | return param
236 | end
237 |
238 | function gModule:type(type, tensorCache)
239 | if not type then
240 | return self._type
241 | end
242 |
243 | tensorCache = tensorCache or {}
244 |
245 | local function applyTypeToTable(table)
246 | for key, value in pairs(table) do
247 | table[key] = recursiveType(table[key], type)
248 | end
249 | end
250 |
251 | -- Convert any stored data in self, and in the in and out nodes
252 | applyTypeToTable(self)
253 | if self.innode then applyTypeToTable(self.innode.data) end
254 | if self.outnode then applyTypeToTable(self.outnode.data) end
255 |
256 | -- Loop through modules and convert data
257 | for _, m in ipairs(self.modules) do
258 | m:type(type, tensorCache)
259 | end
260 |
261 | for i,node in ipairs(self.backwardnodes) do
262 | if node.data.gradOutputBuffer ~= nil then
263 | node.data.gradOutputBuffer =
264 | recursiveType(node.data.gradOutputBuffer, type)
265 | end
266 | for k, child in ipairs(node.children) do
267 | applyTypeToTable(child.data)
268 | end
269 | end
270 |
271 | for i,node in ipairs(self.forwardnodes) do
272 | if node.data.input ~= nil then
273 | node.data.input = recursiveType(node.data.input, type)
274 | end
275 | for k, child in ipairs(node.children) do
276 | applyTypeToTable(child.data)
277 | end
278 | end
279 |
280 | self._type = type
281 | return self
282 | end
283 |
284 | function gModule:updateOutput(input)
285 | return self:runForwardFunction('updateOutput',input)
286 | end
287 |
288 | function gModule:clearState()
289 | local ret = parent.clearState(self)
290 | for _,node in ipairs(self.backwardnodes) do
291 | node.data.gradOutput = nil
292 | node.data.gradOutputBuffer = nil
293 | end
294 | for _,node in ipairs(self.forwardnodes) do
295 | node.data.input = nil
296 | end
297 | return ret
298 | end
299 |
300 | function gModule:runForwardFunction(func,input)
301 | if type(func) == "string" then
302 | local func_name = func
303 | func = function(module,input) return module[func_name](module,input) end
304 | end
305 | -- For backward compatibility, we allow self.nInputs to be missing.
306 | local nInputs = self.nInputs or #self.innode.children
307 | -- We see the input as a list of inputs.
308 | if nInputs <= 1 then
309 | input={input}
310 | elseif type(input) ~= "table" then
311 | error(string.format("expecting table of %s inputs", nInputs))
312 | end
313 | local function neteval(node)
314 | local function propagate(node,x)
315 | for i,child in ipairs(node.children) do
316 | child.data.input = child.data.input or {}
317 | local mapindex = child.data.mapindex[node.data]
318 | assert(not child.data.input[mapindex], "each input should have one source")
319 | child.data.input[mapindex] = x
320 | end
321 | end
322 | if node.data.selectindex then
323 | assert(not node.data.module, "the selectindex-handling nodes should have no module")
324 | local input = node.data.input
325 | assert(#input == 1, "only the splitted node should be the input")
326 | assert(istable(input[1]), "the input for a split should be a table")
327 | input = input[1][node.data.selectindex]
328 | propagate(node,input)
329 | else
330 | local input = node.data.input
331 |
332 | -- a parameter node is captured
333 | if input == nil and node.data.module ~= nil then
334 | input = {}
335 | end
336 | if #input == 1 then
337 | input = input[1]
338 | end
339 | -- forward through this node
340 | -- If no module is present, the node behaves like nn.Identity.
341 | local output
342 | if not node.data.module then
343 | output = input
344 | else
345 | output = func(node.data.module,input)
346 | end
347 | if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then
348 | error(string.format("split(%s) cannot split %s outputs",
349 | node.data.nSplitOutputs,
350 | #output))
351 | end
352 | -- propagate the output to children
353 | propagate(node,output)
354 | end
355 | if self.verbose then
356 | print(' V : ' .. node:label())
357 | end
358 | end
359 |
360 | local innode = self.innode
361 | if #input ~= nInputs then
362 | error(string.format('Got %s inputs instead of %s', #input, nInputs))
363 | end
364 | -- first clear the input states
365 | for _,node in ipairs(self.forwardnodes) do
366 | local input = node.data.input
367 | while input and #input>0 do
368 | table.remove(input)
369 | end
370 | end
371 | -- Set the starting input.
372 | -- We do copy instead of modifying the passed input.
373 | innode.data.input = innode.data.input or {}
374 | for i, item in ipairs(input) do
375 | innode.data.input[i] = item
376 | end
377 |
378 | -- the run forward
379 | for i,node in ipairs(self.forwardnodes) do
380 | neteval(node)
381 | end
382 |
383 | self.output = self.outnode.data.input
384 | if #self.outnode.children == 1 then
385 | self.output = self.output[1]
386 | end
387 | return self.output
388 | end
389 |
390 | function gModule:updateGradInput(input,gradOutput)
391 | local function neteval(node)
392 | if node.data.selectindex then
393 | assert(not node.data.module, "the selectindex-handling nodes should have no module")
394 | assert(#node.children == 1, "only the splitted node should be the input")
395 | local child = node.children[1]
396 | local go = getTotalGradOutput(node)
397 | child.data.gradOutput = child.data.gradOutput or {}
398 | assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
399 | -- The data.gradOutput holds the to-be-summed gradients.
400 | child.data.gradOutput[1] = child.data.gradOutput[1] or {}
401 | assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
402 | child.data.gradOutput[1][node.data.selectindex] = go
403 | else
404 | local gradOutput = getTotalGradOutput(node)
405 | -- updateGradInput through this node
406 | -- If no module is present, the node behaves like nn.Identity.
407 | local gradInput
408 | if not node.data.module then
409 | gradInput = gradOutput
410 | else
411 | local input = node.data.input
412 | -- a parameter node is captured
413 | if input == nil and node.data.module ~= nil then
414 | input = {}
415 | end
416 | if #input == 1 then
417 | input = input[1]
418 | end
419 | local module = node.data.module
420 | gradInput = module:updateGradInput(input,gradOutput)
421 | end
422 | -- propagate the output to children
423 | for i,child in ipairs(node.children) do
424 | child.data.gradOutput = child.data.gradOutput or {}
425 | local mapindex = node.data.mapindex[child.data]
426 | local gi
427 | if #node.children == 1 then
428 | gi = gradInput
429 | else
430 | gi = gradInput[mapindex]
431 | end
432 | table.insert(child.data.gradOutput,gi)
433 | end
434 | end
435 | if self.verbose then
436 | print(' V : ' .. node:label())
437 | end
438 | end
439 | local outnode = self.outnode
440 | if #outnode.children > 1 and #gradOutput ~= #outnode.children then
441 | error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
442 | end
443 | for _,node in ipairs(self.backwardnodes) do
444 | local gradOutput = node.data.gradOutput
445 | while gradOutput and #gradOutput >0 do
446 | table.remove(gradOutput)
447 | end
448 | end
449 | -- Set the starting gradOutput.
450 | outnode.data.gradOutput = outnode.data.gradOutput or {}
451 | outnode.data.gradOutput[1] = gradOutput
452 |
453 | for i,node in ipairs(self.backwardnodes) do
454 | neteval(node)
455 | end
456 |
457 | assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
458 | self.gradInput = self.innode.data.gradOutput[1]
459 | return self.gradInput
460 | end
461 |
462 | function gModule:accGradParameters(input,gradOutput,lr)
463 | local function neteval(node)
464 | if node.data.module then
465 | local module = node.data.module
466 | local gradOutput = node.data.gradOutput[1]
467 | if #node.data.gradOutput > 1 then
468 | gradOutput = node.data.gradOutputBuffer
469 | end
470 | local input = node.data.input
471 | -- a parameter node is captured
472 | if input == nil and node.data.module ~= nil then
473 | input = {}
474 | end
475 | if #input == 1 then
476 | input = input[1]
477 | end
478 | -- accGradParameters through this node
479 | module:accGradParameters(input,gradOutput,lr)
480 | end
481 | if self.verbose then
482 | print(' V : ' .. node:label())
483 | end
484 | end
485 | local outnode = self.outnode
486 | if #outnode.children > 1 and #gradOutput ~= #outnode.children then
487 | error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
488 | end
489 | for i,node in ipairs(self.backwardnodes) do
490 | neteval(node)
491 | end
492 | end
493 |
494 | function gModule:read(file)
495 | local data = file:readObject()
496 | for k, v in pairs(data) do
497 | self[k] = v
498 | end
499 |
500 | -- Initialize the modules table if necessary.
501 | if not self.modules then
502 | self.modules = {}
503 | for _, node in ipairs(self.forwardnodes) do
504 | if node.data.module then
505 | table.insert(self.modules, node.data.module)
506 | end
507 | end
508 | end
509 | end
510 |
511 |
512 | function gModule:__tostring__()
513 | return self.name or torch.type(self)
514 | end
515 |
--------------------------------------------------------------------------------
/graphinspecting.lua:
--------------------------------------------------------------------------------
1 |
2 | -- The findCurrentNode() depends on the names of the
3 | -- local variables in the nngraph.gModule source code.
4 | local function findCurrentNode()
5 | for level = 2, math.huge do
6 | local info = debug.getinfo(level, "n")
7 | if info == nil then
8 | return nil
9 | end
10 |
11 | local funcName = info.name
12 | if funcName == "neteval" then
13 | local varName, node = debug.getlocal(level, 1)
14 | if varName == "node" then
15 | return node
16 | end
17 | end
18 | end
19 | end
20 |
21 | -- Runs the func and calls onError(failedNode, ...) on an error.
22 | -- The stack trace is inspected to find the failedNode.
23 | local function runChecked(func, onError, ...)
24 | -- The current node needs to be searched-for, before unrolling the stack.
25 | local failedNode
26 | local function errorHandler(message)
27 | -- The stack traceback is added only if not already present.
28 | if not string.find(message, 'stack traceback:\n', 1, true) then
29 | message = debug.traceback(message, 2)
30 | end
31 | failedNode = findCurrentNode()
32 | return message
33 | end
34 |
35 | local ok, result = xpcall(func, errorHandler)
36 | if ok then
37 | return result
38 | end
39 |
40 | onError(failedNode, ...)
41 | -- Passing the level 0 avoids adding an additional error position info
42 | -- to the message.
43 | error(result, 0)
44 | end
45 |
46 | local function customToDot(graph, title, failedNode)
47 | local str = graph:todot(title)
48 | if not failedNode then
49 | return str
50 | end
51 |
52 | local failedNodeId = nil
53 | for i, node in ipairs(graph.nodes) do
54 | if node.data == failedNode.data then
55 | failedNodeId = node.id
56 | break
57 | end
58 | end
59 |
60 | if failedNodeId ~= nil then
61 | -- The closing '}' is removed.
62 | -- And red fillcolor is specified for the failedNode.
63 | str = string.gsub(str, '}%s*$', '')
64 | str = str .. string.format('n%s[style=filled, fillcolor=red];\n}',
65 | failedNodeId)
66 | end
67 | return str
68 | end
69 |
70 | local function saveSvg(svgPathPrefix, dotStr)
71 | io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix))
72 | local dotPath = svgPathPrefix .. '.dot'
73 | local dotFile = io.open(dotPath, 'w')
74 | dotFile:write(dotStr)
75 | dotFile:close()
76 |
77 | local svgPath = svgPathPrefix .. '.svg'
78 | local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath)
79 | os.execute(cmd)
80 | end
81 |
82 | local function onError(failedNode, gmodule)
83 | local nInputs = gmodule.nInputs or #gmodule.innode.children
84 | local svgPathPrefix = gmodule.name or string.format(
85 | 'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children)
86 | if paths.filep(svgPathPrefix .. '.svg') then
87 | svgPathPrefix = svgPathPrefix .. '_' .. paths.basename(os.tmpname())
88 | end
89 | local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode)
90 | saveSvg(svgPathPrefix, dotStr)
91 | end
92 |
93 | local origFuncs = {
94 | runForwardFunction = nn.gModule.runForwardFunction,
95 | updateGradInput = nn.gModule.updateGradInput,
96 | accGradParameters = nn.gModule.accGradParameters,
97 | }
98 |
99 | -- When debug is enabled,
100 | -- a gmodule.name .. '.svg' will be saved
101 | -- if an exception occurs in a graph execution.
102 | -- The problematic node will be marked by red color.
103 | function nngraph.setDebug(enable)
104 | if not enable then
105 | -- When debug is disabled,
106 | -- the origFuncs are restored on nn.gModule.
107 | for funcName, origFunc in pairs(origFuncs) do
108 | nn.gModule[funcName] = origFunc
109 | end
110 | return
111 | end
112 |
113 | for funcName, origFunc in pairs(origFuncs) do
114 | nn.gModule[funcName] = function(...)
115 | local args = {...}
116 | local gmodule = args[1]
117 | local unpack = unpack or table.unpack
118 | return runChecked(function()
119 | return origFunc(unpack(args))
120 | end, onError, gmodule)
121 | end
122 | end
123 | end
124 |
125 | -- Sets node.data.annotations.name for the found nodes.
126 | -- The local variables at the given stack level are inspected.
127 | -- The default stack level is 1 (the function that called annotateNodes()).
128 | function nngraph.annotateNodes(stackLevel)
129 | stackLevel = stackLevel or 1
130 | for index = 1, math.huge do
131 | local varName, varValue = debug.getlocal(stackLevel + 1, index)
132 | if not varName then
133 | break
134 | end
135 | if torch.typename(varValue) == "nngraph.Node" then
136 | -- An explicit name is preserved.
137 | if not varValue.data.annotations.name then
138 | varValue:annotate({name = varName})
139 | end
140 | end
141 | end
142 | end
143 |
144 | --[[
145 | SVG visualization for gmodule
146 | TODO: add custom coloring with node types
147 | ]]
148 | function nngraph.display(gmodule)
149 | local ffi = require 'ffi'
150 | local cmd
151 | if ffi.os == 'Linux' then
152 | cmd = 'xdg-open'
153 | elseif ffi.os == 'OSX' then
154 | cmd = 'open -a Safari'
155 | end
156 | local fname = os.tmpname()
157 | graph.dot(gmodule.fg, fname, fname)
158 | os.execute(cmd .. ' ' .. fname .. '.svg')
159 | end
160 |
--------------------------------------------------------------------------------
/init.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'graph'
3 |
4 | nngraph = {}
5 |
6 | require('nngraph.nest')
7 | require('nngraph.node')
8 | require('nngraph.gmodule')
9 | require('nngraph.graphinspecting')
10 | require('nngraph.JustElement')
11 | require('nngraph.JustTable')
12 | require('nngraph.ModuleFromCriterion')
13 |
14 | -- handy functions
15 | local utils = require('nngraph.utils')
16 | local istensor = torch.isTensor
17 | local istable = utils.istable
18 | local istorchclass = utils.istorchclass
19 |
20 | -- simpler todot functions
21 | nngraph.simple_print = require('nngraph.simple_print')
22 |
23 | -- Modify the __call function to hack into nn.Module
24 | local Module = torch.getmetatable('nn.Module')
25 | function Module:__call__(...)
26 | local nArgs = select("#", ...)
27 | assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
28 |
29 | local input = ...
30 | if nArgs == 1 and input == nil then
31 | error(utils.expectingNodeErrorMessage(input, 'inputs', 1))
32 | end
33 | -- Disallow passing empty table, in case someone passes a table with some
34 | -- typo'd variable name in.
35 | if type(input) == 'table' and next(input) == nil then
36 | error('cannot pass an empty table of inputs. To indicate no incoming ' ..
37 | 'connections, leave the second set of parens blank.')
38 | end
39 | if not istable(input) then
40 | input = {input}
41 | end
42 | local mnode = nngraph.Node({module=self})
43 |
44 | local dnode
45 | for i = 1, utils.tableMaxN(input) do
46 | dnode = input[i]
47 | if torch.typename(dnode) ~= 'nngraph.Node' then
48 | error(utils.expectingNodeErrorMessage(dnode, 'inputs', i))
49 | end
50 | mnode:add(dnode,true)
51 | end
52 |
53 | return mnode
54 | end
55 |
56 | local Criterion = torch.getmetatable('nn.Criterion')
57 | function Criterion:__call__(...)
58 | return nn.ModuleFromCriterion(self)(...)
59 | end
60 |
61 |
62 |
63 |
64 | Module.__unm__ = function( obj )
65 | return obj()
66 | end
67 |
68 | Module.__sub__ = function( prev, next )
69 | return next(prev)
70 | end
71 |
72 |
73 | do
74 | local Node = torch.getmetatable('nngraph.Node')
75 | Node.__sub__ = function( prev, next )
76 | return next(prev)
77 | end
78 | end
79 |
80 | return nngraph
81 |
--------------------------------------------------------------------------------
/nest.lua:
--------------------------------------------------------------------------------
1 |
2 | local function isNode(input)
3 | local typename = torch.typename(input)
4 | return typename and typename == 'nngraph.Node'
5 | end
6 |
7 | local function isNonEmptyList(input)
8 | return type(input) == "table" and #input > 0
9 | end
10 |
11 | local function _nest(input)
12 | if not isNode(input) and not isNonEmptyList(input) then
13 | error('what is this in the nest input? ' .. tostring(input))
14 | end
15 |
16 | if isNode(input) then
17 | return input
18 | end
19 |
20 | if #input == 1 then
21 | return nngraph.JustTable()(input)
22 | end
23 |
24 | local wrappedChildren = {}
25 | for i, child in ipairs(input) do
26 | wrappedChildren[i] = _nest(child)
27 | end
28 | return nn.Identity()(wrappedChildren)
29 | end
30 |
31 | -- Returns a nngraph node to represent a nested structure.
32 | -- Usage example:
33 | -- local in1 = nn.Identity()()
34 | -- local in2 = nn.Identity()()
35 | -- local in3 = nn.Identity()()
36 | -- local ok = nn.CAddTable()(nngraph.nest({in1}))
37 | -- local in1Again = nngraph.nest(in1)
38 | -- local state = nngraph.nest({in1, {in2}, in3})
39 | function nngraph.nest(...)
40 | local nArgs = select("#", ...)
41 | assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
42 |
43 | local input = ...
44 | assert(nArgs > 0 and input ~= nil, 'Pass an input.')
45 | return _nest(input)
46 | end
47 |
--------------------------------------------------------------------------------
/nesting.lua:
--------------------------------------------------------------------------------
1 |
2 | local nesting = {}
3 |
4 | local utils = require('nngraph.utils')
5 |
6 | -- Creates a clone of a tensor or of a table with tensors.
7 | function nesting.cloneNested(obj)
8 | if torch.isTensor(obj) then
9 | return obj:clone()
10 | end
11 |
12 | local result = {}
13 | for key, child in pairs(obj) do
14 | result[key] = nesting.cloneNested(child)
15 | end
16 | return result
17 | end
18 |
19 | -- Fills the obj with the given value.
20 | -- The obj can be a tensor or a table with tensors.
21 | function nesting.fillNested(obj, value)
22 | if torch.isTensor(obj) then
23 | obj:fill(value)
24 | else
25 | for key, child in pairs(obj) do
26 | nesting.fillNested(child, value)
27 | end
28 | end
29 | end
30 |
31 | -- Resizes all tensors in the output.
32 | function nesting.resizeNestedAs(output, input)
33 | if torch.isTensor(output) then
34 | output:resizeAs(input)
35 | else
36 | for key, child in pairs(input) do
37 | -- A new element is added to the output, if needed.
38 | if not output[key] then
39 | output[key] = nesting.cloneNested(child)
40 | else
41 | nesting.resizeNestedAs(output[key], child)
42 | end
43 | end
44 | -- Extra elements are removed from the output.
45 | for key, child in pairs(output) do
46 | if not input[key] then
47 | output[key] = nil
48 | end
49 | end
50 | end
51 | end
52 |
53 | -- Copies all tensors in the output.
54 | function nesting.copyNested(output, input)
55 | if torch.isTensor(output) then
56 | output:copy(input)
57 | else
58 | for key, child in pairs(input) do
59 | nesting.copyNested(output[key], child)
60 | end
61 | -- Extra elements in the output table cause an error.
62 | for key, child in pairs(output) do
63 | if not input[key] then
64 | error('key ' .. tostring(key) ..
65 | ' present in output but not in input')
66 | end
67 | end
68 | end
69 | end
70 |
71 | -- Adds the input to the output.
72 | -- The input can contain nested tables.
73 | -- The output will contain the same nesting of tables.
74 | function nesting.addNestedTo(output, input)
75 | if torch.isTensor(output) then
76 | output:add(input)
77 | else
78 | for key, child in pairs(input) do
79 | assert(output[key] ~= nil, "missing key")
80 | nesting.addNestedTo(output[key], child)
81 | end
82 | end
83 | end
84 |
85 | return nesting
86 |
--------------------------------------------------------------------------------
/nngraph-scm-1.rockspec:
--------------------------------------------------------------------------------
1 | package = "nngraph"
2 | version = "scm-1"
3 |
4 | source = {
5 | url = "git://github.com/torch/nngraph",
6 | tag = "master"
7 | }
8 |
9 | description = {
10 | summary = "This package provides graphical computation for nn library in Torch7.",
11 | homepage = "https://github.com/torch/nngraph",
12 | license = "UNKNOWN"
13 | }
14 |
15 | dependencies = {
16 | "torch >= 7.0",
17 | "graph",
18 | "nn"
19 | }
20 |
21 | build = {
22 | type = "command",
23 | build_command = [[
24 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE)
25 | ]],
26 | install_command = "cd build && $(MAKE) install"
27 | }
28 |
--------------------------------------------------------------------------------
/node.lua:
--------------------------------------------------------------------------------
1 |
2 | local utils = require('nngraph.utils')
3 | local istensor = torch.isTensor
4 | local istable = utils.istable
5 | local istorchclass = utils.istorchclass
6 | require 'debug'
7 |
8 | local nnNode,parent = torch.class('nngraph.Node','graph.Node')
9 |
10 | function nnNode:__init(data)
11 | parent.__init(self,data)
12 | self.data.annotations = self.data.annotations or {}
13 | self.data.mapindex = self.data.mapindex or {}
14 | self.data.reverseMap = self.data.reverseMap or {}
15 | if not self.data.annotations._debugLabel then
16 | self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
17 | end
18 | end
19 |
20 | --[[ Build a string label which will be used a tooltip when
21 | making a graph.]]
22 | function nnNode:_makeDebugLabel(dinfo)
23 | if dinfo then
24 | self.data.annotations._debugLabel = string.format('[%s]:%d_%s',
25 | dinfo.short_src,
26 | dinfo.currentline,
27 | dinfo.name or '')
28 | end
29 | end
30 |
31 | -- domap ensures that this node will keep track of the order its children are added.
32 | -- mapindex is a forward/backward list
33 | -- index = self.data.mapindex[child.data]
34 | -- child.data = self.data.mapindex[index]
35 | function nnNode:add(child,domap)
36 | parent.add(self,child)
37 | if domap then
38 | local mapindex = self.data.mapindex
39 | local data = child.data
40 | assert(not mapindex[data], "Don't pass the same input twice.")
41 | table.insert(mapindex,data)
42 | mapindex[data] = #mapindex
43 |
44 | -- The "child" that is added here actually represents the input node,
45 | -- so we write into that node to indicate that we are downstream of it.
46 | -- This enables dangling pointer detection.
47 | local revMap = child.data.reverseMap
48 | assert(not revMap[self], 'this connection has already been made!')
49 | revMap[self] = true
50 | end
51 | end
52 |
53 | -- this function returns noutput number of new nodes
54 | -- that each take a single component of the output of this
55 | -- node in the order they are returned.
56 | function nnNode:split(noutput)
57 | if noutput == 1 then
58 | return nngraph.JustElement()(self)
59 | end
60 | local debugLabel = self.data.annotations._debugLabel
61 | -- Specify the source location where :split is called.
62 | local dinfo = debug.getinfo(2, 'Sl')
63 | local splitLoc = string.format(' split at [%s]:%d',
64 | dinfo.short_src,
65 | dinfo.currentline)
66 | local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. splitLoc .. '-mnode'}})
67 | mnode:add(self,true)
68 |
69 | local selectnodes = {}
70 | for i=1,noutput do
71 | local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}})
72 | node:add(mnode,true)
73 | table.insert(selectnodes,node)
74 | end
75 |
76 | local unpack = unpack or table.unpack -- Lua52 compat
77 | return unpack(selectnodes)
78 | end
79 |
80 | function nnNode:annotate(annotations)
81 | for k, v in pairs(annotations) do
82 | self.data.annotations[k] = v
83 | end
84 |
85 | return self
86 | end
87 |
88 | function nnNode:graphNodeName()
89 | if self.data.annotations.name then
90 | return self.data.annotations.name .. ' (' .. self.id .. ')'
91 | else
92 | return 'Node' .. self.id
93 | end
94 | end
95 |
96 | function nnNode:graphNodeAttributes()
97 | self.data.annotations.graphAttributes =
98 | self.data.annotations.graphAttributes or {}
99 | if not self.data.annotations.graphAttributes.tooltip then
100 | self.data.annotations.graphAttributes.tooltip =
101 | self.data.annotations._debugLabel
102 | end
103 |
104 | return self.data.annotations.graphAttributes
105 | end
106 |
107 | local function getNanFlag(data)
108 | if data:nElement() == 0 then
109 | return ''
110 | end
111 | local isNan = (data:ne(data):sum() > 0)
112 | if isNan then
113 | return 'NaN'
114 | end
115 | if data:max() == math.huge then
116 | return 'inf'
117 | end
118 | if data:min() == -math.huge then
119 | return '-inf'
120 | end
121 | return ''
122 | end
123 |
124 | function nnNode:label()
125 |
126 | local lbl = {}
127 |
128 | local function getstr(data)
129 | if not data then return '' end
130 | if istensor(data) then
131 | local nanFlag = getNanFlag(data)
132 | local tensorType = 'Tensor'
133 | if data:type() ~= torch.Tensor():type() then
134 | tensorType = data:type()
135 | end
136 | return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
137 | elseif istable(data) then
138 | local tstr = {}
139 | for i,v in ipairs(data) do
140 | table.insert(tstr, getstr(v))
141 | end
142 | return '{' .. table.concat(tstr,',') .. '}'
143 | else
144 | return tostring(data):gsub('\n','\\l')
145 | end
146 | end
147 | local function getmapindexstr(mapindex)
148 | local tstr = {}
149 | for i,data in ipairs(mapindex) do
150 | local inputId = 'Node' .. (data.forwardNodeId or '')
151 | table.insert(tstr, inputId)
152 | end
153 | return '{' .. table.concat(tstr,',') .. '}'
154 | end
155 |
156 | for k,v in pairs(self.data) do
157 | local vstr = ''
158 | if k== 'mapindex' then
159 | if #v > 1 then
160 | vstr = getmapindexstr(v)
161 | table.insert(lbl, k .. ' = ' .. vstr)
162 | end
163 | elseif k== 'forwardNodeId' or k== 'annotations' then
164 | -- the forwardNodeId is not displayed in the label.
165 | else
166 | vstr = getstr(v)
167 | table.insert(lbl, k .. ' = ' .. vstr)
168 | end
169 | end
170 |
171 | local desc
172 | if self.data.annotations.description then
173 | desc = 'desc = ' .. self.data.annotations.description .. '\\n'
174 | else
175 | desc = ''
176 | end
177 | return desc .. table.concat(lbl,"\\l")
178 | end
179 |
--------------------------------------------------------------------------------
/simple_print.lua:
--------------------------------------------------------------------------------
1 | local function removeNodeFromEdges(node_id, edges)
2 | local from_nodes = {}
3 | local to_nodes = {}
4 | -- remove edges
5 | local idx = 1
6 | while idx <= #edges do
7 | local edge = edges[idx]
8 | if edge.source == node_id then
9 | local to_node = edges[idx].target
10 | table.insert(to_nodes, to_node)
11 | table.remove(edges, idx)
12 | elseif edge.target == node_id then
13 | local from_node = edges[idx].source
14 | table.insert(from_nodes, from_node)
15 | table.remove(edges, idx)
16 | else
17 | idx = idx + 1
18 | end
19 | end
20 |
21 | -- add new edges
22 | for _, f in pairs(from_nodes) do
23 | for _, t in pairs(to_nodes) do
24 | local edge = {source = f, target= t}
25 | table.insert(edges, edge)
26 | end
27 | end
28 |
29 | return edges
30 | end
31 |
32 | local function isNodeGood(node)
33 | return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity')
34 | end
35 |
36 | local function reIndexNodes(nodes, edges)
37 | -- make reverse map
38 | local rev_map = {}
39 | for idx = 1, #nodes do
40 | rev_map[nodes[idx].id] = idx
41 | nodes[idx].id = idx
42 | end
43 | for idx = 1, #edges do
44 | local edge = edges[idx]
45 | edge.source = rev_map[edge.source]
46 | edge.target = rev_map[edge.target]
47 | end
48 | return nodes, edges
49 | end
50 |
51 | local function cleanGraph(nodes, edges)
52 | local idx = 1
53 | while idx <= #nodes do
54 | local node = nodes[idx]
55 | if isNodeGood(node.orig_node) then
56 | idx = idx + 1
57 | else
58 | local id = node.id
59 | table.remove(nodes, idx)
60 | edges = removeNodeFromEdges(id, edges)
61 | end
62 | end
63 | return reIndexNodes(nodes, edges)
64 | end
65 |
66 | local function loadGraph(graph)
67 | local nodes = {}
68 | local edges = {}
69 | for _, node in ipairs(graph.nodes) do
70 | local idx = node.id
71 | table.insert(nodes, {id=idx, orig_node = node} )
72 | for ich = 1, #node.children do
73 | table.insert( edges, {source = idx, target = node.children[ich].id})
74 | end
75 | end
76 | nodes, edges = cleanGraph(nodes, edges)
77 | return nodes , edges
78 | end
79 |
80 | local M = {}
81 |
82 | function M.todot( graph, title )
83 | local nodes, edges = loadGraph(graph)
84 | local str = {}
85 | table.insert(str,'digraph G {\n')
86 | if title then
87 | table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
88 | end
89 | table.insert(str,'node [shape = oval]; ')
90 | local nodelabels = {}
91 | for i,node in ipairs(nodes) do
92 | local true_node = node.orig_node
93 | local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"'
94 | nodelabels[i] = 'n' .. true_node.id
95 | table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];')
96 | end
97 | table.insert(str,'\n')
98 | for i,edge in ipairs(edges) do
99 | table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n')
100 | end
101 | table.insert(str,'}')
102 | return table.concat(str,'')
103 | end
104 |
105 | function M.dot(g,title,fname)
106 | local gv = M.todot(g, title)
107 | local fngv = (fname or os.tmpname()) .. '.dot'
108 | local fgv = io.open(fngv,'w')
109 | fgv:write(gv)
110 | fgv:close()
111 | local fnsvg = (fname or os.tmpname()) .. '.svg'
112 | os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv)
113 | if not fname then
114 | require 'qtsvg'
115 | local qs = qt.QSvgWidget(fnsvg)
116 | qs:show()
117 | os.remove(fngv)
118 | os.remove(fnsvg)
119 | -- print(fngv,fnpng)
120 | return qs
121 | end
122 | end
123 |
124 | return M
125 |
--------------------------------------------------------------------------------
/test/speed.lua:
--------------------------------------------------------------------------------
1 |
2 | require 'nngraph'
3 |
4 | function time_benchmark(model, input, n)
5 | local forward_timer = torch.Timer():stop():reset()
6 | local backward_timer = torch.Timer():stop():reset()
7 | local total_timer = torch.Timer():stop():reset()
8 | local gradOut
9 | total_timer:resume()
10 | for i = 1, n do
11 | forward_timer:resume()
12 | local out = model:forward(input)
13 | forward_timer:stop()
14 | if not gradOut then
15 | gradOut = torch.rand(out:size())
16 | end
17 | backward_timer:resume()
18 | model:backward(input, gradOut)
19 | backward_timer:stop()
20 | end
21 | total_timer:stop()
22 |
23 | return {forward = forward_timer:time().real,
24 | backward = backward_timer:time().real,
25 | total = total_timer:time().real}
26 | end
27 |
28 | function report_benchmark(result, title)
29 | local nspace = (80-string.len(title))/2
30 | report = {string.rep('#', 80),
31 | string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))),
32 | string.format('Total Time Spent = %.2f s', result.total),
33 | string.format(' Forward Time = %.2f s', result.forward),
34 | string.format(' Backward Time = %.2f s', result.backward),
35 | string.rep('#', 80)
36 | }
37 | print(table.concat(report, '\n'))
38 | return result
39 | end
40 |
41 | function compare_benchmarks(result, base, title)
42 | local nspace = (80-string.len(title))/2
43 | report = {string.rep('#', 80),
44 | string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))),
45 | string.format('Total Time Spent = %.2f %%', result.total/base.total*100),
46 | string.format(' Forward Time = %.2f %%', result.forward/base.forward*100),
47 | string.format(' Backward Time = %.2f %%', result.backward/base.backward*100),
48 | string.rep('#', 80)
49 | }
50 | print(table.concat(report, '\n'))
51 | return result
52 | end
53 |
54 | function get_models(nhidden_layers, ninput, noutput, nhidden)
55 |
56 | local function get_concat_layer(nfrom, nto)
57 | local concat_module = nn.Sequential()
58 | local concat_layer = nn.ConcatTable()
59 | concat_layer:add(nn.Linear(nfrom, nto))
60 | concat_layer:add(nn.Linear(nfrom, nto))
61 | concat_module:add(concat_layer)
62 | concat_module:add(nn.CAddTable())
63 | concat_module:add(nn.ReLU())
64 | return concat_module
65 | end
66 |
67 | -- NN
68 | local nn_model = nn.Sequential()
69 | nn_model:add(get_concat_layer(ninput, nhidden))--nn.Linear(ninput, nhidden)):add(nn.ReLU())
70 | for i = 1, nhidden_layers do
71 | nn_model:add(get_concat_layer(nhidden, nhidden))--nn.Linear(nhidden, nhidden)):add(nn.ReLU())
72 | end
73 | nn_model:add(get_concat_layer(nhidden, noutput))--nn.Linear(nhidden, noutput))
74 |
75 | -- NN graph
76 | local input = nn.Identity()()
77 | local nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(ninput, nhidden)(input),
78 | nn.Linear(ninput, nhidden)(input)}))
79 | for i = 1, nhidden_layers do
80 | nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, nhidden)(nng_model),
81 | nn.Linear(nhidden, nhidden)(nng_model)}))
82 | end
83 | nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, noutput)(nng_model),
84 | nn.Linear(nhidden, noutput)(nng_model)}))
85 |
86 | nng_model = nn.gModule({input},{nng_model})
87 |
88 | return {nn_model = nn_model, nng_model = nng_model}
89 | end
90 |
91 | function get_options(arg)
92 | local cmd = torch.CmdLine()
93 | cmd:text('nngraph benchmarking')
94 | cmd:option('-niter', 10, 'number of iterations of forward/backward for each model')
95 | cmd:option('-nhidden_layers', 10, 'number of hidden layers')
96 | cmd:option('-input_size', 512, 'size of input')
97 | cmd:option('-batch_size', 16, 'size of batch')
98 | cmd:option('-hidden_size', 1024, 'size of hidden layer')
99 | cmd:option('-output_size', 128, 'size of output layer')
100 | local opt = cmd:parse(arg)
101 | return opt
102 | end
103 |
104 | local opt = get_options(arg)
105 | models = get_models(opt.nhidden_layers, opt.input_size, opt.output_size, opt.hidden_size)
106 | print(opt)
107 |
108 | local nn_bench = report_benchmark(time_benchmark(models.nn_model, torch.rand(opt.batch_size,opt.input_size), opt.niter), 'NN')
109 | local nng_bench = report_benchmark(time_benchmark(models.nng_model, torch.rand(opt.batch_size,opt.input_size), opt.niter), 'NNGRAPH')
110 | compare_benchmarks(nng_bench, nn_bench, 'NNGRAPH / NN (%)')
111 |
112 |
--------------------------------------------------------------------------------
/test/test_JustElement.lua:
--------------------------------------------------------------------------------
1 |
2 | require 'totem'
3 | require 'nngraph'
4 | local test = {}
5 | local tester = totem.Tester()
6 |
7 | function test.test_output()
8 | local input = {torch.randn(7, 11)}
9 | local module = nngraph.JustElement()
10 | tester:eq(module:forward(input), input[1], "output")
11 | end
12 |
13 | function test.test_grad()
14 | local input = {torch.randn(7, 11)}
15 | local module = nngraph.JustElement()
16 | totem.nn.checkGradients(tester, module, input)
17 | end
18 |
19 | function test.test_split()
20 | local in1 = nn.Identity()()
21 | local output = in1:split(1)
22 | local net = nn.gModule({in1}, {output})
23 |
24 | local input = {torch.randn(7, 11)}
25 | tester:eq(net:forward(input), input[1], "output of split(1)")
26 | end
27 |
28 | tester:add(test):run()
29 |
--------------------------------------------------------------------------------
/test/test_JustTable.lua:
--------------------------------------------------------------------------------
1 |
2 | require 'totem'
3 | require 'nngraph'
4 | local test = {}
5 | local tester = totem.Tester()
6 |
7 | function test.test_output()
8 | local input = torch.randn(7, 11)
9 | local module = nngraph.JustTable()
10 | tester:eq(module:forward(input), {input}, "output")
11 | end
12 |
13 | function test.test_grad()
14 | local input = torch.randn(7, 11)
15 | local module = nngraph.JustTable()
16 | totem.nn.checkGradients(tester, module, input)
17 | end
18 |
19 | tester:add(test):run()
20 |
--------------------------------------------------------------------------------
/test/test_ModuleFromCriterion.lua:
--------------------------------------------------------------------------------
1 |
2 | require 'totem'
3 | require 'nngraph'
4 | local test = {}
5 | local tester = totem.Tester()
6 |
7 | function test.test_call()
8 | local prediction = nn.Identity()()
9 | local target = nn.Identity()()
10 | local mse = nn.MSECriterion()({prediction, target})
11 | local costBits = nn.MulConstant(1/math.log(2))(mse)
12 | local net = nn.gModule({prediction, target}, {costBits})
13 |
14 | local input = {torch.randn(3, 5), torch.rand(3, 5)}
15 | local criterion = nn.MSECriterion()
16 | local output = net:forward(input)
17 | criterion:forward(input[1], input[2])
18 | tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14)
19 |
20 | local gradOutput = torch.randn(1)
21 | local gradInput = net:backward(input, gradOutput)
22 | criterion:backward(input[1], input[2])
23 | tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14)
24 | tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget")
25 | end
26 |
27 | function test.test_grad()
28 | local prediction = nn.Identity()()
29 | local zero = nn.MulConstant(0)(prediction)
30 | -- The target is created inside of the nngraph
31 | -- to ignore the zero gradTarget.
32 | local target = nn.AddConstant(1.23)(zero)
33 | local mse = nn.MSECriterion()({prediction, target})
34 | local net = nn.gModule({prediction}, {mse})
35 |
36 | local input = torch.randn(4, 7)
37 | totem.nn.checkGradients(tester, net, input)
38 | end
39 |
40 | local function module()
41 | local module = nn.ModuleFromCriterion(nn.MSECriterion())
42 | local input = {torch.randn(3, 5), torch.randn(3, 5)}
43 | return module, input
44 | end
45 |
46 | function test.test_serializable()
47 | local module, input = module()
48 | totem.nn.checkSerializable(tester, module, input)
49 | end
50 |
51 | function test.test_typeCastable()
52 | local module, input = module()
53 | totem.nn.checkTypeCastable(tester, module, input)
54 | end
55 |
56 |
57 | tester:add(test):run()
58 |
--------------------------------------------------------------------------------
/test/test_connectivity.lua:
--------------------------------------------------------------------------------
1 | local totem = require 'totem'
2 | require 'nngraph'
3 | local tests = totem.TestSuite()
4 | local tester = totem.Tester()
5 |
6 | function tests.connectivity()
7 | -- Store debug info here, need to call debug.getinfo on same line as the
8 | -- dangling pointer is declared.
9 | local dInfo
10 | local input = nn.Identity()()
11 | local lin = nn.Linear(20, 10)(input)
12 | -- The Sigmoid does not connect to the output, so should cause an error
13 | -- when we call gModule.
14 | local dangling = nn.Sigmoid()(lin); dInfo = debug.getinfo(1, 'Sl')
15 | local actualOutput = nn.Tanh()(lin)
16 | local errStr = string.format(
17 | 'node declared on %%[%s%%]:%d_ does not connect to gmodule output',
18 | dInfo.short_src, dInfo.currentline)
19 | tester:assertErrorPattern(
20 | function()
21 | return nn.gModule({input}, {actualOutput})
22 | end,
23 | errStr)
24 | end
25 |
26 | return tester:add(tests):run()
27 |
--------------------------------------------------------------------------------
/test/test_debug.lua:
--------------------------------------------------------------------------------
1 | local totem = require 'totem'
2 | require 'nngraph'
3 | local tests = totem.TestSuite()
4 | local tester = totem.Tester()
5 |
6 | function tests.whatIsThisInTheInput()
7 | tester:assertErrorPattern(
8 | function()
9 | local inp1, inp2 = nn.Identity()(), nn.Identity() -- missing 2nd parens
10 | local lin = nn.Linear(20, 10)(nn.CMulTable(){inp1, inp2})
11 | end,
12 | 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' ..
13 | 'only valid thing to pass is an instance of nngraph%.Node')
14 |
15 | tester:assertErrorPattern(
16 | function()
17 | -- pass-through module, again with same mistake
18 | local graphNode, nnModule = nn.Identity()(), nn.Identity()
19 | return nn.gModule({graphNode, nnModule}, {graphNode})
20 | end,
21 | 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' ..
22 | 'only valid thing to pass is an instance of nngraph%.Node')
23 |
24 | tester:assertErrorPattern(
25 | function()
26 | local input = nn.Identity()()
27 | local out1 = nn.Linear(20, 10)(input)
28 | local out2 = nn.Sigmoid()(input)
29 | local unconnectedOut = nn.Linear(2, 3)
30 | return nn.gModule({input}, {out1, out2, unconnectedOut})
31 | end,
32 | 'outputs%[3%] is an nn%.Module, specifically a nn%.Linear, but the ' ..
33 | 'only valid thing to pass is an instance of nngraph%.Node')
34 |
35 | -- Check for detecting a nil in the middle of a table.
36 | tester:assertErrorPattern(
37 | function()
38 | local input = nn.Identity()()
39 | local out1 = nn.Tanh()(input)
40 | local out2 = nn.Sigmoid()(input)
41 | -- nil here is simulating a mis-spelt variable name
42 | return nn.gModule({input}, {out1, nil, out2})
43 | end,
44 | 'outputs%[2%] is nil %(typo / bad index%?%)')
45 |
46 | tester:assertErrorPattern(
47 | function()
48 | -- Typo variable name returns nil, meaning an empty table
49 | local input = nn.Identity()({aNonExistentVariable})
50 | end,
51 | 'cannot pass an empty table of inputs%. To indicate no incoming ' ..
52 | 'connections, leave the second set of parens blank%.')
53 | end
54 |
55 | function tests.splitUnused()
56 | -- Need to do debuginfo on the same lines as the other code here to match
57 | -- what debug.getinfo inside those calls will return
58 | local dInfoDeclare, dInfoSplit
59 | local input = nn.Identity()(); dInfoDeclare = debug.getinfo(1, 'Sl')
60 | local output, unused = input:split(2); dInfoSplit = debug.getinfo(1, 'Sl')
61 |
62 | local function willCrash()
63 | return nn.gModule({input}, {output})
64 | end
65 |
66 | -- Work out what strings will be in the error message
67 | local declareLoc = string.format('%%[%s%%]:%d_',
68 | dInfoDeclare.short_src,
69 | dInfoDeclare.currentline)
70 | local splitLoc = string.format('%%[%s%%]:%d',
71 | dInfoSplit.short_src,
72 | dInfoSplit.currentline)
73 |
74 | tester:assertErrorPattern(
75 | willCrash,
76 | '1 of split%(2%) outputs from the node declared at ' ..
77 | declareLoc .. ' split at ' .. splitLoc .. '%-mnode are unused')
78 | end
79 |
80 | tester:add(tests):run()
81 |
--------------------------------------------------------------------------------
/test/test_nest.lua:
--------------------------------------------------------------------------------
1 |
2 | require 'totem'
3 | require 'nngraph'
4 |
5 | local test = {}
6 | local tester = totem.Tester()
7 |
8 | function test.test_output()
9 | local in1 = nn.Identity()()
10 | local in2 = nn.Identity()()
11 | local in3 = nn.Identity()()
12 | local ok = nn.CAddTable()(nngraph.nest({in1}))
13 | local in1Again = nngraph.nest(in1)
14 | local state = nngraph.nest({in1, {in2}, in3})
15 |
16 | local net = nn.gModule(
17 | {in1, in2, in3},
18 | {ok, in1Again, state, nngraph.nest({in3}), nngraph.nest({in1, in2})})
19 |
20 | local val1 = torch.randn(7, 3)
21 | local val2 = torch.randn(2)
22 | local val3 = torch.randn(3)
23 | local expectedOutput = {
24 | val1, val1, {val1, {val2}, val3}, {val3}, {val1, val2},
25 | }
26 | local output = net:forward({val1, val2, val3})
27 | tester:eq(output, expectedOutput, "output")
28 | end
29 |
30 |
31 | return tester:add(test):run()
32 |
33 |
34 |
--------------------------------------------------------------------------------
/test/test_nngraph.lua:
--------------------------------------------------------------------------------
1 |
2 | require 'totem'
3 | require 'nngraph'
4 | local test = {}
5 | local tester = totem.Tester()
6 |
7 | local function checkGradients(...)
8 | totem.nn.checkGradients(tester, ...)
9 | end
10 |
11 | function test.test_oneOutput()
12 | local in1 = nn.Identity()()
13 | local out1 = nn.Identity()(in1)
14 | local module = nn.gModule({in1}, {out1})
15 |
16 | local input = torch.Tensor({1})
17 | module:forward(input)
18 | tester:eq(module.output, torch.Tensor{1}, "output")
19 | local gradInput = module:backward(input, torch.Tensor({-123}))
20 | tester:eq(gradInput, torch.Tensor{-123}, "gradInput")
21 |
22 | local input2 = torch.Tensor({2})
23 | module:forward(input2)
24 | tester:eq(module.output, torch.Tensor{2}, "output for input2")
25 | gradInput = module:backward(input2, torch.Tensor({-2}))
26 | tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput")
27 | end
28 |
29 |
30 | function test.test_twoOutputs()
31 | local in1 = nn.Identity()()
32 | local out1 = nn.Identity()(in1)
33 | local out2 = nn.Identity()(in1)
34 | local module = nn.gModule({in1}, {out1, out2})
35 |
36 | local input = torch.Tensor({1})
37 | module:forward(input)
38 | local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})})
39 | tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork")
40 | checkGradients(module, input)
41 | end
42 |
43 | function test.test_twoGradOutputs()
44 | local in1 = nn.Sigmoid()()
45 | local splitTable = nn.SplitTable(1)({in1})
46 | local out1, out2 = splitTable:split(2)
47 | local module = nn.gModule({in1}, {out1, out2})
48 |
49 | local input = torch.randn(2, 3)
50 | local output = module:forward(input)
51 | assert(#output == 2, "wrong number of outputs")
52 | module:backward(input, {torch.randn(3), torch.randn(3)})
53 | checkGradients(module, input)
54 | end
55 |
56 | function test.test_twoInputs()
57 | local in1 = nn.Identity()()
58 | local in2 = nn.Identity()()
59 | local prevH, prevCell = in2:split(2)
60 |
61 | local out1 = nn.CMulTable()({in1, prevH, prevCell})
62 | local module = nn.gModule({in1, in2}, {out1})
63 |
64 | local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}}
65 | module:forward(input)
66 | local gradInput = module:backward(input, torch.randn(3))
67 | assert(#gradInput == 2, "wrong number of gradInputs")
68 | assert(type(gradInput[2]) == "table", "wrong gradInput[2] type")
69 | checkGradients(module, input)
70 | end
71 |
72 | function test.test_twoInputs2()
73 | local in1 = nn.Sigmoid()()
74 | local in2 = nn.Sigmoid()()
75 | local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)})
76 |
77 | local input = {torch.randn(3), torch.randn(3)}
78 | module:forward(input)
79 | local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
80 | checkGradients(module, input)
81 | end
82 |
83 | function test.test_splitDebugLabels()
84 | local node = nn.Identity()()
85 | node.data.annotations._debugLabel = "node"
86 | local node1, node2 = node:split(2)
87 | assert(node1.data.annotations._debugLabel == "node-1")
88 | assert(node2.data.annotations._debugLabel == "node-2")
89 | end
90 |
91 | function test.test_identity()
92 | local in1 = nn.Identity()()
93 | local in2 = nn.Identity()()
94 | local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)})
95 |
96 | local input = {torch.randn(3), torch.randn(3)}
97 | module:forward(input)
98 | module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
99 | checkGradients(module, input)
100 | end
101 |
102 | function test.test_gradInputType()
103 | local xInput = torch.randn(3)
104 | local h = torch.randn(3)
105 |
106 | local x = nn.Identity()()
107 | local prevRnnState = nn.Identity()()
108 | local prevH1, prevCell = prevRnnState:split(2)
109 | local prevH = prevH1
110 |
111 | local cellOut = nn.CAddTable()({
112 | nn.CMulTable()({x, prevH}),
113 | nn.CMulTable()({prevH, prevCell})})
114 | local module = nn.gModule({x, prevRnnState}, {cellOut})
115 |
116 | local c = torch.randn(h:size())
117 | local prevRnnState = {h, c}
118 | local input = {xInput, prevRnnState}
119 | local output = module:forward(input)
120 |
121 | local gradOutput = torch.randn(h:size())
122 | local gradInput = module:backward(input, gradOutput)
123 |
124 | local unpack = unpack or table.unpack
125 | local gradX, gradPrevState = unpack(gradInput)
126 | local gradPrevH, gradPrevCell = unpack(gradPrevState)
127 | assert(type(gradPrevH) == type(h), "wrong gradPrevH type")
128 |
129 | tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type")
130 | tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size")
131 | checkGradients(module, input)
132 | end
133 |
134 | function test.test_tabularInput()
135 | local in1 = nn.SplitTable(1)()
136 | local out1 = nn.CAddTable()(in1)
137 | local module = nn.gModule({in1}, {out1})
138 |
139 | local input = torch.randn(2, 3)
140 | checkGradients(module, input)
141 | end
142 |
143 | function test.test_extraTable()
144 | local in1 = nn.Identity()()
145 | local out1 = nn.Identity()(in1)
146 | local module = nn.gModule({in1}, {out1})
147 |
148 | local input = torch.Tensor({123})
149 | tester:eq(module:forward(input), input, "simple output")
150 | tester:eq(module:forward({input}), {input}, "tabular output")
151 | end
152 |
153 | function test.test_accGradParameters()
154 | local input = torch.randn(10)
155 |
156 | local in1 = nn.CMul(input:nElement())()
157 | local out1 = nn.Identity()(in1)
158 | local out2 = nn.Identity()(in1)
159 | local module = nn.gModule({in1}, {out1, out2})
160 | checkGradients(module, input)
161 | end
162 |
163 | function test.test_example1()
164 | local x1 = nn.Linear(20,10)()
165 | local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
166 | local mlp = nn.gModule({x1},{mout})
167 |
168 | local x = torch.rand(20)
169 | checkGradients(mlp, x)
170 | end
171 |
172 | function test.test_example2()
173 | local x1=nn.Linear(20,20)()
174 | local x2=nn.Linear(10,10)()
175 | local m0=nn.Linear(20,1)(nn.Tanh()(x1))
176 | local m1=nn.Linear(10,1)(nn.Tanh()(x2))
177 | local madd=nn.CAddTable()({m0,m1})
178 | local m2=nn.Sigmoid()(madd)
179 | local m3=nn.Tanh()(madd)
180 | local gmod = nn.gModule({x1,x2},{m2,m3})
181 |
182 | local x = torch.rand(20)
183 | local y = torch.rand(10)
184 | checkGradients(gmod, {x, y})
185 | end
186 |
187 | function test.test_example3()
188 | local m = nn.Sequential()
189 | m:add(nn.SplitTable(1))
190 | m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
191 | local input = nn.Identity()()
192 | local input1,input2 = m(input):split(2)
193 | local m3 = nn.JoinTable(1)({input1,input2})
194 | local g = nn.gModule({input},{m3})
195 |
196 | local indata = torch.rand(2,10)
197 | checkGradients(g, indata)
198 | end
199 |
200 | function test.test_example4()
201 | local input = nn.Identity()()
202 | local L1 = nn.Tanh()(nn.Linear(1,2)(input))
203 | local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1})))
204 | local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2})))
205 | local g = nn.gModule({input},{L3})
206 |
207 | local indata = torch.rand(1)
208 | checkGradients(g, indata)
209 | end
210 |
211 | function test.test_type()
212 | local in1 = nn.Linear(20,10)()
213 | local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1))))
214 | local module = nn.gModule({in1}, {out1})
215 | local input = torch.rand(20)
216 | local output = module:forward(input)
217 | local gradOutput = output:clone():normal()
218 | local gradInput = module:backward(input, gradOutput)
219 |
220 | module:backward(input, output)
221 | tester:eq(torch.typename(output), "torch.DoubleTensor")
222 | tester:eq(torch.typename(module.output), "torch.DoubleTensor")
223 | tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor")
224 | tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor")
225 | tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
226 | tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
227 | tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor")
228 | tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.DoubleTensor")
229 |
230 | module:float()
231 | tester:eq(torch.typename(module.output), "torch.FloatTensor")
232 | tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
233 | tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
234 | tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
235 | tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
236 | tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor")
237 | tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.FloatTensor")
238 | local output = module:forward(input:float())
239 | tester:eq(torch.typename(output), "torch.FloatTensor")
240 | local gradInput = module:backward(input:float(), gradOutput:float())
241 | tester:eq(torch.typename(gradInput), "torch.FloatTensor")
242 |
243 | end
244 |
245 | function test.test_nestedGradInput()
246 | local x = nn.Identity()()
247 | local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh())
248 | local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity())
249 | local out = nn.CAddTable()({h1(x), h2(x)})
250 |
251 | local model = nn.gModule({x}, {out})
252 |
253 | local input = {}
254 | input[1] = torch.randn(3, 3)
255 | input[2] = torch.randn(3, 3)
256 | input[3] = torch.randn(3, 3)
257 |
258 | checkGradients(model, input)
259 |
260 | local input = {}
261 | input[1] = torch.randn(2, 3)
262 | input[2] = torch.randn(2, 3)
263 | input[3] = torch.randn(2, 3)
264 |
265 | checkGradients(model, input)
266 | end
267 |
268 | function test.test_unusedInput()
269 | local x = nn.Identity()()
270 | local h = nn.Identity()()
271 | local h2 = nn.Identity()()
272 |
273 | local ok, result = pcall(nn.gModule, {x, h}, {x})
274 | assert(not ok, "the unused input should be detected")
275 | end
276 |
277 | function test.test_unusedChild()
278 | local prevState = nn.Identity()()
279 | local h, cell = prevState:split(2)
280 |
281 | local ok, result = pcall(nn.gModule, {prevState}, {h})
282 | assert(not ok, "the unused cell should be detected")
283 | end
284 |
285 | function test.test_nilInput()
286 | local ok, result = pcall(function() nn.Sigmoid()(nil) end)
287 | assert(not ok, "the nil input should be detected")
288 | end
289 |
290 | function test.test_unusedNode()
291 | local in1 = nn.Identity()()
292 | local in2 = nn.Identity()()
293 | local middleResult = nn.Sigmoid()(in2)
294 | local out1 = nn.Sigmoid()(in1)
295 |
296 | local ok, result = pcall(nn.gModule, {in1, in2}, {out1})
297 | assert(not ok, "the unused middleResult should be detected")
298 | end
299 |
300 | function test.test_usageAfterSplit()
301 | local prevState = nn.Identity()()
302 | local h, cell = prevState:split(2)
303 | local nextState = nn.Identity()(prevState)
304 | local transformed = nn.Sigmoid()(cell)
305 |
306 | local model = nn.gModule({prevState}, {h, nextState, transformed})
307 | local nHidden = 10
308 | local input = {torch.randn(nHidden), torch.randn(nHidden)}
309 | checkGradients(model, input)
310 | end
311 |
312 | function test.test_resizeNestedAs()
313 | local in1 = nn.Identity()()
314 | local out1 = nn.Identity()(in1)
315 | local out2 = nn.Identity()(in1)
316 |
317 | local net = nn.gModule({in1}, {out1, out2})
318 | local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
319 | net:forward(input)
320 | net:backward(input, net.output)
321 | checkGradients(net, input)
322 |
323 | input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}}
324 | net:forward(input)
325 | net:backward(input, net.output)
326 | checkGradients(net, input)
327 |
328 | input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
329 | net:forward(input)
330 | local gradInput = net:backward(input, net.output)
331 | tester:eq(#(gradInput[2]), 2, "gradInput[2] size")
332 | checkGradients(net, input)
333 | end
334 |
335 |
336 | function test.test_annotateGraph()
337 | local input = nn.Identity()():annotate(
338 | {name = 'Input', description = 'DescA',
339 | graphAttributes = {color = 'red'}})
340 |
341 | local hidden_a = nn.Linear(10, 10)(input):annotate(
342 | {name = 'Hidden A', description = 'DescB',
343 | graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}})
344 | local hidden_b = nn.Sigmoid()(hidden_a)
345 | local output = nn.Linear(10, 10)(hidden_b)
346 | local net = nn.gModule({input}, {output})
347 |
348 | tester:assert(hidden_a:label():match('DescB') ~= nil)
349 | local fg_tmpfile = os.tmpname()
350 | local bg_tmpfile = os.tmpname()
351 | if not pcall(function() graph.dot(net.fg, 'Test', fg_tmpfile) end) then
352 | return -- prevent graphviz not found error
353 | end
354 | graph.dot(net.fg, 'Test BG', bg_tmpfile)
355 |
356 | local function checkDotFile(tmpfile)
357 | local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all")
358 | tester:assert(
359 | dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]') ~= nil)
360 | tester:assert(
361 | dotcontent:match(
362 | '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]') ~= nil)
363 | tester:assert(
364 | dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]') ~= nil)
365 | tester:assert(
366 | dotcontent:match(
367 | '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]') ~= nil)
368 | end
369 |
370 | checkDotFile(fg_tmpfile)
371 | checkDotFile(bg_tmpfile)
372 | end
373 |
374 | function test.test_splitMore()
375 | local nSplits = 2
376 | local in1 = nn.Identity()()
377 | local out1, out2 = nn.SplitTable(2)(in1):split(nSplits)
378 |
379 | local model = nn.gModule({in1}, {out1, out2})
380 | local input = torch.randn(10, nSplits + 1)
381 | local ok, result = pcall(model.forward, model, input)
382 | assert(not ok, "the extra input to split should be detected")
383 | end
384 |
385 | function test.test_splitLess()
386 | local nSplits = 3
387 | local in1 = nn.Identity()()
388 | local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits)
389 |
390 | local model = nn.gModule({in1}, {out1, out2, out3})
391 | local input = torch.randn(10, nSplits - 1)
392 | local ok, result = pcall(model.forward, model, input)
393 | assert(not ok, "the missing input to split should be detected")
394 | end
395 |
396 | function test.test_gradOutputZeroOptim()
397 | local unpack = function(...)
398 | if _G[unpack] then return _G[unpack](...)
399 | else return table.unpack(...) end
400 | end
401 | -- Make module that produces an expanded zero gradInput tensor
402 | local dummyModule = nn.Module()
403 | dummyModule.updateOutput = function(self, input)
404 | self.output = torch.Tensor(1, 2, 3):uniform()
405 | return self.output
406 | end
407 | dummyModule.updateGradInput = function(self, input, gradOutput)
408 | local zeroTensor = torch.Tensor{0}
409 | :view(unpack(torch.ones(input:dim()):totable()))
410 | :expandAs(input)
411 | self.gradInput = zeroTensor
412 | return self.gradInput
413 | end
414 |
415 | -- First input and final gradOutput
416 | local input = torch.Tensor(1, 2, 3):uniform()
417 | local gradOutput = torch.Tensor(1, 2, 3):uniform()
418 |
419 | -- First case: one intermediary gradOutput is going to be zero
420 | local x = nn.Identity()()
421 | local h1 = dummyModule:clone()(x)
422 | local h2 = nn.Identity()(x)
423 | local y = nn.CAddTable()({h1, h2})
424 | local mod = nn.gModule({x}, {y})
425 |
426 | local ok, result = pcall(nn.Module.forward, mod, input)
427 | assert(ok, "forward should succeed")
428 |
429 | nn.Module.backward( mod, input, gradOutput)
430 | ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
431 | assert(ok, "backward should succeed")
432 |
433 | -- Second case: all intermediary gradOutputs are going to be zero
434 | local x = nn.Identity()()
435 | local h1 = dummyModule:clone()(x)
436 | local h2 = dummyModule:clone()(x)
437 | local y = nn.CAddTable()({h1, h2})
438 | local mod = nn.gModule({x}, {y})
439 |
440 | local ok, result = pcall(nn.Module.forward, mod, input)
441 | assert(ok, "forward should succeed")
442 |
443 | ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
444 | assert(ok, "backward should succeed")
445 | end
446 |
447 | function test.test_replace()
448 | local i = nn.Identity()()
449 | local l1 = nn.Linear(5, 2)(i)
450 | local sig = nn.Sigmoid()(l1)
451 | local l2 = nn.Linear(2, 5)(sig)
452 | local model = nn.gModule({i}, {l2})
453 |
454 | local input = torch.randn(4, 5)
455 | local gradOutput = torch.randn(4, 5)
456 | tester:eq(model:forward(input):size(), input:size(), "inconsistent output size")
457 | tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size")
458 |
459 | model:replace(function(m)
460 | if torch.type(m) == 'nn.Linear' then
461 | if m.weight:size(1) == 5 then
462 | return nn.Linear(2, 10)
463 | elseif m.weight:size(1) == 2 then
464 | return nn.Linear(10, 2)
465 | end
466 | elseif torch.type(m) == 'nn.Sigmoid' then
467 | return nn.Tanh()
468 | end
469 | return m
470 | end)
471 |
472 | local input = torch.randn(4, 10)
473 | local gradOutput = torch.randn(4, 10)
474 | tester:eq(model:forward(input):size(), input:size(), "inconsistent output size")
475 | tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size")
476 |
477 | tester:ne(model.modules[2], l1, "gModule.modules wasn't updated")
478 | tester:ne(model.modules[3], sig, "gModule.modules wasn't updated")
479 | tester:eq(torch.type(model.modules[3]), 'nn.Tanh', "replace didn't update gModule.modules")
480 | tester:ne(model.modules[4], l2, "gModule.modules wasn't updated")
481 | end
482 |
483 | tester:add(test):run()
484 |
--------------------------------------------------------------------------------
/test/test_old.lua:
--------------------------------------------------------------------------------
1 | require 'nngraph'
2 | function t1()
3 | local x1 = nn.Linear(20,20)()
4 | local x2 = nn.Linear(10,10)()
5 | local m0=nn.Linear(20,1)(nn.Tanh()(x1))
6 | local m1=nn.Linear(10,1)(nn.Tanh()(x2))
7 | local madd=nn.CAddTable()({m0,m1})
8 | local m2=nn.Sigmoid()(madd)
9 | local m3=nn.Tanh()(madd)
10 | local x = torch.rand(20)
11 | local y = torch.rand(10)
12 | gmod = nn.gModule({x1,x2},{m2,m3})
13 | gmod.verbose = true
14 | print('forward')
15 | gmod:updateOutput({x,y})
16 | print('updateGradInput')
17 | gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)})
18 | graph.dot(gmod.fg,'forward')
19 | graph.dot(gmod.bg,'backward')
20 | end
21 |
22 | function t2()
23 | print('compare')
24 | local m0 = nn.Linear(5,10)()
25 | local m1 = nn.Linear(10,20)()
26 | local m2 = nn.Linear(30,50)(nn.JoinTable(1){m0,m1})
27 | gmod = nn.gModule({m0,m1},{m2})
28 |
29 | local nn0 = nn.Linear(5,10)
30 | local nn1 = nn.Linear(10,20)
31 | local nn2 = nn.Linear(30,50)
32 | local nnmod = nn.Sequential():add(nn.ParallelTable():add(nn0):add(nn1)):add(nn.JoinTable(1)):add(nn2)
33 |
34 | nn0.weight:copy(m0.data.module.weight)
35 | nn0.bias:copy(m0.data.module.bias)
36 | nn1.weight:copy(m1.data.module.weight)
37 | nn1.bias:copy(m1.data.module.bias)
38 | nn2.weight:copy(m2.data.module.weight)
39 | nn2.bias:copy(m2.data.module.bias)
40 |
41 |
42 | for i=1,5 do
43 | local x,y = torch.rand(5),torch.rand(10)
44 | local xx,yy = x:clone(),y:clone()
45 |
46 | gmod:updateOutput({x,y})
47 | nnmod:updateOutput({xx,yy})
48 | print('fdiff = ', torch.dist(gmod.output,nnmod.output))
49 |
50 | local odx = torch.rand(50)
51 | local odxx = odx:clone()
52 |
53 | gmod:updateGradInput({x,y},odx)
54 | nnmod:updateGradInput({xx,yy},odxx)
55 | graph.dot(gmod.fg,tostring(i))
56 | for i,v in ipairs(gmod.gradInput) do
57 | print('bdiff [' ..i.. '] = ', torch.dist(gmod.gradInput[i],nnmod.gradInput[i]))
58 | end
59 | end
60 |
61 | local gms = {m0,m1,m2}
62 | local nms = {nn0,nn1,nn2}
63 |
64 | for i=1,5 do
65 | local x,y = torch.rand(5),torch.rand(10)
66 | local xx,yy = x:clone(),y:clone()
67 |
68 | gmod:updateOutput({x,y})
69 | nnmod:updateOutput({xx,yy})
70 | print('fdiff = ', torch.dist(gmod.output,nnmod.output))
71 |
72 | local odx = torch.rand(50)
73 | local odxx = odx:clone()
74 |
75 | gmod:zeroGradParameters()
76 | nnmod:zeroGradParameters()
77 |
78 | gmod:updateGradInput({x,y},odx)
79 | nnmod:updateGradInput({xx,yy},odxx)
80 |
81 | gmod:accGradParameters({x,y},odx)
82 | nnmod:accGradParameters({xx,yy},odxx)
83 | graph.dot(gmod.fg)
84 | for i,v in ipairs(gms) do
85 | print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradWeight,nms[i].gradWeight))
86 | print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradBias,nms[i].gradBias))
87 | end
88 | end
89 | end
90 |
91 | function t3()
92 | mlp=nn.Sequential(); --Create a network that takes a Tensor as input
93 | mlp:add(nn.SplitTable(2))
94 | c=nn.ParallelTable() --The two Tensors go through two different Linear
95 | c:add(nn.Linear(10,3)) --Layers in Parallel
96 | c:add(nn.Linear(10,7))
97 | mlp:add(c) --Outputing a table with 2 elements
98 | p=nn.ParallelTable() --These tables go through two more linear layers
99 | p:add(nn.Linear(3,2)) -- separately.
100 | p:add(nn.Linear(7,1))
101 | mlp:add(p)
102 | mlp:add(nn.JoinTable(1)) --Finally, the tables are joined together and output.
103 |
104 | pred=mlp:forward(torch.randn(10,2))
105 | print(pred)
106 |
107 | for i=1,25 do -- A few steps of training such a network..
108 | x=torch.ones(10,2);
109 | y=torch.Tensor(3); y:copy(x:select(2,1,1):narrow(1,1,3))
110 | pred=mlp:forward(x)
111 |
112 | criterion= nn.MSECriterion()
113 | local err=criterion:forward(pred,y)
114 | local gradCriterion = criterion:backward(pred,y);
115 | print(x,y)
116 | mlp:zeroGradParameters();
117 | mlp:backward(x, gradCriterion);
118 | mlp:updateParameters(0.05);
119 |
120 | print(err)
121 | end
122 | end
123 |
124 | function t4()
125 | local getInput1 = nn.Identity()()
126 | local getInput2 = nn.Identity()()
127 | local mlp = nn.Tanh()(getInput1)
128 | net = nn.gModule({getInput1, getInput2}, {mlp, getInput2})
129 |
130 |
131 | local input1 = torch.randn(2)
132 | local input2 = torch.randn(5)
133 |
134 | net:forward({input1, input2})
135 | local gradInput = net:backward({input1, input2},
136 | {torch.randn(input1:size()), torch.randn(input2:size())})
137 | print("gradInput[1]:", gradInput[1])
138 | print("gradInput[2]:", gradInput[2])
139 | graph.dot(net.fg)
140 | assert(gradInput[1]:nElement() == input1:nElement(), "size mismatch")
141 |
142 | end
143 |
144 | function t5()
145 | local m = nn.Sequential()
146 | m:add(nn.SplitTable(1))
147 | m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
148 | local input = nn.Identity()()
149 | local input1,input2 = m(input):split(2)
150 | local m3 = nn.JoinTable(1)({input1,input2})
151 |
152 | g = nn.gModule({input},{m3})
153 | graph.dot(g.fg,'init forward')
154 |
155 | local indata = torch.rand(2,10)
156 | local gdata = torch.rand(50)
157 | g:forward(indata)
158 | g:backward(indata,gdata)
159 |
160 | graph.dot(g.fg,'forward')
161 | graph.dot(g.bg,'backward')
162 | end
163 |
164 | function topsort(a)
165 | -- first clone the graph
166 | -- local g = self:clone()
167 | -- local nodes = g.nodes
168 | -- local edges = g.edges
169 | -- for i,node in ipairs(nodes) do
170 | -- node.children = {}
171 | -- end
172 |
173 | -- reverse the graph
174 | rg,map = a:reverse()
175 | local rmap = {}
176 | for k,v in pairs(map) do
177 | rmap[v] = k
178 | end
179 |
180 | -- work on the sorted graph
181 | sortednodes = {}
182 | rootnodes = rg:roots()
183 |
184 | if #rootnodes == 0 then
185 | print('Graph has cycles')
186 | end
187 |
188 | -- run
189 | for i,root in ipairs(rootnodes) do
190 | root:dfs(function(node)
191 | print(node.id,rmap[node].id)
192 | -- print(rmap[node])
193 | table.insert(sortednodes,rmap[node]) end)
194 | end
195 |
196 | if #sortednodes ~= #a.nodes then
197 | print('Graph has cycles')
198 | end
199 | return sortednodes,rg,rootnodes
200 | end
201 |
202 | local my={eq =
203 | function(a,b,s)
204 | if a:dist(b) == 0 then
205 | print('ok')
206 | else
207 | print('error : ' .. s)
208 | print('a : ');print(a)
209 | print('b : ');print(b)
210 | end
211 | end}
212 |
213 | function t8()
214 | local in1 = nn.Identity()()
215 | local m = nn.Linear(10,10)(in1)
216 | local out1 = nn.Tanh()(m)
217 | local out2 = nn.Tanh()(m)
218 | local out = nn.CAddTable(){out1, out2}
219 | local mod = nn.gModule({in1}, {out})
220 |
221 | local dot = nngraph.simple_print.todot(mod.fg, 'bogus')
222 | print (dot)
223 | nngraph.simple_print.dot(mod.fg, 'bogus', 'new')
224 | graph.dot(mod.fg, 'bogus', 'old')
225 | end
226 | -- t2()
227 | t8()
228 |
--------------------------------------------------------------------------------
/utils.lua:
--------------------------------------------------------------------------------
1 | local utils = {}
2 |
3 | function utils.istorchclass(x)
4 | return type(x) == 'table' and torch.typename(x)
5 | end
6 |
7 | function utils.istable(x)
8 | return type(x) == 'table' and not torch.typename(x)
9 | end
10 |
11 | --[[ Returns a useful error message when a nngraph.Node is expected. ]]
12 | function utils.expectingNodeErrorMessage(badVal, array, idx)
13 | if badVal == nil then
14 | return string.format('%s[%d] is nil (typo / bad index?)', array, idx)
15 | elseif torch.isTypeOf(badVal, 'nn.Module') then
16 | local errStr = '%s[%d] is an nn.Module, specifically a %s, but the ' ..
17 | 'only valid thing to pass is an instance of ' ..
18 | 'nngraph.Node. Did you forget a second set of parens, ' ..
19 | 'which convert a nn.Module to a nngraph.Node?'
20 | return string.format(errStr, array, idx, torch.typename(badVal))
21 | else
22 | local errStr = '%s[%d] should be an nngraph.Node but is of type %s'
23 | return string.format(errStr, array, idx,
24 | torch.typename(badVal) or type(badVal))
25 | end
26 | end
27 |
28 | --[[ Lua 5.2+ removed table.maxn, provide fallback implementation. ]]
29 | if table.maxn then
30 | utils.tableMaxN = table.maxn
31 | else
32 | function utils.tableMaxN(tbl)
33 | local max = 0
34 | for k, v in pairs(tbl) do
35 | if type(k) == 'number' and k > max then
36 | max = k
37 | end
38 | end
39 | return max
40 | end
41 | end
42 | return utils
43 |
--------------------------------------------------------------------------------