├── README.md ├── matlab ├── processtextfile.m ├── README.md ├── LSTMeval.m ├── mLSTMeval.m ├── LSTMhutter.m ├── LSTMdynamic.m ├── mLSTMhutter.m └── mLSTMdynamic.m ├── chainer ├── WN.py ├── README.md ├── eval.py ├── mLSTMWN.py └── train.py └── LICENSE /README.md: -------------------------------------------------------------------------------- 1 | This repository contains Chainer and MATLAB code for multiplicative LSTM that was used for some results from the paper. The training configuration for initial experiments run on Hutter Prize is given in the matlab folder. Code for the most successful mLSTM configuration we have found to date is given in the chainer folder. The Chainer code was used to achieve 1.24 bits/char on Hutter Prize and 1.27 bits/char on text8. Each folder has its own separate instructions. 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /matlab/processtextfile.m: -------------------------------------------------------------------------------- 1 | function sequence=processtextfile(fname,outfname); 2 | %converts textfile into a sequence of numbers to be used in RNN experiments 3 | %include argument outfname to save "sequence" to a .mat file, which will allow "sequence" to be 4 | %loaded more quickly with the command load(outfname) 5 | 6 | 7 | fid = fopen(fname); 8 | bytes=fread(fid,'*uint8'); 9 | u = unique(bytes); 10 | sequence = zeros(length(bytes),1,'single'); 11 | for i=1:length(u) 12 | f = logical(bytes==u(i)); 13 | sequence(f) = i; 14 | end 15 | 16 | if exist('outfname','var') 17 | 18 | save('outfname','sequence') 19 | display('saved') 20 | end 21 | end -------------------------------------------------------------------------------- /chainer/WN.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from chainer import functions as F 4 | from chainer import links as L 5 | 6 | 7 | class WN(L.Linear): 8 | 9 | def __init__(self, *args, **kwargs): 10 | super(WN, self).__init__(*args, **kwargs) 11 | self.add_param('g', self.W.data.shape[0]) 12 | norm = np.linalg.norm(self.W.data, axis=1) 13 | self.g.data[...] = norm 14 | 15 | def __call__(self): 16 | """Applies the linear layer. 17 | Args: 18 | x (~chainer.Variable): Batch of input vectors. 19 | Returns: 20 | ~chainer.Variable: Output of the linear layer. 21 | """ 22 | norm = F.batch_l2_norm_squared(self.W) ** 0.5 23 | norm_broadcasted = F.broadcast_to( 24 | F.expand_dims(norm, 1), self.W.data.shape) 25 | g_broadcasted = F.broadcast_to( 26 | F.expand_dims(self.g, 1), self.W.data.shape) 27 | return g_broadcasted * self.W / norm_broadcasted 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2016, benkrause 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /chainer/README.md: -------------------------------------------------------------------------------- 1 | This is a chainer implementation of weight normalized multiplicative LSTM with variational dropout. 2 | 3 | requirements: python3, chainer==1.24 4 | 5 | does not run in more recent versions of chainer, install chainer with: 6 | 7 | pip install chainer==1.24 8 | 9 | this code is by default set to run with text8 and Hutter Prize. 10 | 11 | Hutter Prize can be downloaded here: http://mattmahoney.net/dc/enwik8.zip 12 | text8 can be downloaded here: http://mattmahoney.net/dc/text8.zip 13 | 14 | To train the model, put the uzipped text file (either text8 or enwik8) in the directory and use: 15 | 16 | python train.py --file filename 17 | 18 | The default settings were used to obtain 1.24 bits/char on Hutter Prize and 1.27 bits/char on text8. We did not save the initialization seed, but it should hopefully be possible to reproduce similar results. The model takes about 1 week to train on a GTX 1080 TI, and requires about 9GB of GPU memory. The model and log file are stored in the directory specified by --out ("result" by default). The model is saved intermittently throughout training at every log update. To evaluate the test set error after training, run: 19 | 20 | python eval.py --file filename 21 | 22 | train.py assumes the training set is the first 90M characters, and eval.py assumes the test set is the last 5M characters. This is the case for both Hutter prize and text8. 23 | 24 | To train faster and with less memory you could try: 25 | 26 | python train.py --file filename --epoch 10 --edrop 0.2 --unit 1900 --bproplen 100 27 | 28 | This configuration should finish roughly 5 times faster and use under 4GB of GPU memory, but will not obtain as strong results. The hidden and embedding sizes from training must be specified during evaluation, so if you use the above training configuration, evaluate with: 29 | 30 | python eval.py --file filename --unit 1900 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /matlab/README.md: -------------------------------------------------------------------------------- 1 | This code provides an implementation of multiplicative LSTM and a stacked LSTM baseline which can be quickly set up for character level language modelling of the Hutter prize dataset. Slight modifcations would allow application to other character level modelling tasks. 2 | 3 | Instructions for use 4 | 5 | 1. Download file enwik8.zip from http://www.mattmahoney.net/dc/textdata , and include unzipped file with name "enwik8" in directory. 6 | 7 | 2. Run mLSTMhutter.m or LSTMhutter.m from MATLAB to train mLSTM or LSTM model on data 8 | 9 | 3. Run mLSTMeval.m or LSTMeval.m from MATLAB to evaluate model trained in previous step on test set 10 | 11 | 4. Run mLSTMdynamic.m or LSTMdynamic.m from MATLAB to dynamically evaluate model on test set 12 | 13 | Files 14 | 15 | mLSTMhutter.m 16 | -Trains a multiplicative LSTM on the hutter prize dataset, saves network parameter values and writes continually to a log file. Takes ~2 days to run on GTX 970 GPU. 17 | 18 | mLSTMeval.m 19 | -Performs static evaluation of mLSTM on the test set, loading the network parameters saved during training. Takes a few minutes to run on GTX 970 GPU. 20 | 21 | mLSTMdynamic.m 22 | -Performs dynamic evaluation of mLSTM on the test set, loading the network parameters saved during training. Takes ~1 day to run on GTX 970 GPU. 23 | 24 | 25 | LSTMhutter.m 26 | -Trains an LSTM on the hutter prize dataset, saves parameter values and writes continually to a log file. Takes ~2 days to run on GTX 970 GPU. 27 | 28 | LSTMeval.m 29 | -Performs static evaluation of LSTM on the test set, loading the network parameters saved during training. Takes a few minutes to run on GTX 970 GPU. 30 | 31 | LSTMdynamic.m 32 | -Performs dynamic evaluation of LSTM on the test set, loading the network parameters saved during training. Takes ~1 day to run on GTX 970 GPU. 33 | 34 | processtextfile.m 35 | -reads training set text file into appropriate format 36 | 37 | 38 | Dependencies: 39 | MATLAB 2014a or newer, with parellel computing toolbox 40 | CUDA enabled GPU with atleast 4GB of RAM 41 | 42 | If no GPU is available, code can be run on CPU by commenting out "gpuDevice(1)" command at the top of the files, and removing comments from "gpuArray" and "gather" functions at the end of the files. However, running the experiments on a CPU using the full training set would be quite slow. Training set size can be reduced by setting the "maxtrain" variable. 43 | -------------------------------------------------------------------------------- /chainer/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import chainer 4 | import chainer.functions as F 5 | import chainer.links as L 6 | from chainer import Variable 7 | from chainer import training 8 | import sys 9 | import cupy as cp 10 | import os 11 | from mLSTMWN import mLSTM 12 | from WN import WN 13 | from chainer.optimizers import Adam 14 | from chainer import serializers 15 | 16 | def get_char(fname): 17 | fid = open(fname,'rb') 18 | byte_array = fid.read() 19 | text = [0]*len(byte_array) 20 | for i in range(0,len(byte_array)): 21 | text[i] = int(byte_array[i]) 22 | unique = list(set(text)) 23 | unique.sort() 24 | 25 | mapping = dict(zip(unique,list(range(0,len(unique))))) 26 | for i in range(0,len(text)): 27 | text[i] = mapping[text[i]] 28 | return text, mapping 29 | 30 | def ortho_init(shape): 31 | # From Lasagne and Keras. Reference: Saxe et al., http://arxiv.org/abs/1312.6120 32 | 33 | flat_shape = (shape[0], np.prod(shape[1:])) 34 | a = np.random.normal(0.0, 1.0, flat_shape) 35 | a=a.astype(dtype=np.float32) 36 | 37 | u, _, v = np.linalg.svd(a, full_matrices=False) 38 | 39 | q = u if u.shape == flat_shape else v 40 | q = q.reshape(shape) 41 | return q 42 | 43 | class RNNForLM(chainer.Chain): 44 | 45 | def __init__(self, nvocab, nunits, train=True): 46 | super(RNNForLM, self).__init__( 47 | embed=L.EmbedID(nvocab, 400), 48 | WhxWN = WN(400,nunits*4), 49 | WmxWN = WN(400,nunits), 50 | WmhWN = WN(nunits,nunits), 51 | WhmWN = WN(nunits,nunits*4), 52 | 53 | l1=mLSTM(out_size=nunits), 54 | l2=L.Linear(nunits, nvocab) 55 | ) 56 | nparam = 0 57 | for param in self.params(): 58 | print(param.data.shape) 59 | nparam+=param.data.size 60 | nparam+=param.data.size 61 | print('nparam') 62 | print(nparam) 63 | 64 | self.train = train 65 | 66 | def reset_state(self): 67 | self.l1.reset_state() 68 | 69 | def applyWN(self): 70 | self.Whx = self.WhxWN() 71 | self.Wmx = self.WmxWN() 72 | self.Wmh = self.WmhWN() 73 | self.Whm = self.WhmWN() 74 | 75 | def __call__(self, x,mask,mask2): 76 | 77 | 78 | h0 = self.embed(x)*mask2 79 | 80 | h1 = self.l1(h0,self.Whx,self.Wmx,self.Wmh,self.Whm) 81 | h1=h1*mask 82 | self.l1.h = h1 83 | y = self.l2(h1) 84 | 85 | return y 86 | 87 | def test(model,inputs,targets): 88 | inputs = Variable(inputs) 89 | targets = Variable(targets) 90 | 91 | targets.to_gpu() 92 | inputs.to_gpu() 93 | model.applyWN() 94 | model.train = False 95 | loss=0 96 | for j in range(inputs.shape[1]): 97 | output = model(inputs[:,j],1,1) 98 | loss = loss+ F.softmax_cross_entropy(output,targets[:,j]) 99 | loss.unchain_backward() 100 | 101 | model.train=True 102 | 103 | model.reset_state() 104 | 105 | finalloss = loss.data/inputs.shape[1] 106 | return finalloss 107 | 108 | def main(): 109 | 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--batchsize', '-b', type=int, default=10, 112 | help='test or val batch size, 5M mod batchsize must be 0') 113 | 114 | parser.add_argument('--gpu', '-g', type=int, default=0, 115 | help='GPU ID (negative value indicates CPU)') 116 | 117 | parser.add_argument('--model', '-o', default='result', 118 | help='Directory to load model from') 119 | 120 | parser.add_argument('--file', default="enwik8", 121 | help='path to text file for testing') 122 | parser.add_argument('--unit', '-u', type=int, default=2800, 123 | help='Number of LSTM units, must match model') 124 | parser.add_argument('--embd', type=int, default=400, 125 | help='Number of embedding units, must match model') 126 | parser.add_argument('--val', action='store_true', 127 | help='set for validation error, test by default') 128 | 129 | args = parser.parse_args() 130 | 131 | nembd = args.embd 132 | 133 | nbatch = args.batchsize 134 | 135 | filename= args.file 136 | 137 | text,mapping = get_char(filename) 138 | sequence = np.array(text).astype(np.int32) 139 | 140 | if args.val: 141 | start = 90000000-1 142 | else: 143 | start = 95000000-1 144 | neval = 5000000 145 | 146 | ival = sequence[start:start+neval] 147 | tval = sequence[start+1:start+neval+1] 148 | 149 | #uses subset of validation set 150 | ival = ival.reshape(args.batchsize,ival.shape[0]//args.batchsize) 151 | tval = tval.reshape(args.batchsize,tval.shape[0]//args.batchsize) 152 | #test = sequence[ntrain+nval:ntrain+nval+ntest] 153 | nvocab = max(sequence) + 1 # train is just an array of integers 154 | print('#vocab =', nvocab) 155 | # Prepare an RNNLM model 156 | rnn = RNNForLM(nvocab, args.unit,args.embd) 157 | modelname = os.path.join(args.model,'model') 158 | serializers.load_npz(modelname, rnn) 159 | model = L.Classifier(rnn) 160 | model.compute_accuracy = False # we only want the perplexity 161 | if args.gpu >= 0: 162 | chainer.cuda.get_device(args.gpu).use() # make the GPU current 163 | model.to_gpu() 164 | 165 | print('starting') 166 | 167 | start = 0 168 | loss_sum = 0; 169 | 170 | vloss = test(rnn,ival,tval) 171 | vloss= (1.4427*vloss) 172 | print('loss (bits/char): ' + str(vloss)) 173 | 174 | main() 175 | -------------------------------------------------------------------------------- /chainer/mLSTMWN.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import six 3 | 4 | import chainer 5 | from chainer.functions.activation import lstm 6 | from chainer.functions.array import concat 7 | from chainer.functions.array import split_axis 8 | from chainer.functions.activation import lstm 9 | from chainer import initializers 10 | from chainer import link 11 | from chainer import Function 12 | import chainer.functions as F 13 | from chainer.links.connection import linear 14 | from chainer import functions as f 15 | from chainer import variable 16 | 17 | 18 | class LSTMBase(link.Chain): 19 | 20 | def __init__(self, out_size, bias_init=0, forget_bias_init=1): 21 | super(LSTMBase, self).__init__() 22 | 23 | self.bias_init = bias_init 24 | self.forget_bias_init = forget_bias_init 25 | self.out_size = out_size 26 | self.state_size = out_size 27 | if out_size is not None: 28 | self._initialize_params() 29 | 30 | def _initialize_params(self): 31 | 32 | bias_initializer = initializers.Zero() 33 | self.add_param('b', self.state_size*4, initializer=bias_initializer) 34 | a, i, f, o = lstm._extract_gates(self.b.data.reshape(1, 4 * self.state_size, 1)) 35 | initializers.init_weight(a, self.bias_init) 36 | initializers.init_weight(i, self.bias_init) 37 | initializers.init_weight(f, self.forget_bias_init) 38 | initializers.init_weight(o, self.bias_init) 39 | 40 | 41 | 42 | class StatelessLSTM(LSTMBase): 43 | 44 | """Stateless LSTM layer. 45 | This is a fully-connected LSTM layer as a chain. Unlike the 46 | :func:`~chainer.functions.lstm` function, this chain holds upward and 47 | lateral connections as child links. This link doesn't keep cell and 48 | hidden states. 49 | Args: 50 | in_size (int): Dimension of input vectors. If ``None``, parameter 51 | initialization will be deferred until the first forward data pass 52 | at which time the size will be determined. 53 | out_size (int): Dimensionality of output vectors. 54 | Attributes: 55 | upward (chainer.links.Linear): Linear layer of upward connections. 56 | lateral (chainer.links.Linear): Linear layer of lateral connections. 57 | """ 58 | 59 | def __call__(self, c, h, x): 60 | """Returns new cell state and updated output of LSTM. 61 | Args: 62 | c (~chainer.Variable): Cell states of LSTM units. 63 | h (~chainer.Variable): Output at the previous time step. 64 | x (~chainer.Variable): A new batch from the input sequence. 65 | Returns: 66 | tuple of ~chainer.Variable: Returns ``(c_new, h_new)``, where 67 | ``c_new`` represents new cell state, and ``h_new`` is updated 68 | output of LSTM units. 69 | """ 70 | if self.upward.has_uninitialized_params: 71 | in_size = x.size // x.shape[0] 72 | self.upward._initialize_params(in_size) 73 | self._initialize_params() 74 | if self.upward2.has_uninitialized_params: 75 | in_size = x.size // x.shape[0] 76 | self.upward2._initialize_params(in_size) 77 | self._initialize_params() 78 | 79 | 80 | 81 | 82 | lstm_in = self.upward(x) 83 | if h is not None: 84 | lstm_in += self.lateral(h) 85 | if c is None: 86 | xp = self.xp 87 | c = variable.Variable( 88 | xp.zeros((x.shape[0], self.state_size), dtype=x.dtype), 89 | volatile='auto') 90 | return lstm.lstm(c, lstm_in) 91 | 92 | 93 | 94 | class mLSTM(LSTMBase): 95 | 96 | """Fully-connected LSTM layer. 97 | This is a fully-connected LSTM layer as a chain. Unlike the 98 | :func:`~chainer.functions.lstm` function, which is defined as a stateless 99 | activation function, this chain holds upward and lateral connections as 100 | child links. 101 | It also maintains *states*, including the cell state and the output 102 | at the previous time step. Therefore, it can be used as a *stateful LSTM*. 103 | This link supports variable length inputs. The mini-batch size of the 104 | current input must be equal to or smaller than that of the previous one. 105 | The mini-batch size of ``c`` and ``h`` is determined as that of the first 106 | input ``x``. 107 | When mini-batch size of ``i``-th input is smaller than that of the previous 108 | input, this link only updates ``c[0:len(x)]`` and ``h[0:len(x)]`` and 109 | doesn't change the rest of ``c`` and ``h``. 110 | So, please sort input sequences in descending order of lengths before 111 | applying the function. 112 | Args: 113 | in_size (int): Dimension of input vectors. If ``None``, parameter 114 | initialization will be deferred until the first forward data pass 115 | at which time the size will be determined. 116 | out_size (int): Dimensionality of output vectors. 117 | lateral_init: A callable that takes ``numpy.ndarray`` or 118 | ``cupy.ndarray`` and edits its value. 119 | It is used for initialization of the lateral connections. 120 | Maybe be ``None`` to use default initialization. 121 | upward_init: A callable that takes ``numpy.ndarray`` or 122 | ``cupy.ndarray`` and edits its value. 123 | It is used for initialization of the upward connections. 124 | Maybe be ``None`` to use default initialization. 125 | bias_init: A callable that takes ``numpy.ndarray`` or 126 | ``cupy.ndarray`` and edits its value 127 | It is used for initialization of the biases of cell input, 128 | input gate and output gate.and gates of the upward connection. 129 | Maybe a scalar, in that case, the bias is 130 | initialized by this value. 131 | Maybe be ``None`` to use default initialization. 132 | forget_bias_init: A callable that takes ``numpy.ndarray`` or 133 | ``cupy.ndarray`` and edits its value 134 | It is used for initialization of the biases of the forget gate of 135 | the upward connection. 136 | Maybe a scalar, in that case, the bias is 137 | initialized by this value. 138 | Maybe be ``None`` to use default initialization. 139 | Attributes: 140 | upward (~chainer.links.Linear): Linear layer of upward connections. 141 | lateral (~chainer.links.Linear): Linear layer of lateral connections. 142 | c (~chainer.Variable): Cell states of LSTM units. 143 | h (~chainer.Variable): Output at the previous time step. 144 | """ 145 | 146 | def __init__(self,out_size, **kwargs): 147 | super(mLSTM, self).__init__(out_size, **kwargs) 148 | self.reset_state() 149 | 150 | def to_cpu(self): 151 | super(mLSTM, self).to_cpu() 152 | if self.c is not None: 153 | self.c.to_cpu() 154 | if self.h is not None: 155 | self.h.to_cpu() 156 | 157 | def to_gpu(self, device=None): 158 | super(mLSTM, self).to_gpu(device) 159 | if self.c is not None: 160 | self.c.to_gpu(device) 161 | if self.h is not None: 162 | self.h.to_gpu(device) 163 | 164 | def set_state(self, c, h): 165 | """Sets the internal state. 166 | It sets the :attr:`c` and :attr:`h` attributes. 167 | Args: 168 | c (~chainer.Variable): A new cell states of LSTM units. 169 | h (~chainer.Variable): A new output at the previous time step. 170 | """ 171 | assert isinstance(c, chainer.Variable) 172 | assert isinstance(h, chainer.Variable) 173 | c_ = c 174 | h_ = h 175 | if self.xp == numpy: 176 | c_.to_cpu() 177 | h_.to_cpu() 178 | else: 179 | c_.to_gpu() 180 | h_.to_gpu() 181 | self.c = c_ 182 | self.h = h_ 183 | 184 | def reset_state(self): 185 | """Resets the internal state. 186 | It sets ``None`` to the :attr:`c` and :attr:`h` attributes. 187 | """ 188 | self.c = self.h = None 189 | 190 | def __call__(self, x,Whx,Wmx,Wmh,Whm): 191 | """Updates the internal state and returns the LSTM outputs. 192 | Args: 193 | x (~chainer.Variable): A new batch from the input sequence. 194 | Returns: 195 | ~chainer.Variable: Outputs of updated LSTM units. 196 | """ 197 | # if self.upward.has_uninitialized_params: 198 | # in_size = x.size // x.shape[0] 199 | # self.upward._initialize_params(in_size) 200 | # self._initialize_params() 201 | # if self.upward2.has_uninitialized_params: 202 | # in_size = x.size // x.shape[0] 203 | # self.upward2._initialize_params(in_size) 204 | # self._initialize_params() 205 | 206 | batch = x.shape[0] 207 | # Whx = self.upward() 208 | 209 | # Wmx = self.upward2() 210 | 211 | factor_in = F.linear(x,Wmx) 212 | lstm_in = F.linear(x,Whx,self.b) 213 | 214 | h_rest = None 215 | if self.h is not None: 216 | h_size = self.h.shape[0] 217 | if batch == 0: 218 | h_rest = self.h 219 | elif h_size < batch: 220 | msg = ('The batch size of x must be equal to or less than the ' 221 | 'size of the previous state h.') 222 | raise TypeError(msg) 223 | elif h_size > batch: 224 | h_update, h_rest = split_axis.split_axis( 225 | self.h, [batch], axis=0) 226 | # Wmh = self.lateral1() 227 | 228 | mult_in = F.linear(h_update,Wmh) 229 | 230 | mult_out = mult_in*factor_in 231 | # Whm = self.lateral2() 232 | lstm_in += F.linear(mult_out,Whm) 233 | 234 | else: 235 | # Wmh = self.lateral1() 236 | 237 | mult_in = F.linear(self.h,Wmh) 238 | 239 | mult_out = mult_in*factor_in 240 | # Whm = self.lateral2() 241 | lstm_in += F.linear(mult_out,Whm) 242 | 243 | if self.c is None: 244 | xp = self.xp 245 | self.c = variable.Variable(xp.zeros((batch, self.state_size), dtype=x.dtype),volatile='auto') 246 | self.c, y = lstm.lstm(self.c, lstm_in) 247 | 248 | if h_rest is None: 249 | self.h = y 250 | elif len(y.data) == 0: 251 | self.h = h_rest 252 | else: 253 | self.h = concat.concat([y, h_rest], axis=0) 254 | 255 | return y 256 | -------------------------------------------------------------------------------- /chainer/train.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import argparse 5 | import numpy as np 6 | import chainer 7 | import chainer.functions as F 8 | import chainer.links as L 9 | from chainer import Variable 10 | from chainer import training 11 | import sys 12 | import cupy as cp 13 | import os 14 | from mLSTMWN import mLSTM 15 | from WN import WN 16 | from chainer.optimizers import Adam 17 | from chainer import serializers 18 | 19 | def get_char(fname): 20 | fid = open(fname,'rb') 21 | byte_array = fid.read() 22 | text = [0]*len(byte_array) 23 | for i in range(0,len(byte_array)): 24 | text[i] = int(byte_array[i]) 25 | unique = list(set(text)) 26 | unique.sort() 27 | 28 | mapping = dict(zip(unique,list(range(0,len(unique))))) 29 | for i in range(0,len(text)): 30 | text[i] = mapping[text[i]] 31 | return text, mapping 32 | 33 | def ortho_init(shape): 34 | # From Lasagne and Keras. Reference: Saxe et al., http://arxiv.org/abs/1312.6120 35 | 36 | 37 | flat_shape = (shape[0], np.prod(shape[1:])) 38 | a = np.random.normal(0.0, 1.0, flat_shape) 39 | a=a.astype(dtype=np.float32) 40 | 41 | u, _, v = np.linalg.svd(a, full_matrices=False) 42 | 43 | q = u if u.shape == flat_shape else v 44 | q = q.reshape(shape) 45 | return q 46 | 47 | class RNNForLM(chainer.Chain): 48 | 49 | def __init__(self, nvocab, nunits, train=True): 50 | super(RNNForLM, self).__init__( 51 | embed=L.EmbedID(nvocab, 400), 52 | WhxWN = WN(400,nunits*4), 53 | WmxWN = WN(400,nunits), 54 | WmhWN = WN(nunits,nunits), 55 | WhmWN = WN(nunits,nunits*4), 56 | 57 | l1=mLSTM(out_size=nunits), 58 | l2=L.Linear(nunits, nvocab) 59 | ) 60 | nparam = 0 61 | for param in self.params(): 62 | print(param.data.shape) 63 | nparam+=param.data.size 64 | nparam+=param.data.size 65 | print('nparam') 66 | print(nparam) 67 | 68 | self.l1.b.data[2::4] = 3 69 | Wembd = np.random.uniform(-.2, .2, self.embed.W.data.shape) 70 | Wembd =Wembd.astype(dtype=np.float32) 71 | self.embed.W.data = Wembd 72 | 73 | self.WhxWN.W.data = ortho_init(self.WhxWN.W.data.shape) 74 | norm = np.linalg.norm(self.WhxWN.W.data, axis=1) 75 | self.WhxWN.g.data = norm 76 | 77 | self.WmxWN.W.data = ortho_init(self.WmxWN.W.data.shape) 78 | norm = np.linalg.norm(self.WmxWN.W.data, axis=1) 79 | self.WmxWN.g.data = norm 80 | 81 | self.WmhWN.W.data = ortho_init(self.WmhWN.W.data.shape) 82 | norm = np.linalg.norm(self.WmhWN.W.data, axis=1) 83 | self.WmhWN.g.data = norm 84 | 85 | self.WhmWN.W.data = ortho_init(self.WhmWN.W.data.shape) 86 | norm = np.linalg.norm(self.WhmWN.W.data, axis=1) 87 | self.WhmWN.g.data = norm 88 | 89 | self.l2.W.data= ortho_init(self.l2.W.data.shape) 90 | 91 | self.train = train 92 | 93 | def reset_state(self): 94 | self.l1.reset_state() 95 | 96 | def applyWN(self): 97 | self.Whx = self.WhxWN() 98 | self.Wmx = self.WmxWN() 99 | self.Wmh = self.WmhWN() 100 | self.Whm = self.WhmWN() 101 | 102 | 103 | 104 | def __call__(self, x,mask,mask2): 105 | 106 | 107 | h0 = self.embed(x)*mask2 108 | 109 | h1 = self.l1(h0,self.Whx,self.Wmx,self.Wmh,self.Whm) 110 | h1=h1*mask 111 | self.l1.h = h1 112 | y = self.l2(h1) 113 | 114 | return y 115 | 116 | def test(model,inputs,targets): 117 | inputs = Variable(inputs) 118 | targets = Variable(targets) 119 | 120 | targets.to_gpu() 121 | inputs.to_gpu() 122 | model.applyWN() 123 | model.train = False 124 | loss=0 125 | for j in range(inputs.shape[1]): 126 | output = model(inputs[:,j],1,1) 127 | loss = loss+ F.softmax_cross_entropy(output,targets[:,j]) 128 | loss.unchain_backward() 129 | 130 | model.train=True 131 | 132 | model.reset_state() 133 | 134 | finalloss = loss.data/inputs.shape[1] 135 | return finalloss 136 | 137 | def main(): 138 | 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('--batchsize', '-b', type=int, default=100, 141 | help='Number of examples in each mini-batch') 142 | parser.add_argument('--bproplen', '-l', type=int, default=200, 143 | help='Number of words in each mini-batch ' 144 | '(= length of truncated BPTT)') 145 | parser.add_argument('--epoch', '-e', type=int, default=40, 146 | help='Number of sweeps over the dataset to train') 147 | parser.add_argument('--gpu', '-g', type=int, default=0, 148 | help='GPU ID (negative value indicates CPU)') 149 | 150 | parser.add_argument('--out', '-o', default='result', 151 | help='Directory to output the result') 152 | 153 | parser.add_argument('--file', default="enwik8", 154 | help='path to text file for training') 155 | parser.add_argument('--unit', '-u', type=int, default=2800, 156 | help='Number of LSTM units') 157 | parser.add_argument('--embd', type=int, default=400, 158 | help='Number of embedding units') 159 | parser.add_argument('--hdrop', type=float, default=0.2, 160 | help='hidden state dropout (variational)') 161 | parser.add_argument('--edrop', type=float, default=0.5, 162 | help='embedding dropout') 163 | 164 | args = parser.parse_args() 165 | 166 | nembd = args.embd 167 | #number of training iterations per model save, log write, and validation set evaluation 168 | interval =100 169 | 170 | pdrop = args.hdrop 171 | 172 | pdrope = args.edrop 173 | 174 | #initial learning rate 175 | alpha0 = .001 176 | #inverse of linear decay rate towards 0 177 | dec_it = 12*9000 178 | #minimum learning rate 179 | alpha_min = .00007 180 | 181 | #first ntrain words of dataset will be used for training 182 | ntrain = 90000000 183 | 184 | 185 | seqlen = args.bproplen 186 | nbatch = args.batchsize 187 | 188 | filename= args.file 189 | 190 | text,mapping = get_char(filename) 191 | sequence = np.array(text).astype(np.int32) 192 | 193 | itrain =sequence[0:ntrain] 194 | ttrain = sequence[1:ntrain+1] 195 | fullseql=int(ntrain/nbatch) 196 | 197 | itrain = itrain.reshape(nbatch,fullseql) 198 | ttrain = ttrain.reshape(nbatch,fullseql) 199 | 200 | #doesn't use full validations set 201 | nval = 500000 202 | ival = sequence[ntrain:ntrain+nval] 203 | tval = sequence[ntrain+1:ntrain+nval+1] 204 | 205 | ival = ival.reshape(ival.shape[0]//1000,1000) 206 | tval = tval.reshape(tval.shape[0]//1000,1000) 207 | #test = sequence[ntrain+nval:ntrain+nval+ntest] 208 | 209 | 210 | nvocab = max(sequence) + 1 # train is just an array of integers 211 | print('#vocab =', nvocab) 212 | 213 | # Prepare an RNNLM model 214 | rnn = RNNForLM(nvocab, args.unit,args.embd) 215 | model = L.Classifier(rnn) 216 | model.compute_accuracy = False # we only want the perplexity 217 | if args.gpu >= 0: 218 | chainer.cuda.get_device(args.gpu).use() # make the GPU current 219 | model.to_gpu() 220 | 221 | # Set up an optimizer 222 | optimizer = Adam(alpha=alpha0) 223 | optimizer.setup(model) 224 | resultdir = args.out 225 | 226 | print('starting') 227 | nepoch = args.epoch 228 | 229 | start = 0 230 | loss_sum = 0; 231 | 232 | if not os.path.isdir(resultdir): 233 | os.mkdir(resultdir) 234 | 235 | vloss = test(rnn,ival,tval) 236 | vloss= (1.4427*vloss) 237 | f = open(os.path.join(resultdir,'log'), 'w') 238 | outstring = "Initial Validation loss (bits/word): " + str(vloss) + '\n' 239 | f.write(outstring) 240 | f.close() 241 | 242 | i=0 243 | epoch_num = 0 244 | it_num = 0 245 | 246 | while True: 247 | # Get the result of the forward pass. 248 | fin = start+seqlen 249 | 250 | if fin>(itrain.shape[1]): 251 | start = 0 252 | fin = start+seqlen 253 | epoch_num = epoch_num+1 254 | if epoch_num== nepoch: 255 | break 256 | 257 | inputs = itrain[:,start:fin] 258 | targets = ttrain[:,start:fin] 259 | start = fin 260 | 261 | inputs = Variable(inputs) 262 | targets = Variable(targets) 263 | 264 | targets.to_gpu() 265 | inputs.to_gpu() 266 | it_num+=1 267 | loss = 0 268 | rnn.applyWN() 269 | 270 | #make hidden dropout mask 271 | mask = cp.zeros((inputs.shape[0],args.unit),dtype = cp.float32) 272 | ind = cp.nonzero(cp.random.rand(inputs.shape[0],args.unit)>pdrop) 273 | mask[ind] = 1/(1-pdrop) 274 | 275 | #make embedding dropout mask 276 | mask2 = cp.zeros((inputs.shape[0],nembd),dtype = cp.float32) 277 | ind = cp.nonzero(cp.random.rand(inputs.shape[0],nembd)>pdrope) 278 | mask2[ind] = 1/(1-pdrope) 279 | 280 | for j in range(seqlen): 281 | 282 | output = rnn(inputs[:,j],mask,mask2) 283 | loss = loss+ F.softmax_cross_entropy(output,targets[:,j]) 284 | 285 | loss = loss/(seqlen) 286 | 287 | # Zero all gradients before updating them. 288 | rnn.zerograds() 289 | loss_sum += loss.data 290 | 291 | # Calculate and update all gradients. 292 | loss.backward() 293 | s = 0; 294 | 295 | # Use the optmizer to move all parameters of the network 296 | # to values which will reduce the loss. 297 | optimizer.update() 298 | #decays learning rate linearly 299 | optimizer.alpha = alpha0*(dec_it-it_num)/float(dec_it) 300 | #prevents learning rate from going below minumum 301 | if optimizer.alpha1 233 | 234 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).outs.v(:,:,t-1); 235 | else 236 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).outs.vp0 ; 237 | end 238 | 239 | if l>1 240 | 241 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).hweights.matrix*network.hidden(l-1).outs.v(:,:,t) ; 242 | end 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | network.hidden(l).ins.v(network.hidden(l).gateind,:,t)= bsxfun(@plus,network.hidden(l).ins.v(network.hidden(l).gateind,:,t),network.hidden(l).biases.v(network.hidden(l).gateind,:)); 253 | network.hidden(l).ins.v(network.hidden(l).gateind,:,t) = sigmoid(network.hidden(l).ins.v(network.hidden(l).gateind,:,t)); 254 | 255 | 256 | 257 | network.hidden(l).ins.state(:,:,t)=network.hidden(l).ins.v(network.hidden(l).hidind,:,t).*network.hidden(l).ins.v(network.hidden(l).writeind,:,t); 258 | 259 | 260 | 261 | if t>1 262 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t) + network.hidden(l).ins.state(:,:,t-1).*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 263 | else 264 | 265 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t) +network.hidden(l).ins.statep0.*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 266 | end 267 | 268 | 269 | 270 | network.hidden(l).outs.v(:,:,t) = network.hidden(l).ins.state(:,:,t).*network.hidden(l).ins.v(network.hidden(l).readind,:,t); 271 | network.hidden(l).outs.v(:,:,t)=bsxfun(@plus,network.hidden(l).outs.v(:,:,t),network.hidden(l).biases.v(network.hidden(l).hidind,:)); 272 | network.hidden(l).outs.v(:,:,t) = tanh(network.hidden(l).outs.v(:,:,t)); 273 | 274 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*network.hidden(l).outs.v(:,:,t); 275 | 276 | 277 | 278 | if t==size(inputs,3) 279 | network.hidden(l).outs.last = network.hidden(l).outs.v(:,:,t); 280 | network.hidden(l).ins.last = network.hidden(l).ins.state(:,:,t); 281 | end 282 | 283 | end 284 | end 285 | 286 | network.output.outs.v = network.output.fx(network.output.outs.v); 287 | 288 | 289 | end 290 | function [network] = ForwardPasstest(network, inputs) 291 | 292 | inputs = gpuArray(inputs); 293 | network.input.outs.v=inputs; 294 | 295 | 296 | for l=1:length(network.hidden) 297 | hidden(l).outs.vp = network.hidden(l).outs.vp0 ; 298 | hidden(l).ins.statep = network.hidden(l).ins.statep0; 299 | end 300 | 301 | for t=1:size(inputs,3); 302 | for l=1:length(network.hidden); 303 | 304 | 305 | 306 | 307 | hidden(l).ins.v= network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | hidden(l).ins.v = hidden(l).ins.v + network.hidden(l).weights.matrix*hidden(l).outs.vp; 322 | 323 | if l>1 324 | 325 | hidden(l).ins.v = hidden(l).ins.v + network.hidden(l).hweights.matrix*hidden(l-1).outs.v; 326 | end 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | hidden(l).ins.v(network.hidden(l).gateind,:)= bsxfun(@plus,hidden(l).ins.v(network.hidden(l).gateind,:),network.hidden(l).biases.v(network.hidden(l).gateind,:)); 335 | hidden(l).ins.v(network.hidden(l).gateind,:) = sigmoid(hidden(l).ins.v(network.hidden(l).gateind,:)); 336 | 337 | 338 | 339 | hidden(l).ins.state=hidden(l).ins.v(network.hidden(l).hidind,:).*hidden(l).ins.v(network.hidden(l).writeind,:); 340 | 341 | ttemp = t-1; 342 | 343 | 344 | hidden(l).ins.state = hidden(l).ins.state + hidden(l).ins.statep.*hidden(l).ins.v(network.hidden(l).keepind,:); 345 | 346 | 347 | 348 | 349 | hidden(l).outs.v = hidden(l).ins.state.*hidden(l).ins.v(network.hidden(l).readind,:); 350 | hidden(l).outs.v=bsxfun(@plus,hidden(l).outs.v,network.hidden(l).biases.v(network.hidden(l).hidind,:)); 351 | hidden(l).outs.v = tanh(hidden(l).outs.v); 352 | 353 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*hidden(l).outs.v; 354 | 355 | hidden(l).outs.vp = hidden(l).outs.v; 356 | hidden(l).ins.statep = hidden(l).ins.state; 357 | 358 | 359 | 360 | 361 | end 362 | end 363 | 364 | network.output.outs.v = network.output.fx(network.output.outs.v); 365 | 366 | 367 | 368 | end 369 | 370 | function network = computegradient(network, targets,omat) 371 | oind = find(omat); 372 | 373 | network.output.outs.j(oind) = network.output.outs.v(oind)- targets(oind); 374 | 375 | network.output.outs.j = network.output.outs.j.*network.output.dx(network.output.outs.v); 376 | 377 | 378 | 379 | 380 | for l=1:length(network.hidden) 381 | hidden(l).ins.statej = gpuArray(zeros(network.nhidden(l),size(targets,2),'single')); 382 | hidden(l).outs.j = gpuArray(zeros(network.nhidden(l),size(targets,2),'single')); 383 | hidden(l).ins.j = gpuArray(zeros(network.nhidden(l)*4,size(network.input.outs.v,2),'single')); 384 | end 385 | for t=size(network.input.outs.v,3):-1:1; 386 | 387 | 388 | 389 | for l= length(network.hidden):-1:1 390 | 391 | network.output.weights(l).gradient = network.output.weights(l).gradient + network.output.outs.j(:,:,t)*network.hidden(l).outs.v(:,:,t)'; 392 | 393 | hidden(l).outs.j = hidden(l).outs.j + network.output.weights(l).matrix'*network.output.outs.j(:,:,t) ; 394 | 395 | 396 | 397 | 398 | hidden(l).outs.j = hidden(l).outs.j.*tanhdir(network.hidden(l).outs.v(:,:,t)); 399 | network.hidden(l).biases.j(network.hidden(l).hidind,:) = network.hidden(l).biases.j(network.hidden(l).hidind,:) + sum(hidden(l).outs.j,2); 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | hidden(l).ins.j(network.hidden(l).readind,:) = hidden(l).outs.j.* network.hidden(l).ins.state(:,:,t) ; 408 | 409 | hidden(l).ins.statej = hidden(l).ins.statej + hidden(l).outs.j.*network.hidden(l).ins.v(network.hidden(l).readind,:,t); 410 | 411 | ttemp = t-1; 412 | if ttemp>0 413 | 414 | hidden(l).ins.statejp = hidden(l).ins.statej.*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 415 | 416 | 417 | 418 | 419 | end 420 | 421 | 422 | 423 | 424 | if ttemp>0 425 | 426 | 427 | hidden(l).ins.j(network.hidden(l).keepind,:) = network.hidden(l).ins.state(:,:,t-1).*hidden(l).ins.statej; 428 | else 429 | 430 | hidden(l).ins.j(network.hidden(l).keepind,:)= network.hidden(l).ins.statep0.*hidden(l).ins.statej; 431 | 432 | end 433 | 434 | 435 | hidden(l).ins.j(network.hidden(l).writeind,:) = network.hidden(l).ins.v(network.hidden(l).hidind,:,t).*hidden(l).ins.statej; 436 | hidden(l).ins.j(network.hidden(l).hidind,:) = network.hidden(l).ins.v(network.hidden(l).writeind,:,t).*hidden(l).ins.statej; 437 | 438 | hidden(l).ins.j(network.hidden(l).gateind,:)= hidden(l).ins.j(network.hidden(l).gateind,:).*sigdir(network.hidden(l).ins.v(network.hidden(l).gateind,:,t)); 439 | network.hidden(l).biases.j(network.hidden(l).gateind,:) = network.hidden(l).biases.j(network.hidden(l).gateind,:) + sum(hidden(l).ins.j(network.hidden(l).gateind,:),2); 440 | 441 | 442 | 443 | if t-1>0 444 | hidden(l).outs.jp = network.hidden(l).weights.matrix'*hidden(l).ins.j; 445 | 446 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + hidden(l).ins.j*network.hidden(l).outs.v(:,:,t-1)'; 447 | 448 | 449 | 450 | else 451 | 452 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + hidden(l).ins.j*network.hidden(l).outs.vp0'; 453 | 454 | 455 | end 456 | 457 | if l>1 458 | hidden(l-1).outs.j = hidden(l-1).outs.j + network.hidden(l).hweights.matrix'*hidden(l).ins.j; 459 | network.hidden(l).hweights.gradient = network.hidden(l).hweights.gradient + hidden(l).ins.j*network.hidden(l-1).outs.v(:,:,t)'; 460 | end 461 | 462 | network.hidden(l).iweights.gradient = network.hidden(l).iweights.gradient + (hidden(l).ins.j)*network.input.outs.v(:,:,t)'; 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | hidden(l).outs.j = hidden(l).outs.jp; 471 | hidden(l).ins.statej = hidden(l).ins.statejp; 472 | end 473 | 474 | 475 | 476 | end 477 | 478 | end 479 | 480 | 481 | function [lcost] = evalCrossEntropy(output,targets,omat) 482 | 483 | 484 | 485 | oind = find(omat); 486 | 487 | ldiff = targets.*log2(output); 488 | 489 | 490 | lcost = -1*sum(ldiff(:)); 491 | 492 | 493 | 494 | 495 | end 496 | 497 | function network = updateV(network, dW) 498 | 499 | ninput = network.input.n; 500 | noutput = network.output.n; 501 | 502 | 503 | start = 1; 504 | last = 0; 505 | 506 | for l=1:length(network.hidden) 507 | nhidden = network.nhidden(l); 508 | 509 | last = last + numel(network.hidden(l).iweights.matrix); 510 | network.hidden(l).iweights.matrix = reshape(dW(start:last),4*nhidden,ninput)+ network.hidden(l).iweights.matrix ; 511 | start = last + 1; 512 | 513 | last = last + numel(network.hidden(l).biases.v); 514 | network.hidden(l).biases.v = reshape(dW(start:last),4*nhidden,1)+ network.hidden(l).biases.v ; 515 | start = last + 1; 516 | 517 | for i=1:length(network.hidden(l).weights); 518 | last = last + numel(network.hidden(l).weights(i).matrix); 519 | network.hidden(l).weights(i).matrix = reshape(dW(start:last),4*nhidden,nhidden)+network.hidden(l).weights(i).matrix; 520 | start = last+1; 521 | if l>1 522 | last = last + numel(network.hidden(l).hweights(i).matrix); 523 | network.hidden(l).hweights(i).matrix = reshape(dW(start:last),4*nhidden,network.nhidden(l-1))+network.hidden(l).hweights(i).matrix; 524 | start = last+1; 525 | 526 | end 527 | 528 | 529 | 530 | end 531 | 532 | 533 | 534 | 535 | 536 | 537 | last = last+ numel(network.output.weights(l).matrix); 538 | network.output.weights(l).matrix = reshape(dW(start:last),noutput,nhidden)+ network.output.weights(l).matrix ; 539 | start=last+1; 540 | 541 | end 542 | 543 | end 544 | 545 | function vect=weights2vect(allvects) 546 | lsum = 0; 547 | lengths = cell(length(allvects),1); 548 | for i=1:length(allvects) 549 | lsum = lsum + numel(allvects{i}); 550 | lengths{i}= lsum; 551 | 552 | 553 | end 554 | vect = zeros(lsum,1,'single'); 555 | 556 | vect(1:lengths{1}) = gather(reshape(allvects{1},lengths{1},1)); 557 | for i=2:length(allvects) 558 | vect(lengths{i-1}+1:lengths{i}) = gather(reshape(allvects{i},lengths{i}-lengths{i-1},1)); 559 | end 560 | 561 | 562 | end 563 | 564 | 565 | 566 | function network = initpass(network,nbatch,maxt) 567 | 568 | ninput = network.input.n; 569 | 570 | noutput = network.output.n; 571 | 572 | 573 | for l=1:length(network.hidden) 574 | 575 | nhidden = network.nhidden(l); 576 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 577 | if ~network.last 578 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 579 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 580 | else 581 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 582 | 583 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 584 | end 585 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 586 | for i=1:length(network.hidden(l).weights); 587 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 588 | 589 | 590 | end 591 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 592 | if l>1 593 | network.hidden(l).hweights(i).gradient = gpuArray(zeros(nhidden*4,network.nhidden(l-1),'single')); 594 | 595 | end 596 | 597 | 598 | 599 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 600 | network.hidden(l).ins.v = gpuArray(zeros(4*nhidden,nbatch,maxt,'single')); 601 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 602 | 603 | 604 | 605 | end 606 | 607 | 608 | 609 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 610 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 611 | 612 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 613 | 614 | 615 | end 616 | 617 | function network = initpasstest(network,nbatch,maxt) 618 | 619 | ninput = network.input.n; 620 | 621 | noutput = network.output.n; 622 | 623 | 624 | for l=1:length(network.hidden) 625 | 626 | nhidden = network.nhidden(l); 627 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 628 | if ~network.last 629 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 630 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 631 | else 632 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 633 | 634 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 635 | end 636 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 637 | for i=1:length(network.hidden(l).weights); 638 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 639 | 640 | 641 | end 642 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 643 | if l>1 644 | network.hidden(l).hweights(i).gradient = gpuArray(zeros(nhidden*4,network.nhidden(l-1),'single')); 645 | 646 | end 647 | 648 | 649 | 650 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,length(network.storeind),'single')); 651 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,length(network.storeind),'single')); 652 | 653 | 654 | 655 | end 656 | 657 | 658 | 659 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 660 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 661 | 662 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 663 | 664 | 665 | end 666 | function network = initnetwork(ninput,nhidden,noutput) 667 | 668 | 669 | 670 | 671 | network.input.n = ninput; 672 | network.nhidden = nhidden; 673 | network.output.n = noutput; 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | for j = 1:(length(network.nhidden)) 684 | 685 | nhidden = network.nhidden(j); 686 | network.hidden(j).hidind = (1:nhidden)'; 687 | network.hidden(j).writeind = (nhidden+1:2*nhidden)'; 688 | network.hidden(j).keepind = (2*nhidden+1:3*nhidden)'; 689 | network.hidden(j).readind = (3*nhidden+1:4*nhidden)'; 690 | network.hidden(j).gateind = (nhidden+1:4*nhidden)'; 691 | 692 | network.hidden(j).iweights.matrix = gpuArray(.1*(randn(nhidden*4,ninput,'single'))); 693 | network.hidden(j).biases.v = gpuArray(zeros(4*nhidden,1,'single')); 694 | network.hidden(j).biases.v(network.hidden(j).keepind)=3; 695 | network.hidden(j).biases.v(network.hidden(j).readind)=-2; 696 | network.hidden(j).biases.v(network.hidden(j).writeind)=0; 697 | network.hidden(j).iweights.gated = 0; 698 | 699 | 700 | 701 | network.hidden(j).weights.matrix =gpuArray(.0001*(randn(nhidden*4,nhidden,'single'))); 702 | 703 | 704 | 705 | 706 | if j>1 707 | network.hidden(j).hweights.matrix =gpuArray(.01*(randn(nhidden*4,network.nhidden(j-1),'single'))); 708 | end 709 | 710 | 711 | network.hidden(j).fx = @sigmoid; 712 | network.hidden(j).dx = @sigdir; 713 | 714 | 715 | network.output.weights(j).matrix = gpuArray(.1*(randn(noutput,nhidden,'single'))); 716 | 717 | 718 | 719 | end 720 | 721 | 722 | network.nparam = length(weights2vect(getW(network))); 723 | 724 | 725 | 726 | network.output.fx = @softmax; 727 | network.output.dx = @softdir; 728 | network.errorFunc = @evalCrossEntropy; 729 | network.output.getHessian = @CrossEntropyHessian; 730 | 731 | 732 | 733 | 734 | 735 | end 736 | 737 | 738 | function J = getJ(network) 739 | jtot=1; 740 | J = cell(jtot,1); 741 | c=1; 742 | for l=1:length(network.hidden) 743 | J{c}= network.hidden(l).iweights.gradient; 744 | c=c+1; 745 | network.hidden(l).biases.j(network.hidden(l).hidind)=0; 746 | J{c}=network.hidden(l).biases.j; 747 | c=c+1; 748 | 749 | for i = 1:length(network.hidden(l).weights); 750 | J{c}=network.hidden(l).weights(i).gradient; 751 | c=c+1; 752 | if l>1 753 | J{c}=network.hidden(l).hweights(i).gradient; 754 | c=c+1; 755 | end 756 | 757 | 758 | end 759 | 760 | J{c} = network.output.weights(l).gradient; 761 | c=c+1; 762 | end 763 | 764 | 765 | 766 | end 767 | function W = getW(network) 768 | jtot=1; 769 | W = cell(jtot,1); 770 | c=1; 771 | for l=1:length(network.hidden) 772 | 773 | W{c}= network.hidden(l).iweights.matrix; 774 | c=c+1; 775 | 776 | W{c}= network.hidden(l).biases.v; 777 | c=c+1; 778 | 779 | 780 | for i = 1:length(network.hidden(l).weights); 781 | W{c}=network.hidden(l).weights(i).matrix; 782 | c=c+1; 783 | if l>1 784 | W{c}=network.hidden(l).hweights(i).matrix; 785 | c=c+1; 786 | end 787 | 788 | end 789 | 790 | W{c} = network.output.weights(l).matrix; 791 | c=c+1; 792 | end 793 | 794 | 795 | 796 | end 797 | 798 | 799 | 800 | 801 | 802 | function f= sigmoid(x) 803 | 804 | 805 | f= 1./(1+ exp(-1.*x)); 806 | end 807 | 808 | function o = softdir(x); 809 | 810 | o=ones(size(x),'single'); 811 | 812 | 813 | end 814 | function o = softmax(x) 815 | 816 | o=bsxfun(@times,1./sum(exp(x),1),exp(x)); 817 | end 818 | function dir = sigdir( y ) 819 | 820 | dir = y.*(1-y); 821 | 822 | 823 | end 824 | function dir = tanhdir( y ) 825 | 826 | dir = (1-y.*y); 827 | 828 | 829 | end 830 | 831 | 832 | 833 | 834 | %function m=gather(m) 835 | 836 | %end 837 | %function m=gpuArray(m) 838 | 839 | %end -------------------------------------------------------------------------------- /matlab/mLSTMeval.m: -------------------------------------------------------------------------------- 1 | function mLSTMeval 2 | 3 | 4 | 5 | gpuDevice(1) 6 | 7 | sequence = processtextfile('enwik8'); 8 | 9 | weightsfname = 'mLSTMhutter.mat'; 10 | 11 | nunits = 205; 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | network = initnetwork(nunits,[1900],nunits); 20 | 21 | network.last = 0; 22 | 23 | 24 | 25 | network=updateV(network,-1*weights2vect(getW(network))); 26 | 27 | wfile = load(weightsfname,'W'); 28 | W = wfile.W; 29 | 30 | network = updateV(network,W); 31 | 32 | disp(network.nparam) 33 | 34 | 35 | 36 | seqlen = 50; 37 | 38 | 39 | 40 | 41 | 42 | 43 | megabatch = 1000000; 44 | minibatch = 10000; 45 | jit = megabatch/minibatch; 46 | 47 | tic 48 | 49 | start = 95000000;; 50 | 51 | fin = start+1*megabatch-1; 52 | 53 | 54 | 55 | serr = 0; 56 | 57 | network.last = 0; 58 | nbatch = minibatch/seqlen; 59 | 60 | for k=1:5 61 | 62 | 63 | [in0,ta0] = byte2input(sequence(start:fin),sequence(start+1:fin+1),nunits,seqlen); 64 | 65 | start = start+megabatch; 66 | fin = fin+megabatch; 67 | 68 | 69 | 70 | for j=1:jit 71 | 72 | 73 | 74 | in1= in0(:,j:jit:size(in0,2),:); 75 | ta1= ta0(:,j:jit:size(ta0,2),:); 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | network.zero = 1; 88 | 89 | network = initpass(network,size(in1,2),size(in1,3)); 90 | 91 | 92 | network = ForwardPass(network,in1); 93 | network.output.outs.v = .999999*network.output.outs.v + .000001*(ones(size(network.output.outs.v)))/size(network.output.outs.v,1); 94 | 95 | 96 | 97 | terr = evalCrossEntropy(network.output.outs.v,ta1,ones(size(ta1),'single')); 98 | 99 | serr = serr+ (terr/(nbatch*seqlen)); 100 | 101 | err = serr/(jit*(k-1)+j); 102 | 103 | 104 | network.last = 1; 105 | 106 | if mod(j,100)==0 107 | 108 | disp(err) 109 | end 110 | 111 | end 112 | 113 | end 114 | disp('final error') 115 | disp(err) 116 | 117 | end 118 | 119 | function [inputs,targets] =byte2input(inputs,targets,nunits,seqlen) 120 | 121 | 122 | inputs= single(inputs); 123 | targets = single(targets); 124 | 125 | in = zeros(nunits,1,length(inputs),'single'); 126 | targ = zeros(nunits,1,length(targets),'single'); 127 | 128 | ind = sub2ind([nunits,1,length(inputs)],inputs,ones(length(inputs),1),(1:length(inputs))'); 129 | tind = sub2ind([nunits,1,length(inputs)],targets,ones(length(inputs),1),(1:length(inputs))'); 130 | in(ind)=1; 131 | targ(tind) = 1; 132 | 133 | inputs=permute(reshape(in,[size(in,1),seqlen,size(in,3)/seqlen]),[1,3,2]); 134 | targets=permute(reshape(targ,[size(targ,1),seqlen,size(targ,3)/seqlen]),[1,3,2]); 135 | 136 | 137 | 138 | 139 | end 140 | 141 | function gradient = getGradient(network,inputs,targets,npar) 142 | 143 | 144 | pbatch = size(inputs,2)/npar; 145 | citer = 1:pbatch:size(inputs,2); 146 | oind = ones(size(targets)); 147 | in = cell(npar,1);ta = cell(npar,1);oi = cell(npar,1); 148 | for ci=1:length(citer) 149 | c = citer(ci); 150 | in{ci} = inputs(:,c:c+pbatch-1,:); 151 | ta{ci} = targets(:,c:c+pbatch-1,:); 152 | oi{ci} = oind(:,c:c+pbatch-1,:); 153 | end 154 | 155 | gradient = zeros(network.nparam,1); 156 | 157 | 158 | for z=1:npar 159 | net = initpass(network,size(in{z},2),size(in{z},3)); 160 | net = ForwardPass(net,in{z}); 161 | net = computegradient(net,ta{z},ones(size(ta{z}))); 162 | 163 | 164 | gradient = gradient + weights2vect(getJ(net)); 165 | end 166 | 167 | 168 | 169 | 170 | end 171 | 172 | 173 | 174 | function [err] = test(network,inputs,targets,oind) 175 | errsum = 0; 176 | errcount = 0; 177 | 178 | nbatch = size(inputs,2); 179 | 180 | 181 | input = inputs; 182 | network = initpasstest(network,nbatch,size(input,3)); 183 | network = ForwardPasstest(network,input); 184 | network.output.outs.v = .999999*network.output.outs.v + .000001*(ones(size(network.output.outs.v)))/size(network.output.outs.v,1); 185 | [terr]=network.errorFunc(network.output.outs.v,targets,oind); 186 | errsum = errsum + terr; 187 | 188 | errcount = errcount+1; 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | err = errsum/errcount; 198 | 199 | 200 | 201 | 202 | 203 | 204 | end 205 | 206 | function network = ForwardPass(network, inputs) 207 | 208 | inputs = gpuArray(inputs); 209 | network.input.outs.v = inputs; 210 | 211 | 212 | for t=1:size(inputs,3); 213 | l=1; 214 | 215 | 216 | 217 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 218 | 219 | if t-1>0 220 | 221 | 222 | network.hidden(l).intermediates.v(:,:,t) = network.hidden(l).mweights.matrix*network.hidden(l).outs.v(:,:,t-1); 223 | 224 | else 225 | 226 | network.hidden(l).intermediates.v(:,:,t) = network.hidden(l).mweights.matrix*network.hidden(l).outs.vp0; 227 | end 228 | 229 | network.hidden(l).factor.v(:,:,t) = network.hidden(l).fweights.matrix*inputs(:,:,t); 230 | network.hidden(l).mult.v(:,:,t) = network.hidden(l).factor.v(:,:,t).*network.hidden(l).intermediates.v(:,:,t); 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).mult.v(:,:,t); 239 | network.hidden(l).ins.v(network.gateind,:,t)= bsxfun(@plus,network.hidden(l).ins.v(network.gateind,:,t),network.hidden(l).biases.v(network.gateind,:)); 240 | network.hidden(l).ins.v(network.gateind,:,t) = sigmoid(network.hidden(l).ins.v(network.gateind,:,t)); 241 | 242 | 243 | 244 | 245 | 246 | network.hidden(l).ins.state(:,:,t)= network.hidden(l).ins.v(network.hidind,:,t).*network.hidden(l).ins.v(network.writeind,:,t); 247 | 248 | 249 | 250 | if t-1>0 251 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t)+network.hidden(l).ins.state(:,:,t-1).*network.hidden(l).ins.v(network.keepind,:,t); 252 | else 253 | network.hidden(l).ins.state(:,:,t) =network.hidden(l).ins.state(:,:,t)+ network.hidden(l).ins.statep0.*network.hidden(l).ins.v(network.keepind,:,t);; 254 | end 255 | 256 | 257 | network.hidden(l).outs.v(:,:,t) = network.hidden(l).ins.state(:,:,t).*network.hidden(l).ins.v(network.readind,:,t);; 258 | network.hidden(l).outs.v(:,:,t)=bsxfun(@plus,network.hidden(l).outs.v(:,:,t),network.hidden(l).biases.v(network.hidind,:)); 259 | network.hidden(l).outs.v(:,:,t) = tanh(network.hidden(l).outs.v(:,:,t)); 260 | 261 | 262 | 263 | 264 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*network.hidden(l).outs.v(:,:,t); 265 | 266 | 267 | if t==size(inputs,3) 268 | network.hidden(l).outs.last = network.hidden(l).outs.v(:,:,t); 269 | network.hidden(l).ins.last = network.hidden(l).ins.state(:,:,t); 270 | end 271 | 272 | 273 | end 274 | 275 | network.output.outs.v = network.output.fx(network.output.outs.v); 276 | 277 | 278 | 279 | end 280 | function network = ForwardPasstest(network, inputs) 281 | 282 | inputs = gpuArray(inputs); 283 | network.input.outs.v = inputs; 284 | 285 | 286 | 287 | hidden.outs.vp = gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 288 | hidden.ins.statep = gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 289 | hidden.mult.v=gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 290 | 291 | for l=1:length(network.hidden) 292 | hidden(l).outs.vp = network.hidden(l).outs.vp0 ; 293 | hidden(l).ins.statep = network.hidden(l).ins.statep0; 294 | end 295 | s=1; 296 | for t=1:size(inputs,3); 297 | l=1; 298 | 299 | 300 | 301 | 302 | hidden.ins.v= network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | hidden.intermediates.v = network.hidden(l).mweights.matrix*hidden(l).outs.vp; 316 | hidden.factor.v = network.hidden(l).fweights.matrix*inputs(:,:,t); 317 | hidden.mult.v = hidden(l).factor.v.*hidden(l).intermediates.v; 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | hidden.ins.v = hidden(l).ins.v + network.hidden.weights.matrix*hidden(l).mult.v; 329 | hidden(l).ins.v(network.gateind,:)= bsxfun(@plus,hidden(l).ins.v(network.gateind,:),network.hidden(l).biases.v(network.gateind,:)); 330 | hidden.ins.v(network.gateind,:) = sigmoid(hidden.ins.v(network.gateind,:)); 331 | 332 | 333 | 334 | hidden.ins.state=hidden.ins.v(network.hidind,:).*hidden.ins.v(network.writeind,:); 335 | 336 | 337 | 338 | 339 | hidden.ins.state = hidden.ins.state + hidden(l).ins.statep.*hidden.ins.v(network.keepind,:); 340 | 341 | 342 | 343 | 344 | hidden.outs.v = hidden(l).ins.state.*hidden.ins.v(network.readind,:); 345 | hidden(l).outs.v=bsxfun(@plus,hidden(l).outs.v,network.hidden(l).biases.v(network.hidind,:)); 346 | hidden.outs.v = tanh(hidden(l).outs.v); 347 | 348 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*hidden(l).outs.v; 349 | 350 | hidden.outs.vp = hidden(l).outs.v; 351 | hidden.ins.statep = hidden(l).ins.state; 352 | 353 | 354 | 355 | end 356 | 357 | 358 | network.output.outs.v = network.output.fx(network.output.outs.v); 359 | 360 | 361 | 362 | end 363 | 364 | 365 | 366 | 367 | function network = computegradient(network, targets,omat) 368 | oind = find(omat); 369 | 370 | network.output.outs.j(oind) = network.output.outs.v(oind)- targets(oind); 371 | 372 | network.output.outs.j = network.output.outs.j.*network.output.dx(network.output.outs.v); 373 | 374 | 375 | 376 | 377 | 378 | hidden.ins.statej = gpuArray(zeros(network.nhidden,size(targets,2),'single')); 379 | hidden.outs.j = gpuArray(zeros(network.nhidden,size(targets,2),'single')); 380 | hidden.ins.j = gpuArray(zeros(network.nhidden*4,size(network.input.outs.v,2),'single')); 381 | for t=size(network.input.outs.v,3):-1:1; 382 | 383 | 384 | 385 | l=1; 386 | 387 | network.output.weights(l).gradient = network.output.weights(l).gradient + network.output.outs.j(:,:,t)*network.hidden(l).outs.v(:,:,t)'; 388 | 389 | hidden.outs.j = hidden.outs.j + network.output.weights(l).matrix'*network.output.outs.j(:,:,t) ; 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | hidden(l).outs.j = hidden(l).outs.j.*tanhdir(network.hidden(l).outs.v(:,:,t)); 398 | network.hidden(l).biases.j(network.hidind,:) = network.hidden(l).biases.j(network.hidind,:) + sum(hidden(l).outs.j,2);%biases only have 1 dimension 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | hidden(l).ins.j(network.readind,:) = hidden(l).outs.j.* network.hidden(l).ins.state(:,:,t) ; 407 | 408 | hidden(l).ins.statej = hidden(l).ins.statej + hidden(l).outs.j.*network.hidden(l).ins.v(network.readind,:,t); 409 | 410 | 411 | 412 | 413 | ttemp = t-1; 414 | if ttemp>0 415 | 416 | hidden(l).ins.statejp = hidden(l).ins.statej.*network.hidden(l).ins.v(network.keepind,:,t); 417 | 418 | 419 | end 420 | 421 | 422 | 423 | 424 | if ttemp>0 425 | 426 | 427 | hidden(l).ins.j(network.keepind,:) = network.hidden(l).ins.state(:,:,t-1).*hidden(l).ins.statej; 428 | else 429 | 430 | 431 | hidden(l).ins.j(network.keepind,:)= network.hidden(l).ins.statep0.*hidden(l).ins.statej; 432 | end 433 | 434 | 435 | hidden(l).ins.j(network.writeind,:) = network.hidden(l).ins.v(network.hidind,:,t).*hidden(l).ins.statej; 436 | hidden(l).ins.j(network.hidind,:) = network.hidden(l).ins.v(network.writeind,:,t).*hidden(l).ins.statej; 437 | 438 | hidden(l).ins.j(network.gateind,:)= hidden(l).ins.j(network.gateind,:).*sigdir(network.hidden(l).ins.v(network.gateind,:,t)); 439 | network.hidden(l).biases.j(network.gateind,:) = network.hidden(l).biases.j(network.gateind,:) + sum(hidden(l).ins.j(network.gateind,:),2); 440 | 441 | 442 | 443 | hidden(l).mult.j = network.hidden(l).weights.matrix'*hidden(l).ins.j; 444 | 445 | 446 | 447 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + (hidden(l).ins.j)*network.hidden(l).mult.v(:,:,t)'; 448 | 449 | 450 | 451 | 452 | if t-1>0 453 | 454 | hidden(l).intermediates.j = hidden(l).mult.j.*network.hidden(l).factor.v(:,:,t); 455 | 456 | hidden(l).factor.j= hidden(l).mult.j.*network.hidden(l).intermediates.v(:,:,t); 457 | 458 | network.hidden(l).fweights.gradient = network.hidden(l).fweights.gradient + hidden(l).factor.j*network.input.outs.v(:,:,t)'; 459 | 460 | hidden(l).outs.jp = network.hidden(l).mweights.matrix'*hidden(l).intermediates.j; 461 | network.hidden(l).mweights.gradient = network.hidden(l).mweights.gradient + hidden(l).intermediates.j*network.hidden(l).outs.v(:,:,t-1)'; 462 | else 463 | hidden(l).intermediates.j = hidden(l).mult.j.*network.hidden(l).factor.v(:,:,t); 464 | 465 | hidden(l).factor.j= hidden(l).mult.j.*network.hidden(l).intermediates.v(:,:,t); 466 | 467 | network.hidden(l).fweights.gradient = network.hidden(l).fweights.gradient + hidden(l).factor.j*network.input.outs.v(:,:,t)'; 468 | 469 | hidden(l).outs.jp = network.hidden(l).mweights.matrix'*hidden(l).intermediates.j; 470 | network.hidden(l).mweights.gradient = network.hidden(l).mweights.gradient + hidden(l).intermediates.j*network.hidden(l).outs.vp0'; 471 | 472 | 473 | 474 | 475 | end 476 | 477 | 478 | network.hidden(l).iweights.gradient = network.hidden(l).iweights.gradient + (hidden(l).ins.j)*network.input.outs.v(:,:,t)'; 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | hidden(l).outs.j = hidden(l).outs.jp; 487 | hidden(l).ins.statej = hidden(l).ins.statejp; 488 | 489 | 490 | 491 | end 492 | 493 | end 494 | 495 | 496 | function [lcost] = evalCrossEntropy(output,targets,omat) 497 | 498 | 499 | 500 | oind = find(omat); 501 | 502 | ldiff = targets.*log2(output); 503 | 504 | 505 | lcost = -1*sum(ldiff(:)); 506 | 507 | 508 | 509 | 510 | end 511 | 512 | function network = updateV(network, dW) 513 | 514 | ninput = network.input.n; 515 | noutput = network.output.n; 516 | 517 | 518 | start = 1; 519 | last = 0; 520 | 521 | for l=1:length(network.hidden) 522 | nhidden = network.nhidden(l); 523 | 524 | last = last + numel(network.hidden(l).iweights.matrix); 525 | network.hidden(l).iweights.matrix = reshape(dW(start:last),4*nhidden,ninput)+ network.hidden(l).iweights.matrix ; 526 | start = last + 1; 527 | 528 | last = last + numel(network.hidden(l).biases.v); 529 | network.hidden(l).biases.v = reshape(dW(start:last),4*nhidden,1)+ network.hidden(l).biases.v ; 530 | start = last + 1; 531 | 532 | 533 | last = last + numel(network.hidden(l).fweights.matrix); 534 | network.hidden(l).fweights.matrix = reshape(dW(start:last),nhidden,ninput)+ network.hidden(l).fweights.matrix ; 535 | start = last + 1; 536 | 537 | 538 | for i=1:length(network.hidden(l).weights); 539 | last = last + numel(network.hidden(l).weights(i).matrix); 540 | network.hidden(l).weights(i).matrix = reshape(dW(start:last),4*nhidden,nhidden)+network.hidden(l).weights(i).matrix; 541 | start = last+1; 542 | last = last + numel(network.hidden(l).mweights(i).matrix); 543 | network.hidden(l).mweights(i).matrix = reshape(dW(start:last),nhidden,nhidden)+network.hidden(l).mweights(i).matrix; 544 | start = last+1; 545 | 546 | 547 | end 548 | 549 | 550 | 551 | 552 | 553 | 554 | last = last+ numel(network.output.weights(l).matrix); 555 | network.output.weights(l).matrix = reshape(dW(start:last),noutput,nhidden)+ network.output.weights(l).matrix ; 556 | start=last+1; 557 | 558 | end 559 | 560 | end 561 | 562 | function vect=weights2vect(allvects) 563 | lsum = 0; 564 | lengths = cell(length(allvects),1); 565 | for i=1:length(allvects) 566 | lsum = lsum + numel(allvects{i}); 567 | lengths{i}= lsum; 568 | 569 | 570 | end 571 | vect = zeros(lsum,1,'single'); 572 | 573 | vect(1:lengths{1}) = gather(reshape(allvects{1},lengths{1},1)); 574 | for i=2:length(allvects) 575 | vect(lengths{i-1}+1:lengths{i}) = gather(reshape(allvects{i},lengths{i}-lengths{i-1},1)); 576 | end 577 | 578 | 579 | end 580 | 581 | 582 | 583 | 584 | function network = initpass(network,nbatch,maxt) 585 | 586 | ninput = network.input.n; 587 | 588 | noutput = network.output.n; 589 | 590 | for l=1:length(network.hidden) 591 | 592 | nhidden = network.nhidden(l); 593 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 594 | 595 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 596 | if ~network.last 597 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 598 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 599 | else 600 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 601 | 602 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 603 | end 604 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 605 | 606 | 607 | network.hidden(l).ins.v = gpuArray(zeros(nhidden*4,nbatch,maxt,'single')); 608 | 609 | 610 | network.hidden(l).intermediates.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 611 | 612 | network.hidden(l).factor.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 613 | 614 | network.hidden(l).mult.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 615 | 616 | 617 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 618 | 619 | 620 | 621 | 622 | 623 | for i=1:length(network.hidden(l).weights); 624 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 625 | network.hidden(l).mweights(i).gradient = gpuArray(zeros(nhidden,nhidden,'single')); 626 | 627 | end 628 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 629 | network.hidden(l).fweights.gradient = gpuArray(zeros(nhidden,ninput,'single')); 630 | 631 | 632 | 633 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 634 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 635 | 636 | 637 | 638 | end 639 | 640 | 641 | 642 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 643 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 644 | 645 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 646 | 647 | 648 | end 649 | 650 | function network = initpasstest(network,nbatch,maxt) 651 | 652 | ninput = network.input.n; 653 | 654 | noutput = network.output.n; 655 | 656 | for l=1:length(network.hidden) 657 | 658 | nhidden = network.nhidden(l); 659 | 660 | 661 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 662 | if ~network.last 663 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 664 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 665 | else 666 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 667 | 668 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 669 | end 670 | 671 | 672 | 673 | 674 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,'single')); 675 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,'single')); 676 | 677 | 678 | 679 | end 680 | 681 | 682 | 683 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 684 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 685 | 686 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 687 | 688 | 689 | end 690 | 691 | 692 | function network = initnetwork(ninput,nhidden,noutput) 693 | 694 | 695 | network.input.n = ninput; 696 | network.nhidden = nhidden; 697 | network.output.n = noutput; 698 | 699 | 700 | 701 | 702 | network.hidind = (1:nhidden)'; 703 | network.writeind = (nhidden+1:2*nhidden)'; 704 | network.keepind = (2*nhidden+1:3*nhidden)'; 705 | network.readind = (3*nhidden+1:4*nhidden)'; 706 | network.gateind = (nhidden+1:4*nhidden)'; 707 | for j = 1:1 708 | nhidden = network.nhidden(j); 709 | 710 | 711 | network.hidden(j).iweights.matrix = gpuArray(.1*(randn(nhidden*4,ninput,'single'))); 712 | network.hidden(j).fweights.matrix = gpuArray(.1*(randn(nhidden,ninput,'single'))); 713 | network.hidden(j).iweights.gated = 0; 714 | 715 | 716 | network.hidden(j).biases.v = gpuArray(zeros(4*nhidden,1,'single')); 717 | network.hidden(j).biases.v(network.keepind)=3; 718 | 719 | 720 | network.hidden(j).weights.matrix =gpuArray(.02*(randn(nhidden*4,nhidden,'single'))); 721 | network.hidden(j).mweights.matrix =gpuArray(.02*(randn(nhidden,nhidden,'single'))); 722 | 723 | 724 | 725 | 726 | network.hidden(j).fx = @sigmoid; 727 | network.hidden(j).dx = @sigdir; 728 | network.output.weights(j).matrix = gpuArray(.1*(randn(noutput,nhidden,'single'))); 729 | network.output.weights(j).utime = 0; 730 | network.output.weights(j).index = j; 731 | network.output.weights(j).gated = 0; 732 | 733 | 734 | end 735 | 736 | 737 | network.nparam = length(weights2vect(getW(network))); 738 | 739 | 740 | network.output.fx = @softmax; 741 | network.output.dx = @softdirXent; 742 | network.errorFunc = @evalCrossEntropy; 743 | 744 | end 745 | 746 | 747 | function J = getJ(network) 748 | jtot=1; 749 | J = cell(jtot,1); 750 | c=1; 751 | for l=1:length(network.hidden) 752 | J{c}= network.hidden(l).iweights.gradient; 753 | c=c+1; 754 | J{c}= network.hidden(l).biases.j; 755 | c=c+1; 756 | J{c}= network.hidden(l).fweights.gradient; 757 | c=c+1; 758 | 759 | 760 | for i = 1:length(network.hidden(l).weights); 761 | J{c}=network.hidden(l).weights(i).gradient; 762 | c=c+1; 763 | J{c}=network.hidden(l).mweights(i).gradient; 764 | c=c+1; 765 | 766 | 767 | end 768 | 769 | J{c} = 1*network.output.weights(l).gradient; 770 | c=c+1; 771 | end 772 | 773 | 774 | 775 | end 776 | function W = getW(network) 777 | jtot=1; 778 | W = cell(jtot,1); 779 | c=1; 780 | for l=1:length(network.hidden) 781 | 782 | W{c}= network.hidden(l).iweights.matrix; 783 | c=c+1; 784 | W{c}= network.hidden(l).biases.v; 785 | c=c+1; 786 | W{c}= network.hidden(l).fweights.matrix; 787 | c=c+1; 788 | 789 | 790 | 791 | for i = 1:length(network.hidden(l).weights); 792 | W{c}=network.hidden(l).weights(i).matrix; 793 | c=c+1; 794 | W{c}=network.hidden(l).mweights(i).matrix; 795 | c=c+1; 796 | 797 | end 798 | 799 | W{c} = network.output.weights(l).matrix; 800 | c=c+1; 801 | end 802 | 803 | 804 | 805 | end 806 | 807 | 808 | 809 | 810 | 811 | function f= sigmoid(x) 812 | 813 | 814 | f= 1./(1+ exp(-1.*x)); 815 | end 816 | 817 | function o = softdirXent(x); 818 | 819 | o=ones(size(x),'single'); 820 | 821 | 822 | end 823 | 824 | function dir = sigdir( y ) 825 | 826 | dir = y.*(1-y); 827 | 828 | 829 | end 830 | function dir = tanhdir( y ) 831 | 832 | dir = (1-y.*y); 833 | 834 | 835 | end 836 | function o = softmax(x) 837 | 838 | o=bsxfun(@times,1./sum(exp(x),1),exp(x)); 839 | end 840 | 841 | %function m=gather(m) 842 | 843 | %end 844 | %function m=gpuArray(m) 845 | 846 | %end -------------------------------------------------------------------------------- /matlab/LSTMhutter.m: -------------------------------------------------------------------------------- 1 | function LSTMhutter 2 | 3 | 4 | %to run on CPU, comment out gpuDevice command and uncomment function gather 5 | %and function gpuArray and end of file 6 | gpuDevice(1) 7 | 8 | logfname = 'LSTMhutter.txt'; 9 | weightsfname = 'LSTMhutter.mat'; 10 | 11 | sequence = processtextfile('enwik8'); 12 | 13 | 14 | valstart = 90*10^6; 15 | nunits = max(sequence); 16 | 17 | valind = valstart:(valstart+(5*10^5)-1); 18 | [vinputs,vtargets] = byte2input(sequence(valind),sequence(valind+1),nunits,1000); 19 | voind = single(ones(size(vtargets))); 20 | 21 | 22 | 23 | 24 | 25 | nhid = [1250,1250]; 26 | network = initnetwork(nunits,nhid,nunits); 27 | 28 | 29 | network.last = 0; 30 | 31 | 32 | disp(network.nparam) 33 | 34 | 35 | 36 | 37 | log = fopen(logfname,'w'); 38 | nepochs = 1000; 39 | 40 | [verr] = test(network,vinputs,vtargets,voind); 41 | disp(verr) 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | berr = verr; 53 | 54 | npar = 1; 55 | 56 | 57 | 58 | MS = zeros(network.nparam,1,'single'); 59 | 60 | %hyperparameters 61 | thresh = 2; 62 | threshdecay = .9999; 63 | dec = .9; 64 | 65 | 66 | numepochs=4; 67 | 68 | megabatch = 1000000; 69 | minibatch = 10000; 70 | jit = megabatch/minibatch; 71 | 72 | 73 | tic 74 | start = 1; 75 | fin = start+megabatch-1; 76 | 77 | 78 | 79 | %set training set size 80 | maxtrain = 95*10^6; 81 | 82 | co = 0; 83 | for i=1:nepochs 84 | 85 | 86 | if i>1 87 | start = start+megabatch; 88 | fin = fin+megabatch; 89 | if fin > maxtrain 90 | start = 1; 91 | fin = start+megabatch-1; 92 | co=co+1; 93 | end 94 | if co==numepochs 95 | return 96 | end 97 | end 98 | 99 | seqlen = 100; 100 | 101 | 102 | 103 | [in1,ta1] = byte2input(single(sequence(start:fin)),single(sequence(start+1:fin+1)),nunits,seqlen); 104 | network.last = 0; 105 | disp(jit) 106 | jit = 100; 107 | 108 | 109 | for j=1:jit 110 | 111 | 112 | thresh = threshdecay*thresh; 113 | if thresh < .01 114 | thresh = .01; 115 | end 116 | 117 | 118 | 119 | 120 | in= in1(:,j:jit:size(in1,2),:); 121 | ta= ta1(:,j:jit:size(ta1,2),:); 122 | 123 | 124 | 125 | 126 | 127 | 128 | gradient = getGradient(network,in,ta,1)/minibatch; 129 | 130 | if (1-1/(j) < dec) && i==1 131 | dec1 = 1-1/j; 132 | 133 | else 134 | dec1 = dec; 135 | end 136 | 137 | 138 | MS = dec1*MS + (1-dec1)*gradient.^2; 139 | dW = gradient./(sqrt(MS)+.000001); 140 | 141 | dW(isnan(dW)) = 0; 142 | norm = sqrt(dW'*dW); 143 | 144 | dW = dW*(thresh/norm); 145 | 146 | dW = -1*dW; 147 | 148 | 149 | network = updateV(network,dW); 150 | network = initpass(network,size(in,2),size(in,3)); 151 | network = ForwardPass(network, in); 152 | network.last = 1; 153 | toc 154 | 155 | 156 | 157 | 158 | 159 | if j==jit; 160 | network.last = 0; 161 | randind = floor(rand()*(maxtrain-(5*10^5))); 162 | randind = randind:(randind+5*10^5 -1); 163 | [intrain,targtrain] = byte2input(sequence(randind),sequence(randind+1),nunits,1000); 164 | 165 | 166 | [err] = test(network,intrain,targtrain,ones(size(targtrain),'single')); 167 | [verr]= test(network,vinputs,vtargets,voind); 168 | 169 | 170 | 'err' 171 | display(err); 172 | display(verr); 173 | derr = (berr - verr)/(berr); 174 | 175 | 176 | bitchar = err/(size(targtrain,3)*size(targtrain,2)); 177 | 178 | vbitchar = verr/(size(vtargets,3)*size(vtargets,2)); 179 | 180 | fprintf(log,'iter: %i train %f val %f \n',i,bitchar,vbitchar); 181 | disp(derr) 182 | 183 | berr = verr; 184 | W = weights2vect(getW(network)); 185 | save(weightsfname,'W'); 186 | 187 | end 188 | end 189 | 190 | 191 | end 192 | toc 193 | 194 | 195 | 196 | 197 | end 198 | function [inputs,targets] =byte2input(inputs,targets,nunits,seqlen) 199 | 200 | 201 | inputs= single(inputs); 202 | targets = single(targets); 203 | 204 | in = zeros(nunits,1,length(inputs),'single'); 205 | targ = zeros(nunits,1,length(targets),'single'); 206 | ind = sub2ind([nunits,1,length(inputs)],inputs,ones(length(inputs),1),(1:length(inputs))'); 207 | tind = sub2ind([nunits,1,length(inputs)],targets,ones(length(inputs),1),(1:length(inputs))'); 208 | in(ind)=1; 209 | targ(tind) = 1; 210 | 211 | inputs=permute(reshape(in,[size(in,1),seqlen,size(in,3)/seqlen]),[1,3,2]); 212 | targets=permute(reshape(targ,[size(targ,1),seqlen,size(targ,3)/seqlen]),[1,3,2]); 213 | 214 | 215 | 216 | 217 | end 218 | 219 | 220 | 221 | function gradient = getGradient(network,inputs,targets,npar) 222 | 223 | 224 | pbatch = size(inputs,2)/npar; 225 | citer = 1:pbatch:size(inputs,2); 226 | oind = ones(size(targets)); 227 | in = cell(npar,1);ta = cell(npar,1);oi = cell(npar,1); 228 | for ci=1:length(citer) 229 | c = citer(ci); 230 | in{ci} = inputs(:,c:c+pbatch-1,:); 231 | ta{ci} = targets(:,c:c+pbatch-1,:); 232 | oi{ci} = oind(:,c:c+pbatch-1,:); 233 | end 234 | 235 | gradient = zeros(network.nparam,1); 236 | 237 | for z=1:npar 238 | 239 | net = initpass(network,size(in{z},2),size(in{z},3)); 240 | net = ForwardPass(net,in{z}); 241 | net = computegradient(net,ta{z},ones(size(ta{z}))); 242 | 243 | 244 | gradient = gradient + weights2vect(getJ(net)); 245 | end 246 | 247 | 248 | 249 | 250 | end 251 | 252 | function [err] = test(network,inputs,targets,oind) 253 | errsum = 0; 254 | errcount = 0; 255 | 256 | nbatch = size(inputs,2); 257 | 258 | 259 | input = inputs; 260 | network = initpasstest(network,nbatch,size(input,3)); 261 | network = ForwardPasstest(network,input); 262 | network.output.outs.v = .999999*network.output.outs.v + .000001*(ones(size(network.output.outs.v)))/size(network.output.outs.v,1); 263 | 264 | [terr]=network.errorFunc(network.output.outs.v,targets,oind); 265 | errsum = errsum + terr; 266 | 267 | errcount = errcount+1; 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | err = errsum/errcount; 277 | 278 | 279 | 280 | 281 | 282 | 283 | end 284 | 285 | 286 | 287 | 288 | function network = ForwardPass(network, inputs) 289 | 290 | inputs = gpuArray(inputs); 291 | network.input.outs.v=inputs; 292 | 293 | 294 | 295 | s=1; 296 | for t=1:size(inputs,3); 297 | for l=1:length(network.hidden); 298 | 299 | 300 | 301 | 302 | network.hidden(l).ins.v(:,:,t)= network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | if t>1 316 | 317 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).outs.v(:,:,t-1); 318 | else 319 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).outs.vp0 ; 320 | end 321 | 322 | if l>1 323 | 324 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).hweights.matrix*network.hidden(l-1).outs.v(:,:,t) ; 325 | end 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | network.hidden(l).ins.v(network.hidden(l).gateind,:,t)= bsxfun(@plus,network.hidden(l).ins.v(network.hidden(l).gateind,:,t),network.hidden(l).biases.v(network.hidden(l).gateind,:)); 336 | network.hidden(l).ins.v(network.hidden(l).gateind,:,t) = sigmoid(network.hidden(l).ins.v(network.hidden(l).gateind,:,t)); 337 | 338 | 339 | 340 | network.hidden(l).ins.state(:,:,t)=network.hidden(l).ins.v(network.hidden(l).hidind,:,t).*network.hidden(l).ins.v(network.hidden(l).writeind,:,t); 341 | 342 | 343 | 344 | if t>1 345 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t) + network.hidden(l).ins.state(:,:,t-1).*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 346 | else 347 | 348 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t) +network.hidden(l).ins.statep0.*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 349 | end 350 | 351 | 352 | 353 | network.hidden(l).outs.v(:,:,t) = network.hidden(l).ins.state(:,:,t).*network.hidden(l).ins.v(network.hidden(l).readind,:,t); 354 | network.hidden(l).outs.v(:,:,t)=bsxfun(@plus,network.hidden(l).outs.v(:,:,t),network.hidden(l).biases.v(network.hidden(l).hidind,:)); 355 | network.hidden(l).outs.v(:,:,t) = tanh(network.hidden(l).outs.v(:,:,t)); 356 | 357 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*network.hidden(l).outs.v(:,:,t); 358 | 359 | 360 | 361 | if t==size(inputs,3) 362 | network.hidden(l).outs.last = network.hidden(l).outs.v(:,:,t); 363 | network.hidden(l).ins.last = network.hidden(l).ins.state(:,:,t); 364 | end 365 | 366 | end 367 | end 368 | 369 | network.output.outs.v = network.output.fx(network.output.outs.v); 370 | 371 | 372 | end 373 | function [network] = ForwardPasstest(network, inputs) 374 | 375 | inputs = gpuArray(inputs); 376 | network.input.outs.v=inputs; 377 | 378 | 379 | for l=1:length(network.hidden) 380 | hidden(l).outs.vp = network.hidden(l).outs.vp0 ; 381 | hidden(l).ins.statep = network.hidden(l).ins.statep0; 382 | end 383 | 384 | for t=1:size(inputs,3); 385 | for l=1:length(network.hidden); 386 | 387 | 388 | 389 | 390 | hidden(l).ins.v= network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | hidden(l).ins.v = hidden(l).ins.v + network.hidden(l).weights.matrix*hidden(l).outs.vp; 405 | 406 | if l>1 407 | 408 | hidden(l).ins.v = hidden(l).ins.v + network.hidden(l).hweights.matrix*hidden(l-1).outs.v; 409 | end 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | hidden(l).ins.v(network.hidden(l).gateind,:)= bsxfun(@plus,hidden(l).ins.v(network.hidden(l).gateind,:),network.hidden(l).biases.v(network.hidden(l).gateind,:)); 418 | hidden(l).ins.v(network.hidden(l).gateind,:) = sigmoid(hidden(l).ins.v(network.hidden(l).gateind,:)); 419 | 420 | 421 | 422 | hidden(l).ins.state=hidden(l).ins.v(network.hidden(l).hidind,:).*hidden(l).ins.v(network.hidden(l).writeind,:); 423 | 424 | ttemp = t-1; 425 | 426 | 427 | hidden(l).ins.state = hidden(l).ins.state + hidden(l).ins.statep.*hidden(l).ins.v(network.hidden(l).keepind,:); 428 | 429 | 430 | 431 | 432 | hidden(l).outs.v = hidden(l).ins.state.*hidden(l).ins.v(network.hidden(l).readind,:); 433 | hidden(l).outs.v=bsxfun(@plus,hidden(l).outs.v,network.hidden(l).biases.v(network.hidden(l).hidind,:)); 434 | hidden(l).outs.v = tanh(hidden(l).outs.v); 435 | 436 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*hidden(l).outs.v; 437 | 438 | hidden(l).outs.vp = hidden(l).outs.v; 439 | hidden(l).ins.statep = hidden(l).ins.state; 440 | 441 | 442 | 443 | 444 | end 445 | end 446 | 447 | network.output.outs.v = network.output.fx(network.output.outs.v); 448 | 449 | 450 | 451 | end 452 | 453 | function network = computegradient(network, targets,omat) 454 | oind = find(omat); 455 | 456 | network.output.outs.j(oind) = network.output.outs.v(oind)- targets(oind); 457 | 458 | network.output.outs.j = network.output.outs.j.*network.output.dx(network.output.outs.v); 459 | 460 | 461 | 462 | 463 | for l=1:length(network.hidden) 464 | hidden(l).ins.statej = gpuArray(zeros(network.nhidden(l),size(targets,2),'single')); 465 | hidden(l).outs.j = gpuArray(zeros(network.nhidden(l),size(targets,2),'single')); 466 | hidden(l).ins.j = gpuArray(zeros(network.nhidden(l)*4,size(network.input.outs.v,2),'single')); 467 | end 468 | for t=size(network.input.outs.v,3):-1:1; 469 | 470 | 471 | 472 | for l= length(network.hidden):-1:1 473 | 474 | network.output.weights(l).gradient = network.output.weights(l).gradient + network.output.outs.j(:,:,t)*network.hidden(l).outs.v(:,:,t)'; 475 | 476 | hidden(l).outs.j = hidden(l).outs.j + network.output.weights(l).matrix'*network.output.outs.j(:,:,t) ; 477 | 478 | 479 | 480 | 481 | hidden(l).outs.j = hidden(l).outs.j.*tanhdir(network.hidden(l).outs.v(:,:,t)); 482 | network.hidden(l).biases.j(network.hidden(l).hidind,:) = network.hidden(l).biases.j(network.hidden(l).hidind,:) + sum(hidden(l).outs.j,2); 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | hidden(l).ins.j(network.hidden(l).readind,:) = hidden(l).outs.j.* network.hidden(l).ins.state(:,:,t) ; 491 | 492 | hidden(l).ins.statej = hidden(l).ins.statej + hidden(l).outs.j.*network.hidden(l).ins.v(network.hidden(l).readind,:,t); 493 | 494 | ttemp = t-1; 495 | if ttemp>0 496 | 497 | hidden(l).ins.statejp = hidden(l).ins.statej.*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 498 | 499 | 500 | 501 | 502 | end 503 | 504 | 505 | 506 | 507 | if ttemp>0 508 | 509 | 510 | hidden(l).ins.j(network.hidden(l).keepind,:) = network.hidden(l).ins.state(:,:,t-1).*hidden(l).ins.statej; 511 | else 512 | 513 | hidden(l).ins.j(network.hidden(l).keepind,:)= network.hidden(l).ins.statep0.*hidden(l).ins.statej; 514 | 515 | end 516 | 517 | 518 | hidden(l).ins.j(network.hidden(l).writeind,:) = network.hidden(l).ins.v(network.hidden(l).hidind,:,t).*hidden(l).ins.statej; 519 | hidden(l).ins.j(network.hidden(l).hidind,:) = network.hidden(l).ins.v(network.hidden(l).writeind,:,t).*hidden(l).ins.statej; 520 | 521 | hidden(l).ins.j(network.hidden(l).gateind,:)= hidden(l).ins.j(network.hidden(l).gateind,:).*sigdir(network.hidden(l).ins.v(network.hidden(l).gateind,:,t)); 522 | network.hidden(l).biases.j(network.hidden(l).gateind,:) = network.hidden(l).biases.j(network.hidden(l).gateind,:) + sum(hidden(l).ins.j(network.hidden(l).gateind,:),2); 523 | 524 | 525 | 526 | if t-1>0 527 | hidden(l).outs.jp = network.hidden(l).weights.matrix'*hidden(l).ins.j; 528 | 529 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + hidden(l).ins.j*network.hidden(l).outs.v(:,:,t-1)'; 530 | 531 | 532 | 533 | else 534 | 535 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + hidden(l).ins.j*network.hidden(l).outs.vp0'; 536 | 537 | 538 | end 539 | 540 | if l>1 541 | hidden(l-1).outs.j = hidden(l-1).outs.j + network.hidden(l).hweights.matrix'*hidden(l).ins.j; 542 | network.hidden(l).hweights.gradient = network.hidden(l).hweights.gradient + hidden(l).ins.j*network.hidden(l-1).outs.v(:,:,t)'; 543 | end 544 | 545 | network.hidden(l).iweights.gradient = network.hidden(l).iweights.gradient + (hidden(l).ins.j)*network.input.outs.v(:,:,t)'; 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | hidden(l).outs.j = hidden(l).outs.jp; 554 | hidden(l).ins.statej = hidden(l).ins.statejp; 555 | end 556 | 557 | 558 | 559 | end 560 | 561 | end 562 | 563 | 564 | function [lcost] = evalCrossEntropy(output,targets,omat) 565 | 566 | 567 | 568 | oind = find(omat); 569 | 570 | ldiff = targets.*log2(output); 571 | 572 | 573 | lcost = -1*sum(ldiff(:)); 574 | 575 | 576 | 577 | 578 | end 579 | 580 | function network = updateV(network, dW) 581 | 582 | ninput = network.input.n; 583 | noutput = network.output.n; 584 | 585 | 586 | start = 1; 587 | last = 0; 588 | 589 | for l=1:length(network.hidden) 590 | nhidden = network.nhidden(l); 591 | 592 | last = last + numel(network.hidden(l).iweights.matrix); 593 | network.hidden(l).iweights.matrix = reshape(dW(start:last),4*nhidden,ninput)+ network.hidden(l).iweights.matrix ; 594 | start = last + 1; 595 | 596 | last = last + numel(network.hidden(l).biases.v); 597 | network.hidden(l).biases.v = reshape(dW(start:last),4*nhidden,1)+ network.hidden(l).biases.v ; 598 | start = last + 1; 599 | 600 | for i=1:length(network.hidden(l).weights); 601 | last = last + numel(network.hidden(l).weights(i).matrix); 602 | network.hidden(l).weights(i).matrix = reshape(dW(start:last),4*nhidden,nhidden)+network.hidden(l).weights(i).matrix; 603 | start = last+1; 604 | if l>1 605 | last = last + numel(network.hidden(l).hweights(i).matrix); 606 | network.hidden(l).hweights(i).matrix = reshape(dW(start:last),4*nhidden,network.nhidden(l-1))+network.hidden(l).hweights(i).matrix; 607 | start = last+1; 608 | 609 | end 610 | 611 | 612 | 613 | end 614 | 615 | 616 | 617 | 618 | 619 | 620 | last = last+ numel(network.output.weights(l).matrix); 621 | network.output.weights(l).matrix = reshape(dW(start:last),noutput,nhidden)+ network.output.weights(l).matrix ; 622 | start=last+1; 623 | 624 | end 625 | 626 | end 627 | 628 | function vect=weights2vect(allvects) 629 | lsum = 0; 630 | lengths = cell(length(allvects),1); 631 | for i=1:length(allvects) 632 | lsum = lsum + numel(allvects{i}); 633 | lengths{i}= lsum; 634 | 635 | 636 | end 637 | vect = zeros(lsum,1,'single'); 638 | 639 | vect(1:lengths{1}) = gather(reshape(allvects{1},lengths{1},1)); 640 | for i=2:length(allvects) 641 | vect(lengths{i-1}+1:lengths{i}) = gather(reshape(allvects{i},lengths{i}-lengths{i-1},1)); 642 | end 643 | 644 | 645 | end 646 | 647 | 648 | 649 | function network = initpass(network,nbatch,maxt) 650 | 651 | ninput = network.input.n; 652 | 653 | noutput = network.output.n; 654 | 655 | 656 | for l=1:length(network.hidden) 657 | 658 | nhidden = network.nhidden(l); 659 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 660 | if ~network.last 661 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 662 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 663 | else 664 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 665 | 666 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 667 | end 668 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 669 | for i=1:length(network.hidden(l).weights); 670 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 671 | 672 | 673 | end 674 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 675 | if l>1 676 | network.hidden(l).hweights(i).gradient = gpuArray(zeros(nhidden*4,network.nhidden(l-1),'single')); 677 | 678 | end 679 | 680 | 681 | 682 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 683 | network.hidden(l).ins.v = gpuArray(zeros(4*nhidden,nbatch,maxt,'single')); 684 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 685 | 686 | 687 | 688 | end 689 | 690 | 691 | 692 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 693 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 694 | 695 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 696 | 697 | 698 | end 699 | 700 | function network = initpasstest(network,nbatch,maxt) 701 | 702 | ninput = network.input.n; 703 | 704 | noutput = network.output.n; 705 | 706 | 707 | for l=1:length(network.hidden) 708 | 709 | nhidden = network.nhidden(l); 710 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 711 | if ~network.last 712 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 713 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 714 | else 715 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 716 | 717 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 718 | end 719 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 720 | for i=1:length(network.hidden(l).weights); 721 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 722 | 723 | 724 | end 725 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 726 | if l>1 727 | network.hidden(l).hweights(i).gradient = gpuArray(zeros(nhidden*4,network.nhidden(l-1),'single')); 728 | 729 | end 730 | 731 | 732 | 733 | % network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 734 | % network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 735 | 736 | 737 | 738 | end 739 | 740 | 741 | 742 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 743 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 744 | 745 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 746 | 747 | 748 | end 749 | function network = initnetwork(ninput,nhidden,noutput) 750 | 751 | 752 | 753 | 754 | network.input.n = ninput; 755 | network.nhidden = nhidden; 756 | network.output.n = noutput; 757 | 758 | 759 | 760 | 761 | 762 | 763 | 764 | 765 | 766 | for j = 1:(length(network.nhidden)) 767 | 768 | nhidden = network.nhidden(j); 769 | network.hidden(j).hidind = (1:nhidden)'; 770 | network.hidden(j).writeind = (nhidden+1:2*nhidden)'; 771 | network.hidden(j).keepind = (2*nhidden+1:3*nhidden)'; 772 | network.hidden(j).readind = (3*nhidden+1:4*nhidden)'; 773 | network.hidden(j).gateind = (nhidden+1:4*nhidden)'; 774 | 775 | network.hidden(j).iweights.matrix = gpuArray(.1*(randn(nhidden*4,ninput,'single'))); 776 | network.hidden(j).biases.v = gpuArray(zeros(4*nhidden,1,'single')); 777 | network.hidden(j).biases.v(network.hidden(j).keepind)=3; 778 | network.hidden(j).biases.v(network.hidden(j).readind)=-2; 779 | network.hidden(j).biases.v(network.hidden(j).writeind)=0; 780 | network.hidden(j).iweights.gated = 0; 781 | 782 | 783 | 784 | network.hidden(j).weights.matrix =gpuArray(.0001*(randn(nhidden*4,nhidden,'single'))); 785 | 786 | 787 | 788 | 789 | if j>1 790 | network.hidden(j).hweights.matrix =gpuArray(.01*(randn(nhidden*4,network.nhidden(j-1),'single'))); 791 | end 792 | 793 | 794 | network.hidden(j).fx = @sigmoid; 795 | network.hidden(j).dx = @sigdir; 796 | 797 | 798 | network.output.weights(j).matrix = gpuArray(.1*(randn(noutput,nhidden,'single'))); 799 | 800 | 801 | 802 | end 803 | 804 | 805 | network.nparam = length(weights2vect(getW(network))); 806 | 807 | 808 | 809 | network.output.fx = @softmax; 810 | network.output.dx = @softdir; 811 | network.errorFunc = @evalCrossEntropy; 812 | network.output.getHessian = @CrossEntropyHessian; 813 | 814 | 815 | 816 | 817 | 818 | end 819 | 820 | 821 | function J = getJ(network) 822 | jtot=1; 823 | J = cell(jtot,1); 824 | c=1; 825 | for l=1:length(network.hidden) 826 | J{c}= network.hidden(l).iweights.gradient; 827 | c=c+1; 828 | network.hidden(l).biases.j(network.hidden(l).hidind)=0; 829 | J{c}=network.hidden(l).biases.j; 830 | c=c+1; 831 | 832 | for i = 1:length(network.hidden(l).weights); 833 | J{c}=network.hidden(l).weights(i).gradient; 834 | c=c+1; 835 | if l>1 836 | J{c}=network.hidden(l).hweights(i).gradient; 837 | c=c+1; 838 | end 839 | 840 | 841 | end 842 | 843 | J{c} = network.output.weights(l).gradient; 844 | c=c+1; 845 | end 846 | 847 | 848 | 849 | end 850 | function W = getW(network) 851 | jtot=1; 852 | W = cell(jtot,1); 853 | c=1; 854 | for l=1:length(network.hidden) 855 | 856 | W{c}= network.hidden(l).iweights.matrix; 857 | c=c+1; 858 | 859 | W{c}= network.hidden(l).biases.v; 860 | c=c+1; 861 | 862 | 863 | for i = 1:length(network.hidden(l).weights); 864 | W{c}=network.hidden(l).weights(i).matrix; 865 | c=c+1; 866 | if l>1 867 | W{c}=network.hidden(l).hweights(i).matrix; 868 | c=c+1; 869 | end 870 | 871 | end 872 | 873 | W{c} = network.output.weights(l).matrix; 874 | c=c+1; 875 | end 876 | 877 | 878 | 879 | end 880 | 881 | 882 | 883 | 884 | 885 | function f= sigmoid(x) 886 | 887 | 888 | f= 1./(1+ exp(-1.*x)); 889 | end 890 | 891 | function o = softdir(x); 892 | 893 | o=ones(size(x),'single'); 894 | 895 | 896 | end 897 | function o = softmax(x) 898 | 899 | o=bsxfun(@times,1./sum(exp(x),1),exp(x)); 900 | end 901 | function dir = sigdir( y ) 902 | 903 | dir = y.*(1-y); 904 | 905 | 906 | end 907 | function dir = tanhdir( y ) 908 | 909 | dir = (1-y.*y); 910 | 911 | 912 | end 913 | 914 | 915 | 916 | 917 | %function m=gather(m) 918 | 919 | %end 920 | %function m=gpuArray(m) 921 | 922 | %end 923 | 924 | 925 | 926 | 927 | 928 | 929 | 930 | 931 | 932 | -------------------------------------------------------------------------------- /matlab/LSTMdynamic.m: -------------------------------------------------------------------------------- 1 | function LSTMdynamic 2 | 3 | 4 | gpuDevice(1) 5 | 6 | sequence = processtextfile('enwik8'); 7 | 8 | weightsfname = 'LSTMhutter.mat'; 9 | 10 | 11 | nunits = 205; 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | network = initnetwork(nunits,[1250,1250],nunits); 20 | 21 | network.last = 0; 22 | 23 | 24 | 25 | network=updateV(network,-1*weights2vect(getW(network))); 26 | 27 | wfile = load(weightsfname,'W'); 28 | W = wfile.W; 29 | 30 | network = updateV(network,W); 31 | 32 | disp(network.nparam) 33 | 34 | 35 | seqlen = 50; 36 | network.storeind =[20;40]; 37 | 38 | 39 | tic 40 | 41 | start = 96000000; 42 | fin = start+seqlen-1; 43 | 44 | 45 | serr = 0; 46 | dec = .99; 47 | epsilon = .0001; 48 | 49 | network.last = 0; 50 | network0 = network; 51 | alpha = .01; 52 | network = RMSinit(network); 53 | for i=1:(4*10^6)/seqlen 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | [in1,ta1] = byte2input(sequence(start:fin),sequence(start+1:fin+1),nunits,seqlen); 66 | start = start + seqlen; 67 | fin = fin + seqlen; 68 | network = initpass(network,size(in1,2),size(in1,3)); 69 | 70 | network = ForwardPass(network,in1); 71 | network.output.outs.v = .999999*network.output.outs.v + .000001*(ones(size(network.output.outs.v)))/size(network.output.outs.v,1); 72 | terr = evalCrossEntropy(network.output.outs.v,ta1,ones(size(ta1),'single')); 73 | 74 | serr = serr+ (terr/seqlen); 75 | 76 | err = serr/i; 77 | 78 | network = computegradient(network,ta1,ones(size(ta1),'single')); 79 | 80 | 81 | 82 | if 1-1/i < dec 83 | dec1 = 1-1/i; 84 | 85 | else 86 | dec1 = dec; 87 | end 88 | 89 | network = RMSprop(network,network0,dec1,alpha,epsilon,seqlen); 90 | network = initpass(network,size(in1,2),size(in1,3)); 91 | network = ForwardPass(network, in1); 92 | network.last = 1; 93 | 94 | if mod(i,100)==0 95 | disp(i) 96 | disp(err) 97 | end 98 | 99 | end 100 | disp('final error') 101 | disp(err) 102 | 103 | end 104 | 105 | function network = RMSinit(network) 106 | 107 | 108 | 109 | for l=1:length(network.hidden); 110 | network.hidden(l).iweights.MS = gpuArray(zeros(size(network.hidden(l).iweights.matrix),'single')); 111 | network.hidden(l).weights.MS = gpuArray(zeros(size(network.hidden(l).weights.matrix),'single')); 112 | 113 | 114 | network.output.weights(l).MS = gpuArray(zeros(size(network.output.weights(l).matrix),'single')); 115 | end 116 | end 117 | function network = RMSprop(network,network0,dec,alpha,epsilon,n) 118 | 119 | l=1; 120 | for l=1:length(network.hidden); 121 | network.hidden(l).iweights.gradient=network.hidden(l).iweights.gradient/n; 122 | network.hidden(l).weights.gradient=network.hidden(l).weights.gradient/n; 123 | 124 | network.output.weights(l).gradient=network.output.weights(l).gradient/n; 125 | 126 | 127 | 128 | 129 | network.hidden(l).iweights.MS = network.hidden(l).iweights.MS*dec + network.hidden(l).iweights.gradient.^2*(1-dec); 130 | network.hidden(l).weights.MS = network.hidden(l).weights.MS*dec + network.hidden(l).weights.gradient.^2*(1-dec); 131 | 132 | 133 | 134 | network.output.weights(l).MS = network.output.weights(l).MS*dec + network.output.weights(l).gradient.^2*(1-dec); 135 | 136 | 137 | 138 | 139 | network.hidden(l).iweights.dW = -1*network.hidden(l).iweights.gradient./(sqrt(network.hidden(l).iweights.MS+.000001)); 140 | 141 | 142 | network.hidden(l).weights.dW = -1*network.hidden(l).weights.gradient./(sqrt(network.hidden(l).weights.MS)+.000001); 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | network.output.weights(l).dW = -1*network.output.weights(l).gradient./(sqrt(network.output.weights(l).MS)+.000001); 152 | 153 | end 154 | 155 | 156 | 157 | 158 | 159 | for l=1:length(network.hidden) 160 | network.hidden(l).iweights.dW = network.hidden(l).iweights.dW*epsilon; 161 | network.hidden(l).weights.dW = network.hidden(l).weights.dW*epsilon; 162 | 163 | network.output.weights(l).dW = network.output.weights(l).dW*epsilon; 164 | 165 | network.hidden(l).iweights.dW = network.hidden(l).iweights.dW+alpha*(network0.hidden(l).iweights.matrix-network.hidden(l).iweights.matrix); 166 | network.hidden(l).weights.dW = network.hidden(l).weights.dW+alpha*(network0.hidden(l).weights.matrix-network.hidden(l).weights.matrix); 167 | 168 | network.output.weights(l).dW = network.output.weights(l).dW+alpha*(network0.output.weights(l).matrix-network.output.weights(l).matrix); 169 | 170 | 171 | 172 | 173 | network.hidden(l).iweights.matrix = network.hidden(l).iweights.matrix + network.hidden(l).iweights.dW; 174 | network.hidden(l).weights.matrix = network.hidden(l).weights.matrix + network.hidden(l).weights.dW; 175 | 176 | network.output.weights(l).matrix = network.output.weights(l).matrix + network.output.weights(l).dW; 177 | end 178 | end 179 | function [inputs,targets] =byte2input(inputs,targets,nunits,seqlen) 180 | 181 | 182 | inputs= single(inputs); 183 | targets = single(targets); 184 | 185 | in = zeros(nunits,1,length(inputs),'single'); 186 | targ = zeros(nunits,1,length(targets),'single'); 187 | ind = sub2ind([nunits,1,length(inputs)],inputs,ones(length(inputs),1),(1:length(inputs))'); 188 | tind = sub2ind([nunits,1,length(inputs)],targets,ones(length(inputs),1),(1:length(inputs))'); 189 | in(ind)=1; 190 | targ(tind) = 1; 191 | 192 | inputs=permute(reshape(in,[size(in,1),seqlen,size(in,3)/seqlen]),[1,3,2]); 193 | targets=permute(reshape(targ,[size(targ,1),seqlen,size(targ,3)/seqlen]),[1,3,2]); 194 | 195 | 196 | 197 | 198 | end 199 | 200 | 201 | 202 | function gradient = getGradient(network,inputs,targets,npar) 203 | 204 | 205 | pbatch = size(inputs,2)/npar; 206 | citer = 1:pbatch:size(inputs,2); 207 | oind = ones(size(targets)); 208 | in = cell(npar,1);ta = cell(npar,1);oi = cell(npar,1); 209 | for ci=1:length(citer) 210 | c = citer(ci); 211 | in{ci} = inputs(:,c:c+pbatch-1,:); 212 | ta{ci} = targets(:,c:c+pbatch-1,:); 213 | oi{ci} = oind(:,c:c+pbatch-1,:); 214 | end 215 | 216 | gradient = zeros(network.nparam,1); 217 | for z=1:npar 218 | 219 | net = initpass(network,size(in{z},2),size(in{z},3)); 220 | net = ForwardPass(net,in{z}); 221 | net = computegradient(net,ta{z},ones(size(ta{z}))); 222 | 223 | 224 | gradient = gradient + weights2vect(getJ(net)); 225 | end 226 | 227 | 228 | 229 | 230 | end 231 | 232 | function [err] = test(network,inputs,targets,oind) 233 | errsum = 0; 234 | errcount = 0; 235 | 236 | nbatch = size(inputs,2); 237 | 238 | 239 | input = inputs; 240 | network = initpasstest(network,nbatch,size(input,3)); 241 | network = ForwardPasstest(network,input); 242 | [terr]=network.errorFunc(network.output.outs.v,targets,oind); 243 | errsum = errsum + terr; 244 | 245 | errcount = errcount+1; 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | err = errsum/errcount; 255 | 256 | 257 | 258 | 259 | 260 | 261 | end 262 | 263 | 264 | 265 | 266 | function network = ForwardPass(network, inputs) 267 | 268 | inputs = gpuArray(inputs); 269 | network.input.outs.v=inputs; 270 | 271 | 272 | 273 | s=1; 274 | for t=1:size(inputs,3); 275 | for l=1:length(network.hidden); 276 | 277 | 278 | 279 | 280 | network.hidden(l).ins.v(:,:,t)= network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | if t>1 294 | 295 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).outs.v(:,:,t-1); 296 | else 297 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).outs.vp0 ; 298 | end 299 | 300 | if l>1 301 | 302 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).hweights.matrix*network.hidden(l-1).outs.v(:,:,t) ; 303 | end 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | network.hidden(l).ins.v(network.hidden(l).gateind,:,t)= bsxfun(@plus,network.hidden(l).ins.v(network.hidden(l).gateind,:,t),network.hidden(l).biases.v(network.hidden(l).gateind,:)); 314 | network.hidden(l).ins.v(network.hidden(l).gateind,:,t) = sigmoid(network.hidden(l).ins.v(network.hidden(l).gateind,:,t)); 315 | 316 | 317 | 318 | network.hidden(l).ins.state(:,:,t)=network.hidden(l).ins.v(network.hidden(l).hidind,:,t).*network.hidden(l).ins.v(network.hidden(l).writeind,:,t); 319 | 320 | 321 | 322 | if t>1 323 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t) + network.hidden(l).ins.state(:,:,t-1).*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 324 | else 325 | 326 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t) +network.hidden(l).ins.statep0.*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 327 | end 328 | 329 | 330 | 331 | network.hidden(l).outs.v(:,:,t) = network.hidden(l).ins.state(:,:,t).*network.hidden(l).ins.v(network.hidden(l).readind,:,t); 332 | network.hidden(l).outs.v(:,:,t)=bsxfun(@plus,network.hidden(l).outs.v(:,:,t),network.hidden(l).biases.v(network.hidden(l).hidind,:)); 333 | network.hidden(l).outs.v(:,:,t) = tanh(network.hidden(l).outs.v(:,:,t)); 334 | 335 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*network.hidden(l).outs.v(:,:,t); 336 | 337 | 338 | 339 | if t==size(inputs,3) 340 | network.hidden(l).outs.last = network.hidden(l).outs.v(:,:,t); 341 | network.hidden(l).ins.last = network.hidden(l).ins.state(:,:,t); 342 | end 343 | 344 | end 345 | end 346 | 347 | network.output.outs.v = network.output.fx(network.output.outs.v); 348 | 349 | 350 | end 351 | function [network] = ForwardPasstest(network, inputs) 352 | 353 | inputs = gpuArray(inputs); 354 | network.input.outs.v=inputs; 355 | 356 | 357 | for l=1:length(network.hidden) 358 | hidden(l).outs.vp = network.hidden(l).outs.vp0 ; 359 | hidden(l).ins.statep = network.hidden(l).ins.statep0; 360 | end 361 | 362 | for t=1:size(inputs,3); 363 | for l=1:length(network.hidden); 364 | 365 | 366 | 367 | 368 | hidden(l).ins.v= network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | hidden(l).ins.v = hidden(l).ins.v + network.hidden(l).weights.matrix*hidden(l).outs.vp; 383 | 384 | if l>1 385 | 386 | hidden(l).ins.v = hidden(l).ins.v + network.hidden(l).hweights.matrix*hidden(l-1).outs.v; 387 | end 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | hidden(l).ins.v(network.hidden(l).gateind,:)= bsxfun(@plus,hidden(l).ins.v(network.hidden(l).gateind,:),network.hidden(l).biases.v(network.hidden(l).gateind,:)); 396 | hidden(l).ins.v(network.hidden(l).gateind,:) = sigmoid(hidden(l).ins.v(network.hidden(l).gateind,:)); 397 | 398 | 399 | 400 | hidden(l).ins.state=hidden(l).ins.v(network.hidden(l).hidind,:).*hidden(l).ins.v(network.hidden(l).writeind,:); 401 | 402 | ttemp = t-1; 403 | 404 | 405 | hidden(l).ins.state = hidden(l).ins.state + hidden(l).ins.statep.*hidden(l).ins.v(network.hidden(l).keepind,:); 406 | 407 | 408 | 409 | 410 | hidden(l).outs.v = hidden(l).ins.state.*hidden(l).ins.v(network.hidden(l).readind,:); 411 | hidden(l).outs.v=bsxfun(@plus,hidden(l).outs.v,network.hidden(l).biases.v(network.hidden(l).hidind,:)); 412 | hidden(l).outs.v = tanh(hidden(l).outs.v); 413 | 414 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*hidden(l).outs.v; 415 | 416 | hidden(l).outs.vp = hidden(l).outs.v; 417 | hidden(l).ins.statep = hidden(l).ins.state; 418 | 419 | 420 | 421 | 422 | end 423 | end 424 | 425 | network.output.outs.v = network.output.fx(network.output.outs.v); 426 | 427 | 428 | 429 | end 430 | 431 | function network = computegradient(network, targets,omat) 432 | oind = find(omat); 433 | 434 | network.output.outs.j(oind) = network.output.outs.v(oind)- targets(oind); 435 | 436 | network.output.outs.j = network.output.outs.j.*network.output.dx(network.output.outs.v); 437 | 438 | 439 | 440 | 441 | for l=1:length(network.hidden) 442 | hidden(l).ins.statej = gpuArray(zeros(network.nhidden(l),size(targets,2),'single')); 443 | hidden(l).outs.j = gpuArray(zeros(network.nhidden(l),size(targets,2),'single')); 444 | hidden(l).ins.j = gpuArray(zeros(network.nhidden(l)*4,size(network.input.outs.v,2),'single')); 445 | end 446 | for t=size(network.input.outs.v,3):-1:1; 447 | 448 | 449 | 450 | for l= length(network.hidden):-1:1 451 | 452 | network.output.weights(l).gradient = network.output.weights(l).gradient + network.output.outs.j(:,:,t)*network.hidden(l).outs.v(:,:,t)'; 453 | 454 | hidden(l).outs.j = hidden(l).outs.j + network.output.weights(l).matrix'*network.output.outs.j(:,:,t) ; 455 | 456 | 457 | 458 | 459 | hidden(l).outs.j = hidden(l).outs.j.*tanhdir(network.hidden(l).outs.v(:,:,t)); 460 | network.hidden(l).biases.j(network.hidden(l).hidind,:) = network.hidden(l).biases.j(network.hidden(l).hidind,:) + sum(hidden(l).outs.j,2); 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | hidden(l).ins.j(network.hidden(l).readind,:) = hidden(l).outs.j.* network.hidden(l).ins.state(:,:,t) ; 469 | 470 | hidden(l).ins.statej = hidden(l).ins.statej + hidden(l).outs.j.*network.hidden(l).ins.v(network.hidden(l).readind,:,t); 471 | 472 | ttemp = t-1; 473 | if ttemp>0 474 | 475 | hidden(l).ins.statejp = hidden(l).ins.statej.*network.hidden(l).ins.v(network.hidden(l).keepind,:,t); 476 | 477 | 478 | 479 | 480 | end 481 | 482 | 483 | 484 | 485 | if ttemp>0 486 | 487 | 488 | hidden(l).ins.j(network.hidden(l).keepind,:) = network.hidden(l).ins.state(:,:,t-1).*hidden(l).ins.statej; 489 | else 490 | 491 | hidden(l).ins.j(network.hidden(l).keepind,:)= network.hidden(l).ins.statep0.*hidden(l).ins.statej; 492 | 493 | end 494 | 495 | 496 | hidden(l).ins.j(network.hidden(l).writeind,:) = network.hidden(l).ins.v(network.hidden(l).hidind,:,t).*hidden(l).ins.statej; 497 | hidden(l).ins.j(network.hidden(l).hidind,:) = network.hidden(l).ins.v(network.hidden(l).writeind,:,t).*hidden(l).ins.statej; 498 | 499 | hidden(l).ins.j(network.hidden(l).gateind,:)= hidden(l).ins.j(network.hidden(l).gateind,:).*sigdir(network.hidden(l).ins.v(network.hidden(l).gateind,:,t)); 500 | network.hidden(l).biases.j(network.hidden(l).gateind,:) = network.hidden(l).biases.j(network.hidden(l).gateind,:) + sum(hidden(l).ins.j(network.hidden(l).gateind,:),2); 501 | 502 | 503 | 504 | if t-1>0 505 | hidden(l).outs.jp = network.hidden(l).weights.matrix'*hidden(l).ins.j; 506 | 507 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + hidden(l).ins.j*network.hidden(l).outs.v(:,:,t-1)'; 508 | 509 | 510 | 511 | else 512 | 513 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + hidden(l).ins.j*network.hidden(l).outs.vp0'; 514 | 515 | 516 | end 517 | 518 | if l>1 519 | hidden(l-1).outs.j = hidden(l-1).outs.j + network.hidden(l).hweights.matrix'*hidden(l).ins.j; 520 | network.hidden(l).hweights.gradient = network.hidden(l).hweights.gradient + hidden(l).ins.j*network.hidden(l-1).outs.v(:,:,t)'; 521 | end 522 | 523 | network.hidden(l).iweights.gradient = network.hidden(l).iweights.gradient + (hidden(l).ins.j)*network.input.outs.v(:,:,t)'; 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | hidden(l).outs.j = hidden(l).outs.jp; 532 | hidden(l).ins.statej = hidden(l).ins.statejp; 533 | end 534 | 535 | 536 | 537 | end 538 | 539 | end 540 | 541 | 542 | function [lcost] = evalCrossEntropy(output,targets,omat) 543 | 544 | 545 | 546 | oind = find(omat); 547 | 548 | ldiff = targets.*log2(output); 549 | 550 | 551 | lcost = -1*sum(ldiff(:)); 552 | 553 | 554 | 555 | 556 | end 557 | 558 | function network = updateV(network, dW) 559 | 560 | ninput = network.input.n; 561 | noutput = network.output.n; 562 | 563 | 564 | start = 1; 565 | last = 0; 566 | 567 | for l=1:length(network.hidden) 568 | nhidden = network.nhidden(l); 569 | 570 | last = last + numel(network.hidden(l).iweights.matrix); 571 | network.hidden(l).iweights.matrix = reshape(dW(start:last),4*nhidden,ninput)+ network.hidden(l).iweights.matrix ; 572 | start = last + 1; 573 | 574 | last = last + numel(network.hidden(l).biases.v); 575 | network.hidden(l).biases.v = reshape(dW(start:last),4*nhidden,1)+ network.hidden(l).biases.v ; 576 | start = last + 1; 577 | 578 | for i=1:length(network.hidden(l).weights); 579 | last = last + numel(network.hidden(l).weights(i).matrix); 580 | network.hidden(l).weights(i).matrix = reshape(dW(start:last),4*nhidden,nhidden)+network.hidden(l).weights(i).matrix; 581 | start = last+1; 582 | if l>1 583 | last = last + numel(network.hidden(l).hweights(i).matrix); 584 | network.hidden(l).hweights(i).matrix = reshape(dW(start:last),4*nhidden,network.nhidden(l-1))+network.hidden(l).hweights(i).matrix; 585 | start = last+1; 586 | 587 | end 588 | 589 | 590 | 591 | end 592 | 593 | 594 | 595 | 596 | 597 | 598 | last = last+ numel(network.output.weights(l).matrix); 599 | network.output.weights(l).matrix = reshape(dW(start:last),noutput,nhidden)+ network.output.weights(l).matrix ; 600 | start=last+1; 601 | 602 | end 603 | 604 | end 605 | 606 | function vect=weights2vect(allvects) 607 | lsum = 0; 608 | lengths = cell(length(allvects),1); 609 | for i=1:length(allvects) 610 | lsum = lsum + numel(allvects{i}); 611 | lengths{i}= lsum; 612 | 613 | 614 | end 615 | vect = zeros(lsum,1,'single'); 616 | 617 | vect(1:lengths{1}) = gather(reshape(allvects{1},lengths{1},1)); 618 | for i=2:length(allvects) 619 | vect(lengths{i-1}+1:lengths{i}) = gather(reshape(allvects{i},lengths{i}-lengths{i-1},1)); 620 | end 621 | 622 | 623 | end 624 | 625 | 626 | 627 | function network = initpass(network,nbatch,maxt) 628 | 629 | ninput = network.input.n; 630 | 631 | noutput = network.output.n; 632 | 633 | 634 | for l=1:length(network.hidden) 635 | 636 | nhidden = network.nhidden(l); 637 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 638 | if ~network.last 639 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 640 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 641 | else 642 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 643 | 644 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 645 | end 646 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 647 | for i=1:length(network.hidden(l).weights); 648 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 649 | 650 | 651 | end 652 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 653 | if l>1 654 | network.hidden(l).hweights(i).gradient = gpuArray(zeros(nhidden*4,network.nhidden(l-1),'single')); 655 | 656 | end 657 | 658 | 659 | 660 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 661 | network.hidden(l).ins.v = gpuArray(zeros(4*nhidden,nbatch,maxt,'single')); 662 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 663 | 664 | 665 | 666 | end 667 | 668 | 669 | 670 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 671 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 672 | 673 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 674 | 675 | 676 | end 677 | 678 | function network = initpasstest(network,nbatch,maxt) 679 | 680 | ninput = network.input.n; 681 | 682 | noutput = network.output.n; 683 | 684 | 685 | for l=1:length(network.hidden) 686 | 687 | nhidden = network.nhidden(l); 688 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 689 | if ~network.last 690 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 691 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 692 | else 693 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 694 | 695 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 696 | end 697 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 698 | for i=1:length(network.hidden(l).weights); 699 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 700 | 701 | 702 | end 703 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 704 | if l>1 705 | network.hidden(l).hweights(i).gradient = gpuArray(zeros(nhidden*4,network.nhidden(l-1),'single')); 706 | 707 | end 708 | 709 | 710 | 711 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,length(network.storeind),'single')); 712 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,length(network.storeind),'single')); 713 | 714 | 715 | 716 | end 717 | 718 | 719 | 720 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 721 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 722 | 723 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 724 | 725 | 726 | end 727 | function network = initnetwork(ninput,nhidden,noutput) 728 | 729 | 730 | 731 | 732 | network.input.n = ninput; 733 | network.nhidden = nhidden; 734 | network.output.n = noutput; 735 | 736 | 737 | 738 | 739 | 740 | 741 | 742 | 743 | 744 | for j = 1:(length(network.nhidden)) 745 | 746 | nhidden = network.nhidden(j); 747 | network.hidden(j).hidind = (1:nhidden)'; 748 | network.hidden(j).writeind = (nhidden+1:2*nhidden)'; 749 | network.hidden(j).keepind = (2*nhidden+1:3*nhidden)'; 750 | network.hidden(j).readind = (3*nhidden+1:4*nhidden)'; 751 | network.hidden(j).gateind = (nhidden+1:4*nhidden)'; 752 | 753 | network.hidden(j).iweights.matrix = gpuArray(.1*(randn(nhidden*4,ninput,'single'))); 754 | network.hidden(j).biases.v = gpuArray(zeros(4*nhidden,1,'single')); 755 | network.hidden(j).biases.v(network.hidden(j).keepind)=3; 756 | network.hidden(j).biases.v(network.hidden(j).readind)=-2; 757 | network.hidden(j).biases.v(network.hidden(j).writeind)=0; 758 | network.hidden(j).iweights.gated = 0; 759 | 760 | 761 | 762 | network.hidden(j).weights.matrix =gpuArray(.0001*(randn(nhidden*4,nhidden,'single'))); 763 | 764 | 765 | 766 | 767 | if j>1 768 | network.hidden(j).hweights.matrix =gpuArray(.01*(randn(nhidden*4,network.nhidden(j-1),'single'))); 769 | end 770 | 771 | 772 | network.hidden(j).fx = @sigmoid; 773 | network.hidden(j).dx = @sigdir; 774 | 775 | 776 | network.output.weights(j).matrix = gpuArray(.1*(randn(noutput,nhidden,'single'))); 777 | 778 | 779 | 780 | end 781 | 782 | 783 | network.nparam = length(weights2vect(getW(network))); 784 | 785 | 786 | 787 | network.output.fx = @softmax; 788 | network.output.dx = @softdir; 789 | network.errorFunc = @evalCrossEntropy; 790 | network.output.getHessian = @CrossEntropyHessian; 791 | 792 | 793 | 794 | 795 | 796 | end 797 | 798 | 799 | function J = getJ(network) 800 | jtot=1; 801 | J = cell(jtot,1); 802 | c=1; 803 | for l=1:length(network.hidden) 804 | J{c}= network.hidden(l).iweights.gradient; 805 | c=c+1; 806 | network.hidden(l).biases.j(network.hidden(l).hidind)=0; 807 | J{c}=network.hidden(l).biases.j; 808 | c=c+1; 809 | 810 | for i = 1:length(network.hidden(l).weights); 811 | J{c}=network.hidden(l).weights(i).gradient; 812 | c=c+1; 813 | if l>1 814 | J{c}=network.hidden(l).hweights(i).gradient; 815 | c=c+1; 816 | end 817 | 818 | 819 | end 820 | 821 | J{c} = network.output.weights(l).gradient; 822 | c=c+1; 823 | end 824 | 825 | 826 | 827 | end 828 | function W = getW(network) 829 | jtot=1; 830 | W = cell(jtot,1); 831 | c=1; 832 | for l=1:length(network.hidden) 833 | 834 | W{c}= network.hidden(l).iweights.matrix; 835 | c=c+1; 836 | 837 | W{c}= network.hidden(l).biases.v; 838 | c=c+1; 839 | 840 | 841 | for i = 1:length(network.hidden(l).weights); 842 | W{c}=network.hidden(l).weights(i).matrix; 843 | c=c+1; 844 | if l>1 845 | W{c}=network.hidden(l).hweights(i).matrix; 846 | c=c+1; 847 | end 848 | 849 | end 850 | 851 | W{c} = network.output.weights(l).matrix; 852 | c=c+1; 853 | end 854 | 855 | 856 | 857 | end 858 | 859 | 860 | 861 | 862 | 863 | function f= sigmoid(x) 864 | 865 | 866 | f= 1./(1+ exp(-1.*x)); 867 | end 868 | 869 | function o = softdir(x); 870 | 871 | o=ones(size(x),'single'); 872 | 873 | 874 | end 875 | function o = softmax(x) 876 | 877 | o=bsxfun(@times,1./sum(exp(x),1),exp(x)); 878 | end 879 | function dir = sigdir( y ) 880 | 881 | dir = y.*(1-y); 882 | 883 | 884 | end 885 | function dir = tanhdir( y ) 886 | 887 | dir = (1-y.*y); 888 | 889 | 890 | end 891 | 892 | 893 | 894 | 895 | %function m=gather(m) 896 | 897 | %end 898 | %function m=gpuArray(m) 899 | 900 | %end -------------------------------------------------------------------------------- /matlab/mLSTMhutter.m: -------------------------------------------------------------------------------- 1 | function mLSTMhutter 2 | %# 3 | 4 | %to run on CPU, comment out gpuDevice command and uncomment function gather 5 | %and function gpuArray and end of file 6 | gpuDevice(1) 7 | logfname = 'mLSTMhutter.txt'; 8 | weightsfname = 'mLSTMhutter.mat'; 9 | 10 | 11 | 12 | sequence = processtextfile('enwik8'); 13 | 14 | 15 | valstart = 90*10^6; 16 | nunits = max(sequence); 17 | valind = valstart:(valstart+(5*10^5)-1); 18 | [vinputs,vtargets] = byte2input(sequence(valind),sequence(valind+1),nunits,1000); 19 | 20 | 21 | 22 | voind = single(ones(size(vtargets))); 23 | 24 | 25 | 26 | 27 | 28 | 29 | nhid = 1900; 30 | network = initnetwork(nunits,nhid,nunits); 31 | 32 | 33 | network.last = 0; 34 | 35 | 36 | disp(network.nparam) 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | log = fopen(logfname,'w'); 45 | 46 | 47 | [verr] = test(network,vinputs,vtargets,voind); 48 | 49 | 50 | 51 | 52 | 53 | %hyperparameters 54 | dec = .9; 55 | threshdecay = .9999; 56 | thresh = 10; 57 | 58 | numepochs=4; 59 | 60 | berr = verr; 61 | 62 | 63 | 64 | 65 | 66 | MS = zeros(network.nparam,1,'single'); 67 | 68 | 69 | 70 | megabatch = 1000000; 71 | minibatch = 10000; 72 | jit = megabatch/minibatch; 73 | 74 | 75 | tic 76 | start = 1; 77 | fin = start+megabatch-1; 78 | 79 | 80 | %set training set size 81 | maxtrain = 95*10^6; 82 | 83 | 84 | 85 | neph=0; 86 | for i=1:999999999999 87 | 88 | 89 | 90 | if i>1 91 | start = start+megabatch; 92 | fin = fin+megabatch; 93 | if fin > maxtrain 94 | neph = neph+1; 95 | start = 1; 96 | fin = start+megabatch-1; 97 | end 98 | end 99 | 100 | 101 | if neph==numepochs 102 | return 103 | end 104 | seqlen = 100; 105 | 106 | 107 | 108 | [in1,ta1] = byte2input(single(sequence(start:fin)),single(sequence(start+1:fin+1)),nunits,seqlen); 109 | network.last = 0; 110 | disp(jit) 111 | jit = 100; 112 | 113 | 114 | for j=1:jit 115 | 116 | 117 | thresh = threshdecay*thresh; 118 | if thresh < .01 119 | thresh = .01; 120 | end 121 | 122 | 123 | 124 | 125 | in= in1(:,j:jit:size(in1,2),:); 126 | ta= ta1(:,j:jit:size(ta1,2),:); 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | gradient = getGradient(network,in,ta,1)/minibatch; 135 | 136 | if 1-1/j < dec && i==1 137 | dec1 = 1-1/j; 138 | 139 | else 140 | dec1 = dec; 141 | end 142 | 143 | 144 | MS = dec1*MS + (1-dec1)*gradient.^2; 145 | 146 | dW = (gradient./(sqrt(MS)+.000001)); 147 | dW(isnan(dW)) = 0; 148 | norm = sqrt(dW'*dW); 149 | 150 | dW = dW*(thresh/norm); 151 | 152 | dW = -1*dW; 153 | 154 | 155 | 156 | network = updateV(network,dW); 157 | network = initpass(network,size(in,2),size(in,3)); 158 | network = ForwardPass(network, in); 159 | network.last = 1; 160 | toc 161 | 162 | 163 | 164 | 165 | 166 | if j==jit; 167 | network.last = 0; 168 | randind = floor(rand()*(maxtrain-(5*10^5))); 169 | randind = randind:(randind+5*10^5 -1); 170 | [intrain,targtrain] = byte2input(sequence(randind),sequence(randind+1),nunits,1000); 171 | 172 | 173 | [err] = test(network,intrain,targtrain,ones(size(targtrain),'single')); 174 | [verr]= test(network,vinputs,vtargets,voind); 175 | 176 | 177 | 'err' 178 | display(err); 179 | display(verr); 180 | derr = (berr - verr)/(berr); 181 | 182 | 183 | bitchar = err/(size(targtrain,3)*size(targtrain,2)); 184 | 185 | vbitchar = verr/(size(vtargets,3)*size(vtargets,2)); 186 | 187 | fprintf(log,'iter: %i train %f val %f \n',i,bitchar,vbitchar); 188 | disp(derr) 189 | 190 | berr = verr; 191 | W = weights2vect(getW(network)); 192 | save(weightsfname,'W'); 193 | 194 | break; 195 | end 196 | end 197 | 198 | 199 | end 200 | 201 | 202 | 203 | end 204 | 205 | 206 | 207 | function [inputs,targets] =byte2input(inputs,targets,nunits,seqlen) 208 | 209 | 210 | inputs= single(inputs); 211 | targets = single(targets); 212 | 213 | in = zeros(nunits,1,length(inputs),'single'); 214 | targ = zeros(nunits,1,length(targets),'single'); 215 | 216 | ind = sub2ind([nunits,1,length(inputs)],inputs,ones(length(inputs),1),(1:length(inputs))'); 217 | tind = sub2ind([nunits,1,length(inputs)],targets,ones(length(inputs),1),(1:length(inputs))'); 218 | in(ind)=1; 219 | targ(tind) = 1; 220 | 221 | inputs=permute(reshape(in,[size(in,1),seqlen,size(in,3)/seqlen]),[1,3,2]); 222 | targets=permute(reshape(targ,[size(targ,1),seqlen,size(targ,3)/seqlen]),[1,3,2]); 223 | 224 | 225 | 226 | 227 | end 228 | 229 | function gradient = getGradient(network,inputs,targets,npar) 230 | 231 | 232 | pbatch = size(inputs,2)/npar; 233 | citer = 1:pbatch:size(inputs,2); 234 | oind = ones(size(targets)); 235 | in = cell(npar,1);ta = cell(npar,1);oi = cell(npar,1); 236 | for ci=1:length(citer) 237 | c = citer(ci); 238 | in{ci} = inputs(:,c:c+pbatch-1,:); 239 | ta{ci} = targets(:,c:c+pbatch-1,:); 240 | oi{ci} = oind(:,c:c+pbatch-1,:); 241 | end 242 | 243 | gradient = zeros(network.nparam,1); 244 | for z=1:npar 245 | net = initpass(network,size(in{z},2),size(in{z},3)); 246 | net = ForwardPass(net,in{z}); 247 | net = computegradient(net,ta{z},ones(size(ta{z}))); 248 | 249 | 250 | gradient = gradient + weights2vect(getJ(net)); 251 | end 252 | 253 | 254 | 255 | 256 | end 257 | 258 | 259 | 260 | function [err] = test(network,inputs,targets,oind) 261 | errsum = 0; 262 | errcount = 0; 263 | 264 | nbatch = size(inputs,2); 265 | 266 | 267 | input = inputs; 268 | network = initpasstest(network,nbatch,size(input,3)); 269 | network = ForwardPasstest(network,input); 270 | [terr]=network.errorFunc(network.output.outs.v,targets,oind); 271 | errsum = errsum + terr; 272 | 273 | errcount = errcount+1; 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | err = errsum/errcount; 283 | 284 | 285 | 286 | 287 | 288 | 289 | end 290 | 291 | 292 | function network = ForwardPass(network, inputs) 293 | 294 | inputs = gpuArray(inputs); 295 | network.input.outs.v = inputs; 296 | 297 | 298 | for t=1:size(inputs,3); 299 | l=1; 300 | 301 | 302 | 303 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 304 | 305 | if t-1>0 306 | 307 | 308 | network.hidden(l).intermediates.v(:,:,t) = network.hidden(l).mweights.matrix*network.hidden(l).outs.v(:,:,t-1); 309 | 310 | else 311 | 312 | network.hidden(l).intermediates.v(:,:,t) = network.hidden(l).mweights.matrix*network.hidden(l).outs.vp0; 313 | end 314 | 315 | network.hidden(l).factor.v(:,:,t) = network.hidden(l).fweights.matrix*inputs(:,:,t); 316 | network.hidden(l).mult.v(:,:,t) = network.hidden(l).factor.v(:,:,t).*network.hidden(l).intermediates.v(:,:,t); 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).mult.v(:,:,t); 325 | network.hidden(l).ins.v(network.gateind,:,t)= bsxfun(@plus,network.hidden(l).ins.v(network.gateind,:,t),network.hidden(l).biases.v(network.gateind,:)); 326 | network.hidden(l).ins.v(network.gateind,:,t) = sigmoid(network.hidden(l).ins.v(network.gateind,:,t)); 327 | 328 | 329 | 330 | 331 | 332 | network.hidden(l).ins.state(:,:,t)= network.hidden(l).ins.v(network.hidind,:,t).*network.hidden(l).ins.v(network.writeind,:,t); 333 | 334 | 335 | 336 | if t-1>0 337 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t)+network.hidden(l).ins.state(:,:,t-1).*network.hidden(l).ins.v(network.keepind,:,t); 338 | else 339 | network.hidden(l).ins.state(:,:,t) =network.hidden(l).ins.state(:,:,t)+ network.hidden(l).ins.statep0.*network.hidden(l).ins.v(network.keepind,:,t);; 340 | end 341 | 342 | 343 | network.hidden(l).outs.v(:,:,t) = network.hidden(l).ins.state(:,:,t).*network.hidden(l).ins.v(network.readind,:,t);; 344 | network.hidden(l).outs.v(:,:,t)=bsxfun(@plus,network.hidden(l).outs.v(:,:,t),network.hidden(l).biases.v(network.hidind,:)); 345 | network.hidden(l).outs.v(:,:,t) = tanh(network.hidden(l).outs.v(:,:,t)); 346 | 347 | 348 | 349 | 350 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*network.hidden(l).outs.v(:,:,t); 351 | 352 | 353 | if t==size(inputs,3) 354 | network.hidden(l).outs.last = network.hidden(l).outs.v(:,:,t); 355 | network.hidden(l).ins.last = network.hidden(l).ins.state(:,:,t); 356 | end 357 | 358 | 359 | end 360 | 361 | network.output.outs.v = network.output.fx(network.output.outs.v); 362 | 363 | 364 | 365 | end 366 | function network = ForwardPasstest(network, inputs) 367 | 368 | inputs = gpuArray(inputs); 369 | network.input.outs.v = inputs; 370 | 371 | 372 | 373 | hidden.outs.vp = gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 374 | hidden.ins.statep = gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 375 | hidden.mult.v=gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 376 | 377 | for l=1:length(network.hidden) 378 | hidden(l).outs.vp = network.hidden(l).outs.vp0 ; 379 | hidden(l).ins.statep = network.hidden(l).ins.statep0; 380 | end 381 | s=1; 382 | for t=1:size(inputs,3); 383 | l=1; 384 | 385 | 386 | 387 | 388 | hidden.ins.v= network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | hidden.intermediates.v = network.hidden(l).mweights.matrix*hidden(l).outs.vp; 401 | hidden.factor.v = network.hidden(l).fweights.matrix*inputs(:,:,t); 402 | hidden.mult.v = hidden(l).factor.v.*hidden(l).intermediates.v; 403 | 404 | 405 | hidden.ins.v = hidden(l).ins.v + network.hidden.weights.matrix*hidden(l).mult.v; 406 | hidden(l).ins.v(network.gateind,:)= bsxfun(@plus,hidden(l).ins.v(network.gateind,:),network.hidden(l).biases.v(network.gateind,:)); 407 | hidden.ins.v(network.gateind,:) = sigmoid(hidden.ins.v(network.gateind,:)); 408 | 409 | 410 | 411 | hidden.ins.state=hidden.ins.v(network.hidind,:).*hidden.ins.v(network.writeind,:); 412 | 413 | 414 | 415 | 416 | hidden.ins.state = hidden.ins.state + hidden(l).ins.statep.*hidden.ins.v(network.keepind,:); 417 | 418 | 419 | 420 | 421 | hidden.outs.v = hidden(l).ins.state.*hidden.ins.v(network.readind,:); 422 | hidden(l).outs.v=bsxfun(@plus,hidden(l).outs.v,network.hidden(l).biases.v(network.hidind,:)); 423 | hidden.outs.v = tanh(hidden(l).outs.v); 424 | 425 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*hidden(l).outs.v; 426 | 427 | hidden.outs.vp = hidden(l).outs.v; 428 | hidden.ins.statep = hidden(l).ins.state; 429 | 430 | 431 | 432 | end 433 | 434 | 435 | network.output.outs.v = network.output.fx(network.output.outs.v); 436 | 437 | 438 | 439 | end 440 | 441 | 442 | 443 | 444 | function network = computegradient(network, targets,omat) 445 | oind = find(omat); 446 | 447 | network.output.outs.j(oind) = network.output.outs.v(oind)- targets(oind); 448 | 449 | network.output.outs.j = network.output.outs.j.*network.output.dx(network.output.outs.v); 450 | 451 | 452 | 453 | 454 | hidden.ins.statej = gpuArray(zeros(network.nhidden,size(targets,2),'single')); 455 | hidden.outs.j = gpuArray(zeros(network.nhidden,size(targets,2),'single')); 456 | hidden.ins.j = gpuArray(zeros(network.nhidden*4,size(network.input.outs.v,2),'single')); 457 | for t=size(network.input.outs.v,3):-1:1; 458 | 459 | 460 | 461 | l=1; 462 | 463 | network.output.weights(l).gradient = network.output.weights(l).gradient + network.output.outs.j(:,:,t)*network.hidden(l).outs.v(:,:,t)'; 464 | 465 | hidden.outs.j = hidden.outs.j + network.output.weights(l).matrix'*network.output.outs.j(:,:,t) ; 466 | 467 | 468 | 469 | 470 | 471 | 472 | hidden(l).outs.j = hidden(l).outs.j.*tanhdir(network.hidden(l).outs.v(:,:,t)); 473 | network.hidden(l).biases.j(network.hidind,:) = network.hidden(l).biases.j(network.hidind,:) + sum(hidden(l).outs.j,2);%biases only have 1 dimension 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | hidden(l).ins.j(network.readind,:) = hidden(l).outs.j.* network.hidden(l).ins.state(:,:,t) ; 482 | 483 | hidden(l).ins.statej = hidden(l).ins.statej + hidden(l).outs.j.*network.hidden(l).ins.v(network.readind,:,t); 484 | 485 | 486 | 487 | ttemp = t-1; 488 | if ttemp>0 489 | 490 | hidden(l).ins.statejp = hidden(l).ins.statej.*network.hidden(l).ins.v(network.keepind,:,t); 491 | % network.hidden(l).outs.j(:,:,t) = network.hidden(l).outs.j(:,:,t)+ network.hidden(l).weights(i).matrix'*network.hidden(l).outs.j(:,:,ttemp); 492 | 493 | end 494 | 495 | 496 | 497 | 498 | if ttemp>0 499 | 500 | 501 | 502 | hidden(l).ins.j(network.keepind,:) = network.hidden(l).ins.state(:,:,t-1).*hidden(l).ins.statej; 503 | else 504 | 505 | 506 | hidden(l).ins.j(network.keepind,:)= network.hidden(l).ins.statep0.*hidden(l).ins.statej; 507 | end 508 | 509 | 510 | hidden(l).ins.j(network.writeind,:) = network.hidden(l).ins.v(network.hidind,:,t).*hidden(l).ins.statej; 511 | hidden(l).ins.j(network.hidind,:) = network.hidden(l).ins.v(network.writeind,:,t).*hidden(l).ins.statej; 512 | 513 | hidden(l).ins.j(network.gateind,:)= hidden(l).ins.j(network.gateind,:).*sigdir(network.hidden(l).ins.v(network.gateind,:,t)); 514 | network.hidden(l).biases.j(network.gateind,:) = network.hidden(l).biases.j(network.gateind,:) + sum(hidden(l).ins.j(network.gateind,:),2); 515 | 516 | 517 | hidden(l).mult.j = network.hidden(l).weights.matrix'*hidden(l).ins.j; 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + (hidden(l).ins.j)*network.hidden(l).mult.v(:,:,t)'; 527 | 528 | 529 | 530 | 531 | if t-1>0 532 | 533 | hidden(l).intermediates.j = hidden(l).mult.j.*network.hidden(l).factor.v(:,:,t); 534 | 535 | hidden(l).factor.j= hidden(l).mult.j.*network.hidden(l).intermediates.v(:,:,t); 536 | 537 | network.hidden(l).fweights.gradient = network.hidden(l).fweights.gradient + hidden(l).factor.j*network.input.outs.v(:,:,t)'; 538 | 539 | hidden(l).outs.jp = network.hidden(l).mweights.matrix'*hidden(l).intermediates.j; 540 | network.hidden(l).mweights.gradient = network.hidden(l).mweights.gradient + hidden(l).intermediates.j*network.hidden(l).outs.v(:,:,t-1)'; 541 | else 542 | hidden(l).intermediates.j = hidden(l).mult.j.*network.hidden(l).factor.v(:,:,t); 543 | 544 | hidden(l).factor.j= hidden(l).mult.j.*network.hidden(l).intermediates.v(:,:,t); 545 | 546 | network.hidden(l).fweights.gradient = network.hidden(l).fweights.gradient + hidden(l).factor.j*network.input.outs.v(:,:,t)'; 547 | 548 | hidden(l).outs.jp = network.hidden(l).mweights.matrix'*hidden(l).intermediates.j; 549 | network.hidden(l).mweights.gradient = network.hidden(l).mweights.gradient + hidden(l).intermediates.j*network.hidden(l).outs.vp0'; 550 | 551 | 552 | 553 | 554 | end 555 | 556 | 557 | network.hidden(l).iweights.gradient = network.hidden(l).iweights.gradient + (hidden(l).ins.j)*network.input.outs.v(:,:,t)'; 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | hidden(l).outs.j = hidden(l).outs.jp; 566 | hidden(l).ins.statej = hidden(l).ins.statejp; 567 | 568 | 569 | 570 | end 571 | 572 | end 573 | 574 | 575 | function [lcost] = evalCrossEntropy(output,targets,omat) 576 | 577 | 578 | 579 | oind = find(omat); 580 | 581 | ldiff = targets.*log2(output); 582 | 583 | lcost = -1*sum(ldiff(:)); 584 | 585 | 586 | 587 | 588 | end 589 | 590 | function network = updateV(network, dW) 591 | 592 | ninput = network.input.n; 593 | noutput = network.output.n; 594 | 595 | 596 | start = 1; 597 | last = 0; 598 | 599 | for l=1:length(network.hidden) 600 | nhidden = network.nhidden(l); 601 | 602 | last = last + numel(network.hidden(l).iweights.matrix); 603 | network.hidden(l).iweights.matrix = reshape(dW(start:last),4*nhidden,ninput)+ network.hidden(l).iweights.matrix ; 604 | start = last + 1; 605 | 606 | last = last + numel(network.hidden(l).biases.v); 607 | network.hidden(l).biases.v = reshape(dW(start:last),4*nhidden,1)+ network.hidden(l).biases.v ; 608 | start = last + 1; 609 | 610 | 611 | last = last + numel(network.hidden(l).fweights.matrix); 612 | network.hidden(l).fweights.matrix = reshape(dW(start:last),nhidden,ninput)+ network.hidden(l).fweights.matrix ; 613 | start = last + 1; 614 | 615 | 616 | for i=1:length(network.hidden(l).weights); 617 | last = last + numel(network.hidden(l).weights(i).matrix); 618 | network.hidden(l).weights(i).matrix = reshape(dW(start:last),4*nhidden,nhidden)+network.hidden(l).weights(i).matrix; 619 | start = last+1; 620 | last = last + numel(network.hidden(l).mweights(i).matrix); 621 | network.hidden(l).mweights(i).matrix = reshape(dW(start:last),nhidden,nhidden)+network.hidden(l).mweights(i).matrix; 622 | start = last+1; 623 | 624 | 625 | end 626 | 627 | 628 | 629 | 630 | 631 | 632 | last = last+ numel(network.output.weights(l).matrix); 633 | network.output.weights(l).matrix = reshape(dW(start:last),noutput,nhidden)+ network.output.weights(l).matrix ; 634 | start=last+1; 635 | 636 | end 637 | 638 | end 639 | 640 | function vect=weights2vect(allvects) 641 | lsum = 0; 642 | lengths = cell(length(allvects),1); 643 | for i=1:length(allvects) 644 | lsum = lsum + numel(allvects{i}); 645 | lengths{i}= lsum; 646 | 647 | 648 | end 649 | vect = zeros(lsum,1,'single'); 650 | 651 | vect(1:lengths{1}) = gather(reshape(allvects{1},lengths{1},1)); 652 | for i=2:length(allvects) 653 | vect(lengths{i-1}+1:lengths{i}) = gather(reshape(allvects{i},lengths{i}-lengths{i-1},1)); 654 | end 655 | 656 | 657 | end 658 | 659 | 660 | 661 | 662 | function network = initpass(network,nbatch,maxt) 663 | 664 | ninput = network.input.n; 665 | 666 | noutput = network.output.n; 667 | 668 | for l=1:length(network.hidden) 669 | 670 | nhidden = network.nhidden(l); 671 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 672 | 673 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 674 | if ~network.last 675 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 676 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 677 | else 678 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 679 | 680 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 681 | end 682 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 683 | 684 | 685 | network.hidden(l).ins.v = gpuArray(zeros(nhidden*4,nbatch,maxt,'single')); 686 | 687 | 688 | network.hidden(l).intermediates.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 689 | 690 | network.hidden(l).factor.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 691 | 692 | network.hidden(l).mult.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 693 | 694 | 695 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 696 | 697 | 698 | 699 | 700 | 701 | for i=1:length(network.hidden(l).weights); 702 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 703 | network.hidden(l).mweights(i).gradient = gpuArray(zeros(nhidden,nhidden,'single')); 704 | 705 | end 706 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 707 | network.hidden(l).fweights.gradient = gpuArray(zeros(nhidden,ninput,'single')); 708 | 709 | 710 | 711 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 712 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 713 | 714 | 715 | 716 | end 717 | 718 | 719 | 720 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 721 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 722 | 723 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 724 | 725 | 726 | end 727 | 728 | function network = initpasstest(network,nbatch,maxt) 729 | 730 | ninput = network.input.n; 731 | 732 | noutput = network.output.n; 733 | 734 | for l=1:length(network.hidden) 735 | 736 | nhidden = network.nhidden(l); 737 | 738 | 739 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 740 | if ~network.last 741 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 742 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 743 | else 744 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 745 | 746 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 747 | end 748 | 749 | 750 | 751 | 752 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,'single')); 753 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,'single')); 754 | 755 | 756 | 757 | end 758 | 759 | 760 | 761 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 762 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 763 | 764 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 765 | 766 | 767 | end 768 | 769 | function network = initnetwork(ninput,nhidden,noutput) 770 | 771 | 772 | network.input.n = ninput; 773 | network.nhidden = nhidden; 774 | network.output.n = noutput; 775 | 776 | 777 | 778 | 779 | network.hidind = (1:nhidden)'; 780 | network.writeind = (nhidden+1:2*nhidden)'; 781 | network.keepind = (2*nhidden+1:3*nhidden)'; 782 | network.readind = (3*nhidden+1:4*nhidden)'; 783 | network.gateind = (nhidden+1:4*nhidden)'; 784 | for j = 1:1; 785 | nhidden = network.nhidden(j); 786 | 787 | 788 | network.hidden(j).iweights.matrix = gpuArray(.1*(randn(nhidden*4,ninput,'single'))); 789 | network.hidden(j).fweights.matrix = gpuArray(.1*(randn(nhidden,ninput,'single'))); 790 | 791 | 792 | 793 | network.hidden(j).biases.v = gpuArray(zeros(4*nhidden,1,'single')); 794 | network.hidden(j).biases.v(network.keepind)=3; 795 | 796 | 797 | network.hidden(j).weights.matrix =gpuArray(.02*(randn(nhidden*4,nhidden,'single'))); 798 | network.hidden(j).mweights.matrix =gpuArray(.02*(randn(nhidden,nhidden,'single'))); 799 | 800 | 801 | 802 | 803 | network.hidden(j).fx = @sigmoid; 804 | network.hidden(j).dx = @sigdir; 805 | network.output.weights(j).matrix = gpuArray(.1*(randn(noutput,nhidden,'single'))); 806 | 807 | 808 | 809 | end 810 | 811 | 812 | network.nparam = length(weights2vect(getW(network))); 813 | 814 | 815 | network.output.fx = @softmax; 816 | network.output.dx = @softdirXent; 817 | network.errorFunc = @evalCrossEntropy; 818 | 819 | end 820 | 821 | 822 | function J = getJ(network) 823 | 824 | J = cell(5,1); 825 | c=1; 826 | for l=1:length(network.hidden) 827 | J{c}= network.hidden(l).iweights.gradient; 828 | c=c+1; 829 | J{c}= network.hidden(l).biases.j; 830 | c=c+1; 831 | J{c}= network.hidden(l).fweights.gradient; 832 | c=c+1; 833 | 834 | 835 | for i = 1:length(network.hidden(l).weights); 836 | J{c}=network.hidden(l).weights(i).gradient; 837 | c=c+1; 838 | J{c}=network.hidden(l).mweights(i).gradient; 839 | c=c+1; 840 | 841 | 842 | end 843 | 844 | J{c} = 1*network.output.weights(l).gradient; 845 | c=c+1; 846 | end 847 | 848 | 849 | 850 | end 851 | function W = getW(network) 852 | 853 | W = cell(5,1); 854 | c=1; 855 | for l=1:length(network.hidden) 856 | 857 | W{c}= network.hidden(l).iweights.matrix; 858 | c=c+1; 859 | W{c}= network.hidden(l).biases.v; 860 | c=c+1; 861 | W{c}= network.hidden(l).fweights.matrix; 862 | c=c+1; 863 | 864 | 865 | 866 | for i = 1:length(network.hidden(l).weights); 867 | W{c}=network.hidden(l).weights(i).matrix; 868 | c=c+1; 869 | W{c}=network.hidden(l).mweights(i).matrix; 870 | c=c+1; 871 | 872 | end 873 | 874 | W{c} = network.output.weights(l).matrix; 875 | c=c+1; 876 | end 877 | 878 | 879 | 880 | end 881 | 882 | 883 | 884 | 885 | 886 | function f= sigmoid(x) 887 | 888 | 889 | f= 1./(1+ exp(-1.*x)); 890 | end 891 | 892 | function o = softdirXent(x); 893 | 894 | o=ones(size(x),'single'); 895 | 896 | 897 | end 898 | 899 | function dir = sigdir( y ) 900 | 901 | dir = y.*(1-y); 902 | 903 | 904 | end 905 | function dir = tanhdir( y ) 906 | 907 | dir = (1-y.*y); 908 | 909 | 910 | end 911 | function o = softmax(x) 912 | 913 | o=bsxfun(@times,1./sum(exp(x),1),exp(x)); 914 | end 915 | 916 | %function m=gather(m) 917 | 918 | %end 919 | %function m=gpuArray(m) 920 | 921 | %end -------------------------------------------------------------------------------- /matlab/mLSTMdynamic.m: -------------------------------------------------------------------------------- 1 | function mLSTMdynamic 2 | 3 | gpuDevice(1) 4 | 5 | sequence = processtextfile('enwik8'); 6 | 7 | weightsfname = 'mLSTMhutter.mat'; 8 | 9 | nunits = max(sequence); 10 | 11 | 12 | 13 | 14 | network = initnetwork(nunits,[1900],nunits); 15 | 16 | network.last = 0; 17 | % 18 | 19 | 20 | network=updateV(network,-1*weights2vect(getW(network))); 21 | 22 | wfile = load(weightsfname,'W'); 23 | W = wfile.W; 24 | 25 | 26 | network = updateV(network,W); 27 | 28 | disp(network.nparam) 29 | 30 | 31 | 32 | 33 | network.zero = 0; 34 | 35 | 36 | seqlen = 50; 37 | 38 | tic 39 | epsilon = .0002; 40 | start = 96000000;; 41 | fin = start+seqlen-1; 42 | 43 | serr = 0; 44 | dec = .99; 45 | 46 | 47 | network.last = 0; 48 | network0 = network; 49 | alpha = .01; 50 | network = RMSinit(network); 51 | for i=1:(4*10^6)/seqlen 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | [in1,ta1] = byte2input(sequence(start:fin),sequence(start+1:fin+1),nunits,seqlen); 64 | start = start + seqlen; 65 | fin = fin + seqlen; 66 | network = initpass(network,size(in1,2),size(in1,3)); 67 | network.zero = 1; 68 | network = ForwardPass(network,in1); 69 | 70 | 71 | network.output.outs.v = .999999*network.output.outs.v + .000001*(ones(size(network.output.outs.v)))/size(network.output.outs.v,1); 72 | 73 | terr = evalCrossEntropy(network.output.outs.v,ta1,ones(size(ta1),'single')); 74 | 75 | serr = serr+ (terr/seqlen); 76 | 77 | err = serr/i; 78 | 79 | network = computegradient(network,ta1,ones(size(ta1),'single')); 80 | 81 | 82 | 83 | if 1-1/i < dec 84 | dec1 = 1-1/i; 85 | 86 | else 87 | dec1 = dec; 88 | end 89 | 90 | network = RMSprop(network,network0,dec1,alpha,epsilon,seqlen); 91 | network.last=1; 92 | network = initpass(network,size(in1,2),size(in1,3)); 93 | network = ForwardPass(network, in1); 94 | 95 | 96 | if mod(i,100)==0 97 | disp(i) 98 | disp(err) 99 | end 100 | 101 | end 102 | disp('final error') 103 | disp(err) 104 | 105 | end 106 | 107 | function network = RMSinit(network) 108 | 109 | 110 | l=1; 111 | network.hidden(l).iweights.MS = gpuArray(zeros(size(network.hidden(l).iweights.matrix),'single')); 112 | network.hidden(l).weights.MS = gpuArray(zeros(size(network.hidden(l).weights.matrix),'single')); 113 | network.hidden(l).fweights.MS = gpuArray(zeros(size(network.hidden(l).fweights.matrix),'single')); 114 | network.hidden(l).mweights.MS = gpuArray(zeros(size(network.hidden(l).mweights.matrix),'single')); 115 | 116 | network.output.weights.MS = gpuArray(zeros(size(network.output.weights.matrix),'single')); 117 | end 118 | function network = RMSprop(network,network0,dec,alpha,epsilon,n) 119 | 120 | l=1; 121 | network.hidden(l).iweights.gradient=network.hidden.iweights.gradient/n; 122 | network.hidden(l).weights.gradient=network.hidden.weights.gradient/n; 123 | network.hidden(l).fweights.gradient=network.hidden.fweights.gradient/n; 124 | network.hidden(l).mweights.gradient=network.hidden.mweights.gradient/n; 125 | network.output.weights.gradient=network.output.weights.gradient/n; 126 | 127 | 128 | 129 | 130 | network.hidden(l).iweights.MS = network.hidden(l).iweights.MS*dec + network.hidden(l).iweights.gradient.^2*(1-dec); 131 | network.hidden(l).weights.MS = network.hidden(l).weights.MS*dec + network.hidden(l).weights.gradient.^2*(1-dec); 132 | network.hidden(l).mweights.MS = network.hidden(l).mweights.MS*dec + network.hidden(l).mweights.gradient.^2*(1-dec); 133 | network.hidden(l).fweights.MS = network.hidden(l).fweights.MS*dec + network.hidden(l).fweights.gradient.^2*(1-dec); 134 | 135 | 136 | network.output.weights.MS = network.output.weights.MS*dec + network.output.weights.gradient.^2*(1-dec); 137 | 138 | 139 | 140 | 141 | network.hidden(l).iweights.dW = -1*network.hidden(l).iweights.gradient./sqrt(network.hidden(l).iweights.MS); 142 | network.hidden(l).iweights.dW(isnan(network.hidden(l).iweights.dW))=0; 143 | 144 | network.hidden(l).weights.dW = -1*network.hidden(l).weights.gradient./sqrt(network.hidden(l).weights.MS); 145 | network.hidden(l).weights.dW(isnan(network.hidden(l).weights.dW))=0; 146 | 147 | 148 | network.hidden(l).mweights.dW = -1*network.hidden(l).mweights.gradient./sqrt(network.hidden(l).mweights.MS); 149 | network.hidden(l).mweights.dW(isnan(network.hidden(l).mweights.dW))=0; 150 | 151 | network.hidden(l).fweights.dW = -1*network.hidden(l).fweights.gradient./sqrt(network.hidden(l).fweights.MS); 152 | network.hidden(l).fweights.dW(isnan(network.hidden(l).fweights.dW))=0; 153 | 154 | 155 | 156 | 157 | network.output.weights.dW = -1*network.output.weights.gradient./sqrt(network.output.weights.MS); 158 | network.output.weights.dW(isnan(network.output.weights.dW))=0; 159 | 160 | 161 | 162 | 163 | network.hidden(l).iweights.dW = network.hidden(l).iweights.dW*epsilon; 164 | network.hidden(l).weights.dW = network.hidden(l).weights.dW*epsilon; 165 | network.hidden(l).fweights.dW = network.hidden(l).fweights.dW*epsilon; 166 | network.hidden(l).mweights.dW = network.hidden(l).mweights.dW*epsilon; 167 | network.output.weights.dW = network.output.weights.dW*epsilon; 168 | 169 | network.hidden(l).iweights.dW = network.hidden(l).iweights.dW+alpha*(network0.hidden(l).iweights.matrix-network.hidden(l).iweights.matrix); 170 | network.hidden(l).weights.dW = network.hidden(l).weights.dW+alpha*(network0.hidden(l).weights.matrix-network.hidden(l).weights.matrix); 171 | network.hidden(l).fweights.dW = network.hidden(l).fweights.dW+alpha*(network0.hidden(l).fweights.matrix-network.hidden(l).fweights.matrix); 172 | network.hidden(l).mweights.dW = network.hidden(l).mweights.dW+alpha*(network0.hidden(l).mweights.matrix-network.hidden(l).mweights.matrix); 173 | network.output.weights.dW = network.output.weights.dW+alpha*(network0.output.weights.matrix-network.output.weights.matrix); 174 | 175 | 176 | 177 | 178 | network.hidden(l).iweights.matrix = network.hidden(l).iweights.matrix + network.hidden(l).iweights.dW; 179 | network.hidden(l).weights.matrix = network.hidden(l).weights.matrix + network.hidden(l).weights.dW; 180 | network.hidden(l).fweights.matrix = network.hidden(l).fweights.matrix + network.hidden(l).fweights.dW; 181 | network.hidden(l).mweights.matrix = network.hidden(l).mweights.matrix + network.hidden(l).mweights.dW; 182 | network.output.weights.matrix = network.output.weights.matrix + network.output.weights.dW; 183 | 184 | end 185 | 186 | function [inputs,targets] =byte2input(inputs,targets,nunits,seqlen) 187 | 188 | 189 | inputs= single(inputs); 190 | targets = single(targets); 191 | 192 | in = zeros(nunits,1,length(inputs),'single'); 193 | targ = zeros(nunits,1,length(targets),'single'); 194 | 195 | ind = sub2ind([nunits,1,length(inputs)],inputs,ones(length(inputs),1),(1:length(inputs))'); 196 | tind = sub2ind([nunits,1,length(inputs)],targets,ones(length(inputs),1),(1:length(inputs))'); 197 | in(ind)=1; 198 | targ(tind) = 1; 199 | 200 | inputs=permute(reshape(in,[size(in,1),seqlen,size(in,3)/seqlen]),[1,3,2]); 201 | targets=permute(reshape(targ,[size(targ,1),seqlen,size(targ,3)/seqlen]),[1,3,2]); 202 | 203 | 204 | 205 | 206 | end 207 | 208 | function gradient = getGradient(network,inputs,targets,npar) 209 | 210 | 211 | pbatch = size(inputs,2)/npar; 212 | citer = 1:pbatch:size(inputs,2); 213 | oind = ones(size(targets)); 214 | in = cell(npar,1);ta = cell(npar,1);oi = cell(npar,1); 215 | for ci=1:length(citer) 216 | c = citer(ci); 217 | in{ci} = inputs(:,c:c+pbatch-1,:); 218 | ta{ci} = targets(:,c:c+pbatch-1,:); 219 | oi{ci} = oind(:,c:c+pbatch-1,:); 220 | end 221 | 222 | gradient = zeros(network.nparam,1); 223 | for z=1:npar 224 | net = initpass(network,size(in{z},2),size(in{z},3)); 225 | net = ForwardPass(net,in{z}); 226 | net = computegradient(net,ta{z},ones(size(ta{z}))); 227 | 228 | 229 | gradient = gradient + weights2vect(getJ(net)); 230 | end 231 | 232 | 233 | 234 | 235 | end 236 | 237 | 238 | 239 | function [err] = test(network,inputs,targets,oind) 240 | errsum = 0; 241 | errcount = 0; 242 | 243 | nbatch = size(inputs,2); 244 | 245 | 246 | input = inputs; 247 | network = initpasstest(network,nbatch,size(input,3)); 248 | network = ForwardPasstest(network,input); 249 | [terr]=network.errorFunc(network.output.outs.v,targets,oind); 250 | errsum = errsum + terr; 251 | 252 | errcount = errcount+1; 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | err = errsum/errcount; 262 | 263 | 264 | 265 | 266 | 267 | 268 | end 269 | 270 | function network = ForwardPass(network, inputs) 271 | 272 | inputs = gpuArray(inputs); 273 | network.input.outs.v = inputs; 274 | 275 | 276 | for t=1:size(inputs,3); 277 | l=1; 278 | 279 | 280 | 281 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 282 | 283 | if t-1>0 284 | 285 | 286 | network.hidden(l).intermediates.v(:,:,t) = network.hidden(l).mweights.matrix*network.hidden(l).outs.v(:,:,t-1); 287 | 288 | else 289 | 290 | network.hidden(l).intermediates.v(:,:,t) = network.hidden(l).mweights.matrix*network.hidden(l).outs.vp0; 291 | end 292 | 293 | network.hidden(l).factor.v(:,:,t) = network.hidden(l).fweights.matrix*inputs(:,:,t); 294 | network.hidden(l).mult.v(:,:,t) = network.hidden(l).factor.v(:,:,t).*network.hidden(l).intermediates.v(:,:,t); 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | network.hidden(l).ins.v(:,:,t) = network.hidden(l).ins.v(:,:,t) + network.hidden(l).weights.matrix*network.hidden(l).mult.v(:,:,t); 303 | network.hidden(l).ins.v(network.gateind,:,t)= bsxfun(@plus,network.hidden(l).ins.v(network.gateind,:,t),network.hidden(l).biases.v(network.gateind,:)); 304 | network.hidden(l).ins.v(network.gateind,:,t) = sigmoid(network.hidden(l).ins.v(network.gateind,:,t)); 305 | 306 | 307 | 308 | 309 | 310 | network.hidden(l).ins.state(:,:,t)= network.hidden(l).ins.v(network.hidind,:,t).*network.hidden(l).ins.v(network.writeind,:,t); 311 | 312 | 313 | 314 | if t-1>0 315 | network.hidden(l).ins.state(:,:,t) = network.hidden(l).ins.state(:,:,t)+network.hidden(l).ins.state(:,:,t-1).*network.hidden(l).ins.v(network.keepind,:,t); 316 | else 317 | network.hidden(l).ins.state(:,:,t) =network.hidden(l).ins.state(:,:,t)+ network.hidden(l).ins.statep0.*network.hidden(l).ins.v(network.keepind,:,t);; 318 | end 319 | 320 | 321 | network.hidden(l).outs.v(:,:,t) = network.hidden(l).ins.state(:,:,t).*network.hidden(l).ins.v(network.readind,:,t);; 322 | network.hidden(l).outs.v(:,:,t)=bsxfun(@plus,network.hidden(l).outs.v(:,:,t),network.hidden(l).biases.v(network.hidind,:)); 323 | network.hidden(l).outs.v(:,:,t) = tanh(network.hidden(l).outs.v(:,:,t)); 324 | 325 | 326 | 327 | 328 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*network.hidden(l).outs.v(:,:,t); 329 | 330 | 331 | if t==size(inputs,3) 332 | network.hidden(l).outs.last = network.hidden(l).outs.v(:,:,t); 333 | network.hidden(l).ins.last = network.hidden(l).ins.state(:,:,t); 334 | end 335 | 336 | 337 | end 338 | 339 | network.output.outs.v = network.output.fx(network.output.outs.v); 340 | 341 | 342 | 343 | end 344 | function network = ForwardPasstest(network, inputs) 345 | 346 | inputs = gpuArray(inputs); 347 | network.input.outs.v = inputs; 348 | 349 | 350 | 351 | hidden.outs.vp = gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 352 | hidden.ins.statep = gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 353 | hidden.mult.v=gpuArray(zeros(network.nhidden,size(inputs,2),'single')); 354 | 355 | for l=1:length(network.hidden) 356 | hidden(l).outs.vp = network.hidden(l).outs.vp0 ; 357 | hidden(l).ins.statep = network.hidden(l).ins.statep0; 358 | end 359 | s=1; 360 | for t=1:size(inputs,3); 361 | l=1; 362 | 363 | 364 | 365 | 366 | hidden.ins.v= network.hidden(l).iweights.matrix*network.input.outs.v(:,:,t); 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | hidden.intermediates.v = network.hidden(l).mweights.matrix*hidden(l).outs.vp; 381 | hidden.factor.v = network.hidden(l).fweights.matrix*inputs(:,:,t); 382 | hidden.mult.v = hidden(l).factor.v.*hidden(l).intermediates.v; 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | hidden.ins.v = hidden(l).ins.v + network.hidden.weights.matrix*hidden(l).mult.v; 392 | hidden(l).ins.v(network.gateind,:)= bsxfun(@plus,hidden(l).ins.v(network.gateind,:),network.hidden(l).biases.v(network.gateind,:)); 393 | hidden.ins.v(network.gateind,:) = sigmoid(hidden.ins.v(network.gateind,:)); 394 | 395 | 396 | 397 | hidden.ins.state=hidden.ins.v(network.hidind,:).*hidden.ins.v(network.writeind,:); 398 | 399 | 400 | 401 | hidden.ins.state = hidden.ins.state + hidden(l).ins.statep.*hidden.ins.v(network.keepind,:); 402 | 403 | 404 | 405 | 406 | hidden.outs.v = hidden(l).ins.state.*hidden.ins.v(network.readind,:); 407 | hidden(l).outs.v=bsxfun(@plus,hidden(l).outs.v,network.hidden(l).biases.v(network.hidind,:)); 408 | hidden.outs.v = tanh(hidden(l).outs.v); 409 | 410 | network.output.outs.v(:,:,t) = network.output.outs.v(:,:,t) + network.output.weights(l).matrix*hidden(l).outs.v; 411 | 412 | hidden.outs.vp = hidden(l).outs.v; 413 | hidden.ins.statep = hidden(l).ins.state; 414 | 415 | 416 | 417 | end 418 | 419 | 420 | network.output.outs.v = network.output.fx(network.output.outs.v); 421 | 422 | 423 | 424 | end 425 | 426 | 427 | 428 | 429 | function network = computegradient(network, targets,omat) 430 | oind = find(omat); 431 | 432 | network.output.outs.j(oind) = network.output.outs.v(oind)- targets(oind); 433 | 434 | network.output.outs.j = network.output.outs.j.*network.output.dx(network.output.outs.v); 435 | 436 | 437 | 438 | 439 | 440 | hidden.ins.statej = gpuArray(zeros(network.nhidden,size(targets,2),'single')); 441 | hidden.outs.j = gpuArray(zeros(network.nhidden,size(targets,2),'single')); 442 | hidden.ins.j = gpuArray(zeros(network.nhidden*4,size(network.input.outs.v,2),'single')); 443 | for t=size(network.input.outs.v,3):-1:1; 444 | 445 | 446 | 447 | l=1; 448 | 449 | network.output.weights(l).gradient = network.output.weights(l).gradient + network.output.outs.j(:,:,t)*network.hidden(l).outs.v(:,:,t)'; 450 | 451 | hidden.outs.j = hidden.outs.j + network.output.weights(l).matrix'*network.output.outs.j(:,:,t) ; 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | hidden(l).outs.j = hidden(l).outs.j.*tanhdir(network.hidden(l).outs.v(:,:,t)); 460 | network.hidden(l).biases.j(network.hidind,:) = network.hidden(l).biases.j(network.hidind,:) + sum(hidden(l).outs.j,2);%biases only have 1 dimension 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | hidden(l).ins.j(network.readind,:) = hidden(l).outs.j.* network.hidden(l).ins.state(:,:,t) ; 469 | 470 | hidden(l).ins.statej = hidden(l).ins.statej + hidden(l).outs.j.*network.hidden(l).ins.v(network.readind,:,t); 471 | 472 | 473 | 474 | 475 | ttemp = t-1; 476 | if ttemp>0 477 | 478 | hidden(l).ins.statejp = hidden(l).ins.statej.*network.hidden(l).ins.v(network.keepind,:,t); 479 | 480 | 481 | end 482 | 483 | 484 | 485 | 486 | if ttemp>0 487 | 488 | 489 | 490 | hidden(l).ins.j(network.keepind,:) = network.hidden(l).ins.state(:,:,t-1).*hidden(l).ins.statej; 491 | else 492 | 493 | 494 | hidden(l).ins.j(network.keepind,:)= network.hidden(l).ins.statep0.*hidden(l).ins.statej; 495 | end 496 | 497 | 498 | hidden(l).ins.j(network.writeind,:) = network.hidden(l).ins.v(network.hidind,:,t).*hidden(l).ins.statej; 499 | hidden(l).ins.j(network.hidind,:) = network.hidden(l).ins.v(network.writeind,:,t).*hidden(l).ins.statej; 500 | 501 | hidden(l).ins.j(network.gateind,:)= hidden(l).ins.j(network.gateind,:).*sigdir(network.hidden(l).ins.v(network.gateind,:,t)); 502 | network.hidden(l).biases.j(network.gateind,:) = network.hidden(l).biases.j(network.gateind,:) + sum(hidden(l).ins.j(network.gateind,:),2); 503 | 504 | 505 | 506 | hidden(l).mult.j = network.hidden(l).weights.matrix'*hidden(l).ins.j; 507 | 508 | 509 | 510 | 511 | 512 | 513 | network.hidden(l).weights.gradient = network.hidden(l).weights.gradient + (hidden(l).ins.j)*network.hidden(l).mult.v(:,:,t)'; 514 | 515 | 516 | 517 | 518 | if t-1>0 519 | 520 | hidden(l).intermediates.j = hidden(l).mult.j.*network.hidden(l).factor.v(:,:,t); 521 | 522 | hidden(l).factor.j= hidden(l).mult.j.*network.hidden(l).intermediates.v(:,:,t); 523 | 524 | network.hidden(l).fweights.gradient = network.hidden(l).fweights.gradient + hidden(l).factor.j*network.input.outs.v(:,:,t)'; 525 | 526 | hidden(l).outs.jp = network.hidden(l).mweights.matrix'*hidden(l).intermediates.j; 527 | network.hidden(l).mweights.gradient = network.hidden(l).mweights.gradient + hidden(l).intermediates.j*network.hidden(l).outs.v(:,:,t-1)'; 528 | else 529 | hidden(l).intermediates.j = hidden(l).mult.j.*network.hidden(l).factor.v(:,:,t); 530 | 531 | hidden(l).factor.j= hidden(l).mult.j.*network.hidden(l).intermediates.v(:,:,t); 532 | 533 | network.hidden(l).fweights.gradient = network.hidden(l).fweights.gradient + hidden(l).factor.j*network.input.outs.v(:,:,t)'; 534 | 535 | hidden(l).outs.jp = network.hidden(l).mweights.matrix'*hidden(l).intermediates.j; 536 | network.hidden(l).mweights.gradient = network.hidden(l).mweights.gradient + hidden(l).intermediates.j*network.hidden(l).outs.vp0'; 537 | 538 | 539 | 540 | 541 | end 542 | 543 | 544 | network.hidden(l).iweights.gradient = network.hidden(l).iweights.gradient + (hidden(l).ins.j)*network.input.outs.v(:,:,t)'; 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | hidden(l).outs.j = hidden(l).outs.jp; 553 | hidden(l).ins.statej = hidden(l).ins.statejp; 554 | 555 | 556 | 557 | end 558 | 559 | end 560 | 561 | 562 | function [lcost,bitacc] = evalCrossEntropy(output,targets,omat) 563 | 564 | 565 | 566 | oind = find(omat); 567 | 568 | ldiff = targets.*log2(output); 569 | 570 | 571 | lcost = -1*sum(ldiff(:)); 572 | 573 | 574 | 575 | 576 | end 577 | 578 | function network = updateV(network, dW) 579 | 580 | ninput = network.input.n; 581 | noutput = network.output.n; 582 | 583 | 584 | start = 1; 585 | last = 0; 586 | 587 | for l=1:length(network.hidden) 588 | nhidden = network.nhidden(l); 589 | 590 | last = last + numel(network.hidden(l).iweights.matrix); 591 | network.hidden(l).iweights.matrix = reshape(dW(start:last),4*nhidden,ninput)+ network.hidden(l).iweights.matrix ; 592 | start = last + 1; 593 | 594 | last = last + numel(network.hidden(l).biases.v); 595 | network.hidden(l).biases.v = reshape(dW(start:last),4*nhidden,1)+ network.hidden(l).biases.v ; 596 | start = last + 1; 597 | 598 | 599 | last = last + numel(network.hidden(l).fweights.matrix); 600 | network.hidden(l).fweights.matrix = reshape(dW(start:last),nhidden,ninput)+ network.hidden(l).fweights.matrix ; 601 | start = last + 1; 602 | 603 | 604 | for i=1:length(network.hidden(l).weights); 605 | last = last + numel(network.hidden(l).weights(i).matrix); 606 | network.hidden(l).weights(i).matrix = reshape(dW(start:last),4*nhidden,nhidden)+network.hidden(l).weights(i).matrix; 607 | start = last+1; 608 | last = last + numel(network.hidden(l).mweights(i).matrix); 609 | network.hidden(l).mweights(i).matrix = reshape(dW(start:last),nhidden,nhidden)+network.hidden(l).mweights(i).matrix; 610 | start = last+1; 611 | 612 | 613 | end 614 | 615 | 616 | 617 | 618 | 619 | 620 | last = last+ numel(network.output.weights(l).matrix); 621 | network.output.weights(l).matrix = reshape(dW(start:last),noutput,nhidden)+ network.output.weights(l).matrix ; 622 | start=last+1; 623 | 624 | end 625 | 626 | end 627 | 628 | function vect=weights2vect(allvects) 629 | lsum = 0; 630 | lengths = cell(length(allvects),1); 631 | for i=1:length(allvects) 632 | lsum = lsum + numel(allvects{i}); 633 | lengths{i}= lsum; 634 | 635 | 636 | end 637 | vect = zeros(lsum,1,'single'); 638 | 639 | vect(1:lengths{1}) = gather(reshape(allvects{1},lengths{1},1)); 640 | for i=2:length(allvects) 641 | vect(lengths{i-1}+1:lengths{i}) = gather(reshape(allvects{i},lengths{i}-lengths{i-1},1)); 642 | end 643 | 644 | 645 | end 646 | 647 | 648 | 649 | 650 | function network = initpass(network,nbatch,maxt) 651 | 652 | ninput = network.input.n; 653 | 654 | noutput = network.output.n; 655 | 656 | for l=1:length(network.hidden) 657 | 658 | nhidden = network.nhidden(l); 659 | network.output.weights(l).gradient = gpuArray(zeros(noutput,nhidden,'single')); 660 | 661 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 662 | if ~network.last 663 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 664 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 665 | else 666 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 667 | 668 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 669 | end 670 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 671 | 672 | 673 | network.hidden(l).ins.v = gpuArray(zeros(nhidden*4,nbatch,maxt,'single')); 674 | 675 | 676 | network.hidden(l).intermediates.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 677 | 678 | network.hidden(l).factor.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 679 | 680 | network.hidden(l).mult.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 681 | 682 | 683 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 684 | 685 | 686 | 687 | 688 | 689 | for i=1:length(network.hidden(l).weights); 690 | network.hidden(l).weights(i).gradient = gpuArray(zeros(nhidden*4,nhidden,'single')); 691 | network.hidden(l).mweights(i).gradient = gpuArray(zeros(nhidden,nhidden,'single')); 692 | 693 | end 694 | network.hidden(l).iweights.gradient = gpuArray(zeros(nhidden*4,ninput,'single')); 695 | network.hidden(l).fweights.gradient = gpuArray(zeros(nhidden,ninput,'single')); 696 | 697 | 698 | 699 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 700 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,maxt,'single')); 701 | 702 | 703 | 704 | end 705 | 706 | 707 | 708 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 709 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 710 | 711 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 712 | 713 | 714 | end 715 | 716 | function network = initpasstest(network,nbatch,maxt) 717 | 718 | ninput = network.input.n; 719 | 720 | noutput = network.output.n; 721 | 722 | for l=1:length(network.hidden) 723 | 724 | nhidden = network.nhidden(l); 725 | 726 | 727 | network.hidden(l).biases.j = gpuArray(zeros(nhidden*4,1,'single')); 728 | if ~network.last 729 | network.hidden(l).outs.vp0 = gpuArray(zeros(nhidden,nbatch,'single')); 730 | network.hidden(l).ins.statep0= gpuArray(zeros(nhidden,nbatch,'single')); 731 | else 732 | network.hidden(l).outs.vp0 = network.hidden(l).outs.last ; 733 | 734 | network.hidden(l).ins.statep0 = network.hidden(l).ins.last; 735 | end 736 | 737 | 738 | 739 | 740 | network.hidden(l).outs.v = gpuArray(zeros(nhidden,nbatch,'single')); 741 | network.hidden(l).ins.state = gpuArray(zeros(nhidden,nbatch,'single')); 742 | 743 | 744 | 745 | end 746 | 747 | 748 | 749 | network.output.outs.j = gpuArray(zeros(noutput,nbatch,maxt,'single')); 750 | network.input.outs.v = gpuArray(zeros(ninput,nbatch,maxt,'single')); 751 | 752 | network.output.outs.v = gpuArray(zeros(noutput,nbatch,maxt,'single')); 753 | 754 | 755 | end 756 | 757 | 758 | function network = initnetwork(ninput,nhidden,noutput) 759 | 760 | 761 | network.input.n = ninput; 762 | network.nhidden = nhidden; 763 | network.output.n = noutput; 764 | 765 | 766 | 767 | 768 | network.hidind = (1:nhidden)'; 769 | network.writeind = (nhidden+1:2*nhidden)'; 770 | network.keepind = (2*nhidden+1:3*nhidden)'; 771 | network.readind = (3*nhidden+1:4*nhidden)'; 772 | network.gateind = (nhidden+1:4*nhidden)'; 773 | for j = 1:1 774 | nhidden = network.nhidden(j); 775 | 776 | 777 | network.hidden(j).iweights.matrix = gpuArray(.1*(randn(nhidden*4,ninput,'single'))); 778 | network.hidden(j).fweights.matrix = gpuArray(.1*(randn(nhidden,ninput,'single'))); 779 | network.hidden(j).iweights.gated = 0; 780 | 781 | 782 | network.hidden(j).biases.v = gpuArray(zeros(4*nhidden,1,'single')); 783 | network.hidden(j).biases.v(network.keepind)=3; 784 | 785 | 786 | network.hidden(j).weights.matrix =gpuArray(.02*(randn(nhidden*4,nhidden,'single'))); 787 | network.hidden(j).mweights.matrix =gpuArray(.02*(randn(nhidden,nhidden,'single'))); 788 | 789 | 790 | 791 | 792 | network.hidden(j).fx = @sigmoid; 793 | network.hidden(j).dx = @sigdir; 794 | network.output.weights(j).matrix = gpuArray(.1*(randn(noutput,nhidden,'single'))); 795 | 796 | 797 | end 798 | 799 | 800 | network.nparam = length(weights2vect(getW(network))); 801 | 802 | 803 | network.output.fx = @softmax; 804 | network.output.dx = @softdirXent; 805 | network.errorFunc = @evalCrossEntropy; 806 | 807 | end 808 | 809 | 810 | function J = getJ(network) 811 | jtot=1; 812 | J = cell(jtot,1); 813 | c=1; 814 | for l=1:length(network.hidden) 815 | J{c}= network.hidden(l).iweights.gradient; 816 | c=c+1; 817 | J{c}= network.hidden(l).biases.j; 818 | c=c+1; 819 | J{c}= network.hidden(l).fweights.gradient; 820 | c=c+1; 821 | 822 | 823 | for i = 1:length(network.hidden(l).weights); 824 | J{c}=network.hidden(l).weights(i).gradient; 825 | c=c+1; 826 | J{c}=network.hidden(l).mweights(i).gradient; 827 | c=c+1; 828 | 829 | 830 | end 831 | 832 | J{c} = 1*network.output.weights(l).gradient; 833 | c=c+1; 834 | end 835 | 836 | 837 | 838 | end 839 | function W = getW(network) 840 | jtot=1; 841 | W = cell(jtot,1); 842 | c=1; 843 | for l=1:length(network.hidden) 844 | 845 | W{c}= network.hidden(l).iweights.matrix; 846 | c=c+1; 847 | W{c}= network.hidden(l).biases.v; 848 | c=c+1; 849 | W{c}= network.hidden(l).fweights.matrix; 850 | c=c+1; 851 | 852 | 853 | 854 | for i = 1:length(network.hidden(l).weights); 855 | W{c}=network.hidden(l).weights(i).matrix; 856 | c=c+1; 857 | W{c}=network.hidden(l).mweights(i).matrix; 858 | c=c+1; 859 | 860 | end 861 | 862 | W{c} = network.output.weights(l).matrix; 863 | c=c+1; 864 | end 865 | 866 | 867 | 868 | end 869 | 870 | 871 | 872 | 873 | 874 | function f= sigmoid(x) 875 | 876 | 877 | f= 1./(1+ exp(-1.*x)); 878 | end 879 | 880 | function o = softdirXent(x); 881 | 882 | o=ones(size(x),'single'); 883 | 884 | 885 | end 886 | 887 | function dir = sigdir( y ) 888 | 889 | dir = y.*(1-y); 890 | 891 | 892 | end 893 | function dir = tanhdir( y ) 894 | 895 | dir = (1-y.*y); 896 | 897 | 898 | end 899 | function o = softmax(x) 900 | 901 | o=bsxfun(@times,1./sum(exp(x),1),exp(x)); 902 | end 903 | 904 | %function m=gather(m) 905 | 906 | %end 907 | %function m=gpuArray(m) 908 | 909 | %end 910 | --------------------------------------------------------------------------------