├── .gitignore ├── img ├── boring.png └── cool1.png ├── manifest ├── pretty-nn-scm-1.rockspec ├── README.md └── init.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.sw* 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /img/boring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atcold/torch-pretty-nn/master/img/boring.png -------------------------------------------------------------------------------- /img/cool1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atcold/torch-pretty-nn/master/img/cool1.png -------------------------------------------------------------------------------- /manifest: -------------------------------------------------------------------------------- 1 | commands = {} 2 | modules = {} 3 | repository = { 4 | ['pretty-print'] = { 5 | ['scm-1'] = { 6 | { 7 | arch = "rockspec" 8 | } 9 | } 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /pretty-nn-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = 'pretty-nn' 2 | version = 'scm-1' 3 | 4 | source = {url = 'git://github.com/Atcold/torch-pretty-nn'} 5 | 6 | description = { 7 | summary = 'Brings some colours to the borin nn', 8 | detailed = [[ 9 | It allows to make sense of your architecture when things start growing to a 10 | reasonable size. 11 | ]] 12 | } 13 | 14 | dependencies = { 15 | 'torch >= 7.0', 16 | 'nn >= 1.0' 17 | } 18 | 19 | build = { 20 | type = 'builtin', 21 | modules = { 22 | ['pretty-nn.init'] = 'init.lua' 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pretty-nn 2 | 3 | `pretty-nn` package brings some colour to the boring `nn` package of Torch. 4 | 5 | ## Installation 6 | 7 | ``` 8 | luarocks install pretty-nn 9 | ``` 10 | 11 | ## Usage 12 | 13 | Say we have the following model 14 | 15 | ```lua 16 | c = nn.Parallel(1,2) 17 | for i = 1, 3 do 18 | local t = nn.Sequential() 19 | t:add(nn.Linear(4, 2)) 20 | t:add(nn.Reshape(2, 1)) 21 | c:add(t) 22 | end 23 | split = nn.Concat(2) 24 | split:add(c) 25 | block = nn.Sequential() 26 | block:add(nn.View(-1)) 27 | block:add(nn.Linear(12, 4)) 28 | block:add(nn.View(2, 2)) 29 | split:add(block) 30 | ``` 31 | 32 | If we print it on screen with `print(split)` we get something like this 33 | 34 | ![Boring](img/boring.png) 35 | 36 | If we `require 'pretty-nn'` and `print(split)` again, we'll get this 37 | 38 | ![Cool](img/cool1.png) 39 | 40 | You can toggle the prettiness with `yourModelName:prettyPrint()`. 41 | Alternative, you can *enable / disable* it by `yourModelName:prettyPrint(true / false)`. 42 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | nn.config = {} 4 | nn.config.prettyPrint = true 5 | 6 | function nn.Container:prettyPrint(status) 7 | if status == nil then 8 | nn.config.prettyPrint = not nn.config.prettyPrint 9 | else 10 | nn.config.prettyPrint = status 11 | end 12 | end 13 | 14 | function nn.Sequential:__tostring__() 15 | local b = function(s) -- BLUE 16 | if nn.config.prettyPrint then return '\27[0;34m' .. s .. '\27[0m' end 17 | return s 18 | end 19 | local tab = ' ' 20 | local line = '\n' 21 | local next = b ' -> ' 22 | local str = b 'nn.Sequential' 23 | str = str .. b ' {' .. line .. tab .. b '[input' 24 | for i=1,#self.modules do 25 | str = str .. next .. b '(' .. i .. b ')' 26 | end 27 | str = str .. next .. b 'output]' 28 | for i=1,#self.modules do 29 | str = str .. line .. tab .. b '(' .. i .. b '): ' .. tostring(self.modules[i]):gsub(line, line .. tab) 30 | end 31 | str = str .. line .. b '}' 32 | return str 33 | end 34 | 35 | -------------------------------------------------------------------------------- 36 | -- Concat 37 | -------------------------------------------------------------------------------- 38 | 39 | function nn.Concat:__tostring__() 40 | local r = function(s) -- RED 41 | if nn.config.prettyPrint then return '\27[0;31m' .. s .. '\27[0m' end 42 | return s 43 | end 44 | local tab = ' ' 45 | local line = '\n' 46 | local next = r ' |`-> ' 47 | local lastNext = r ' `-> ' 48 | local ext = r ' | ' 49 | local extlast = ' ' 50 | local last = r ' ... -> ' 51 | local str = r(torch.type(self)) 52 | str = str .. r ' {' .. line .. tab .. r 'input' 53 | for i=1,#self.modules do 54 | if i == #self.modules then 55 | str = str .. line .. tab .. lastNext .. r '(' .. i .. r '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) 56 | else 57 | str = str .. line .. tab .. next .. r '(' .. i .. r '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) 58 | end 59 | end 60 | str = str .. line .. tab .. last .. r 'output' 61 | str = str .. line .. r '}' 62 | return str 63 | end 64 | 65 | nn.ConcatTable.__tostring__ = nn.Concat.__tostring__ 66 | 67 | -------------------------------------------------------------------------------- 68 | -- Parallel 69 | -------------------------------------------------------------------------------- 70 | 71 | function nn.Parallel:__tostring__() 72 | local g = function(s) -- GREEN 73 | if nn.config.prettyPrint then return '\27[0;32m' .. s .. '\27[0m' end 74 | return s 75 | end 76 | local tab = ' ' 77 | local line = '\n' 78 | local next = g ' |`-> ' 79 | local lastNext = g ' `-> ' 80 | local ext = g ' | ' 81 | local extlast = ' ' 82 | local last = g ' ... -> ' 83 | local str = g(torch.type(self)) 84 | str = str .. g ' {' .. line .. tab .. g 'input' 85 | for i=1,#self.modules do 86 | if i == #self.modules then 87 | str = str .. line .. tab .. lastNext .. g '(' .. i .. g '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) 88 | else 89 | str = str .. line .. tab .. next .. g '(' .. i .. g '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) 90 | end 91 | end 92 | str = str .. line .. tab .. last .. g 'output' 93 | str = str .. line .. g '}' 94 | return str 95 | end 96 | 97 | nn.ParallelTable.__tostring__ = nn.Parallel.__tostring__ 98 | --------------------------------------------------------------------------------