├── .gitignore ├── README.md ├── data └── download_data.sh ├── evaluate.lua ├── extract_fc7.lua ├── lstm.lua ├── models └── download_models.sh ├── predict.lua ├── train.lua └── utils ├── DataLoader.lua ├── JSON.lua └── misc.lua /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | models/ 3 | checkpoints/ 4 | !data/download_data.sh 5 | !models/download_models.sh 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # neural-vqa 2 | 3 | [![Join the chat at https://gitter.im/abhshkdz/neural-vqa](https://badges.gitter.im/abhshkdz/neural-vqa.svg)](https://gitter.im/abhshkdz/neural-vqa?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 4 | 5 | This is an experimental Torch implementation of the 6 | VIS + LSTM visual question answering model from the paper 7 | [Exploring Models and Data for Image Question Answering][2] 8 | by Mengye Ren, Ryan Kiros & Richard Zemel. 9 | 10 | ![Model architecture](http://i.imgur.com/UXAPlqe.png) 11 | 12 | ## Setup 13 | 14 | Requirements: 15 | 16 | - [Torch][10] 17 | - [loadcaffe][9] 18 | 19 | Download the [MSCOCO][11] train+val images and [VQA][1] data using `sh data/download_data.sh`. Extract all the downloaded zip files inside the `data` folder. 20 | 21 | ``` 22 | unzip Annotations_Train_mscoco.zip 23 | unzip Questions_Train_mscoco.zip 24 | unzip train2014.zip 25 | 26 | unzip Annotations_Val_mscoco.zip 27 | unzip Questions_Val_mscoco.zip 28 | unzip val2014.zip 29 | ``` 30 | 31 | If you had them downloaded already, copy over the `train2014` and `val2014` image folders 32 | and VQA JSON files to the `data` folder. 33 | 34 | Download the [VGG-19][7] Caffe model and prototxt using `sh models/download_models.sh`. 35 | 36 | ### Known issues 37 | 38 | - To avoid memory issues with LuaJIT, install Torch with Lua 5.1 (`TORCH_LUA_VERSION=LUA51 ./install.sh`). 39 | More instructions [here][4]. 40 | - If working with plain Lua, [luaffifb][8] may be needed for [loadcaffe][9], 41 | unless using pre-extracted fc7 features. 42 | 43 | ## Usage 44 | 45 | ### Extract image features 46 | 47 | ``` 48 | th extract_fc7.lua -split train 49 | th extract_fc7.lua -split val 50 | ``` 51 | 52 | #### Options 53 | 54 | - `batch_size`: Batch size. Default is 10. 55 | - `split`: train/val. Default is `train`. 56 | - `gpuid`: 0-indexed id of GPU to use. Default is -1 = CPU. 57 | - `proto_file`: Path to the `deploy.prototxt` file for the VGG Caffe model. Default is `models/VGG_ILSVRC_19_layers_deploy.prototxt`. 58 | - `model_file`: Path to the `.caffemodel` file for the VGG Caffe model. Default is `models/VGG_ILSVRC_19_layers.caffemodel`. 59 | - `data_dir`: Data directory. Default is `data`. 60 | - `feat_layer`: Layer to extract features from. Default is `fc7`. 61 | - `input_image_dir`: Image directory. Default is `data`. 62 | 63 | 64 | ### Training 65 | 66 | ``` 67 | th train.lua 68 | ``` 69 | 70 | #### Options 71 | 72 | - `rnn_size`: Size of LSTM internal state. Default is 512. 73 | - `num_layers`: Number of layers in LSTM 74 | - `embedding_size`: Size of word embeddings. Default is 512. 75 | - `learning_rate`: Learning rate. Default is 4e-4. 76 | - `learning_rate_decay`: Learning rate decay factor. Default is 0.95. 77 | - `learning_rate_decay_after`: In number of epochs, when to start decaying the learning rate. Default is 15. 78 | - `alpha`: Alpha for adam. Default is 0.8 79 | - `beta`: Beta used for adam. Default is 0.999. 80 | - `epsilon`: Denominator term for smoothing. Default is 1e-8. 81 | - `batch_size`: Batch size. Default is 64. 82 | - `max_epochs`: Number of full passes through the training data. Default is 15. 83 | - `dropout`: Dropout for regularization. Probability of dropping input. Default is 0.5. 84 | - `init_from`: Initialize network parameters from checkpoint at this path. 85 | - `save_every`: No. of iterations after which to checkpoint. Default is 1000. 86 | - `train_fc7_file`: Path to fc7 features of training set. Default is `data/train_fc7.t7`. 87 | - `fc7_image_id_file`: Path to fc7 image ids of training set. Default is `data/train_fc7_image_id.t7`. 88 | - `val_fc7_file`: Path to fc7 features of validation set. Default is `data/val_fc7.t7`. 89 | - `val_fc7_image_id_file`: Path to fc7 image ids of validation set. Default is `data/val_fc7_image_id.t7`. 90 | - `data_dir`: Data directory. Default is `data`. 91 | - `checkpoint_dir`: Checkpoint directory. Default is `checkpoints`. 92 | - `savefile`: Filename to save checkpoint to. Default is `vqa`. 93 | - `gpuid`: 0-indexed id of GPU to use. Default is -1 = CPU. 94 | 95 | ### Testing 96 | 97 | ``` 98 | th predict.lua -checkpoint_file checkpoints/vqa_epoch23.26_0.4610.t7 -input_image_path data/train2014/COCO_train2014_000000405541.jpg -question 'What is the cat on?' 99 | ``` 100 | 101 | #### Options 102 | 103 | - `checkpoint_file`: Path to model checkpoint to initialize network parameters from 104 | - `input_image_path`: Path to input image 105 | - `question`: Question string 106 | 107 | ## Sample predictions 108 | 109 | Randomly sampled image-question pairs from the VQA test set, 110 | and answers predicted by the VIS+LSTM model. 111 | 112 | ![](http://i.imgur.com/V3nHbo9.jpg) 113 | 114 | Q: What animals are those? 115 | A: Sheep 116 | 117 | ![](http://i.imgur.com/QRBi6qb.jpg) 118 | 119 | Q: What color is the frisbee that's upside down? 120 | A: Red 121 | 122 | ![](http://i.imgur.com/tiOqJfH.jpg) 123 | 124 | Q: What is flying in the sky? 125 | A: Kite 126 | 127 | ![](http://i.imgur.com/4ZmOoUF.jpg) 128 | 129 | Q: What color is court? 130 | A: Blue 131 | 132 | ![](http://i.imgur.com/1D6NxvD.jpg) 133 | 134 | Q: What is in the standing person's hands? 135 | A: Bat 136 | 137 | ![](http://i.imgur.com/tY9BT1I.jpg) 138 | 139 | Q: Are they riding horses both the same color? 140 | A: No 141 | 142 | ![](http://i.imgur.com/hzwj0NS.jpg) 143 | 144 | Q: What shape is the plate? 145 | A: Round 146 | 147 | ![](http://i.imgur.com/n1Kn1vZ.jpg) 148 | 149 | Q: Is the man wearing socks? 150 | A: Yes 151 | 152 | ![](http://i.imgur.com/dXhNKP6.jpg) 153 | 154 | Q: What is over the woman's left shoulder? 155 | A: Fork 156 | 157 | ![](http://i.imgur.com/thzv03r.jpg) 158 | 159 | Q: Where are the pink flowers? 160 | A: On wall 161 | 162 | ## Implementation Details 163 | 164 | - Last hidden layer image features from [VGG-19][6] 165 | - Zero-padded question sequences for batched implementation 166 | - Training questions are filtered for `top_n` answers, 167 | `top_n = 1000` by default (~87% coverage) 168 | 169 | ## Pretrained model and data files 170 | 171 | To reproduce results shown on this page or try your own 172 | image-question pairs, download the following and run 173 | `predict.lua` with the appropriate paths. 174 | 175 | - vqa\_epoch23.26\_0.4610.t7 (Serialized using Lua51) [[GPU](https://drive.google.com/file/d/0B8qwt8PA_oxpSWhRQ1NKYkxhYnc/view?usp=sharing)] [[CPU](https://drive.google.com/file/d/0B8qwt8PA_oxpbGJQY0EyZ2phYTg/view?usp=sharing)] 176 | - [answers_vocab.t7](https://drive.google.com/file/d/0B8qwt8PA_oxpNE1RdWlMLWlNcVk/view?usp=sharing) 177 | - [questions_vocab.t7](https://drive.google.com/file/d/0B8qwt8PA_oxpd2Y4MXIzb0pxSWM/view?usp=sharing) 178 | - [data.t7](https://drive.google.com/file/d/0B8qwt8PA_oxpejVuTFVsZTJDSUU/view?usp=sharing) 179 | 180 | ## References 181 | 182 | - [Exploring Models and Data for Image Question Answering][2], Ren et al., NIPS15 183 | - [VQA: Visual Question Answering][3], Antol et al., ICCV15 184 | 185 | ## License 186 | 187 | [MIT][12] 188 | 189 | [1]: http://visualqa.org/ 190 | [2]: http://arxiv.org/abs/1505.02074 191 | [3]: http://arxiv.org/abs/1505.00468 192 | [4]: https://github.com/torch/distro 193 | [5]: http://nlp.stanford.edu/projects/glove/ 194 | [6]: http://arxiv.org/abs/1409.1556 195 | [7]: https://gist.github.com/ksimonyan/3785162f95cd2d5fee77#file-readme-md 196 | [8]: https://github.com/facebook/luaffifb 197 | [9]: https://github.com/szagoruyko/loadcaffe 198 | [10]: http://torch.ch/ 199 | [11]: http://mscoco.org/ 200 | [12]: https://abhshkdz.mit-license.org/ 201 | -------------------------------------------------------------------------------- /data/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cd data 4 | 5 | wget -c http://visualqa.org/data/mscoco/vqa/Annotations_Train_mscoco.zip 6 | wget -c http://visualqa.org/data/mscoco/vqa/Questions_Train_mscoco.zip 7 | wget -c http://msvocds.blob.core.windows.net/coco2014/train2014.zip 8 | 9 | wget -c http://visualqa.org/data/mscoco/vqa/Annotations_Val_mscoco.zip 10 | wget -c http://visualqa.org/data/mscoco/vqa/Questions_Val_mscoco.zip 11 | wget -c http://msvocds.blob.core.windows.net/coco2014/val2014.zip 12 | 13 | cd .. 14 | 15 | -------------------------------------------------------------------------------- /evaluate.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | 5 | local utils = require 'utils.misc' 6 | local DataLoader = require 'utils.DataLoader' 7 | 8 | local LSTM = require 'lstm' 9 | 10 | cmd = torch.CmdLine() 11 | cmd:text('Options') 12 | 13 | -- model params 14 | cmd:option('-rnn_size', 512, 'size of LSTM internal state') 15 | cmd:option('-num_layers', 2, 'Number of layers in LSTM') 16 | cmd:option('-embedding_size', 512, 'size of word embeddings') 17 | -- optimization 18 | cmd:option('-batch_size', 64, 'batch size') 19 | -- bookkeeping 20 | cmd:option('-checkpoint_file', 'checkpoints/vqa_epoch23.26_0.4610.t7', 'Checkpoint file to use for predictions') 21 | cmd:option('-data_dir', 'data', 'data directory') 22 | cmd:option('-seed', 981723, 'Torch manual random number generator seed') 23 | cmd:option('-train_fc7_file', 'data/train_fc7.t7', 'Path to fc7 features of training set') 24 | cmd:option('-train_fc7_image_id_file', 'data/train_fc7_image_id.t7', 'Path to fc7 image ids of training set') 25 | cmd:option('-val_fc7_file', 'data/val_fc7.t7', 'Path to fc7 features of validation set') 26 | cmd:option('-val_fc7_image_id_file', 'data/val_fc7_image_id.t7', 'Path to fc7 image ids of validation set') 27 | -- gpu/cpu 28 | cmd:option('-gpuid', -1, '0-indexed id of GPU to use. -1 = CPU') 29 | 30 | opt = cmd:parse(arg or {}) 31 | torch.manualSeed(opt.seed) 32 | 33 | if opt.gpuid >= 0 then 34 | local ok, cunn = pcall(require, 'cunn') 35 | local ok2, cutorch = pcall(require, 'cutorch') 36 | if not ok then print('package cunn not found!') end 37 | if not ok2 then print('package cutorch not found!') end 38 | if ok and ok2 then 39 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 40 | cutorch.setDevice(opt.gpuid + 1) -- torch is 1-indexed 41 | cutorch.manualSeed(opt.seed) 42 | else 43 | print('If cutorch and cunn are installed, your CUDA toolkit may be improperly configured.') 44 | print('Check your CUDA toolkit installation, rebuild cutorch and cunn, and try again.') 45 | print('Falling back to CPU mode') 46 | opt.gpuid = -1 47 | end 48 | end 49 | 50 | loader = DataLoader.create(opt.data_dir, opt.batch_size, opt) 51 | 52 | print('loading checkpoint from ' .. opt.checkpoint_file) 53 | checkpoint = torch.load(opt.checkpoint_file) 54 | 55 | lstm_clones = {} 56 | lstm_clones = utils.clone_many_times(checkpoint.protos.lstm, loader.q_max_length + 1) 57 | 58 | checkpoint.protos.ltw:evaluate() 59 | checkpoint.protos.lti:evaluate() 60 | 61 | q_vocab_size = checkpoint.vocab_size 62 | 63 | a_iv = {} 64 | for i,v in pairs(loader.a_vocab_mapping) do 65 | a_iv[v] = i 66 | end 67 | 68 | q_iv = {} 69 | for i,v in pairs(loader.q_vocab_mapping) do 70 | q_iv[v] = i 71 | end 72 | 73 | if q_vocab_size ~= loader.q_vocab_size then 74 | print('Vocab size of checkpoint and data are different.') 75 | end 76 | 77 | init_state = {} 78 | for L = 1, opt.num_layers do 79 | local h_init = torch.zeros(opt.batch_size, opt.rnn_size) 80 | if opt.gpuid >=0 then h_init = h_init:cuda() end 81 | table.insert(init_state, h_init:clone()) 82 | table.insert(init_state, h_init:clone()) 83 | end 84 | 85 | local init_state_global = utils.clone_list(init_state) 86 | 87 | count = 0 88 | for i = 1, loader.batch_data.val.nbatches do 89 | q_batch, a_batch, i_batch = loader:next_batch('val') 90 | 91 | -- 1st index of `nn.LookupTable` is for zeros 92 | q_batch = q_batch + 1 93 | 94 | qf = checkpoint.protos.ltw:forward(q_batch) 95 | 96 | imf = checkpoint.protos.lti:forward(i_batch) 97 | 98 | if opt.gpuid >= 0 then 99 | imf = imf:cuda() 100 | end 101 | 102 | rnn_state = {[0] = init_state_global} 103 | 104 | for t = 1, loader.q_max_length do 105 | lst = lstm_clones[t]:forward{qf:select(2,t), unpack(rnn_state[t-1])} 106 | rnn_state[t] = {} 107 | for i = 1, #init_state do table.insert(rnn_state[t], lst[i]) end 108 | end 109 | 110 | lst = lstm_clones[loader.q_max_length + 1]:forward{imf, unpack(rnn_state[loader.q_max_length])} 111 | 112 | prediction = checkpoint.protos.sm:forward(lst[#lst]) 113 | 114 | _, idx = prediction:max(2) 115 | for j = 1, opt.batch_size do 116 | if idx[j][1] == a_batch[j] then 117 | count = count + 1 118 | end 119 | end 120 | 121 | print(count .. '/' .. i * opt.batch_size) 122 | end 123 | 124 | print(count / (loader.batch_data.val.nbatches * opt.batch_size)) 125 | -------------------------------------------------------------------------------- /extract_fc7.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'nn' 4 | require 'image' 5 | 6 | local utils = require 'utils.misc' 7 | local DataLoader = require 'utils.DataLoader' 8 | 9 | require 'loadcaffe' 10 | 11 | cmd = torch.CmdLine() 12 | cmd:text('Options') 13 | 14 | cmd:option('-batch_size', 10, 'batch size') 15 | cmd:option('-split', 'train', 'train/val') 16 | cmd:option('-debug', 0, 'set debug = 1 for lots of prints') 17 | -- bookkeeping 18 | cmd:option('-seed', 981723, 'Torch manual random number generator seed') 19 | cmd:option('-proto_file', 'models/VGG_ILSVRC_19_layers_deploy.prototxt') 20 | cmd:option('-model_file', 'models/VGG_ILSVRC_19_layers.caffemodel') 21 | cmd:option('-data_dir', 'data', 'Data directory.') 22 | cmd:option('-feat_layer', 'relu7', 'Layer to extract features from') 23 | cmd:option('-input_image_dir', 'data', 'Image directory') 24 | -- gpu/cpu 25 | cmd:option('-gpuid', -1, '0-indexed id of GPU to use. -1 = CPU') 26 | 27 | opt = cmd:parse(arg or {}) 28 | torch.manualSeed(opt.seed) 29 | 30 | if opt.gpuid >= 0 then 31 | local ok, cunn = pcall(require, 'cunn') 32 | local ok2, cutorch = pcall(require, 'cutorch') 33 | if not ok then print('package cunn not found!') end 34 | if not ok2 then print('package cutorch not found!') end 35 | if ok and ok2 then 36 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 37 | cutorch.setDevice(opt.gpuid + 1) 38 | cutorch.manualSeed(opt.seed) 39 | else 40 | print('If cutorch and cunn are installed, your CUDA toolkit may be improperly configured.') 41 | print('Check your CUDA toolkit installation, rebuild cutorch and cunn, and try again.') 42 | print('Falling back to CPU mode') 43 | opt.gpuid = -1 44 | end 45 | end 46 | 47 | loader = DataLoader.create(opt.data_dir, opt.batch_size, opt, 'fc7_feat') 48 | 49 | cnn = loadcaffe.load(opt.proto_file, opt.model_file) 50 | if opt.gpuid >= 0 then 51 | cnn = cnn:cuda() 52 | end 53 | 54 | cnn_fc7 = nn.Sequential() 55 | 56 | for i = 1, #cnn.modules do 57 | local layer = cnn:get(i) 58 | local name = layer.name 59 | cnn_fc7:add(layer) 60 | if name == opt.feat_layer then 61 | break 62 | end 63 | end 64 | 65 | cnn_fc7:evaluate() 66 | 67 | if opt.gpuid >= 0 then 68 | cnn_fc7 = cnn_fc7:cuda() 69 | end 70 | 71 | tmp_image_id = {} 72 | for i = 1, #loader.data[opt.split] do 73 | tmp_image_id[loader.data[opt.split][i].image_id] = 1 74 | end 75 | 76 | image_id = {} 77 | idx = 1 78 | for i, v in pairs(tmp_image_id) do 79 | image_id[idx] = i 80 | idx = idx + 1 81 | end 82 | 83 | fc7 = torch.DoubleTensor(#image_id, 4096) 84 | idx = 1 85 | 86 | if opt.gpuid >= 0 then 87 | fc7 = fc7:cuda() 88 | end 89 | 90 | repeat 91 | local timer = torch.Timer() 92 | img_batch = torch.zeros(opt.batch_size, 3, 224, 224) 93 | img_id_batch = {} 94 | for i = 1, opt.batch_size do 95 | if not image_id[idx] then 96 | break 97 | end 98 | local fp = path.join(opt.input_image_dir, string.format('%s2014/COCO_%s2014_%.12d.jpg', opt.split, opt.split, image_id[idx])) 99 | if opt.debug == 1 then 100 | print(idx) 101 | print(fp) 102 | end 103 | img_batch[i] = utils.preprocess(image.scale(image.load(fp, 3), 224, 224)) 104 | img_id_batch[i] = image_id[idx] 105 | idx = idx + 1 106 | end 107 | 108 | if opt.gpuid >= 0 then 109 | img_batch = img_batch:cuda() 110 | end 111 | 112 | fc7_batch = cnn_fc7:forward(img_batch:narrow(1, 1, #img_id_batch)) 113 | 114 | for i = 1, fc7_batch:size(1) do 115 | if opt.debug == 1 then 116 | print(idx - fc7_batch:size(1) + i - 1) 117 | end 118 | fc7[idx - fc7_batch:size(1) + i - 1]:copy(fc7_batch[i]) 119 | end 120 | 121 | if opt.gpuid >= 0 then 122 | cutorch.synchronize() 123 | end 124 | 125 | local time = timer:time().real 126 | print(idx-1 .. '/' .. #image_id .. " " .. time) 127 | collectgarbage() 128 | until idx >= #image_id 129 | 130 | torch.save(path.join(opt.data_dir, opt.split .. '_fc7.t7'), fc7) 131 | torch.save(path.join(opt.data_dir, opt.split .. '_fc7_image_id.t7'), image_id) 132 | -------------------------------------------------------------------------------- /lstm.lua: -------------------------------------------------------------------------------- 1 | local LSTM = {} 2 | function LSTM.create(input_size, rnn_size, num_layers) 3 | 4 | local inputs = {} 5 | 6 | table.insert(inputs, nn.Identity()()) -- x 7 | 8 | for L = 1, num_layers do 9 | table.insert(inputs, nn.Identity()()) -- c 10 | table.insert(inputs, nn.Identity()()) -- h 11 | end 12 | 13 | local x, input_size_L 14 | local outputs = {} 15 | 16 | for L = 1, num_layers do 17 | local prev_c = inputs[2*L] 18 | local prev_h = inputs[2*L + 1] 19 | 20 | if L == 1 then 21 | x = inputs[1] 22 | input_size_L = input_size 23 | else 24 | x = outputs[2*(L-1)] 25 | input_size_L = rnn_size 26 | end 27 | 28 | local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x) 29 | local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) 30 | local preactivations = nn.CAddTable()({i2h, h2h}) 31 | 32 | -- gates 33 | local pre_sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(preactivations) 34 | local all_gates = nn.Sigmoid()(pre_sigmoid_chunk) 35 | 36 | -- input 37 | local in_chunk = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(preactivations) 38 | local in_transform = nn.Tanh()(in_chunk) 39 | 40 | local in_gate = nn.Narrow(2, 1, rnn_size)(all_gates) 41 | local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(all_gates) 42 | local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(all_gates) 43 | 44 | local next_c = nn.CAddTable()({ 45 | nn.CMulTable()({forget_gate, prev_c}), 46 | nn.CMulTable()({in_gate, in_transform}) 47 | }) 48 | 49 | local c_transform = nn.Tanh()(next_c) 50 | local next_h = nn.CMulTable()({out_gate, c_transform}) 51 | 52 | table.insert(outputs, next_c) 53 | table.insert(outputs, next_h) 54 | end 55 | 56 | return nn.gModule(inputs, outputs) 57 | end 58 | 59 | return LSTM -------------------------------------------------------------------------------- /models/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cd models 4 | 5 | wget -c https://gist.githubusercontent.com/ksimonyan/3785162f95cd2d5fee77/raw/bb2b4fe0a9bb0669211cf3d0bc949dfdda173e9e/VGG_ILSVRC_19_layers_deploy.prototxt 6 | wget -c http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel 7 | 8 | cd .. 9 | -------------------------------------------------------------------------------- /predict.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'nn' 4 | require 'nngraph' 5 | require 'optim' 6 | require 'image' 7 | 8 | local utils = require 'utils.misc' 9 | local DataLoader = require 'utils.DataLoader' 10 | 11 | require 'loadcaffe' 12 | local LSTM = require 'lstm' 13 | 14 | cmd = torch.CmdLine() 15 | cmd:text('Options') 16 | 17 | -- model params 18 | cmd:option('-rnn_size', 512, 'size of LSTM internal state') 19 | cmd:option('-num_layers', 2, 'Number of layers in LSTM') 20 | cmd:option('-embedding_size', 512, 'size of word embeddings') 21 | -- optimization 22 | cmd:option('-batch_size', 64, 'batch size') 23 | -- bookkeeping 24 | cmd:option('-proto_file', 'models/VGG_ILSVRC_19_layers_deploy.prototxt') 25 | cmd:option('-model_file', 'models/VGG_ILSVRC_19_layers.caffemodel') 26 | cmd:option('-checkpoint_file', 'checkpoints/vqa_epoch23.26_0.4610.t7', 'Checkpoint file to use for predictions') 27 | cmd:option('-data_dir', 'data', 'data directory') 28 | cmd:option('-seed', 981723, 'Torch manual random number generator seed') 29 | cmd:option('-feat_layer', 'fc7', 'Layer to extract features from') 30 | cmd:option('-train_fc7_file', 'data/train_fc7.t7', 'Path to fc7 features of training set') 31 | cmd:option('-train_fc7_image_id_file', 'data/train_fc7_image_id.t7', 'Path to fc7 image ids of training set') 32 | cmd:option('-val_fc7_file', 'data/val_fc7.t7', 'Path to fc7 features of validation set') 33 | cmd:option('-val_fc7_image_id_file', 'data/val_fc7_image_id.t7', 'Path to fc7 image ids of validation set') 34 | cmd:option('-input_image_path', 'data/train2014/COCO_train2014_000000405541.jpg', 'Image path') 35 | cmd:option('-question', 'What is the cat on?', 'Question string') 36 | -- gpu/cpu 37 | cmd:option('-gpuid', -1, '0-indexed id of GPU to use. -1 = CPU') 38 | 39 | opt = cmd:parse(arg or {}) 40 | torch.manualSeed(opt.seed) 41 | 42 | if opt.gpuid >= 0 then 43 | local ok, cunn = pcall(require, 'cunn') 44 | local ok2, cutorch = pcall(require, 'cutorch') 45 | if not ok then print('package cunn not found!') end 46 | if not ok2 then print('package cutorch not found!') end 47 | if ok and ok2 then 48 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 49 | cutorch.setDevice(opt.gpuid + 1) -- torch is 1-indexed 50 | cutorch.manualSeed(opt.seed) 51 | else 52 | print('If cutorch and cunn are installed, your CUDA toolkit may be improperly configured.') 53 | print('Check your CUDA toolkit installation, rebuild cutorch and cunn, and try again.') 54 | print('Falling back to CPU mode') 55 | opt.gpuid = -1 56 | end 57 | end 58 | 59 | loader = DataLoader.create(opt.data_dir, opt.batch_size, opt, 'predict') 60 | 61 | -- load model checkpoint 62 | 63 | print('loading checkpoint from ' .. opt.checkpoint_file) 64 | checkpoint = torch.load(opt.checkpoint_file) 65 | 66 | lstm_clones = {} 67 | lstm_clones = utils.clone_many_times(checkpoint.protos.lstm, loader.q_max_length + 1) 68 | 69 | checkpoint.protos.ltw:evaluate() 70 | checkpoint.protos.lti:evaluate() 71 | 72 | q_vocab_size = checkpoint.vocab_size 73 | 74 | a_iv = {} 75 | for i,v in pairs(loader.a_vocab_mapping) do 76 | a_iv[v] = i 77 | end 78 | 79 | q_iv = {} 80 | for i,v in pairs(loader.q_vocab_mapping) do 81 | q_iv[v] = i 82 | end 83 | 84 | if q_vocab_size ~= loader.q_vocab_size then 85 | print('Vocab size of checkpoint and data are different.') 86 | end 87 | 88 | cnn = loadcaffe.load(opt.proto_file, opt.model_file) 89 | 90 | function predict(input_image_path, question_string) 91 | 92 | -- extract image features 93 | 94 | if opt.gpuid >= 0 then 95 | cnn = cnn:cuda() 96 | end 97 | 98 | local cnn_fc7 = nn.Sequential() 99 | 100 | for i = 1, #cnn.modules do 101 | local layer = cnn:get(i) 102 | local name = layer.name 103 | cnn_fc7:add(layer) 104 | if name == opt.feat_layer then 105 | break 106 | end 107 | end 108 | 109 | if opt.gpuid >= 0 then 110 | cnn_fc7 = cnn_fc7:cuda() 111 | end 112 | 113 | local img = utils.preprocess(image.scale(image.load(input_image_path, 3), 224, 224)) 114 | 115 | if opt.gpuid >= 0 then 116 | img = img:cuda() 117 | end 118 | 119 | local fc7 = cnn_fc7:forward(img) 120 | local imf = checkpoint.protos.lti:forward(fc7) 121 | 122 | -- extract question features 123 | 124 | local question = torch.ShortTensor(loader.q_max_length):zero() 125 | 126 | local idx = 1 127 | local words = {} 128 | for token in string.gmatch(question_string, "%a+") do 129 | words[idx] = token 130 | idx = idx + 1 131 | end 132 | 133 | for i = 1, #words do 134 | question[loader.q_max_length - #words + i] = loader.q_vocab_mapping[words[i]] or loader.q_vocab_mapping['UNK'] 135 | end 136 | 137 | if opt.gpuid >= 0 then 138 | question = question:cuda() 139 | end 140 | 141 | -- 1st index of `nn.LookupTable` is for zeros 142 | question = question + 1 143 | 144 | local qf = checkpoint.protos.ltw:forward(question) 145 | 146 | -- lstm + softmax 147 | 148 | local init_state = {} 149 | for L = 1, opt.num_layers do 150 | local h_init = torch.zeros(1, opt.rnn_size) 151 | if opt.gpuid >=0 then h_init = h_init:cuda() end 152 | table.insert(init_state, h_init:clone()) 153 | table.insert(init_state, h_init:clone()) 154 | end 155 | 156 | local rnn_state = {[0] = init_state} 157 | 158 | for t = 1, loader.q_max_length do 159 | local lst = lstm_clones[t]:forward{qf:select(1,t):view(1,-1), unpack(rnn_state[t-1])} 160 | rnn_state[t] = {} 161 | for i = 1, #init_state do table.insert(rnn_state[t], lst[i]) end 162 | end 163 | 164 | local lst = lstm_clones[loader.q_max_length + 1]:forward{imf:view(1,-1), unpack(rnn_state[loader.q_max_length])} 165 | 166 | local prediction = checkpoint.protos.sm:forward(lst[#lst]) 167 | 168 | local _, idx = prediction:max(2) 169 | 170 | print(a_iv[idx[1][1]]) 171 | end 172 | 173 | predict(opt.input_image_path, opt.question) 174 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Torch implementation of the VIS + LSTM model from the paper 3 | 'Exploring Models and Data for Image Question Answering' 4 | by Mengye Ren, Ryan Kiros & Richard Zemel. 5 | 6 | This implementation passes the question embeddings 7 | first and then the image embeddings into the LSTM, 8 | and does a softmax over the answer vocabulary. 9 | ]]-- 10 | 11 | require 'torch' 12 | require 'nn' 13 | require 'nngraph' 14 | require 'optim' 15 | 16 | local utils = require 'utils.misc' 17 | local DataLoader = require 'utils.DataLoader' 18 | 19 | local LSTM = require 'lstm' 20 | 21 | cmd = torch.CmdLine() 22 | cmd:text('Options') 23 | 24 | -- model params 25 | cmd:option('-rnn_size', 512, 'Size of LSTM internal state') 26 | cmd:option('-num_layers', 2, 'Number of layers in LSTM') 27 | cmd:option('-embedding_size', 512, 'Size of word embeddings') 28 | -- optimization 29 | cmd:option('-learning_rate', 4e-4, 'Learning rate') 30 | cmd:option('-learning_rate_decay', 0.95, 'Learning rate decay') 31 | cmd:option('-learning_rate_decay_after', 15, 'In number of epochs, when to start decaying the learning rate') 32 | cmd:option('-alpha', 0.8, 'alpha for adam') 33 | cmd:option('-beta', 0.999, 'beta used for adam') 34 | cmd:option('-epsilon', 1e-8, 'epsilon that goes into denominator for smoothing') 35 | cmd:option('-batch_size', 200, 'Batch size') 36 | cmd:option('-max_epochs', 50, 'Number of full passes through the training data') 37 | cmd:option('-dropout', 0.5, 'Dropout') 38 | cmd:option('-init_from', '', 'Initialize network parameters from checkpoint at this path') 39 | -- bookkeeping 40 | cmd:option('-seed', 981723, 'Torch manual random number generator seed') 41 | cmd:option('-save_every', 1000, 'No. of iterations after which to checkpoint') 42 | cmd:option('-train_fc7_file', 'data/train_fc7.t7', 'Path to fc7 features of training set') 43 | cmd:option('-train_fc7_image_id_file', 'data/train_fc7_image_id.t7', 'Path to fc7 image ids of training set') 44 | cmd:option('-val_fc7_file', 'data/val_fc7.t7', 'Path to fc7 features of validation set') 45 | cmd:option('-val_fc7_image_id_file', 'data/val_fc7_image_id.t7', 'Path to fc7 image ids of validation set') 46 | cmd:option('-data_dir', 'data', 'Data directory') 47 | cmd:option('-checkpoint_dir', 'checkpoints', 'Checkpoint directory') 48 | cmd:option('-savefile', 'vqa', 'Filename to save checkpoint to') 49 | -- gpu/cpu 50 | cmd:option('-gpuid', -1, '0-indexed id of GPU to use. -1 = CPU') 51 | 52 | -- parse command-line parameters 53 | opt = cmd:parse(arg or {}) 54 | print(opt) 55 | torch.manualSeed(opt.seed) 56 | 57 | -- gpu stuff 58 | if opt.gpuid >= 0 then 59 | local ok, cunn = pcall(require, 'cunn') 60 | local ok2, cutorch = pcall(require, 'cutorch') 61 | if not ok then print('package cunn not found!') end 62 | if not ok2 then print('package cutorch not found!') end 63 | if ok and ok2 then 64 | print('using CUDA on GPU ' .. opt.gpuid .. '...') 65 | cutorch.setDevice(opt.gpuid + 1) -- torch is 1-indexed 66 | cutorch.manualSeed(opt.seed) 67 | else 68 | print('If cutorch and cunn are installed, your CUDA toolkit may be improperly configured.') 69 | print('Check your CUDA toolkit installation, rebuild cutorch and cunn, and try again.') 70 | print('Falling back to CPU mode') 71 | opt.gpuid = -1 72 | end 73 | end 74 | 75 | -- initialize the data loader 76 | -- checks if .t7 data files exist 77 | -- if they don't or if they're old, 78 | -- they're created from scratch and loaded 79 | local loader = DataLoader.create(opt.data_dir, opt.batch_size, opt) 80 | 81 | -- create the directory for saving snapshots of model at different times during training 82 | if not path.exists(opt.checkpoint_dir) then lfs.mkdir(opt.checkpoint_dir) end 83 | 84 | local do_random_init = true 85 | if string.len(opt.init_from) > 0 then 86 | -- initializing model from checkpoint 87 | print('Loading model from checkpoint ' .. opt.init_from) 88 | local checkpoint = torch.load(opt.init_from) 89 | protos = checkpoint.protos 90 | do_random_init = false 91 | else 92 | -- model definition 93 | -- components: ltw, lti, lstm and sm 94 | protos = {} 95 | 96 | -- ltw: lookup table + dropout for question words 97 | -- each word of the question gets mapped to its index in vocabulary 98 | -- and then is passed through ltw to get a vector of size `embedding_size` 99 | -- lookup table dimensions are `vocab_size` x `embedding_size` 100 | protos.ltw = nn.Sequential() 101 | protos.ltw:add(nn.LookupTable(loader.q_vocab_size+1, opt.embedding_size)) 102 | protos.ltw:add(nn.Dropout(opt.dropout)) 103 | 104 | -- lti: fully connected layer + dropout for image features 105 | -- activations from the last fully connected layer of the deep convnet (VGG in this case) 106 | -- are passed through lti to get a vector of `embedding_size` 107 | -- linear layer dimensions are 4096 (size of fc7 layer) x `embedding_size` 108 | protos.lti = nn.Sequential() 109 | protos.lti:add(nn.Linear(4096, opt.embedding_size)) 110 | protos.lti:add(nn.Tanh()) 111 | protos.lti:add(nn.Dropout(opt.dropout)) 112 | 113 | -- lstm: long short-term memory cell which takes a vector of size `embedding_size` at every time step 114 | -- hidden state h_t of LSTM cell in first layer is passed as input x_t of cell in second layer and so on. 115 | protos.lstm = LSTM.create(opt.embedding_size, opt.rnn_size, opt.num_layers) 116 | 117 | -- sm: linear layer + softmax over the answer vocabulary 118 | -- linear layer dimensions are `rnn_size` x `answer_vocab_size` 119 | protos.sm = nn.Sequential() 120 | protos.sm:add(nn.Linear(opt.rnn_size, loader.a_vocab_size)) 121 | protos.sm:add(nn.LogSoftMax()) 122 | 123 | -- negative log-likelihood loss 124 | protos.criterion = nn.ClassNLLCriterion() 125 | 126 | -- pass over the model to gpu 127 | if opt.gpuid >= 0 then 128 | protos.ltw = protos.ltw:cuda() 129 | protos.lti = protos.lti:cuda() 130 | protos.lstm = protos.lstm:cuda() 131 | protos.sm = protos.sm:cuda() 132 | protos.criterion = protos.criterion:cuda() 133 | end 134 | end 135 | 136 | -- put all trainable model parameters into one flattened parameters tensor 137 | params, grad_params = utils.combine_all_parameters(protos.lti, protos.lstm, protos.sm) 138 | 139 | print('Parameters: ' .. params:size(1)) 140 | print('Batches: ' .. loader.batch_data.train.nbatches) 141 | 142 | -- initialize model parameters 143 | if do_random_init then 144 | params:uniform(-0.08, 0.08) 145 | end 146 | 147 | -- make clones of the LSTM model that shared parameters for subsequent timesteps (unrolling) 148 | lstm_clones = {} 149 | lstm_clones = utils.clone_many_times(protos.lstm, loader.q_max_length + 1) 150 | 151 | -- initialize h_0 and c_0 of LSTM to zero tensors and store in `init_state` 152 | init_state = {} 153 | for L = 1, opt.num_layers do 154 | local h_init = torch.zeros(opt.batch_size, opt.rnn_size) 155 | if opt.gpuid >=0 then h_init = h_init:cuda() end 156 | table.insert(init_state, h_init:clone()) 157 | table.insert(init_state, h_init:clone()) 158 | end 159 | 160 | -- make a clone of `init_state` as it's going to be modified later 161 | local init_state_global = utils.clone_list(init_state) 162 | 163 | -- closure to calculate accuracy over validation set 164 | feval_val = function(max_batches) 165 | 166 | count = 0 167 | n = loader.batch_data.val.nbatches 168 | 169 | -- set `n` to `max_batches` if provided 170 | if max_batches ~= nil then n = math.min(n, max_batches) end 171 | 172 | -- set to evaluation mode for dropout to work properly 173 | protos.ltw:evaluate() 174 | protos.lti:evaluate() 175 | 176 | for i = 1, n do 177 | 178 | -- load question batch, answer batch and image features batch 179 | q_batch, a_batch, i_batch = loader:next_batch('val') 180 | 181 | -- 1st index of `nn.LookupTable` is reserved for zeros 182 | q_batch = q_batch + 1 183 | 184 | -- forward the question features through ltw 185 | qf = protos.ltw:forward(q_batch) 186 | 187 | -- forward the image features through lti 188 | imf = protos.lti:forward(i_batch) 189 | 190 | -- convert to CudaTensor if using gpu 191 | if opt.gpuid >= 0 then 192 | imf = imf:cuda() 193 | end 194 | 195 | -- set the state at 0th time step of LSTM 196 | rnn_state = {[0] = init_state_global} 197 | 198 | -- LSTM forward pass for question features 199 | for t = 1, loader.q_max_length do 200 | lst = lstm_clones[t]:forward{qf:select(2,t), unpack(rnn_state[t-1])} 201 | -- at every time step, set the rnn state (h_t, c_t) to be passed as input in next time step 202 | rnn_state[t] = {} 203 | for i = 1, #init_state do table.insert(rnn_state[t], lst[i]) end 204 | end 205 | 206 | -- after completing the unrolled LSTM forward pass with question features, forward the image features 207 | lst = lstm_clones[loader.q_max_length + 1]:forward{imf, unpack(rnn_state[loader.q_max_length])} 208 | 209 | -- forward the hidden state at the last time step to get softmax over answers 210 | prediction = protos.sm:forward(lst[#lst]) 211 | 212 | -- count number of correct answers 213 | _, idx = prediction:max(2) 214 | for j = 1, opt.batch_size do 215 | if idx[j][1] == a_batch[j] then 216 | count = count + 1 217 | end 218 | end 219 | 220 | end 221 | 222 | -- set to training mode once done with validation 223 | protos.ltw:training() 224 | protos.lti:training() 225 | 226 | -- return accuracy 227 | return count / (n * opt.batch_size) 228 | 229 | end 230 | 231 | -- closure to run a forward and backward pass and return loss and gradient parameters 232 | feval = function(x) 233 | 234 | -- get latest parameters 235 | if x ~= params then 236 | params:copy(x) 237 | end 238 | grad_params:zero() 239 | 240 | -- load question batch, answer batch and image features batch 241 | q_batch, a_batch, i_batch = loader:next_batch() 242 | 243 | -- slightly hackish; 1st index of `nn.LookupTable` is reserved for zeros 244 | q_batch = q_batch + 1 245 | 246 | -- forward the question features through ltw 247 | qf = protos.ltw:forward(q_batch) 248 | 249 | -- forward the image features through lti 250 | imf = protos.lti:forward(i_batch) 251 | 252 | -- convert to CudaTensor if using gpu 253 | if opt.gpuid >= 0 then 254 | imf = imf:cuda() 255 | end 256 | 257 | ------------ forward pass ------------ 258 | 259 | -- set initial loss 260 | loss = 0 261 | 262 | -- set the state at 0th time step of LSTM 263 | rnn_state = {[0] = init_state_global} 264 | 265 | -- LSTM forward pass for question features 266 | for t = 1, loader.q_max_length do 267 | lst = lstm_clones[t]:forward{qf:select(2,t), unpack(rnn_state[t-1])} 268 | -- at every time step, set the rnn state (h_t, c_t) to be passed as input in next time step 269 | rnn_state[t] = {} 270 | for i = 1, #init_state do table.insert(rnn_state[t], lst[i]) end 271 | end 272 | 273 | -- after completing the unrolled LSTM forward pass with question features, forward the image features 274 | lst = lstm_clones[loader.q_max_length + 1]:forward{imf, unpack(rnn_state[loader.q_max_length])} 275 | 276 | -- forward the hidden state at the last time step to get softmax over answers 277 | prediction = protos.sm:forward(lst[#lst]) 278 | 279 | -- calculate loss 280 | loss = protos.criterion:forward(prediction, a_batch) 281 | 282 | ------------ backward pass ------------ 283 | 284 | -- backprop through loss and softmax 285 | dloss = protos.criterion:backward(prediction, a_batch) 286 | doutput_t = protos.sm:backward(lst[#lst], dloss) 287 | 288 | -- set internal state of LSTM (starting from last time step) 289 | drnn_state = {[loader.q_max_length + 1] = utils.clone_list(init_state, true)} 290 | drnn_state[loader.q_max_length + 1][opt.num_layers * 2] = doutput_t 291 | 292 | -- backprop for last time step (image features) 293 | dlst = lstm_clones[loader.q_max_length + 1]:backward({imf, unpack(rnn_state[loader.q_max_length])}, drnn_state[loader.q_max_length + 1]) 294 | 295 | -- backprop into image linear layer 296 | protos.lti:backward(i_batch, dlst[1]) 297 | 298 | -- set LSTM state 299 | drnn_state[loader.q_max_length] = {} 300 | for i,v in pairs(dlst) do 301 | if i > 1 then 302 | drnn_state[loader.q_max_length][i-1] = v 303 | end 304 | end 305 | 306 | dqf = torch.Tensor(qf:size()):zero() 307 | if opt.gpuid >= 0 then 308 | dqf = dqf:cuda() 309 | end 310 | 311 | -- backprop into the LSTM for rest of the time steps 312 | for t = loader.q_max_length, 1, -1 do 313 | dlst = lstm_clones[t]:backward({qf:select(2, t), unpack(rnn_state[t-1])}, drnn_state[t]) 314 | dqf:select(2, t):copy(dlst[1]) 315 | drnn_state[t-1] = {} 316 | for i,v in pairs(dlst) do 317 | if i > 1 then 318 | drnn_state[t-1][i-1] = v 319 | end 320 | end 321 | end 322 | 323 | -- zero gradient buffers of lookup table, backprop into it and update parameters 324 | protos.ltw:zeroGradParameters() 325 | protos.ltw:backward(q_batch, dqf) 326 | protos.ltw:updateParameters(opt.learning_rate) 327 | 328 | -- clip gradient element-wise 329 | grad_params:clamp(-5, 5) 330 | 331 | return loss, grad_params 332 | 333 | end 334 | 335 | -- optim state with ADAM parameters 336 | local optim_state = {learningRate = opt.learning_rate, alpha = opt.alpha, beta = opt.beta, epsilon = opt.epsilon} 337 | 338 | -- train / val loop! 339 | losses = {} 340 | iterations = opt.max_epochs * loader.batch_data.train.nbatches 341 | print('Max iterations: ' .. iterations) 342 | lloss = 0 343 | for i = 1, iterations do 344 | _, local_loss = optim.adam(feval, params, optim_state) 345 | 346 | losses[#losses + 1] = local_loss[1] 347 | 348 | lloss = lloss + local_loss[1] 349 | local epoch = i / loader.batch_data.train.nbatches 350 | 351 | if i%10 == 0 then 352 | print('epoch ' .. epoch .. ' loss ' .. lloss / 10) 353 | lloss = 0 354 | end 355 | 356 | -- Decay learning rate occasionally 357 | if i % loader.batch_data.train.nbatches == 0 and opt.learning_rate_decay < 1 then 358 | if epoch >= opt.learning_rate_decay_after then 359 | local decay_factor = opt.learning_rate_decay 360 | optim_state.learningRate = optim_state.learningRate * decay_factor -- decay it 361 | print('decayed learning rate by a factor ' .. decay_factor .. ' to ' .. optim_state.learningRate) 362 | end 363 | end 364 | 365 | -- Calculate validation accuracy and save model snapshot 366 | if i % opt.save_every == 0 or i == iterations then 367 | print('Checkpointing. Calculating validation accuracy..') 368 | local val_acc = feval_val() 369 | local savefile = string.format('%s/%s_epoch%.2f_%.4f.t7', opt.checkpoint_dir, opt.savefile, epoch, val_acc) 370 | print('Saving checkpoint to ' .. savefile) 371 | local checkpoint = {} 372 | checkpoint.opt = opt 373 | checkpoint.protos = protos 374 | checkpoint.vocab_size = loader.q_vocab_size 375 | torch.save(savefile, checkpoint) 376 | end 377 | 378 | if i%10 == 0 then 379 | collectgarbage() 380 | end 381 | end 382 | -------------------------------------------------------------------------------- /utils/DataLoader.lua: -------------------------------------------------------------------------------- 1 | -- Messy but works. 2 | 3 | local DataLoader = {} 4 | DataLoader.__index = DataLoader 5 | 6 | function DataLoader.create(data_dir, batch_size, opt, mode) 7 | 8 | local self = {} 9 | setmetatable(self, DataLoader) 10 | 11 | self.mode = mode or 'train' 12 | 13 | local train_questions_file = path.join(data_dir, 'MultipleChoice_mscoco_train2014_questions.json') 14 | local train_annotations_file = path.join(data_dir, 'mscoco_train2014_annotations.json') 15 | 16 | local val_questions_file = path.join(data_dir, 'MultipleChoice_mscoco_val2014_questions.json') 17 | local val_annotations_file = path.join(data_dir, 'mscoco_val2014_annotations.json') 18 | 19 | local questions_vocab_file = path.join(data_dir, 'questions_vocab.t7') 20 | local answers_vocab_file = path.join(data_dir, 'answers_vocab.t7') 21 | 22 | local tensor_file = path.join(data_dir, 'data.t7') 23 | 24 | -- fetch file attributes to determine if we need to rerun preprocessing 25 | 26 | local run_prepro = false 27 | if not (path.exists(questions_vocab_file) and path.exists(answers_vocab_file) and path.exists(tensor_file)) then 28 | print('questions_vocab.t7, answers_vocab.t7 or data.t7 files do not exist. Running preprocessing...') 29 | run_prepro = true 30 | else 31 | local train_questions_attr = lfs.attributes(train_questions_file) 32 | local questions_vocab_attr = lfs.attributes(questions_vocab_file) 33 | local tensor_attr = lfs.attributes(tensor_file) 34 | 35 | if train_questions_attr.modification > questions_vocab_attr.modification or train_questions_attr.modification > tensor_attr.modification then 36 | print('questions_vocab.t7 or data.t7 detected as stale. Re-running preprocessing...') 37 | run_prepro = true 38 | end 39 | end 40 | if run_prepro then 41 | -- construct a tensor with all the data, and vocab file 42 | print('one-time setup: preprocessing...') 43 | DataLoader.json_to_tensor(train_questions_file, train_annotations_file, val_questions_file, val_annotations_file, questions_vocab_file, answers_vocab_file, tensor_file) 44 | end 45 | 46 | print('Loading data files...') 47 | local data = torch.load(tensor_file) 48 | if mode == 'fc7_feat' then 49 | self.data = data 50 | collectgarbage() 51 | return self 52 | end 53 | 54 | self.q_max_length = data.q_max_length 55 | self.q_vocab_mapping = torch.load(questions_vocab_file) 56 | self.a_vocab_mapping = torch.load(answers_vocab_file) 57 | 58 | self.q_vocab_size = 0 59 | for _ in pairs(self.q_vocab_mapping) do 60 | self.q_vocab_size = self.q_vocab_size + 1 61 | end 62 | 63 | self.a_vocab_size = 0 64 | for _ in pairs(self.a_vocab_mapping) do 65 | self.a_vocab_size = self.a_vocab_size + 1 66 | end 67 | 68 | self.batch_size = batch_size 69 | 70 | if mode == 'predict' then 71 | collectgarbage() 72 | return self 73 | end 74 | 75 | self.train_nbatches = 0 76 | self.val_nbatches = 0 77 | 78 | -- Load train into batches 79 | 80 | print('Loading train fc7 features from ' .. opt.train_fc7_file) 81 | local fc7 = torch.load(opt.train_fc7_file) 82 | local fc7_image_id = torch.load(opt.train_fc7_image_id_file) 83 | local fc7_mapping = {} 84 | for i, v in pairs(fc7_image_id) do 85 | fc7_mapping[v] = i 86 | end 87 | 88 | self.batch_data = {['train'] = {}, ['val'] = {}} 89 | 90 | self.batch_data.train = { 91 | ['question'] = torch.ShortTensor(self.batch_size * math.floor(#data.train / self.batch_size), data.q_max_length), 92 | ['answer'] = torch.ShortTensor(self.batch_size * math.floor(#data.train / self.batch_size)), 93 | ['image_feat'] = torch.DoubleTensor(self.batch_size * math.floor(#data.train / self.batch_size), 4096), 94 | ['image_id'] = {}, 95 | ['nbatches'] = math.floor(#data.train / self.batch_size) 96 | } 97 | 98 | if opt.gpuid >= 0 then 99 | self.batch_data.train.image_feat = self.batch_data.train.image_feat:cuda() 100 | end 101 | 102 | for i = 1, self.batch_size * self.batch_data.train.nbatches do 103 | self.batch_data.train.question[i] = data.train[i]['question'] 104 | self.batch_data.train.answer[i] = data.train[i]['answer'] 105 | self.batch_data.train.image_feat[i] = fc7[fc7_mapping[data.train[i]['image_id']]] 106 | self.batch_data.train.image_id[i] = data.train[i]['image_id'] 107 | end 108 | 109 | if opt.gpuid >= 0 then 110 | self.batch_data.train.question = self.batch_data.train.question:cuda() 111 | self.batch_data.train.answer = self.batch_data.train.answer:cuda() 112 | end 113 | 114 | -- Load val into batches 115 | 116 | print('Loading val fc7 features from ' .. opt.val_fc7_file) 117 | local fc7 = torch.load(opt.val_fc7_file) 118 | local fc7_image_id = torch.load(opt.val_fc7_image_id_file) 119 | local fc7_mapping = {} 120 | for i, v in pairs(fc7_image_id) do 121 | fc7_mapping[v] = i 122 | end 123 | 124 | self.batch_data.val = { 125 | ['question'] = torch.ShortTensor(self.batch_size * math.floor(#data.val / self.batch_size), data.q_max_length), 126 | ['answer'] = torch.ShortTensor(self.batch_size * math.floor(#data.val / self.batch_size)), 127 | ['image_feat'] = torch.DoubleTensor(self.batch_size * math.floor(#data.val / self.batch_size), 4096), 128 | ['image_id'] = {}, 129 | ['nbatches'] = math.floor(#data.val / self.batch_size) 130 | } 131 | 132 | if opt.gpuid >= 0 then 133 | self.batch_data.val.image_feat = self.batch_data.val.image_feat:cuda() 134 | end 135 | 136 | for i = 1, self.batch_size * self.batch_data.val.nbatches do 137 | self.batch_data.val.question[i] = data.val[i]['question'] 138 | self.batch_data.val.answer[i] = data.val[i]['answer'] 139 | self.batch_data.val.image_feat[i] = fc7[fc7_mapping[data.val[i]['image_id']]] 140 | self.batch_data.val.image_id[i] = data.val[i]['image_id'] 141 | end 142 | 143 | if opt.gpuid >= 0 then 144 | self.batch_data.val.question = self.batch_data.val.question:cuda() 145 | self.batch_data.val.answer = self.batch_data.val.answer:cuda() 146 | end 147 | 148 | self.train_batch_idx = 1 149 | self.val_batch_idx = 1 150 | 151 | collectgarbage() 152 | return self 153 | 154 | end 155 | 156 | function DataLoader:next_batch(split) 157 | split = split or 'train' 158 | if split == 'train' then 159 | if self.train_batch_idx - 1 == self.batch_data.train.nbatches then self.train_batch_idx = 1 end 160 | local question = self.batch_data.train.question:narrow(1, (self.train_batch_idx - 1) * self.batch_size + 1, self.batch_size) 161 | local answer = self.batch_data.train.answer:narrow(1, (self.train_batch_idx - 1) * self.batch_size + 1, self.batch_size) 162 | local image = self.batch_data.train.image_feat:narrow(1, (self.train_batch_idx - 1) * self.batch_size + 1, self.batch_size) 163 | local image_id = {unpack(self.batch_data.train.image_id, (self.train_batch_idx - 1) * self.batch_size + 1, self.train_batch_idx * self.batch_size)} 164 | 165 | self.train_batch_idx = self.train_batch_idx + 1 166 | return question, answer, image, image_id 167 | else 168 | if self.val_batch_idx - 1 == self.batch_data.val.nbatches then self.val_batch_idx = 1 end 169 | local question = self.batch_data.val.question:narrow(1, (self.val_batch_idx - 1) * self.batch_size + 1, self.batch_size) 170 | local answer = self.batch_data.val.answer:narrow(1, (self.val_batch_idx - 1) * self.batch_size + 1, self.batch_size) 171 | local image = self.batch_data.val.image_feat:narrow(1, (self.val_batch_idx - 1) * self.batch_size + 1, self.batch_size) 172 | local image_id = {unpack(self.batch_data.val.image_id, (self.val_batch_idx - 1) * self.batch_size + 1, self.val_batch_idx * self.batch_size)} 173 | 174 | self.val_batch_idx = self.val_batch_idx + 1 175 | return question, answer, image, image_id 176 | end 177 | end 178 | 179 | function DataLoader.json_to_tensor(in_train_q, in_train_a, in_val_q, in_val_a, out_vocab_q, out_vocab_a, out_tensor) 180 | 181 | local JSON = (loadfile "utils/JSON.lua")() 182 | 183 | print('creating vocabulary mapping...') 184 | 185 | -- build answer vocab using train+val 186 | 187 | local f = torch.DiskFile(in_train_a) 188 | c = f:readString('*a') 189 | local train_a = JSON:decode(c) 190 | f:close() 191 | 192 | f = torch.DiskFile(in_val_a) 193 | c = f:readString('*a') 194 | local val_a = JSON:decode(c) 195 | f:close() 196 | 197 | local unordered = {} 198 | 199 | for i = 1, #train_a['annotations'] do 200 | local token = train_a['annotations'][i]['multiple_choice_answer'] 201 | if not unordered[token] then 202 | unordered[token] = 1 203 | else 204 | unordered[token] = unordered[token] + 1 205 | end 206 | end 207 | 208 | for i = 1, #val_a['annotations'] do 209 | local token = val_a['annotations'][i]['multiple_choice_answer'] 210 | if not unordered[token] then 211 | unordered[token] = 1 212 | else 213 | unordered[token] = unordered[token] + 1 214 | end 215 | end 216 | 217 | local sorted_a = get_keys_sorted_by_value(unordered, function(a, b) return a > b end) 218 | 219 | local top_n = 1000 220 | local ordered = {} 221 | for i = 1, top_n do 222 | ordered[#ordered + 1] = sorted_a[i] 223 | end 224 | ordered[#ordered + 1] = "UNK" 225 | table.sort(ordered) 226 | 227 | local a_vocab_mapping = {} 228 | for i, word in ipairs(ordered) do 229 | a_vocab_mapping[word] = i 230 | end 231 | 232 | -- build question vocab using train+val 233 | 234 | f = torch.DiskFile(in_train_q) 235 | c = f:readString('*a') 236 | local train_q = JSON:decode(c) 237 | f:close() 238 | 239 | f = torch.DiskFile(in_val_q) 240 | c = f:readString('*a') 241 | local val_q = JSON:decode(c) 242 | f:close() 243 | 244 | unordered = {} 245 | max_length = 0 246 | 247 | for i = 1, #train_q['questions'] do 248 | local count = 0 249 | if a_vocab_mapping[train_a['annotations'][i]['multiple_choice_answer']] then 250 | for token in word_iter(train_q['questions'][i]['question']) do 251 | if not unordered[token] then 252 | unordered[token] = 1 253 | else 254 | unordered[token] = unordered[token] + 1 255 | end 256 | count = count + 1 257 | end 258 | if count > max_length then max_length = count end 259 | end 260 | end 261 | 262 | for i = 1, #val_q['questions'] do 263 | local count = 0 264 | for token in word_iter(val_q['questions'][i]['question']) do 265 | if not unordered[token] then 266 | unordered[token] = 1 267 | else 268 | unordered[token] = unordered[token] + 1 269 | end 270 | count = count + 1 271 | end 272 | if count > max_length then max_length = count end 273 | end 274 | 275 | local threshold = 0 276 | local ordered = {} 277 | for token, count in pairs(unordered) do 278 | if count > threshold then 279 | ordered[#ordered + 1] = token 280 | end 281 | end 282 | ordered[#ordered + 1] = "UNK" 283 | table.sort(ordered) 284 | 285 | local q_vocab_mapping = {} 286 | for i, word in ipairs(ordered) do 287 | q_vocab_mapping[word] = i 288 | end 289 | 290 | print('putting data into tensor...') 291 | 292 | -- save train+val data 293 | 294 | local data = { 295 | train = {}, 296 | val = {}, 297 | q_max_length = max_length 298 | } 299 | 300 | print('q max length: ' .. max_length) 301 | 302 | local idx = 1 303 | 304 | for i = 1, #train_q['questions'] do 305 | if a_vocab_mapping[train_a['annotations'][i]['multiple_choice_answer']] then 306 | local question = {} 307 | local wl = 0 308 | for token in word_iter(train_q['questions'][i]['question']) do 309 | wl = wl + 1 310 | question[wl] = q_vocab_mapping[token] or q_vocab_mapping["UNK"] 311 | end 312 | data.train[idx] = { 313 | image_id = train_a['annotations'][i]['image_id'], 314 | question = torch.ShortTensor(max_length):zero(), 315 | answer = a_vocab_mapping[train_a['annotations'][i]['multiple_choice_answer']] or a_vocab_mapping["UNK"] 316 | } 317 | for j = 1, wl do 318 | data.train[idx]['question'][max_length - wl + j] = question[j] 319 | end 320 | idx = idx + 1 321 | end 322 | end 323 | 324 | idx = 1 325 | 326 | for i = 1, #val_q['questions'] do 327 | local question = {} 328 | local wl = 0 329 | for token in word_iter(val_q['questions'][i]['question']) do 330 | wl = wl + 1 331 | question[wl] = q_vocab_mapping[token] or q_vocab_mapping["UNK"] 332 | end 333 | data.val[idx] = { 334 | image_id = val_a['annotations'][i]['image_id'], 335 | question = torch.ShortTensor(max_length):zero(), 336 | answer = a_vocab_mapping[val_a['annotations'][i]['multiple_choice_answer']] or a_vocab_mapping["UNK"] 337 | } 338 | for j = 1, wl do 339 | data.val[idx]['question'][max_length - wl + j] = question[j] 340 | end 341 | idx = idx + 1 342 | end 343 | 344 | -- save output preprocessed files 345 | print('saving ' .. out_vocab_q) 346 | torch.save(out_vocab_q, q_vocab_mapping) 347 | print('saving ' .. out_vocab_a) 348 | torch.save(out_vocab_a, a_vocab_mapping) 349 | print('saving ' .. out_tensor) 350 | torch.save(out_tensor, data) 351 | 352 | end 353 | 354 | function word_iter(str) 355 | return string.gmatch(str, "%a+") 356 | end 357 | 358 | function get_keys_sorted_by_value(tbl, sort_fn) 359 | local keys = {} 360 | for key in pairs(tbl) do 361 | table.insert(keys, key) 362 | end 363 | 364 | table.sort(keys, function(a, b) 365 | return sort_fn(tbl[a], tbl[b]) 366 | end) 367 | 368 | return keys 369 | end 370 | 371 | return DataLoader 372 | -------------------------------------------------------------------------------- /utils/JSON.lua: -------------------------------------------------------------------------------- 1 | -- -*- coding: utf-8 -*- 2 | -- 3 | -- Simple JSON encoding and decoding in pure Lua. 4 | -- 5 | -- Copyright 2010-2014 Jeffrey Friedl 6 | -- http://regex.info/blog/ 7 | -- 8 | -- Latest version: http://regex.info/blog/lua/json 9 | -- 10 | -- This code is released under a Creative Commons CC-BY "Attribution" License: 11 | -- http://creativecommons.org/licenses/by/3.0/deed.en_US 12 | -- 13 | -- It can be used for any purpose so long as the copyright notice above, 14 | -- the web-page links above, and the 'AUTHOR_NOTE' string below are 15 | -- maintained. Enjoy. 16 | -- 17 | local VERSION = 20141223.14 -- version history at end of file 18 | local AUTHOR_NOTE = "-[ JSON.lua package by Jeffrey Friedl (http://regex.info/blog/lua/json) version 20141223.14 ]-" 19 | 20 | -- 21 | -- The 'AUTHOR_NOTE' variable exists so that information about the source 22 | -- of the package is maintained even in compiled versions. It's also 23 | -- included in OBJDEF below mostly to quiet warnings about unused variables. 24 | -- 25 | local OBJDEF = { 26 | VERSION = VERSION, 27 | AUTHOR_NOTE = AUTHOR_NOTE, 28 | } 29 | 30 | 31 | -- 32 | -- Simple JSON encoding and decoding in pure Lua. 33 | -- http://www.json.org/ 34 | -- 35 | -- 36 | -- JSON = assert(loadfile "JSON.lua")() -- one-time load of the routines 37 | -- 38 | -- local lua_value = JSON:decode(raw_json_text) 39 | -- 40 | -- local raw_json_text = JSON:encode(lua_table_or_value) 41 | -- local pretty_json_text = JSON:encode_pretty(lua_table_or_value) -- "pretty printed" version for human readability 42 | -- 43 | -- 44 | -- 45 | -- DECODING (from a JSON string to a Lua table) 46 | -- 47 | -- 48 | -- JSON = assert(loadfile "JSON.lua")() -- one-time load of the routines 49 | -- 50 | -- local lua_value = JSON:decode(raw_json_text) 51 | -- 52 | -- If the JSON text is for an object or an array, e.g. 53 | -- { "what": "books", "count": 3 } 54 | -- or 55 | -- [ "Larry", "Curly", "Moe" ] 56 | -- 57 | -- the result is a Lua table, e.g. 58 | -- { what = "books", count = 3 } 59 | -- or 60 | -- { "Larry", "Curly", "Moe" } 61 | -- 62 | -- 63 | -- The encode and decode routines accept an optional second argument, 64 | -- "etc", which is not used during encoding or decoding, but upon error 65 | -- is passed along to error handlers. It can be of any type (including nil). 66 | -- 67 | -- 68 | -- 69 | -- ERROR HANDLING 70 | -- 71 | -- With most errors during decoding, this code calls 72 | -- 73 | -- JSON:onDecodeError(message, text, location, etc) 74 | -- 75 | -- with a message about the error, and if known, the JSON text being 76 | -- parsed and the byte count where the problem was discovered. You can 77 | -- replace the default JSON:onDecodeError() with your own function. 78 | -- 79 | -- The default onDecodeError() merely augments the message with data 80 | -- about the text and the location if known (and if a second 'etc' 81 | -- argument had been provided to decode(), its value is tacked onto the 82 | -- message as well), and then calls JSON.assert(), which itself defaults 83 | -- to Lua's built-in assert(), and can also be overridden. 84 | -- 85 | -- For example, in an Adobe Lightroom plugin, you might use something like 86 | -- 87 | -- function JSON:onDecodeError(message, text, location, etc) 88 | -- LrErrors.throwUserError("Internal Error: invalid JSON data") 89 | -- end 90 | -- 91 | -- or even just 92 | -- 93 | -- function JSON.assert(message) 94 | -- LrErrors.throwUserError("Internal Error: " .. message) 95 | -- end 96 | -- 97 | -- If JSON:decode() is passed a nil, this is called instead: 98 | -- 99 | -- JSON:onDecodeOfNilError(message, nil, nil, etc) 100 | -- 101 | -- and if JSON:decode() is passed HTML instead of JSON, this is called: 102 | -- 103 | -- JSON:onDecodeOfHTMLError(message, text, nil, etc) 104 | -- 105 | -- The use of the fourth 'etc' argument allows stronger coordination 106 | -- between decoding and error reporting, especially when you provide your 107 | -- own error-handling routines. Continuing with the the Adobe Lightroom 108 | -- plugin example: 109 | -- 110 | -- function JSON:onDecodeError(message, text, location, etc) 111 | -- local note = "Internal Error: invalid JSON data" 112 | -- if type(etc) = 'table' and etc.photo then 113 | -- note = note .. " while processing for " .. etc.photo:getFormattedMetadata('fileName') 114 | -- end 115 | -- LrErrors.throwUserError(note) 116 | -- end 117 | -- 118 | -- : 119 | -- : 120 | -- 121 | -- for i, photo in ipairs(photosToProcess) do 122 | -- : 123 | -- : 124 | -- local data = JSON:decode(someJsonText, { photo = photo }) 125 | -- : 126 | -- : 127 | -- end 128 | -- 129 | -- 130 | -- 131 | -- 132 | -- 133 | -- DECODING AND STRICT TYPES 134 | -- 135 | -- Because both JSON objects and JSON arrays are converted to Lua tables, 136 | -- it's not normally possible to tell which original JSON type a 137 | -- particular Lua table was derived from, or guarantee decode-encode 138 | -- round-trip equivalency. 139 | -- 140 | -- However, if you enable strictTypes, e.g. 141 | -- 142 | -- JSON = assert(loadfile "JSON.lua")() --load the routines 143 | -- JSON.strictTypes = true 144 | -- 145 | -- then the Lua table resulting from the decoding of a JSON object or 146 | -- JSON array is marked via Lua metatable, so that when re-encoded with 147 | -- JSON:encode() it ends up as the appropriate JSON type. 148 | -- 149 | -- (This is not the default because other routines may not work well with 150 | -- tables that have a metatable set, for example, Lightroom API calls.) 151 | -- 152 | -- 153 | -- ENCODING (from a lua table to a JSON string) 154 | -- 155 | -- JSON = assert(loadfile "JSON.lua")() -- one-time load of the routines 156 | -- 157 | -- local raw_json_text = JSON:encode(lua_table_or_value) 158 | -- local pretty_json_text = JSON:encode_pretty(lua_table_or_value) -- "pretty printed" version for human readability 159 | -- local custom_pretty = JSON:encode(lua_table_or_value, etc, { pretty = true, indent = "| ", align_keys = false }) 160 | -- 161 | -- On error during encoding, this code calls: 162 | -- 163 | -- JSON:onEncodeError(message, etc) 164 | -- 165 | -- which you can override in your local JSON object. 166 | -- 167 | -- The 'etc' in the error call is the second argument to encode() 168 | -- and encode_pretty(), or nil if it wasn't provided. 169 | -- 170 | -- 171 | -- PRETTY-PRINTING 172 | -- 173 | -- An optional third argument, a table of options, allows a bit of 174 | -- configuration about how the encoding takes place: 175 | -- 176 | -- pretty = JSON:encode(val, etc, { 177 | -- pretty = true, -- if false, no other options matter 178 | -- indent = " ", -- this provides for a three-space indent per nesting level 179 | -- align_keys = false, -- see below 180 | -- }) 181 | -- 182 | -- encode() and encode_pretty() are identical except that encode_pretty() 183 | -- provides a default options table if none given in the call: 184 | -- 185 | -- { pretty = true, align_keys = false, indent = " " } 186 | -- 187 | -- For example, if 188 | -- 189 | -- JSON:encode(data) 190 | -- 191 | -- produces: 192 | -- 193 | -- {"city":"Kyoto","climate":{"avg_temp":16,"humidity":"high","snowfall":"minimal"},"country":"Japan","wards":11} 194 | -- 195 | -- then 196 | -- 197 | -- JSON:encode_pretty(data) 198 | -- 199 | -- produces: 200 | -- 201 | -- { 202 | -- "city": "Kyoto", 203 | -- "climate": { 204 | -- "avg_temp": 16, 205 | -- "humidity": "high", 206 | -- "snowfall": "minimal" 207 | -- }, 208 | -- "country": "Japan", 209 | -- "wards": 11 210 | -- } 211 | -- 212 | -- The following three lines return identical results: 213 | -- JSON:encode_pretty(data) 214 | -- JSON:encode_pretty(data, nil, { pretty = true, align_keys = false, indent = " " }) 215 | -- JSON:encode (data, nil, { pretty = true, align_keys = false, indent = " " }) 216 | -- 217 | -- An example of setting your own indent string: 218 | -- 219 | -- JSON:encode_pretty(data, nil, { pretty = true, indent = "| " }) 220 | -- 221 | -- produces: 222 | -- 223 | -- { 224 | -- | "city": "Kyoto", 225 | -- | "climate": { 226 | -- | | "avg_temp": 16, 227 | -- | | "humidity": "high", 228 | -- | | "snowfall": "minimal" 229 | -- | }, 230 | -- | "country": "Japan", 231 | -- | "wards": 11 232 | -- } 233 | -- 234 | -- An example of setting align_keys to true: 235 | -- 236 | -- JSON:encode_pretty(data, nil, { pretty = true, indent = " ", align_keys = true }) 237 | -- 238 | -- produces: 239 | -- 240 | -- { 241 | -- "city": "Kyoto", 242 | -- "climate": { 243 | -- "avg_temp": 16, 244 | -- "humidity": "high", 245 | -- "snowfall": "minimal" 246 | -- }, 247 | -- "country": "Japan", 248 | -- "wards": 11 249 | -- } 250 | -- 251 | -- which I must admit is kinda ugly, sorry. This was the default for 252 | -- encode_pretty() prior to version 20141223.14. 253 | -- 254 | -- 255 | -- AMBIGUOUS SITUATIONS DURING THE ENCODING 256 | -- 257 | -- During the encode, if a Lua table being encoded contains both string 258 | -- and numeric keys, it fits neither JSON's idea of an object, nor its 259 | -- idea of an array. To get around this, when any string key exists (or 260 | -- when non-positive numeric keys exist), numeric keys are converted to 261 | -- strings. 262 | -- 263 | -- For example, 264 | -- JSON:encode({ "one", "two", "three", SOMESTRING = "some string" })) 265 | -- produces the JSON object 266 | -- {"1":"one","2":"two","3":"three","SOMESTRING":"some string"} 267 | -- 268 | -- To prohibit this conversion and instead make it an error condition, set 269 | -- JSON.noKeyConversion = true 270 | -- 271 | 272 | 273 | 274 | 275 | -- 276 | -- SUMMARY OF METHODS YOU CAN OVERRIDE IN YOUR LOCAL LUA JSON OBJECT 277 | -- 278 | -- assert 279 | -- onDecodeError 280 | -- onDecodeOfNilError 281 | -- onDecodeOfHTMLError 282 | -- onEncodeError 283 | -- 284 | -- If you want to create a separate Lua JSON object with its own error handlers, 285 | -- you can reload JSON.lua or use the :new() method. 286 | -- 287 | --------------------------------------------------------------------------- 288 | 289 | local default_pretty_indent = " " 290 | local default_pretty_options = { pretty = true, align_keys = false, indent = default_pretty_indent } 291 | 292 | local isArray = { __tostring = function() return "JSON array" end } isArray.__index = isArray 293 | local isObject = { __tostring = function() return "JSON object" end } isObject.__index = isObject 294 | 295 | 296 | function OBJDEF:newArray(tbl) 297 | return setmetatable(tbl or {}, isArray) 298 | end 299 | 300 | function OBJDEF:newObject(tbl) 301 | return setmetatable(tbl or {}, isObject) 302 | end 303 | 304 | local function unicode_codepoint_as_utf8(codepoint) 305 | -- 306 | -- codepoint is a number 307 | -- 308 | if codepoint <= 127 then 309 | return string.char(codepoint) 310 | 311 | elseif codepoint <= 2047 then 312 | -- 313 | -- 110yyyxx 10xxxxxx <-- useful notation from http://en.wikipedia.org/wiki/Utf8 314 | -- 315 | local highpart = math.floor(codepoint / 0x40) 316 | local lowpart = codepoint - (0x40 * highpart) 317 | return string.char(0xC0 + highpart, 318 | 0x80 + lowpart) 319 | 320 | elseif codepoint <= 65535 then 321 | -- 322 | -- 1110yyyy 10yyyyxx 10xxxxxx 323 | -- 324 | local highpart = math.floor(codepoint / 0x1000) 325 | local remainder = codepoint - 0x1000 * highpart 326 | local midpart = math.floor(remainder / 0x40) 327 | local lowpart = remainder - 0x40 * midpart 328 | 329 | highpart = 0xE0 + highpart 330 | midpart = 0x80 + midpart 331 | lowpart = 0x80 + lowpart 332 | 333 | -- 334 | -- Check for an invalid character (thanks Andy R. at Adobe). 335 | -- See table 3.7, page 93, in http://www.unicode.org/versions/Unicode5.2.0/ch03.pdf#G28070 336 | -- 337 | if ( highpart == 0xE0 and midpart < 0xA0 ) or 338 | ( highpart == 0xED and midpart > 0x9F ) or 339 | ( highpart == 0xF0 and midpart < 0x90 ) or 340 | ( highpart == 0xF4 and midpart > 0x8F ) 341 | then 342 | return "?" 343 | else 344 | return string.char(highpart, 345 | midpart, 346 | lowpart) 347 | end 348 | 349 | else 350 | -- 351 | -- 11110zzz 10zzyyyy 10yyyyxx 10xxxxxx 352 | -- 353 | local highpart = math.floor(codepoint / 0x40000) 354 | local remainder = codepoint - 0x40000 * highpart 355 | local midA = math.floor(remainder / 0x1000) 356 | remainder = remainder - 0x1000 * midA 357 | local midB = math.floor(remainder / 0x40) 358 | local lowpart = remainder - 0x40 * midB 359 | 360 | return string.char(0xF0 + highpart, 361 | 0x80 + midA, 362 | 0x80 + midB, 363 | 0x80 + lowpart) 364 | end 365 | end 366 | 367 | function OBJDEF:onDecodeError(message, text, location, etc) 368 | if text then 369 | if location then 370 | message = string.format("%s at char %d of: %s", message, location, text) 371 | else 372 | message = string.format("%s: %s", message, text) 373 | end 374 | end 375 | 376 | if etc ~= nil then 377 | message = message .. " (" .. OBJDEF:encode(etc) .. ")" 378 | end 379 | 380 | if self.assert then 381 | self.assert(false, message) 382 | else 383 | assert(false, message) 384 | end 385 | end 386 | 387 | OBJDEF.onDecodeOfNilError = OBJDEF.onDecodeError 388 | OBJDEF.onDecodeOfHTMLError = OBJDEF.onDecodeError 389 | 390 | function OBJDEF:onEncodeError(message, etc) 391 | if etc ~= nil then 392 | message = message .. " (" .. OBJDEF:encode(etc) .. ")" 393 | end 394 | 395 | if self.assert then 396 | self.assert(false, message) 397 | else 398 | assert(false, message) 399 | end 400 | end 401 | 402 | local function grok_number(self, text, start, etc) 403 | -- 404 | -- Grab the integer part 405 | -- 406 | local integer_part = text:match('^-?[1-9]%d*', start) 407 | or text:match("^-?0", start) 408 | 409 | if not integer_part then 410 | self:onDecodeError("expected number", text, start, etc) 411 | end 412 | 413 | local i = start + integer_part:len() 414 | 415 | -- 416 | -- Grab an optional decimal part 417 | -- 418 | local decimal_part = text:match('^%.%d+', i) or "" 419 | 420 | i = i + decimal_part:len() 421 | 422 | -- 423 | -- Grab an optional exponential part 424 | -- 425 | local exponent_part = text:match('^[eE][-+]?%d+', i) or "" 426 | 427 | i = i + exponent_part:len() 428 | 429 | local full_number_text = integer_part .. decimal_part .. exponent_part 430 | local as_number = tonumber(full_number_text) 431 | 432 | if not as_number then 433 | self:onDecodeError("bad number", text, start, etc) 434 | end 435 | 436 | return as_number, i 437 | end 438 | 439 | 440 | local function grok_string(self, text, start, etc) 441 | 442 | if text:sub(start,start) ~= '"' then 443 | self:onDecodeError("expected string's opening quote", text, start, etc) 444 | end 445 | 446 | local i = start + 1 -- +1 to bypass the initial quote 447 | local text_len = text:len() 448 | local VALUE = "" 449 | while i <= text_len do 450 | local c = text:sub(i,i) 451 | if c == '"' then 452 | return VALUE, i + 1 453 | end 454 | if c ~= '\\' then 455 | VALUE = VALUE .. c 456 | i = i + 1 457 | elseif text:match('^\\b', i) then 458 | VALUE = VALUE .. "\b" 459 | i = i + 2 460 | elseif text:match('^\\f', i) then 461 | VALUE = VALUE .. "\f" 462 | i = i + 2 463 | elseif text:match('^\\n', i) then 464 | VALUE = VALUE .. "\n" 465 | i = i + 2 466 | elseif text:match('^\\r', i) then 467 | VALUE = VALUE .. "\r" 468 | i = i + 2 469 | elseif text:match('^\\t', i) then 470 | VALUE = VALUE .. "\t" 471 | i = i + 2 472 | else 473 | local hex = text:match('^\\u([0123456789aAbBcCdDeEfF][0123456789aAbBcCdDeEfF][0123456789aAbBcCdDeEfF][0123456789aAbBcCdDeEfF])', i) 474 | if hex then 475 | i = i + 6 -- bypass what we just read 476 | 477 | -- We have a Unicode codepoint. It could be standalone, or if in the proper range and 478 | -- followed by another in a specific range, it'll be a two-code surrogate pair. 479 | local codepoint = tonumber(hex, 16) 480 | if codepoint >= 0xD800 and codepoint <= 0xDBFF then 481 | -- it's a hi surrogate... see whether we have a following low 482 | local lo_surrogate = text:match('^\\u([dD][cdefCDEF][0123456789aAbBcCdDeEfF][0123456789aAbBcCdDeEfF])', i) 483 | if lo_surrogate then 484 | i = i + 6 -- bypass the low surrogate we just read 485 | codepoint = 0x2400 + (codepoint - 0xD800) * 0x400 + tonumber(lo_surrogate, 16) 486 | else 487 | -- not a proper low, so we'll just leave the first codepoint as is and spit it out. 488 | end 489 | end 490 | VALUE = VALUE .. unicode_codepoint_as_utf8(codepoint) 491 | 492 | else 493 | 494 | -- just pass through what's escaped 495 | VALUE = VALUE .. text:match('^\\(.)', i) 496 | i = i + 2 497 | end 498 | end 499 | end 500 | 501 | self:onDecodeError("unclosed string", text, start, etc) 502 | end 503 | 504 | local function skip_whitespace(text, start) 505 | 506 | local _, match_end = text:find("^[ \n\r\t]+", start) -- [http://www.ietf.org/rfc/rfc4627.txt] Section 2 507 | if match_end then 508 | return match_end + 1 509 | else 510 | return start 511 | end 512 | end 513 | 514 | local grok_one -- assigned later 515 | 516 | local function grok_object(self, text, start, etc) 517 | if text:sub(start,start) ~= '{' then 518 | self:onDecodeError("expected '{'", text, start, etc) 519 | end 520 | 521 | local i = skip_whitespace(text, start + 1) -- +1 to skip the '{' 522 | 523 | local VALUE = self.strictTypes and self:newObject { } or { } 524 | 525 | if text:sub(i,i) == '}' then 526 | return VALUE, i + 1 527 | end 528 | local text_len = text:len() 529 | while i <= text_len do 530 | local key, new_i = grok_string(self, text, i, etc) 531 | 532 | i = skip_whitespace(text, new_i) 533 | 534 | if text:sub(i, i) ~= ':' then 535 | self:onDecodeError("expected colon", text, i, etc) 536 | end 537 | 538 | i = skip_whitespace(text, i + 1) 539 | 540 | local new_val, new_i = grok_one(self, text, i) 541 | 542 | VALUE[key] = new_val 543 | 544 | -- 545 | -- Expect now either '}' to end things, or a ',' to allow us to continue. 546 | -- 547 | i = skip_whitespace(text, new_i) 548 | 549 | local c = text:sub(i,i) 550 | 551 | if c == '}' then 552 | return VALUE, i + 1 553 | end 554 | 555 | if text:sub(i, i) ~= ',' then 556 | self:onDecodeError("expected comma or '}'", text, i, etc) 557 | end 558 | 559 | i = skip_whitespace(text, i + 1) 560 | end 561 | 562 | self:onDecodeError("unclosed '{'", text, start, etc) 563 | end 564 | 565 | local function grok_array(self, text, start, etc) 566 | if text:sub(start,start) ~= '[' then 567 | self:onDecodeError("expected '['", text, start, etc) 568 | end 569 | 570 | local i = skip_whitespace(text, start + 1) -- +1 to skip the '[' 571 | local VALUE = self.strictTypes and self:newArray { } or { } 572 | if text:sub(i,i) == ']' then 573 | return VALUE, i + 1 574 | end 575 | 576 | local VALUE_INDEX = 1 577 | 578 | local text_len = text:len() 579 | while i <= text_len do 580 | local val, new_i = grok_one(self, text, i) 581 | 582 | -- can't table.insert(VALUE, val) here because it's a no-op if val is nil 583 | VALUE[VALUE_INDEX] = val 584 | VALUE_INDEX = VALUE_INDEX + 1 585 | 586 | i = skip_whitespace(text, new_i) 587 | 588 | -- 589 | -- Expect now either ']' to end things, or a ',' to allow us to continue. 590 | -- 591 | local c = text:sub(i,i) 592 | if c == ']' then 593 | return VALUE, i + 1 594 | end 595 | if text:sub(i, i) ~= ',' then 596 | self:onDecodeError("expected comma or '['", text, i, etc) 597 | end 598 | i = skip_whitespace(text, i + 1) 599 | end 600 | self:onDecodeError("unclosed '['", text, start, etc) 601 | end 602 | 603 | 604 | grok_one = function(self, text, start, etc) 605 | -- Skip any whitespace 606 | start = skip_whitespace(text, start) 607 | 608 | if start > text:len() then 609 | self:onDecodeError("unexpected end of string", text, nil, etc) 610 | end 611 | 612 | if text:find('^"', start) then 613 | return grok_string(self, text, start, etc) 614 | 615 | elseif text:find('^[-0123456789 ]', start) then 616 | return grok_number(self, text, start, etc) 617 | 618 | elseif text:find('^%{', start) then 619 | return grok_object(self, text, start, etc) 620 | 621 | elseif text:find('^%[', start) then 622 | return grok_array(self, text, start, etc) 623 | 624 | elseif text:find('^true', start) then 625 | return true, start + 4 626 | 627 | elseif text:find('^false', start) then 628 | return false, start + 5 629 | 630 | elseif text:find('^null', start) then 631 | return nil, start + 4 632 | 633 | else 634 | self:onDecodeError("can't parse JSON", text, start, etc) 635 | end 636 | end 637 | 638 | function OBJDEF:decode(text, etc) 639 | if type(self) ~= 'table' or self.__index ~= OBJDEF then 640 | OBJDEF:onDecodeError("JSON:decode must be called in method format", nil, nil, etc) 641 | end 642 | 643 | if text == nil then 644 | self:onDecodeOfNilError(string.format("nil passed to JSON:decode()"), nil, nil, etc) 645 | elseif type(text) ~= 'string' then 646 | self:onDecodeError(string.format("expected string argument to JSON:decode(), got %s", type(text)), nil, nil, etc) 647 | end 648 | 649 | if text:match('^%s*$') then 650 | return nil 651 | end 652 | 653 | if text:match('^%s*<') then 654 | -- Can't be JSON... we'll assume it's HTML 655 | self:onDecodeOfHTMLError(string.format("html passed to JSON:decode()"), text, nil, etc) 656 | end 657 | 658 | -- 659 | -- Ensure that it's not UTF-32 or UTF-16. 660 | -- Those are perfectly valid encodings for JSON (as per RFC 4627 section 3), 661 | -- but this package can't handle them. 662 | -- 663 | if text:sub(1,1):byte() == 0 or (text:len() >= 2 and text:sub(2,2):byte() == 0) then 664 | self:onDecodeError("JSON package groks only UTF-8, sorry", text, nil, etc) 665 | end 666 | 667 | local success, value = pcall(grok_one, self, text, 1, etc) 668 | 669 | if success then 670 | return value 671 | else 672 | -- if JSON:onDecodeError() didn't abort out of the pcall, we'll have received the error message here as "value", so pass it along as an assert. 673 | if self.assert then 674 | self.assert(false, value) 675 | else 676 | assert(false, value) 677 | end 678 | -- and if we're still here, return a nil and throw the error message on as a second arg 679 | return nil, value 680 | end 681 | end 682 | 683 | local function backslash_replacement_function(c) 684 | if c == "\n" then 685 | return "\\n" 686 | elseif c == "\r" then 687 | return "\\r" 688 | elseif c == "\t" then 689 | return "\\t" 690 | elseif c == "\b" then 691 | return "\\b" 692 | elseif c == "\f" then 693 | return "\\f" 694 | elseif c == '"' then 695 | return '\\"' 696 | elseif c == '\\' then 697 | return '\\\\' 698 | else 699 | return string.format("\\u%04x", c:byte()) 700 | end 701 | end 702 | 703 | local chars_to_be_escaped_in_JSON_string 704 | = '[' 705 | .. '"' -- class sub-pattern to match a double quote 706 | .. '%\\' -- class sub-pattern to match a backslash 707 | .. '%z' -- class sub-pattern to match a null 708 | .. '\001' .. '-' .. '\031' -- class sub-pattern to match control characters 709 | .. ']' 710 | 711 | local function json_string_literal(value) 712 | local newval = value:gsub(chars_to_be_escaped_in_JSON_string, backslash_replacement_function) 713 | return '"' .. newval .. '"' 714 | end 715 | 716 | local function object_or_array(self, T, etc) 717 | -- 718 | -- We need to inspect all the keys... if there are any strings, we'll convert to a JSON 719 | -- object. If there are only numbers, it's a JSON array. 720 | -- 721 | -- If we'll be converting to a JSON object, we'll want to sort the keys so that the 722 | -- end result is deterministic. 723 | -- 724 | local string_keys = { } 725 | local number_keys = { } 726 | local number_keys_must_be_strings = false 727 | local maximum_number_key 728 | 729 | for key in pairs(T) do 730 | if type(key) == 'string' then 731 | table.insert(string_keys, key) 732 | elseif type(key) == 'number' then 733 | table.insert(number_keys, key) 734 | if key <= 0 or key >= math.huge then 735 | number_keys_must_be_strings = true 736 | elseif not maximum_number_key or key > maximum_number_key then 737 | maximum_number_key = key 738 | end 739 | else 740 | self:onEncodeError("can't encode table with a key of type " .. type(key), etc) 741 | end 742 | end 743 | 744 | if #string_keys == 0 and not number_keys_must_be_strings then 745 | -- 746 | -- An empty table, or a numeric-only array 747 | -- 748 | if #number_keys > 0 then 749 | return nil, maximum_number_key -- an array 750 | elseif tostring(T) == "JSON array" then 751 | return nil 752 | elseif tostring(T) == "JSON object" then 753 | return { } 754 | else 755 | -- have to guess, so we'll pick array, since empty arrays are likely more common than empty objects 756 | return nil 757 | end 758 | end 759 | 760 | table.sort(string_keys) 761 | 762 | local map 763 | if #number_keys > 0 then 764 | -- 765 | -- If we're here then we have either mixed string/number keys, or numbers inappropriate for a JSON array 766 | -- It's not ideal, but we'll turn the numbers into strings so that we can at least create a JSON object. 767 | -- 768 | 769 | if self.noKeyConversion then 770 | self:onEncodeError("a table with both numeric and string keys could be an object or array; aborting", etc) 771 | end 772 | 773 | -- 774 | -- Have to make a shallow copy of the source table so we can remap the numeric keys to be strings 775 | -- 776 | map = { } 777 | for key, val in pairs(T) do 778 | map[key] = val 779 | end 780 | 781 | table.sort(number_keys) 782 | 783 | -- 784 | -- Throw numeric keys in there as strings 785 | -- 786 | for _, number_key in ipairs(number_keys) do 787 | local string_key = tostring(number_key) 788 | if map[string_key] == nil then 789 | table.insert(string_keys , string_key) 790 | map[string_key] = T[number_key] 791 | else 792 | self:onEncodeError("conflict converting table with mixed-type keys into a JSON object: key " .. number_key .. " exists both as a string and a number.", etc) 793 | end 794 | end 795 | end 796 | 797 | return string_keys, nil, map 798 | end 799 | 800 | -- 801 | -- Encode 802 | -- 803 | -- 'options' is nil, or a table with possible keys: 804 | -- pretty -- if true, return a pretty-printed version 805 | -- indent -- a string (usually of spaces) used to indent each nested level 806 | -- align_keys -- if true, align all the keys when formatting a table 807 | -- 808 | local encode_value -- must predeclare because it calls itself 809 | function encode_value(self, value, parents, etc, options, indent) 810 | 811 | if value == nil then 812 | return 'null' 813 | 814 | elseif type(value) == 'string' then 815 | return json_string_literal(value) 816 | 817 | elseif type(value) == 'number' then 818 | if value ~= value then 819 | -- 820 | -- NaN (Not a Number). 821 | -- JSON has no NaN, so we have to fudge the best we can. This should really be a package option. 822 | -- 823 | return "null" 824 | elseif value >= math.huge then 825 | -- 826 | -- Positive infinity. JSON has no INF, so we have to fudge the best we can. This should 827 | -- really be a package option. Note: at least with some implementations, positive infinity 828 | -- is both ">= math.huge" and "<= -math.huge", which makes no sense but that's how it is. 829 | -- Negative infinity is properly "<= -math.huge". So, we must be sure to check the ">=" 830 | -- case first. 831 | -- 832 | return "1e+9999" 833 | elseif value <= -math.huge then 834 | -- 835 | -- Negative infinity. 836 | -- JSON has no INF, so we have to fudge the best we can. This should really be a package option. 837 | -- 838 | return "-1e+9999" 839 | else 840 | return tostring(value) 841 | end 842 | 843 | elseif type(value) == 'boolean' then 844 | return tostring(value) 845 | 846 | elseif type(value) ~= 'table' then 847 | self:onEncodeError("can't convert " .. type(value) .. " to JSON", etc) 848 | 849 | else 850 | -- 851 | -- A table to be converted to either a JSON object or array. 852 | -- 853 | local T = value 854 | 855 | if type(options) ~= 'table' then 856 | options = {} 857 | end 858 | if type(indent) ~= 'string' then 859 | indent = "" 860 | end 861 | 862 | if parents[T] then 863 | self:onEncodeError("table " .. tostring(T) .. " is a child of itself", etc) 864 | else 865 | parents[T] = true 866 | end 867 | 868 | local result_value 869 | 870 | local object_keys, maximum_number_key, map = object_or_array(self, T, etc) 871 | if maximum_number_key then 872 | -- 873 | -- An array... 874 | -- 875 | local ITEMS = { } 876 | for i = 1, maximum_number_key do 877 | table.insert(ITEMS, encode_value(self, T[i], parents, etc, options, indent)) 878 | end 879 | 880 | if options.pretty then 881 | result_value = "[ " .. table.concat(ITEMS, ", ") .. " ]" 882 | else 883 | result_value = "[" .. table.concat(ITEMS, ",") .. "]" 884 | end 885 | 886 | elseif object_keys then 887 | -- 888 | -- An object 889 | -- 890 | local TT = map or T 891 | 892 | if options.pretty then 893 | 894 | local KEYS = { } 895 | local max_key_length = 0 896 | for _, key in ipairs(object_keys) do 897 | local encoded = encode_value(self, tostring(key), parents, etc, options, indent) 898 | if options.align_keys then 899 | max_key_length = math.max(max_key_length, #encoded) 900 | end 901 | table.insert(KEYS, encoded) 902 | end 903 | local key_indent = indent .. tostring(options.indent or "") 904 | local subtable_indent = key_indent .. string.rep(" ", max_key_length) .. (options.align_keys and " " or "") 905 | local FORMAT = "%s%" .. string.format("%d", max_key_length) .. "s: %s" 906 | 907 | local COMBINED_PARTS = { } 908 | for i, key in ipairs(object_keys) do 909 | local encoded_val = encode_value(self, TT[key], parents, etc, options, subtable_indent) 910 | table.insert(COMBINED_PARTS, string.format(FORMAT, key_indent, KEYS[i], encoded_val)) 911 | end 912 | result_value = "{\n" .. table.concat(COMBINED_PARTS, ",\n") .. "\n" .. indent .. "}" 913 | 914 | else 915 | 916 | local PARTS = { } 917 | for _, key in ipairs(object_keys) do 918 | local encoded_val = encode_value(self, TT[key], parents, etc, options, indent) 919 | local encoded_key = encode_value(self, tostring(key), parents, etc, options, indent) 920 | table.insert(PARTS, string.format("%s:%s", encoded_key, encoded_val)) 921 | end 922 | result_value = "{" .. table.concat(PARTS, ",") .. "}" 923 | 924 | end 925 | else 926 | -- 927 | -- An empty array/object... we'll treat it as an array, though it should really be an option 928 | -- 929 | result_value = "[]" 930 | end 931 | 932 | parents[T] = false 933 | return result_value 934 | end 935 | end 936 | 937 | 938 | function OBJDEF:encode(value, etc, options) 939 | if type(self) ~= 'table' or self.__index ~= OBJDEF then 940 | OBJDEF:onEncodeError("JSON:encode must be called in method format", etc) 941 | end 942 | return encode_value(self, value, {}, etc, options or nil) 943 | end 944 | 945 | function OBJDEF:encode_pretty(value, etc, options) 946 | if type(self) ~= 'table' or self.__index ~= OBJDEF then 947 | OBJDEF:onEncodeError("JSON:encode_pretty must be called in method format", etc) 948 | end 949 | return encode_value(self, value, {}, etc, options or default_pretty_options) 950 | end 951 | 952 | function OBJDEF.__tostring() 953 | return "JSON encode/decode package" 954 | end 955 | 956 | OBJDEF.__index = OBJDEF 957 | 958 | function OBJDEF:new(args) 959 | local new = { } 960 | 961 | if args then 962 | for key, val in pairs(args) do 963 | new[key] = val 964 | end 965 | end 966 | 967 | return setmetatable(new, OBJDEF) 968 | end 969 | 970 | return OBJDEF:new() 971 | 972 | -- 973 | -- Version history: 974 | -- 975 | -- 20141223.14 The encode_pretty() routine produced fine results for small datasets, but isn't really 976 | -- appropriate for anything large, so with help from Alex Aulbach I've made the encode routines 977 | -- more flexible, and changed the default encode_pretty() to be more generally useful. 978 | -- 979 | -- Added a third 'options' argument to the encode() and encode_pretty() routines, to control 980 | -- how the encoding takes place. 981 | -- 982 | -- Updated docs to add assert() call to the loadfile() line, just as good practice so that 983 | -- if there is a problem loading JSON.lua, the appropriate error message will percolate up. 984 | -- 985 | -- 20140920.13 Put back (in a way that doesn't cause warnings about unused variables) the author string, 986 | -- so that the source of the package, and its version number, are visible in compiled copies. 987 | -- 988 | -- 20140911.12 Minor lua cleanup. 989 | -- Fixed internal reference to 'JSON.noKeyConversion' to reference 'self' instead of 'JSON'. 990 | -- (Thanks to SmugMug's David Parry for these.) 991 | -- 992 | -- 20140418.11 JSON nulls embedded within an array were being ignored, such that 993 | -- ["1",null,null,null,null,null,"seven"], 994 | -- would return 995 | -- {1,"seven"} 996 | -- It's now fixed to properly return 997 | -- {1, nil, nil, nil, nil, nil, "seven"} 998 | -- Thanks to "haddock" for catching the error. 999 | -- 1000 | -- 20140116.10 The user's JSON.assert() wasn't always being used. Thanks to "blue" for the heads up. 1001 | -- 1002 | -- 20131118.9 Update for Lua 5.3... it seems that tostring(2/1) produces "2.0" instead of "2", 1003 | -- and this caused some problems. 1004 | -- 1005 | -- 20131031.8 Unified the code for encode() and encode_pretty(); they had been stupidly separate, 1006 | -- and had of course diverged (encode_pretty didn't get the fixes that encode got, so 1007 | -- sometimes produced incorrect results; thanks to Mattie for the heads up). 1008 | -- 1009 | -- Handle encoding tables with non-positive numeric keys (unlikely, but possible). 1010 | -- 1011 | -- If a table has both numeric and string keys, or its numeric keys are inappropriate 1012 | -- (such as being non-positive or infinite), the numeric keys are turned into 1013 | -- string keys appropriate for a JSON object. So, as before, 1014 | -- JSON:encode({ "one", "two", "three" }) 1015 | -- produces the array 1016 | -- ["one","two","three"] 1017 | -- but now something with mixed key types like 1018 | -- JSON:encode({ "one", "two", "three", SOMESTRING = "some string" })) 1019 | -- instead of throwing an error produces an object: 1020 | -- {"1":"one","2":"two","3":"three","SOMESTRING":"some string"} 1021 | -- 1022 | -- To maintain the prior throw-an-error semantics, set 1023 | -- JSON.noKeyConversion = true 1024 | -- 1025 | -- 20131004.7 Release under a Creative Commons CC-BY license, which I should have done from day one, sorry. 1026 | -- 1027 | -- 20130120.6 Comment update: added a link to the specific page on my blog where this code can 1028 | -- be found, so that folks who come across the code outside of my blog can find updates 1029 | -- more easily. 1030 | -- 1031 | -- 20111207.5 Added support for the 'etc' arguments, for better error reporting. 1032 | -- 1033 | -- 20110731.4 More feedback from David Kolf on how to make the tests for Nan/Infinity system independent. 1034 | -- 1035 | -- 20110730.3 Incorporated feedback from David Kolf at http://lua-users.org/wiki/JsonModules: 1036 | -- 1037 | -- * When encoding lua for JSON, Sparse numeric arrays are now handled by 1038 | -- spitting out full arrays, such that 1039 | -- JSON:encode({"one", "two", [10] = "ten"}) 1040 | -- returns 1041 | -- ["one","two",null,null,null,null,null,null,null,"ten"] 1042 | -- 1043 | -- In 20100810.2 and earlier, only up to the first non-null value would have been retained. 1044 | -- 1045 | -- * When encoding lua for JSON, numeric value NaN gets spit out as null, and infinity as "1+e9999". 1046 | -- Version 20100810.2 and earlier created invalid JSON in both cases. 1047 | -- 1048 | -- * Unicode surrogate pairs are now detected when decoding JSON. 1049 | -- 1050 | -- 20100810.2 added some checking to ensure that an invalid Unicode character couldn't leak in to the UTF-8 encoding 1051 | -- 1052 | -- 20100731.1 initial public release 1053 | -- 1054 | -------------------------------------------------------------------------------- /utils/misc.lua: -------------------------------------------------------------------------------- 1 | local utils = {} 2 | 3 | -- Preprocess an image before passing it to a Caffe model. 4 | -- We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, 5 | -- and subtract the mean pixel. 6 | function utils.preprocess(img) 7 | local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) 8 | local perm = torch.LongTensor{3, 2, 1} 9 | img = img:index(1, perm):mul(256.0) 10 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 11 | img:add(-1, mean_pixel) 12 | return img 13 | end 14 | 15 | 16 | -- Undo the above preprocessing. 17 | function utils.deprocess(img) 18 | local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) 19 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 20 | img = img + mean_pixel 21 | local perm = torch.LongTensor{3, 2, 1} 22 | img = img:index(1, perm):div(256.0) 23 | return img 24 | end 25 | 26 | -- from https://github.com/karpathy/char-rnn/blob/master/util/model_utils.lua 27 | -- which is in turn adapted from https://github.com/wojciechz/learning_to_execute 28 | 29 | function utils.combine_all_parameters(...) 30 | --[[ like module:getParameters, but operates on many modules ]]-- 31 | 32 | -- get parameters 33 | local networks = {...} 34 | local parameters = {} 35 | local gradParameters = {} 36 | for i = 1, #networks do 37 | local net_params, net_grads = networks[i]:parameters() 38 | 39 | if net_params then 40 | for _, p in pairs(net_params) do 41 | parameters[#parameters + 1] = p 42 | end 43 | for _, g in pairs(net_grads) do 44 | gradParameters[#gradParameters + 1] = g 45 | end 46 | end 47 | end 48 | 49 | local function storageInSet(set, storage) 50 | local storageAndOffset = set[torch.pointer(storage)] 51 | if storageAndOffset == nil then 52 | return nil 53 | end 54 | local _, offset = unpack(storageAndOffset) 55 | return offset 56 | end 57 | 58 | -- this function flattens arbitrary lists of parameters, 59 | -- even complex shared ones 60 | local function flatten(parameters) 61 | if not parameters or #parameters == 0 then 62 | return torch.Tensor() 63 | end 64 | local Tensor = parameters[1].new 65 | 66 | local storages = {} 67 | local nParameters = 0 68 | for k = 1,#parameters do 69 | local storage = parameters[k]:storage() 70 | if not storageInSet(storages, storage) then 71 | storages[torch.pointer(storage)] = {storage, nParameters} 72 | nParameters = nParameters + storage:size() 73 | end 74 | end 75 | 76 | local flatParameters = Tensor(nParameters):fill(1) 77 | local flatStorage = flatParameters:storage() 78 | 79 | for k = 1,#parameters do 80 | local storageOffset = storageInSet(storages, parameters[k]:storage()) 81 | parameters[k]:set(flatStorage, 82 | storageOffset + parameters[k]:storageOffset(), 83 | parameters[k]:size(), 84 | parameters[k]:stride()) 85 | parameters[k]:zero() 86 | end 87 | 88 | local maskParameters= flatParameters:float():clone() 89 | local cumSumOfHoles = flatParameters:float():cumsum(1) 90 | local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] 91 | local flatUsedParameters = Tensor(nUsedParameters) 92 | local flatUsedStorage = flatUsedParameters:storage() 93 | 94 | for k = 1,#parameters do 95 | local offset = cumSumOfHoles[parameters[k]:storageOffset()] 96 | parameters[k]:set(flatUsedStorage, 97 | parameters[k]:storageOffset() - offset, 98 | parameters[k]:size(), 99 | parameters[k]:stride()) 100 | end 101 | 102 | for _, storageAndOffset in pairs(storages) do 103 | local k, v = unpack(storageAndOffset) 104 | flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) 105 | end 106 | 107 | if cumSumOfHoles:sum() == 0 then 108 | flatUsedParameters:copy(flatParameters) 109 | else 110 | local counter = 0 111 | for k = 1,flatParameters:nElement() do 112 | if maskParameters[k] == 0 then 113 | counter = counter + 1 114 | flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] 115 | end 116 | end 117 | assert (counter == nUsedParameters) 118 | end 119 | return flatUsedParameters 120 | end 121 | 122 | -- flatten parameters and gradients 123 | local flatParameters = flatten(parameters) 124 | local flatGradParameters = flatten(gradParameters) 125 | 126 | -- return new flat vector that contains all discrete parameters 127 | return flatParameters, flatGradParameters 128 | end 129 | 130 | function utils.clone_many_times(net, T) 131 | local clones = {} 132 | for i = 1, T do 133 | clones[i] = net:clone('weight', 'bias', 'gradWeight', 'gradBias') 134 | end 135 | return clones 136 | end 137 | 138 | function utils.clone_list(tensor_list, zero_too) 139 | -- takes a list of tensors and returns a list of cloned tensors 140 | local out = {} 141 | for k,v in pairs(tensor_list) do 142 | out[k] = v:clone() 143 | if zero_too then out[k]:zero() end 144 | end 145 | return out 146 | end 147 | 148 | return utils --------------------------------------------------------------------------------