├── LICENSE ├── README.md ├── example.lua ├── nnquery ├── ChildlessElement.lua ├── ContainerElement.lua ├── Context.lua ├── Element.lua ├── ElementList.lua ├── ModuleElement.lua ├── NNGraphGModuleElement.lua ├── NNGraphNodeElement.lua ├── init.lua └── tests │ ├── test_elem.lua │ └── test_nngraph.lua └── rocks └── nnquery-scm-1.rockspec /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Brendan Shillingford 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of logviz nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `nnquery`: query large neural network graph structures in Torch 2 | NN modules in Torch are often complex graph structures, like `nn.Container`s and its subclasses and `nn.gModules` (`nngraph`), arbitrarily nested. This makes it tedious to extract nn modules when debugging, monitoring training progress, or testing. 3 | 4 | `nnquery` provides a facility to query these arbitrarily complex DAGs. XPath and CSS are designed to handle trees, whereas this library supports querying DAGs like neural nets. 5 | The API is loosely inspired by a mix of XPath, CSS queries, and .NET's LINQ. 6 | 7 | See below for a simple example, and a more complete example of extracting things from an LSTM. 8 | 9 | ## Installation 10 | Install `nnquery`: 11 | ``` 12 | luarocks install https://raw.githubusercontent.com/bshillingford/nnquery/master/rocks/nnquery-scm-1.rockspec 13 | ``` 14 | Totem is optional, and used for unit tests. 15 | 16 | # Usage 17 | There are two important base classes that nearly everything is derived from: 18 | 19 | * `Element` (full name: `nnquery.Element`) 20 | * `ElementList` 21 | 22 | Every object you wish to query is wrapped in an `Element`, and sequences/collections of these 23 | are represented using `ElementList`s. 24 | 25 | To wrap an object in an `Element` so you can query it: 26 | ```lua 27 | local nnq = require 'nnquery' 28 | local seq = nn.Sequential() 29 | :add(nn.Tanh()) 30 | :add(nn.ReLU()) 31 | 32 | local tanh = nnq(seq):children():first() 33 | ``` 34 | On the last line, 35 | 36 | * `nnq(seq)` wraps `seq` into an `Element`; 37 | * `:children()` returns an `ElementList` of two `Elements` for `seq`'s children; 38 | * `:first()` returns the first `Element` in the `ElementList`. 39 | 40 | # Realistic example with an LSTM: 41 | This is an example of using various functions in `Element` and `ElementList`: 42 | ```lua 43 | require 'nn' 44 | require 'nngraph' 45 | local nnq = require 'nnquery' 46 | 47 | -- nngraph implementation of LSTM timestep, from Oxford course's practical #6 48 | function create_lstm(opt) 49 | local x = nn.Identity()() 50 | local prev_c = nn.Identity()() 51 | local prev_h = nn.Identity()() 52 | 53 | function new_input_sum() 54 | -- transforms input 55 | local i2h = nn.Linear(opt.rnn_size, opt.rnn_size)(x) 56 | -- transforms previous timestep's output 57 | local h2h = nn.Linear(opt.rnn_size, opt.rnn_size)(prev_h) 58 | return nn.CAddTable()({i2h, h2h}) 59 | end 60 | 61 | local in_gate = nn.Sigmoid()(new_input_sum()) 62 | local forget_gate = nn.Sigmoid()(new_input_sum()) 63 | local out_gate = nn.Sigmoid()(new_input_sum()) 64 | local in_transform = nn.Tanh()(new_input_sum()) 65 | 66 | local next_c = nn.CAddTable()({ 67 | nn.CMulTable()({forget_gate, prev_c}), 68 | nn.CMulTable()({in_gate, in_transform}) 69 | }) 70 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 71 | 72 | nngraph.annotateNodes() 73 | local mod = nn.gModule({x, prev_c, prev_h}, {next_c, next_h}) 74 | mod.name = "LSTM" 75 | return mod 76 | end 77 | 78 | -- Example network 79 | local foo = nn.Sequential() 80 | :add(nn.Module()) 81 | :add(create_lstm{rnn_size=3}) 82 | :add(nn.ReLU()) 83 | :add(nn.ReLU()) 84 | :add(nn.Linear(3, 4)) 85 | 86 | -- Find the LSTM in a few different ways: 87 | local lstm = nnq(foo) -- Wrap the module in an Element object using the default context 88 | -- which allows querying nn containers and nngraph's gmodules. 89 | :descendants() -- Get all descendants below this node in the graph 90 | :where(function(e) -- Filter Elements by the given predicate 91 | return e:classIs(nnq.NNGraphGModuleElement) 92 | end) 93 | :only() -- Returns the first element in the returned sequence, and 94 | -- asserts that it is the only element in the sequence. 95 | -- (shortcut for list:first() and assert(list:count() == 1)) 96 | local lstm2 = nnq(foo) 97 | :children() -- Returns the contained modules of the nn.Sequential object as an 98 | -- ElementList 99 | :nth(2) -- Grabs the 2nd child of the nn.Sequential 100 | -- (alternate shorthand syntax: nnq(foo):children()[2]) 101 | local lstm3 = nnq(foo) 102 | :descendants() -- 103 | :attr{name='LSTM'} -- Get only the objects with a name attribute set to 'LSTM', 104 | -- where it'll check both raw attributes and attempt to call 105 | -- the function assuming it's a getter method, i.e. check 106 | -- module:name() == 'LSTM'. 107 | :only() 108 | assert(lstm:val() == lstm2:val() and lstm2:val() == lstm3:val(), 109 | 'they should all return the same LSTM gmodule') 110 | 111 | -- Get the output nodes of the nngraph gmodule as an ElementList: 112 | local outputs = lstm:outputs() 113 | -- Two ways to get the count for an ElementList: 114 | print('The LSTM gmodule has '..outputs:count()..' outputs, they are:' outputs) 115 | print('The LSTM gmodule has '..#outputs..' outputs, they are:', outputs) 116 | assert(outputs:first():name() == 'next_c') -- :name() is available on NNGraphNodeElements, 117 | -- as a shortcut for: 118 | assert(outputs:first():val().data.annotations.name == 'next_c') 119 | 120 | -- Let's find the forget gate: 121 | local forget_gate = lstm:descendants():attr{name='forget_gate'}:only() 122 | print(forget_gate) 123 | -- But it's the sigmoid, not the gate's pre-activations, so let's get the sum: 124 | local input_sum = forget_gate:parent() -- This is an alias for :parents():only(). 125 | -- Note: nngraph nodes can have multiple parents (i.e. 126 | -- inputs 127 | assert(torch.isTypeOf(input_sum:val().data.module, nn.CAddTable)) 128 | assert(torch.isTypeOf(input_sum:module(), nn.CAddTable)) -- alias for :val().data.module 129 | ``` 130 | 131 | # Further details: 132 | Wrapping objects into elements and similar operations only make sense relative to a **context**, an instance of `nnquery.Context`, which contains a list of `Element` types and conditions on which to instantiate depending on what type is provided to it. Additionally, the context caches `Element`s, so that wrapping the same object twice returns the same instance of the `Element` subclass. 133 | `nnquery/init.lua` contains the construction of a default context (accessible as `nnquery.default`) that contains all the implemented `Element` types, similarly to this: 134 | ```lua 135 | local ctx = nnq.Context() 136 | ctx:reg(nnq.NNGraphGModuleElement, nnq.NNGraphGModuleElement.isGmodule) 137 | ctx:reg(nnq.NNGraphNodeElement, nnq.NNGraphNodeElement.isNode) 138 | ctx:reg(nnq.ContainerElement, nnq.ContainerElement.isContainer) -- after since gModule IS_A Container 139 | ctx:default(nnq.ChildlessElement) 140 | ``` 141 | 142 | Note that there is no true "root" node, unlike an XML/HTML document; the root is simply the place where the query begins. Therefore, one cannot[*] search for the root's parents, even if the root module is contained in (for example) a container. 143 | 144 | [*] Usually. Unless an element's parents are pre-populated from a previous query. 145 | 146 | # Documentation 147 | Further documentation can be found in doc comment style before class definitions and method definitions in the code itself. 148 | 149 | ***TODO: extract these into markdown format and put links here*** 150 | 151 | # Developing 152 | ## Extending 153 | You may have your own `nn` modules that are not handled by the existing handlers. In this case, 154 | you can implement your own `Element` object (see the existing ones for examples), and create your own context that adds a handler for this `Element`. See the default context (see above) for details. 155 | ## Contributing 156 | Bug reports are appreciated, preferably with a pull request for a test that breaks existing code and a patch that fixes it. If you do, please adhere to the (informal) code style in the existing code where appropriate. 157 | 158 | -------------------------------------------------------------------------------- /example.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | local nnq = require 'nnquery' 4 | 5 | -- nngraph implementation of LSTM timestep, from Oxford course's practical #6 6 | function create_lstm(opt) 7 | local x = nn.Identity()() 8 | local prev_c = nn.Identity()() 9 | local prev_h = nn.Identity()() 10 | 11 | function new_input_sum() 12 | -- transforms input 13 | local i2h = nn.Linear(opt.rnn_size, opt.rnn_size)(x) 14 | -- transforms previous timestep's output 15 | local h2h = nn.Linear(opt.rnn_size, opt.rnn_size)(prev_h) 16 | return nn.CAddTable()({i2h, h2h}) 17 | end 18 | 19 | local in_gate = nn.Sigmoid()(new_input_sum()) 20 | local forget_gate = nn.Sigmoid()(new_input_sum()) 21 | local out_gate = nn.Sigmoid()(new_input_sum()) 22 | local in_transform = nn.Tanh()(new_input_sum()) 23 | 24 | local next_c = nn.CAddTable()({ 25 | nn.CMulTable()({forget_gate, prev_c}), 26 | nn.CMulTable()({in_gate, in_transform}) 27 | }) 28 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 29 | 30 | nngraph.annotateNodes() 31 | local mod = nn.gModule({x, prev_c, prev_h}, {next_c, next_h}) 32 | mod.name = "LSTM" 33 | return mod 34 | end 35 | 36 | -- Example network 37 | local foo = nn.Sequential() 38 | :add(nn.Module()) 39 | :add(create_lstm{rnn_size=3}) 40 | :add(nn.ReLU()) 41 | :add(nn.ReLU()) 42 | :add(nn.Linear(3, 4)) 43 | 44 | -- Find the LSTM in a few different ways: 45 | local lstm = nnq(foo) -- Wrap the module in an Element object using the default context 46 | -- which allows querying nn containers and nngraph's gmodules. 47 | :descendants() -- Get all descendants below this node in the graph 48 | :where(function(e) -- Filter Elements by the given predicate 49 | return e:classIs(nnq.NNGraphGModuleElement) 50 | end) 51 | :only() -- Returns the first element in the returned sequence, and 52 | -- asserts that it is the only element in the sequence. 53 | -- (shortcut for list:first() and assert(list:count() == 1)) 54 | local lstm2 = nnq(foo) 55 | :children() -- Returns the contained modules of the nn.Sequential object as an 56 | -- ElementList 57 | :nth(2) -- Grabs the 2nd child of the nn.Sequential 58 | -- (alternate shorthand syntax: nnq(foo):children()[2]) 59 | local lstm3 = nnq(foo) 60 | :descendants() -- 61 | :attr{name='LSTM'} -- Get only the objects with a name attribute set to 'LSTM', 62 | -- where it'll check both raw attributes and attempt to call 63 | -- the function assuming it's a getter method, i.e. check 64 | -- module:name() == 'LSTM'. 65 | assert(lstm:val() == lstm2:val() and lstm2:val() == lstm3:val(), 66 | 'they should all return the same LSTM gmodule') 67 | 68 | -- Get the output nodes of the nngraph gmodule as an ElementList: 69 | local outputs = lstm:outputs() 70 | -- Two ways to get the count for an ElementList: 71 | print('The LSTM gmodule has '..outputs:count()..' outputs, they are:' outputs) 72 | print('The LSTM gmodule has '..#outputs..' outputs, they are:', outputs) 73 | assert(outputs:first():name() == 'next_c') -- :name() is available on NNGraphNodeElements, 74 | -- as a shortcut for: 75 | assert(outputs:first():val().data.annotations.name == 'next_c') 76 | 77 | -- Let's find the forget gate: 78 | local forget_gate = lstm:descendants():attr{name='forget_gate'}:only() 79 | print(forget_gate) 80 | -- But it's the sigmoid, not the gate's pre-activations, so let's get the sum: 81 | local input_sum = forget_gate:parent() -- This is an alias for :parents():only(). 82 | -- Note: nngraph nodes can have multiple parents (i.e. 83 | -- inputs 84 | assert(torch.isTypeOf(input_sum:val().data.module, nn.CAddTable)) 85 | assert(torch.isTypeOf(input_sum:module(), nn.CAddTable)) -- alias for :val().data.module 86 | 87 | -------------------------------------------------------------------------------- /nnquery/ChildlessElement.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | 3 | local nnquery = require 'nnquery' 4 | 5 | --[[ 6 | Concrete class with no children, using manually added parents. 7 | 8 | Note: `:parents()` already impl'd by `ModuleElement`; this just makes 9 | `:children()` return an empty `ElementList`. 10 | ]] 11 | local ChildlessElement, super = classic.class(..., nnquery.ModuleElement) 12 | 13 | --[[ 14 | Returns children, in this case empty `ElementList`. 15 | ]] 16 | function ChildlessElement:children() 17 | return nnquery.ElementList.create_empty() 18 | end 19 | 20 | return ChildlessElement 21 | -------------------------------------------------------------------------------- /nnquery/ContainerElement.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | local classic = require 'classic' 3 | local nn = require 'nn' 4 | 5 | local nnquery = require 'nnquery' 6 | 7 | local ContainerElement, super = classic.class(..., nnquery.ModuleElement) 8 | 9 | function ContainerElement:children() 10 | local wrappeds = self._ctx:wrapall(self._val.modules) 11 | for _, wrapped in ipairs(wrappeds) do 12 | wrapped:_set_parents({self}) 13 | end 14 | return nnquery.ElementList.fromtable(wrappeds) 15 | end 16 | 17 | function ContainerElement.static.isContainer(m) 18 | return torch.isTypeOf(m, nn.Container) 19 | end 20 | 21 | return ContainerElement 22 | -------------------------------------------------------------------------------- /nnquery/Context.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Represents a query execution context, which keeps track of a registry of 3 | `Element` types and provides a mechanism for constructing an `Element` 4 | given the registry. 5 | 6 | Caches `Element` instances for wrapped values, guaranteeing that two calls 7 | to `:wrap()` in the same context always return the same `Element`. Note that 8 | this means a reference will be retained and seemingly deleted objects 9 | will not be freed until the cache is clear. In that case, call `:clear()`. 10 | 11 | Note that cacheing does not affect mutating children of a value, at least 12 | in the case of `ContainerElement`, which builds up children lists on the 13 | fly every time it is called. For `nngraph` `Element`s, these are not (easily) 14 | mutable, so that implementation caches its nodes and wrapped nodes. 15 | ]] 16 | 17 | local classic = require 'classic' 18 | 19 | local nnquery = require 'nnquery' 20 | 21 | local Context = classic.class(...) 22 | 23 | --[[ 24 | single_match: if true, only allows one handler to match. 25 | Otherwise, the first matching handler is used. 26 | Defaults to false. 27 | ]] 28 | function Context:_init(single_match) 29 | self.single_match = single_match or false 30 | self._reg = {} 31 | -- Cache table mapping vals to Elements, guaranteeing only one Element exists per val 32 | self._wrapcache = {} 33 | end 34 | 35 | --[[ 36 | Clears `Element` cache. 37 | ]] 38 | function Context:clear() 39 | self._wrapcache = {} 40 | end 41 | 42 | --[[ 43 | Register the provided `Element` to handle the cases specified by the 44 | match-checking predicate. Note that only one registered handle can return 45 | true for a given object. 46 | ]] 47 | function Context:reg(cls, check_match) 48 | if type(check_match) ~= 'function' then 49 | error('Match-checking predicate must be a function') 50 | end 51 | if not cls:isSubclassOf(nnquery.Element) then 52 | error('Can only register Element subclasses') 53 | end 54 | table.insert(self._reg, {cls=cls, check_match=check_match}) 55 | return self 56 | end 57 | 58 | --[[ 59 | Registers a default `Element` implementation in case no registered handlers match. 60 | If not specified, `:wrap()` will raise an error. 61 | ]] 62 | function Context:default(elem_class) 63 | if not elem_class:isSubclassOf(nnquery.Element) then 64 | error('Can only register Element subclasses') 65 | end 66 | self._default_cls = elem_class 67 | return self 68 | end 69 | 70 | --[[ 71 | As specified by the registered handlers, wraps the given object in an 72 | instance of `Element` (or subclass). 73 | 74 | Behaviour depends on value of `single_match` provided to ctor. 75 | ]] 76 | function Context:wrap(val) 77 | if self._wrapcache[val] then 78 | return self._wrapcache[val] 79 | end 80 | 81 | local wrapped_reg 82 | local wrapped 83 | for _, reg in ipairs(self._reg) do 84 | if reg.check_match(val) then 85 | if wrapped then 86 | error('More than one handler matched, first ' 87 | .. tostring(wrapped_reg or reg.cls:name()) 88 | .. ', now ' .. tostring(reg.name or reg.cls:name())) 89 | end 90 | wrapped_reg = reg 91 | -- first arg to ctor is ctx, second is the wrapee 92 | wrapped = reg.cls(self, val) 93 | -- if we're only allowing a single matching handler, keep going 94 | -- to check that no other handlers match 95 | if not self.single_match then 96 | break 97 | end 98 | end 99 | end 100 | if not wrapped_reg then 101 | if self._default_cls then 102 | return self._default_cls(self, val) 103 | else 104 | error('No handlers matched, and no default class provided') 105 | end 106 | end 107 | self._wrapcache[val] = wrapped 108 | return wrapped 109 | end 110 | 111 | --[[ 112 | Applies `:wrap()` over a table, returns a table of wrapped objects. 113 | ]] 114 | function Context:wrapall(vals) 115 | local wrappeds = {} 116 | for _,v in ipairs(vals) do 117 | table.insert(wrappeds, self:wrap(v)) 118 | end 119 | return wrappeds 120 | end 121 | 122 | --[[ 123 | Convenient alias for `:wrap()`. 124 | Intended to be used by the user, not by internal code for the sake 125 | of cleanliness. 126 | ]] 127 | function Context:__call(val) 128 | return self:wrap(val) 129 | end 130 | 131 | return Context 132 | 133 | -------------------------------------------------------------------------------- /nnquery/Element.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | 3 | local nnquery = require 'nnquery' 4 | 5 | --[[ 6 | Abstract base class for all elements. 7 | Provides an interface and various common functionality for search and querying. 8 | 9 | Note a few notable design features and/or differences from XPath and CSS queries. 10 | 11 | 1. Modules are best described as a DAG rather than a tree. Hence, there can the multiple 12 | parents (i.e. inputs) rather than just one. 13 | 2. Since the API is OOP rather than XPath or the CSS query selection language, operations 14 | such as `//SomeTag[position() mod 2 = 0]` are implemented as calls to an `ElementList`. 15 | 3. There is no true "root" node, unlike an XML/HTML document; the root is simply the place 16 | where the query begins. Therefore, one cannot search for the root's parents, even if 17 | the root module is contained in another context. 18 | 19 | The API is inspired in part by XPath and .NET's LINQ. 20 | ]] 21 | local Element = classic.class(...) 22 | 23 | --[[ 24 | Constructor called by the execution context. Can be overridden. 25 | Argument `ctx` is a `Context` instance, and `val` specifies the 26 | contents of this element; usually an nn module. 27 | ]] 28 | function Element:_init(ctx, val) 29 | assert(ctx and ctx:classIs(nnquery.Context), 'a Context ctx must be given') 30 | assert(val, 'val must be given') 31 | self._ctx = ctx 32 | self._val = val 33 | end 34 | 35 | --[[ 36 | Returns the object that this `Element` wraps. 37 | ]] 38 | function Element:val() 39 | return self._val 40 | end 41 | 42 | --[[ 43 | Returns true if the `Element` instances refer to the same element. 44 | 45 | Defaults to comparing (by reference) `val` property, but can be overridden. 46 | ]] 47 | function Element:equals(other) 48 | return self._val == other._val 49 | end 50 | 51 | --[[ 52 | Returns an `ElementList` for children of this element. 53 | 54 | Pure virtual method; must be implemented by the concrete `Element`. 55 | ]] 56 | Element:mustHave('children') 57 | 58 | --[[ 59 | Returns an `ElementList` for parents of this element. 60 | Note that an element can have more than one parent, such as the input nodes to 61 | an nngraph node. The precise definition of "parent" is implementation-dependent. 62 | 63 | Pure virtual method; must be implemented by the concrete `Element`. 64 | ]] 65 | Element:mustHave('parents') 66 | 67 | --[[ 68 | Alias for `:parents():only()`. 69 | ]] 70 | function Element:parent() 71 | return self:parents():only() 72 | end 73 | Element:final('parent') 74 | 75 | --[[ 76 | Alias for `:children():nth()`. 77 | ]] 78 | function Element:nth_child(...) 79 | return self:children():nth(...) 80 | end 81 | Element:final('nth_child') 82 | 83 | --[[ 84 | Alias for `:children():first()`. 85 | ]] 86 | function Element:first_child() 87 | return self:children():first() 88 | end 89 | Element:final('first_child') 90 | 91 | --[[ 92 | Alias for `:children():last()`. 93 | ]] 94 | function Element:last_child() 95 | return self:children():last() 96 | end 97 | Element:final('last_child') 98 | 99 | --[[ 100 | Returns an `ElementList` for following siblings of this element. 101 | 102 | Raises an error if this element has multiple parents. 103 | ]] 104 | function Element:following_siblings() 105 | local parents = self:parents():totable() 106 | if #parents ~= 1 then 107 | error('finds siblings only for elements with precisely one parent') 108 | end 109 | local all_siblings = parents[1]:children() 110 | -- after is exclusive, which is what we want 111 | return all_siblings:after(function(el) 112 | return el:equals(self) 113 | end) 114 | end 115 | Element:final('following_siblings') 116 | 117 | --[[ 118 | Returns an `ElementList` for preceding siblings of this element, where 119 | the first sibling is first child of the parent, and subsequent elements 120 | are progressively closer, where the last is the immediately preceding 121 | element. 122 | ]] 123 | function Element:preceding_siblings() 124 | local parents = self:parents():totable() 125 | if #parents ~= 1 then 126 | error('finds siblings only for elements with precisely one parent') 127 | end 128 | local all_siblings = parents[1]:children() 129 | -- before is exclusive, which is what we want 130 | return all_siblings:before(function(el) 131 | return el:equals(self) 132 | end) 133 | end 134 | Element:final('preceding_siblings') 135 | 136 | --[[ 137 | Returns an `ElementList` of all descendants. 138 | ]] 139 | function Element:descendants() 140 | local descs = {} 141 | self:dfs(function(el) table.insert(descs, el) end) 142 | return nnquery.ElementList.fromtable(descs) 143 | end 144 | 145 | --[[ 146 | Returns an `ElementList` of all ancestors. 147 | ]] 148 | function Element:ancestors() 149 | local descs = {} 150 | self:dfs(function(el) table.insert(descs, el) end, 'parents') 151 | return nnquery.ElementList.fromtable(descs) 152 | end 153 | 154 | --[[ 155 | Recurses down the DAG below this `Element` in DFS order, calling the callback 156 | at each `Element`. Note that DFS order is not unique, both due to ordering of 157 | children and the structure being a DAG rather than a tree. 158 | 159 | If `children_func_name` is set to `parents`, performs a DFS with the graph's 160 | edges flipped. Defaults to `children`, i.e. a normal DFS. 161 | ]] 162 | function Element:dfs(func_visit, children_func_name) 163 | children_func_name = children_func_name or 'children' 164 | local visited_table = {} 165 | local function traverse(el) 166 | for child in el[children_func_name](el):iter() do 167 | if not visited_table[child:val()] then 168 | visited_table[child:val()] = true 169 | func_visit(child) 170 | traverse(child) 171 | end 172 | end 173 | end 174 | traverse(self) 175 | end 176 | 177 | -- TODO: get a queue and implement BFS 178 | 179 | function Element:__tostring() 180 | return string.format('%s[val=%s]', self:class():name(), tostring(self:val())) 181 | end 182 | 183 | return Element 184 | -------------------------------------------------------------------------------- /nnquery/ElementList.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | 3 | --[[ 4 | Stores sequences of elements as returned by querying operations, and provides various filtering 5 | and aggregation operations on them. 6 | 7 | Most functions either return a single `Element`, a new `ElementList`, or a number (for `:count()`). 8 | 9 | ### Boring details and implementation decisions: 10 | Makes the assumption that traversing an iterator is cheap, so we sometimes perform multiple 11 | passes. This allows future use for large data structures. Some small results are cached, but 12 | most things are constructed on demand, and some large things are not cached as performance 13 | is not an issue: the library is intended mostly to be a debugging tool and for tasks 14 | not part of a hotloop like a training loop, besides when query results are saved and used in 15 | such a loop of course. 16 | Performance optimizations may come later, if needed. 17 | ]] 18 | local EL = classic.class(...) 19 | 20 | --[[ 21 | Constructor meant to be called by an `Element` instance, not by the user. 22 | - new_elem_iter: Constructs an iterator that returns the next element in the sequence. 23 | May be called multiple times to construct multiple iterators. 24 | In some cases, this will be used to construct a full table, but 25 | most operations on the `ElementList` don't need to (like `nth` and similar ops). 26 | ]] 27 | function EL:_init(new_elem_iter) 28 | assert(type(new_elem_iter) == 'function', 'expected iterator factory') 29 | self._newiter = new_elem_iter 30 | 31 | -- FIXME: workaround: classic doesn't support __len metamethod 32 | getmetatable(self).__len = self.count 33 | end 34 | 35 | --[[ 36 | Factory for constructing `ElementList`s from tables of `Element`s. 37 | ]] 38 | function EL.static.fromtable(elements) 39 | assert(type(elements) == 'table', 'expected table') 40 | 41 | -- returns ElementList w/ iterator factory: 42 | local el = EL(function() 43 | local pos = 0 44 | return function() 45 | pos = pos + 1 46 | -- Terminates properly since table[count+1] == nil 47 | return elements[pos] 48 | end 49 | end) 50 | el._count = #elements -- precomputed for efficiency 51 | return el 52 | end 53 | 54 | --[[ 55 | Factory for constructing an empty `ElementList`. 56 | ]] 57 | function EL.static.create_empty() 58 | return EL(function() 59 | return function() return nil end 60 | end) 61 | end 62 | 63 | --[[ 64 | Returns a new iterator over elements of this sequence. 65 | ]] 66 | function EL:iter() 67 | -- . not : because it's a variable, not a method 68 | return self._newiter() 69 | end 70 | 71 | --[[ 72 | Returns the `n`th `Element` of the sequence. 73 | If `n` is negative, counts backward from the end. 74 | The first element is at index 1, and the last is -1. 75 | ]] 76 | function EL:nth(n_) 77 | if type(n_) ~= 'number' or n_ == 0 or math.floor(n_) ~= n_ then 78 | error('expected positive or negative int as argument to :nth(), got: ' .. tostring(n_)) 79 | end 80 | -- Make the index a positive one 81 | local n = n_ 82 | if n < 0 then 83 | -- e.g. -1 + count + 1 = count 84 | n = n + self:count() + 1 85 | end 86 | 87 | -- Use a new iterator find the n'th 88 | i = 1 89 | for e in self:iter() do 90 | if i == n then 91 | return e 92 | end 93 | i = i + 1 94 | end 95 | error(string.format('n out of range: %d (computed: %d), count=%d', n_, n, self:count())) 96 | end 97 | 98 | function EL:__index(i) 99 | -- So that nonexistent properties don't go into :nth() 100 | if type(i) == 'number' then 101 | return self:nth(i) 102 | end 103 | end 104 | 105 | --[[ 106 | First `Element`. 107 | ]] 108 | function EL:first() 109 | return self:nth(1) 110 | end 111 | 112 | --[[ 113 | Like `:first()`, but ensures only one `Element` is in the list. 114 | If there is more than one, an error is thrown. 115 | ]] 116 | function EL:only() 117 | if self:count() ~= 1 then 118 | error('expected 1 element, have ' .. tostring(self:count())) 119 | end 120 | return self:first() 121 | end 122 | 123 | --[[ 124 | Last `Element`. 125 | ]] 126 | function EL:last() 127 | return self:nth(-1) 128 | end 129 | 130 | --[[ 131 | Return a subsequence as an `ElementList`, using the index range given, where 132 | both `from` and `to` are ***inclusive***; `to` may be `nil` to indicate slicing to the end, 133 | and both indices can be negative with the same semantics as `nth`. 134 | ]] 135 | function EL:slice(from, to) 136 | if type(from) ~= 'number' or from == 0 or math.floor(from) ~= from then 137 | error('expected negative or positive int index for from') 138 | end 139 | if (to ~= nil and type(to) ~= 'number') or to == 0 or math.floor(to) ~= to then 140 | error('expected negative or positive int index for to') 141 | end 142 | if to ~= nil and (to < from or (to < 0 and from > 0)) then 143 | error('range must satisfy to <= from (unless to is nil or negative but from is positive)') 144 | end 145 | 146 | -- Make all indices positive (relative to start) 147 | if from < 0 then 148 | from = from + self:count() + 1 149 | end 150 | if to ~= nil and to < 0 then 151 | to = to + self:count() + 1 152 | end 153 | 154 | -- Since the slice may be small, in which case iterating every time would be inefficient, 155 | -- we explicitly construct the result table. 156 | local results = {} 157 | local i = 1 158 | for e in self:iter() do 159 | if to ~= nil and i > to then 160 | break 161 | end 162 | if i >= from then 163 | table.insert(results, e) 164 | end 165 | i = i + 1 166 | end 167 | return EL.fromtable(results) 168 | end 169 | 170 | --[[ 171 | Returns the number of elements produced by the query. 172 | ]] 173 | function EL:count() 174 | if self._count == nil then 175 | local count = 0 176 | for e in self:iter() do 177 | count = count + 1 178 | end 179 | self._count = count 180 | end 181 | return self._count 182 | end 183 | 184 | --[[ 185 | With the given predicate, produces an EL with only the elements 186 | ***after*** (***exclusively*** unless `incl` is `true`) the 187 | first time `pred` returns true. The predicate will not longer 188 | be calle the first time it returns true. 189 | 190 | Predicate has the same form as `:where()`. 191 | ]] 192 | function EL:after(pred, incl) 193 | local true_yet = false 194 | return self:where(function(...) 195 | if not true_yet and pred(...) then 196 | true_yet = true 197 | if incl then 198 | return true 199 | else 200 | return false 201 | end 202 | end 203 | return true_yet 204 | end) 205 | end 206 | 207 | --[[ 208 | Similar to `after`, except inclusive. 209 | ]] 210 | function EL:iafter(pred) 211 | return self:after(pred, true) 212 | end 213 | 214 | --[[ 215 | With the given predicate, produces an EL with only the elements 216 | ***before*** (***exclusively***) the first time 'pred' returns true. 217 | The predicate will not longer be calle the first time it returns true. 218 | 219 | Predicate has the same form as in `:where()`. 220 | ]] 221 | function EL:before(pred, incl) 222 | local true_yet = false 223 | return self:where(function(...) 224 | if not true_yet and pred(...) then 225 | true_yet = true 226 | if incl then 227 | return true 228 | else 229 | return false 230 | end 231 | end 232 | return not true_yet 233 | end) 234 | end 235 | 236 | --[[ 237 | Similar to `before`, except inclusive. 238 | ]] 239 | function EL:ibefore(pred) 240 | return self:before(pred, true) 241 | end 242 | 243 | --[[ 244 | Returns true if the given predicate is true for all elements in this 245 | element list, false otherwise. 246 | 247 | Returns true for an empty list. 248 | 249 | Predicate takes the same two arguments as `:where()`: element, index. 250 | ]] 251 | function EL:all(pred) 252 | local b = true 253 | local idx = 1 254 | for el in self:iter() do 255 | b = b and pred(el, idx) 256 | idx = idx + 1 257 | end 258 | return b 259 | end 260 | 261 | --[[ 262 | Returns true if the given predicate is true for ***any*** element in this 263 | element list. If it is not true for all, then returns false. 264 | 265 | Returns false for an empty list. 266 | 267 | Predicate takes the same two arguments as `:where()`: element, index. 268 | ]] 269 | function EL:any(pred) 270 | local b = false 271 | local idx = 1 272 | for el in self:iter() do 273 | b = b or pred(el, idx) 274 | idx = idx + 1 275 | end 276 | return b 277 | end 278 | 279 | --[[ 280 | Applies a function to each element in the `ElementList`. 281 | 282 | Function takes the same two arguments as `:where()`: element, index. 283 | ]] 284 | function EL:foreach(f) 285 | local idx = 1 286 | for el in self:iter() do 287 | f(el, idx) 288 | idx = idx + 1 289 | end 290 | end 291 | 292 | --[[ 293 | Returns a table of the `Element`s. 294 | 295 | Tables are not cached, so the table can be safely modified and not affect subsequent calls. 296 | ]] 297 | function EL:totable() 298 | local result = {} 299 | for e in self:iter() do 300 | table.insert(result, e) 301 | end 302 | return result 303 | end 304 | 305 | --[[ 306 | Filters by the given predicate function: each `Element` in the `ElementList` 307 | is passed to the sequence and must be return true iff it is meant to kept. 308 | Returns a new `ElementList`. 309 | 310 | If `no_table` is true, doesn't construct a table. This is ideal for when the result of 311 | the `:where()` will only be used once or chained. Used frequently to implement the 312 | other filters. Note: the predicate sequence might be called more than once for the 313 | sequence. 314 | 315 | The function is passed two args: the `Element`, and its index into the `ElementList`. 316 | The index argument needn't be given for lua functions, which ignore extra args. 317 | ]] 318 | function EL:where(pred, no_table) 319 | if type(pred) ~= 'function' then 320 | error('expected function as argument to :where()') 321 | end 322 | 323 | if no_table then 324 | return EL(function() 325 | local pos = 0 326 | local iter = self:iter() 327 | return function() 328 | local el 329 | -- fetch the next matching element, if any, and return if found, else nil 330 | pos = pos + 1 331 | el = iter() 332 | while el do 333 | if pred(el, pos) then 334 | return el 335 | end 336 | pos = pos + 1 337 | el = iter() 338 | end 339 | end 340 | end) 341 | else 342 | -- Here, we don't construct on the fly, in case pred() is rarely or never true, in which 343 | -- case caching is a better idea. 344 | local results = {} 345 | local pos = 1 346 | for el in self:iter() do 347 | if pred(el, pos) then 348 | table.insert(results, el) 349 | end 350 | pos = pos + 1 351 | end 352 | return EL.fromtable(results) 353 | end 354 | end 355 | 356 | --[[ 357 | Given a table mapping `Element` property names to values, returns only the 358 | elements where ***all*** properties equal (using `==`) the provided values. 359 | 360 | Each property can be a simple instance variable or a getter method that takes 361 | no arguments. 362 | ]] 363 | function EL:props(props) 364 | if type(props) ~= 'table' then 365 | error('props must be a table of properties to check') 366 | end 367 | return self:where(function(el) 368 | local all_true = true 369 | for k, v in pairs(props) do 370 | local prop_val 371 | if type(el[k]) == 'function' then 372 | prop_val = el[k](el) -- getter method needs implicit self 373 | else 374 | prop_val = el[k] 375 | end 376 | all_true = all_true and (prop_val == v) 377 | end 378 | return all_true 379 | end) 380 | end 381 | 382 | --[[ 383 | Alias for `props`. 384 | ]] 385 | function EL:attr(...) 386 | return self:props(...) 387 | end 388 | 389 | --[[ 390 | Same as `:props()` except ***any*** property must match rather than all. 391 | ]] 392 | function EL:props_any(props) 393 | if type(props) ~= 'table' then 394 | error('props must be a table of properties to check') 395 | end 396 | return self:where(function(el) 397 | for k, v in pairs(props) do 398 | local prop_val 399 | if type(el[k]) == 'function' then 400 | prop_val = el[k](el) -- getter method needs implicit self 401 | else 402 | prop_val = el[k] 403 | end 404 | if prop_val == v then 405 | return true 406 | end 407 | end 408 | return false 409 | end) 410 | end 411 | 412 | --[[ 413 | Alias for `props_any`. 414 | ]] 415 | function EL:attr_any(...) 416 | return self:props_any(...) 417 | end 418 | 419 | 420 | function EL:__tostring() 421 | local strs = {} 422 | for e in self:iter() do 423 | table.insert(strs, tostring(e)) 424 | end 425 | return string.format('%s{\n %s\n}', 426 | self:class():name(), 427 | table.concat(strs, ',\n ')) 428 | end 429 | 430 | return EL 431 | -------------------------------------------------------------------------------- /nnquery/ModuleElement.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | 3 | local nnquery = require 'nnquery' 4 | 5 | --[[ 6 | Abstract class that adds manual parent tracking to `Element`. 7 | To be used by `nn.Module`'s, with extra functionality added for 8 | returning children. See also `ChildlessElement` for modules with 9 | no children. 10 | ]] 11 | local MPE, super = classic.class(..., nnquery.Element) 12 | 13 | function MPE:_init(...) 14 | super._init(self, ...) 15 | self._parents = {} 16 | end 17 | 18 | --[[ 19 | Returns an `ElementList` for parents of this element. 20 | Parents must be specified by `:add_parent()`. 21 | ]] 22 | function MPE:parents() 23 | return nnquery.ElementList.fromtable(self._parents) 24 | end 25 | 26 | --[[ 27 | Sets parent, intended to be used by `ContainerElement` only. 28 | ]] 29 | function MPE:_set_parents(parents) 30 | self._parents = parents 31 | end 32 | 33 | return MPE 34 | -------------------------------------------------------------------------------- /nnquery/NNGraphGModuleElement.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | 3 | local nnquery = require 'nnquery' 4 | 5 | --[[ 6 | `nn.gModule` element. Children are input node(s). 7 | ]] 8 | local NNGGME, super = classic.class(..., nnquery.ModuleElement) 9 | 10 | function NNGGME:_init(...) 11 | super._init(self, ...) 12 | 13 | -- Find root node first 14 | local mod = self:val() 15 | local nInputs = mod.nInputs or #mod.innode.children 16 | local node = mod.innode 17 | if nInputs ~= #node.children then 18 | assert(#node.children == 1, "expected single child root nngraph node") 19 | node = node.children[1] 20 | end 21 | assert(nInputs == #node.children, "at this point, # children should equal # nngraph inputs") 22 | self._root = node 23 | 24 | -- Sort children if needed: 25 | local innodes = node.children 26 | if innodes[1].data.selectindex and #innodes > 1 then 27 | table.sort(innodes, function(x, y) 28 | assert(x.data.selectindex and y.data.selectindex, "selectindex should exist for all nodes") 29 | return x.data.selectindex < y.data.selectindex 30 | end) 31 | end 32 | self._innodes = innodes 33 | 34 | -- Pre-build Elements for nodes and mapping of fg node -> Element 35 | self._node2el = {} 36 | -- Wrap all nodes and set their gModule references: 37 | for _, node in ipairs(mod.fg.nodes) do 38 | -- this is the only place we should ever wrap a node using ctx 39 | local el = self._ctx:wrap(node) 40 | el:_set_gmod(mod, self) 41 | self._node2el[node] = el 42 | end 43 | 44 | -- Initialize the parents of all the nodes, now that _node2el is ready: 45 | for node, el in pairs(self._node2el) do 46 | el:_init_parents(mod, self) 47 | end 48 | -- Override parent of inputs to be this gmodule element: 49 | for _, innode in ipairs(innodes) do 50 | self._node2el[innode]._parents = {self} 51 | end 52 | end 53 | 54 | -- not the version in ctx: this uses the cached elements in the GModule element 55 | function NNGGME:_wrap(x) 56 | return assert(self._node2el[x], 'internal error: node not found') 57 | end 58 | 59 | function NNGGME:_wrapall(tbl) 60 | local result = {} 61 | for k, v in ipairs(tbl) do 62 | result[k] = self:_wrap(v) 63 | end 64 | return result 65 | end 66 | 67 | --[[ 68 | Returns `ElementList` consisting of input nodes of module's forward graph. 69 | ]] 70 | function NNGGME:children() 71 | return self:inputs() 72 | end 73 | 74 | --[[ 75 | Returns `ElementList` consisting of modules in graph, in the order that 76 | nngraph evaluates them on a forwards pass. 77 | ]] 78 | function NNGGME:modules() 79 | local mods = {} 80 | for _, node in ipairs(self:val().forwardnodes) do 81 | if node.data.module then 82 | table.insert(mods, node.data.module) 83 | end 84 | end 85 | 86 | return nnquery.ElementList.fromtable(self:_wrapall(mods)) 87 | end 88 | 89 | --[[ 90 | Returns `ElementList` consisting of input nodes of module's **forward graph**. 91 | ]] 92 | function NNGGME:inputs() 93 | return nnquery.ElementList.fromtable(self:_wrapall(self._innodes)) 94 | end 95 | 96 | --[[ 97 | Returns `ElementList` consisting of output nodes of module's **forward graph**. 98 | ]] 99 | function NNGGME:outputs() 100 | local mod = self:val() 101 | local leaves = mod.fg:leaves() 102 | assert(#leaves == 1, "gmodule forward graph should have a single leaf") 103 | local leaf = leaves[1] 104 | 105 | -- Get node objects in order of output: 106 | local outnodes = {} 107 | for _, mi in ipairs(leaf.data.mapindex) do 108 | table.insert(outnodes, mod.fg.nodes[mi.forwardNodeId]) 109 | end 110 | 111 | return nnquery.ElementList.fromtable(self:_wrapall(outnodes)) 112 | end 113 | 114 | function NNGGME.static.isGmodule(m) 115 | -- require in here so that the default context can be constructed without nngraph installed 116 | require 'nngraph' 117 | return torch.isTypeOf(m, nn.gModule) 118 | end 119 | 120 | return NNGGME 121 | -------------------------------------------------------------------------------- /nnquery/NNGraphNodeElement.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | 3 | local nnquery = require 'nnquery' 4 | 5 | --[[ 6 | `Element` for an nngraph node (`nngraph.Node`). 7 | ]] 8 | local NE, super = classic.class(..., nnquery.Element) 9 | 10 | function NE:_init(...) 11 | super._init(self, ...) 12 | end 13 | 14 | function NE:_set_gmod(gmod, gmod_el) 15 | assert(not self._gmod and not self._gmod_el) 16 | self._gmod = gmod 17 | self._gmod_el = gmod_el 18 | end 19 | 20 | function NE:_init_parents() 21 | assert(self._gmod and self._gmod_el, "internal error: don't wrap a gmodule node directly") 22 | 23 | -- Find all the parents: 24 | self._parents = {} 25 | for i, mi in ipairs(self:val().data.mapindex) do 26 | self._parents[i] = self._gmod_el:_wrap(self._gmod.fg.nodes[mi.forwardNodeId]) 27 | end 28 | end 29 | 30 | --[[ 31 | Alias for `:val().data.module`. 32 | ]] 33 | function NE:module() 34 | return self:val().data.module 35 | end 36 | 37 | --[[ 38 | Alias for `:val().data.annotations.name`. 39 | ]] 40 | function NE:name() 41 | return self:val().data.annotations.name 42 | end 43 | 44 | --[[ 45 | Returns an `ElementList` for parent nngraph nodes of this element. 46 | ]] 47 | function NE:parents() 48 | return nnquery.ElementList.fromtable(self._parents) 49 | end 50 | 51 | --[[ 52 | Returns an `ElementList` for children nngraph nodes of this element. 53 | ]] 54 | function NE:children() 55 | local childelems = self._gmod_el:_wrapall(self:val().children) 56 | for _, child in ipairs(childelems) do 57 | -- at this point, should have all parents set by the ctor, incl ourself 58 | assert(child:classIs(NE), 'All wrappers for nodes should be NodeElements') 59 | end 60 | return nnquery.ElementList.fromtable(childelems) 61 | end 62 | 63 | function NE:__tostring() 64 | local val = self:val() 65 | local prints = {} 66 | for _, key in ipairs{'module', 'nSplitOutputs'} do 67 | if val.data[key] then 68 | table.insert(prints, string.format('d.%s=%s', key, tostring(val.data[key]))) 69 | end 70 | end 71 | if val.data.annotations then 72 | for k, v in pairs(val.data.annotations) do 73 | table.insert(prints, string.format('d.a.%s=%s', k, v)) 74 | end 75 | end 76 | return string.format('%s[%s]', 77 | self:class():name(), 78 | table.concat(prints, ', ')) 79 | end 80 | 81 | function NE.static.isNode(m) 82 | -- require in here so that the default context can be constructed without nngraph installed 83 | require 'nngraph' 84 | return torch.isTypeOf(m, nngraph.Node) 85 | end 86 | 87 | return NE 88 | -------------------------------------------------------------------------------- /nnquery/init.lua: -------------------------------------------------------------------------------- 1 | local classic = require 'classic' 2 | 3 | local M = classic.module(...) 4 | 5 | M:class('Context') 6 | 7 | M:class('Element') 8 | M:class('ChildlessElement') 9 | M:class('ContainerElement') 10 | M:class('ModuleElement') 11 | M:class('NNGraphGModuleElement') 12 | M:class('NNGraphNodeElement') 13 | 14 | M:class('ElementList') 15 | 16 | -- TODO: maybe lightweight hierarchy (via classic) for exceptions, instead of string errors. 17 | 18 | -- Create a default context with some good default settings: 19 | local ctx = M.Context() 20 | ctx:reg(M.NNGraphGModuleElement, M.NNGraphGModuleElement.isGmodule) 21 | ctx:reg(M.NNGraphNodeElement, M.NNGraphNodeElement.isNode) 22 | ctx:reg(M.ContainerElement, M.ContainerElement.isContainer) -- after since gModule IS_A Container 23 | ctx:default(M.ChildlessElement) 24 | M.default = ctx 25 | 26 | -- Copy the classic module metatable, and forward __call to default ctx 27 | local mt = {} 28 | for k,v in pairs(getmetatable(M)) do mt[k] = v end 29 | setmetatable(M, mt) 30 | mt.__call = function(M, ...) return M.default(...) end 31 | 32 | return M 33 | 34 | -------------------------------------------------------------------------------- /nnquery/tests/test_elem.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Tests the following classes (not all in isolation, TODO: write real unit tests w/ mocks): 3 | * Element (abstract base class) 4 | * ModuleElement (abstract base class) 5 | * ContainerElement 6 | * ChildlessElement 7 | Also tests some basic ElementList functionality. 8 | ]] 9 | 10 | local totem = require 'totem' 11 | require 'nn' 12 | 13 | -- path hack for running in src dir without nnquery installed: 14 | package.path = package.path .. ';../../?/init.lua;../?/init.lua' 15 | .. ';../../?.lua;../?.lua' 16 | local nnq = require 'nnquery' 17 | 18 | local tester = totem.Tester() 19 | local tests = {} 20 | 21 | -- Constructs container- and childless-only context 22 | local function newctx() 23 | local ctx = nnq.Context() 24 | ctx:reg(nnq.ContainerElement, nnq.ContainerElement.isContainer) 25 | ctx:default(nnq.ChildlessElement) 26 | return ctx 27 | end 28 | 29 | function tests.Val() 30 | local ctx = newctx() 31 | local mod = nn.Identity() 32 | tester:asserteq(ctx(mod):val(), mod, 'wrong val') 33 | end 34 | 35 | function tests.NoSiblings() 36 | local ctx = newctx() 37 | local idn = nn.Tanh() 38 | local mod = nn.Container():add(idn) 39 | 40 | -- make sure we grab the right one: 41 | tester:asserteq(ctx(mod):children():only():val(), idn, 'wrong val?') 42 | -- should have no siblings 43 | tester:asserteq(ctx(mod):children():first():following_siblings():count(), 0, 'no siblings') 44 | tester:asserteq(ctx(mod):children():first():preceding_siblings():count(), 0, 'no siblings') 45 | end 46 | 47 | function tests.FollowingSiblingsPrecedingSiblings() 48 | local ctx = newctx() 49 | local before = nn.Sigmoid() 50 | local idn = nn.Tanh() 51 | local after = nn.ReLU() 52 | local mod = nn.Container() 53 | :add(nn.Identity()) 54 | :add(before) 55 | :add(idn) 56 | :add(after) 57 | :add(nn.Identity()) 58 | :add(nn.Identity()) 59 | 60 | -- make sure we grab the right one: 61 | tester:asserteq(ctx(mod):children():nth(3):val(), idn, 'wrong val?') 62 | -- test sibling behaviour 63 | tester:asserteq(ctx(mod):children():nth(3):following_siblings():count(), 3, 'should be 3 siblings after') 64 | tester:asserteq(ctx(mod):children():nth(3):following_siblings():first():val(), after, 'wrong element after') 65 | tester:asserteq(ctx(mod):children():nth(3):preceding_siblings():count(), 2, 'should be 2 siblings before') 66 | tester:asserteq(ctx(mod):children():nth(3):preceding_siblings():last():val(), before, 'wrong element before') 67 | end 68 | 69 | -- Mostly for the before/after tests below... 70 | local function new_ctx_simple_container1() 71 | local ctx = newctx() 72 | local mod = nn.Container() 73 | :add(nn.Identity()) 74 | :add(nn.ReLU()) 75 | :add(nn.ReLU()) 76 | :add(nn.Identity()) 77 | :add(nn.Identity()) 78 | return ctx, mod, ctx(mod):children() 79 | end 80 | 81 | --[[ Tests basic filtering by predicate. ]] 82 | function tests.Where() 83 | local ctx, mod, children = new_ctx_simple_container1() 84 | 85 | -- returns iterator: 86 | local result = children:where( 87 | function(x, i) return torch.isTypeOf(x:val(), nn.Identity) end, true) 88 | tester:asserteq(result:count(), 3, 'expect 3 nn.Identitys') 89 | tester:asserteq(#result:totable(), 3, 'expect 3 nn.Identitys') 90 | tester:asserteq(#result:totable(), 3, 'expect 3 nn.Identitys again') 91 | -- returns table: 92 | local result = children:where( 93 | function(x, i) return torch.isTypeOf(x:val(), nn.Identity) end, false) 94 | tester:asserteq(result:count(), 3, 'expect 3 nn.Identitys') 95 | tester:asserteq(#result:totable(), 3, 'expect 3 nn.Identitys') 96 | end 97 | 98 | --[[ Tests getting elements in a list before/after a predicate is true. 99 | Edge cases: empty/full. 100 | ]] 101 | function tests.BeforeAfterEmptyFull() 102 | local ctx, mod, children = new_ctx_simple_container1() 103 | 104 | -- empty 'before' (i.e. returns true on first iter): 105 | tester:asserteq( 106 | children:before(function() return true end):count(), 107 | 0, 108 | 'should be an empty result') 109 | -- full 'before' (i.e. returns false on all iter) 110 | tester:asserteq( 111 | children:before(function() return false end):count(), 112 | #mod.modules, 113 | 'should be a full result') 114 | 115 | -- empty 'after' (i.e. returns false on all iter): 116 | tester:asserteq( 117 | children:after(function() return false end):count(), 118 | 0, 119 | 'should be an empty result') 120 | -- *almost* full 'after' (i.e. returns true on first iter), 121 | -- returns all elements except first 122 | tester:asserteq( 123 | children:after(function() return true end):count(), 124 | #mod.modules - 1, 125 | 'should be an ALMOST full result') 126 | end 127 | 128 | --[[ Tests getting elements in a list before/after a predicate is true. 129 | Typical case. 130 | ]] 131 | function tests.BeforeAfterTypical() 132 | local ctx, mod, children = new_ctx_simple_container1() 133 | 134 | -- get elements before the 2nd (of 5) element: 135 | local result = children:before(function(x, i) return i == 2 end) 136 | tester:asserteq(result:count(), 1, 'should be 1 element before') 137 | tester:asserteq(result:first():val(), mod.modules[1], 'and it should be just the identity') 138 | 139 | -- get elements after the 2nd (of 5) element: 140 | local result = children:after(function(x, i) return i == 2 end) 141 | tester:asserteq(result:count(), 3, 'should be 3 elements after') 142 | end 143 | 144 | --[[ Tests getting elements in a list before/after a predicate is true, inclusive version. 145 | Edge cases. 146 | ]] 147 | function tests.BeforeAfterInclEmptyFull() 148 | local ctx, mod, children = new_ctx_simple_container1() 149 | 150 | -- empty 'before' (i.e. returns true on first iter): 151 | tester:asserteq( 152 | children:ibefore(function() return true end):count(), 153 | 1, 154 | 'should be a singleton result') 155 | -- full 'before' (i.e. returns false on all iter) 156 | tester:asserteq( 157 | children:ibefore(function() return false end):count(), 158 | #mod.modules, 159 | 'should be a full result') 160 | 161 | -- empty 'after' (i.e. returns false on all iter): 162 | tester:asserteq( 163 | children:iafter(function() return false end):count(), 164 | 0, 165 | 'should be an empty result') 166 | -- full 'after' (i.e. returns true on first iter), 167 | tester:asserteq( 168 | children:iafter(function() return true end):count(), 169 | #mod.modules, 170 | 'should be a full result') 171 | end 172 | 173 | --[[ Tests getting elements in a list before/after a predicate is true, inclusive version. 174 | Edge cases. 175 | ]] 176 | function tests.BeforeAfterIncl() 177 | local ctx, mod, children = new_ctx_simple_container1() 178 | 179 | -- get elements before the 2nd (of 5) element: 180 | local result = children:ibefore(function(x, i) return i == 2 end) 181 | tester:asserteq(result:count(), 2, 'should be 2 element before') 182 | tester:asserteq(result:first():val(), mod.modules[1], 'and it should be just the identity') 183 | 184 | -- get elements after the 2nd (of 5) element: 185 | local result = children:iafter(function(x, i) return i == 2 end) 186 | tester:asserteq(result:count(), 4, 'should be 4 elements after') 187 | end 188 | 189 | --[[ Tests depth first search on an element's descendants. ]] 190 | function tests.DescendantsAndDFS() 191 | -- Create a context and tree-shaped container 192 | local ctx = newctx() 193 | local mod = nn.Container() 194 | :add(nn.Identity()) 195 | :add(nn.Sequential() 196 | :add(nn.ParallelTable() 197 | :add(nn.Identity()) 198 | :add(nn.Sequential())) 199 | :add(nn.Identity())) 200 | -- Number of descendants = number of lines of code above = 6 201 | tester:asserteq(ctx(mod):descendants():count(), 6, 'should be 6 descendants') 202 | 203 | ctx(mod):dfs(function(el) 204 | -- Each nn module should be visited precisely once: 205 | tester:assert(not el:val().test_visited, 'should visit each element precisely once') 206 | el:val().test_visited = true 207 | -- Check that parent is set correctly: 208 | local found_self = false 209 | for i in el:parents():only():children():iter() do 210 | if el:equals(i) then 211 | found_self = true 212 | end 213 | end 214 | tester:assert(found_self, "parent set incorrectly: cannot find self in parent's children") 215 | end) 216 | end 217 | 218 | return tester:add(tests):run() 219 | -------------------------------------------------------------------------------- /nnquery/tests/test_nngraph.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Tests the nngraph classes, also needs ChildlessElement: 3 | * NNGraphNodeElement 4 | * NNGraphGModuleElement 5 | * ChildlessElement 6 | 7 | NOTE: very incomplete. 8 | ]] 9 | 10 | local totem = require 'totem' 11 | require 'nn' 12 | require 'nngraph' 13 | 14 | -- path hack for running in src dir without nnquery installed: 15 | package.path = package.path .. ';../../?/init.lua;../?/init.lua' 16 | .. ';../../?.lua;../?.lua' 17 | local nnq = require 'nnquery' 18 | 19 | local tester = totem.Tester() 20 | local tests = {} 21 | 22 | -- Constructs context: NNGraphNodeElement, NNGraphGModuleElement, ChildlessElement 23 | local function newctx() 24 | local ctx = nnq.Context() 25 | ctx:reg(nnq.NNGraphGModuleElement, nnq.NNGraphGModuleElement.isGmodule) 26 | ctx:reg(nnq.NNGraphNodeElement, nnq.NNGraphNodeElement.isNode) 27 | ctx:default(nnq.ChildlessElement) 28 | return ctx 29 | end 30 | 31 | function tests.Val() 32 | local ctx = newctx() 33 | local mod = nn.Identity() 34 | tester:asserteq(ctx(mod):val(), mod, 'wrong val') 35 | end 36 | 37 | -- Helper function to generate one timestep of an LSTM: 38 | -- (source: Oxford practical 6) 39 | -- Modified for test purposes (e.g. is_i2h), rnn_size=1, etc. 40 | function create_lstm() 41 | local rnn_size = 1 -- for testing purposes 42 | local x = nn.Identity()() 43 | local prev_c = nn.Identity()() 44 | local prev_h = nn.Identity()() 45 | 46 | function new_input_sum() 47 | -- transforms input 48 | local i2h = nn.Linear(rnn_size, rnn_size)(x):annotate{is_i2h=true} 49 | -- transforms previous timestep's output 50 | local h2h = nn.Linear(rnn_size, rnn_size)(prev_h):annotate{is_h2h=true} 51 | return nn.CAddTable()({i2h, h2h}) 52 | end 53 | local in_gate = nn.Sigmoid()(new_input_sum()) 54 | local forget_gate = nn.Sigmoid()(new_input_sum()) 55 | local out_gate = nn.Sigmoid()(new_input_sum()) 56 | local in_transform = nn.Tanh()(new_input_sum()) 57 | 58 | local next_c = nn.CAddTable()({ 59 | nn.CMulTable()({forget_gate, prev_c}), 60 | nn.CMulTable()({in_gate, in_transform}) 61 | }) 62 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 63 | 64 | nngraph.annotateNodes() 65 | return nn.gModule({x, prev_c, prev_h}, {next_c, next_h}) 66 | end 67 | 68 | --[[ 69 | Some basic operations on the above LSTM module. 70 | 71 | Serves as documentation. 72 | ]] 73 | function tests.LSTM() 74 | local ctx = newctx() 75 | local lstm = create_lstm() 76 | 77 | -- Get an Element for forget_gate: (twice) 78 | local forget_gate = ctx(lstm):descendants() 79 | :where(function(e) return e:val().data.annotations.name == 'forget_gate' end):only() 80 | local forget_gate_2 = ctx(lstm):descendants():attr{name='forget_gate'}:only() 81 | tester:asserteq(forget_gate:val(), forget_gate_2:val(), 'both ways should find same node') 82 | 83 | -- Only one parent: input_sum. In turn has two parents: i2h and h2h nn.Linear's 84 | local input_sum = forget_gate:parent() -- error if more than one, in which case use :parents() 85 | tester:asserteq(#input_sum:parents(), 2, 'wrong number of parents to input_sum') 86 | -- These names i2h and h2h aren't automatically set by annotateNodes(), since in closure: 87 | local i2h, h2h = unpack(input_sum:parents():totable()) 88 | tester:assert(i2h:val().data.annotations.is_i2h, 'is not i2h') 89 | tester:assert(h2h:val().data.annotations.is_h2h, 'is not h2h') 90 | 91 | -- Get the output nodes, verify they are next_c and next_h, resp. 92 | tester:asserteq(ctx(lstm):outputs():count(), 2, 'should have 2 outputs') 93 | tester:asserteq(ctx(lstm):outputs():first():name(), 'next_c') 94 | tester:asserteq(ctx(lstm):outputs():last():name(), 'next_h') 95 | end 96 | 97 | return tester:add(tests):run() 98 | -------------------------------------------------------------------------------- /rocks/nnquery-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = 'nnquery' 2 | version = 'scm-1' 3 | 4 | source = { 5 | url = 'git://github.com/bshillingford/nnquery.git', 6 | branch = 'master' 7 | } 8 | 9 | description = { 10 | summary = 'Query complex graph structures in neural networks', 11 | detailed = 'Traverse complex neural netwrok graph structures as easily as XPath or CSS', 12 | homepage = 'https://github.com/bshillingford/nnquery', 13 | license = 'BSD' 14 | } 15 | 16 | dependencies = { 17 | 'lua >= 5.1', 18 | 'torch', 19 | 'classic' 20 | } 21 | 22 | build = { 23 | type = 'builtin', 24 | modules = { 25 | nnquery = 'nnquery/init.lua', 26 | ['nnquery.Element'] = 'nnquery/Element.lua', 27 | ['nnquery.ChildlessElement'] = 'nnquery/ChildlessElement.lua', 28 | ['nnquery.ContainerElement'] = 'nnquery/ContainerElement.lua', 29 | ['nnquery.Context'] = 'nnquery/Context.lua', 30 | ['nnquery.Element'] = 'nnquery/Element.lua', 31 | ['nnquery.ElementList'] = 'nnquery/ElementList.lua', 32 | ['nnquery.ModuleElement'] = 'nnquery/ModuleElement.lua', 33 | ['nnquery.NNGraphGModuleElement'] = 'nnquery/NNGraphGModuleElement.lua', 34 | ['nnquery.NNGraphNodeElement'] = 'nnquery/NNGraphNodeElement.lua', 35 | } 36 | } 37 | 38 | --------------------------------------------------------------------------------