├── CONTRIBUTING.md ├── LICENSE ├── Linear.h ├── Nonlinearity.h ├── PATENTS ├── README.md ├── StackRNN.h ├── Vec.h ├── common.h ├── data └── .gitignore ├── makefile ├── script_tasks.sh ├── task.h ├── train_add.cpp ├── train_toy.cpp └── utils.h /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Stack RNN 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 1. Fork the repo and create your branch from `master`. 8 | 2. If you've added code that should be tested, add tests 9 | 3. If you've changed APIs, update the documentation. 10 | 4. Ensure the test suite passes. 11 | 5. Make sure your code lints. 12 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 13 | 14 | ## Contributor License Agreement ("CLA") 15 | In order to accept your pull request, we need you to submit a CLA. You only need 16 | to do this once to work on any of Facebook's open source projects. 17 | 18 | Complete your CLA here: 19 | 20 | ## Issues 21 | We use GitHub issues to track public bugs. Please ensure your description is 22 | clear and has sufficient instructions to be able to reproduce the issue. 23 | 24 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 25 | disclosure of security bugs. In those cases, please go through the process 26 | outlined on that page and do not file a public issue. 27 | 28 | ## License 29 | By contributing to Stack RNN, you agree that your contributions will be licensed 30 | under its BSD license. 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For Stack RNN software 4 | 5 | Copyright (c) 2015-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /Linear.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #ifndef _LINEAR_ 11 | #define _LINEAR_ 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "common.h" 18 | #include "Vec.h" 19 | #include "utils.h" 20 | 21 | namespace rnn { 22 | 23 | // Linear struct: 24 | struct Linear{ 25 | public: 26 | /*** Constructors ***/ 27 | 28 | Linear(){}; 29 | 30 | explicit Linear(const my_int& si, const my_int& so) : 31 | _data(so, si, 0), 32 | _gradient(so, si, 0) {}; 33 | 34 | Linear(const Linear& rhs) { 35 | this->_data = rhs._data; 36 | this->_gradient = rhs._gradient; 37 | this->_gradient.zeros(); 38 | } 39 | 40 | 41 | /*** methods ***/ 42 | 43 | my_int ncol() const { 44 | return this->_data.ncol(); 45 | } 46 | 47 | my_int nrow() const { 48 | return this->_data.nrow(); 49 | } 50 | 51 | void initialize(){ 52 | for(my_int i = 0; i < _data.size(); i++) 53 | _data[i] = random(-0.1, 0.1)+random(-0.1, 0.1)+random(-0.1, 0.1); 54 | } 55 | 56 | my_int size() { return this->_data.size();} 57 | my_int sizeIn() { return this->_data.ncol();} 58 | my_int sizeOut() { return this->_data.nrow();} 59 | 60 | void zeros(){ this->_data.zeros();}; 61 | 62 | /*** forward methods ***/ 63 | 64 | void forward(const my_int& idx, Vec& out){ 65 | assert(out.size() == _data.nrow()); 66 | for(my_int x = 0; x < out.size(); x++) 67 | out[x] += _data(x, idx); 68 | }; 69 | 70 | void forward_transpose(const my_int& idx, Vec& out){ 71 | assert(out.size() == _data.ncol()); 72 | for(my_int x = 0; x < out.size(); x++) 73 | out[x] += _data(idx, x); 74 | }; 75 | 76 | void forward_transpose(const my_int& idx, Vec& out, 77 | const my_int& obegin, const my_int& oend){ 78 | assert(out.size() == _data.ncol()); 79 | for(my_int x = obegin; x < oend; x++) 80 | out[x] += _data(idx, x); 81 | }; 82 | 83 | void forward(const Vec& in, Vec& out, 84 | const my_int& ibegin, const my_int& iend, 85 | const my_int& obegin, const my_int& oend){ 86 | assert(obegin >= 0); 87 | assert(oend <= _data.nrow()); 88 | assert(oend <= out.size()); 89 | assert(ibegin >= 0); 90 | assert(iend <= _data.ncol()); 91 | assert(iend <= in.size()); 92 | matrixXvector(out, in, this->_data, obegin, oend, ibegin, iend, 0); 93 | } 94 | 95 | void forward(const Vec& in, Vec& out){ 96 | matrixXvector(out, in, this->_data, 0, out.size(), 0, in.size(), 0); 97 | }; 98 | 99 | /*** backward methods ***/ 100 | 101 | void backward(Vec& in, const Vec& out){ 102 | matrixXvector(in, out, this->_data, 0, out.size(), 0, in.size(), 1); 103 | }; 104 | 105 | void backward(Vec& in, const Vec& out, 106 | const my_int& ibegin, const my_int& iend, 107 | const my_int& obegin, const my_int& oend){ 108 | assert(obegin >= 0); 109 | assert(oend <= _data.nrow()); 110 | assert(oend <= out.size()); 111 | assert(ibegin >= 0); 112 | assert(iend <= _data.ncol()); 113 | assert(iend <= in.size()); 114 | matrixXvector(in, out, this->_data, obegin, oend, ibegin, iend, 1); 115 | } 116 | 117 | /*** gradient methods ***/ 118 | 119 | void resetGradient(){ 120 | this->_gradient.zeros(); 121 | }; 122 | 123 | void computeGradient(const my_int& idx ,const Vec& out){ 124 | for(my_int i = 0; i < _gradient.nrow(); i++) 125 | _gradient(i, idx) += out[i]; // gradient += out * in'; 126 | }; 127 | 128 | void computeGradient_transpose(const my_int& idx ,const Vec& out){ 129 | for(my_int i = 0; i < _gradient.ncol(); i++) 130 | _gradient(idx,i) += out[i]; // gradient += out * in'; 131 | }; 132 | 133 | void computeGradient_transpose(const my_int& idx ,const Vec& out, 134 | const my_int& obegin, const my_int& oend){ 135 | for(my_int i = obegin; i < oend; i++) 136 | _gradient(idx,i) += out[i]; // gradient += out * in'; 137 | }; 138 | 139 | void computeGradient(const Vec& in ,const Vec& out){ 140 | computeGradient(in, out, 0, in.size(), 0, out.size()); 141 | }; 142 | 143 | void computeGradient(const Vec& in, const Vec& out, 144 | const my_int& ibegin, const my_int& iend, 145 | const my_int& obegin, const my_int& oend){ 146 | assert(obegin >= 0); 147 | assert(oend <= _gradient.nrow()); 148 | assert(oend <= out.size()); 149 | assert(ibegin >= 0); 150 | assert(iend <= _gradient.ncol()); 151 | assert(iend <= in.size()); 152 | for(my_int o = obegin; o < oend; o++){ 153 | for(my_int i = ibegin; i < iend; i++){ 154 | _gradient(o, i) += out[o] * in[i]; 155 | } 156 | } 157 | } 158 | 159 | void update(const my_real& lr){ 160 | for(my_int i =0; i < this->size(); i++) 161 | this->_data[i] += lr * this->_gradient[i]; 162 | } 163 | 164 | // TODO make that private 165 | 166 | Vec2D _data; 167 | Vec2D _gradient; 168 | 169 | }; 170 | 171 | 172 | }// end namespace rnn 173 | 174 | #endif 175 | -------------------------------------------------------------------------------- /Nonlinearity.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #ifndef _NONLINEARITY_ 11 | #define _NONLINEARITY_ 12 | #include 13 | #include 14 | 15 | #include "common.h" 16 | #include "Vec.h" 17 | 18 | namespace rnn{ 19 | struct Softmax{ 20 | 21 | void static forward(Vec& v, my_int b = 0, my_int e = -1){ 22 | if(e == -1) e = v.size(); 23 | my_real max=v[b], denom = 0; 24 | for(my_int i = b; i < e; i++) 25 | if(v[i] > max) max = v[i]; 26 | 27 | for(my_int i = b; i < e; i++){ 28 | v[i] = exp(v[i]-max); 29 | denom += v[i]; 30 | } 31 | for(my_int i = b; i < e; i++) 32 | v[i] = v[i] / denom; 33 | } 34 | 35 | void static backward(Vec& err, const Vec& v, my_int b = 0, my_int e = -1){ 36 | if(e == -1) e = v.size(); 37 | 38 | Vec grad = err; 39 | for(my_int i = b; i < e; i++){ 40 | grad[i] = err[i] * v[i]; 41 | for(my_int j = b; j < e; j++) 42 | grad[i] -= err[j] * v[j] * v[i]; 43 | } 44 | err = grad; 45 | } 46 | 47 | 48 | }; 49 | 50 | struct Sigmoid{ 51 | 52 | void static forward(my_real& v){ 53 | if(v > 50 ) v = 50; 54 | if(v < -50 ) v = -50; 55 | v = 1 / (1 + exp(-v)); 56 | } 57 | 58 | void static forward(Vec& v, my_int b = -1, my_int e = -1){ 59 | if(b == - 1) b =0; 60 | if(e == -1 ) e = v.size(); 61 | for (my_int i = b; i < e; i++) 62 | { 63 | if(v[i] > 50 ) v[i] = 50; 64 | if(v[i] < -50 ) v[i] = -50; 65 | v[i] = 1 / (1 + exp(-v[i])); 66 | } 67 | } 68 | 69 | void static backward(my_real& err, const my_real& v){ 70 | err = err * v * (1 - v); 71 | } 72 | 73 | 74 | void static backward(Vec& err, const Vec& v, my_int b = 0, my_int e = -1){ 75 | if(e == -1) e = err.size(); 76 | for(my_int i = b; i < e; i++) 77 | err[i] = err[i] * (v[i] * (1 - v[i])); 78 | } 79 | 80 | }; 81 | 82 | 83 | } 84 | 85 | #endif 86 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the Stack-RNN software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stack RNN 2 | Stack RNN is a project gathering the code from the paper 3 | *Inferring Algorithmic Patterns with Stack-Augmented Recurrent Nets* by Armand Joulin and Tomas Mikolov ([pdf](http://arxiv.org/abs/1503.01007)). 4 | In this research project, we focus on extending Recurrent Neural Networks (RNN) with a stack to allow them to learn sequences which require 5 | some form of persistent memory. 6 | 7 | Examples are given in the script `script_tasks.sh`. The code is still under construction. 8 | We are working on releasing the code for the list RNN. If you have any suggestion, please let us know (contacts below). 9 | 10 | 11 | ## Examples 12 | To run the code on a task: 13 | ``` 14 | > make toy 15 | > ./train_toy -ntask 1 -nchar 2 -nhid 10 -nstack 1 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 16 | ``` 17 | To run the code on binary addition: 18 | ``` 19 | > make add 20 | > ./train_add 21 | ``` 22 | 23 | ## Requirements 24 | Stack RNN works on: 25 | * Mac OS X 26 | * Linux 27 | 28 | It was not tested on Windows. To compile the code a relatively recent version of g++ is required. 29 | 30 | ## Building Stack RNN 31 | Run `make` to compile everything. 32 | 33 | 34 | ## Options 35 | For more help about the options: 36 | ``` 37 | > make toy 38 | > ./train_toy --help 39 | ``` 40 | Note that `train_add` can take the same options as `train_toy`. 41 | 42 | 43 | ## Join the Stack RNN community 44 | * Paper: http://arxiv.org/abs/1503.01007 45 | * Facebook page: https://www.facebook.com/fair 46 | * Contact: ajoulin@fb.com 47 | 48 | See the CONTRIBUTING file for how to help out. 49 | 50 | ## License 51 | Stack RNN is BSD-licensed. We also provide an additional patent grant 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /StackRNN.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #ifndef _STACK_RNN_ 11 | #define _STACK_RNN_ 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "common.h" 19 | #include "Vec.h" 20 | #include "Linear.h" 21 | #include "Nonlinearity.h" 22 | 23 | 24 | #define EMPTY_STACK_VALUE -1 25 | 26 | namespace rnn 27 | { 28 | 29 | enum {push, pop, noop}; 30 | 31 | struct StackRNN 32 | { 33 | public: 34 | 35 | StackRNN(const std::string& filename) 36 | { 37 | load(filename); 38 | emptyStacks(); 39 | } 40 | 41 | StackRNN(my_int si, 42 | my_int sh, 43 | my_int nstack, 44 | my_int stack_capacity, 45 | my_int so, 46 | my_int sm, 47 | my_int bptt_step, 48 | my_int mod = 1, 49 | bool isnoop = false, 50 | my_int depth = 1, 51 | my_real reg = 0) : 52 | _reg(reg), // regularization by entropy -- NOT USED 53 | _count(0), 54 | _HIDDEN(sh), // size of the hidden layer 55 | _NB_STACK(nstack), // number of stacks 56 | _STACK_SIZE(stack_capacity), // stacks capacity - this is currently fix TODO make it flexible 57 | _ACTION(2 + ((isnoop)?1:0)), // size of the action layer 58 | _TOP_OF_STACK(0), // index of the top of the stack 59 | _BPTT(sm), // length of the bptt 60 | _BPTT_STEP(bptt_step), // step of bptt (how often backprop is perform) 61 | _IN(si), // size of the input layer 62 | _OUT(so), // size of the output layer 63 | _it_mem(_BPTT - 1),// iterator for the circular buffer 64 | _mod(mod), // mod=0 -> no-rec, mod=1 -> rec with stack, mod=2 -> rec through stack+full 65 | _DEPTH(depth), // depth used to predict next hidden units from stacks 66 | _in2hidTranspose(_HIDDEN, _IN), 67 | _hid2act(_NB_STACK, Linear(_HIDDEN, _ACTION)), 68 | _hid2hid(_HIDDEN, _HIDDEN), 69 | _hid2stack(_NB_STACK, Linear(_HIDDEN, _STACK_SIZE)), 70 | _stack2hid(_NB_STACK, Linear(_STACK_SIZE, _HIDDEN)), 71 | _hid2out(_HIDDEN, _OUT), 72 | _in(_BPTT,0), 73 | _hid(_BPTT, Vec (_HIDDEN, 0)), 74 | _act(_NB_STACK, std::vector(_BPTT, Vec (_ACTION, 0))), 75 | _stack(_NB_STACK, std::vector(_BPTT, Vec (_STACK_SIZE, 0))), 76 | _out(_BPTT, Vec(_OUT, 0)), 77 | _targets(_BPTT, 0), 78 | _err_out(_OUT, 0), 79 | _err_hid (_HIDDEN, 0), 80 | _err_stack(_NB_STACK, Vec(_STACK_SIZE, 0)), 81 | _err_act(_NB_STACK, Vec(_ACTION,0)), 82 | _pred_err_stack(_NB_STACK, Vec(_STACK_SIZE,0)), 83 | _pred_err_hid(_HIDDEN,0), 84 | _isemptied(_BPTT, false) 85 | { 86 | this->initialize(); 87 | }; 88 | 89 | 90 | void initialize() 91 | { 92 | 93 | // initialize input to output linear layer: 94 | _in2hidTranspose.initialize(); 95 | 96 | if(_mod != 2) _hid2hid.zeros(); 97 | else _hid2hid.initialize(); 98 | 99 | //initialize transition between action, hidden and top of stack: 100 | for(my_int i = 0; i <_NB_STACK;i++) 101 | { 102 | _hid2act[i].initialize(); 103 | _hid2stack[i].initialize(); 104 | _stack2hid[i].initialize(); 105 | } 106 | 107 | for(my_int s = 0; s <_NB_STACK;s++) 108 | { 109 | for(my_int j = 0; j < _HIDDEN; j++) 110 | for(my_int i = _TOP_OF_STACK +_DEPTH; i < _TOP_OF_STACK + _STACK_SIZE; i++) 111 | _stack2hid[s]._data(j,i) = 0; 112 | for(my_int i = _TOP_OF_STACK +1; i < _TOP_OF_STACK + _STACK_SIZE; i++) 113 | for(my_int j = 0; j < _HIDDEN; j++) 114 | _hid2stack[s]._data(i,j) = 0; 115 | } 116 | 117 | // initialize hidden to output linear layer: 118 | _hid2out.initialize(); 119 | 120 | // initialize the stack with empty value: 121 | emptyStacks(); 122 | }; 123 | 124 | 125 | void emptyStacks() 126 | { 127 | if(_NB_STACK == 0) return; 128 | _count = 0; 129 | my_int m = _it_mem; 130 | _isemptied[m] = true; 131 | for(my_int s = 0; s <_NB_STACK;s++) 132 | for(my_int i = _TOP_OF_STACK; i < _TOP_OF_STACK + _STACK_SIZE; i++) 133 | _stack[s][m][i] = EMPTY_STACK_VALUE; 134 | } 135 | 136 | void forward(const my_int& cur, const my_int& target, bool ishard = false) 137 | { 138 | 139 | // increment iterator on memory 140 | my_int old_it = _it_mem; 141 | _it_mem = ( _it_mem + 1) % _in.size(); 142 | 143 | _isemptied[_it_mem] = false; 144 | 145 | // zeros the current hidden states: 146 | _out[_it_mem].zeros(); 147 | _hid[_it_mem].zeros(); 148 | for(my_int s = 0; s <_NB_STACK; s++) 149 | { 150 | _act[s][_it_mem].zeros(); 151 | _stack[s][_it_mem].zeros(); 152 | } 153 | 154 | //copy current word and target word in in memory 155 | _targets[ _it_mem ] = target; 156 | _in[ _it_mem ] = cur; 157 | 158 | // forward propagation from input to hidden: 159 | _in2hidTranspose.forward_transpose(cur, _hid[_it_mem]); 160 | 161 | // forward from hidden to hidden: 162 | // (hidden + top of stack) (t-1) -> hidden (t): 163 | // mod = 1 -> recurrent only through stack 164 | // mod = 2 -> full hidden 165 | // mod 0 -> no recurrent 166 | if( _mod != 0) 167 | { 168 | // previous top of stack -> current hidden 169 | for(my_int s = 0; s <_NB_STACK;s++) 170 | { 171 | _stack2hid[s].forward(_stack[s][old_it], _hid[_it_mem], 172 | _TOP_OF_STACK, _TOP_OF_STACK + _DEPTH, 0, _HIDDEN); 173 | } 174 | } 175 | 176 | if(_mod == 2) 177 | { 178 | // previous hidden (t-1) -> current hidden (t) 179 | _hid2hid.forward(_hid[old_it], _hid[_it_mem]); 180 | } 181 | 182 | // nonlinearity on the hidden: 183 | Sigmoid::forward(_hid[_it_mem]); 184 | 185 | for(my_int s = 0; s <_NB_STACK;s++) 186 | { 187 | // current hidden -> current action: 188 | _hid2act[s].forward(_hid[_it_mem], _act[s][_it_mem]); 189 | 190 | // non linearity 191 | // action 192 | Softmax::forward(_act[s][_it_mem]); 193 | if(ishard) 194 | { 195 | //if it s discretize, i.e. take the most probable action: 196 | my_int im =0; my_real pm = _act[s][_it_mem][0]; 197 | _act[s][_it_mem][0] = 0; 198 | for(my_int i = 1; i < _ACTION; i++) 199 | { 200 | if( pm < _act[s][_it_mem][i]) 201 | { 202 | im = i; 203 | pm = _act[s][_it_mem][i]; 204 | } 205 | _act[s][_it_mem][i] = 0; 206 | } 207 | _act[s][_it_mem][im] = 1; 208 | } 209 | 210 | my_real pop_weight = _act[s][_it_mem][pop]; 211 | my_real push_weight = _act[s][_it_mem][push]; 212 | 213 | // (action + hidden) -> (stack): 214 | 215 | // in case of push: 216 | // push from the top to the bottom: 217 | for(my_int i = _TOP_OF_STACK + 1; i < _STACK_SIZE; i++) 218 | _stack[s][_it_mem][i] += _stack[s][old_it][i-1] * push_weight; 219 | 220 | // the push on the top of the stack is weighted by push action: 221 | _stack[s][_it_mem][_TOP_OF_STACK] = 0; 222 | for(my_int i = 0; i < _HIDDEN; i++) 223 | _stack[s][_it_mem][_TOP_OF_STACK] += _hid2stack[s]._data(_TOP_OF_STACK, i) * _hid[_it_mem][i]; 224 | // add a non-linearity on the top of the stack: 225 | if(_stack[s][_it_mem][_TOP_OF_STACK] > 50) 226 | _stack[s][_it_mem][_TOP_OF_STACK] = 50; 227 | if(_stack[s][_it_mem][_TOP_OF_STACK] < -50) 228 | _stack[s][_it_mem][_TOP_OF_STACK] = -50; 229 | _stack[s][_it_mem][_TOP_OF_STACK] = 1 / ( 1 + exp( - _stack[s][_it_mem][_TOP_OF_STACK] ) ); 230 | 231 | _stack[s][_it_mem][_TOP_OF_STACK] *= push_weight; 232 | 233 | // in case of pop: 234 | for(my_int i = _TOP_OF_STACK; i < _STACK_SIZE - 1; i++) 235 | _stack[s][_it_mem][i] += _stack[s][old_it][i+1] * pop_weight; 236 | 237 | // last element of the stack get an empty value: 238 | _stack[s][_it_mem][_STACK_SIZE - 1] += EMPTY_STACK_VALUE * pop_weight; 239 | 240 | // in case of no-op: 241 | if(_ACTION == 3) 242 | { 243 | my_real noop_weight = _act[s][_it_mem][noop]; 244 | for(my_int i = _TOP_OF_STACK; i < _TOP_OF_STACK + _STACK_SIZE; i++) 245 | _stack[s][_it_mem][i] += _stack[s][old_it][i] * noop_weight; 246 | } 247 | } 248 | 249 | // propagation from hidden to out: 250 | _hid2out.forward(_hid[_it_mem], _out[_it_mem]); 251 | Softmax::forward(_out[_it_mem]); 252 | } 253 | 254 | 255 | void backward() 256 | { 257 | 258 | // put gradient to zeros: 259 | _in2hidTranspose.resetGradient(); 260 | _hid2hid.resetGradient(); 261 | _hid2out.resetGradient(); 262 | 263 | for(my_int s = 0; s <_NB_STACK;s++) 264 | { 265 | _hid2stack[s].resetGradient(); 266 | _hid2act[s].resetGradient(); 267 | _stack2hid[s].resetGradient(); 268 | } 269 | 270 | _err_hid.zeros(); 271 | 272 | for(my_int s = 0; s <_NB_STACK;s++) 273 | { 274 | _err_stack[s].zeros(); 275 | _err_act[s].zeros(); 276 | } 277 | 278 | my_int itm = _it_mem, count = 0; 279 | _count++; 280 | 281 | //back prog through time 282 | while(count < std::min(_BPTT,_count)) 283 | { 284 | 285 | if(_mod != 2) _err_hid.zeros(); 286 | 287 | //out -> hidden 288 | if( count < _BPTT_STEP) 289 | { 290 | // backprop through softmax: 291 | for(my_int i = 0; i < _OUT; i++) _err_out[i] = -_out[itm][i]; 292 | _err_out[_targets[itm]] +=1; 293 | 294 | // Compute gradient from hidden -> out 295 | _hid2out.computeGradient(_hid[itm], _err_out); 296 | 297 | //propagate error from out -> hidden 298 | _hid2out.backward(_err_hid, _err_out); 299 | 300 | // clip the error: 301 | hardclipping(_err_hid); 302 | } 303 | 304 | if(_isemptied[itm]) break; 305 | 306 | _pred_err_hid.zeros(); 307 | 308 | my_int old_it = itm - 1; 309 | if(old_it < 0) old_it = _in.size() - 1; 310 | 311 | for(my_int s = 0; s <_NB_STACK;s++) 312 | { 313 | 314 | _err_act[s].zeros(); 315 | _pred_err_stack[s].zeros(); 316 | 317 | if(itm == _it_mem) 318 | { 319 | for(my_int a = 0; a < _ACTION; a++) 320 | { 321 | _err_act[s][a] = _reg * ( log(_act[s][itm][a] + 1e-16) + 1); 322 | } 323 | } 324 | 325 | // gradient of hidden -> top of stack (due to push): 326 | // this is ugly but required: the gradient of hid->stack apply to the value before the sigmoid, 327 | // I don t store that value, so I need to recompute it (it would be better to simply store it...) 328 | my_real tmp_top_stack_in = 0; 329 | for(my_int i = 0; i < _HIDDEN; i++) 330 | { 331 | tmp_top_stack_in += _hid2stack[s]._data(_TOP_OF_STACK, i) * _hid[itm][i]; 332 | } 333 | if(tmp_top_stack_in > 50) tmp_top_stack_in = 50; 334 | if(tmp_top_stack_in < -50) tmp_top_stack_in = -50; 335 | tmp_top_stack_in = 1 / (1 + exp( - tmp_top_stack_in)); 336 | 337 | my_real tmp_top_stack_err = _err_stack[s][_TOP_OF_STACK]; 338 | tmp_top_stack_err *= _act[s][itm][push]; 339 | tmp_top_stack_err *= tmp_top_stack_in * ( 1 - tmp_top_stack_in); 340 | 341 | if(tmp_top_stack_err > 15) tmp_top_stack_err = 15; 342 | if(tmp_top_stack_err < -15) tmp_top_stack_err = -15; 343 | 344 | // gradient if hid -> stack 345 | for(my_int i = 0; i < _HIDDEN; i++) 346 | { 347 | _hid2stack[s]._gradient(_TOP_OF_STACK, i) += _hid[itm][i] * tmp_top_stack_err; 348 | } 349 | // propagate error from stack(t) -> stack(t-1) 350 | for(my_int i = _TOP_OF_STACK; i < _TOP_OF_STACK + _STACK_SIZE - 1; i++) 351 | { 352 | _pred_err_stack[s][i+1] += _err_stack[s][i] * _act[s][itm][pop]; 353 | } 354 | // propagate error from stack(t) -> action[pop] 355 | for(my_int i = _TOP_OF_STACK; i < _TOP_OF_STACK + _STACK_SIZE - 1; i++) 356 | { 357 | _err_act[s][pop] += _err_stack[s][i] * _stack[s][old_it][i+1]; 358 | } 359 | _err_act[s][pop] += _err_stack[s][_TOP_OF_STACK + _STACK_SIZE - 1] * EMPTY_STACK_VALUE; 360 | 361 | // in case of push: 362 | // push from the top to the bottom: 363 | for(my_int i = _TOP_OF_STACK + 1; i < _TOP_OF_STACK + _STACK_SIZE; i++) 364 | { 365 | _pred_err_stack[s][i-1] += _err_stack[s][i] * _act[s][itm][push]; 366 | } 367 | for(my_int i = _TOP_OF_STACK + 1; i < _TOP_OF_STACK + _STACK_SIZE; i++) 368 | { 369 | _err_act[s][push] += _err_stack[s][i] * _stack[s][old_it][i-1]; 370 | } 371 | // propagate error from stack to action + hidden 372 | for(my_int i = 0; i < _HIDDEN; i++) 373 | { 374 | _err_hid[i] += _hid2stack[s]._data(_TOP_OF_STACK, i) * tmp_top_stack_err; 375 | } 376 | _err_act[s][push] += _err_stack[s][_TOP_OF_STACK] * tmp_top_stack_in; 377 | 378 | // in case of no-op action: 379 | if(_ACTION == 3) 380 | { 381 | for(my_int i = _TOP_OF_STACK; i < _TOP_OF_STACK + _STACK_SIZE; i++) 382 | { 383 | _pred_err_stack[s][i] += _err_stack[s][i] * _act[s][itm][noop]; 384 | } 385 | for(my_int i = _TOP_OF_STACK; i < _TOP_OF_STACK + _STACK_SIZE; i++) 386 | { 387 | _err_act[s][noop] += _err_stack[s][i] * _stack[s][old_it][i]; 388 | } 389 | } 390 | hardclipping(_err_act[s]); 391 | hardclipping(_pred_err_stack[s]); 392 | 393 | Softmax::backward(_err_act[s], _act[s][itm]); 394 | 395 | hardclipping(_err_act[s]); 396 | 397 | // gradient of hidden -> action: 398 | _hid2act[s].computeGradient( _hid[itm], _err_act[s]); 399 | 400 | // propagate error from action -> hidden: 401 | _hid2act[s].backward(_err_hid, _err_act[s]); 402 | } 403 | 404 | // at that point: err_hid = err_from_out + err_from_top_stack + err_from_action 405 | //propagate error on hidden layer through non-linearity: 406 | Sigmoid::backward(_err_hid, _hid[itm]); 407 | 408 | // clip the error: 409 | hardclipping(_err_hid); 410 | 411 | // compute contribution of the hidden to the gradient of in2hid: 412 | _in2hidTranspose.computeGradient_transpose(_in[itm], _err_hid); 413 | 414 | //propagate error in the past: 415 | 416 | itm = old_it; 417 | 418 | // stop before doing last propagaton from hidden to hidden 419 | if(count == _BPTT - 1) break; 420 | 421 | if(_mod != 0) 422 | { 423 | // compute gradient of (hidden + top of stack) -> hidden 424 | for(my_int s = 0; s <_NB_STACK;s++) 425 | { 426 | _stack2hid[s].computeGradient( _stack[s][itm], _err_hid, 427 | _TOP_OF_STACK, _TOP_OF_STACK + _DEPTH, 428 | 0, _HIDDEN); 429 | // Propagate error from hidden -> top of stack 430 | _stack2hid[s].backward(_pred_err_stack[s], _err_hid, 431 | _TOP_OF_STACK, _TOP_OF_STACK + _DEPTH, 432 | 0, _HIDDEN); 433 | hardclipping(_pred_err_stack[s]); 434 | } 435 | } 436 | if(_mod == 2) 437 | { 438 | // compute gradient of (hidden ) -> hidden 439 | _hid2hid.computeGradient( _hid[itm], _err_hid); 440 | // Propagate error from hidden -> (hidden + top of stack) 441 | _hid2hid.backward(_pred_err_hid, _err_hid); 442 | } 443 | 444 | for(my_int i = 0; i < _HIDDEN; i++) 445 | { 446 | _err_hid[i] = _pred_err_hid[i]; 447 | } 448 | hardclipping(_err_hid); 449 | 450 | for(my_int s = 0; s <_NB_STACK;s++) 451 | { 452 | for(my_int i = 0; i < _STACK_SIZE; i++) 453 | _err_stack[s][i] = _pred_err_stack[s][i]; 454 | hardclipping(_err_stack[s]); 455 | } 456 | count++; 457 | } 458 | } 459 | 460 | void update(const my_real& lr) 461 | { 462 | _hid2out.update(lr); 463 | if(_mod == 2) _hid2hid.update(lr); 464 | for(my_int s = 0; s <_NB_STACK;s++) 465 | { 466 | _hid2act[s].update(lr); 467 | _stack2hid[s].update(lr); 468 | _hid2stack[s].update(lr); 469 | } 470 | _in2hidTranspose.update(lr); 471 | } 472 | 473 | my_real eval(const my_int& target) const { 474 | return _out[_it_mem][target]; 475 | } 476 | 477 | my_int pred() const { 478 | my_int pred = 0; 479 | my_real pv = _out[_it_mem][0]; 480 | for(my_int i = 1; i <_OUT; i++) 481 | { 482 | if(pv < _out[_it_mem][i]) 483 | { 484 | pred = i; pv =_out[_it_mem][i]; 485 | } 486 | } 487 | return pred; 488 | } 489 | 490 | /*************************************************************************************/ 491 | 492 | 493 | void copy(StackRNN shrnn) 494 | { 495 | assert(_IN == shrnn._IN); 496 | assert(_HIDDEN == shrnn._HIDDEN); 497 | assert(_OUT == shrnn._OUT); 498 | assert(_BPTT == shrnn._BPTT); 499 | assert(_ACTION == shrnn._ACTION); 500 | assert(_STACK_SIZE == shrnn._STACK_SIZE); 501 | assert(_NB_STACK == shrnn._NB_STACK); 502 | assert(_DEPTH == shrnn._DEPTH); 503 | 504 | _it_mem = shrnn._it_mem; 505 | _mod = shrnn._mod; 506 | 507 | _in2hidTranspose._data = shrnn._in2hidTranspose._data; 508 | _hid2hid._data = shrnn._hid2hid._data; 509 | _hid2out._data = shrnn._hid2out._data; 510 | 511 | for(my_int s = 0; s < _NB_STACK; s++) 512 | { 513 | _hid2stack[s]._data = shrnn._hid2stack[s]._data; 514 | _hid2act[s]._data = shrnn._hid2act[s]._data; 515 | _stack2hid[s]._data = shrnn._stack2hid[s]._data; 516 | for(my_int m = 0; m < _BPTT; m++) 517 | { 518 | _act[s][m] = shrnn._act[s][m]; 519 | _stack[s][m] = shrnn._stack[s][m]; 520 | } 521 | } 522 | 523 | for(my_int m = 0; m < _BPTT; m++) 524 | { 525 | _out[m] = shrnn._out[m]; 526 | _in[m] = shrnn._in[m]; 527 | _hid[m] = shrnn._hid[m]; 528 | _targets[m] = shrnn._targets[m]; 529 | _isemptied[m] = shrnn._isemptied[m]; 530 | } 531 | 532 | } 533 | 534 | 535 | void save(std::string filename) 536 | { 537 | FILE* f; 538 | f= fopen(filename.c_str(),"w"); 539 | fprintf(f, "%d %d %d %d %d %d %d %d %d %d\n", _IN, _ACTION, _HIDDEN, _NB_STACK, _STACK_SIZE, _OUT, 540 | _BPTT, _BPTT_STEP, _mod, _DEPTH); 541 | for(my_int i = 0; i < _in2hidTranspose.size(); i++) fprintf(f, "%f,", _in2hidTranspose._data[i]); 542 | for(my_int i = 0; i < _hid2hid.size(); i++) fprintf(f, "%f,", _hid2hid._data[i]); 543 | for(my_int s = 0; s <_NB_STACK;s++) 544 | { 545 | for(my_int i = 0; i < _hid2act[s].size(); i++) fprintf(f, "%f,", _hid2act[s]._data[i]); 546 | for(my_int i = 0; i < _hid2stack[s].size(); i++) fprintf(f, "%f,", _hid2stack[s]._data[i]); 547 | for(my_int i = 0; i < _stack2hid[s].size(); i++) fprintf(f, "%f,", _stack2hid[s]._data[i]); 548 | } 549 | for(my_int i = 0; i < _hid2out.size(); i++) fprintf(f, "%f,", _hid2out._data[i]); 550 | fclose(f); 551 | } 552 | 553 | void load(const std::string& filename) 554 | { 555 | FILE* f; 556 | f= fopen(filename.c_str(),"r"); 557 | fscanf(f, "%d %d %d %d %d %d %d %d %d %d\n", &_IN, &_ACTION, &_HIDDEN, &_NB_STACK, &_STACK_SIZE, &_OUT, 558 | &_BPTT, &_BPTT_STEP, &_mod, &_DEPTH); 559 | _TOP_OF_STACK = 0; 560 | _in2hidTranspose = Linear(_HIDDEN, _IN); 561 | _hid2hid = Linear(_HIDDEN, _HIDDEN); 562 | _hid2stack = std::vector(_NB_STACK, Linear(_HIDDEN, _STACK_SIZE)); 563 | _stack2hid = std::vector(_NB_STACK, Linear(_STACK_SIZE, _HIDDEN)); 564 | _hid2act = std::vector(_NB_STACK, Linear(_HIDDEN, _ACTION)); 565 | _hid2out = Linear(_HIDDEN, _OUT); 566 | for(my_int i = 0; i < _in2hidTranspose.size(); i++) fscanf(f, "%lf,", &_in2hidTranspose._data[i]); 567 | for(my_int i = 0; i < _hid2hid.size(); i++) fscanf(f, "%lf,", &_hid2hid._data[i]); 568 | for(my_int s = 0; s <_NB_STACK;s++) 569 | { 570 | for(my_int i = 0; i < _hid2act[s].size(); i++) fscanf(f, "%lf,", &_hid2act[s]._data[i]); 571 | for(my_int i = 0; i < _hid2stack[s].size(); i++) fscanf(f, "%lf,", &_hid2stack[s]._data[i]); 572 | for(my_int i = 0; i < _stack2hid[s].size(); i++) fscanf(f, "%lf,", &_stack2hid[s]._data[i]); 573 | } 574 | for(my_int i = 0; i < _hid2out.size(); i++) fscanf(f, "%lf,", &_hid2out._data[i]); 575 | fclose(f); 576 | _isemptied = std::vector< bool >(_BPTT, false); 577 | _it_mem = _BPTT - 1; 578 | _in = std::vector(_BPTT, 0); 579 | _hid = std::vector(_BPTT, Vec (_HIDDEN,0)); 580 | _act = std::vector< std::vector >( 581 | _NB_STACK, std::vector(_BPTT, Vec (_ACTION, 0))); 582 | _stack = std::vector< std::vector >( 583 | _NB_STACK, std::vector(_BPTT, Vec (_STACK_SIZE, 0))); 584 | _out = std::vector(_BPTT, Vec(_OUT,0)); 585 | _targets = std::vector( _BPTT,0); 586 | _err_out = Vec(_OUT, 0); 587 | _err_hid = Vec(_HIDDEN, 0); 588 | _err_act = std::vector ( 589 | _NB_STACK, Vec(_ACTION,0)); 590 | _err_stack = std::vector ( 591 | _NB_STACK, Vec(_STACK_SIZE, 0)); 592 | _pred_err_hid = Vec(_HIDDEN,0); 593 | _pred_err_stack = std::vector ( 594 | _NB_STACK, Vec(_STACK_SIZE,0)); 595 | _reg = 0; 596 | _count = 0; 597 | } 598 | 599 | // TODO: make this private: 600 | 601 | my_real _reg; 602 | 603 | my_int _count; 604 | 605 | my_int _HIDDEN; 606 | my_int _NB_STACK; 607 | my_int _STACK_SIZE; 608 | my_int _ACTION; 609 | my_int _TOP_OF_STACK; 610 | my_int _BPTT; 611 | my_int _BPTT_STEP; 612 | my_int _IN; 613 | my_int _OUT; 614 | my_int _it_mem; 615 | my_int _mod; 616 | my_int _DEPTH; 617 | 618 | Linear _in2hidTranspose; 619 | std::vector< Linear > _hid2act; 620 | Linear _hid2hid; 621 | std::vector< Linear > _hid2stack; 622 | std::vector< Linear > _stack2hid; 623 | Linear _hid2out; 624 | 625 | std::vector _in; 626 | std::vector< Vec > _hid; 627 | std::vector< std::vector< Vec > > _act; 628 | std::vector< std::vector< Vec > > _stack; 629 | std::vector< Vec > _out; 630 | std::vector _targets; 631 | 632 | Vec _err_out; 633 | Vec _err_hid; 634 | std::vector< Vec > _err_stack; 635 | std::vector< Vec > _err_act; 636 | std::vector< Vec > _pred_err_stack; 637 | Vec _pred_err_hid; 638 | 639 | std::vector< bool > _isemptied; 640 | 641 | }; 642 | 643 | 644 | 645 | } // end namespace 646 | #endif 647 | 648 | 649 | 650 | 651 | 652 | -------------------------------------------------------------------------------- /Vec.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #ifndef _VEC_ 11 | #define _VEC_ 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include "common.h" 22 | 23 | 24 | /************************************************ 25 | 26 | Vec is a vector class use for matrix computation 27 | along with Vec2D. 28 | 29 | ********************************************/ 30 | 31 | namespace rnn { 32 | 33 | class Vec{ 34 | public: 35 | typedef my_int size_type; 36 | typedef my_real value_type; 37 | typedef my_real* iterator; 38 | typedef const my_real* const_iterator; 39 | typedef my_real& reference; 40 | typedef const my_real& const_reference; 41 | 42 | 43 | /*** Constructors ***/ 44 | 45 | Vec() {create();} 46 | 47 | explicit Vec(size_type s, const_reference v = my_real()) {create(s,v);}; 48 | 49 | Vec( const Vec& v){ create(v.begin(), v.end()); } 50 | 51 | 52 | /*** Destructors ***/ 53 | 54 | ~Vec() { this->uncreate();} 55 | 56 | /*** Iterators ***/ 57 | 58 | iterator begin() { return this->_begin;} 59 | const_iterator begin() const { return this->_begin;} 60 | iterator end() { return this->_end;} 61 | const_iterator end() const { return this->_end;} 62 | 63 | /*** methods ***/ 64 | 65 | void zeros(){ 66 | for(iterator it = this->begin(); it != this->end(); it++) 67 | *it = 0; 68 | } 69 | 70 | size_type size() const { return this->_end - this->_begin;} 71 | 72 | 73 | /*** operators ***/ 74 | 75 | reference operator[] (size_type i){ return this->_begin[i];} 76 | const_reference operator[] (size_type i) const { return this->_begin[i];} 77 | 78 | Vec& operator = (const Vec& rhs) { 79 | if( this != &rhs ){ 80 | this->uncreate(); 81 | this->create( rhs.begin(), rhs.end() ); 82 | } 83 | return *this; 84 | } 85 | 86 | 87 | protected: 88 | iterator _begin; 89 | iterator _end; 90 | /*** private ***/ 91 | 92 | void create(){ _begin = _end = NULL;} 93 | void create(size_type s, const_reference v = my_real()); 94 | void create(const_iterator begin, const_iterator end); 95 | 96 | void uncreate(); 97 | 98 | }; 99 | 100 | 101 | /********************************* Method definition ***************************************/ 102 | 103 | void Vec::create( Vec::size_type n, 104 | Vec::const_reference val){ 105 | this->_begin = (my_real*) calloc(n , sizeof(my_real)); 106 | this->_end = this->_begin + n; 107 | }; 108 | 109 | void Vec::create( Vec::const_iterator b, 110 | Vec::const_iterator e){ 111 | Vec::size_type n = e - b; 112 | this->_begin = (my_real*) calloc(n , sizeof(my_real)); 113 | this->_end = this->_begin + n; 114 | memcpy(this->_begin, b, sizeof(my_real) * n); 115 | } 116 | 117 | void Vec::uncreate(){ 118 | if( this->_begin != NULL ){ 119 | free(this->_begin); 120 | } 121 | this->_begin = this->_end = NULL; 122 | } 123 | 124 | class Vec2D : public Vec{ 125 | 126 | private: 127 | my_int _ncol; 128 | my_int _nrow; 129 | 130 | public: 131 | explicit Vec2D() {}; 132 | 133 | explicit Vec2D(my_int nr, 134 | my_int nc, 135 | Vec::const_reference v = my_real()) : 136 | Vec(nr*nc, v), _ncol(nc), _nrow(nr) {}; 137 | 138 | ~Vec2D() { this->uncreate();} 139 | 140 | 141 | my_int nrow() const { return this->_nrow;} 142 | my_int ncol() const { return this->_ncol;} 143 | 144 | Vec2D& operator= (const Vec2D& rhs) { 145 | Vec::operator=(rhs); 146 | this->_ncol = rhs._ncol; 147 | this->_nrow = rhs._nrow; 148 | return *this; 149 | } 150 | 151 | 152 | Vec::reference operator() (my_int i, 153 | my_int j){ 154 | return this->_begin[i * _ncol + j]; 155 | } 156 | Vec::const_reference operator() (my_int i, my_int j) const { 157 | return this->_begin[i * _ncol + j]; 158 | } 159 | 160 | }; 161 | 162 | 163 | } // end namespace rnn 164 | 165 | #endif 166 | -------------------------------------------------------------------------------- /common.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #ifndef _COMMON_ 11 | #define _COMMON_ 12 | 13 | namespace rnn { 14 | 15 | typedef int my_int; 16 | typedef double my_real; 17 | 18 | }; 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. An additional grant 6 | # of patent rights can be found in the PATENTS file in the same directory. 7 | 8 | 9 | CC = g++ 10 | CFLAGS = -std=c++0x -lm -O3 -march=native -Wall -funroll-loops -ffast-math 11 | 12 | all: toy add 13 | 14 | toy : train_toy.cpp 15 | $(CC) $(CFLAGS) $(OPT_DEF) train_toy.cpp -o train_toy 16 | 17 | add : train_add.cpp 18 | $(CC) $(CFLAGS) $(OPT_DEF) train_add.cpp -o train_add 19 | 20 | clean: 21 | rm train_toy train_add 22 | -------------------------------------------------------------------------------- /script_tasks.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2015-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | 11 | # compile: 12 | make toy 13 | 14 | 15 | # experiments made with our model with 40 hidden units and 10 stacks: 16 | 17 | # a^nb^n 18 | ./train_toy -ntask 1 -nchar 2 -nhid 40 -nstack 10 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 19 | # a^nb^nc^n 20 | ./train_toy -ntask 1 -nchar 3 -nhid 40 -nstack 10 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 21 | # a^nb^nc^nd^n 22 | ./train_toy -ntask 1 -nchar 4 -nhid 40 -nstack 10 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 23 | # a^nb^2n 24 | ./train_toy -ntask 2 -nchar 2 -nhid 40 -nstack 10 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 -nrep 2 25 | # example where discretization helps on a^nb^mc^{n+m}: 26 | ./train_toy -ntask 3 -nchar 3 -nhid 40 -nstack 10 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 27 | ./train_toy -ntask 3 -nchar 3 -nhid 40 -nstack 10 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 -hard 28 | 29 | # memorization (with smaller epochs i.e. nreset = 100 instead of 1000) 30 | ./train_toy -ntask 4 -nchar 3 -nhid 100 -nstack 10 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 -nreset 100 31 | # note that to reproduce the results in the paper, one needs to cycle over the seed and restart the ones which give 32 | # entropy above average. 33 | 34 | 35 | # experiments with noop: 36 | 37 | ./train_toy -ntask 1 -nchar 2 -nhid 40 -nstack 10 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 -nseq 10000 -noop 38 | # data/test_ntask1_nchar2_nhid40_nstack10_bptt50_mod1_depth2_noop1_nrep1_hard0_seed1_nseq10000_nmax20 39 | ./train_toy -ntask 1 -nchar 2 -nhid 40 -nstack 10 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 -nseq 10000 -noop -hard 40 | # data/test_ntask1_nchar2_nhid40_nstack10_bptt50_mod1_depth2_noop1_nrep1_hard1_seed1_nseq10000_nmax20 41 | ./train_toy -ntask 1 -nchar 3 -nhid 40 -nstack 10 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 -nseq 10000 -noop 42 | ./train_toy -ntask 1 -nchar 3 -nhid 40 -nstack 10 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 -nseq 10000 -noop -hard 43 | # data/test_ntask1_nchar3_nhid40_nstack10_bptt50_mod1_depth2_noop1_nrep1_hard0_seed1_nseq10000_nmax10 44 | ./train_toy -ntask 1 -nchar 4 -nhid 40 -nstack 10 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 -nseq 10000 -noop 45 | # data/test_ntask1_nchar4_nhid40_nstack10_bptt50_mod1_depth2_noop1_nrep1_hard0_seed1_nseq10000_nmax20 46 | ./train_toy -ntask 1 -nchar 4 -nhid 40 -nstack 10 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 -nseq 5000 -noop -hard 47 | # data/test_ntask1_nchar4_nhid40_nstack10_bptt50_mod1_depth2_noop1_nrep1_hard1_seed1_nseq5000_nmax20 48 | 49 | ./train_toy -ntask 2 -nchar 2 -nhid 40 -nstack 10 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 -nseq 10000 -noop -nrep 2 50 | # data/test_ntask2_nchar2_nhid40_nstack10_bptt50_mod1_depth2_noop1_nrep2_hard0_seed1_nseq10000_nmax15 51 | ./train_toy -ntask 2 -nchar 2 -nhid 40 -nstack 10 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 -nseq 1000 -noop -hard -nrep 2 52 | # data/test_ntask2_nchar2_nhid40_nstack10_bptt50_mod1_depth2_noop1_nrep2_hard1_seed1_nseq1000_nmax20 53 | 54 | ./train_toy -ntask 3 -nchar 3 -nhid 40 -nstack 10 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 -nseq 1000 -noop 55 | #data/test_ntask3_nchar3_nhid40_nstack10_bptt50_mod1_depth2_noop1_nrep1_hard0_seed1_nseq1000_nmax10 56 | 57 | # experiments made with our model with a small number of hidden units and stacks: 58 | 59 | ./train_toy -ntask 1 -nchar 2 -nhid 10 -nstack 1 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 60 | ./train_toy -ntask 1 -nchar 3 -nhid 10 -nstack 2 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 61 | ./train_toy -ntask 1 -nchar 3 -nhid 20 -nstack 2 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 62 | ./train_toy -ntask 1 -nchar 4 -nhid 10 -nstack 2 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 63 | ./train_toy -ntask 1 -nchar 4 -nhid 20 -nstack 2 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 -nseq 5000 64 | ./train_toy -ntask 2 -nchar 2 -nhid 20 -nstack 2 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 -nrep 2 65 | ./train_toy -ntask 2 -nchar 2 -nhid 20 -nstack 2 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 -nrep 2 -hard 66 | ./train_toy -ntask 2 -nchar 2 -nhid 20 -nstack 2 -lr .1 -nmax 15 -depth 2 -bptt 50 -mod 1 -nrep 3 67 | ./train_toy -ntask 3 -nchar 3 -nhid 10 -nstack 1 -lr .1 -nmax 20 -depth 2 -bptt 50 -mod 1 68 | ./train_toy -ntask 3 -nchar 3 -nhid 10 -nstack 1 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 -hard 69 | 70 | ./train_toy -ntask 4 -nchar 3 -nhid 40 -nstack 10 -lr .1 -nmax 10 -depth 2 -bptt 50 -mod 1 -nreset 100 71 | #example with depth 1: 72 | ./train_toy -ntask 1 -nchar 2 -nhid 20 -nstack 2 -lr .1 -nmax 20 -depth 1 -bptt 50 -mod 1 73 | ./train_toy -ntask 1 -nchar 2 -nhid 20 -nstack 2 -lr .1 -nmax 20 -depth 1 -bptt 50 -mod 1 -hard 74 | ./train_toy -ntask 2 -nchar 2 -nhid 20 -nstack 2 -lr .1 -nmax 15 -depth 1 -bptt 50 -mod 1 -nrep 2 75 | ./train_toy -ntask 2 -nchar 2 -nhid 20 -nstack 2 -lr .1 -nmax 15 -depth 1 -bptt 50 -mod 1 -nrep 2 -hard 76 | 77 | -------------------------------------------------------------------------------- /task.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #ifndef _TASK_ 11 | #define _TASK_ 12 | #include 13 | #include 14 | 15 | #include "common.h" 16 | 17 | namespace rnn { 18 | 19 | // a^nb^n, a^nb^nc^n, a^nb^nc^nd^n... 20 | std::string task1(const my_int nmax, const my_int nmin, my_int nchar){ 21 | my_int n = (rand() % (nmax-nmin)) + nmin ; 22 | std::string p( (nchar-1)* n + n, 'a'); 23 | for(my_int c = 1; c < nchar; c++){ 24 | for(my_int i = c * n; i < (c+1) * n; i++) 25 | p[i] = 'a' + c; 26 | } 27 | return p; 28 | } 29 | 30 | // a^nb^kn n>=1 31 | std::string task2(const my_int nmax, const my_int nmin, my_int nchar, my_int nrep = 2){ 32 | my_int n = (rand() % (nmax-nmin)) + nmin ; 33 | my_int c2 = rand() % (nchar-1) + 1; 34 | std::string p( n + nrep * n, 'a'); 35 | for(my_int i = n; i < n + nrep * n; i++) 36 | p[i] = 'a' + c2; 37 | return p; 38 | } 39 | 40 | 41 | // addition: a^nb^mc^{n+m} 42 | std::string task3(const my_int nmax, const my_int nmin, my_int nchar){ 43 | my_int n = (rand() % (nmax-nmin)) + nmin ; 44 | my_int m = (rand() % (n-1)) + 1 ; 45 | n = n - m; 46 | std::string p( (nchar-2)* (n+m) + m + n, 'a'); 47 | for(my_int i = n ; i < n + m; i++) 48 | p[i] = 'b'; 49 | for(my_int i = n + m ; i < p.size(); i++) 50 | p[i] = 'c'; 51 | return p; 52 | } 53 | 54 | // memorization string (see paper) 55 | std::string task4(const my_int nmax, const my_int nmin, my_int nchar){ 56 | my_int n = (rand() % (nmax-nmin)) + nmin ; 57 | std::string p( 2 * n + 1, 'a'); 58 | for(my_int i = 0 ; i < n; i++) 59 | p[i] = 'a' + (rand() % (nchar-1) + 1); 60 | for(my_int i = 0 ; i < n; i++) 61 | p[p.size()-1 - i] = p[i]; 62 | return p; 63 | } 64 | 65 | // multiplication a^nb^nc^{nm} 66 | std::string task5(const my_int nmax, const my_int nmin, my_int nchar){ 67 | my_int n = (rand() % (nmax-nmin)) + nmin ; 68 | my_int k = (rand() % (n-1)) + 1 ; 69 | n = n - k; 70 | my_int c1 = (rand() % (nchar -2)) + 2; 71 | std::string p( k + n + k * n, 'a'); 72 | for(my_int i = n; i < n + k; i++) 73 | p[i] = 'b'; 74 | for(my_int i = n + k; i < p.size(); i++) 75 | p[i] = 'a' + c1; 76 | return p; 77 | } 78 | 79 | // a^nb^mc^nd^m 80 | std::string task6(const my_int nmax, const my_int nmin, my_int nchar){ 81 | if(nchar != 4) exit(-1); 82 | my_int n = (rand() % (nmax-nmin)) + nmin ; 83 | my_int m = (rand() % (n-1)) + 1 ; 84 | n = n - m; 85 | std::string p( 2 * n + 2 * m, 'a'); 86 | for(my_int i = n ; i < n + m; i++) 87 | p[i] = 'b'; 88 | for(my_int i = n + m ; i < 2 * n + m ; i++) 89 | p[i] = 'c'; 90 | for(my_int i = 2 * n + m ; i < p.size(); i++) 91 | p[i] = 'd'; 92 | return p; 93 | } 94 | 95 | std::string generate_addition(const my_int nmax, const my_int nmin, my_int base){ 96 | if(base < 2) exit(-1); 97 | 98 | my_int i0 = 1; 99 | my_int tln; 100 | if(nmax > nmin+1) 101 | tln =(rand() % (nmax-nmin)) + nmin; 102 | else 103 | tln = nmin; 104 | my_int ln = rand() % (tln+1); 105 | my_int lm = tln - ln; 106 | 107 | std::string n = std::string(ln, '1' + (rand() % (base-1))); 108 | if(ln == 0) {n = "0"; ln = 1;} 109 | for(my_int i = i0; i < ln; i++) 110 | n[i] = '0' + (rand() % base); 111 | std::string m = std::string(lm, '1' + (rand() % (base-1))); 112 | if(lm == 0) {m = "0"; lm = 1;} 113 | for(my_int i = i0; i < lm; i++) 114 | m[i] = '0' + (rand() % base); 115 | 116 | std::string p = n;p += "+";p += m;p += "="; 117 | my_int carry = 0; 118 | my_int in = n.size() -1, im = m.size() -1; 119 | while ( in >= 0 || im >= 0 || carry > 0){ 120 | my_int num = carry; 121 | if( in >= 0) num += n[in] - '0'; 122 | if( im >= 0) num += m[im] - '0'; 123 | p += '0' + (num % base); 124 | carry = num /base; 125 | in--; im--; 126 | } 127 | p+="."; 128 | return p; 129 | } 130 | 131 | std::string generate_next_sequence(const my_int nmax, const my_int nmin, my_int nchar, my_int nrep, my_int ntask){ 132 | 133 | if(ntask == 2) 134 | return task2(nmax, nmin, nchar, nrep); 135 | if(ntask == 3){ 136 | return task3(nmax, nmin, 3); 137 | } 138 | if(ntask == 6){ 139 | return task6(nmax, nmin, nchar); 140 | } 141 | if(ntask == 5){ 142 | return task5(nmax, nmin, nchar); 143 | } 144 | if(ntask == 4){ 145 | return task4(nmax, nmin, nchar); 146 | } 147 | return task1(nmax, nmin, nchar); 148 | } 149 | 150 | 151 | } 152 | 153 | #endif 154 | -------------------------------------------------------------------------------- /train_add.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "common.h" 19 | #include "task.h" 20 | #include "StackRNN.h" 21 | 22 | using namespace std; 23 | using namespace rnn; 24 | 25 | int main(int argc, char **argv){ 26 | 27 | int nhid = 100; 28 | int nstack = 10; 29 | int stack_size = 200; 30 | int bptt = 50; 31 | float lr = 0.1; 32 | int mod = 1; 33 | int nmaxmax = 20; 34 | int nmin = 2; 35 | bool isnoop = true; 36 | bool ishard = false; 37 | int nreset = 10; 38 | int base = 2; 39 | int depth = 2; 40 | int nseq = 10000; 41 | int seed = 22; 42 | bool save = false; 43 | int nvalidmax = 20; 44 | float lrmin = 1e-5; 45 | 46 | int ai = 1; 47 | while(ai < argc){ 48 | if( strcmp( argv[ai], "-nhid") == 0){ 49 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 50 | nhid = atoi(argv[ai+1]); 51 | } 52 | else if( strcmp( argv[ai], "-nseq") == 0){ 53 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 54 | nseq = atoi(argv[ai+1]); 55 | } 56 | else if( strcmp( argv[ai], "-nstack") == 0){ 57 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 58 | nstack = atoi(argv[ai+1]); 59 | } 60 | else if( strcmp( argv[ai], "-stack_size") == 0){ 61 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 62 | stack_size = atoi(argv[ai+1]); 63 | } 64 | else if( strcmp( argv[ai], "-bptt") == 0){ 65 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 66 | bptt = atoi(argv[ai+1]); 67 | } 68 | else if( strcmp( argv[ai], "-mod") == 0){ 69 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 70 | mod = atoi(argv[ai+1]); 71 | } 72 | else if( strcmp( argv[ai], "-lr") == 0){ 73 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 74 | lr = atof(argv[ai+1]); 75 | } 76 | else if( strcmp( argv[ai], "-nreset") == 0){ 77 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 78 | nreset = atoi(argv[ai+1]); 79 | if(nreset < 0) {printf("error nchar should be >= 0\n");return -1;} 80 | } 81 | else if( strcmp( argv[ai], "-base") == 0){ 82 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 83 | base = atoi(argv[ai+1]); 84 | } 85 | else if( strcmp( argv[ai], "-nmin") == 0){ 86 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 87 | nmin = atoi(argv[ai+1]); 88 | } 89 | else if( strcmp( argv[ai], "-seed") == 0){ 90 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 91 | seed = atoi(argv[ai+1]); 92 | } 93 | else if( strcmp( argv[ai], "-nmax") == 0){ 94 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 95 | nmaxmax = atoi(argv[ai+1]); 96 | } 97 | else if( strcmp( argv[ai], "-nvalidmax") == 0){ 98 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 99 | nvalidmax = atoi(argv[ai+1]); 100 | } 101 | else if( strcmp( argv[ai], "-noop") == 0){ 102 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 103 | isnoop = true; 104 | ai--; 105 | } 106 | else if( strcmp( argv[ai], "-save") == 0){ 107 | save = true; 108 | ai--; 109 | } 110 | else if( strcmp( argv[ai], "-hard") == 0){ 111 | ishard = true; 112 | ai--; 113 | } 114 | else if( strcmp( argv[ai], "-lrmin") == 0){ 115 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 116 | lrmin = atoi(argv[ai+1]); 117 | } 118 | else if( strcmp( argv[ai], "-depth") == 0){ 119 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 120 | depth = atoi(argv[ai+1]); 121 | if(depth < 1) {printf("error in depth...\n"); return -1;} 122 | } 123 | else{ 124 | printf("unknown option: %s\n",argv[ai]); 125 | return -1; 126 | } 127 | ai += 2; 128 | } 129 | srand(seed); 130 | 131 | cout<<"seed: "< dic; 165 | vector rdic(nchar,0); 166 | dic['+'] = 0; rdic[0] = '+'; 167 | dic['='] = 1; rdic[1] = '='; 168 | dic['.'] = 2; rdic[2] = '.'; 169 | for(int i = 0; i < nchar -3; i++) 170 | { dic['0'+i] = 3 + i; rdic[3 + i] = '0' + i;} 171 | 172 | cout<<"create rnn..."; 173 | StackRNN rnn(nchar, nhid, nstack, stack_size, 174 | nchar, bptt, 1, mod, isnoop, depth, 0); 175 | StackRNN back_up_model(nchar, nhid, nstack, stack_size, 176 | nchar, bptt, 1, mod, isnoop, depth, 0); 177 | cout<<"done"<= nmax) nmax = nmin + 1; 182 | int nseqv = 1000; 183 | 184 | string p = generate_addition(nmax, nmin, base); 185 | 186 | int count = 0, neval = 0; 187 | int ne = 0; double lo = 0; 188 | int nepoch = 100; 189 | bool iseval = true; 190 | float last_ent = 0; 191 | double loss; 192 | string spred, sgoal; 193 | 194 | FILE* f; 195 | for(int e = 0; e < nepoch; e++){ 196 | nmax = max(min(e+3,nmaxmax),3); 197 | nmin = 0; 198 | neval = 1; loss = 0; 199 | ne = 1; lo = 0; 200 | 201 | rnn.emptyStacks(); 202 | 203 | /************* TRAIN *************/ 204 | 205 | for(int iseq = 0; iseq < nseq; iseq++) { 206 | 207 | p = generate_addition(nmax, nmin, base); 208 | 209 | if(nreset > 0 && iseq % nreset == 0 ) rnn.emptyStacks(); 210 | //spred += '_'; sgoal += '_'; 211 | 212 | iseval = false; 213 | for(int ip = 0; ip < p.size(); ip++){ 214 | next = dic[p[ip]]; 215 | if(rdic[cur] == '=') iseval = true; 216 | 217 | rnn.forward(cur, next); 218 | 219 | spred += (iseval)? rdic[rnn.pred()] : '_'; sgoal += rdic[next]; 220 | if (spred.size() > 30) spred.erase(spred.begin(), spred.end() - 30); 221 | if (sgoal.size() > 30) sgoal.erase(sgoal.begin(), sgoal.end() - 30); 222 | if(ip == 0 && iseq == 0) rnn.emptyStacks(); 223 | if(iseval) { 224 | rnn.backward(); 225 | rnn.update(lr); 226 | lo -= log(rnn.eval(next)) / log(10); ne++; 227 | fprintf(stdout, "\r[train] lr: %.5f\tnmax: %02d\tentropy: %.3f\tgoal: %s pred: %s prog=%.1f%%", 228 | lr, nmax, lo / ne, sgoal.c_str(), spred.c_str(), 100.0 * iseq / nseq); 229 | } 230 | cur = next; 231 | } 232 | 233 | } 234 | fprintf(stdout, "\r[train] lr: %.5f\tnmax: %02d\tentropy: %.3f\tgoal: %s pred: %s\n", 235 | lr, nmax, lo / ne, sgoal.c_str(), spred.c_str()); 236 | 237 | /************* VALID *************/ 238 | 239 | nmax = max(nmaxmax, nvalidmax); 240 | nmin = min(nmaxmax, nvalidmax); 241 | ne = 1; lo = 0; 242 | 243 | rnn.emptyStacks(); 244 | 245 | for(int iseq = 0; iseq < nseqv; iseq++){ 246 | //spred += '_'; sgoal += '_'; 247 | p = generate_addition(nmax, nmin, base); 248 | iseval = false; 249 | for(int ip = 0; ip < p.size(); ip++){ 250 | next = dic[p[ip]]; 251 | if(rdic[cur] == '=') iseval = true; 252 | rnn.forward(cur, next, ishard); 253 | spred += (iseval)? rdic[rnn.pred()] : '_'; sgoal += rdic[next]; 254 | if (spred.size() > 30) spred.erase(spred.begin(), spred.end() - 30); 255 | if (sgoal.size() > 30) sgoal.erase(sgoal.begin(), sgoal.end() - 30); 256 | if(ip == 0 && iseq == 0) rnn.emptyStacks(); 257 | if(iseval){ 258 | lo -= log(rnn.eval(next)) / log(10); 259 | ne++; 260 | fprintf(stdout, "\r[valid] lr: %.5f\tnmax: %d\tentropy: %.3f\tgoal: %s pred: %s prog=%.1f%%", 261 | lr, nmax, lo / ne, sgoal.c_str(), spred.c_str(), 100.0 * iseq / nseqv); 262 | } 263 | cur = next; 264 | } 265 | } 266 | 267 | fprintf(stdout, "\r[valid] lr: %.5f\tnmax: %02d\tentropy: %.3f \tgoal: %s pred: %s\n", 268 | lr, nmax, lo / ne, sgoal.c_str(), spred.c_str()); 269 | 270 | if( e == 0 || lo / ne < last_ent){ 271 | last_ent = lo / ne; 272 | back_up_model.copy(rnn); 273 | back_up_model.save(modelname); 274 | } 275 | else if( e > 0 ){ 276 | if(e > nmaxmax/2){ 277 | lr /= 2; 278 | rnn.copy(back_up_model); 279 | } 280 | if(last_ent < .1) //supervised | < .1 means it works 281 | rnn.copy(back_up_model); 282 | } 283 | if(lr < lrmin) break; 284 | } 285 | 286 | FILE* fseq; 287 | FILE* fres; 288 | fprintf(stdout,"Test set: \n"); 289 | if(save){ 290 | sprintf(buff,"data/test_seqence"); 291 | cout << " Sequence used at test time saved at: "<< buff << endl; 292 | fseq = fopen(buff,"w"); 293 | fres = fopen(testfilename.c_str(),"w"); 294 | fprintf(fres,"validation:\t %f\n", lo / ne); 295 | } 296 | int ntest = 200; 297 | bool begin_seq = true; 298 | cur = nchar - 1; 299 | 300 | rnn.emptyStacks(); 301 | 302 | for(int nm = 2; nm < 60; nm++){ 303 | nmin = nm; nmax = nm + 1; 304 | float corr = 0, ecorr = 0; 305 | int sseq = 0; nseq = 0; 306 | neval = 0; 307 | ne = 0;lo = 0; 308 | if(save) f = fopen(logtestfilename.c_str(),"w"); 309 | 310 | for(int iseq = 0; iseq < ntest; iseq++){ 311 | p = generate_addition(nmax, nmin, base); 312 | iseval = false; 313 | if(nreset > 0 && iseq % nreset == 0 ) rnn.emptyStacks(); 314 | 315 | for(int ip = 0; ip < p.size(); ip++){ 316 | next = dic[p[ip]] ; 317 | if(save) fprintf(fseq, "%c", p[ip]); 318 | 319 | rnn.forward(cur, next, ishard); 320 | 321 | // begin of a sequence / end of evaluation: 322 | if (ip == 0) { 323 | if(iseq != 0){ 324 | neval++; 325 | if( corr == sseq ) ecorr++; 326 | if(save)fprintf(f, "end eval - accuracy: %f \n", ecorr / neval); 327 | } 328 | sseq=0; corr = 0; 329 | iseval = false; 330 | } 331 | 332 | if(iseval && next == rnn.pred()) corr++; 333 | if(iseval) sseq++; 334 | 335 | lo -= log(rnn.eval(next)) / log(10); 336 | ne++; 337 | 338 | // begin of evaluation: 339 | if(rdic[next] == '=') { 340 | iseval = true; 341 | if(save) fprintf(f, "begin eval\n"); 342 | } 343 | cur = next; 344 | count++; 345 | } 346 | } 347 | if(save){ 348 | fprintf(fres,"%d \t %f\n", nm, ecorr / neval); 349 | fclose(f); 350 | } 351 | fprintf(stdout,"n: %d \t accuracy: %f \n", nm, ecorr / neval); 352 | } 353 | fprintf(stdout, "\n"); 354 | if(save){ 355 | fclose(fres); 356 | fclose(fseq); 357 | } 358 | return 0; 359 | } 360 | -------------------------------------------------------------------------------- /train_toy.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "common.h" 18 | #include "task.h" 19 | #include "StackRNN.h" 20 | 21 | using namespace std; 22 | using namespace rnn; 23 | 24 | /**************** 25 | This files is used to learn a model for a given 26 | toy task (e.g., a^nb^n) 27 | 28 | See script_toy.sh to see examples of how to use this file to 29 | reproduce the experiments in our paper 30 | 31 | **************/ 32 | void print_help(){ 33 | printf("train_toy is used ot train a model on simple toy tasks (see Joulin and Mikolov, 2015)\n"); 34 | printf("We print every sequence seperated by a underscore [_]. The model does not see this character.\n"); 35 | printf("usage: train_toy [options]\n"); 36 | printf("options:\n"); 37 | printf("-nhid [integer]\t\t number of units in the hidden layer. Default value: 40\n"); 38 | printf("-nstack [integer]\t number of stacks. Default value: 10\n"); 39 | printf("-depth [integer]\t depth used of the stack to predict the hidden units. Default value: 1\n"); 40 | printf("-stack_size [integer]\t size of the stack container. Default value: 200\n"); 41 | printf("-bptt [integer]\t\t number of step of the back-propagation through tie (BPTT). Default value: 50\n"); 42 | printf("-nseq [integer]\t\t number of sequences used at each training epoch. Default value: 2000\n"); 43 | printf("-mod [integer]\t\t switch between feedforward (-mod 0), recurrence only through stacks (-mod 1) and recurrence through hidden layer and stacks (-mod 2). Default value: 21\n"); 44 | printf("-lr [float]\t\t Learning rate. Default value: 0.1\n"); 45 | printf("-nreset [integer]\t how often the stacks are emptied. Default value: 1000\n"); 46 | printf("-ntask [integer]\t choice the task (see readme or script_toy.sh). Default value: 1 \n"); 47 | printf("-nchar [integer]\t number of characters for a task (works with ntaks - see readme). Default value: 2\n"); 48 | printf("-nrep [integer]\t\t number of repetition in characters for a task (only use for ntask=2 - see readme). Default value: 1\n"); 49 | printf("-seed [integer]\t\t seed for the random number generator. Default value: 1\n"); 50 | printf("-nmax [integer]\t\t the maximum value for n for the tasks (e.g. n in a^nb^n). Default value: 10\n"); 51 | printf("-save \t\t\t use to save a bunch of things, like the model, logs... Default: false\n"); 52 | printf("-noop \t\t\t use a no-op action on the stack. Default: false\n"); 53 | printf("-hard \t\t\t use adiscrete actions at validation and test time. Default: false\n"); 54 | printf("Examples: \n ./train_toy -nhid 40 -nstack 10 -depth 2 -ntask 1 -nchar 2 -lr .1 -seed 1\n"); 55 | } 56 | 57 | void print(StackRNN& rnn, FILE* f, int cur, int next){ 58 | int nstack = rnn._NB_STACK; 59 | bool isnoop = (rnn._ACTION == 3); 60 | int naction = rnn._ACTION; 61 | fprintf(f, "cur: %c next: %c pred: %c ", 'a' + cur, 'a' + next, 'a' + rnn.pred()); 62 | fprintf(f, "prob[%c]: %f ", 'a'+next, rnn.eval(next)); 63 | for(int s = 0; s < nstack; s++) { 64 | if(rnn._act[s][rnn._it_mem][push] * naction > 1. ) 65 | fprintf(f, " push[%f] ", rnn._act[s][rnn._it_mem][push]); 66 | if(rnn._act[s][rnn._it_mem][pop] * naction > 1. ) 67 | fprintf(f, " pop[%f] ",rnn._act[s][rnn._it_mem][pop]); 68 | if(isnoop && rnn._act[s][rnn._it_mem][noop] * naction > 1. ) 69 | fprintf(f, " noop[%f] ", rnn._act[s][rnn._it_mem][noop]); 70 | } 71 | for(int s = 0; s < nstack; s++) { 72 | fprintf(f,"stack[%d]: ",s); 73 | for(int d = 0; d < 3; d++) 74 | fprintf(f," [%d]:%.3f" , d, rnn._stack[s][rnn._it_mem][d]); 75 | } 76 | fprintf(f,"\n"); 77 | } 78 | 79 | int main(int argc, char **argv){ 80 | 81 | int nhid = 40; 82 | int nstack = 10; 83 | int stack_size = 200; 84 | int bptt = 50; 85 | float lr = 0.1; 86 | int max_count_train = 10000000; 87 | string modelname = "model"; 88 | int mod = 1; 89 | int nmaxmax = 5; 90 | int nmin = 2; 91 | bool isnoop = false; 92 | bool ishard = false; 93 | int nchar = 2; 94 | int nrep = 1; 95 | int nreset = 1000; 96 | int ntask = 1; 97 | int depth = 2; 98 | int nseq = 2000; 99 | int seed = 1; 100 | double reg = 0; 101 | bool save = false; 102 | 103 | printf("For help: train_toy --help\n"); 104 | 105 | int ai = 1; 106 | while(ai < argc){ 107 | if( strcmp( argv[ai], "--help") == 0){ 108 | print_help(); 109 | return 1; 110 | } 111 | if( strcmp( argv[ai], "-nhid") == 0){ 112 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 113 | nhid = atoi(argv[ai+1]); 114 | } 115 | else if( strcmp( argv[ai], "-nseq") == 0){ 116 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 117 | nseq = atoi(argv[ai+1]); 118 | } 119 | else if( strcmp( argv[ai], "-nstack") == 0){ 120 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 121 | nstack = atoi(argv[ai+1]); 122 | } 123 | else if( strcmp( argv[ai], "-stack_size") == 0){ 124 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 125 | stack_size = atoi(argv[ai+1]); 126 | } 127 | else if( strcmp( argv[ai], "-bptt") == 0){ 128 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 129 | bptt = atoi(argv[ai+1]); 130 | } 131 | else if( strcmp( argv[ai], "-mod") == 0){ 132 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 133 | mod = atoi(argv[ai+1]); 134 | } 135 | else if( strcmp( argv[ai], "-reg") == 0){ 136 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 137 | reg = atof(argv[ai+1]); 138 | } 139 | else if( strcmp( argv[ai], "-lr") == 0){ 140 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 141 | lr = atof(argv[ai+1]); 142 | } 143 | else if( strcmp( argv[ai], "-nreset") == 0){ 144 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 145 | nreset = atoi(argv[ai+1]); 146 | if(nreset < 0) {printf("error nchar should be >= 0\n");return -1;} 147 | } 148 | else if( strcmp( argv[ai], "-nrep") == 0){ 149 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 150 | nrep = atoi(argv[ai+1]); 151 | if(nrep < 1) {printf("error nchar should be >= 1\n");return -1;} 152 | } 153 | else if( strcmp( argv[ai], "-ntask") == 0){ 154 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 155 | ntask = atoi(argv[ai+1]); 156 | } 157 | else if( strcmp( argv[ai], "-nchar") == 0){ 158 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 159 | nchar = atoi(argv[ai+1]); 160 | if(nchar < 2) {printf("error nchar should be >= 2\n");return -1;} 161 | } 162 | else if( strcmp( argv[ai], "-ntrain") == 0){ 163 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 164 | max_count_train = atoi(argv[ai+1]); 165 | } 166 | else if( strcmp( argv[ai], "-nmin") == 0){ 167 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 168 | nmin = atoi(argv[ai+1]); 169 | } 170 | else if( strcmp( argv[ai], "-seed") == 0){ 171 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 172 | seed = atoi(argv[ai+1]); 173 | } 174 | else if( strcmp( argv[ai], "-nmax") == 0){ 175 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 176 | nmaxmax = atoi(argv[ai+1]); 177 | } 178 | else if( strcmp( argv[ai], "-noop") == 0){ 179 | isnoop = true; 180 | ai--; 181 | } 182 | else if( strcmp( argv[ai], "-save") == 0){ 183 | save = true; 184 | ai--; 185 | } 186 | else if( strcmp( argv[ai], "-hard") == 0){ 187 | ishard = true; 188 | ai--; 189 | } 190 | else if( strcmp( argv[ai], "-depth") == 0){ 191 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 192 | depth = atoi(argv[ai+1]); 193 | if(depth < 1) {printf("error blabla depth...\n"); return -1;} 194 | } 195 | else if( strcmp( argv[ai], "-name") == 0){ 196 | if(ai + 1 >= argc) { printf("error need argument for option %s\n",argv[ai]); return - 1;} 197 | modelname = argv[ai+1]; 198 | } 199 | else{ 200 | printf("unknown option: %s\n",argv[ai]); 201 | return -1; 202 | } 203 | ai += 2; 204 | } 205 | 206 | cout<<"seed: "<= nmax) nmax = nmin + 1; 252 | 253 | string p = generate_next_sequence(nmax, nmin, nchar, nrep, ntask); 254 | 255 | // string to be print: 256 | string spred(50,'#'); 257 | string sgoal(50,'#'); 258 | 259 | vector sstacks(nstack); 260 | for(int s = 0; s < nstack; s++) 261 | sstacks[s] = string(50,'#'); 262 | 263 | 264 | int count = 0, neval = 0; 265 | int ne = 0; 266 | double lo = 0; 267 | 268 | 269 | int nepoch = 100; 270 | 271 | float last_ent = 0; 272 | FILE* f; 273 | for(int e = 0; e < nepoch; e++){ 274 | 275 | if(save) f = fopen(logfilename.c_str(), "w"); 276 | nmax = max(min(e+3,nmaxmax),3); 277 | neval = 1; loss = 0; 278 | ne = 1; lo = 0; 279 | count = 0; 280 | 281 | rnn.emptyStacks(); 282 | 283 | // train on increasingly more challenging tasks: 284 | for(int iseq = 0; iseq < nseq; iseq++){ 285 | p = generate_next_sequence(nmax, nmin, nchar, nrep, ntask); 286 | 287 | if(save) fprintf(f,"begin sequence\n"); 288 | spred += '_'; sgoal += '_'; 289 | for(int s = 0; s < nstack; s++) sstacks[s] += '_'; 290 | if(nreset == 1 || (nreset > 0 && iseq % nreset == 0 )) rnn.emptyStacks(); 291 | 292 | for(int ip = 0; ip < p.size(); ip++){ 293 | next = p[ip] - 'a'; 294 | 295 | rnn.forward(cur, next); 296 | 297 | if(ip == 0 && iseq == 0) rnn.emptyStacks(); 298 | else{ 299 | rnn.backward(); 300 | rnn.update(lr); 301 | } 302 | if (ip == 0) { 303 | loss -= log(rnn.eval(next)) / log(10); 304 | neval++; 305 | } 306 | lo -= log(rnn.eval(next)) / log(10); 307 | ne++; 308 | 309 | // print stuff: 310 | if(save) print(rnn, f, cur, next); 311 | 312 | spred += 'a' + rnn.pred(); sgoal += 'a' + next; 313 | for(int s = 0; s < nstack; s++) { 314 | if(rnn._act[s][rnn._it_mem][pop] > 0.7) sstacks[s] += '-'; 315 | else if(rnn._act[s][rnn._it_mem][push] > 0.7) sstacks[s] += '+'; 316 | else if(isnoop && rnn._act[s][rnn._it_mem][noop] > 0.7) sstacks[s] += '|'; 317 | else sstacks[s] += 'X'; 318 | } 319 | if (spred.size() > 30) spred.erase(spred.begin(), spred.end() - 30); 320 | if (sgoal.size() > 30) sgoal.erase(sgoal.begin(), sgoal.end() - 30); 321 | for(int s = 0; s < nstack; s++) if (sstacks[s].size() > 30){ 322 | sstacks[s].erase(sstacks[s].begin(), sstacks[s].end() - 30); 323 | } 324 | if(ip == 0){ 325 | fprintf(stdout, "\r [train] lr: %.5f it=%7d nmax:%d entropy: %.3f goal: %s pred: %s ", 326 | lr, count, nmax, lo / ne, sgoal.c_str(), spred.c_str()); 327 | } 328 | 329 | cur = next; 330 | count++; 331 | 332 | 333 | } 334 | } 335 | 336 | 337 | fprintf(stdout, "\r [train] lr: %.5f it=%7d nmax:%d entropy: %.3f goal: %s pred: %s ", 338 | lr, count, nmax, lo / ne, sgoal.c_str(), spred.c_str()); 339 | for(int s = 0; s < min(nstack,5); s++) 340 | fprintf(stdout, "| actions on stack[%d] = %s", s, sstacks[s].c_str()); 341 | fprintf(stdout," [ - = pop, + = push, | = no-op, X = not determined yet ]"); 342 | fprintf(stdout, "\n"); 343 | 344 | // evaluation on every sequences: 345 | nmax = max(nmaxmax, 20), nmin = 2; 346 | if(nstack==0) nmax = nmaxmax; // else it does not work for standard rnn... 347 | neval = 1; loss = 0; 348 | ne = 1; lo = 0; 349 | count = 0; 350 | 351 | rnn.emptyStacks(); 352 | cur = nchar - 1; 353 | 354 | if(save) fprintf(f, "[VALID]\n"); 355 | 356 | for(int iseq = 0; iseq < 1000; iseq++){ 357 | p = generate_next_sequence(nmax, nmin, nchar, nrep, ntask); 358 | spred += '_'; sgoal += '_'; 359 | for(int s = 0; s < nstack; s++) sstacks[s] += '_'; 360 | 361 | if( nreset == 1) rnn.emptyStacks(); 362 | 363 | if(save) fprintf(f,"begin sequence\n"); 364 | for(int ip = 0; ip < p.size(); ip++){ 365 | next = p[ip] - 'a'; 366 | 367 | rnn.forward(cur, next, ishard); 368 | 369 | //if(ip == 0 && iseq == 0) rnn.emptyStacks(); 370 | 371 | if (ip == 0) { 372 | loss -= log(rnn.eval(next)) / log(10); 373 | neval++; 374 | } 375 | lo -= log(rnn.eval(next)) / log(10); 376 | ne++; 377 | 378 | 379 | // printing stuff 380 | if(save) print(rnn, f, cur, next); 381 | 382 | spred += 'a' + rnn.pred(); sgoal += 'a' + next; 383 | for(int s = 0; s < nstack; s++) { 384 | if(rnn._act[s][rnn._it_mem][pop] > 0.7) sstacks[s] += '-'; 385 | else if(rnn._act[s][rnn._it_mem][push] > 0.7) sstacks[s] += '+'; 386 | else if(isnoop && rnn._act[s][rnn._it_mem][noop] > 0.7) sstacks[s] += '|'; 387 | else sstacks[s] += 'X'; 388 | } 389 | if (spred.size() > 30) spred.erase(spred.begin(), spred.end() - 30); 390 | if (sgoal.size() > 30) sgoal.erase(sgoal.begin(), sgoal.end() - 30); 391 | for(int s = 0; s < nstack; s++) if (sstacks[s].size() > 30){ 392 | sstacks[s].erase(sstacks[s].begin(), sstacks[s].end() - 30); 393 | } 394 | 395 | fprintf(stdout, "\r [valid] lr: %.5f it=%7d nmax:%d entropy: %.3f goal: %s pred: %s ", 396 | lr, count, nmax, lo / ne, sgoal.c_str(), spred.c_str()); 397 | 398 | cur = next; 399 | count++; 400 | } 401 | } 402 | if(save)fprintf(f, "\n [valid] lr: %.5f it=%7d nmax:%d entropy: %.3f goal: %s pred: %s \n", 403 | lr, count, nmax, lo / ne, sgoal.c_str(), spred.c_str()); 404 | 405 | 406 | fprintf(stdout, "\r [valid] lr: %.5f it=%7d nmax:%d entropy: %.3f goal: %s pred: %s ", 407 | lr, count, nmax, lo / ne, sgoal.c_str(), spred.c_str()); 408 | for(int s = 0; s < min(nstack,5); s++) 409 | fprintf(stdout, "| actions on stack[%d] = %s", s, sstacks[s].c_str()); 410 | fprintf(stdout," [ - = pop, + = push, | = no-op, X = not determined yet ]"); 411 | fprintf(stdout, "\n"); 412 | 413 | if( e == 0 || lo / ne < last_ent){ 414 | last_ent = lo / ne; 415 | back_up_model.copy(rnn); 416 | back_up_model.save(modelname); 417 | } 418 | else if( e > 0 ){ 419 | if(e > nmaxmax/2){ 420 | lr /= 2; 421 | rnn.copy(back_up_model); 422 | } 423 | } 424 | if(lr < 1e-5) break; 425 | 426 | rnn._reg *= 2; 427 | if(save) fclose(f); 428 | } 429 | 430 | srand(10); 431 | 432 | rnn.copy(back_up_model); 433 | 434 | cout<< testfilename << endl; 435 | cout<< logtestfilename << endl; 436 | count = 0; 437 | 438 | FILE* fseq; 439 | FILE* fres; 440 | fprintf(stdout,"Test set: \n"); 441 | if(save){ 442 | sprintf(buff,"data/test_seqence_ntask%d_nchar%d", ntask, nchar); 443 | cout << " Sequence used at test time saved at: "<< buff << endl; 444 | fseq = fopen(buff,"w"); 445 | fres = fopen(testfilename.c_str(),"w"); 446 | fprintf(fres,"validation:\t %f\n", lo / ne); 447 | } 448 | 449 | int ntest = 200; 450 | bool iseval = false; 451 | rnn.emptyStacks(); 452 | cur = nchar - 1; 453 | 454 | // task =4: 1st element is not part of the evaluation 455 | bool iscountfirstelement = (ntask != 4); 456 | 457 | for(int nm = 2; nm < 60; nm++){ 458 | nmin = nm; nmax = nm + 1; 459 | float corr = 0, ecorr = 0; 460 | int sseq = 0; nseq = 0; 461 | neval = 0; 462 | ne = 0;lo = 0; 463 | if(save)f = fopen(logtestfilename.c_str(),"w"); 464 | 465 | for(int iseq = 0; iseq < ntest; iseq++){ 466 | 467 | if(ntask >= 7) rnn.emptyStacks(); 468 | p = generate_next_sequence(nmax, nmin, nchar, nrep, ntask); 469 | iseval = false; 470 | 471 | for(int ip = 0; ip < p.size(); ip++){ 472 | next = p[ip] - 'a'; 473 | if(save)fprintf(fseq, "%c", p[ip]); 474 | 475 | rnn.forward(cur, next, ishard); 476 | 477 | //if(ip == 0 && iseq == 0) rnn.emptyStacks(); 478 | 479 | // begin of a sequence / end of evaluation: 480 | if (ip == 0) { 481 | if(iseq != 0){ 482 | neval++; 483 | if( corr == sseq && (!iscountfirstelement || next == rnn.pred())) 484 | ecorr++; 485 | if(save) fprintf(f, "end eval - accuracy: %f \n", ecorr / neval); 486 | } 487 | sseq=0; corr = 0; 488 | iseval = false; 489 | } 490 | 491 | if(iseval && next == rnn.pred()) corr++; 492 | if(iseval) sseq++; 493 | 494 | lo -= log(rnn.eval(next)) / log(10); 495 | ne++; 496 | 497 | // printing stuff 498 | if(save)print(rnn, f, cur, next); 499 | 500 | // begin of evaluation: 501 | if( (ntask == 1 && cur == 0 && next != 0) 502 | || (ntask == 2 && cur == 0 && next!= 0) 503 | || (ntask == 3 && cur == nchar -2 && next == nchar - 1) 504 | || (ntask == 4 && next == 0) 505 | || (ntask == 6 && cur == 1 && next == 2) 506 | || (ntask == 5 && cur == nchar -2 && next == nchar - 1) ){ 507 | iseval = true; 508 | if(save)fprintf(f, "begin eval\n"); 509 | } 510 | 511 | cur = next; 512 | count++; 513 | } 514 | } 515 | if(save){ 516 | fprintf(fres,"%d \t %f\n", nm, ecorr / neval); 517 | fclose(f); 518 | } 519 | fprintf(stdout,"n: %d \t accuracy: %f \n", nm, ecorr / neval); 520 | } 521 | fprintf(stdout, "\n"); 522 | if(save) fclose(fres); 523 | if(save) fclose(fseq); 524 | 525 | return 0; 526 | } 527 | -------------------------------------------------------------------------------- /utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | * 9 | */ 10 | #ifndef _UTILS_ 11 | #define _UTILS_ 12 | #include 13 | #include 14 | 15 | #include "common.h" 16 | #include "Vec.h" 17 | 18 | namespace rnn { 19 | 20 | // utils: 21 | void hardclipping( Vec& v, my_int b = -1, my_int e = -1){ 22 | if(b == -1 ) b = 0; 23 | if(e == -1 ) e = v.size(); 24 | for(my_int i = b; i < e; i++){ 25 | if( v[i] < -15) v[i] = -15; 26 | if( v[i] > 15) v[i] = 15; 27 | } 28 | } 29 | 30 | /* uniform distribution, (0..1] */ 31 | my_real drand() 32 | { 33 | return (rand()+1.0)/(RAND_MAX+1.0); 34 | } 35 | 36 | /* normal distribution, centered on 0, std dev 1 */ 37 | my_real random_normal() 38 | { 39 | return sqrt(-2*log(drand())) * cos(2*M_PI*drand()); 40 | } 41 | 42 | my_real random(my_real min, my_real max){ 43 | return rand()/(my_real)RAND_MAX*(max-min)+min; 44 | } 45 | 46 | void matrixXvector(Vec& dest, const Vec& srcvec, const Vec2D& srcmatrix, 47 | const my_int& obegin, const my_int& oend, 48 | const my_int& ibegin, const my_int& iend, const my_int& type) 49 | { 50 | // type = 0 -> srcmatrix * srcvec 51 | // type = 1 -> srcmatrix^T * srcvec 52 | 53 | assert(srcmatrix.nrow() >= oend); 54 | assert(srcmatrix.ncol() >= iend); 55 | 56 | my_int a, b; 57 | my_real val1, val2, val3, val4; 58 | my_real val5, val6, val7, val8; 59 | 60 | my_int matrix_width = srcmatrix.ncol(); 61 | 62 | 63 | if (type==0) { //ac mod 64 | assert(dest.size() >= oend); 65 | assert(srcvec.size() >= iend); 66 | for (b=0; b<(oend-obegin)/8; b++) { 67 | val1=0; 68 | val2=0; 69 | val3=0; 70 | val4=0; 71 | 72 | val5=0; 73 | val6=0; 74 | val7=0; 75 | val8=0; 76 | 77 | for (a=ibegin; a= iend); 107 | assert(srcvec.size() >= oend); 108 | for (a=0; a<(iend-ibegin)/8; a++) { 109 | val1=0; 110 | val2=0; 111 | val3=0; 112 | val4=0; 113 | 114 | val5=0; 115 | val6=0; 116 | val7=0; 117 | val8=0; 118 | 119 | for (b=obegin; b