├── Readme.md ├── cv └── lm_lstm_epoch30.00_1.3904.t7 ├── data └── tinyshakespeare │ ├── data.t7 │ ├── input.txt │ └── vocab.t7 ├── inspect_checkpoint.lua ├── model ├── GRU.lua ├── LSTM.lua └── RNN.lua ├── sample.lua ├── server.py ├── templates └── main.html ├── train.lua └── util ├── CharSplitLMMinibatchLoader.lua ├── OneHot.lua ├── misc.lua └── model_utils.lua /Readme.md: -------------------------------------------------------------------------------- 1 | # char-rnn 2 | A multi-layer Recurrent Neural Network (RNN, LSTM, and GRU) for training/sampling from character-level language models. The input is a single text file and the model learns to predict the next character in the sequence. More info here and here. Created by (https://twitter.com/karpathy). 3 | 4 | # char-rnn-API 5 | An API and web frontend for char-rnn, running on python/flask. 6 | Hoping to see many public char-rnn micro-api´s with different models spring up, so we can experiment together more easily. Created by (https://twitter.com/samim). 7 | 8 | ![char-rnn-api](https://i.imgur.com/xXY4Jqo.png "char-rnn-api") 9 | 10 | # instructions 11 | - install torch: http://torch.ch/docs/getting-started.html 12 | - install `luarocks install nngraph` and `luarocks install optim` 13 | - install flask: http://flask.pocoo.org/docs/0.10/installation/ 14 | - install flask flask cors: `pip install -U flask-cors` 15 | - `git clone https://github.com/samim23/char-rnn-api` 16 | - python server.py 17 | - goto https://thisserver.com:8080 18 | 19 | # API calls 20 | Post json request to: http://thisserver.com/api/v1.0 21 | {"primetext":"mytext", "temperature":"1", "length":"2000", "gpuid":"-1", "model":"model.t7","seed":"123", "sample":"1" } 22 | 23 | 24 | # char-rnn 25 | 26 | This code implements **multi-layer Recurrent Neural Network** (RNN, LSTM, and GRU) for training/sampling from character-level language models. The input is a single text file and the model learns to predict the next character in the sequence. 27 | 28 | The context of this code base is described in detail in my [blog post](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). 29 | 30 | There is also a [project page](http://cs.stanford.edu/people/karpathy/char-rnn/) that has some pointers and datasets. 31 | 32 | This code is based on Oxford University Machine Learning class [practical 6](https://github.com/oxford-cs-ml-2015/practical6), which is in turn based on [learning to execute](https://github.com/wojciechz/learning_to_execute) code from Wojciech Zaremba. Chunks of it were also developed in collaboration with my labmate [Justin Johnson](https://github.com/jcjohnson/). 33 | 34 | ## Requirements 35 | 36 | This code is written in Lua and requires [Torch](http://torch.ch/). 37 | Additionally, you need to install the `nngraph` and `optim` packages using [LuaRocks](https://luarocks.org/) which you will be able to do after installing Torch 38 | 39 | ```bash 40 | $ luarocks install nngraph 41 | $ luarocks install optim 42 | ``` 43 | 44 | ## Usage 45 | 46 | 47 | ### Data 48 | 49 | All input data is stored inside the `data/` directory. You'll notice that there is an example dataset included in the repo (in folder `data/tinyshakespeare`) which consists of a subset of works of Shakespeare. I'm providing a few more datasets on the [project page](http://cs.stanford.edu/people/karpathy/char-rnn/). 50 | 51 | **Your own data**: If you'd like to use your own data create a single file `input.txt` and place it into a folder in `data/`. For example, `data/some_folder/input.txt`. The first time you run the training script it will write two more convenience files into `data/some_folder`. **Note**: If you change the file `input.txt` in place you currently must delete the two intermediate files manually to force the preprocessing to re-run. 52 | 53 | Note that if your data is too small (1MB is already considered very small) the RNN won't learn very effectively. Remember that it has to learn everything completely from scratch. 54 | 55 | ### Training 56 | 57 | Start training the model using `train.lua`, for example: 58 | 59 | ``` 60 | $ th train.lua -data_dir data/some_folder -gpuid -1 61 | ``` 62 | 63 | The `-data_dir` flag is most important since it specifies the dataset to use. Notice that in this example we're also setting `gpuid` to -1 which tells the code to train using CPU, otherwise it defaults to GPU 0. There are many other flags for various options. Consult `$ th train.lua -help` for comprehensive settings. Here's another example: 64 | 65 | ``` 66 | $ th train.lua -data_dir data/some_folder -rnn_size 512 -num_layers 2 -dropout 0.5 67 | ``` 68 | 69 | While the model is training it will periodically write checkpoint files to the `cv` folder. The frequency with which these checkpoints are written is controlled with number of iterations, as specified with the `eval_val_every` option (e.g. if this is 1 then a checkpoint is written every iteration). 70 | 71 | We can use these checkpoints to generate text (discussed next). 72 | 73 | ### Sampling 74 | 75 | Given a checkpoint file (such as those written to `cv`) we can generate new text. For example: 76 | 77 | ``` 78 | $ th sample.lua cv/some_checkpoint.t7 -gpuid -1 79 | ``` 80 | 81 | Make sure that if your checkpoint was trained with GPU it is also sampled from with GPU, or vice versa. Otherwise the code will (currently) complain. As with the train script, see `$ th sample.lua -help` for full options. One important one is (for example) `-length 10000` which would generate 10,000 characters (default = 2000). 82 | 83 | **Temperature**. An important parameter you may want to play with a lot is `-temparature`, which takes a number in range (0, 1] (notice 0 not included), default = 1. The temperature is dividing the predicted log probabilities before the Softmax, so lower temperature will cause the model to make more likely, but also more boring and conservative predictions. Higher temperatures cause the model to take more chances and increase diversity of results, but at a cost of more mistakes. 84 | 85 | **Priming**. It's also possible to prime the model with some starting text using `-primetext`. 86 | 87 | Happy sampling! 88 | 89 | ## Tips and Tricks 90 | 91 | ### Monitoring Validation Loss vs. Training Loss 92 | If you're somewhat new to Machine Learning or Neural Networks it can take a bit of expertise to get good models. The most important quantity to keep track of is the difference between your training loss (printed during training) and the validation loss (printed once in a while when the RNN is run on the validation data (by default every 1000 iterations)). In particular: 93 | 94 | - If your training loss is much lower than validation loss then this means the network is **overfitting**. Solutions to this are to decrease your network size, or to increase dropout. For example you could try dropout of 0.5 and so on. 95 | - If your training/validation loss are about equal then your model is **underfitting**. Increase the size of your model (either number of layers or the raw number of neurons per layer) 96 | 97 | ### Approximate number of parameters 98 | 99 | The two most important parameters that control the model are `rnn_size` and `num_layers`. I would advise that you always use `num_layers` of about 3. The `rnn_size` can be adjusted based on how much data you have. The two important quantities to keep track of here are: 100 | 101 | - The number of parameters in your model. This is printed when you start training. 102 | - The size of your dataset. 1MB file is approximately 1 million characters. 103 | 104 | These two should be about the same order of magnitude. It's a little tricky to tell. Here are some examples: 105 | 106 | - I have a 100MB dataset and I'm using the default parameter settings (which currently print 150K parameters). My data size is significantly larger (100 mil >> 0.15 mil), so I expect to heavily underfit. I am thinking I can comfortably afford to make `rnn_size` larger. 107 | - I have a 10MB dataset and running a 10 million parameter model. I'm slightly nervous and I'm carefully monitoring my validation loss. If it's larger than my training loss then I may want to increase dropout a bit. 108 | 109 | ### Best models strategy 110 | 111 | The winning strategy to obtaining very good models (if you have the compute time) is to always err on making the network larger (as large as you're willing to wait for it to compute) and then try different dropout values (between 0,1). Whatever model has the best validation performance (the loss, written in the checkpoint filename, low is good) is the one you should use in the end. 112 | 113 | It is very common in deep learning to run many different models with many different hyperparameter settings, and in the end take whatever checkpoint gave the best validation performance. 114 | 115 | By the way, the size of your training and validation splits are also parameters. Make sure you have a decent amount of data in your validation set or otherwise the validation performance will be noisy and not very informative. 116 | 117 | ## License 118 | 119 | MIT 120 | 121 | 122 | ## Datasets 123 | - text from https://cs.stanford.edu/people/karpathy/char-rnn/ 124 | 125 | Cleaner version of this page coming soon, but for now some fun datasets: 126 | 127 | - Linux Kernel (6.2MB) 128 | https://cs.stanford.edu/people/karpathy/char-rnn/linux_input.txt 129 | 130 | The above is only the kernel. The examples in my blog post were trained on the full Linux code base. That is: 131 | 132 | $ git clone https://github.com/torvalds/linux.git 133 | $ cd linux 134 | $ find . -name "*.[c|h]" | shuf | xargs cat > linux.txt 135 | 136 | (This gives a 474MB file that I plugged in) 137 | 138 | - All works of Shakespeare concatenated (4.6MB) 139 | https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt 140 | 141 | - Leo Tolstoy's War and Peace (3.3MB) 142 | https://cs.stanford.edu/people/karpathy/char-rnn/warpeace_input.txt 143 | 144 | - Free books: (in general, including War and Peace) can be found in https://www.gutenberg.org/. 145 | 146 | - Wikipedia: 147 | 100MB Wikipedia data Hutter Prize 148 | http://prize.hutter1.net/ 149 | 150 | - The Stacks Project 151 | http://stacks.math.columbia.edu/ 152 | which is where the Latex dataset on Algebraic Geometry came from. 153 | 154 | -------------------------------------------------------------------------------- /cv/lm_lstm_epoch30.00_1.3904.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samim23/char-rnn-api/a81d3894f59e2fc4dea069efc0eb418145d80d35/cv/lm_lstm_epoch30.00_1.3904.t7 -------------------------------------------------------------------------------- /data/tinyshakespeare/vocab.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samim23/char-rnn-api/a81d3894f59e2fc4dea069efc0eb418145d80d35/data/tinyshakespeare/vocab.t7 -------------------------------------------------------------------------------- /inspect_checkpoint.lua: -------------------------------------------------------------------------------- 1 | -- simple script that loads a checkpoint and prints its opts 2 | 3 | require 'torch' 4 | require 'nn' 5 | require 'nngraph' 6 | require 'cutorch' 7 | require 'cunn' 8 | 9 | require 'util.OneHot' 10 | require 'util.misc' 11 | 12 | cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('Load a checkpoint and print its options and validation losses.') 15 | cmd:text() 16 | cmd:text('Options') 17 | cmd:argument('-model','model to load') 18 | cmd:option('-gpuid',0,'gpu to use') 19 | cmd:text() 20 | 21 | -- parse input params 22 | opt = cmd:parse(arg) 23 | 24 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 25 | require 'cutorch' 26 | require 'cunn' 27 | cutorch.setDevice(opt.gpuid + 1) 28 | 29 | local model = torch.load(opt.model) 30 | 31 | print('opt:') 32 | print(model.opt) 33 | print('val losses:') 34 | print(model.val_losses) 35 | 36 | -------------------------------------------------------------------------------- /model/GRU.lua: -------------------------------------------------------------------------------- 1 | 2 | local GRU = {} 3 | 4 | --[[ 5 | Creates one timestep of one GRU 6 | Paper reference: http://arxiv.org/pdf/1412.3555v1.pdf 7 | ]]-- 8 | function GRU.gru(input_size, rnn_size, n) 9 | 10 | -- there are n+1 inputs (hiddens on each layer and x) 11 | local inputs = {} 12 | table.insert(inputs, nn.Identity()()) -- x 13 | for L = 1,n do 14 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 15 | end 16 | 17 | function new_input_sum(insize, xv, hv) 18 | local i2h = nn.Linear(insize, rnn_size)(xv) 19 | local h2h = nn.Linear(rnn_size, rnn_size)(hv) 20 | return nn.CAddTable()({i2h, h2h}) 21 | end 22 | 23 | local x, input_size_L 24 | local outputs = {} 25 | for L = 1,n do 26 | 27 | local prev_h = inputs[L+1] 28 | if L == 1 then x = inputs[1] else x = outputs[L-1] end 29 | if L == 1 then input_size_L = input_size else input_size_L = rnn_size end 30 | 31 | -- GRU tick 32 | -- forward the update and reset gates 33 | local update_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) 34 | local reset_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) 35 | -- compute candidate hidden state 36 | local gated_hidden = nn.CMulTable()({reset_gate, prev_h}) 37 | local p2 = nn.Linear(rnn_size, rnn_size)(gated_hidden) 38 | local p1 = nn.Linear(input_size_L, rnn_size)(x) 39 | local hidden_candidate = nn.Tanh()(nn.CAddTable()({p1,p2})) 40 | -- compute new interpolated hidden state, based on the update gate 41 | local zh = nn.CMulTable()({update_gate, hidden_candidate}) 42 | local zhm1 = nn.CMulTable()({nn.AddConstant(1,false)(nn.MulConstant(-1,false)(update_gate)), prev_h}) 43 | local next_h = nn.CAddTable()({zh, zhm1}) 44 | 45 | table.insert(outputs, next_h) 46 | end 47 | 48 | return nn.gModule(inputs, outputs) 49 | end 50 | 51 | return GRU 52 | 53 | -------------------------------------------------------------------------------- /model/LSTM.lua: -------------------------------------------------------------------------------- 1 | 2 | local LSTM = {} 3 | function LSTM.lstm(input_size, rnn_size, n, dropout) 4 | dropout = dropout or 0 5 | 6 | -- there will be 2*n+1 inputs 7 | local inputs = {} 8 | table.insert(inputs, nn.Identity()()) -- x 9 | for L = 1,n do 10 | table.insert(inputs, nn.Identity()()) -- prev_c[L] 11 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 12 | end 13 | 14 | local x, input_size_L 15 | local outputs = {} 16 | for L = 1,n do 17 | -- c,h from previos timesteps 18 | local prev_h = inputs[L*2+1] 19 | local prev_c = inputs[L*2] 20 | -- the input to this layer 21 | if L == 1 then x = inputs[1] else x = outputs[(L-1)*2] end 22 | if L == 1 then input_size_L = input_size else input_size_L = rnn_size end 23 | -- evaluate the input sums at once for efficiency 24 | local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x) 25 | local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) 26 | local all_input_sums = nn.CAddTable()({i2h, h2h}) 27 | -- decode the gates 28 | local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) 29 | sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) 30 | local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) 31 | local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) 32 | local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk) 33 | -- decode the write inputs 34 | local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) 35 | in_transform = nn.Tanh()(in_transform) 36 | -- perform the LSTM update 37 | local next_c = nn.CAddTable()({ 38 | nn.CMulTable()({forget_gate, prev_c}), 39 | nn.CMulTable()({in_gate, in_transform}) 40 | }) 41 | -- gated cells form the output 42 | local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) 43 | -- add dropout to output, if desired 44 | if dropout > 0 then next_h = nn.Dropout(dropout)(next_h) end 45 | 46 | table.insert(outputs, next_c) 47 | table.insert(outputs, next_h) 48 | end 49 | 50 | return nn.gModule(inputs, outputs) 51 | end 52 | 53 | return LSTM 54 | 55 | -------------------------------------------------------------------------------- /model/RNN.lua: -------------------------------------------------------------------------------- 1 | local RNN = {} 2 | 3 | function RNN.rnn(input_size, rnn_size, n) 4 | 5 | -- there are n+1 inputs (hiddens on each layer and x) 6 | local inputs = {} 7 | table.insert(inputs, nn.Identity()()) -- x 8 | for L = 1,n do 9 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 10 | end 11 | 12 | local x, input_size_L 13 | local outputs = {} 14 | for L = 1,n do 15 | 16 | local prev_h = inputs[L+1] 17 | if L == 1 then x = inputs[1] else x = outputs[L-1] end 18 | if L == 1 then input_size_L = input_size else input_size_L = rnn_size end 19 | 20 | -- RNN tick 21 | local i2h = nn.Linear(input_size_L, rnn_size)(x) 22 | local h2h = nn.Linear(rnn_size, rnn_size)(prev_h) 23 | local next_h = nn.Tanh()(nn.CAddTable(){i2h, h2h}) 24 | 25 | table.insert(outputs, next_h) 26 | end 27 | 28 | return nn.gModule(inputs, outputs) 29 | end 30 | 31 | return RNN 32 | -------------------------------------------------------------------------------- /sample.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | 4 | This file samples characters from a trained model 5 | 6 | Code is based on implementation in 7 | https://github.com/oxford-cs-ml-2015/practical6 8 | 9 | ]]-- 10 | 11 | require 'torch' 12 | require 'nn' 13 | require 'nngraph' 14 | require 'optim' 15 | require 'lfs' 16 | 17 | require 'util.OneHot' 18 | require 'util.misc' 19 | 20 | cmd = torch.CmdLine() 21 | cmd:text() 22 | cmd:text('Sample from a character-level language model') 23 | cmd:text() 24 | cmd:text('Options') 25 | -- required: 26 | cmd:argument('-model','model checkpoint to use for sampling') 27 | -- optional parameters 28 | cmd:option('-seed',123,'random number generator\'s seed') 29 | cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep') 30 | cmd:option('-primetext'," ",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.') 31 | cmd:option('-length',2000,'number of characters to sample') 32 | cmd:option('-temperature',1,'temperature of sampling') 33 | cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') 34 | cmd:text() 35 | 36 | -- parse input params 37 | opt = cmd:parse(arg) 38 | 39 | if opt.gpuid >= 0 then 40 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 41 | require 'cutorch' 42 | require 'cunn' 43 | cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 44 | end 45 | torch.manualSeed(opt.seed) 46 | 47 | -- load the model checkpoint 48 | if not lfs.attributes(opt.model, 'mode') then 49 | print('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?') 50 | end 51 | checkpoint = torch.load(opt.model) 52 | 53 | 54 | local vocab = checkpoint.vocab 55 | local ivocab = {} 56 | for c,i in pairs(vocab) do ivocab[i] = c end 57 | 58 | protos = checkpoint.protos 59 | local rnn_idx = #protos.softmax.modules - 1 60 | opt.rnn_size = protos.softmax.modules[rnn_idx].weight:size(2) 61 | 62 | -- initialize the rnn state 63 | local current_state, state_predict_index 64 | local model = checkpoint.opt.model 65 | 66 | print('creating an LSTM...') 67 | local num_layers = checkpoint.opt.num_layers or 1 -- or 1 is for backward compatibility 68 | current_state = {} 69 | for L=1,checkpoint.opt.num_layers do 70 | -- c and h for all layers 71 | local h_init = torch.zeros(1, opt.rnn_size) 72 | if opt.gpuid >= 0 then h_init = h_init:cuda() end 73 | table.insert(current_state, h_init:clone()) 74 | table.insert(current_state, h_init:clone()) 75 | end 76 | state_predict_index = #current_state -- last one is the top h 77 | local seed_text = opt.primetext 78 | local prev_char 79 | 80 | protos.rnn:evaluate() -- put in eval mode so that dropout works properly 81 | 82 | -- do a few seeded timesteps 83 | print('seeding with ' .. seed_text) 84 | for c in seed_text:gmatch'.' do 85 | prev_char = torch.Tensor{vocab[c]} 86 | if opt.gpuid >= 0 then prev_char = prev_char:cuda() end 87 | local embedding = protos.embed:forward(prev_char) 88 | current_state = protos.rnn:forward{embedding, unpack(current_state)} 89 | if type(current_state) ~= 'table' then current_state = {current_state} end 90 | end 91 | 92 | -- start sampling/argmaxing 93 | for i=1, opt.length do 94 | 95 | -- softmax from previous timestep 96 | local next_h = current_state[state_predict_index] 97 | next_h = next_h / opt.temperature 98 | local log_probs = protos.softmax:forward(next_h) 99 | 100 | if opt.sample == 0 then 101 | -- use argmax 102 | local _, prev_char_ = log_probs:max(2) 103 | prev_char = prev_char_:resize(1) 104 | else 105 | -- use sampling 106 | local probs = torch.exp(log_probs):squeeze() 107 | prev_char = torch.multinomial(probs:float(), 1):resize(1):float() 108 | end 109 | 110 | -- forward the rnn for next character 111 | local embedding = protos.embed:forward(prev_char) 112 | current_state = protos.rnn:forward{embedding, unpack(current_state)} 113 | if type(current_state) ~= 'table' then current_state = {current_state} end 114 | 115 | io.write(ivocab[prev_char[1]]) 116 | end 117 | io.write('\n') io.flush() 118 | 119 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from flask import jsonify,render_template,redirect,url_for,request,abort 3 | from flask.ext.cors import CORS, cross_origin 4 | import json 5 | import subprocess 6 | 7 | app = Flask(__name__) 8 | 9 | modelsDirectory = 'cv' 10 | 11 | @app.route('/') 12 | def index(): 13 | return render_template('main.html') 14 | 15 | @app.route('/api/v1.0', methods=['POST']) 16 | def api_v1(): 17 | if not request.json or not 'primetext' in request.json: 18 | abort(400) 19 | 20 | primetext = request.json['primetext'] 21 | temperature = request.json['temperature'] 22 | length = request.json['length'] 23 | model = request.json['model'] 24 | seed = request.json['seed'] 25 | sample = request.json['sample'] 26 | gpuid = request.json['gpuid'] 27 | # override for public APIs 28 | gpuid = '-1' 29 | 30 | searchstring = 'th ../char-rnn/sample.lua ../char-rnn/'+modelsDirectory+'/' + str(model) 31 | searchstring += ' -gpuid ' + str(gpuid) 32 | searchstring += ' -primetext "' + str(primetext) + '"' 33 | searchstring += ' -temperature ' + str(temperature) 34 | searchstring += ' -length ' + str(length) 35 | searchstring += ' -seed ' + str(seed) 36 | searchstring += ' -sample ' + str(sample) 37 | 38 | responds = subprocess.Popen(searchstring, shell=True, stdout=subprocess.PIPE).stdout.read() 39 | 40 | # remove console stats output 41 | responds = responds.split('\n', 1)[1].split('\n', 1)[1].split('\n', 1)[1] 42 | 43 | return jsonify({'responds': responds}), 201 44 | 45 | @app.route('/api/v1.0/model', methods=['POST']) 46 | def api_v1_model(): 47 | searchstring = '(cd ../char-rnn/cv/ && ls -t)' 48 | responds = subprocess.Popen(searchstring, shell=True, stdout=subprocess.PIPE).stdout.read() 49 | responds = responds.splitlines(); 50 | return jsonify({'models': responds}), 201 51 | 52 | if __name__ == "__main__": 53 | app.run(host='0.0.0.0', port=8080) 54 | 55 | -------------------------------------------------------------------------------- /templates/main.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | char-rnn API 5 | 6 | 7 | 8 | 9 | 10 | 11 | 17 | 18 | 88 | 89 | 90 | 91 | 92 |
93 |

