├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── instance ├── Instance.lua └── Model.lua ├── interfaces ├── Data.lua ├── Interface.lua └── Interfaces.lua ├── layers ├── EpisodeBarrier.lua ├── Index.lua ├── Loss.lua └── Move.lua ├── main.lua ├── play.py └── utils ├── Acc.lua ├── Curriculum.lua ├── DataGenerator.lua ├── Game.lua ├── Visualizer.lua ├── base.lua ├── colors.lua ├── models.lua ├── nngraph.lua └── q_learning.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | movie/* 3 | # Compiled Lua sources 4 | luac.out 5 | *.swp 6 | 7 | # luarocks build files 8 | *.src.rock 9 | *.zip 10 | *.tar.gz 11 | 12 | # Object files 13 | *.o 14 | *.os 15 | *.ko 16 | *.obj 17 | *.elf 18 | 19 | # Precompiled Headers 20 | *.gch 21 | *.pch 22 | 23 | # Libraries 24 | *.lib 25 | *.a 26 | *.la 27 | *.lo 28 | *.def 29 | *.exp 30 | 31 | # Shared objects (inc. Windows DLLs) 32 | *.dll 33 | *.so 34 | *.so.* 35 | *.dylib 36 | 37 | # Executables 38 | *.exe 39 | *.out 40 | *.app 41 | *.i*86 42 | *.x86_64 43 | *.hex 44 | 45 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Learning Simple Algorithms from Examples 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | This project is developed internally at Facebook inside a private repository. 7 | Changes are periodically pushed to the open-source branch. Pull requests are 8 | integrated manually into our private repository first, and they then get 9 | propagated to the public repository with the next push. 10 | 11 | ## Pull Requests 12 | We actively welcome your pull requests. 13 | 14 | 1. Fork the repo and create your branch from `master`. 15 | 2. If you've added code that should be tested, add tests. 16 | 3. If you've changed APIs, update the documentation. 17 | 4. Ensure the test suite passes. 18 | 5. Make sure your code lints. 19 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 20 | 21 | ## Contributor License Agreement ("CLA") 22 | In order to accept your pull request, we need you to submit a CLA. You only need 23 | to do this once to work on any of Facebook's open source projects. 24 | 25 | Complete your CLA here: 26 | 27 | ## Issues 28 | We use GitHub issues to track public bugs. Please ensure your description is 29 | clear and has sufficient instructions to be able to reproduce the issue. 30 | 31 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 32 | disclosure of security bugs. In those cases, please go through the process 33 | outlined on that page and do not file a public issue. 34 | 35 | ## Coding Style 36 | 37 | ### C++ 38 | * 2 spaces for indentation rather than tabs 39 | * 80 character line length 40 | * Name classes LikeThis, functions and methods likeThis, data members 41 | likeThis_. 42 | * Most naming and formatting recommendations from 43 | [Google's C++ Coding Style Guide]( 44 | http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml) apply (but 45 | not the restrictions; exceptions and templates are fine.) 46 | * Feel free to use [boost](http://www.boost.org/), 47 | [folly](https://github.com/facebook/folly) and 48 | [fbthrift](https://github.com/facebook/fbthrift) 49 | 50 | ### Lua 51 | * Inspired by [PEP 8](http://legacy.python.org/dev/peps/pep-0008/) 52 | * 4 spaces for indentation rather than tabs 53 | * 80 character line length 54 | * Name classes LikeThis, functions, methods, and variables like_this, private 55 | methods _like_this 56 | * Use [Penlight](http://stevedonovan.github.io/Penlight/api/index.html); 57 | specifically pl.class for OOP 58 | * Do not use global variables (except with a very good reason) 59 | * Use [new-style modules](http://lua-users.org/wiki/ModulesTutorial); do not 60 | use the module() function 61 | * Assume [LuaJIT 2.0+](http://luajit.org/), so Lua 5.1 code with LuaJIT's 62 | supported [extensions](http://luajit.org/extensions.html); 63 | [FFI](http://luajit.org/ext_ffi.html) is okay. 64 | 65 | ## License 66 | By contributing to ``Learning Algorithms from Examples'', you agree that your contributions will be licensed 67 | under its Apache 2. 68 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For learningSimpleAlgorithms software 4 | 5 | Copyright (c) 2015-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the learningSimpleAlgorithms software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Learning Simple Algorithms from Examples 2 | ======================================== 3 | 4 | This is a framework to learn simple algorithms such as 5 | copying, multi-digit addition and single digit multiplication 6 | directly from examples. Our framework consists of a set of 7 | interfaces, accessed by a controller. Typical 8 | interfaces are 1-D tapes or 2-D grids that hold the input and output 9 | data. 10 | The paper can be found at: http://arxiv.org/abs/1511.07275 .
11 | Moreover, the accompanying video https://www.youtube.com/watch?v=GVe6kfJnRAw gives a concise overview of our approach. 12 | 13 | 14 | This software runs in Torch. Type
15 | 16 | `th main.lua` 17 | 18 | to train the model for the addition task. 19 | 20 | The model generates traces of the intermediate solutions while training (in 21 | directory ./movie/). They can be displayed by calling: 22 | 23 | `python play.py` 24 | 25 | -------------------------------------------------------------------------------- /instance/Instance.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- It contains all the information about a single time-step (i.e. gradients 12 | -- of the RNN, actions taken etc.). 13 | local Ins = torch.class('Instance') 14 | 15 | function Ins:__init(time) 16 | self.time = time 17 | assert(type(self.time) == "number") 18 | self:copy_from_root() 19 | local function assign(from, to) 20 | for k, v in pairs(from) do 21 | if torch.isTensor(v) then 22 | to[k] = v[self.time + 1] 23 | elseif type(v) == "table" then 24 | to[k] = {} 25 | assign(v, to[k]) 26 | end 27 | end 28 | end 29 | assign(model.fields, self) 30 | end 31 | 32 | -- Return previous instance in time. 33 | function Ins:prev() 34 | return model(self.time - 1) 35 | end 36 | 37 | function Ins:copy_from_root() 38 | if model.core_network == nil then 39 | model.core_network = create_network() 40 | -- Network is created only once. Other instances contain reference 41 | -- to the same weights. 42 | paramx, paramdx = model.core_network:getParameters() 43 | paramx:mul(2) 44 | end 45 | self.rnn = create_network() 46 | for i, node in pairs(self.rnn.forwardnodes) do 47 | if node.data.module then 48 | local to = node.data.module 49 | local from = model.core_network.forwardnodes[i].data.module 50 | if to.weight then 51 | to.weight = from.weight 52 | to.gradWeight = from.gradWeight 53 | to.bias = from.bias 54 | to.gradBias = from.gradBias 55 | end 56 | end 57 | end 58 | end 59 | 60 | -- Return next instance in time. 61 | function Ins:child() 62 | return model(self.time + 1) 63 | end 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /instance/Model.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Defines the model, its forward and backward pass. Contains model parameters. 12 | local Model = torch.class('Model') 13 | 14 | function Model:__init() 15 | self.step = 1 16 | self.trained_chars = 0 17 | self.instances = {} 18 | self.fields = {} 19 | self.samples = {} 20 | self.norm_dw = 0 21 | self.q_size = 1 22 | self.ring = 5 23 | for _, s in pairs(interfaces.name2desc) do 24 | self.q_size = self.q_size * s.size 25 | end 26 | self:clean() 27 | end 28 | 29 | function Model:clean() 30 | local f = self.fields 31 | if self.fields.s ~= nil then 32 | local function zero(obj) 33 | for k, _ in pairs(obj) do 34 | if torch.isTensor(obj[k]) then 35 | obj[k]:zero() 36 | elseif type(obj[k]) == "table" then 37 | zero(obj[k]) 38 | end 39 | end 40 | end 41 | zero(self.fields) 42 | for name, desc in pairs(interfaces.name2desc) do 43 | if desc.size > 0 then 44 | f.actions[name].action:zero():add(desc.size) 45 | end 46 | end 47 | return 48 | end 49 | local tl = params.max_seq_length + 2 50 | local bs = params.batch_size 51 | local rs = params.rnn_size 52 | local ks = params.key_size 53 | local ms = params.memory_size 54 | local quant = params.layers 55 | -- LSTM has twice as many units (half for hidden states, half for cells). 56 | if params.unit == "lstm" then 57 | quant = quant * 2 58 | end 59 | f.s = torch.zeros(tl, quant, bs, rs) -- State. 60 | f.ds = f.s:clone() -- Derivative of state. 61 | f.correct = torch.zeros(tl, bs) -- If prediction is correct. 62 | f.dq = torch.zeros(tl, bs, self.q_size) -- Derivative with respect to q-function. 63 | f.err = torch.zeros(tl, bs) -- Error. 64 | f.q = torch.zeros(tl, bs, self.q_size) -- Q-value. 65 | f.logits = torch.zeros(tl, bs, self.q_size) 66 | f.chosen = torch.zeros(tl, bs) -- Which action have been choosen 67 | f.sampled = torch.zeros(tl, bs) -- If action was sampled. 68 | f.target_idx = torch.zeros(tl, bs) 69 | f.max_idx = torch.zeros(tl, bs) -- Which Q-function index yields the highest value. 70 | f.data = {} 71 | f.tape = {} 72 | f.total = {} 73 | f.actions = {} 74 | f.d = {} 75 | self:register_empty("data", {"begin", "finish", "time", "sampled", "x", "y", "pred", "task"}) 76 | 77 | f.actions = {} 78 | for name, desc in pairs(interfaces.name2desc) do 79 | f.actions[name] = {} 80 | f.actions[name] = {} 81 | f.actions[name].action = torch.zeros(tl, bs):add(desc.size) 82 | f.actions[name].max = torch.zeros(tl, bs) 83 | end 84 | f.tape.idx = torch.zeros(tl, bs) 85 | end 86 | 87 | function Model:register_empty(name, fields) 88 | if self.fields[name] == nil then 89 | self.fields[name] = {} 90 | end 91 | for _, f in pairs(fields) do 92 | self.fields[name][f] = torch.zeros(params.max_seq_length + 2, params.batch_size) 93 | end 94 | end 95 | 96 | -- model(t) gives an access to the model instantiation ot time t. 97 | function Model:__call(time) 98 | assert(type(time) == "number") 99 | local offset = 5 100 | if params.train == 2 or time < offset then 101 | if model.instances[time] == nil then 102 | model.instances[time] = Instance(time) 103 | end 104 | return model.instances[time] 105 | else 106 | -- While testing on the very long sequences, we have to 107 | -- reuse previous time instances (to not run out of memory). 108 | for i = -1, 1 do 109 | local idx = (time - offset + i) % self.ring + offset 110 | if model.instances[idx] == nil then 111 | model.instances[idx] = Instance(idx) 112 | end 113 | model.instances[idx].time = time + i 114 | end 115 | return model.instances[(time - offset) % self.ring + offset] 116 | end 117 | end 118 | 119 | function Model:reboot() 120 | g_make_deterministic(params.seed) 121 | self.root = model(0) 122 | paramdx:zero() 123 | for _, interface in pairs(interfaces.interfaces) do 124 | interface:clean() 125 | end 126 | interfaces.data:clean() 127 | model:clean() 128 | acc:clean() 129 | end 130 | 131 | -- Computes rewards, and their derivatives. 132 | function Model:rewards() 133 | -- If during test, then exit (params.train \in {1, 2}). 134 | if params.train == 1 then 135 | return 136 | end 137 | local err = 0 138 | for batch = 1, params.batch_size do 139 | for i = 1, params.seq_length - 1 do 140 | local ins = model(i) 141 | local sample = ins.data.samples[batch] 142 | if sample:eos() and sample.train == 2 then 143 | local diff = _G[params.q_type](ins, batch) 144 | assert(ins.dq[batch]:norm() == 0) 145 | if acc:get_current_acc(sample.complexity) > 0.9 then 146 | local q_decay_lr = params.q_decay_lr or 0 147 | ins.dq[batch]:add(q_decay_lr * (ins.q[batch]:sum() - 1)) 148 | for j = 1, ins.dq:size(2) do 149 | local q = ins.q[batch][j] 150 | if q <= 0 then 151 | ins.dq[batch][j] = ins.dq[batch][j] + q_decay_lr * q 152 | elseif q >= 1 then 153 | ins.dq[batch][j] = ins.dq[batch][j] - q_decay_lr * (q - 1) 154 | end 155 | end 156 | end 157 | local chosen = ins.chosen[batch] 158 | ins.dq[batch][chosen] = ins.dq[batch][chosen] + params.q_lr * diff 159 | err = err + params.q_lr * (diff * diff) / 2 160 | end 161 | end 162 | end 163 | model.err = model.err + err / params.batch_size 164 | end 165 | 166 | -- Forward propagation. 167 | function Model:fp() 168 | -- Sequence length to unroll is chosen dynamically depending on the current complexity. 169 | params.seq_length = 1 170 | model:clean() 171 | interfaces.data:clean() 172 | model.err = 0 173 | local time_offset = 1 174 | while true do 175 | local ins = model(time_offset) 176 | local s = ins.rnn:forward({ins, ins.s})[2] 177 | local sample = ins.data.samples[1] 178 | tensors_copy(model(time_offset + 1).s, s) 179 | model.err = model.err + ins.err:mean() 180 | if params.train == 1 then 181 | io.write(tostring(time_offset) .. " ") 182 | io.flush() 183 | end 184 | time_offset = time_offset + 1 185 | if (params.train == 1 and time_offset > params.test_len) or 186 | (params.train == 2 and time_offset > params.seq_length) then 187 | break 188 | end 189 | end 190 | self:rewards() 191 | visualizer:visualize() 192 | self.samples = {} 193 | end 194 | 195 | -- Backpropagation. 196 | function Model:bp() 197 | -- Don't backpropagate with respect to the test data. 198 | if params.train == 1 then 199 | return 200 | end 201 | paramdx:zero() 202 | for time_offset = params.seq_length, 1, -1 do 203 | local ins = model(time_offset) 204 | assert(ins.s ~= nil and ins.ds ~= nil) 205 | local ds = ins.rnn:backward({ins, ins.s}, {torch.zeros(1), ins.ds})[2] 206 | tensors_copy(model(time_offset - 1).ds, ds) 207 | end 208 | paramdx:div(params.batch_size * params.seq_length / 10) 209 | collectgarbage() 210 | end 211 | 212 | -------------------------------------------------------------------------------- /interfaces/Data.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Interface to access input tape (or grid), and output tape. 12 | local Data, parent = torch.class('Data', 'Interface') 13 | 14 | function Data:__init() 15 | assert(type(params.ntasks) == "number") 16 | print(params) 17 | parent.__init(self, {x = {size=params.vocab_size}, -- 18 | pred = {size=params.vocab_size}, -- this are inputs to the interface 19 | sampled = {size=2}}, -- 20 | {xa = {size=math.pow(2, params.dim), now=1}, -- where to move. 21 | ya = {size=2, now=2}}) -- do prediction or not 22 | 23 | end 24 | 25 | 26 | function Data:view(ins_node) 27 | local function ident(ins, name) 28 | local ret = nil 29 | local delay = 0 30 | if ins.time - delay < 1 then 31 | ret = ins.data[name]:clone() 32 | ret:zero():add(self.view_type[name].size) 33 | else 34 | local past = model(ins.time - delay) 35 | ret = past.data[name]:clone() 36 | if delay > 0 then 37 | for i = 1, params.batch_size do 38 | if past.data.samples[i] ~= ins.data.samples[i] then 39 | ret[i] = self.view_type[name].size 40 | end 41 | end 42 | end 43 | end 44 | return ret 45 | end 46 | return self:emb(ins_node, ident, self.view_type) 47 | end 48 | 49 | function Data:apply(ins_node) 50 | local function fp(self, ins) 51 | local xa = ins.actions.xa.action 52 | for i = 1, params.batch_size do 53 | local sample = ins.data.samples[i] 54 | sample.dir = Game.dirs[sample.dim][xa[i]] 55 | end 56 | return ins 57 | end 58 | return self:new_node(ident, fp, empty_bp, {ins_node}) 59 | end 60 | 61 | function Data:clean() 62 | local child = model.root:child() 63 | child.data.begin:fill(1) 64 | child.data.finish:fill(1):mul(1 / 0) 65 | child.data.time:fill(1) 66 | child.data.pred:fill(params.vocab_size) 67 | child.data.x:fill(1) 68 | child.data.y:fill(1) 69 | child.data.sampled:fill(1) 70 | child.data.samples = {} 71 | for i = 1, params.batch_size do 72 | local sample = curriculum:generateNewSample() 73 | sample.ins = child.time 74 | sample.idx = i 75 | child.data.x[i] = sample:current_input() 76 | child.data.y[i] = sample:current_target() 77 | child.data.samples[i] = sample 78 | table.insert(model.samples, sample) 79 | end 80 | collectgarbage() 81 | end 82 | 83 | function Data:target(ins_node) 84 | local function fp(self, ins) 85 | self.output = zeros() 86 | for idx = 1, params.batch_size do 87 | local sample = ins.data.samples[idx] 88 | if ins.actions.ya.action[idx] == 2 then 89 | local sample = ins.data.samples[idx] 90 | self.output[idx] = ins.data.y[idx] 91 | end 92 | end 93 | return self.output 94 | end 95 | local target = self:new_node(ident, fp, empty_bp, {ins_node}) 96 | return target 97 | end 98 | 99 | -------------------------------------------------------------------------------- /interfaces/Interface.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- An abstract Interface class. Every interface (i.e. tape, grid, memory) 12 | -- has to inherit from this class. 13 | local Interface = torch.class('Interface') 14 | 15 | function Interface:__init(view_type, actions_type) 16 | self.view_type = view_type 17 | self.actions_type = actions_type 18 | self.name = self.__typename:lower() 19 | Interface.emb = emb 20 | Interface.new_node = new_node 21 | end 22 | 23 | -- Returns the last action from the given Interface. 24 | function Interface:last_action(ins_node) 25 | local function action(ins, name) 26 | local ret = nil 27 | local delay = 1 28 | if ins.time - delay < 1 then 29 | ret = ins.actions[name].action:clone() 30 | ret:zero():add(self.actions_type[name].size + 1) 31 | else 32 | local past = model(ins.time - delay) 33 | ret = past.actions[name].action:clone() 34 | for i = 1, params.batch_size do 35 | if past.data.samples[i] ~= ins.data.samples[i] then 36 | ret[i] = self.actions_type[name].size + 1 37 | end 38 | end 39 | end 40 | return ret 41 | end 42 | local last_action = emb(self, ins_node, action, self.actions_type) 43 | return last_action 44 | end 45 | 46 | function Interface:__tostring__() 47 | return torch.type(self) 48 | end 49 | -------------------------------------------------------------------------------- /interfaces/Interfaces.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | local Interfaces = torch.class('Interfaces') 12 | 13 | local unpack = unpack and unpack or table.unpack 14 | 15 | function Interfaces:__init() 16 | self.data = Data() 17 | self.interfaces = {self.data} 18 | self.name2desc = {} 19 | for _, interface in pairs(self.interfaces) do 20 | for name, desc in pairs(interface.actions_type) do 21 | self.name2desc[name] = desc 22 | end 23 | end 24 | self.action_type = {} 25 | for k, v in pairs(self.data.actions_type) do 26 | self.action_type[k] = v 27 | end 28 | if self.tape then 29 | for k, v in pairs(self.tape.actions_type) do 30 | self.action_type[k] = v 31 | end 32 | end 33 | Interfaces.new_node = new_node 34 | end 35 | 36 | function Interfaces:apply(ins_node) 37 | for _, interface in pairs(self.interfaces) do 38 | ins_node = interface:apply(ins_node) 39 | end 40 | return ins_node 41 | end 42 | 43 | function Interfaces:set_sizes() 44 | params.input_size = 0 45 | for _, interface in pairs(self.interfaces) do 46 | for name, desc in pairs(interface.view_type) do 47 | params.input_size = params.input_size + (desc.size + 1) 48 | end 49 | for name, desc in pairs(interface.actions_type) do 50 | params.input_size = params.input_size + (desc.size + 1) 51 | end 52 | end 53 | end 54 | 55 | function Interfaces:q_learning(ins_node, h) 56 | local logits = nn.Linear(params.rnn_size, model.q_size)(nn.Tanh()(h)) 57 | ins_node = self:q_logits(ins_node, logits) 58 | ins_node = self:q_sample(ins_node) 59 | return ins_node 60 | end 61 | 62 | function Interfaces:decode_q(chosen) 63 | local ret = {} 64 | for k, v in pairs(self.action_type) do 65 | ret[k] = (chosen - 1) % v.size + 1 66 | chosen = ((chosen - 1) - (ret[k] - 1)) / v.size + 1 67 | end 68 | return ret 69 | end 70 | 71 | function Interfaces:encode_q(vals) 72 | local ret = 0 73 | local order = {} 74 | local actions = self.action_type 75 | for k, v in pairs(actions) do 76 | table.insert(order, k) 77 | end 78 | for i = 1, #order do 79 | local k = order[#order - i + 1] 80 | local v = actions[k] 81 | ret = ret * v.size 82 | ret = ret + (vals[k] - 1) 83 | end 84 | return ret + 1 85 | end 86 | 87 | function Interfaces:q_logits(ins_node, logits) 88 | local function fp(self, input) 89 | local ins, logits = unpack(input) 90 | for i = 1, params.batch_size do 91 | local sample = ins.data.samples[i] 92 | ins.target_idx[i] = sample.target_idx 93 | ins.q[i]:copy(logits[i]) 94 | ins.logits[i]:copy(logits[i]) 95 | end 96 | return ins 97 | end 98 | local function bp(self, input, gradOutput) 99 | local ins, logits, logits_dropped = unpack(input) 100 | local dlogits = logits:clone():zero() 101 | for i = 1, params.batch_size do 102 | local sample = ins.data.samples[i] 103 | local mul = sample.normal - ins.target_idx[i] 104 | dlogits[i]:copy(ins.dq[i]) 105 | if sample.train == 1 then 106 | assert(ins.dq[i]:norm() == 0) 107 | end 108 | end 109 | self.gradInput = {torch.zeros(1), dlogits} 110 | return self.gradInput 111 | end 112 | return self:new_node(ident, fp, bp, {ins_node, logits}) 113 | end 114 | 115 | function Interfaces:q_sample(ins_node, logits) 116 | local function sample_actions(sample) 117 | assert(sample.train == 2) 118 | local acc_current = acc:get_current_acc(sample.complexity) 119 | local acc_ranges = {0, 0.9, 1} 120 | -- Probability of choosing a random action. 121 | local random = {20, 20} 122 | if params.decay_expr ~= nil then 123 | random = {20, -1} 124 | end 125 | for i = 1, #acc_ranges - 1 do 126 | if acc_current >= acc_ranges[i] and 127 | acc_current < acc_ranges[i + 1] and 128 | math.random(random[i]) == 1 then 129 | local a = {} 130 | for k, v in pairs(self.action_type) do 131 | a[k] = math.random(v.size) 132 | end 133 | return a 134 | end 135 | end 136 | return nil 137 | end 138 | 139 | local function fp(self, ins) 140 | local max_idx = ins.max_idx 141 | local chosen = ins.chosen 142 | ins:child().data.sampled:copy(ins.data.sampled) 143 | for i = 1, params.batch_size do 144 | local sample = ins.data.samples[i] 145 | max_idx[i] = argmax(ins.q[i]) 146 | chosen[i] = max_idx[i] 147 | local a = self:decode_q(chosen[i]) 148 | if sample.train == 2 then 149 | a = sample_actions(sample) or a 150 | end 151 | 152 | if self:encode_q(a) ~= max_idx[i] then 153 | ins.sampled[i] = 1 154 | if sample.sampled == nil then 155 | sample.sampled = ins.time 156 | ins:child().data.sampled[i] = 2 157 | end 158 | end 159 | chosen[i] = self:encode_q(a) 160 | assert(chosen[i] >= 1 and chosen[i] <= model.q_size) 161 | for k, v in pairs(self.action_type) do 162 | ins.actions[k].action[i] = a[k] 163 | end 164 | end 165 | return ins 166 | end 167 | return self:new_node(ident, fp, empty_bp, {ins_node}) 168 | end 169 | -------------------------------------------------------------------------------- /layers/EpisodeBarrier.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Cleans state between samples (i.e. when the episode is over). 12 | -- It makes sure that nothing leaks during forward, as well as the backward pass. 13 | local EpisodeBarrier, parent = torch.class('nn.EpisodeBarrier', 'nn.Module') 14 | 15 | function EpisodeBarrier:__init(current) 16 | self.current = current 17 | self.buffers = {} 18 | end 19 | 20 | function EpisodeBarrier:clone_with_barrier(input, time, idx) 21 | if torch.isTensor(input) then 22 | if input:nDimension() == 2 then 23 | if self.buffers[idx] == nil then 24 | self.buffers[idx] = input:clone() 25 | end 26 | self.buffers[idx]:copy(input) 27 | local output = self.buffers[idx] 28 | idx = idx + 1 29 | assert(output:size(1) == time:size(1)) 30 | assert(output:size(1) == params.batch_size) 31 | for i = 1, params.batch_size do 32 | if time[i] == 1 then 33 | output[i]:zero() 34 | end 35 | end 36 | return output 37 | else 38 | local ret = {} 39 | for k = 1, input:size(1) do 40 | idx = idx + 1 41 | ret[k] = self:clone_with_barrier(input[k], time, idx) 42 | end 43 | return ret 44 | end 45 | elseif type(input) == "table" then 46 | local ret = {} 47 | for k, v in pairs(input) do 48 | idx = idx + 1 49 | ret[k] = self:clone_with_barrier(v, time, idx) 50 | end 51 | return ret 52 | else 53 | assert(false) 54 | end 55 | end 56 | 57 | local unpack = unpack and unpack or table.unpack 58 | 59 | function EpisodeBarrier:updateOutput(input) 60 | local ins, input = unpack(input) 61 | return self:clone_with_barrier(input, ins.data.time, 1) 62 | end 63 | 64 | function EpisodeBarrier:updateGradInput(input, gradOutput) 65 | local ins, s = unpack(input) 66 | self.gradInput = {torch.zeros(1), self:clone_with_barrier(gradOutput, ins.data.time, 100)} 67 | return self.gradInput 68 | end 69 | -------------------------------------------------------------------------------- /layers/Index.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Encodes integers in dummy vector representation. 12 | -- i.e. 4 would be encoded as [0, 0, 0, 1, 0, 0 ...] 13 | local Index, parent = torch.class('nn.OneHotIndex', 'nn.Module') 14 | 15 | function Index:__init(inputSize, name) 16 | parent.__init(self) 17 | assert(type(inputSize) == "number") 18 | self.inputSize = inputSize + 1 19 | self.name = name 20 | end 21 | 22 | function Index:updateOutput(input) 23 | self.output:resize(input:size(1), self.inputSize):zero() 24 | for i = 1, input:size(1) do 25 | assert(input[i] >= 1 and input[i] <= self.inputSize) 26 | self.output[i][input[i]] = 1 27 | end 28 | return self.output 29 | end 30 | 31 | function Index:updateGradInput(input, gradOutput) 32 | if self.gradInput then 33 | self.gradInput:resize(input:size()) 34 | return self.gradInput 35 | end 36 | end 37 | 38 | Index.sharedAccUpdateGradParameters = Index.accUpdateGradParameters 39 | -------------------------------------------------------------------------------- /layers/Loss.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Computes cross entropy loss, and reward (it's stored in correct). 12 | local Loss, parent = torch.class('nn.Loss', 'nn.Module') 13 | 14 | local unpack = unpack and unpack or table.unpack 15 | 16 | function Loss:__init() 17 | parent.__init(self) 18 | self.dlogprob = torch.Tensor() 19 | self.dtarget = zeros() 20 | self.dins = torch.zeros(1) 21 | end 22 | 23 | function Loss:updateOutput(input) 24 | local ins, logprob, target = unpack(input) 25 | ins:child().data.time:copy(ins.data.time):add(1) 26 | local begin = ins:child().data.begin 27 | local correct = ins.correct:zero() 28 | begin:copy(ins.data.begin) 29 | local finish = ins:child().data.finish:zero():add(1 / 0) 30 | local err = ins.err 31 | err:zero() 32 | local pred = ins:child().data.pred:zero() 33 | for i = 1, params.batch_size do 34 | pred[i] = params.vocab_size 35 | local sample = ins.data.samples[i] 36 | if not sample:eos() then 37 | for k = 1, logprob[i]:size(1) do 38 | if math.abs(logprob[i]:max() - logprob[i][k]) < 1e-8 then 39 | pred[i] = k 40 | break 41 | end 42 | end 43 | assert(pred[i] ~= 0) 44 | if target[i] ~= 0 then 45 | if sample:predict(pred[i], logprob[i]) then 46 | correct[i] = 1 47 | else 48 | correct[i] = 0 49 | end 50 | err[i] = -logprob[i][target[i]] 51 | else 52 | pred[i] = params.vocab_size 53 | end 54 | if model.step % 20 == 1 and params.train == 2 then 55 | assert(#sample.strings + ins.data.begin[i] == ins.time) 56 | -- Saves trace of an executions. 57 | table.insert(sample.strings, tostring(sample)) 58 | end 59 | end 60 | end 61 | return ins 62 | end 63 | 64 | function Loss:updateGradInput(input, gradOutput) 65 | local ins, logprob, target = unpack(input) 66 | self.dlogprob:resizeAs(logprob):zero() 67 | for i = 1, target:size(1) do 68 | if ins.data.samples[i].train == 2 or params.gc == 1 then 69 | if target[i] ~= 0 then 70 | self.dlogprob[i][target[i]] = -1 71 | end 72 | model.trained_chars = model.trained_chars + 1 73 | end 74 | end 75 | self.gradInput = {self.dins, self.dlogprob, self.dtarget} 76 | return self.gradInput 77 | end 78 | -------------------------------------------------------------------------------- /layers/Move.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Moves over input tape, or input grid. 12 | local Move, parent = torch.class('nn.Move', 'nn.Module') 13 | 14 | function Move:__init() 15 | parent.__init(self) 16 | self.dins = torch.zeros(1) 17 | end 18 | 19 | function Move:updateOutput(input) 20 | local ins = input 21 | local err = ins.err 22 | local time = ins:child().data.time 23 | local pred = ins:child().data.pred 24 | local sampled = ins:child().data.sampled 25 | local x = ins:child().data.x:zero():add(1) 26 | local y = ins:child().data.y:zero():add(1) 27 | local samples = {} 28 | for i = 1, params.batch_size do 29 | local sample = ins.data.samples[i] 30 | samples[i] = sample 31 | sample:move() 32 | if sample:eos() then 33 | acc:record(sample) 34 | time[i] = 1 35 | if params.train == 2 then 36 | ins:child().data.begin[i] = ins:child().time 37 | for j = math.max(ins.data.begin[i], 1), ins.time do 38 | assert(model(j).data.finish[i] == 1 / 0) 39 | model(j).data.finish[i] = ins.time 40 | end 41 | end 42 | local sample = curriculum:generateNewSample() 43 | table.insert(model.samples, sample) 44 | sample.ins = ins:child().time 45 | sample.idx = i 46 | pred[i] = params.vocab_size 47 | samples[i] = sample 48 | sampled[i] = 1 49 | end 50 | x[i] = samples[i]:current_input() 51 | y[i] = samples[i]:current_target() 52 | end 53 | ins:child().data.samples = samples 54 | return ins 55 | end 56 | 57 | function Move:updateGradInput(input, gradOutput) 58 | self.gradInput = self.dins 59 | return self.gradInput 60 | end 61 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require('nn') 11 | require('torch') 12 | require('nngraph') 13 | require('optim') 14 | include("instance/Model.lua") 15 | include("instance/Instance.lua") 16 | include("layers/Index.lua") 17 | include("layers/Loss.lua") 18 | include("layers/Move.lua") 19 | include("layers/EpisodeBarrier.lua") 20 | include("interfaces/Interface.lua") 21 | include("interfaces/Interfaces.lua") 22 | include("interfaces/Data.lua") 23 | include("utils/base.lua") 24 | include("utils/Acc.lua") 25 | include("utils/nngraph.lua") 26 | include("utils/models.lua") 27 | include("utils/Visualizer.lua") 28 | include("utils/DataGenerator.lua") 29 | include("utils/Curriculum.lua") 30 | include("utils/Game.lua") 31 | include("utils/colors.lua") 32 | include("utils/q_learning.lua") 33 | 34 | lapp = require 'pl.lapp' 35 | 36 | params = lapp[[ 37 | --batch_size (default 20) 38 | --task (default "addition") copy | reverse | walk | addition | addition3 | single_mul 39 | --seed (default 1) random initialization seed 40 | --q_type (default "q_watkins") q_classic | q_watkins 41 | --q_discount (default -1) -1 (dynamic discount) | 0.95 (discount of 0.95) | 1 (no discount) 42 | --q_lr (default 0.1) learning rate over q-function 43 | --unit (default "gru") feedforward | lstm | gru 44 | --test_len (default 200) complexity of the test instances 45 | --max_seq_length (default 50) maximum complexity of the training instances 46 | --layers (default 1) number of layers 47 | --rnn_size (default 200) number of hidden units 48 | --lr (default 0.1) learning rate 49 | --max_grad_norm (default 5) clipping of gradient norm 50 | ]] 51 | 52 | if params.task == "reverse" or 53 | params.task == "ident" then 54 | params.dim = 1 55 | else 56 | params.dim = 2 57 | end 58 | 59 | function create_network() 60 | interfaces:set_sizes() 61 | g_make_deterministic(params.seed) 62 | local ins_node_org = nn.Identity()() 63 | local prev_s = nn.Identity()() 64 | local s = prev_s 65 | local ins_node = ins_node_org 66 | -- Ensures that state doesn't leak between consecutive samples. 67 | s = nn.EpisodeBarrier()({ins_node, s}) 68 | local embs = {} 69 | -- Input from the all interfaces. 70 | for name, interface in pairs(interfaces.interfaces) do 71 | embs = merge(embs, interface:last_action(ins_node)) 72 | embs = merge(embs, interface:view(ins_node)) 73 | end 74 | local join = join_table(embs) 75 | local linear = nn.Linear(params.input_size, params.rnn_size) 76 | -- It's an LSTM, GRU, or FF. 77 | local h, next_s = _G[params.unit](linear(join), s) 78 | h = nn.Linear(params.rnn_size, params.rnn_size)(h) 79 | -- Computes Q-function. 80 | ins_node = interfaces:q_learning(ins_node, h) 81 | ins_node = interfaces:apply(ins_node) 82 | local pre_prob = nn.Linear(params.rnn_size, params.vocab_size)(nn.Tanh()(h)) 83 | local prob = nn.LogSoftMax()(pre_prob) 84 | local target = interfaces.data:target(ins_node) 85 | -- Computes loss. 86 | ins_node = nn.Loss()({ins_node, prob, target}) 87 | ins_node = nn.Move()(ins_node) 88 | return nn.gModule({ins_node_org, prev_s}, {ins_node, next_s}) 89 | end 90 | 91 | function update_weights() 92 | -- Gradient clipping. 93 | local norm_dw = paramdx:norm() 94 | if norm_dw ~= norm_dw or norm_dw >= 10000 then 95 | print("\nNORM TOO HIGH", norm_dw) 96 | os.exit(-1) 97 | end 98 | local shrink_factor = 1 99 | if norm_dw > params.max_grad_norm then 100 | shrink_factor = params.max_grad_norm / norm_dw 101 | end 102 | model.norm_dw = norm_dw 103 | paramdx:mul(shrink_factor) 104 | return 0, paramdx 105 | end 106 | 107 | function setup() 108 | torch.setdefaulttensortype('torch.FloatTensor') 109 | g_make_deterministic(params.seed) 110 | params.vocab_size = 24 111 | initial_params = {} 112 | for k, v in pairs(params) do 113 | initial_params[k] = v 114 | end 115 | -- Generates tasks. 116 | data_generator = DataGenerator() 117 | -- Interfaces (i.e. tape, grid). 118 | interfaces = Interfaces() 119 | -- Stores model parameters. 120 | model = Model() 121 | -- Provides curriculum over the samples. 122 | curriculum = Curriculum() 123 | -- Keeps track of what have been solved. 124 | acc = Acc() 125 | -- Visualizes execution. 126 | visualizer = Visualizer() 127 | model:reboot() 128 | end 129 | 130 | setup() 131 | print("Network parameters:") 132 | print(params) 133 | print("Starting training.") 134 | while true do 135 | model:fp() 136 | model:bp() 137 | model.step = model.step + 1 138 | curriculum:progress() 139 | optim.sgd(update_weights, paramx, {learningRate=params.lr}) 140 | if model.step % 30 == 0 then 141 | collectgarbage() 142 | end 143 | if curriculum.complexity > params.max_seq_length then 144 | os.exit(0) 145 | end 146 | end 147 | -------------------------------------------------------------------------------- /play.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import os 3 | import sys 4 | 5 | assert(len(sys.argv) <= 2) 6 | prefix = "addition" 7 | if len(sys.argv) == 2: 8 | prefix = sys.argv[1] 9 | 10 | def getch(): 11 | import sys, tty, termios 12 | fd = sys.stdin.fileno() 13 | old_settings = termios.tcgetattr(fd) 14 | try: 15 | tty.setraw(sys.stdin.fileno()) 16 | ch = sys.stdin.read(1) 17 | finally: 18 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) 19 | return ch 20 | 21 | os.system("cp movie/* /tmp") 22 | 23 | 24 | content = [] 25 | files = os.listdir('movie') 26 | end = 1 27 | for f in files: 28 | if len(f) <= len(prefix) + 1: 29 | continue 30 | try: 31 | if f[:len(prefix)] == prefix: 32 | end = max(end, int(f[(len(prefix) + 1):])) 33 | except: 34 | pass 35 | 36 | for i in range(end): 37 | with open('movie/%s_%d' % (prefix, i + 1), 'r') as content_file: 38 | content.append(content_file.read()) 39 | 40 | current = 0 41 | 42 | while True: 43 | for i in range(200): 44 | print("") 45 | print("%d / %d" % (current + 1, len(content))) 46 | print(content[current]) 47 | print("\nPress ``s'' to show next frame, ``a'' the previous frame, and ``q'' to exit.") 48 | ch = getch() 49 | if ch == 'a': 50 | current = (current - 1) % len(content) 51 | elif ch == 's': 52 | current = (current + 1) % len(content) 53 | elif ch == 'q': 54 | exit(0) 55 | 56 | -------------------------------------------------------------------------------- /utils/Acc.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | local Acc = torch.class('Acc') 12 | 13 | function Acc:__init() 14 | self.fields = {"correct", "normal"} 15 | self.min_normal = {10, 50} 16 | self.alpha = 0.9 17 | self.min_acc = 0.995 18 | self:clean() 19 | end 20 | 21 | function Acc:getkey(current_complexity, complexity) 22 | assert(current_complexity ~= nil and complexity ~= nil) 23 | local ret = "current_complexity=" .. tostring(current_complexity) 24 | ret = ret .. string.format(", complexity:%03d", complexity) 25 | self.complexity[ret] = complexity 26 | if self.available_complexity[current_complexity] == nil then 27 | self.available_complexity[current_complexity] = {} 28 | end 29 | self.available_complexity[current_complexity][complexity] = true 30 | self.current_complexity[ret] = current_complexity 31 | return ret 32 | end 33 | 34 | function Acc:is_normal() 35 | local count = 0 36 | if self.available_complexity == nil or 37 | self.available_complexity[curriculum.complexity] == nil then 38 | return false 39 | end 40 | for c, _ in pairs(self.available_complexity[curriculum.complexity]) do 41 | if (params.train == 1 and c > curriculum.complexity + 10) or (c >= curriculum.complexity and c <= curriculum.complexity + 10 and params.train == 2) then 42 | local key_c = self:getkey(curriculum.complexity, c) 43 | count = count + 1 44 | if self.strict_normal[key_c] < self.min_normal[params.train] then 45 | return false 46 | end 47 | end 48 | end 49 | if count == 0 then 50 | return false 51 | end 52 | return true 53 | end 54 | 55 | function Acc:is_good(min_acc) 56 | min_acc = min_acc or self.min_acc 57 | local count = 0 58 | if self.available_complexity == nil or 59 | self.available_complexity[curriculum.complexity] == nil then 60 | return false 61 | end 62 | for c, _ in pairs(self.available_complexity[curriculum.complexity]) do 63 | if (params.train == 1 and c > curriculum.complexity + 10) or (c >= curriculum.complexity and c < curriculum.complexity + 10 and params.train == 2) then 64 | local key_c = self:getkey(curriculum.complexity, c) 65 | count = count + 1 66 | if self.strict_acc[key_c] < min_acc then 67 | return false 68 | end 69 | end 70 | end 71 | if count == 0 then 72 | return false 73 | end 74 | return true 75 | end 76 | 77 | function Acc:record(sample) 78 | if not sample:eos() or 79 | sample.train == 2 then 80 | return 81 | end 82 | local key = self:getkey(curriculum.complexity, sample.complexity) 83 | for _, f in pairs(self.fields) do 84 | if self[f][key] == nil then 85 | self[f][key] = sample[f] 86 | end 87 | self[f][key] = self[f][key] + sample[f] 88 | end 89 | if self.strict_normal[key] == nil then 90 | self.strict_normal[key] = 0 91 | self.strict_correct[key] = 0 92 | end 93 | if sample.correct == sample.normal then 94 | self.strict_correct[key] = self.strict_correct[key] + 1 95 | elseif sample.normal > 20 and sample.correct == sample.normal - 1 then 96 | self.strict_correct[key] = self.strict_correct[key] + 1 97 | self.correct[key] = self.correct[key] + 1 98 | end 99 | 100 | self.strict_normal[key] = self.strict_normal[key] + 1 101 | self.strict_acc[key] = self.strict_correct[key] / self.strict_normal[key] 102 | self.acc[key] = self.correct[key] / self.normal[key] 103 | 104 | while self.normal[key] > 100 and self.alpha ~= 1 do 105 | self.normal[key] = self.normal[key] * self.alpha 106 | self.correct[key] = self.correct[key] * self.alpha 107 | end 108 | while self.strict_normal[key] > 100 and self.alpha ~= 1 do 109 | self.strict_normal[key] = self.strict_normal[key] * self.alpha 110 | self.strict_correct[key] = self.strict_correct[key] * self.alpha 111 | end 112 | end 113 | 114 | function Acc:clean() 115 | for _, f in pairs(self.fields) do 116 | self[f] = {} 117 | end 118 | self.complexity = {} 119 | self.available_complexity = {} 120 | self.current_complexity = {} 121 | self.acc = {} 122 | self.strict_acc = {} 123 | self.strict_correct = {} 124 | self.strict_normal = {} 125 | end 126 | 127 | function Acc:__tostring__() 128 | local function get_res(field) 129 | local current_complexities = {} 130 | for k, v in pairs(self.current_complexity) do 131 | if current_complexities[v] == nil then 132 | current_complexities[v] = {} 133 | end 134 | table.insert(current_complexities[v], k) 135 | table.sort(current_complexities[v]) 136 | end 137 | local results = "{" 138 | for current_complexity, list in pairs(current_complexities) do 139 | results = results .. tostring(current_complexity) .. " : [" 140 | for _, key in pairs(list) do 141 | results = results .. "(" .. g_d(self.complexity[key]) .. ", " .. g_f3(self[field][key]) .. "), " 142 | end 143 | results = results .. "], " 144 | end 145 | return results .. "#\n" 146 | end 147 | return "\nAccuracies for a given complexity.\n" .. 148 | "Accuracy on the whole sample basis = " .. get_res("strict_acc") .. 149 | "Accuracy on the character basis = " .. get_res("acc") 150 | end 151 | 152 | function Acc:get_current_acc(complexity) 153 | complexity = complexity or curriculum.complexity 154 | for key, _ in pairs(self.complexity) do 155 | if self.complexity[key] == complexity and 156 | self.current_complexity[key] == curriculum.complexity then 157 | return self.acc[key] 158 | end 159 | end 160 | return 0 161 | end 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /utils/Curriculum.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | local Curriculum = torch.class('Curriculum') 12 | 13 | function Curriculum:__init() 14 | self.min_complexity = 6 15 | self.complexity = self.min_complexity 16 | self.type_count = 0 -- Number of trained examples. 17 | self.min_min_acc = 0.95 18 | params.train = 2 19 | end 20 | 21 | function Curriculum:progress() 22 | -- Checks if achieved good results on test data. 23 | if params.train == 1 then 24 | if acc:is_good() then 25 | if acc:is_normal() then 26 | Visualizer:visualize(true) 27 | print("DONE") 28 | print(acc) 29 | os.exit(0) 30 | end 31 | else 32 | self.complexity = self.complexity + 4 33 | self.complexity = math.min(self.complexity, params.test_len - 10) 34 | params.train = 2 35 | model.instances = {} 36 | return 37 | end 38 | end 39 | if acc:is_good(self.min_min_acc) and 40 | acc:is_normal() then 41 | -- When the training seems to work reasonably well, then we start penalizing Q(s, \bullet). 42 | params.q_decay_lr = 0.01 43 | params.decay_expr = 1 44 | end 45 | if not acc:is_normal() or not acc:is_good() then 46 | return 47 | end 48 | params.train = 1 49 | Visualizer:visualize(true) 50 | model:clean() 51 | end 52 | 53 | -- Chooses complexity for a sample. 54 | function Curriculum:pick_complexity() 55 | local complexity = self.complexity - math.random(5) + 1 56 | complexity = math.max(complexity, 2) 57 | local acc_current = acc:get_current_acc() 58 | -- Complexity for a sample when we test. 59 | if params.train == 1 then 60 | local pow = math.log(params.test_len - 6 - self.complexity) / math.log(2) 61 | pow = math.floor(pow) 62 | complexity = self.complexity + math.pow(2, pow - math.random(3) + 1) + 5 63 | end 64 | return complexity 65 | end 66 | 67 | -- Dynamically decides on the number of unrolling steps. 68 | function Curriculum:expandLength(sample) 69 | if params.train == 2 then 70 | params.seq_length = math.max(params.seq_length or 1, sample.complexity + 5) 71 | params.seq_length = math.min(params.seq_length, params.max_seq_length - 1) 72 | else 73 | params.seq_length = math.max(params.seq_length, 20) 74 | end 75 | end 76 | 77 | function Curriculum:generateNewSample() 78 | local complexity = self:pick_complexity() 79 | local sample = data_generator["task_" .. params.task](data_generator, complexity) 80 | sample.task_name = params.task 81 | sample.train = params.train 82 | -- 20% of samples from training are used for the validation. 83 | -- The real testing occurs on much longer samples. 84 | if math.random(5) == 1 then 85 | sample.train = 1 86 | end 87 | self:expandLength(sample) 88 | -- Counts number of used samples so far. 89 | self.type_count = self.type_count + 1 90 | return sample 91 | end 92 | 93 | -------------------------------------------------------------------------------- /utils/DataGenerator.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Defines tasks. 12 | local DataGenerator = torch.class('DataGenerator') 13 | 14 | function DataGenerator:__init() 15 | self.task_ids = {} 16 | local tasks = 0 17 | self.base = 10 18 | 19 | for k, _ in pairs(getmetatable(self)) do 20 | if k:sub(1, 4) == "task" then 21 | tasks = tasks + 1 22 | self.task_ids[k:sub(6, #k)] = tasks 23 | end 24 | end 25 | params.ntasks = len(self.task_ids) 26 | end 27 | 28 | function DataGenerator:encode(s) 29 | local ret = "" 30 | if type(s) == "number" then 31 | ret = encode_map[s] 32 | else 33 | ret = "" 34 | for i = 1, s:size(1) do 35 | ret = ret .. encode_map[s[i]] 36 | end 37 | end 38 | assert(ret ~= nil, tostring(s)) 39 | return ret 40 | end 41 | 42 | --------------------------------------------------- 43 | 44 | local function simple() 45 | local current = curriculum.complexity 46 | if current == 6 then 47 | return math.min(math.random(2), self.base - 1) 48 | elseif current == 10 then 49 | return math.min(math.random(3), self.base - 1) 50 | elseif current == 14 then 51 | return math.random(self.base) - 1 52 | end 53 | end 54 | 55 | -- Single digit multiplication. 56 | function DataGenerator:task_single_mul(complexity) 57 | local inp = {} 58 | local len = math.max(complexity, 1) 59 | local digit = math.min(math.random(8) + 1, self.base - 1) 60 | for j = 1, len do 61 | inp[len - j + 1] = math.random(self.base) - 1 62 | end 63 | local ret = self:mul(digit, inp) 64 | if #ret.targets < 1 then 65 | return self:task_single_mul(complexity) 66 | end 67 | return ret 68 | end 69 | 70 | function DataGenerator:mul(digit, inp) 71 | local function at(self, x, y) 72 | if x == 1 and y == 1 then 73 | return digit 74 | end 75 | if y == 2 and x >= -#inp + 2 and x <= 1 then 76 | return inp[x + #inp - 1] 77 | end 78 | if x == 3 then 79 | return "q" 80 | end 81 | return 0 82 | end 83 | local task = {{}} 84 | local config = {complexity=complexity, 85 | dim=2, 86 | at=at, 87 | evaluate=function() end} 88 | local ret = Game(task, config) 89 | 90 | local curry_val = 0 91 | for i = 1, #inp + 1 do 92 | local x = inp[#inp - i + 1] 93 | if x == nil or x == self.base then 94 | x = 0 95 | end 96 | local out = x * digit + curry_val 97 | curry_val = math.floor(out / self.base) 98 | local out = out % self.base 99 | table.insert(ret.targets, out) 100 | end 101 | ret.complexity = #inp + 2 102 | ret.normal = #ret.targets 103 | return ret 104 | end 105 | 106 | -- Generic addition (arbitrary number of numbers). 107 | function DataGenerator:add(inp) 108 | generic = generic or 0 109 | assert(inp ~= nil) 110 | local function at(self, x, y) 111 | for k = 1, #inp do 112 | if y == k and x >= -(#inp[k]) + 2 and x <= 1 then 113 | return inp[k][x + (#inp[k]) - 1] 114 | end 115 | end 116 | if x == 3 then 117 | return "q" 118 | end 119 | return "e" 120 | end 121 | local task = {{}} 122 | local config = {complexity=complexity, 123 | dim=2, 124 | at=at, 125 | evaluate=function() end} 126 | local ret = Game(task, config) 127 | local curry_val = 0 128 | local max_len = 0 129 | for i = 1, #inp do 130 | max_len = math.max(max_len, #inp[i]) 131 | end 132 | for i = 1, max_len + 1 do 133 | local out = 0 134 | for k = 1, #inp do 135 | local x = inp[k][#inp[k] - i + 1] 136 | if x == nil or x == self.base then 137 | x = 0 138 | end 139 | out = out + x 140 | end 141 | out = out + curry_val 142 | curry_val = math.floor(out / self.base) 143 | table.insert(ret.targets, out % self.base) 144 | end 145 | ret.len = (max_len + 1) * #inp 146 | ret.normal = #ret.targets 147 | return ret 148 | end 149 | 150 | -- Two row addition. 151 | function DataGenerator:task_addition(complexity) 152 | local inp = {} 153 | local rows = 2 154 | local min = math.random(3) 155 | local max_len = 0 156 | for i = 1, rows do 157 | local len = math.max(math.floor(complexity / rows), min) 158 | inp[i] = {} 159 | for j = 1, len do 160 | inp[i][len - j + 1] = math.random(self.base) - 1 161 | end 162 | max_len = math.max(max_len, len) 163 | end 164 | local ret = self:add(inp, generic) 165 | ret.complexity = rows * max_len 166 | if #ret.targets < 1 then 167 | return self:task_addition(complexity, rows, generic) 168 | end 169 | return ret 170 | end 171 | 172 | -- Three row addition. 173 | -- We had to provide heavy curriculum to be able to solve this task. 174 | function DataGenerator:task_addition3(complexity) 175 | local inp = {} 176 | local rows = 3 177 | local len = {} 178 | local max_len = 0 179 | local min_rows = math.random(3) + 1 180 | local acc_current = acc:get_current_acc(complexity) 181 | local curr_complexity = curriculum.complexity 182 | for i = 1, rows do 183 | len[i] = math.max(math.floor(complexity / rows), min_rows) 184 | inp[i] = {} 185 | for j = 1, len[i] do 186 | if params.train == 2 and (curr_complexity <= 4 or (complexity <= 16 and math.random(2) == 1)) then 187 | if acc_current < 0.7 then 188 | inp[i][j] = math.random(self.base - 1) 189 | if math.random(2) == 1 then 190 | inp[i][j] = math.min(math.random(4), self.base - 1) 191 | end 192 | else 193 | inp[i][j] = math.random(self.base - 1) 194 | end 195 | else 196 | inp[i][j] = math.random(self.base) - 1 197 | end 198 | end 199 | max_len = math.max(max_len, len[i]) 200 | end 201 | local function simple(harder) 202 | if curr_complexity <= 4 and hardner ~= true then 203 | if math.random(3) <= 2 then 204 | return 0 205 | else 206 | return 1 207 | end 208 | elseif curr_complexity <= 8 then 209 | return math.min(math.random(4), self.base) - 1 210 | else 211 | if math.random(3) == 1 then 212 | return 0 213 | elseif math.random(2) == 1 then 214 | return math.min(math.random(4), self.base) - 1 215 | else 216 | return math.random(self.base) - 1 217 | end 218 | end 219 | end 220 | local function set_simple(row, offset, harder) 221 | if #inp[row] > offset then 222 | inp[row][#inp[row] - offset] = simple(harder) 223 | end 224 | end 225 | if curr_complexity <= 12 or (curr_complexity <= 16 and math.random(2) == 1) then 226 | if math.random(3) <= 2 then 227 | set_simple(1, 0) 228 | set_simple(2, 0) 229 | else 230 | set_simple(math.random(2), 0) 231 | end 232 | local acc_current12 = acc:get_current_acc(12) 233 | if max_len >= 2 then 234 | set_simple(3, 1) 235 | set_simple(2, 1, true) 236 | end 237 | if max_len >= 3 then 238 | set_simple(1, 2) 239 | set_simple(2, 2, true) 240 | end 241 | if max_len >= 4 then 242 | set_simple(3, 3) 243 | set_simple(2, 3, true) 244 | end 245 | end 246 | local ret = self:add(inp) 247 | ret.complexity = #inp * 3 248 | if #ret.targets < 1 then 249 | return self:task_addition(complexity, rows, generic) 250 | end 251 | ret.inp = inp 252 | return ret 253 | end 254 | 255 | -- Copy task. 256 | function DataGenerator:task_copy(complexity) 257 | local tabs = {} 258 | tabs[complexity + 1] = {"q"} 259 | local config = {complexity=complexity, 260 | dim=1} 261 | local ret = Game(tabs, config) 262 | ret.complexity = complexity 263 | return ret 264 | end 265 | 266 | -- Reverse task. 267 | function DataGenerator:task_reverse(complexity) 268 | local function fun_p(self) 269 | self.output = true 270 | self.dir[1] = -1 271 | self.dir[2] = 0 272 | self.dir_idx = 1 273 | end 274 | local function fun_init(self) 275 | self.output = false 276 | end 277 | local time = math.floor(complexity / 2) + 1 278 | local task = {} 279 | assert(time >= 2) 280 | for i = 1, time - 1 do 281 | task[i] = {math.random(self.base) - 1} 282 | end 283 | task[time] = {"p"} 284 | task[time + 1] = {"p"} 285 | task[0] = {"q"} 286 | local config = {complexity=complexity, 287 | dim=1, 288 | p_loc=time, 289 | init=fun_init, 290 | fun_p=fun_p} 291 | return Game(task, config) 292 | end 293 | 294 | -- Walk task. 295 | function DataGenerator:task_walk(complexity) 296 | local function fun_l(self) 297 | self.output = true;self.dir[1] = -1;self.dir[2] = 0;self.dir_idx = 1 298 | end 299 | local function fun_u(self) 300 | self.output = true;self.dir[1] = 0;self.dir[2] = -1;self.dir_idx = 4 301 | end 302 | local function fun_d(self) 303 | self.output = true; self.dir[1] = 0;self.dir[2] = 1;self.dir_idx = 3 304 | end 305 | local function fun_init(self) 306 | self.output = false 307 | end 308 | local funs = 309 | {fun_l=fun_l, 310 | fun_u=fun_u, 311 | fun_d=fun_d, 312 | fun_r=fun_r} 313 | local at_good = Game.at 314 | local function at(self, x, y) 315 | if self.time >= complexity + 3 then 316 | return "q" 317 | else 318 | return at_good(self, x, y) 319 | end 320 | end 321 | local time = math.floor((complexity - 2) / 2) + 1 322 | local function attach(task) 323 | task.x = task.x + time * task.dir[1] 324 | task.y = task.y + time * task.dir[2] 325 | if task[task.x] == nil then 326 | task[task.x] = {} 327 | end 328 | local char = {"l", "u", "d"} 329 | local c = char[math.random(3)] 330 | task[task.x][task.y] = c 331 | local s = {dir = {}} 332 | funs["fun_" .. c](s) 333 | task.dir = table.copy(s.dir) 334 | return task 335 | end 336 | local task = {x=1, y=1, dir={1, 0}} 337 | local task = attach(task) 338 | task["x"] = nil 339 | task["y"] = nil 340 | task["dir"] = nil 341 | 342 | local config = safe_merge(funs, {complexity=complexity, 343 | dim=2, 344 | init=fun_init, 345 | at=at}) 346 | local ret = Game(task, config) 347 | return ret 348 | end 349 | 350 | 351 | -------------------------------------------------------------------------------- /utils/Game.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Our problems are executed on a grid (board). Every problem 12 | -- is the instance of this class. 13 | 14 | local Game = torch.class('Game') 15 | 16 | function init() 17 | if Game.max_symbols ~= nil then 18 | return 19 | end 20 | Game.max_symbols = 10 21 | local chars = {"q", "e", "s", "p", "a", "c", "y", "k", "v"} 22 | Game.dirs = {} 23 | Game.dirs[1] = {{-1, 0}, {1, 0}} -- Directions to move for 1-dimensional tape. 24 | Game.dirs[2] = {{-1, 0}, {1, 0}, {0, 1}, {0, -1}} -- Directions to move for 2-dimensional grid. 25 | Game.dir_names = {"l", "r", "d", "u"} -- Directions: left, right, down, up. 26 | Game.encode_map = {} 27 | Game.decode_map = {} 28 | for i = 1, Game.max_symbols - 1 do 29 | Game.decode_map[i] = i 30 | end 31 | Game.decode_map[0] = Game.max_symbols 32 | chars = merge(chars, Game.dir_names) 33 | for _, c in pairs(chars) do 34 | Game.decode_map[c] = len(Game.decode_map) + 1 35 | end 36 | for k, v in pairs(Game.decode_map) do 37 | Game.encode_map[v] = k 38 | end 39 | end 40 | 41 | function Game:__init(tab, config) 42 | init() 43 | config = config or {} 44 | for k, v in pairs(config) do 45 | self[k] = v 46 | end 47 | self.A = tab 48 | self.targets = {} 49 | self.complexity = 0 50 | self:reset() 51 | self:evaluate() 52 | self:reset() 53 | self.dim = self.dim or 2 54 | end 55 | 56 | function Game:reset() 57 | self.correct = 0 58 | self.x = 1 59 | self.failed = false 60 | self.y = 1 61 | self.predicted = {} 62 | self.strings = {} 63 | self.target_idx = 1 64 | self.dir = {1, 0} 65 | self.output = true 66 | self.grid = "A" 67 | self.time = 1 68 | self.last_prediction_time = 1 69 | self.normal = #self.targets 70 | self.dir_idx = 2 71 | if self.init ~= nil then 72 | self.init(self) 73 | end 74 | end 75 | 76 | --empty. 77 | function Game:fun_e() 78 | end 79 | 80 | function Game:fun_c() 81 | end 82 | 83 | function Game:fun_r() 84 | end 85 | 86 | function Game:fun_s() 87 | self.output = false 88 | end 89 | 90 | function Game:fun_p() 91 | self.output = true 92 | end 93 | 94 | -- Change direction to the left. 95 | function Game:fun_l() 96 | self.dir[1] = -1 97 | self.dir[2] = 0 98 | self.dir_idx = 1 99 | end 100 | 101 | -- Change direction to down. 102 | function Game:fun_d() 103 | self.dir[1] = 0 104 | self.dir[2] = 1 105 | self.dir_idx = 3 106 | end 107 | 108 | -- Change direction to up. 109 | function Game:fun_u() 110 | self.dir[1] = 0 111 | self.dir[2] = -1 112 | self.dir_idx = 4 113 | end 114 | 115 | -- Change direction to the right. 116 | function Game:fun_r() 117 | self.dir[1] = 1 118 | self.dir[2] = 0 119 | self.dir_idx = 2 120 | end 121 | 122 | -- Used to establish targets. 123 | function Game:produce_target() 124 | if self.output then 125 | local v = self:at(self.x, self.y) 126 | if type(v) == "number" then 127 | table.insert(self.targets, v) 128 | end 129 | end 130 | self.complexity = self.complexity + 1 131 | end 132 | 133 | -- Used during model execution. Loss calls 134 | -- this function with it's own predictions. 135 | -- This function tells if predictions are good or not. 136 | function Game:predict(predicted) 137 | table.insert(self.predicted, predicted) 138 | local target = self.targets[self.target_idx] 139 | if target == 0 then 140 | target = params.base 141 | end 142 | local passed = true 143 | if predicted ~= target then 144 | passed = false 145 | self.failed = true 146 | else 147 | self.correct = self.correct + 1 148 | end 149 | self.last_prediction_time = self.time 150 | self.target_idx = self.target_idx + 1 151 | return passed 152 | end 153 | 154 | function Game:evaluate() 155 | while true do 156 | local c = self:at(self.x, self.y) 157 | local f = self["fun_" .. c] 158 | if f ~= nil then 159 | f(self) 160 | end 161 | local ya = 1 162 | if self.output and type(c) == "number" then 163 | ya = 2 164 | end 165 | self:produce_target() 166 | self:move() 167 | if self:at(self.x, self.y) == "q" then 168 | return 169 | end 170 | end 171 | end 172 | 173 | function Game:current_input() 174 | if self:eos() then 175 | return params.vocab_size 176 | end 177 | local v = self:at(self.x, self.y) 178 | local ret = Game.decode_map[v] 179 | assert(ret ~= nil) 180 | return ret 181 | end 182 | 183 | function Game:current_target() 184 | if self:eos() then 185 | return params.vocab_size 186 | end 187 | local ret = Game.decode_map[self.targets[self.target_idx]] 188 | assert(ret ~= nil) 189 | return ret 190 | end 191 | 192 | -- The board is infinite, and the value in newly visited locations 193 | -- is generated on the flu. 194 | function Game:at(x, y) 195 | assert(type(x) == "number") 196 | assert(type(y) == "number") 197 | if self.A[x] == nil then 198 | self.A[x] = {} 199 | end 200 | if self.A[x][y] == nil then 201 | self.A[x][y] = math.random(self.max_symbols) - 1 202 | end 203 | local t = type(self.A[x][y]) 204 | assert(t == "number" or t == "string") 205 | return self.A[x][y] 206 | end 207 | 208 | -- Used to visualize current state of task. 209 | function Game:__tostring__() 210 | local s = "State: x=" .. g_d(self.x) .. ", y=" .. g_d(self.y) 211 | s = s .. "\ndir=[" .. g_d(self.dir[1]) .. ", " .. g_d(self.dir[2]) .. "]" 212 | s = s .. "\noutput=" .. tostring(self.output) 213 | s = s .. "\nScore=" .. tostring(self.correct) .. "/" .. tostring(#self.targets) 214 | s = s .. "\nLen=" .. tostring(self.complexity) 215 | s = s .. "\ngrid=" .. tostring(self.grid) 216 | s = s .. "\n\nA:\n " 217 | local from_y = math.min(-5, self.y - 2) 218 | local to_y = math.max(5, self.y + 2) 219 | if self.dim == 1 then 220 | from_y = 1 221 | to_y = 1 222 | end 223 | for y = from_y, to_y do 224 | for x = math.min(-5, self.x - 2), math.max(5, self.x + 2) do 225 | local v = self:at(x, y) 226 | if v == params.base then 227 | v = 0 228 | end 229 | if x == self.x and y == self.y then 230 | s = s .. green .. v .. reset 231 | else 232 | if type(v) == "number" then 233 | s = s .. v 234 | else 235 | s = s .. blue .. v .. reset 236 | end 237 | end 238 | end 239 | s = s .. "\n " 240 | end 241 | local y = "Y:" 242 | local p = "P:" 243 | for i = 1, #self.targets do 244 | local t = self.targets[i] 245 | if t == params.base then 246 | t = 0 247 | end 248 | if i == self.target_idx - 1 then 249 | y = y .. green .. t .. reset 250 | else 251 | y = y .. t 252 | end 253 | end 254 | for i = 1, math.min(self.target_idx - 1, #self.predicted) do 255 | local t = Game.encode_map[self.predicted[i] ] 256 | if t == params.base then 257 | t = 0 258 | end 259 | if i == self.target_idx - 1 then 260 | if t == self.targets[i] then 261 | p = p .. green .. t .. reset 262 | else 263 | p = p .. red .. t .. reset 264 | end 265 | else 266 | p = p .. t 267 | end 268 | end 269 | s = s .. "\n" .. y 270 | s = s .. "\n" .. p 271 | local xa = "" 272 | local ya = "" 273 | local ta = "" 274 | local wa = "" 275 | s = s .. "\nxa" .. "=" .. xa 276 | s = s .. "\nya" .. "=" .. ya 277 | s = s .. "\nta" .. "=" .. ta 278 | s = s .. "\nwa" .. "=" .. wa 279 | s = s .. "\nself.failed = " .. tostring(self.failed) 280 | s = s .. "\nself.target_idx = " .. tostring(self.target_idx) 281 | for i = 1, (5 - math.max(5, self.y + 2)) do 282 | s = s .. "\n" 283 | end 284 | return s 285 | end 286 | 287 | -- Determines if it's the end of the episode. 288 | function Game:eos(execute) 289 | if self.target_idx > #self.targets or 290 | (self:at(self.x, self.y) == "q" and execute == true) or 291 | self.time >= self.last_prediction_time + self.complexity + 2 or 292 | self.failed then 293 | return true 294 | else 295 | return false 296 | end 297 | end 298 | 299 | -- Moves according to the current move direction. 300 | function Game:move(dir) 301 | local c = self:at(self.x, self.y) 302 | dir = dir or self.dir 303 | self.x = self.x + dir[1] 304 | self.y = self.y + dir[2] 305 | self.time = self.time + 1 306 | end 307 | 308 | 309 | -------------------------------------------------------------------------------- /utils/Visualizer.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | local Visualizer = torch.class('Visualizer') 12 | 13 | function Visualizer:__init() 14 | self.last_visualize = nil 15 | end 16 | 17 | function Visualizer:get_global_desc() 18 | beginning_time = beginning_time or torch.tic() 19 | local total_cases = model.step * params.seq_length * params.batch_size 20 | local wps = math.floor(total_cases / torch.toc(beginning_time)) 21 | since_beginning = torch.toc(beginning_time) / 60 22 | local pr = {tostring(acc)} 23 | table.insert(pr, "norm of weights=" .. g_f5(paramx:norm()) .. 24 | ", norm of gradient=" .. g_f5(model.norm_dw) .. 25 | ", number of parameters=" .. g_d(paramx:size(1)) .. 26 | ', characters per second=' .. wps .. 27 | ', since beginning=' .. g_d(since_beginning) .. ' mins.' .. 28 | ', lr=' .. g_f3(params.lr)) 29 | 30 | local sizes = "" 31 | table.insert(pr, 'current complexity=' .. tostring(curriculum.complexity) .. 32 | ', time-step=' .. tostring(model.step) .. 33 | ', trained characters=' .. tostring(model.trained_chars) .. 34 | ', number of unrolling steps=' .. tostring(params.seq_length)) 35 | local counts = "Number of training samples so far: " .. curriculum.type_count 36 | table.insert(pr, counts) 37 | table.insert(pr, "") 38 | local s = "" 39 | for i = 1, #pr do 40 | s = s .. pr[i] .. "\n" 41 | end 42 | return s 43 | end 44 | 45 | function Visualizer:get_sample_desc(sample) 46 | local desc = "task_name=" .. tostring(sample.task_name) .. 47 | ", complexity=" .. tostring(sample.complexity) .. "\n\n" 48 | return desc 49 | end 50 | 51 | function Visualizer:save_movie(sample) 52 | if sample == nil then 53 | return 54 | end 55 | name = sample.task_name 56 | os.execute("mkdir -p movie") 57 | os.execute("rm -rf movie/" .. name .. "*") 58 | local begin = sample.ins 59 | local idx = sample.idx 60 | local from = model(begin) 61 | local global = self:get_global_desc() 62 | local desc = self:get_sample_desc(sample) 63 | for i, s in pairs(sample.strings) do 64 | local f = io.open(string.format("movie/%s_%d", name, i), "w") 65 | f:write(global) 66 | f:write(desc) 67 | f:write(s) 68 | f:write("\n\n") 69 | local ins = model(from.time + i - 1) 70 | f:write(self:actions_desc(ins, idx)) 71 | f:write(self:instance_desc(ins, idx)) 72 | f:close() 73 | end 74 | end 75 | 76 | function Visualizer:visualize(force) 77 | if force or self.last_params == nil or torch.toc(self.last_params) > 20 then 78 | last_params = torch.tic() 79 | print(params) 80 | print("") 81 | end 82 | print(self:get_global_desc()) 83 | if model.step % 20 ~= 1 or 84 | params.train == 1 then 85 | return 86 | end 87 | local random_sample = nil 88 | for _, sample in pairs(model.samples) do 89 | if sample:eos() and sample.train == 1 then 90 | local ins = model(sample.ins) 91 | local begin = ins.data.begin[sample.idx] 92 | local finish = ins.data.finish[sample.idx] 93 | assert(begin >= 1 and finish <= params.seq_length) 94 | if finish - begin >= 0 then 95 | local acc = sample.correct / sample.normal 96 | random_sample = sample 97 | break 98 | end 99 | end 100 | end 101 | self:save_movie(random_sample) 102 | end 103 | 104 | 105 | function Visualizer:actions_desc(ins, idx) 106 | local s = "" 107 | local sample = ins.data.samples[idx] 108 | s = s .. "\nlogits: \n" 109 | for pred = 1, ins.q:size(2) do 110 | local a = interfaces:decode_q(pred) 111 | for k, v in pairs(a) do 112 | s = s .. k .. tostring(v) .. "," 113 | end 114 | local max_idx = ins.max_idx[q] 115 | if max_idx == pred then 116 | s = s .. blue 117 | end 118 | s = s .. string.format("logits = %.2f, ", ins.logits[idx][pred]) 119 | if max_idx == pred then 120 | s = s .. reset 121 | end 122 | s = s .. "\n" 123 | end 124 | s = s .. string.format("reward = %d", ins.correct[idx]) 125 | return s 126 | end 127 | 128 | function Visualizer:instance_desc(ins, idx) 129 | local from = ins.data.begin[idx] 130 | local to = ins.data.finish[idx] 131 | return string.format("\nfrom:%d, to:%d, ins:%d", from, to, ins.time) 132 | end 133 | 134 | function short_prob(x) 135 | local function number2short_prob(x) 136 | assert(x >= 0 and x <= 1) 137 | if x == 0 then 138 | return "0" 139 | elseif x > 0.05 and x < 0.95 then 140 | return tostring(math.floor(10 * x + 0.5)) 141 | elseif x >= 0.95 then 142 | return "9" 143 | else 144 | return string.char(string.byte('a') + math.min(math.floor(-math.log(x)) - 1, 20)) 145 | end 146 | assert(false) 147 | end 148 | local ret = "" 149 | assert(math.abs(x:sum() - 1) < 1e-3, tostring(x)) 150 | for i = 1, x:size(1) do 151 | ret = ret .. number2short_prob(x[i]) 152 | end 153 | return ret 154 | end 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /utils/base.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | function g_make_deterministic(seed) 12 | math.randomseed(seed) 13 | torch.manualSeed(seed) 14 | end 15 | 16 | function join_table(table) 17 | if #table > 1 then 18 | return nn.JoinTable(2)(table) 19 | else 20 | return table[1] 21 | end 22 | end 23 | 24 | function clone(from, val) 25 | if torch.isTensor(from) then 26 | local ret = from:clone() 27 | if val ~= nil then 28 | ret:zero():add(val) 29 | end 30 | return ret 31 | elseif type(from) == "number" then 32 | if val ~= nil then 33 | return val 34 | else 35 | return from 36 | end 37 | elseif type(from) == "string" or 38 | type(from) == "boolean" then 39 | return from 40 | end 41 | assert(type(from) == "table") 42 | local ret = {} 43 | for k, v in pairs(from) do 44 | ret[k] = clone(v, val) 45 | end 46 | return ret 47 | end 48 | 49 | -- Display float up to 5 digits after dot. 50 | function g_f5(f) 51 | local ret = string.format("%.5f", f) 52 | if #tostring(f) < #ret then 53 | return tostring(f) 54 | else 55 | return ret 56 | end 57 | end 58 | 59 | -- Display float up to 3 digits after dot. 60 | function g_f3(f) 61 | local ret = string.format("%.3f", f) 62 | if #tostring(f) < #ret then 63 | return tostring(f) 64 | else 65 | return ret 66 | end 67 | end 68 | 69 | -- Display int. 70 | function g_d(f) 71 | return string.format("%d", math.floor(f)) 72 | end 73 | 74 | function istable(x) 75 | return type(x) == 'table' and not torch.typename(x) 76 | end 77 | 78 | function merge(a, b) 79 | local ret = setmetatable({}, getmetatable(a)) 80 | for k, v in pairs(a) do 81 | if type(k) == "number" then 82 | ret[#ret + 1] = v 83 | else 84 | ret[k] = v 85 | end 86 | end 87 | if b ~= nil then 88 | for k, v in pairs(b) do 89 | if type(k) == "number" then 90 | ret[#ret + 1] = v 91 | else 92 | ret[k] = v 93 | end 94 | end 95 | end 96 | return ret 97 | end 98 | 99 | function safe_merge(a, b) 100 | local ret = {} 101 | for k, v in pairs(a) do 102 | ret[k] = v 103 | end 104 | if b ~= nil then 105 | for k, v in pairs(b) do 106 | assert(ret[k] == nil) 107 | ret[k] = v 108 | end 109 | end 110 | return ret 111 | end 112 | 113 | function table.copy(obj) 114 | local ret = setmetatable({}, getmetatable(obj)) 115 | for k, v in pairs(obj) do 116 | ret[k] = v 117 | end 118 | return ret 119 | end 120 | 121 | function tensors_copy(to, from) 122 | if torch.isTensor(from) then 123 | to:copy(from) 124 | else 125 | for k, v in pairs(from) do 126 | tensors_copy(to[k], v) 127 | end 128 | end 129 | end 130 | 131 | function one() 132 | return torch.ones(1) 133 | end 134 | 135 | function zeros(size) 136 | if size ~= nil then 137 | return torch.zeros(params.batch_size, size) 138 | else 139 | return torch.zeros(params.batch_size) 140 | end 141 | end 142 | 143 | function ones(size) 144 | if size ~= nil then 145 | return torch.ones(params.batch_size, size) 146 | else 147 | return torch.ones(params.batch_size) 148 | end 149 | end 150 | 151 | function split(node, nr) 152 | assert(nr >= 1) 153 | if nr == 1 then 154 | return {nn.SelectTable(1)(node)} 155 | else 156 | return {node:split(nr)} 157 | end 158 | end 159 | 160 | function len(T) 161 | if type(T) == "string" then 162 | local inside = false 163 | local count = 0 164 | for i = 1, #T - 1 do 165 | if not inside and T:byte(i, i) == 27 and T:sub(i + 1, i + 1) == "[" then 166 | inside = true 167 | end 168 | if not inside then 169 | count = count + 1 170 | end 171 | if inside and T:sub(i, i) == "m" then 172 | inside = false 173 | end 174 | end 175 | if not inside then 176 | count = count + 1 177 | end 178 | return count 179 | else 180 | local count = 0 181 | for _ in pairs(T) do 182 | count = count + 1 183 | end 184 | return count 185 | end 186 | end 187 | 188 | function argmax(x) 189 | local max = -1 / 0 190 | local ret = 0 191 | for j = 1, x:size(1) do 192 | if max < x[j] then 193 | max = x[j] 194 | ret = j 195 | end 196 | end 197 | return ret 198 | end 199 | 200 | -------------------------------------------------------------------------------- /utils/colors.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | -- Allows to display colors. 12 | -- Modified code from https://github.com/hoelzro/ansicolors . 13 | 14 | local colormt = {} 15 | 16 | function colormt:__tostring() 17 | return self.value 18 | end 19 | 20 | function colormt:__concat(other) 21 | return tostring(self) .. tostring(other) 22 | end 23 | 24 | function colormt:__call(s) 25 | return self .. s .. _M.reset 26 | end 27 | 28 | colormt.__metatable = {} 29 | 30 | local function makecolor(value) 31 | return setmetatable({ value = string.char(27) .. '[' .. tostring(value) .. 'm' }, colormt) 32 | end 33 | 34 | local colors = { 35 | -- attributes 36 | reset = 0, 37 | clear = 0, 38 | bright = 1, 39 | dim = 2, 40 | underscore = 4, 41 | blink = 5, 42 | reverse = 7, 43 | hidden = 8, 44 | 45 | -- foreground 46 | black = 30, 47 | red = 31, 48 | green = 32, 49 | yellow = 33, 50 | blue = 34, 51 | magenta = 35, 52 | cyan = 36, 53 | white = 37, 54 | 55 | -- background 56 | onblack = 40, 57 | onred = 41, 58 | ongreen = 42, 59 | onyellow = 43, 60 | onblue = 44, 61 | onmagenta = 45, 62 | oncyan = 46, 63 | onwhite = 47, 64 | } 65 | 66 | for c, v in pairs(colors) do 67 | _G[c] = makecolor(v) 68 | end 69 | -------------------------------------------------------------------------------- /utils/models.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | function gru_core(x, prev_h) 12 | local function new_input_sum(xv, hv) 13 | local i2h = nn.Linear(params.rnn_size, params.rnn_size)(xv) 14 | local h2h = nn.Linear(params.rnn_size, params.rnn_size)(hv) 15 | return nn.CAddTable()({i2h, h2h}) 16 | end 17 | local update_gate = nn.Sigmoid()(new_input_sum(x, prev_h)) 18 | local reset_gate = nn.Sigmoid()(new_input_sum(x, prev_h)) 19 | -- compute candidate hidden state 20 | local gated_hidden = nn.CMulTable()({reset_gate, prev_h}) 21 | local p2 = nn.Linear(params.rnn_size, params.rnn_size)(gated_hidden) 22 | local p1 = nn.Linear(params.rnn_size, params.rnn_size)(x) 23 | local hidden_candidate = nn.Tanh()(nn.CAddTable()({p1,p2})) 24 | -- compute new interpolated hidden state, based on the update gate 25 | local zh = nn.CMulTable()({update_gate, hidden_candidate}) 26 | local zhm1 = nn.CMulTable()({nn.AddConstant(1,false)(nn.MulConstant(-1,false)(update_gate)), prev_h}) 27 | local next_h = nn.CAddTable()({zh, zhm1}) 28 | return next_h 29 | end 30 | 31 | function gru(single_input, prev_s) 32 | local i = {[0] = single_input} 33 | local next_s = {} 34 | local split = split(prev_s, params.layers) 35 | for layer_idx = 1, params.layers do 36 | local prev_h = split[layer_idx] 37 | local next_h = gru_core(i[layer_idx - 1], prev_h) 38 | table.insert(next_s, next_h) 39 | i[layer_idx] = next_h 40 | end 41 | return i[params.layers], nn.Identity()(next_s) 42 | end 43 | 44 | function feedforward(h, prev_s) 45 | return h, prev_s 46 | end 47 | 48 | function lstm_core(x, prev_c, prev_h) 49 | local i2h = nn.Linear(params.rnn_size, 4 * params.rnn_size)(x) 50 | local h2h = nn.Linear(params.rnn_size, 4 * params.rnn_size)(prev_h) 51 | local gates = nn.CAddTable()({i2h, h2h}) 52 | local reshaped_gates = nn.Reshape(4, params.rnn_size)(gates) 53 | local sliced_gates = nn.SplitTable(2)(reshaped_gates) 54 | local in_gate = nn.Sigmoid()(nn.SelectTable(1)(sliced_gates)) 55 | local in_transform = nn.Tanh()(nn.SelectTable(2)(sliced_gates)) 56 | local forget_gate = nn.Sigmoid()(nn.SelectTable(3)(sliced_gates)) 57 | local out_gate = nn.Sigmoid()(nn.SelectTable(4)(sliced_gates)) 58 | local next_c = nn.CAddTable()({ 59 | nn.CMulTable()({forget_gate, prev_c}), 60 | nn.CMulTable()({in_gate, in_transform}) 61 | }) 62 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 63 | return next_c, next_h 64 | end 65 | 66 | 67 | function lstm(single_input, prev_s) 68 | local i = {[0] = single_input} 69 | local next_s = {} 70 | local split = {prev_s:split(2 * params.layers)} 71 | for layer_idx = 1, params.layers do 72 | local prev_c = split[2 * layer_idx - 1] 73 | local prev_h = split[2 * layer_idx] 74 | local next_c, next_h 75 | next_c, next_h = lstm_core(i[layer_idx - 1], prev_c, prev_h) 76 | table.insert(next_s, next_c) 77 | table.insert(next_s, next_h) 78 | i[layer_idx] = next_h 79 | end 80 | return i[params.layers], nn.Identity()(next_s) 81 | end 82 | 83 | 84 | -------------------------------------------------------------------------------- /utils/nngraph.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | function ident(x) 12 | return x 13 | end 14 | 15 | function empty() 16 | return {torch.zeros(1)}, {torch.zeros(1)} 17 | end 18 | 19 | function empty_bp(self, input, gradOutput, name) 20 | if name == nil then 21 | name = "gradInput" 22 | end 23 | if istable(input) then 24 | self[name] = {} 25 | for k, v in pairs(input) do 26 | self[name][k] = empty_bp(self, v, 0, k) 27 | end 28 | elseif type(input) == "number" or 29 | (type(input) == 'table' and input.__typename == "Instance") then 30 | if self[name] == nil then 31 | self[name] = torch.zeros(1) 32 | end 33 | elseif torch.isTensor(input) then 34 | if self[name] == nil then 35 | self[name] = input:clone():zero() 36 | end 37 | else 38 | print("input") 39 | print(input) 40 | print("gradOutput") 41 | print(gradOutput) 42 | assert(false) 43 | end 44 | assert(self[name] ~= nil) 45 | return self[name] 46 | end 47 | 48 | function empty_type(self, new_type) 49 | return self 50 | end 51 | 52 | function emb(self, ins_node, fun, selection, which) 53 | local function fp(name, which) 54 | local function forward(self, ins) 55 | return fun(ins, name, which) 56 | end 57 | return forward 58 | end 59 | assert(type(selection) == "table") 60 | local ret = {} 61 | for name, desc in pairs(selection) do 62 | if desc.size > 0 then 63 | local tmp = new_node(self, ident, fp(name, which), empty_bp, {ins_node}) 64 | table.insert(ret, nn.OneHotIndex(desc.size, name)(tmp)) 65 | end 66 | end 67 | return ret 68 | end 69 | 70 | function new_node(self, init, fp, bp, input_nodes, name, which) 71 | local module = safe_merge(table.copy(self), {updateOutput=fp, 72 | updateGradInput=bp, 73 | accGradParameters=empty, 74 | parameters=empty, 75 | type=empty_type}) 76 | init(module) 77 | setmetatable(module, getmetatable(self)) 78 | module.trace = debug.traceback() 79 | if name ~= nil then 80 | module.node_name = name 81 | end 82 | if which ~= nil then 83 | module.which = which 84 | end 85 | local node = nngraph.Node({module=module}) 86 | if not istable(input_nodes) then 87 | input_nodes = {input_nodes} 88 | end 89 | for _, input_node in pairs(input_nodes) do 90 | if torch.typename(input_node) ~= 'nngraph.Node' then 91 | error('what is this in the input? ' .. tostring(input_node)) 92 | end 93 | node:add(input_node, true) 94 | end 95 | return node 96 | end 97 | 98 | local unpack = unpack and unpack or table.unpack 99 | 100 | local Module = torch.getmetatable('nn.Module') 101 | local call__ = Module.__call__ 102 | function Module:__call__(...) 103 | local input = {...} 104 | local ret = call__(unpack(merge({self}, input))) 105 | if ret.data ~= nil and 106 | ret.data.module ~= nil then 107 | ret.data.module.trace = debug.traceback() 108 | end 109 | return ret 110 | end 111 | -------------------------------------------------------------------------------- /utils/q_learning.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | 11 | function q_watkins_dynamic(ins, batch) 12 | local last = 0 13 | local diff = 0 14 | local rewards = 0 15 | local normal = ins.data.samples[batch].normal 16 | local my_v = normal - ins.target_idx[batch] + 1 17 | for k = ins.time, ins.data.finish[batch] do 18 | if k == ins.time or model(k).sampled[batch] == 0 then 19 | rewards = rewards + model(k).correct[batch] 20 | else 21 | last = k 22 | break 23 | end 24 | end 25 | local sample = ins.data.samples[batch] 26 | local chosen = ins.chosen[batch] 27 | diff = ins.q[batch][chosen] 28 | assert(my_v > 0) 29 | if last ~= 0 then 30 | assert(sample == model(last).data.samples[batch]) 31 | local max_idx = model(last).max_idx[batch] 32 | local last_v = normal - model(last).target_idx[batch] + 1 33 | assert(last_v <= my_v) 34 | diff = diff - rewards / my_v - last_v * model(last).q[batch][max_idx] / my_v 35 | else 36 | diff = ins.q[batch][chosen] - (rewards / my_v) 37 | end 38 | return diff 39 | end 40 | 41 | function q_watkins(ins, batch) 42 | local discount = params.q_discount 43 | if discount == -1 then 44 | return q_watkins_dynamic(ins, batch) 45 | end 46 | local last = 0 47 | local diff = 0 48 | local rewards = 0 49 | local mul = 1 50 | for k = ins.time, ins.data.finish[batch] do 51 | if k == ins.time or model(k).sampled[batch] == 0 then 52 | rewards = rewards + model(k).correct[batch] * mul 53 | else 54 | last = k 55 | break 56 | end 57 | mul = mul * discount 58 | end 59 | local sample = ins.data.samples[batch] 60 | local chosen = ins.chosen[batch] 61 | diff = ins.q[batch][chosen] - rewards 62 | if last ~= 0 then 63 | assert(sample == model(last).data.samples[batch]) 64 | local max_idx = model(last).max_idx[batch] 65 | diff = diff - mul * model(last).q[batch][max_idx] 66 | end 67 | return diff 68 | end 69 | 70 | function q_classic_dynamic(ins, batch) 71 | local normal = ins.data.samples[batch].normal 72 | local my_v = normal - ins.target_idx[batch] + 1 73 | local chosen = ins.chosen[batch] 74 | local diff = ins.q[batch][chosen] - ins.correct[batch] / my_v 75 | if ins:child().data.samples[batch] == ins.data.samples[batch] then 76 | local last_v = normal - ins:child().target_idx[batch] + 1 77 | assert(last_v <= my_v) 78 | local max_idx = ins:child().max_idx[batch] 79 | local discount = last_v / my_v 80 | diff = diff - discount * ins:child().q[batch][max_idx] 81 | end 82 | return diff 83 | end 84 | 85 | function q_classic(ins, batch) 86 | local discount = params.q_discount 87 | if discount == -1 then 88 | return q_classic_dynamic(ins, batch) 89 | end 90 | local chosen = ins.chosen[batch] 91 | local diff = ins.q[batch][chosen] - ins.correct[batch] 92 | if ins:child().data.samples[batch] == ins.data.samples[batch] then 93 | local max_idx = ins:child().max_idx[batch] 94 | diff = diff - discount * ins:child().q[batch][max_idx] 95 | end 96 | return diff 97 | end 98 | 99 | --------------------------------------------------------------------------------