├── .gitignore ├── README.md ├── data_processing.lua ├── model_utils.lua ├── network ├── CosineSimilarity.lua ├── GRU.lua ├── LSTM.lua ├── VanillaRNN.lua ├── joinlayer.lua ├── splitlayer.lua └── unbiased_linear.lua └── rnn_playground.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.t7 2 | *~ 3 | *.txt 4 | /Defunct/*.* 5 | /Experimental/*.* 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | NOTES 7/25/15: 2 | 3 | Hello! 4 | 5 | So, uh, this whole thing is still kind of under construction. This is not a well-polished piece of code. But it is probably fun enough to play with that it is worth sharing. 6 | 7 | This is a code for building and training RNNs to learn and generate text. Character stuff is easiest, although you can change the matching pattern. (For example, you could learn two-characters at a time, or ignore all punctuation, and so on.) 8 | 9 | This is basically a poor man's version of Andrej Karpathy's char-RNN, which came out earlier this summer, and was awesome. (In fact, I have flat-out ripped off his model_utils file with no modification. The credit and glory here is not all mine.) I spent a good chunk of time learning how to do things in Torch from his and other examples. Originally, I was planning on doing sequence-to-sequence learning (http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf), but that is a little ambitious for this very modest code, so I'm going to do that in another codebase. 10 | 11 | But, there are some things I've done here which you may find interesting/useful! 12 | 13 | I have implemented LSTM with three potentially desirable features. First, peephole connections, as described here (http://arxiv.org/abs/1503.04069). This lets the gates look at the memory cell contents, instead of just relying on the hidden state. Useful! Second, my LSTM code allows you to make layers of variable sizes. So for example you could have a deep LSTM network where the first layer's hidden state and memory cells are 512-dimensional vectors, and the second layer's are 256-dim, and so on. Lastly, forget gate biases are initialized to 1, which is known to help LSTMs perform better at long-term sequence learning tasks. Naive initialization inhibits gradient flow, and this is a quick-fix. 14 | 15 | Vanilla RNN and Gated Recurrent Unit implementations are also present and usable. 16 | 17 | I've also made a couple of potentially useful utility layers. Joinlayer and Splitlayer respectively concatenate a table of vectors into one large vector, and split one large vector into a table of smaller vectors. They are probably not maximally memory efficient, but eh, I'll get around to that eventually. I really should have done the splitting with narrows, but, again, eventually. There is also an unbiased linear layer in here (a linear layer with just a weight matrix, and no bias). This is useful if, for instance, you are making gated units, and there should only be one bias per gate instead of two (which is the case in some of the LSTM examples I have seen out there). There's also CosineSimilarity, which I was going to use for something but never got around to it. 18 | 19 | For training, there are two modes, controlled by the xrnn.training_mode variable. In mode 2, you proceed linearly through the text corpus. In mode 1, you randomly sample sequences from anywhere in the text corpus as you train. If your corpus is very different in different places, the random sampling might help smooth things out during training. 20 | 21 | Also, if you want to split your raw text into word tokens instead of characters, I have made that particularly easy. Just use pattern='word' as one of the options. But, uh, to note, that one does not seem to go over super great during training. I'll try adding some better support for it later 22 | 23 | So, none of this is really, uh, user friendly yet. If you want to use it, here's how you do: 24 | 25 | 1. Fire up torch. 26 | 27 | 2. require 'rnn_playground' 28 | 29 | 3. build out your list of args, read comments in rnn_playground.lua to see what your options are. you can omit any options and it will default to something safe and simple, though. 30 | 31 | 4. init(args) 32 | 33 | 5. then to train, run 'train_network_N_steps(N)' with some number as the argument 34 | 35 | 6. or, if you want to proceed through your training corpus all the way some number of times, do 'epochs(N)' where N is reasonable 36 | 37 | 7. to sample, run 'sample_from_network({length=$some_number})' with some number as the argument 38 | 39 | 8. if you trained it on a corpus of chat data*, try 'test_conv(temperature)' to chat with it! 40 | temperature is a parameter between 0 and 1, which sharpens the probability distribution as you lower it. 41 | that is, it becomes more deterministic for lower temperatures, and more random for higher temps. 42 | I find that temps around 0.65-0.85 are nice. 43 | 44 | Saving is easy: just do 'save(filename)'. Loading is also easy: 'load(filename)' does it. But if you want to make a new network with the same options as an old network, do 'loadFromOptions(filename)' and that will do! 45 | 46 | I can take requests for additional features. I guess a command line option might be popular? I don't know, I like working with it directly in torch. Get to play with the guts of it while I am going. 47 | 48 | There are no validation statistics in this - for a scientist, I was a very bad scientist about this. Maybe I'll add it in later. But I haven't had any troubles with overfitting yet. 49 | 50 | 51 | * A corpus of chat data - to play nicely with this code - should be a text file where every line is a chatroom message. (Scrub off usernames, timestamps, all that jazz. Just messages separated by linebreaks.) 52 | -------------------------------------------------------------------------------- /data_processing.lua: -------------------------------------------------------------------------------- 1 | 2 | -- DATA MANIPULATION FUNCTIONS 3 | function data(filename) 4 | local f = torch.DiskFile(filename) 5 | local rawdata = f:readString('*a') 6 | f:close() 7 | return rawdata 8 | end 9 | 10 | function split_string(rawdata,pattern,lower,wordopt) 11 | local replace 12 | if wordopt then 13 | replace = wordopt.replace or false 14 | end 15 | if pattern == 'word' then 16 | ptrn = '[^%s]+\n?' 17 | else 18 | ptrn = pattern 19 | end 20 | local breakapart = rawdata:gmatch(ptrn) 21 | local splitstring = {} 22 | local tokens = {} 23 | for elem in breakapart do 24 | tokens = {} 25 | if lower then elem = elem:lower() end 26 | if pattern == 'word' then 27 | local pref = {} 28 | local front = elem 29 | local back = {} 30 | 31 | -- strip off punctuation characters and newlines 32 | for i=1,front:len() do 33 | local prevchar = front:sub(1,1) 34 | if prevchar:match('[%p\n]') then 35 | table.insert(pref,prevchar) 36 | front = front:sub(2,front:len()) 37 | else 38 | break 39 | end 40 | end 41 | for i=front:len(),1,-1 do 42 | local lastchar = front:sub(front:len(),front:len()) 43 | if lastchar:match('[%p\n]') then 44 | table.insert(back,lastchar) 45 | front = front:sub(1,front:len()-1) 46 | else 47 | break 48 | end 49 | end 50 | 51 | -- prefix characters/punctuation to tokens 52 | for i=1,#pref do 53 | tokens[#tokens+1] = pref[i] 54 | end 55 | 56 | -- word to token 57 | -- time for some common replacements! 58 | if replace and front then 59 | local asplit = {} 60 | local ba = front:gmatch('[^\']+') 61 | for a in ba do 62 | table.insert(asplit,a) 63 | end 64 | local replaceflag = false 65 | if #asplit > 1 then 66 | local prev = asplit[#asplit-1]:lower() 67 | local last = asplit[#asplit]:lower() 68 | if last == 'll' then 69 | asplit[#asplit] = 'will' 70 | replaceflag = true 71 | elseif last == 'm' then 72 | asplit[#asplit] = 'am' 73 | replaceflag = true 74 | elseif last == 've' then 75 | asplit[#asplit] = 'have' 76 | replaceflag = true 77 | elseif last == 're' then 78 | asplit[#asplit] = 'are' 79 | replaceflag = true 80 | elseif last == 's' then 81 | if prev == 'he' or prev == 'she' 82 | or prev == 'that' or prev == 'this' 83 | or prev == 'it' or prev == 'how' 84 | or prev == 'why' or prev == 'who' 85 | or prev == 'when' or prev == 'what' then 86 | asplit[#asplit] = 'is' 87 | replaceflag = true 88 | end 89 | end 90 | end 91 | if not(replaceflag) then 92 | tokens[#tokens+1] = front 93 | else 94 | for i=1,#asplit do 95 | tokens[#tokens+1] = asplit[i] 96 | end 97 | end 98 | else 99 | tokens[1] = front 100 | end 101 | 102 | --suffic characters/punctuation to tokens 103 | for i=#back,1,-1 do 104 | tokens[#tokens+1] = back[i] 105 | end 106 | else 107 | tokens[1] = elem 108 | end 109 | for i,v in pairs(tokens) do table.insert(splitstring,tokens[i]) end 110 | end 111 | return splitstring 112 | end 113 | 114 | function data_processing(rawdata,pattern,lower,wordopt) 115 | local usemostcommon = false 116 | local useNmostcommon = 4500 117 | if wordopt then 118 | usemostcommon = wordopt.usemostcommon or false 119 | useNmostcommon = wordopt.useNmostcommon or 4500 120 | end 121 | local embeddings = {} 122 | local deembeddings = {} 123 | local freq = {} 124 | local numkeys = 0 125 | local numwords = 0 126 | 127 | -- split the string and make embeddings/deembeddings/freq 128 | local splitstring = split_string(rawdata,pattern,lower,wordopt) 129 | numwords = #splitstring 130 | tokenized = torch.zeros(numwords) 131 | for i=1,numwords do 132 | if not embeddings[splitstring[i]] then 133 | numkeys = numkeys + 1 134 | embeddings[splitstring[i]] = numkeys 135 | deembeddings[numkeys] = splitstring[i] 136 | freq[numkeys] = {1,numkeys} 137 | else 138 | freq[embeddings[splitstring[i]]][1] = freq[embeddings[splitstring[i]]][1] + 1 139 | end 140 | tokenized[i] = embeddings[splitstring[i]] 141 | end 142 | 143 | -- only take the most frequent entries 144 | local num_represented = 0 145 | if usemostcommon then 146 | numkeys = math.min(numkeys,useNmostcommon) 147 | table.sort(freq,function(a,b) return a[1]>b[1] end) 148 | local new_embed = {} 149 | local new_deembed = {} 150 | for i=1,numkeys do 151 | new_deembed[i] = deembeddings[freq[i][2]] 152 | new_embed[new_deembed[i]] = i 153 | num_represented = num_represented + freq[i][1] 154 | end 155 | embeddings = new_embed 156 | deembeddings = new_deembed 157 | print('Dictionary captures about ', 100*num_represented/numwords, '% of text.') 158 | -- rebuild tokenized: 159 | for i=1,numwords do 160 | tokenized[i] = embeddings[splitstring[i]] or numkeys + 1 161 | end 162 | end 163 | 164 | return embeddings, deembeddings, numkeys, numwords, tokenized, freq 165 | end 166 | 167 | -------------------------------------------------------------------------------- /model_utils.lua: -------------------------------------------------------------------------------- 1 | -- Credit to Andrej Karpathy's RNN code, as this is directly lifted with (at present) no modification 2 | -- adapted from https://github.com/wojciechz/learning_to_execute 3 | -- utilities for combining/flattening parameters in a model 4 | -- the code in this script is more general than it needs to be, which is 5 | -- why it is kind of a large 6 | 7 | require 'torch' 8 | local model_utils = {} 9 | function model_utils.combine_all_parameters(...) 10 | --[[ like module:getParameters, but operates on many modules ]]-- 11 | 12 | -- get parameters 13 | local networks = {...} 14 | local parameters = {} 15 | local gradParameters = {} 16 | for i = 1, #networks do 17 | local net_params, net_grads = networks[i]:parameters() 18 | 19 | if net_params then 20 | for _, p in pairs(net_params) do 21 | parameters[#parameters + 1] = p 22 | end 23 | for _, g in pairs(net_grads) do 24 | gradParameters[#gradParameters + 1] = g 25 | end 26 | end 27 | end 28 | 29 | local function storageInSet(set, storage) 30 | local storageAndOffset = set[torch.pointer(storage)] 31 | if storageAndOffset == nil then 32 | return nil 33 | end 34 | local _, offset = unpack(storageAndOffset) 35 | return offset 36 | end 37 | 38 | -- this function flattens arbitrary lists of parameters, 39 | -- even complex shared ones 40 | local function flatten(parameters) 41 | if not parameters or #parameters == 0 then 42 | return torch.Tensor() 43 | end 44 | local Tensor = parameters[1].new 45 | 46 | local storages = {} 47 | local nParameters = 0 48 | for k = 1,#parameters do 49 | local storage = parameters[k]:storage() 50 | if not storageInSet(storages, storage) then 51 | storages[torch.pointer(storage)] = {storage, nParameters} 52 | nParameters = nParameters + storage:size() 53 | end 54 | end 55 | 56 | local flatParameters = Tensor(nParameters):fill(1) 57 | local flatStorage = flatParameters:storage() 58 | 59 | for k = 1,#parameters do 60 | local storageOffset = storageInSet(storages, parameters[k]:storage()) 61 | parameters[k]:set(flatStorage, 62 | storageOffset + parameters[k]:storageOffset(), 63 | parameters[k]:size(), 64 | parameters[k]:stride()) 65 | parameters[k]:zero() 66 | end 67 | 68 | local maskParameters= flatParameters:float():clone() 69 | local cumSumOfHoles = flatParameters:float():cumsum(1) 70 | local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 71 | local flatUsedParameters = Tensor(nUsedParameters) 72 | local flatUsedStorage = flatUsedParameters:storage() 73 | 74 | for k = 1,#parameters do 75 | local offset = cumSumOfHoles[parameters[k]:storageOffset()] 76 | parameters[k]:set(flatUsedStorage, 77 | parameters[k]:storageOffset() - offset, 78 | parameters[k]:size(), 79 | parameters[k]:stride()) 80 | end 81 | 82 | for _, storageAndOffset in pairs(storages) do 83 | local k, v = unpack(storageAndOffset) 84 | flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 85 | end 86 | 87 | if cumSumOfHoles:sum() == 0 then 88 | flatUsedParameters:copy(flatParameters) 89 | else 90 | local counter = 0 91 | for k = 1,flatParameters:nElement() do 92 | if maskParameters[k] == 0 then 93 | counter = counter + 1 94 | flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 95 | end 96 | end 97 | assert (counter == nUsedParameters) 98 | end 99 | return flatUsedParameters 100 | end 101 | 102 | -- flatten parameters and gradients 103 | local flatParameters = flatten(parameters) 104 | local flatGradParameters = flatten(gradParameters) 105 | 106 | -- return new flat vector that contains all discrete parameters 107 | return flatParameters, flatGradParameters 108 | end 109 | 110 | 111 | 112 | 113 | function model_utils.clone_many_times(net, T) 114 | local clones = {} 115 | 116 | local params, gradParams 117 | if net.parameters then 118 | params, gradParams = net:parameters() 119 | if params == nil then 120 | params = {} 121 | end 122 | end 123 | 124 | local paramsNoGrad 125 | if net.parametersNoGrad then 126 | paramsNoGrad = net:parametersNoGrad() 127 | end 128 | 129 | local mem = torch.MemoryFile("w"):binary() 130 | mem:writeObject(net) 131 | 132 | for t = 1, T do 133 | -- We need to use a new reader for each clone. 134 | -- We don't want to use the pointers to already read objects. 135 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 136 | local clone = reader:readObject() 137 | reader:close() 138 | 139 | if net.parameters then 140 | local cloneParams, cloneGradParams = clone:parameters() 141 | local cloneParamsNoGrad 142 | for i = 1, #params do 143 | cloneParams[i]:set(params[i]) 144 | cloneGradParams[i]:set(gradParams[i]) 145 | end 146 | if paramsNoGrad then 147 | cloneParamsNoGrad = clone:parametersNoGrad() 148 | for i =1,#paramsNoGrad do 149 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 150 | end 151 | end 152 | end 153 | 154 | clones[t] = clone 155 | collectgarbage() 156 | end 157 | 158 | mem:close() 159 | return clones 160 | end 161 | 162 | return model_utils 163 | -------------------------------------------------------------------------------- /network/CosineSimilarity.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local cosines = torch.class('nn.CosineSimilarity','nn.Module') 4 | 5 | function cosines:__init(divbyx) 6 | self.divbyx = divbyx or false 7 | self.output = torch.Tensor() 8 | self.gradInput = {} 9 | 10 | -- Useful things from forward pass to keep 11 | self.wx = torch.Tensor() 12 | self.wc = torch.Tensor() 13 | self.w = torch.Tensor() 14 | self.xc = torch.Tensor() 15 | self.x = torch.Tensor() 16 | end 17 | 18 | function cosines:updateOutput(input) 19 | -- W is an 'nkeys x embveclen' matrix 20 | -- X is an 'nbatch x embveclen' matrix 21 | -- the goal is to output a matrix C that is 'nbatch x nkeys', 22 | -- where C[i,j] is the cosine similarity of the vector X[i] with 23 | -- key embedding W[j]. That is, C[i,j] = dot(W[j],X[i]) / |W[j]|. 24 | local W, X = input[1], input[2] 25 | 26 | if W:dim() == 1 then W:resize(1,W:size(1)) end 27 | if X:dim() == 1 then X:resize(1,X:size(1)) end 28 | 29 | local w = W:clone() 30 | w = w:cmul(w):sum(2) 31 | self.wc = w:sqrt() 32 | w = self.wc:clone():t() -- This is a vector of the norms |W[j]| 33 | 34 | self.wx = torch.mm(X,W:t()) 35 | w = w:expandAs(self.wx) 36 | self.w = w 37 | 38 | self.output = torch.cdiv(self.wx,self.w) 39 | 40 | local x 41 | if self.divbyx then 42 | x = X:clone() 43 | x = x:cmul(x):sum(2) 44 | self.xc = x:sqrt() 45 | x = self.xc:clone() 46 | x = x:expandAs(self.wx) 47 | self.x = x 48 | self.output:cdiv(x) 49 | end 50 | 51 | return self.output 52 | end 53 | 54 | function cosines:updateGradInput(input,gradOutput) 55 | local W,X = input[1],input[2] 56 | 57 | if W:dim() == 1 then W:resize(1,W:size(1)) end 58 | if X:dim() == 1 then X:resize(1,X:size(1)) end 59 | 60 | --[[ 61 | oh man. OH MAN. are you ready for a bunch of ugly matrix ops? 62 | because i know i am. 63 | 64 | so here is how we are running this show. think of the forward 65 | as four major computational blocks. 66 | block 1: go from X to x, where x is the 'nbatch x nkeys' matrix 67 | with elements x[i,j] = |X[i,1:embveclen]| 68 | block 2: go from W to w, where w is the 'nbatch x nkeys' matrix 69 | with elements w[i,j] = |W[j,1:embveclen]| 70 | block 3: take X and W and produce wx = XW^T 71 | block 4: take wx,w,x, and give q = wx/w/x. 72 | 73 | but we are only using block 1 if divbyx. otherwise no block 1, and 74 | block 4 gives q = wx/w. 75 | 76 | we are going to go graphways backwards through that. 77 | hold on to your butts. 78 | 79 | so, the output from the module is q, which is 'nbatch x nkeys' 80 | gradOuput is therefore also 'nbatch x nkeys' 81 | 82 | first we will get the gradients of q with resp to wx,w,x. 83 | ]] 84 | 85 | -- block 4 backward pass 86 | local g4x, g4w, g4wx 87 | g4wx = torch.ones(self.output:size()):typeAs(W):cdiv(self.w) 88 | g4w = -torch.cdiv(self.output,self.w) 89 | if self.divbyx then 90 | g4wx = g4wx:cdiv(self.x) 91 | g4x = -torch.cdiv(self.output,self.x) 92 | end 93 | 94 | -- now we make the gradients of wx,w,x with respect to the loss 95 | local glx, glw, glwx 96 | glwx = torch.cmul(gradOutput,g4wx) 97 | glw = torch.cmul(gradOutput,g4w) 98 | if self.divbyx then 99 | glx = torch.cmul(gradOutput,g4x) 100 | end 101 | 102 | -- optional block 1 backward pass. 103 | -- Things that happen in block 1: 104 | -- in -a- -b- -c- -d- 105 | -- X --> X cmul X --> sum(2) --> sqrt() --> expand 106 | local g1a,g1b,g1c,g1d 107 | if self.divbyx then 108 | -- First, backwards through the expand: 109 | g1d = glx:sum(2) 110 | -- Next, backwards through the sqrt: 111 | -- for y=sqrt(x), dy/dx = (1/2)(1/y) 112 | g1c = g1d:cmul(self.xc:clone():pow(-1)):mul(0.5) 113 | -- Backwards through the sum: 114 | g1b = g1c:expandAs(X) 115 | -- Backwards through the squaring: 116 | -- for y = x^2, dy/dx = 2x 117 | g1a = torch.cmul(g1b,X):mul(2) 118 | end 119 | 120 | -- block 2 backward pass. Conceptually the same as block 1 backward. 121 | local g2a,g2b,g2c,g2d 122 | g2d = glw:sum(1) 123 | g2c = g2d:cmul(self.wc:clone():pow(-1)):mul(0.5):t() 124 | g2b = g2c:expandAs(W) 125 | g2a = torch.cmul(g2b,W):mul(2) 126 | 127 | -- block 3 backward pass. Simple because W,X -> XW^T is linear in W,X. 128 | local g3X, g3W 129 | g3X = torch.mm(glwx,W) 130 | g3W = torch.mm(glwx:t(),X) 131 | 132 | -- Sum up for final grads 133 | local gW,gX 134 | gW = g3W + g2a 135 | gX = g3X 136 | if self.divbyx then gX = gX + g1a end 137 | 138 | -- This was done by hand, and so, uh, you know. Who knows if it will work! 139 | -- I have tested it on some very small cases, where yeah, it does work fo sho 100%, 140 | -- but you know. Use at your own risk. 141 | self.gradInput = {gW,gX} 142 | return self.gradInput 143 | end 144 | -------------------------------------------------------------------------------- /network/GRU.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | require 'joinlayer' 5 | require 'splitlayer' 6 | require 'unbiased_linear' 7 | 8 | -- Gated Recurrent Units implementation 9 | function GRU(input_dim,layer_sizes,opt) 10 | local decoder = opt.decoder or false 11 | 12 | local inputs = {} 13 | inputs[1] = nn.Identity()() -- Hidden state input 14 | inputs[2] = nn.Identity()() -- External input 15 | 16 | -- How many elements in all the layers? 17 | local m = 0 18 | for i=1,#layer_sizes do m = m + layer_sizes[i] end 19 | 20 | -- hidden_split gets a table of the split hidden states 21 | local hidden_split 22 | local hidden_states_prev = {} 23 | if #layer_sizes > 1 then 24 | hidden_split = nn.SplitLayer(m,layer_sizes)(inputs[1]) 25 | hidden_states_prev = {hidden_split:split(#layer_sizes)} 26 | else 27 | hidden_states_prev[1] = inputs[1] 28 | end 29 | 30 | -- utility function: gives node with linear transform of input 31 | local function new_input_sum(indim,layer_size,innode,hiddennode) 32 | local i2h = nn.Linear(indim,layer_size)(innode) 33 | local h2h = nn.UnbiasedLinear(layer_size,layer_size)(hiddennode) 34 | return nn.CAddTable()({i2h,h2h}) 35 | end 36 | 37 | local hidden_states_cur = {} 38 | for j=1,#layer_sizes do 39 | local innode, indim 40 | if j==1 then 41 | innode = inputs[2] 42 | indim = input_dim 43 | else 44 | innode = hidden_states_cur[j-1] 45 | indim = layer_sizes[j-1] 46 | end 47 | 48 | local z,r,htil 49 | 50 | -- Update gate, z, and reset gate, r 51 | z = nn.Sigmoid()(new_input_sum(indim,layer_sizes[j],innode,hidden_states_prev[j])) 52 | r = nn.Sigmoid()(new_input_sum(indim,layer_sizes[j],innode,hidden_states_prev[j])) 53 | 54 | -- 1 - z 55 | ztil =nn.AddConstant(1,true)(nn.MulConstant(-1,false)(z)) 56 | 57 | -- Proposed new memory content 58 | htil = nn.Tanh()(nn.CAddTable()({ 59 | nn.Linear(indim,layer_sizes[j])(innode), 60 | nn.CMulTable()({r,nn.Linear(layer_sizes[j],layer_sizes[j])(hidden_states_prev[j])}) 61 | })) 62 | 63 | -- Update 64 | hidden_states_cur[j] = nn.CAddTable()({ 65 | nn.CMulTable()({hidden_states_prev[j],z}), 66 | nn.CMulTable()({htil,ztil}) 67 | }) 68 | 69 | end 70 | 71 | local external_output_base = hidden_states_cur[#layer_sizes] 72 | local external_output 73 | local hidden_state_output 74 | if #layer_sizes > 1 then 75 | hidden_state_output = nn.JoinLayer()(hidden_states_cur) 76 | else 77 | hidden_state_output = hidden_states_cur[1] 78 | end 79 | if not(decoder) then 80 | -- I know this looks dumb... but yes, this is intentional. 81 | -- There's a dumb error with one-layer nets otherwise. 82 | external_output = nn.Identity()(external_output_base) 83 | else 84 | external_output = nn.Linear(layer_sizes[#layer_sizes],input_dim)(external_output_base) 85 | end 86 | local outputs = {hidden_state_output,external_output} 87 | 88 | return nn.gModule(inputs,outputs) 89 | end 90 | -------------------------------------------------------------------------------- /network/LSTM.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | require 'joinlayer' 5 | require 'splitlayer' 6 | require 'unbiased_linear' 7 | 8 | -- layer_sizes is a table whose entries are the number of cells per layer 9 | -- fgate_init is a flag: if true, then initialize the forget gate biases to 1 10 | function LSTM(input_dim,layer_sizes,opt) 11 | local peep = opt.peepholes or false 12 | local decoder = opt.decoder or false 13 | local fgate_init = opt.fgate_init or false 14 | 15 | local inputs = {} 16 | inputs[1] = nn.Identity()() -- Hidden state input 17 | inputs[2] = nn.Identity()() -- External input 18 | 19 | -- How many elements in all the layers? 20 | local m = 0 21 | for i=1,#layer_sizes do m = m + layer_sizes[i] end 22 | 23 | local sizes_with_cells = {} 24 | for i=1,#layer_sizes do 25 | sizes_with_cells[i] = layer_sizes[i] 26 | sizes_with_cells[i + #layer_sizes] = layer_sizes[i] 27 | end 28 | 29 | -- hidden_split gets a table of the split hidden states 30 | local hidden_split = nn.SplitLayer(2*m,sizes_with_cells)(inputs[1]) 31 | local hidden_states_prev = {hidden_split:split(2*#layer_sizes)} 32 | 33 | -- utility function: gives node with linear transform of input 34 | local function new_input_sum(indim,layer_size,innode,hiddennode,biasflag) 35 | local biasflag = biasflag or false 36 | local i2h = nn.Linear(indim,layer_size)(innode) 37 | local h2h = nn.UnbiasedLinear(layer_size,layer_size)(hiddennode) 38 | if biasflag then 39 | i2h.data.module.bias:fill(1) 40 | end 41 | return nn.CAddTable()({i2h,h2h}) 42 | end 43 | 44 | -- we will assume the following structure in the hidden state: 45 | -- the first /k/ entries are the h_j, and the second /k/ entries 46 | -- are the cell memory states c_j. 47 | local hidden_states_cur = {} 48 | for j=1,#layer_sizes do 49 | local innode, indim 50 | if j==1 then 51 | innode = inputs[2] 52 | indim = input_dim 53 | else 54 | innode = hidden_states_cur[j-1] 55 | indim = layer_sizes[j-1] 56 | end 57 | 58 | local zbar, ibar, fbar, obar, z, i, f, o, p_i, p_f, p_o 59 | 60 | -- Input block, input gate, forget gate, output gate linear transforms. 61 | zbar = new_input_sum(indim,layer_sizes[j],innode,hidden_states_prev[j]) 62 | ibar = new_input_sum(indim,layer_sizes[j],innode,hidden_states_prev[j]) 63 | fbar = new_input_sum(indim,layer_sizes[j],innode,hidden_states_prev[j],fgate_init) 64 | obar = new_input_sum(indim,layer_sizes[j],innode,hidden_states_prev[j]) 65 | 66 | -- Input block nonlinear 67 | z = nn.Tanh()(zbar) 68 | 69 | -- Input and forget gate nonlinear / and possibly peepholes 70 | if not(peep) then 71 | i = nn.Sigmoid()(ibar) 72 | f = nn.Sigmoid()(fbar) 73 | else 74 | p_i = nn.CMul(layer_sizes[j])(hidden_states_prev[j+#layer_sizes]) 75 | p_f = nn.CMul(layer_sizes[j])(hidden_states_prev[j+#layer_sizes]) 76 | i = nn.Sigmoid()(nn.CAddTable()({ibar,p_f})) 77 | f = nn.Sigmoid()(nn.CAddTable()({fbar,p_f})) 78 | end 79 | 80 | -- Calculate memory cell values 81 | -- hidden_states_cur[j + #layer_sizes] is c_t for this layer 82 | hidden_states_cur[j + #layer_sizes] = nn.CAddTable()({ 83 | nn.CMulTable()({z,i}), 84 | nn.CMulTable()({f,hidden_states_prev[j + #layer_sizes]}) 85 | }) 86 | 87 | -- Output Gate nonlinear / and possibly peepholes 88 | if not(peep) then 89 | o = nn.Sigmoid()(obar) 90 | else 91 | p_o = nn.CMul(layer_sizes[j])(hidden_states_cur[j+#layer_sizes]) 92 | o = nn.Sigmoid()(nn.CAddTable()({obar,p_o})) 93 | end 94 | 95 | -- now make h_t for this layer 96 | hidden_states_cur[j] = nn.CMulTable()({ 97 | nn.Tanh()(hidden_states_cur[j + #layer_sizes]), o 98 | }) 99 | end 100 | 101 | local external_output_base = hidden_states_cur[#layer_sizes] 102 | local external_output 103 | local hidden_state_output = nn.JoinLayer()(hidden_states_cur) 104 | if not(decoder) then 105 | external_output = external_output_base 106 | else 107 | external_output = nn.Linear(layer_sizes[#layer_sizes],input_dim)(external_output_base) 108 | end 109 | local outputs = {hidden_state_output,external_output} 110 | 111 | return nn.gModule(inputs,outputs) 112 | end 113 | -------------------------------------------------------------------------------- /network/VanillaRNN.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | require 'joinlayer' 5 | require 'splitlayer' 6 | require 'unbiased_linear' 7 | 8 | function VanillaRNN(input_dim,layer_sizes,opt) 9 | local decoder = opt.decoder or false 10 | local nl_type = opt.nl_type or 'tanh' 11 | 12 | 13 | local inputs = {} 14 | inputs[1] = nn.Identity()() -- Hidden state input 15 | inputs[2] = nn.Identity()() -- External input 16 | 17 | -- How many elements in all the layers? 18 | local m = 0 19 | for i=1,#layer_sizes do m = m + layer_sizes[i] end 20 | 21 | -- hidden_split gets a table of the split hidden states 22 | local hidden_split 23 | local hidden_states_prev = {} 24 | if #layer_sizes > 1 then 25 | hidden_split = nn.SplitLayer(m,layer_sizes)(inputs[1]) 26 | hidden_states_prev = {hidden_split:split(#layer_sizes)} 27 | else 28 | hidden_states_prev[1] = inputs[1] 29 | end 30 | 31 | local hidden_states_cur = {} 32 | for j=1,#layer_sizes do 33 | local innode, indim 34 | if j==1 then 35 | innode = inputs[2] 36 | indim = input_dim 37 | else 38 | innode = hidden_states_cur[j-1] 39 | indim = layer_sizes[j-1] 40 | end 41 | 42 | local i2h = nn.Linear(indim,layer_sizes[j])(innode) 43 | local h2h = nn.UnbiasedLinear(layer_sizes[j],layer_sizes[j])(hidden_states_prev[j]) 44 | local hbar = nn.CAddTable()({i2h,h2h}) 45 | 46 | -- now make h_t for this layer 47 | if nl_type == 'sigmoid' then 48 | hidden_states_cur[j] = nn.Sigmoid()(hbar) 49 | elseif nl_type == 'relu' then 50 | hidden_states_cur[j] = nn.ReLU()(hbar) 51 | elseif nl_type == 'none' then 52 | hidden_states_cur[j] = hbar 53 | else 54 | hidden_states_cur[j] = nn.Tanh()(hbar) 55 | end 56 | end 57 | 58 | local external_output_base = hidden_states_cur[#layer_sizes] 59 | local external_output 60 | local hidden_state_output 61 | if #layer_sizes > 1 then 62 | hidden_state_output = nn.JoinLayer()(hidden_states_cur) 63 | else 64 | hidden_state_output = hidden_states_cur[1] 65 | end 66 | if not(decoder) then 67 | -- I know this looks dumb... but yes, this is intentional. 68 | -- There's a dumb error with one-layer nets otherwise. 69 | external_output = nn.Identity()(external_output_base) 70 | else 71 | external_output = nn.Linear(layer_sizes[#layer_sizes],input_dim)(external_output_base) 72 | end 73 | local outputs = {hidden_state_output,external_output} 74 | 75 | return nn.gModule(inputs,outputs) 76 | end 77 | -------------------------------------------------------------------------------- /network/joinlayer.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | --[[ 4 | Overview: 5 | This layer accepts a table of vectors and combines them 6 | all into one larger vector. 7 | 8 | Or, if this input is a table of batches of vectors, it 9 | combines them into an appropriate single batch of vectors. 10 | 11 | If the input is a set of k vectors of dimension n_j, then 12 | the output is a vector with N = sum_{j=1}^k n_j entries. 13 | 14 | If the input is a set of k M x n_j batch vectors, then the 15 | output is an M x (sum_{j=1}^k n_j) tensor. 16 | 17 | Table entries should be numbered from 1 to k. 18 | ]] 19 | 20 | local joinlayer = torch.class('nn.JoinLayer','nn.Module') 21 | 22 | function joinlayer:__init() 23 | self.gradInput = {} 24 | self.splitSizes = {} 25 | self.batch = false 26 | end 27 | 28 | 29 | function joinlayer:updateOutput(input) 30 | local X = input 31 | local splitSizes = {} 32 | 33 | -- Check the first entry in the table. If dim = 1, then all 34 | -- other entries should also be vectors. Otherwise, all other 35 | -- entries should be batches. 36 | local batch = X[1]:dim() > 1 37 | local M 38 | if batch then M = X[1]:size(1) end 39 | -- save 'batch' for gradients later 40 | self.batch = batch 41 | 42 | local N = 0 43 | -- Check to make sure that either all table entries are 44 | -- vectors OR that all table entries are M-batches of vectors. 45 | for i=1,#X do 46 | if batch then 47 | assert(X[i]:size(1) == M) 48 | splitSizes[i] = X[i]:size(2) 49 | N = N + X[i]:size(2) 50 | else 51 | assert(X[i]:dim() == 1) 52 | splitSizes[i] = X[i]:size(1) 53 | N = N + X[i]:size(1) 54 | end 55 | end 56 | -- save the splitSizes for gradients later 57 | self.splitSizes = splitSizes 58 | 59 | -- Make the output have the appropriate size and type 60 | local size 61 | if batch then 62 | size = {M,N} 63 | else 64 | size = {N} 65 | end 66 | self.output = torch.Tensor(unpack(size)):typeAs(X[1]) 67 | 68 | -- Build the output 69 | local ptr = 1 70 | for i=1,#X do 71 | if batch then 72 | self.output[{{},{ptr, ptr + splitSizes[i] - 1}}] = X[i] 73 | else 74 | self.output[{{ptr, ptr + splitSizes[i] - 1}}] = X[i] 75 | end 76 | ptr = ptr + splitSizes[i] 77 | end 78 | 79 | return self.output 80 | end 81 | 82 | function joinlayer:updateGradInput(input,gradOutput) 83 | 84 | self.gradInput = {} 85 | local ptr = 1 86 | for i=1,#self.splitSizes do 87 | if self.batch then 88 | self.gradInput[i] = gradOutput[{{},{ptr, ptr + self.splitSizes[i] - 1} }] 89 | else 90 | self.gradInput[i] = gradOutput[{{ptr, ptr + self.splitSizes[i] - 1} }] 91 | end 92 | ptr = ptr + self.splitSizes[i] 93 | end 94 | 95 | return self.gradInput 96 | end 97 | -------------------------------------------------------------------------------- /network/splitlayer.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | --[[ 4 | Overview: 5 | This layer accepts a vector with N elements as an input, 6 | and splits that vector into k sub-vectors. 7 | 8 | Or, if the input is a batch of vectors, it splits each 9 | vector in the batch accordingly. 10 | 11 | This returns a table with k elements. If the input was a 12 | vector, then output[j] has n_j elements, where 13 | 14 | sum_{j=1}^k n_j = N. 15 | 16 | If the input was a batch of M N-vectors, then output[j] has 17 | M x n_j elements. 18 | 19 | ]] 20 | 21 | local splitlayer = torch.class('nn.SplitLayer','nn.Module') 22 | 23 | -- splitSizes should be a table of k elements, 24 | -- {n_1, n_2, ..., n_k}. We must have 25 | -- sum_i n_i = N. 26 | function splitlayer:__init(N,splitSizes) 27 | self.N = N 28 | self.splitSizes = splitSizes 29 | 30 | -- Check that splitsizes add up to N 31 | local m = 0 32 | for i=1,#self.splitSizes do m = m + self.splitSizes[i] end 33 | assert(m == self.N) 34 | self.output = {} 35 | self.gradInput = torch.Tensor() 36 | end 37 | 38 | function splitlayer:updateOutput(input) 39 | local x = input 40 | 41 | local ptr = 1 42 | 43 | for j=1,#self.splitSizes do 44 | if x:dim() == 1 then 45 | self.output[j] = x[{{ptr, ptr+self.splitSizes[j] - 1}}] 46 | else 47 | self.output[j] = x[{{},{ptr, ptr+self.splitSizes[j] - 1}}] 48 | end 49 | ptr = ptr + self.splitSizes[j] 50 | end 51 | 52 | return self.output 53 | end 54 | 55 | function splitlayer:updateGradInput(input,gradOutput) 56 | self.gradInput = torch.Tensor():typeAs(input):resizeAs(input):zero() 57 | 58 | local ptr = 1 59 | for i=1,#self.splitSizes do 60 | if input:dim() == 1 then 61 | self.gradInput[{{ptr, ptr+self.splitSizes[i] - 1}}] = gradOutput[i] 62 | else 63 | self.gradInput[{{},{ptr, ptr+self.splitSizes[i] - 1}}] = gradOutput[i] 64 | end 65 | ptr = ptr + self.splitSizes[i] 66 | end 67 | 68 | return self.gradInput 69 | end 70 | -------------------------------------------------------------------------------- /network/unbiased_linear.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local UnbiasedLinear, parent = torch.class('nn.UnbiasedLinear', 'nn.Module') 4 | 5 | function UnbiasedLinear:__init(inputSize, outputSize) 6 | parent.__init(self) 7 | 8 | self.weight = torch.Tensor(outputSize, inputSize) 9 | self.gradWeight = torch.Tensor(outputSize, inputSize) 10 | 11 | self:reset() 12 | end 13 | 14 | function UnbiasedLinear:reset(stdv) 15 | if stdv then 16 | stdv = stdv * math.sqrt(3) 17 | else 18 | stdv = 1./math.sqrt(self.weight:size(2)) 19 | end 20 | if nn.oldSeed then 21 | for i=1,self.weight:size(1) do 22 | self.weight:select(1, i):apply(function() 23 | return torch.uniform(-stdv, stdv) 24 | end) 25 | end 26 | else 27 | self.weight:uniform(-stdv, stdv) 28 | end 29 | 30 | return self 31 | end 32 | 33 | function UnbiasedLinear:updateOutput(input) 34 | if input:dim() == 1 then 35 | self.output:resize(self.weight:size(1)) 36 | self.output:zero() 37 | self.output:addmv(1, self.weight, input) 38 | elseif input:dim() == 2 then 39 | local nframe = input:size(1) 40 | self.output:resize(nframe, self.weight:size(1)) 41 | self.output:addmm(0, self.output, 1, input, self.weight:t()) 42 | else 43 | error('input must be vector or matrix') 44 | end 45 | 46 | return self.output 47 | end 48 | 49 | function UnbiasedLinear:updateGradInput(input, gradOutput) 50 | if self.gradInput then 51 | 52 | local nElement = self.gradInput:nElement() 53 | self.gradInput:resizeAs(input) 54 | if self.gradInput:nElement() ~= nElement then 55 | self.gradInput:zero() 56 | end 57 | if input:dim() == 1 then 58 | self.gradInput:addmv(0, 1, self.weight:t(), gradOutput) 59 | elseif input:dim() == 2 then 60 | self.gradInput:addmm(0, 1, gradOutput, self.weight) 61 | end 62 | 63 | return self.gradInput 64 | end 65 | end 66 | 67 | function UnbiasedLinear:accGradParameters(input, gradOutput, scale) 68 | scale = scale or 1 69 | if input:dim() == 1 then 70 | self.gradWeight:addr(scale, gradOutput, input) 71 | elseif input:dim() == 2 then 72 | self.gradWeight:addmm(scale, gradOutput:t(), input) 73 | end 74 | end 75 | 76 | -- we do not need to accumulate parameters when sharing 77 | UnbiasedLinear.sharedAccUpdateGradParameters = UnbiasedLinear.accUpdateGradParameters 78 | 79 | 80 | function UnbiasedLinear:__tostring__() 81 | return torch.type(self) .. 82 | string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) 83 | end 84 | -------------------------------------------------------------------------------- /rnn_playground.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | require 'data_processing' 5 | local model_utils = require 'model_utils' 6 | package.path = package.path .. ";./network/?.lua" 7 | 8 | xrnn = {} 9 | 10 | --[[ Arguments: 11 | 12 | ------------------- 13 | -- VERBOSITY -- 14 | ------------------- 15 | report_freq -- reports performance during training this often 16 | conv_verbose -- if true, when conversing, if input symbol is unknown, declare so 17 | 18 | ------------------- 19 | -- COMPUTATIONAL -- 20 | ------------------- 21 | gpu 22 | batch_size 23 | seq_len 24 | collect_often 25 | noclones -- if true, gradients are calculated more slowly but much more memory-efficiently (default false) 26 | (useful for big networks on small graphics cards) 27 | 28 | ------------------- 29 | -- OPTIMIZATION -- 30 | ------------------- 31 | lr -- learning rate 32 | lambda -- RMSprop parameter 33 | gamma -- RMSprop parameter 34 | grad_clip -- boolean. clip gradients? yes or no. (currently always true) 35 | clip_to -- clip gradients to what? (default 5) 36 | 37 | training_mode -- if 1, randomly samples minibatches from anywhere in text corpus (default) 38 | if 2, proceeds linearly through text corpus 39 | 40 | ------------------- 41 | -- DATA -- 42 | ------------------- 43 | pattern -- lua pattern as string, or 'word' (experimental!). (defaults to '.', which matches all chars.) 44 | lower -- processes all text in lowercase (default false) 45 | rawdata -- if none given, rawdata will be obtained from input file 46 | usemostcommon -- use only the most common pattern-matching entries in rawdata as input symbols (default false) 47 | useNmostcommon -- how many? 48 | replace -- some common replacements for word mode (it's to it is, etc.) 49 | filename -- name of file containing training data rawtext (if not supplied, defaults to 'input.txt') 50 | 51 | ------------------- 52 | -- MODEL -- 53 | ------------------- 54 | layer_sizes -- table containing sizes of layers (default {128}) 55 | peepholes -- for LSTM networks (default false) 56 | nl_type -- for Vanilla RNN networks (default 'tanh') 57 | RNN_type -- string, determines RNN type. options are 'LSTM', 'GRU', 'Vanilla'. (default 'LSTM') 58 | 59 | ]] 60 | function init(args) 61 | 62 | -- Verbosity 63 | xrnn.report_freq = args.report_freq or 10 64 | xrnn.conv_verbose = args.conv_verbose or false 65 | 66 | -- Computational 67 | xrnn.gpu = args.gpu or -1 68 | if xrnn.gpu >= 0 then 69 | require 'cutorch' 70 | require 'cunn' 71 | end 72 | xrnn.batch_size = args.batch_size or 30 73 | xrnn.seq_len = args.seq_len or 50 74 | xrnn.collect_often = args.collect_often or false 75 | xrnn.noclones = args.noclones or false 76 | 77 | -- RMSprop and clipping gradients 78 | xrnn.lr = args.lr or 0.001 -- learning rate 79 | xrnn.lambda = args.lambda or 1e-8 80 | xrnn.gamma = args.gamma or 0.95 81 | xrnn.grad_clip = args.grad_clip or true 82 | xrnn.clip_to = args.clip_to or 5 83 | 84 | -- Training 85 | xrnn.training_mode = args.training_mode or 1 86 | 87 | -- Data 88 | xrnn.pattern = args.pattern or '.' 89 | xrnn.lower = args.lower or false 90 | xrnn.usemostcommon = args.usemostcommon or false 91 | xrnn.useNmostcommon = args.useNmostcommon or 4500 92 | xrnn.replace = args.replace or false 93 | xrnn.wordopt = {usemostcommon = xrnn.usemostcommon, useNmostcommon = xrnn.useNmostcommon, replace = xrnn.replace} 94 | xrnn.filename = args.filename or 'input.txt' 95 | xrnn.rawdata = args.rawdata or data(xrnn.filename) 96 | print('Creating embedding/deembedding tables for characters in data...') 97 | xrnn.embed, xrnn.deembed, xrnn.numkeys, xrnn.numwords, xrnn.tokenized, xrnn.freq_data = data_processing(xrnn.rawdata,xrnn.pattern,xrnn.lower,xrnn.wordopt) 98 | xrnn.numkeys = xrnn.numkeys + 1 -- for unknown character 99 | print('Finished making embed/deembed tables.') 100 | print('Finished making embedded data.') 101 | print('Dictionary has this many keys in it: ',xrnn.numkeys) 102 | 103 | -- Input mode things 104 | xrnn.eye = torch.eye(xrnn.numkeys) 105 | xrnn.decoder = true 106 | xrnn.rnn_input_size = xrnn.numkeys 107 | 108 | -- Networks and parameters 109 | xrnn.layer_sizes = args.layer_sizes or {128} 110 | xrnn.peepholes = args.peepholes or false 111 | xrnn.nl_type = args.nl_type or 'tanh' 112 | xrnn.n_hidden = 0 113 | 114 | xrnn.RNN_type = args.RNN_type or 'LSTM' 115 | 116 | -- Interpret RNN_type options and compute length of hidden state vector 117 | if xrnn.RNN_type == 'LSTM' then 118 | print('Making LSTM...') 119 | for i=1,#xrnn.layer_sizes do xrnn.n_hidden = xrnn.n_hidden + 2*xrnn.layer_sizes[i] end 120 | require 'LSTM' 121 | elseif xrnn.RNN_type == 'GRU' then 122 | print('Making GRU...') 123 | for i=1,#xrnn.layer_sizes do xrnn.n_hidden = xrnn.n_hidden + xrnn.layer_sizes[i] end 124 | require 'GRU' 125 | elseif xrnn.RNN_type == 'Vanilla' then 126 | print('Making Vanilla RNN...') 127 | for i=1,#xrnn.layer_sizes do xrnn.n_hidden = xrnn.n_hidden + xrnn.layer_sizes[i] end 128 | require 'VanillaRNN' 129 | end 130 | 131 | -- Build RNN and make references to its parameters and gradient 132 | if args.RNN then 133 | xrnn.RNN = args.RNN 134 | elseif xrnn.RNN_type == 'LSTM' then 135 | xrnn.RNN = LSTM(xrnn.rnn_input_size,xrnn.layer_sizes,{peepholes=xrnn.peepholes,decoder=xrnn.decoder,fgate_init=true}) 136 | elseif xrnn.RNN_type == 'GRU' then 137 | xrnn.RNN = GRU(xrnn.rnn_input_size,xrnn.layer_sizes,{decoder=xrnn.decoder}) 138 | elseif xrnn.RNN_type == 'Vanilla' then 139 | xrnn.RNN = VanillaRNN(xrnn.rnn_input_size,xrnn.layer_sizes,{decoder=xrnn.decoder,nl_type=xrnn.nl_type}) 140 | end 141 | if xrnn.gpu >= 0 then xrnn.RNN:cuda() end 142 | xrnn.params, xrnn.gradparams = xrnn.RNN:getParameters() 143 | xrnn.v = xrnn.gradparams:clone():zero() 144 | print('RNN is done.') 145 | 146 | -- Make criterion 147 | print('Making criterion...') 148 | local criterion_input_1, criterion_input_2, criterion_out 149 | criterion_input_1 = nn.Identity()() 150 | criterion_input_2 = nn.Identity()() 151 | criterion_out = nn.ClassNLLCriterion()({nn.LogSoftMax()(criterion_input_1),criterion_input_2}) 152 | xrnn.criterion = nn.gModule({criterion_input_1,criterion_input_2},{criterion_out}) 153 | if xrnn.gpu >= 0 then xrnn.criterion:cuda() end 154 | print('Criterion is done.') 155 | 156 | -- Make RNN/criterion clones, if applicable 157 | if not(xrnn.noclones) then 158 | clones = {} 159 | print('Cloning RNN...') 160 | clones.RNN = model_utils.clone_many_times(xrnn.RNN,xrnn.seq_len) 161 | collectgarbage() 162 | print('Cloning criterion...') 163 | clones.criterion = model_utils.clone_many_times(xrnn.criterion,xrnn.seq_len) 164 | collectgarbage() 165 | print('Clones are done.') 166 | end 167 | 168 | if xrnn.gpu >=0 then 169 | free,tot = cutorch.getMemoryUsage() 170 | print('Free fraction of memory remaining: ', free/tot) 171 | end 172 | print('Number of trainable parameters: ', xrnn.params:nElement()) 173 | 174 | end 175 | 176 | 177 | -- TRAINING 178 | 179 | function grad_pass_with_clones() 180 | total_loss = 0 181 | for i=1,xrnn.seq_len do 182 | clones.RNN[i]:forward({H[i],X[i]}) 183 | if i < xrnn.seq_len then H[i+1] = clones.RNN[i].output[1] end 184 | loss = clones.criterion[i]:forward({clones.RNN[i].output[2],Y[i]}) 185 | total_loss = total_loss + loss[1] 186 | end 187 | if xrnn.collect_often then collectgarbage() end 188 | 189 | local gradH 190 | for i=xrnn.seq_len,1,-1 do 191 | clones.criterion[i]:backward({clones.RNN[i].output[2],Y[i]},{1}) 192 | if i < xrnn.seq_len then 193 | gradH = clones.RNN[i+1].gradInput[1] 194 | else 195 | gradH = torch.Tensor():typeAs(H[i]):resizeAs(H[i]):zero() 196 | end 197 | clones.RNN[i]:backward({H[i],X[i]},{gradH,clones.criterion[i].gradInput[1]}) 198 | if xrnn.collect_often then collectgarbage() end 199 | end 200 | end 201 | 202 | function grad_pass_no_clones() 203 | total_loss = 0 204 | 205 | outputs = {} 206 | -- fwd pass to get the outputs and hidden states 207 | for i=1,xrnn.seq_len do 208 | xrnn.RNN:forward({H[i],X[i]}) 209 | if i < xrnn.seq_len then H[i+1] = xrnn.RNN.output[1] end 210 | outputs[i] = xrnn.RNN.output[2] 211 | end 212 | if xrnn.collect_often then collectgarbage() end 213 | 214 | gradInputs = {} 215 | -- bwd pass 216 | for i=xrnn.seq_len,1,-1 do 217 | loss = xrnn.criterion:forward({outputs[i],Y[i]}) 218 | total_loss = total_loss + loss[1] 219 | xrnn.criterion:backward({outputs[i],Y[i]},{1}) 220 | xrnn.RNN:forward({H[i],X[i]}) 221 | if i < xrnn.seq_len then 222 | xrnn.RNN:backward({H[i],X[i]},{gradInputs[i], xrnn.criterion.gradInput[1]}) 223 | else 224 | xrnn.RNN:backward({H[i],X[i]},{torch.Tensor():typeAs(H[i]):resizeAs(H[i]):zero(),xrnn.criterion.gradInput[1]}) 225 | end 226 | gradInputs[i-1] = xrnn.RNN.gradInput[1] 227 | end 228 | if xrnn.collect_often then collectgarbage() end 229 | end 230 | 231 | -- First index is time slice 232 | -- Second is element-in-batch 233 | function minibatch_loader() 234 | local i 235 | local preX,postX 236 | local I = torch.Tensor(xrnn.seq_len,xrnn.batch_size):zero():long() 237 | local X = torch.Tensor(xrnn.seq_len,xrnn.batch_size,xrnn.rnn_input_size):zero() 238 | local Y = torch.Tensor(xrnn.seq_len,xrnn.batch_size):zero() 239 | local H = torch.Tensor(xrnn.seq_len,xrnn.batch_size,xrnn.n_hidden):zero() 240 | 241 | if xrnn.training_mode == 2 then 242 | if not(xrnn.pos_in_text) then 243 | xrnn.pos_in_text = 1 244 | end 245 | end 246 | 247 | for n=1,xrnn.batch_size do 248 | if xrnn.training_mode == 1 then 249 | i = torch.ceil(torch.uniform()*(xrnn.numwords - xrnn.seq_len)) 250 | else 251 | i = xrnn.pos_in_text 252 | end 253 | preX = xrnn.tokenized[{{i,i + xrnn.seq_len - 1}}]:long() 254 | postX = xrnn.eye:index(1,preX) 255 | I[{{},{n}}]:copy(preX) 256 | X[{{},{n}}]:copy(postX) 257 | Y[{{},{n}}] = xrnn.tokenized[{{i+1,i + xrnn.seq_len}}] 258 | if xrnn.training_mode == 2 then 259 | xrnn.pos_in_text = xrnn.pos_in_text + xrnn.seq_len 260 | if xrnn.pos_in_text > xrnn.numwords - xrnn.seq_len then 261 | xrnn.pos_in_text = 1 262 | end 263 | end 264 | end 265 | 266 | if xrnn.gpu >= 0 then 267 | X = X:float():cuda() 268 | Y = Y:float():cuda() 269 | H = H:float():cuda() 270 | end 271 | 272 | return X,Y,H,I 273 | end 274 | 275 | 276 | function train_network_one_step() 277 | X,Y,H,I = minibatch_loader() 278 | xrnn.gradparams:zero() 279 | 280 | if xrnn.noclones then 281 | grad_pass_no_clones() 282 | else 283 | grad_pass_with_clones() 284 | end 285 | 286 | -- Average over batch and sequence length 287 | xrnn.gradparams:div(xrnn.batch_size):div(xrnn.seq_len) 288 | 289 | if xrnn.grad_clip then 290 | xrnn.gradparams:clamp(-xrnn.clip_to, xrnn.clip_to) 291 | end 292 | -- RMSprop: 293 | local grad = xrnn.gradparams:clone() 294 | grad:pow(2):mul(1 - xrnn.gamma) 295 | xrnn.v:mul(xrnn.gamma):add(grad) 296 | xrnn.gradparams:cdiv(torch.sqrt(xrnn.v):add(xrnn.lambda)) 297 | xrnn.params:add(-xrnn.lr,xrnn.gradparams) 298 | 299 | collectgarbage() 300 | end 301 | 302 | function train_network_N_steps(N) 303 | running_total_loss = 0 304 | for n=1,N do 305 | train_network_one_step() 306 | if n==1 then init_error = total_loss/xrnn.seq_len end 307 | if total_loss/xrnn.seq_len > 3*init_error then 308 | print('Error is exploding. Current error: ', total_loss/xrnn.seq_len) 309 | print('Terminating training here.') 310 | break 311 | end 312 | running_total_loss = running_total_loss + total_loss/xrnn.seq_len 313 | if n % xrnn.report_freq == 0 then 314 | if xrnn.training_mode == 1 then 315 | print('Average Error: ',running_total_loss/xrnn.report_freq,'Num Steps: ',n) 316 | else 317 | print('Average Error: ',running_total_loss/xrnn.report_freq,'Num Steps: ',n,' % thru text: ', round(100*xrnn.pos_in_text/xrnn.numwords)) 318 | end 319 | running_total_loss = 0 320 | end 321 | end 322 | end 323 | 324 | function epochs(N) 325 | local steps_per_epoch = xrnn.numwords/xrnn.seq_len/xrnn.batch_size 326 | xrnn.training_mode = 2 327 | for k=1,N do 328 | train_network_N_steps(steps_per_epoch) 329 | end 330 | end 331 | 332 | function round(x) 333 | return math.floor(x * 1000)/1000 334 | end 335 | 336 | -- SAMPLING 337 | 338 | function tokenize_string(text) 339 | 340 | local splitstring = split_string(text,xrnn.pattern,xrnn.lower,xrnn.wordopt) 341 | local numtokens = #splitstring 342 | local tokenized = torch.zeros(numtokens) 343 | for i=1,numtokens do 344 | tokenized[i] = xrnn.embed[splitstring[i]] or xrnn.numkeys 345 | if tokenized[i] == xrnn.numkeys and xrnn.conv_verbose then 346 | print('Machine Subconscious: I did not recognize the word ' .. splitstring[i] .. '.') 347 | print() 348 | end 349 | end 350 | 351 | return tokenized:long(), numtokens 352 | end 353 | 354 | function string_to_rnn_input(text) 355 | local tokenized, numtokens = tokenize_string(text) 356 | local X = torch.zeros(numtokens,xrnn.rnn_input_size) 357 | X:copy(xrnn.eye:index(1,tokenized)) 358 | return X 359 | end 360 | 361 | function sample_from_network(args) 362 | local X, xcur, hcur 363 | local length = args.length 364 | local chatmode = args.chatmode or false 365 | local toscreen = args.toscreen or not(chatmode) 366 | local primetext = args.primetext 367 | sample_text = '' 368 | 369 | xcur = torch.Tensor(xrnn.rnn_input_size):zero() 370 | hcur = args.hcur or torch.Tensor(xrnn.n_hidden):zero() 371 | 372 | softmax = nn.SoftMax() 373 | if xrnn.gpu >= 1 then 374 | xcur = xcur:float():cuda() 375 | hcur = hcur:float():cuda() 376 | softmax:cuda() 377 | end 378 | 379 | if not(primetext) then 380 | i = torch.ceil(torch.uniform()*xrnn.numkeys) 381 | xcur[i] = 1 382 | else 383 | X = string_to_rnn_input(primetext) 384 | if xrnn.gpu >=1 then 385 | X = X:cuda() 386 | end 387 | for n=1,X:size(1)-1 do 388 | xrnn.RNN:forward({hcur,X[n]}) 389 | hcur = xrnn.RNN.output[1] 390 | end 391 | xcur = X[X:size(1)] 392 | end 393 | 394 | local function next_character(pred) 395 | local probs = softmax:forward(pred) 396 | if args.temperature then 397 | local logprobs = torch.log(probs) 398 | logprobs:div(args.temperature) 399 | probs = torch.exp(logprobs) 400 | end 401 | probs:div(torch.sum(probs)) 402 | local next_char = torch.multinomial(probs:float(), 1):resize(1):float() 403 | local i = next_char[1] 404 | local next_text = xrnn.deembed[i] or '' 405 | return i, next_text 406 | end 407 | 408 | local n = 1 409 | repeat 410 | xrnn.RNN:forward({hcur,xcur}) 411 | hcur = xrnn.RNN.output[1] 412 | pred = xrnn.RNN.output[2] 413 | i, next_text = next_character(pred) 414 | xcur:zero() 415 | xcur[i] = 1 416 | if xrnn.pattern == 'word' then 417 | if not(i==xrnn.numkeys) then 418 | sample_text = sample_text .. ' ' .. next_text 419 | end 420 | else 421 | sample_text = sample_text .. next_text 422 | end 423 | if length then 424 | end_condition = (n==length) 425 | n = n + 1 426 | else 427 | end_condition = not(not(next_text:match('\n'))) -- output character includes linebreak 428 | end 429 | until end_condition 430 | 431 | if toscreen then 432 | print('Sample text: ',sample_text) 433 | end 434 | if chatmode then 435 | return sample_text, hcur 436 | else 437 | return sample_text 438 | end 439 | end 440 | 441 | function test_conv(temperature) 442 | local user_input 443 | local args 444 | local hcur = torch.Tensor(xrnn.n_hidden):zero() 445 | repeat 446 | user_input = io.read() 447 | io.write('\n') 448 | io.flush() 449 | if not(user_input) then 450 | user_input = ' ' 451 | end 452 | args = {hcur = hcur, chatmode=true, primetext = user_input .. '\n', temperature = temperature} 453 | --if length then args.length = length end 454 | machine_output, hcur = sample_from_network(args) 455 | io.write('Machine: ' .. machine_output .. '\n') 456 | io.flush() 457 | until user_input=="quit" 458 | end 459 | 460 | -- SAVING AND LOADING AND CHECKING MEMORY 461 | 462 | function memory() 463 | free,tot = cutorch.getMemoryUsage() 464 | print(free/tot) 465 | end 466 | 467 | function load(filename,gpu) 468 | -- check some requirements before loading 469 | local gpu = gpu or -1 470 | if gpu>=0 then 471 | require 'cutorch' 472 | require 'cunn' 473 | end 474 | require 'LSTM' 475 | require 'GRU' 476 | require 'VanillaRNN' 477 | 478 | T = torch.load(filename) 479 | xrnn = T.xrnn 480 | init(xrnn) 481 | end 482 | 483 | function loadFromOptions(filename) 484 | T = torch.load(filename) 485 | xrnn = T.xrnn 486 | xrnn.RNN = nil 487 | end 488 | 489 | function save(filename) 490 | torch.save(filename .. '.t7',{xrnn = xrnn}) 491 | end 492 | --------------------------------------------------------------------------------