char-rnn API

94 |

A multi-layer Recurrent Neural Network (RNN, LSTM, and GRU) for training/sampling from character-level language models. Created by @karpathy. The input is a single text file and the model learns to predict the next character in the sequence. More info here and here. API by @samim.

95 |

API calls

96 | Post json request to: http://thisserver.com/api/v1.0 97 |
98 | 99 | {"primetext":"mytext", "temperature":"1", "length":"2000", "gpuid":"-1", "model":"model.t7","seed":"123", "sample":"1" } 100 | 101 |

API query

102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 120 | 121 | 122 | 125 | 126 |
127 | 128 |

129 | 130 | 131 | 132 |
133 | 134 | 135 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | 4 | This file trains a character-level multi-layer RNN on text data 5 | 6 | Code is based on implementation in 7 | https://github.com/oxford-cs-ml-2015/practical6 8 | but modified to have multi-layer support, GPU support, as well as 9 | many other common model/optimization bells and whistles. 10 | The practical6 code is in turn based on 11 | https://github.com/wojciechz/learning_to_execute 12 | which is turn based on other stuff in Torch, etc... (long lineage) 13 | 14 | ]]-- 15 | 16 | require 'torch' 17 | require 'nn' 18 | require 'nngraph' 19 | require 'optim' 20 | require 'lfs' 21 | 22 | require 'util.OneHot' 23 | require 'util.misc' 24 | local CharSplitLMMinibatchLoader = require 'util.CharSplitLMMinibatchLoader' 25 | local model_utils = require 'util.model_utils' 26 | local LSTM = require 'model.LSTM' 27 | 28 | cmd = torch.CmdLine() 29 | cmd:text() 30 | cmd:text('Train a character-level language model') 31 | cmd:text() 32 | cmd:text('Options') 33 | -- data 34 | cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data') 35 | -- model params 36 | cmd:option('-rnn_size', 100, 'size of LSTM internal state') 37 | cmd:option('-num_layers', 2, 'number of layers in the LSTM') 38 | cmd:option('-model', 'lstm', 'for now only lstm is supported. keep fixed') 39 | -- optimization 40 | cmd:option('-learning_rate',2e-3,'learning rate') 41 | cmd:option('-decay_rate',0.95,'decay rate for rmsprop') 42 | cmd:option('-dropout',0,'dropout to use just before classifier. 0 = no dropout') 43 | cmd:option('-seq_length',50,'number of timesteps to unroll for') 44 | cmd:option('-batch_size',100,'number of sequences to train on in parallel') 45 | cmd:option('-max_epochs',30,'number of full passes through the training data') 46 | cmd:option('-grad_clip',5,'clip gradients at') 47 | cmd:option('-train_frac',0.95,'fraction of data that goes into train set') 48 | cmd:option('-val_frac',0.05,'fraction of data that goes into validation set') 49 | -- note: test_frac will be computed as (1 - train_frac - val_frac) 50 | -- bookkeeping 51 | cmd:option('-seed',123,'torch manual random number generator seed') 52 | cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss') 53 | cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?') 54 | cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written') 55 | cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/') 56 | -- GPU/CPU 57 | cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') 58 | cmd:text() 59 | 60 | -- parse input params 61 | opt = cmd:parse(arg) 62 | torch.manualSeed(opt.seed) 63 | -- train / val / test split for data, in fractions 64 | local test_frac = math.max(0, 1 - opt.train_frac - opt.val_frac) 65 | local split_sizes = {opt.train_frac, opt.val_frac, test_frac} 66 | 67 | if opt.gpuid >= 0 then 68 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 69 | require 'cutorch' 70 | require 'cunn' 71 | cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua 72 | end 73 | -- create the data loader class 74 | local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes) 75 | local vocab_size = loader.vocab_size -- the number of distinct characters 76 | print('vocab size: ' .. vocab_size) 77 | -- make sure output directory exists 78 | if not path.exists(opt.checkpoint_dir) then lfs.mkdir(opt.checkpoint_dir) end 79 | 80 | -- define the model: prototypes for one timestep, then clone them in time 81 | protos = {} 82 | protos.embed = OneHot(vocab_size) 83 | print('creating an LSTM with ' .. opt.num_layers .. ' layers') 84 | protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) 85 | -- the initial state of the cell/hidden states 86 | init_state = {} 87 | for L=1,opt.num_layers do 88 | local h_init = torch.zeros(opt.batch_size, opt.rnn_size) 89 | if opt.gpuid >=0 then h_init = h_init:cuda() end 90 | table.insert(init_state, h_init:clone()) 91 | table.insert(init_state, h_init:clone()) 92 | end 93 | state_predict_index = #init_state -- index of blob to make prediction from 94 | -- classifier on top 95 | protos.softmax = nn.Sequential():add(nn.Linear(opt.rnn_size, vocab_size)):add(nn.LogSoftMax()) 96 | -- training criterion (negative log likelihood) 97 | protos.criterion = nn.ClassNLLCriterion() 98 | 99 | -- ship the model to the GPU if desired 100 | if opt.gpuid >= 0 then 101 | for k,v in pairs(protos) do v:cuda() end 102 | end 103 | 104 | -- put the above things into one flattened parameters tensor 105 | params, grad_params = model_utils.combine_all_parameters(protos.embed, protos.rnn, protos.softmax) 106 | params:uniform(-0.08, 0.08) 107 | print('number of parameters in the model: ' .. params:nElement()) 108 | -- make a bunch of clones after flattening, as that reallocates memory 109 | clones = {} 110 | for name,proto in pairs(protos) do 111 | print('cloning ' .. name) 112 | clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters) 113 | end 114 | 115 | -- evaluate the loss over an entire split 116 | function eval_split(split_index, max_batches) 117 | print('evaluating loss over split index ' .. split_index) 118 | local n = loader.split_sizes[split_index] 119 | if max_batches ~= nil then n = math.min(max_batches, n) end 120 | 121 | loader:reset_batch_pointer(split_index) -- move batch iteration pointer for this split to front 122 | local loss = 0 123 | local rnn_state = {[0] = init_state} 124 | 125 | for i = 1,n do -- iterate over batches in the split 126 | -- fetch a batch 127 | local x, y = loader:next_batch(split_index) 128 | if opt.gpuid >= 0 then -- ship the input arrays to GPU 129 | -- have to convert to float because integers can't be cuda()'d 130 | x = x:float():cuda() 131 | y = y:float():cuda() 132 | end 133 | -- forward pass 134 | for t=1,opt.seq_length do 135 | local embedding = clones.embed[t]:forward(x[{{}, t}]) 136 | clones.rnn[t]:evaluate() -- for dropout proper functioning 137 | rnn_state[t] = clones.rnn[t]:forward{embedding, unpack(rnn_state[t-1])} 138 | if type(rnn_state[t]) ~= 'table' then rnn_state[t] = {rnn_state[t]} end 139 | local prediction = clones.softmax[t]:forward(rnn_state[t][state_predict_index]) 140 | loss = loss + clones.criterion[t]:forward(prediction, y[{{}, t}]) 141 | end 142 | -- carry over lstm state 143 | rnn_state[0] = rnn_state[#rnn_state] 144 | print(i .. '/' .. n .. '...') 145 | end 146 | 147 | loss = loss / opt.seq_length / n 148 | return loss 149 | end 150 | 151 | -- do fwd/bwd and return loss, grad_params 152 | local init_state_global = clone_list(init_state) 153 | function feval(x) 154 | if x ~= params then 155 | params:copy(x) 156 | end 157 | grad_params:zero() 158 | 159 | ------------------ get minibatch ------------------- 160 | local x, y = loader:next_batch(1) 161 | if opt.gpuid >= 0 then -- ship the input arrays to GPU 162 | -- have to convert to float because integers can't be cuda()'d 163 | x = x:float():cuda() 164 | y = y:float():cuda() 165 | end 166 | ------------------- forward pass ------------------- 167 | local embeddings = {} -- input embeddings 168 | local rnn_state = {[0] = init_state_global} 169 | local predictions = {} -- softmax outputs 170 | local loss = 0 171 | for t=1,opt.seq_length do 172 | embeddings[t] = clones.embed[t]:forward(x[{{}, t}]) 173 | clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag) 174 | rnn_state[t] = clones.rnn[t]:forward{embeddings[t], unpack(rnn_state[t-1])} 175 | -- the following line is needed because nngraph tries to be clever 176 | if type(rnn_state[t]) ~= 'table' then rnn_state[t] = {rnn_state[t]} end 177 | predictions[t] = clones.softmax[t]:forward(rnn_state[t][state_predict_index]) 178 | loss = loss + clones.criterion[t]:forward(predictions[t], y[{{}, t}]) 179 | end 180 | loss = loss / opt.seq_length 181 | ------------------ backward pass ------------------- 182 | local dembeddings = {} 183 | -- initialize gradient at time t to be zeros (there's no influence from future) 184 | local drnn_state = {[opt.seq_length] = clone_list(init_state, true)} -- true also zeros the clones 185 | for t=opt.seq_length,1,-1 do 186 | -- backprop through loss, and softmax/linear 187 | local doutput_t = clones.criterion[t]:backward(predictions[t], y[{{}, t}]) 188 | drnn_state[t][state_predict_index] = clones.softmax[t]:backward(rnn_state[t][state_predict_index], doutput_t) 189 | -- backprop through LSTM timestep 190 | local drnn_statet_passin = drnn_state[t] 191 | -- we have to be careful with nngraph again 192 | if #(rnn_state[t]) == 1 then drnn_statet_passin = drnn_state[t][1] end 193 | local dlst = clones.rnn[t]:backward({embeddings[t], unpack(rnn_state[t-1])}, drnn_statet_passin) 194 | drnn_state[t-1] = {} 195 | for k,v in pairs(dlst) do 196 | if k == 1 then 197 | dembeddings[t] = v 198 | else 199 | -- note we do k-1 because first item is dembeddings, and then follow the 200 | -- derivatives of the state, starting at index 2. I know... 201 | drnn_state[t-1][k-1] = v 202 | end 203 | end 204 | -- backprop through embeddings 205 | clones.embed[t]:backward(x[{{}, t}], dembeddings[t]) 206 | end 207 | ------------------------ misc ---------------------- 208 | -- transfer final state to initial state (BPTT) 209 | init_state_global = rnn_state[#rnn_state] -- NOTE: I don't think this needs to be a clone, right? 210 | -- clip gradient element-wise 211 | grad_params:clamp(-opt.grad_clip, opt.grad_clip) 212 | return loss, grad_params 213 | end 214 | 215 | -- start optimization here 216 | train_losses = {} 217 | val_losses = {} 218 | local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate} 219 | local iterations = opt.max_epochs * loader.ntrain 220 | local iterations_per_epoch = loader.ntrain 221 | local loss0 = nil 222 | for i = 1, iterations do 223 | local epoch = i / loader.ntrain 224 | 225 | local timer = torch.Timer() 226 | local _, loss = optim.rmsprop(feval, params, optim_state) 227 | local time = timer:time().real 228 | 229 | local train_loss = loss[1] -- the loss is inside a list, pop it 230 | train_losses[i] = train_loss 231 | 232 | -- every now and then or on last iteration 233 | if i % opt.eval_val_every == 0 or i == iterations then 234 | -- evaluate loss on validation data 235 | local val_loss = eval_split(2) -- 2 = validation 236 | val_losses[i] = val_loss 237 | 238 | local savefile = string.format('%s/lm_%s_epoch%.2f_%.4f.t7', opt.checkpoint_dir, opt.savefile, epoch, val_loss) 239 | print('saving checkpoint to ' .. savefile) 240 | local checkpoint = {} 241 | checkpoint.protos = protos 242 | checkpoint.opt = opt 243 | checkpoint.train_losses = train_losses 244 | checkpoint.val_loss = val_loss 245 | checkpoint.val_losses = val_losses 246 | checkpoint.i = i 247 | checkpoint.epoch = epoch 248 | checkpoint.vocab = loader.vocab_mapping 249 | torch.save(savefile, checkpoint) 250 | end 251 | 252 | if i % opt.print_every == 0 then 253 | print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.2fs", i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time)) 254 | end 255 | 256 | if i % 10 == 0 then collectgarbage() end 257 | 258 | -- handle early stopping if things are going really bad 259 | if loss0 == nil then loss0 = loss[1] end 260 | if loss[1] > loss0 * 3 then 261 | print('loss is exploding, aborting.') 262 | break -- halt 263 | end 264 | end 265 | 266 | 267 | -------------------------------------------------------------------------------- /util/CharSplitLMMinibatchLoader.lua: -------------------------------------------------------------------------------- 1 | 2 | -- Modified from https://github.com/oxford-cs-ml-2015/practical6 3 | -- the modification included support for train/val/test splits 4 | 5 | local CharSplitLMMinibatchLoader = {} 6 | CharSplitLMMinibatchLoader.__index = CharSplitLMMinibatchLoader 7 | 8 | function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, split_fractions) 9 | -- split_fractions is e.g. {0.9, 0.05, 0.05} 10 | 11 | local self = {} 12 | setmetatable(self, CharSplitLMMinibatchLoader) 13 | 14 | local input_file = path.join(data_dir, 'input.txt') 15 | local vocab_file = path.join(data_dir, 'vocab.t7') 16 | local tensor_file = path.join(data_dir, 'data.t7') 17 | 18 | -- construct a tensor with all the data 19 | if not (path.exists(vocab_file) or path.exists(tensor_file)) then 20 | print('one-time setup: preprocessing input text file ' .. input_file .. '...') 21 | CharSplitLMMinibatchLoader.text_to_tensor(input_file, vocab_file, tensor_file) 22 | end 23 | 24 | print('loading data files...') 25 | local data = torch.load(tensor_file) 26 | self.vocab_mapping = torch.load(vocab_file) 27 | 28 | -- cut off the end so that it divides evenly 29 | local len = data:size(1) 30 | if len % (batch_size * seq_length) ~= 0 then 31 | print('cutting off end of data so that the batches/sequences divide evenly') 32 | data = data:sub(1, batch_size * seq_length 33 | * math.floor(len / (batch_size * seq_length))) 34 | end 35 | 36 | -- count vocab 37 | self.vocab_size = 0 38 | for _ in pairs(self.vocab_mapping) do 39 | self.vocab_size = self.vocab_size + 1 40 | end 41 | 42 | -- self.batches is a table of tensors 43 | print('reshaping tensor...') 44 | self.batch_size = batch_size 45 | self.seq_length = seq_length 46 | 47 | local ydata = data:clone() 48 | ydata:sub(1,-2):copy(data:sub(2,-1)) 49 | ydata[-1] = data[1] 50 | self.x_batches = data:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches 51 | self.nbatches = #self.x_batches 52 | self.y_batches = ydata:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches 53 | assert(#self.x_batches == #self.y_batches) 54 | 55 | self.ntrain = math.floor(self.nbatches * split_fractions[1]) 56 | self.nval = math.floor(self.nbatches * split_fractions[2]) 57 | self.ntest = self.nbatches - self.nval - self.ntrain -- the rest goes to test (to ensure this adds up exactly) 58 | 59 | self.split_sizes = {self.ntrain, self.nval, self.ntest} 60 | self.batch_ix = {0,0,0} 61 | 62 | print(string.format('data load done. Number of batches in train: %d, val: %d, test: %d', self.ntrain, self.nval, self.ntest)) 63 | collectgarbage() 64 | return self 65 | end 66 | 67 | function CharSplitLMMinibatchLoader:reset_batch_pointer(split_index, batch_index) 68 | batch_index = batch_index or 0 69 | self.batch_ix[split_index] = batch_index 70 | end 71 | 72 | function CharSplitLMMinibatchLoader:next_batch(split_index) 73 | -- split_index is integer: 1 = train, 2 = val, 3 = test 74 | self.batch_ix[split_index] = self.batch_ix[split_index] + 1 75 | if self.batch_ix[split_index] > self.split_sizes[split_index] then 76 | self.batch_ix[split_index] = 1 -- cycle around to beginning 77 | end 78 | -- pull out the correct next batch 79 | local ix = self.batch_ix[split_index] 80 | if split_index == 2 then ix = ix + self.ntrain end -- offset by train set size 81 | if split_index == 3 then ix = ix + self.ntrain + self.nval end -- offset by train + test 82 | return self.x_batches[ix], self.y_batches[ix] 83 | end 84 | 85 | -- *** STATIC method *** 86 | function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile) 87 | local timer = torch.Timer() 88 | 89 | print('loading text file...') 90 | local f = torch.DiskFile(in_textfile) 91 | local rawdata = f:readString('*a') -- NOTE: this reads the whole file at once 92 | f:close() 93 | 94 | -- create vocabulary if it doesn't exist yet 95 | print('creating vocabulary mapping...') 96 | -- record all characters to a set 97 | local unordered = {} 98 | for char in rawdata:gmatch'.' do 99 | if not unordered[char] then unordered[char] = true end 100 | end 101 | -- sort into a table (i.e. keys become 1..N) 102 | local ordered = {} 103 | for char in pairs(unordered) do ordered[#ordered + 1] = char end 104 | table.sort(ordered) 105 | -- invert `ordered` to create the char->int mapping 106 | local vocab_mapping = {} 107 | for i, char in ipairs(ordered) do 108 | vocab_mapping[char] = i 109 | end 110 | -- construct a tensor with all the data 111 | print('putting data into tensor...') 112 | local data = torch.ByteTensor(#rawdata) -- store it into 1D first, then rearrange 113 | for i=1, #rawdata do 114 | data[i] = vocab_mapping[rawdata:sub(i, i)] -- lua has no string indexing using [] 115 | end 116 | 117 | -- save output preprocessed files 118 | print('saving ' .. out_vocabfile) 119 | torch.save(out_vocabfile, vocab_mapping) 120 | print('saving ' .. out_tensorfile) 121 | torch.save(out_tensorfile, data) 122 | end 123 | 124 | return CharSplitLMMinibatchLoader 125 | 126 | -------------------------------------------------------------------------------- /util/OneHot.lua: -------------------------------------------------------------------------------- 1 | 2 | local OneHot, parent = torch.class('OneHot', 'nn.Module') 3 | 4 | function OneHot:__init(outputSize) 5 | parent.__init(self) 6 | self.outputSize = outputSize 7 | -- We'll construct one-hot encodings by using the index method to 8 | -- reshuffle the rows of an identity matrix. To avoid recreating 9 | -- it every iteration we'll cache it. 10 | self._eye = torch.eye(outputSize) 11 | end 12 | 13 | function OneHot:updateOutput(input) 14 | self.output:resize(input:size(1), self.outputSize):zero() 15 | if self._eye == nil then self._eye = torch.eye(self.outputSize) end 16 | self._eye = self._eye:float() 17 | local longInput = input:long() 18 | self.output:copy(self._eye:index(1, longInput)) 19 | return self.output 20 | end 21 | -------------------------------------------------------------------------------- /util/misc.lua: -------------------------------------------------------------------------------- 1 | 2 | -- misc utilities 3 | 4 | function clone_list(tensor_list, zero_too) 5 | -- utility function. todo: move away to some utils file? 6 | -- takes a list of tensors and returns a list of cloned tensors 7 | local out = {} 8 | for k,v in pairs(tensor_list) do 9 | out[k] = v:clone() 10 | if zero_too then out[k]:zero() end 11 | end 12 | return out 13 | end -------------------------------------------------------------------------------- /util/model_utils.lua: -------------------------------------------------------------------------------- 1 | 2 | -- adapted from https://github.com/wojciechz/learning_to_execute 3 | -- utilities for combining/flattening parameters in a model 4 | -- the code in this script is more general than it needs to be, which is 5 | -- why it is kind of a large 6 | 7 | require 'torch' 8 | local model_utils = {} 9 | function model_utils.combine_all_parameters(...) 10 | --[[ like module:getParameters, but operates on many modules ]]-- 11 | 12 | -- get parameters 13 | local networks = {...} 14 | local parameters = {} 15 | local gradParameters = {} 16 | for i = 1, #networks do 17 | local net_params, net_grads = networks[i]:parameters() 18 | 19 | if net_params then 20 | for _, p in pairs(net_params) do 21 | parameters[#parameters + 1] = p 22 | end 23 | for _, g in pairs(net_grads) do 24 | gradParameters[#gradParameters + 1] = g 25 | end 26 | end 27 | end 28 | 29 | local function storageInSet(set, storage) 30 | local storageAndOffset = set[torch.pointer(storage)] 31 | if storageAndOffset == nil then 32 | return nil 33 | end 34 | local _, offset = unpack(storageAndOffset) 35 | return offset 36 | end 37 | 38 | -- this function flattens arbitrary lists of parameters, 39 | -- even complex shared ones 40 | local function flatten(parameters) 41 | if not parameters or #parameters == 0 then 42 | return torch.Tensor() 43 | end 44 | local Tensor = parameters[1].new 45 | 46 | local storages = {} 47 | local nParameters = 0 48 | for k = 1,#parameters do 49 | local storage = parameters[k]:storage() 50 | if not storageInSet(storages, storage) then 51 | storages[torch.pointer(storage)] = {storage, nParameters} 52 | nParameters = nParameters + storage:size() 53 | end 54 | end 55 | 56 | local flatParameters = Tensor(nParameters):fill(1) 57 | local flatStorage = flatParameters:storage() 58 | 59 | for k = 1,#parameters do 60 | local storageOffset = storageInSet(storages, parameters[k]:storage()) 61 | parameters[k]:set(flatStorage, 62 | storageOffset + parameters[k]:storageOffset(), 63 | parameters[k]:size(), 64 | parameters[k]:stride()) 65 | parameters[k]:zero() 66 | end 67 | 68 | local maskParameters= flatParameters:float():clone() 69 | local cumSumOfHoles = flatParameters:float():cumsum(1) 70 | local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 71 | local flatUsedParameters = Tensor(nUsedParameters) 72 | local flatUsedStorage = flatUsedParameters:storage() 73 | 74 | for k = 1,#parameters do 75 | local offset = cumSumOfHoles[parameters[k]:storageOffset()] 76 | parameters[k]:set(flatUsedStorage, 77 | parameters[k]:storageOffset() - offset, 78 | parameters[k]:size(), 79 | parameters[k]:stride()) 80 | end 81 | 82 | for _, storageAndOffset in pairs(storages) do 83 | local k, v = unpack(storageAndOffset) 84 | flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 85 | end 86 | 87 | if cumSumOfHoles:sum() == 0 then 88 | flatUsedParameters:copy(flatParameters) 89 | else 90 | local counter = 0 91 | for k = 1,flatParameters:nElement() do 92 | if maskParameters[k] == 0 then 93 | counter = counter + 1 94 | flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 95 | end 96 | end 97 | assert (counter == nUsedParameters) 98 | end 99 | return flatUsedParameters 100 | end 101 | 102 | -- flatten parameters and gradients 103 | local flatParameters = flatten(parameters) 104 | local flatGradParameters = flatten(gradParameters) 105 | 106 | -- return new flat vector that contains all discrete parameters 107 | return flatParameters, flatGradParameters 108 | end 109 | 110 | 111 | 112 | 113 | function model_utils.clone_many_times(net, T) 114 | local clones = {} 115 | 116 | local params, gradParams 117 | if net.parameters then 118 | params, gradParams = net:parameters() 119 | if params == nil then 120 | params = {} 121 | end 122 | end 123 | 124 | local paramsNoGrad 125 | if net.parametersNoGrad then 126 | paramsNoGrad = net:parametersNoGrad() 127 | end 128 | 129 | local mem = torch.MemoryFile("w"):binary() 130 | mem:writeObject(net) 131 | 132 | for t = 1, T do 133 | -- We need to use a new reader for each clone. 134 | -- We don't want to use the pointers to already read objects. 135 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 136 | local clone = reader:readObject() 137 | reader:close() 138 | 139 | if net.parameters then 140 | local cloneParams, cloneGradParams = clone:parameters() 141 | local cloneParamsNoGrad 142 | for i = 1, #params do 143 | cloneParams[i]:set(params[i]) 144 | cloneGradParams[i]:set(gradParams[i]) 145 | end 146 | if paramsNoGrad then 147 | cloneParamsNoGrad = clone:parametersNoGrad() 148 | for i =1,#paramsNoGrad do 149 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 150 | end 151 | end 152 | end 153 | 154 | clones[t] = clone 155 | collectgarbage() 156 | end 157 | 158 | mem:close() 159 | return clones 160 | end 161 | 162 | return model_utils 163 | --------------------------------------------------------------------------------