├── .gitignore ├── LICENSE.txt ├── README.md ├── convert_gpu_to_cpu.lua ├── data ├── prepro.py ├── prepro_img_resnet.lua ├── prepro_img_vgg16.lua └── prepro_utils.lua ├── dataloader.lua ├── decoders ├── disc.lua └── gen.lua ├── encoders ├── hre-ques-hist.lua ├── hre-ques-im-hist.lua ├── hrea-ques-im-hist.lua ├── lf-att-ques-im-hist.lua ├── lf-ques-hist.lua ├── lf-ques-im-hist.lua ├── lf-ques-im.lua ├── lf-ques.lua ├── mn-att-ques-im-hist.lua ├── mn-ques-hist.lua └── mn-ques-im-hist.lua ├── evaluate.lua ├── generate.lua ├── model.lua ├── model_utils ├── MaskFuture.lua ├── MaskSoftMax.lua ├── MaskTime.lua ├── ReplaceZero.lua ├── optim_updates.lua └── weight-init.lua ├── opts.lua ├── scripts └── download_model.sh ├── train.lua ├── utils.lua └── vis ├── index.html └── static ├── bootstrap.min.css ├── jquery-3.2.1.min.js └── main.js /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | logs 3 | tmp 4 | checkpoints 5 | 6 | vis/results 7 | 8 | *scripts* 9 | *internal* 10 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For visdial software 4 | 5 | Copyright (c) 2017-present, Machine Learning & Perception Lab, Georgia Tech. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its contributors 18 | may be used to endorse or promote products derived from this software 19 | without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisDial 2 | 3 | Code for the paper 4 | 5 | **[Visual Dialog][1]** 6 | Abhishek Das, Satwik Kottur, Khushi Gupta, Avi Singh, Deshraj Yadav, José M. F. Moura, Devi Parikh, Dhruv Batra 7 | [arxiv.org/abs/1611.08669][1] 8 | [CVPR 2017][10] (Spotlight) 9 | 10 | **Visual Dialog** requires an AI agent to hold a meaningful dialog with humans in natural, conversational language about visual content. Given an image, dialog history, and a follow-up question about the image, the AI agent has to answer the question. 11 | 12 | Demo: [demo.visualdialog.org][11] 13 | 14 | 15 | 16 | This repository contains code for **training**, **evaluating** and **visualizing results** for all combinations of encoder-decoder architectures described in the paper. Specifically, we have 3 encoders: **Late Fusion** (LF), **Hierarchical Recurrent Encoder** (HRE), **Memory Network** (MN), and 2 kinds of decoding: **Generative** (G) and **Discriminative** (D). 17 | 18 | [![models](http://i.imgur.com/mdSOZPj.jpg)][1] 19 | 20 | If you find this code useful, consider citing our work: 21 | 22 | ``` 23 | @inproceedings{visdial, 24 | title={{V}isual {D}ialog}, 25 | author={Abhishek Das and Satwik Kottur and Khushi Gupta and Avi Singh 26 | and Deshraj Yadav and Jos\'e M.F. Moura and Devi Parikh and Dhruv Batra}, 27 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 28 | year={2017} 29 | } 30 | ``` 31 | 32 | ## Setup 33 | 34 | All our code is implemented in [Torch][13] (Lua). Installation instructions are as follows: 35 | 36 | ```sh 37 | git clone https://github.com/torch/distro.git ~/torch --recursive 38 | cd ~/torch; bash install-deps; 39 | TORCH_LUA_VERSION=LUA51 ./install.sh 40 | ``` 41 | 42 | Additionally, our code uses the following packages: [torch/torch7][14], [torch/nn][15], [torch/nngraph][16], [Element-Research/rnn][17], [torch/image][18], [lua-cjson][19], [loadcaffe][20], [torch-hdf5][25]. After Torch is installed, these can be installed/updated using: 43 | 44 | ```sh 45 | luarocks install torch 46 | luarocks install nn 47 | luarocks install nngraph 48 | luarocks install image 49 | luarocks install lua-cjson 50 | luarocks install loadcaffe 51 | luarocks install luabitop 52 | luarocks install totem 53 | ``` 54 | 55 | **NOTE**: `luarocks install rnn` defaults to [torch/rnn][33], follow these steps to install [Element-Research/rnn][17]. 56 | 57 | ```sh 58 | git clone https://github.com/Element-Research/rnn.git 59 | cd rnn 60 | luarocks make rocks/rnn-scm-1.rockspec 61 | ``` 62 | 63 | Installation instructions for torch-hdf5 are given [here][26]. 64 | 65 | **NOTE**: torch-hdf5 does not work with few versions of gcc. It is recommended that you use gcc 4.8 / gcc 4.9 with Lua 5.1 for proper installation of torch-hdf5. 66 | 67 | ### Running on GPUs 68 | 69 | Although our code should work on CPUs, it is *highly* recommended to use GPU acceleration with [CUDA][21]. You'll also need [torch/cutorch][22], [torch/cudnn][31] and [torch/cunn][23]. 70 | 71 | ```sh 72 | luarocks install cutorch 73 | luarocks install cunn 74 | luarocks install cudnn 75 | ``` 76 | 77 | ## Training your own network 78 | 79 | ### Preprocessing VisDial 80 | 81 | The preprocessing script is in Python, and you'll need to install [NLTK][24]. 82 | 83 | ```sh 84 | pip install nltk 85 | pip install numpy 86 | pip install h5py 87 | python -c "import nltk; nltk.download('all')" 88 | ``` 89 | 90 | [VisDial v1.0][27] dataset can be downloaded and preprocessed as specified below. The path provided as `-image_root` must have four subdirectories - [`train2014`][34] and [`val2014`][35] as per COCO dataset, `VisualDialog_val2018` and `VisualDialog_test2018` which can be downloaded from [here][27]. 91 | 92 | ```sh 93 | cd data 94 | python prepro.py -download -image_root /path/to/images 95 | cd .. 96 | ``` 97 | 98 | To download and preprocess [Visdial v0.9][27] dataset, provide an extra `-version 0.9` argument while execution. 99 | 100 | This script will generate the files `data/visdial_data.h5` (contains tokenized captions, questions, answers, image indices) and `data/visdial_params.json` (contains vocabulary mappings and COCO image ids). 101 | 102 | 103 | ### Extracting image features 104 | 105 | Since we don't finetune the CNN, training is significantly faster if image features are pre-extracted. Currently this repository provides support for extraction from VGG-16 and ResNets. We use image features from [VGG-16][28]. The VGG-16 model can be downloaded and features extracted using: 106 | 107 | ```sh 108 | sh scripts/download_model.sh vgg 16 # works for 19 as well 109 | cd data 110 | # For all models except mn-att-ques-im-hist 111 | th prepro_img_vgg16.lua -imageRoot /path/to/images -gpuid 0 112 | # For mn-att-ques-im-hist 113 | th prepro_img_vgg16.lua -imageRoot /path/to/images -imgSize 448 -layerName pool5 -gpuid 0 114 | ``` 115 | 116 | Similarly, [ResNet models][32] released by Facebook can be used for feature extraction. Feature extraction can be carried out in a similar manner as VGG-16: 117 | 118 | ```sh 119 | sh scripts/download_model.sh resnet 200 # works for 18, 34, 50, 101, 152 as well 120 | cd data 121 | th prepro_img_resnet.lua -imageRoot /path/to/images -cnnModel /path/to/t7/model -gpuid 0 122 | ``` 123 | 124 | Running either of these should generate `data/data_img.h5` containing features for `train`, `val` and `test` splits corresponding to VisDial v1.0. 125 | 126 | 127 | ### Training 128 | 129 | Finally, we can get to training models! All supported encoders are in the `encoders/` folder (`lf-ques`, `lf-ques-im`, `lf-ques-hist`, `lf-ques-im-hist`, `hre-ques-hist`, `hre-ques-im-hist`, `hrea-ques-im-hist`, `mn-ques-hist`, `mn-ques-im-hist`, `mn-att-ques-im-hist`), and decoders in the `decoders/` folder (`gen` and `disc`). 130 | 131 | **Generative** (`gen`) decoding tries to maximize likelihood of ground-truth response and only has access to single input-output pairs of dialog, while **discriminative** (`disc`) decoding makes use of 100 candidate option responses provided for every round of dialog, and maximizes likelihood of correct option. 132 | 133 | Encoders and decoders can be arbitrarily plugged together. For example, to train an HRE model with question and history information only (no images), and generative decoding: 134 | 135 | ```sh 136 | th train.lua -encoder hre-ques-hist -decoder gen -gpuid 0 137 | ``` 138 | 139 | Similarly, to train a Memory Network model with question, image and history information, and discriminative decoding: 140 | 141 | ```sh 142 | th train.lua -encoder mn-ques-im-hist -decoder disc -gpuid 0 143 | ``` 144 | 145 | **Note:** For attention based encoders, set both `imgSpatialSize` and `imgFeatureSize` command line params, feature dimensions are interpreted as `(batch X spatial X spatial X feature)`. For other encoders, `imgSpatialSize` is redundant. 146 | 147 | The training script saves model snapshots at regular intervals in the `checkpoints/` folder. 148 | 149 | It takes about 15-20 epochs to train models with generative decoding to convergence, and 4-8 epochs for discriminative decoding. 150 | 151 | ## Evaluation 152 | 153 | We evaluate model performance by where it ranks human response given 100 response options for every round of dialog, based on retrieval metrics — mean reciprocal rank, R@1, R@5, R@10, mean rank. 154 | 155 | Model evaluation can be run using: 156 | 157 | ```sh 158 | th evaluate.lua -loadPath checkpoints/model.t7 -gpuid 0 159 | ``` 160 | 161 | Note that evaluation requires image features `data/data_img.h5`, tokenized dialogs `data/visdial_data.h5` and vocabulary mappings `data/visdial_params.json`. 162 | 163 | ## Running Beam Search & Visualizing Results 164 | 165 | We also include code for running beam search on your model snapshots. This gives significantly nicer results than argmax decoding, and can be run as follows: 166 | 167 | ```sh 168 | th generate.lua -loadPath checkpoints/model.t7 -maxThreads 50 169 | ``` 170 | 171 | This would compute predictions for 50 threads from the `val` split and save results in `vis/results/results.json`. 172 | 173 | ```sh 174 | cd vis 175 | # python 3.6 176 | python -m http.server 177 | # python 2.7 178 | # python -m SimpleHTTPServer 179 | ``` 180 | 181 | Now visit `localhost:8000` in your browser to see generated results. 182 | 183 | Sample results from HRE-QIH-G available [here](https://computing.ece.vt.edu/~abhshkdz/visdial/browse_results/). 184 | 185 | ![](http://i.imgur.com/R3HJ2E5.gif) 186 | 187 | ## Download Extracted Features & Pretrained Models 188 | 189 | ### v0.9 190 | 191 | Extracted features for v0.9 train and val are available for download. 192 | 193 | * [`visdial_data.h5`](https://s3.amazonaws.com/visual-dialog/data/v0.9/visdial_data.h5): Tokenized captions, questions, answers, image indices 194 | * [`visdial_params.json`](https://s3.amazonaws.com/visual-dialog/data/v0.9/visdial_params.json): Vocabulary mappings and COCO image ids 195 | * [`data_img_vgg16_relu7.h5`](https://s3.amazonaws.com/visual-dialog/data/v0.9/data_img_vgg16_relu7.h5): VGG16 `relu7` image features 196 | * [`data_img_vgg16_pool5.h5`](https://s3.amazonaws.com/visual-dialog/data/v0.9/data_img_vgg16_pool5.h5): VGG16 `pool5` image features 197 | 198 | #### Pretrained models 199 | 200 | Trained on v0.9 `train`, results on v0.9 `val`. 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 |
EncoderDecoderCNNMRRR@1R@5R@10MRDownload
lf-quesgenVGG-160.50480.39740.60670.664917.8003lf-ques-gen-vgg16-18
lf-ques-histgenVGG-160.50990.40120.61550.674017.3974lf-ques-hist-gen-vgg16-18
lf-ques-imgenVGG-160.52060.42060.61650.676017.0578lf-ques-im-gen-vgg16-22
lf-ques-im-histgenVGG-160.51460.40860.62050.682816.7553lf-ques-im-hist-gen-vgg16-26
lf-att-ques-im-histgenVGG-160.53540.43540.63550.694116.7663lf-att-ques-im-hist-gen-vgg16-80
hre-ques-histgenVGG-160.50890.40000.61540.673917.3618hre-ques-hist-gen-vgg16-18
hre-ques-im-histgenVGG-160.52370.42230.62280.681116.9669hre-ques-im-hist-gen-vgg16-14
hrea-ques-im-histgenVGG-160.52380.42130.62440.684216.6044hrea-ques-im-hist-gen-vgg16-24
mn-ques-histgenVGG-160.51310.40570.61760.677017.6253mn-ques-hist-gen-vgg16-102
mn-ques-im-histgenVGG-160.52580.42290.62740.687416.9871mn-ques-im-hist-gen-vgg16-78
mn-att-ques-im-histgenVGG-160.53410.43540.63180.690317.0726mn-att-ques-im-hist-gen-vgg16-100
lf-quesdiscVGG-160.54910.41130.70200.79647.1519lf-ques-disc-vgg16-10
lf-ques-histdiscVGG-160.57240.43190.73080.82516.2847lf-ques-hist-disc-vgg16-8
lf-ques-imdiscVGG-160.57450.43310.73980.83405.9801lf-ques-im-disc-vgg16-12
lf-ques-im-histdiscVGG-160.59110.44900.75630.84935.5493lf-ques-im-hist-disc-vgg16-8
lf-att-ques-im-histdiscVGG-160.60790.46920.77310.86355.1965lf-att-ques-im-hist-disc-vgg16-20
hre-ques-histdiscVGG-160.56680.42650.72450.82076.3701hre-ques-hist-disc-vgg16-4
hre-ques-im-histdiscVGG-160.58180.44610.73730.83425.9647hre-ques-im-hist-disc-vgg16-4
hrea-ques-im-histdiscVGG-160.58210.44560.73780.83415.9646hrea-ques-im-hist-disc-vgg16-4
mn-ques-histdiscVGG-160.58310.43880.75070.84345.8090mn-ques-hist-disc-vgg16-20
mn-ques-im-histdiscVGG-160.59710.45620.76270.85395.4218mn-ques-im-hist-disc-vgg16-12
mn-att-ques-im-histdiscVGG-160.60820.47000.77240.86235.2930mn-att-ques-im-hist-disc-vgg16-28
277 | 278 | ### v1.0 279 | 280 | Extracted features for v1.0 train, val and test are available for download. 281 | 282 | * [`visdial_data_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/visdial_data_train.h5): Tokenized captions, questions, answers, image indices, for training on `train` 283 | * [`visdial_params_train.json`](https://s3.amazonaws.com/visual-dialog/data/v1.0/visdial_params_train.json): Vocabulary mappings and COCO image ids for training on `train` 284 | * [`data_img_vgg16_relu7_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/data_img_vgg16_relu7_train.h5): VGG16 `relu7` image features for training on `train` 285 | * [`data_img_vgg16_pool5_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/data_img_vgg16_pool5_train.h5): VGG16 `pool5` image features for training on `train` 286 | * [`visdial_data_trainval.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/visdial_data_trainval.h5): Tokenized captions, questions, answers, image indices, for training on `train`+`val` 287 | * [`visdial_params_trainval.json`](https://s3.amazonaws.com/visual-dialog/data/v1.0/visdial_params_trainval.json): Vocabulary mappings and COCO image ids for training on `train`+`val` 288 | * [`data_img_vgg16_relu7_trainval.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/data_img_vgg16_relu7_trainval.h5): VGG16 `relu7` image features for training on `train`+`val` 289 | * [`data_img_vgg16_pool5_trainval.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/data_img_vgg16_pool5_trainval.h5): VGG16 `pool5` image features for training on `train`+`val` 290 | 291 | #### Pretrained models 292 | 293 | Trained on v1.0 `train` + v1.0 `val`, results on v1.0 `test-std`. Leaderboard [here][evalai-leaderboard]. 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 |
EncoderDecoderCNNNDCGMRRR@1R@5R@10MRDownload
lf-ques-im-histgenVGG-160.51210.456835.0855.9264.0218.8140lf-ques-im-hist-gen-vgg16-24
hre-ques-im-histgenVGG-160.52450.456134.7856.1863.7218.7778hre-ques-im-hist-gen-vgg16-20
mn-ques-im-histgenVGG-160.52800.458035.0556.3563.9219.3128mn-ques-im-hist-gen-vgg16-92
lf-att-ques-im-histgenVGG-160.53620.469736.5857.4064.4818.9550lf-att-ques-im-hist-gen-vgg16-82
mn-att-ques-im-histgenVGG-160.53670.465036.0056.8064.2519.3470mn-att-ques-im-hist-gen-vgg16-100
lf-ques-im-histdiscVGG-160.45310.554240.9572.4582.835.9532lf-ques-im-hist-disc-vgg16-8
hre-ques-im-histdiscVGG-160.45460.541639.9370.4581.506.4082hre-ques-im-hist-disc-vgg16-4
mn-ques-im-histdiscVGG-160.47500.554940.9872.3083.305.9245mn-ques-im-hist-disc-vgg16-12
lf-att-ques-im-histdiscVGG-160.49760.570742.0874.8285.055.4092lf-att-ques-im-hist-disc-vgg16-24
mn-att-ques-im-histdiscVGG-160.49580.569042.4274.0084.355.5852mn-att-ques-im-hist-disc-vgg16-24
334 | 335 | ## License 336 | 337 | BSD 338 | 339 | 340 | [1]: https://arxiv.org/abs/1611.08669 341 | [2]: https://abhishekdas.com 342 | [3]: https://satwikkottur.github.io 343 | [4]: http://www.linkedin.com/in/khushi-gupta-9a678448 344 | [5]: http://people.eecs.berkeley.edu/~avisingh/ 345 | [6]: http://deshraj.github.io 346 | [7]: http://users.ece.cmu.edu/~moura/ 347 | [8]: http://www.cc.gatech.edu/~parikh/ 348 | [9]: http://www.cc.gatech.edu/~dbatra 349 | [10]: http://cvpr2017.thecvf.com/ 350 | [11]: http://demo.visualdialog.org 351 | [12]: https://vimeo.com/193092429 352 | [13]: http://torch.ch/ 353 | [14]: https://github.com/torch/torch7 354 | [15]: https://github.com/torch/nn 355 | [16]: https://github.com/torch/nngraph 356 | [17]: https://github.com/Element-Research/rnn/ 357 | [18]: https://github.com/torch/image 358 | [19]: https://luarocks.org/modules/luarocks/lua-cjson 359 | [20]: https://github.com/szagoruyko/loadcaffe 360 | [21]: https://developer.nvidia.com/cuda-toolkit 361 | [22]: https://github.com/torch/cutorch 362 | [23]: https://github.com/torch/cunn 363 | [24]: http://www.nltk.org/ 364 | [25]: https://github.com/deepmind/torch-hdf5 365 | [26]: https://github.com/deepmind/torch-hdf5/blob/master/doc/usage.md 366 | [27]: https://visualdialog.org/data 367 | [28]: http://www.robots.ox.ac.uk/~vgg/research/very_deep/ 368 | [31]: https://www.github.com/soumith/cudnn.torch 369 | [32]: https://github.com/facebook/fb.resnet.torch/tree/master/pretrained 370 | [33]: https://github.com/torch/rnn 371 | [34]: http://images.cocodataset.org/zips/train2014.zip 372 | [35]: http://images.cocodataset.org/zips/val2014.zip 373 | [evalai-leaderboard]: https://evalai.cloudcv.org/web/challenges/challenge-page/103/leaderboard/298 374 | -------------------------------------------------------------------------------- /convert_gpu_to_cpu.lua: -------------------------------------------------------------------------------- 1 | require 'torch'; 2 | require 'nn'; 3 | 4 | cmd = torch.CmdLine() 5 | cmd:option('-loadPath', 'checkpoints/model.t7') 6 | cmd:option('-savePath', 'checkpoints/model_cpu.t7') 7 | cmd:option('-gpuid', 0) 8 | 9 | opt = cmd:parse(arg) 10 | 11 | -- check for new save path 12 | if opt.savePath == 'checkpoints/model_cpu.t7' then 13 | opt.savePath = opt.loadPath .. '.cpu.t7' 14 | end 15 | 16 | print(opt) 17 | 18 | if opt.gpuid >= 0 then 19 | require 'cutorch' 20 | require 'cunn' 21 | if opt.backend == 'cudnn' then require 'cudnn' end 22 | cutorch.setDevice(opt.gpuid+1) 23 | torch.setdefaulttensortype('torch.CudaTensor'); 24 | else 25 | print('Gotta have a GPU to convert to CPU :(') 26 | os.exit() 27 | end 28 | 29 | print('Loading model') 30 | model = torch.load(opt.loadPath) 31 | 32 | -- convert modelW and optims to cpu 33 | print('Shipping params to CPU') 34 | if model.modelW:type() == 'torch.CudaTensor' then 35 | model.modelW = model.modelW:float() 36 | end 37 | 38 | for k,v in pairs(model.optims) do 39 | if torch.type(v) ~= 'number' and v:type() == 'torch.CudaTensor' then 40 | model.optims[k] = v:float() 41 | end 42 | end 43 | 44 | print('Saving to ' .. opt.savePath) 45 | torch.save(opt.savePath, model) 46 | -------------------------------------------------------------------------------- /data/prepro.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import h5py 4 | import json 5 | import os 6 | import numpy as np 7 | from nltk.tokenize import word_tokenize 8 | from tqdm import tqdm 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-download', action='store_true', help='Whether to download VisDial data') 13 | parser.add_argument('-version', default='1.0', choices=['0.9', '1.0'], help='Version of VisDial to be downloaded') 14 | parser.add_argument('-train_split', default='train', help='Choose the data split: train | trainval', choices=['train', 'trainval']) 15 | 16 | # Input files 17 | parser.add_argument('-input_json_train', default='visdial_1.0_train.json', help='Input `train` json file') 18 | parser.add_argument('-input_json_val', default='visdial_1.0_val.json', help='Input `val` json file') 19 | parser.add_argument('-input_json_test', default='visdial_1.0_test.json', help='Input `test` json file') 20 | parser.add_argument('-image_root', default='/path/to/images', help='Path to coco and VisDial val/test images') 21 | parser.add_argument('-input_vocab', default=False, help='Optional vocab file; similar to visdial_params.json') 22 | 23 | # Output files 24 | parser.add_argument('-output_json', default='visdial_params.json', help='Output json file') 25 | parser.add_argument('-output_h5', default='visdial_data.h5', help='Output hdf5 file') 26 | 27 | # Options 28 | parser.add_argument('-max_ques_len', default=20, type=int, help='Max length of questions') 29 | parser.add_argument('-max_ans_len', default=20, type=int, help='Max length of answers') 30 | parser.add_argument('-max_cap_len', default=40, type=int, help='Max length of captions') 31 | parser.add_argument('-word_count_threshold', default=5, type=int, help='Min threshold of word count to include in vocabulary') 32 | 33 | 34 | def tokenize_data(data, word_count=False): 35 | """Tokenize captions, questions and answers, maintain word count 36 | if required. 37 | """ 38 | word_counts = {} 39 | dialogs = data['data']['dialogs'] 40 | # dialogs is a nested dict so won't be copied, just a reference 41 | 42 | print("[%s] Tokenizing captions..." % data['split']) 43 | for i, dialog in enumerate(tqdm(dialogs)): 44 | caption = word_tokenize(dialog['caption']) 45 | dialogs[i]['caption_tokens'] = caption 46 | 47 | print("[%s] Tokenizing questions and answers..." % data['split']) 48 | q_tokens, a_tokens = [], [] 49 | for q in tqdm(data['data']['questions']): 50 | q_tokens.append(word_tokenize(q + '?')) 51 | 52 | for a in tqdm(data['data']['answers']): 53 | a_tokens.append(word_tokenize(a)) 54 | data['data']['question_tokens'] = q_tokens 55 | data['data']['answer_tokens'] = a_tokens 56 | 57 | print("[%s] Filling missing values in dialog, if any..." % data['split']) 58 | for i, dialog in enumerate(tqdm(dialogs)): 59 | # last round of dialog will not have answer for test split 60 | if 'answer' not in dialog['dialog'][-1]: 61 | dialog['dialog'][-1]['answer'] = -1 62 | # right-pad dialog with empty question-answer pairs at the end 63 | dialog['num_rounds'] = len(dialog['dialog']) 64 | while len(dialog['dialog']) < 10: 65 | dialog['dialog'].append({'question': -1, 'answer': -1}) 66 | dialogs[i] = dialog 67 | 68 | if word_count: 69 | print("[%s] Building word counts from tokens..." % data['split']) 70 | for i, dialog in enumerate(tqdm(dialogs)): 71 | caption = dialogs[i]['caption_tokens'] 72 | all_qa = [] 73 | for j in range(10): 74 | all_qa += q_tokens[dialog['dialog'][j]['question']] 75 | all_qa += a_tokens[dialog['dialog'][j]['answer']] 76 | for word in caption + all_qa: 77 | word_counts[word] = word_counts.get(word, 0) + 1 78 | print('\n') 79 | return data, word_counts 80 | 81 | 82 | def encode_vocab(data, word2ind): 83 | """Converts string tokens to indices based on given dictionary.""" 84 | dialogs = data['data']['dialogs'] 85 | print("[%s] Encoding caption tokens..." % data['split']) 86 | for i, dialog in enumerate(tqdm(dialogs)): 87 | dialogs[i]['caption_tokens'] = [word2ind.get(word, word2ind['UNK']) \ 88 | for word in dialog['caption_tokens']] 89 | 90 | print("[%s] Encoding question and answer tokens..." % data['split']) 91 | q_tokens = data['data']['question_tokens'] 92 | a_tokens = data['data']['answer_tokens'] 93 | 94 | for i, q in enumerate(tqdm(q_tokens)): 95 | q_tokens[i] = [word2ind.get(word, word2ind['UNK']) for word in q] 96 | 97 | for i, a in enumerate(tqdm(a_tokens)): 98 | a_tokens[i] = [word2ind.get(word, word2ind['UNK']) for word in a] 99 | 100 | data['data']['question_tokens'] = q_tokens 101 | data['data']['answer_tokens'] = a_tokens 102 | return data 103 | 104 | 105 | def create_data_mats(data, params, dtype): 106 | num_threads = len(data['data']['dialogs']) 107 | data_mats = {} 108 | data_mats['img_pos'] = np.arange(num_threads, dtype=np.int) 109 | 110 | print("[%s] Creating caption data matrices..." % data['split']) 111 | max_cap_len = params.max_cap_len 112 | captions = np.zeros([num_threads, max_cap_len]) 113 | caption_len = np.zeros(num_threads, dtype=np.int) 114 | 115 | for i, dialog in enumerate(tqdm(data['data']['dialogs'])): 116 | caption_len[i] = len(dialog['caption_tokens'][0:max_cap_len]) 117 | captions[i][0:caption_len[i]] = dialog['caption_tokens'][0:max_cap_len] 118 | data_mats['cap_length'] = caption_len 119 | data_mats['cap'] = captions 120 | 121 | print("[%s] Creating question and answer data matrices..." % data['split']) 122 | num_rounds = 10 123 | max_ques_len = params.max_ques_len 124 | max_ans_len = params.max_ans_len 125 | 126 | ques = np.zeros([num_threads, num_rounds, max_ques_len]) 127 | ans = np.zeros([num_threads, num_rounds, max_ans_len]) 128 | ques_length = np.zeros([num_threads, num_rounds], dtype=np.int) 129 | ans_length = np.zeros([num_threads, num_rounds], dtype=np.int) 130 | 131 | for i, dialog in enumerate(tqdm(data['data']['dialogs'])): 132 | for j in range(num_rounds): 133 | if dialog['dialog'][j]['question'] != -1: 134 | ques_length[i][j] = len(data['data']['question_tokens'][ 135 | dialog['dialog'][j]['question']][0:max_ques_len]) 136 | ques[i][j][0:ques_length[i][j]] = data['data']['question_tokens'][ 137 | dialog['dialog'][j]['question']][0:max_ques_len] 138 | if dialog['dialog'][j]['answer'] != -1: 139 | ans_length[i][j] = len(data['data']['answer_tokens'][ 140 | dialog['dialog'][j]['answer']][0:max_ans_len]) 141 | ans[i][j][0:ans_length[i][j]] = data['data']['answer_tokens'][ 142 | dialog['dialog'][j]['answer']][0:max_ans_len] 143 | 144 | data_mats['ques'] = ques 145 | data_mats['ans'] = ans 146 | data_mats['ques_length'] = ques_length 147 | data_mats['ans_length'] = ans_length 148 | 149 | print("[%s] Creating options data matrices..." % data['split']) 150 | # options and answer_index are 1-indexed specifically for lua 151 | options = np.ones([num_threads, num_rounds, 100]) 152 | num_rounds_list = np.full(num_threads, 10) 153 | 154 | for i, dialog in enumerate(tqdm(data['data']['dialogs'])): 155 | for j in range(num_rounds): 156 | num_rounds_list[i] = dialog['num_rounds'] 157 | # v1.0 test does not have options for all dialog rounds 158 | if 'answer_options' in dialog['dialog'][j]: 159 | options[i][j] += np.array(dialog['dialog'][j]['answer_options']) 160 | 161 | data_mats['num_rounds'] = num_rounds_list 162 | data_mats['opt'] = options 163 | 164 | if dtype != 'test': 165 | print("[%s] Creating ground truth answer data matrices..." % data['split']) 166 | answer_index = np.zeros([num_threads, num_rounds]) 167 | for i, dialog in enumerate(tqdm(data['data']['dialogs'])): 168 | for j in range(num_rounds): 169 | answer_index[i][j] = dialog['dialog'][j]['gt_index'] + 1 170 | data_mats['ans_index'] = answer_index 171 | 172 | options_len = np.zeros(len(data['data']['answer_tokens']), dtype=np.int) 173 | options_list = np.zeros([len(data['data']['answer_tokens']), max_ans_len]) 174 | 175 | for i, ans_token in enumerate(tqdm(data['data']['answer_tokens'])): 176 | options_len[i] = len(ans_token[0:max_ans_len]) 177 | options_list[i][0:options_len[i]] = ans_token[0:max_ans_len] 178 | 179 | data_mats['opt_length'] = options_len 180 | data_mats['opt_list'] = options_list 181 | return data_mats 182 | 183 | 184 | def get_image_ids(data, id2path): 185 | image_ids = [dialog['image_id'] for dialog in data['data']['dialogs']] 186 | for i, image_id in enumerate(image_ids): 187 | image_ids[i] = id2path[image_id] 188 | return image_ids 189 | 190 | 191 | if __name__ == "__main__": 192 | args = parser.parse_args() 193 | 194 | if args.download: 195 | if args.version == '1.0': 196 | os.system('wget https://www.dropbox.com/s/ix8keeudqrd8hn8/visdial_1.0_train.zip') 197 | os.system('wget https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip') 198 | elif args.version == '0.9': 199 | os.system('wget https://computing.ece.vt.edu/~abhshkdz/data/visdial/visdial_0.9_train.zip') 200 | os.system('wget https://computing.ece.vt.edu/~abhshkdz/data/visdial/visdial_0.9_val.zip') 201 | os.system('wget https://www.dropbox.com/s/o7mucbre2zm7i5n/visdial_1.0_test.zip') 202 | 203 | os.system('unzip visdial_%s_train.zip' % args.version) 204 | os.system('unzip visdial_%s_val.zip' % args.version) 205 | os.system('unzip visdial_1.0_test.zip') 206 | 207 | args.input_json_train = 'visdial_%s_train.json' % args.version 208 | args.input_json_val = 'visdial_%s_val.json' % args.version 209 | args.input_json_test = 'visdial_1.0_test.json' 210 | 211 | print('Reading json...') 212 | data_train = json.load(open(args.input_json_train, 'r')) 213 | data_val = json.load(open(args.input_json_val, 'r')) 214 | data_test = json.load(open(args.input_json_test, 'r')) 215 | 216 | # Tokenizing 217 | data_train, word_counts_train = tokenize_data(data_train, True) 218 | data_val, word_counts_val = tokenize_data(data_val, True) 219 | data_test, _ = tokenize_data(data_test) 220 | 221 | if args.input_vocab == False: 222 | word_counts_all = dict(word_counts_train) 223 | # combining the word counts of train and val splits 224 | if args.train_split == 'trainval': 225 | for word, count in word_counts_val.items(): 226 | word_counts_all[word] = word_counts_all.get(word, 0) + count 227 | 228 | print('Building vocabulary...') 229 | word_counts_all['UNK'] = args.word_count_threshold 230 | vocab = [word for word in word_counts_all \ 231 | if word_counts_all[word] >= args.word_count_threshold] 232 | print('Words: %d' % len(vocab)) 233 | word2ind = {word: word_ind + 1 for word_ind, word in enumerate(vocab)} 234 | ind2word = {word_ind: word for word, word_ind in word2ind.items()} 235 | else: 236 | print('Loading vocab from %s...' % args.input_vocab) 237 | vocab_data = json.load(open(args.input_vocab, 'r')) 238 | 239 | word2ind = vocab_data['word2ind'] 240 | for i in word2ind: 241 | word2ind[i] = int(word2ind[i]) 242 | 243 | ind2word = {} 244 | for i in vocab_data['ind2word']: 245 | ind2word[int(i)] = vocab_data['ind2word'][i] 246 | 247 | print('Encoding based on vocabulary...') 248 | data_train = encode_vocab(data_train, word2ind) 249 | data_val = encode_vocab(data_val, word2ind) 250 | data_test = encode_vocab(data_test, word2ind) 251 | 252 | print('Creating data matrices...') 253 | data_mats_train = create_data_mats(data_train, args, 'train') 254 | data_mats_val = create_data_mats(data_val, args, 'val') 255 | data_mats_test = create_data_mats(data_test, args, 'test') 256 | 257 | if args.train_split == 'trainval': 258 | data_mats_trainval = {} 259 | for key in data_mats_train: 260 | data_mats_trainval[key] = np.concatenate((data_mats_train[key], 261 | data_mats_val[key]), axis = 0) 262 | 263 | print('Saving hdf5 to %s...' % args.output_h5) 264 | f = h5py.File(args.output_h5, 'w') 265 | if args.train_split == 'train': 266 | for key in data_mats_train: 267 | f.create_dataset(key + '_train', dtype='uint32', data=data_mats_train[key]) 268 | 269 | for key in data_mats_val: 270 | f.create_dataset(key + '_val', dtype='uint32', data=data_mats_val[key]) 271 | 272 | elif args.train_split == 'trainval': 273 | for key in data_mats_trainval: 274 | f.create_dataset(key + '_train', dtype='uint32', data=data_mats_trainval[key]) 275 | 276 | for key in data_mats_test: 277 | f.create_dataset(key + '_test', dtype='uint32', data=data_mats_test[key]) 278 | f.close() 279 | 280 | out = {} 281 | out['ind2word'] = ind2word 282 | out['word2ind'] = word2ind 283 | 284 | print('Preparing image paths with image_ids...') 285 | id2path = {} 286 | # NOTE: based on assumption that image_id is unique across all splits 287 | for image_path in tqdm(glob.iglob(os.path.join(args.image_root, '*', '*.jpg'))): 288 | id2path[int(image_path[-12:-4])] = '/'.join(image_path.split('/')[-2:]) 289 | 290 | out['unique_img_train'] = get_image_ids(data_train, id2path) 291 | out['unique_img_val'] = get_image_ids(data_val, id2path) 292 | out['unique_img_test'] = get_image_ids(data_test, id2path) 293 | if args.train_split == 'trainval': 294 | out['unique_img_train'] += out['unique_img_val'] 295 | out.pop('unique_img_val') 296 | print('Saving json to %s...' % args.output_json) 297 | json.dump(out, open(args.output_json, 'w')) 298 | -------------------------------------------------------------------------------- /data/prepro_img_resnet.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'prepro_utils' 4 | 5 | ------------------------------------------------------------------------------- 6 | -- Input arguments and options 7 | ------------------------------------------------------------------------------- 8 | cmd = torch.CmdLine() 9 | cmd:text('Extract Image Features from Pretrained ResNet Models (t7 models)') 10 | cmd:text() 11 | cmd:text('Options') 12 | cmd:option('-inputJson', 'visdial_params.json', 'Path to JSON file') 13 | cmd:option('-imageRoot', '/path/to/images/', 'Path to COCO image root') 14 | cmd:option('-cnnModel', '/path/to/t7/model', 'Path to Pretrained T7 Model') 15 | cmd:option('-trainSplit', 'train', 'Which split to use: train | trainval') 16 | cmd:option('-batchSize', 50, 'Batch size') 17 | 18 | cmd:option('-outName', 'data_img.h5', 'Output name') 19 | cmd:option('-gpuid', 0, 'Which gpu to use. -1 = use CPU') 20 | 21 | cmd:option('-imgSize', 224) 22 | 23 | opt = cmd:parse(arg) 24 | print(opt) 25 | 26 | if opt.gpuid >= 0 then 27 | require 'cutorch' 28 | require 'cunn' 29 | require 'cudnn' 30 | cutorch.setDevice(opt.gpuid + 1) 31 | end 32 | 33 | ------------------------------------------------------------------------------- 34 | -- Loading model and removing extra layers 35 | ------------------------------------------------------------------------------- 36 | model = torch.load(opt.cnnModel); 37 | -- Remove the last fully connected + softmax layer of the model 38 | model:remove() 39 | model:evaluate() 40 | 41 | ------------------------------------------------------------------------------- 42 | -- Infering output dim 43 | ------------------------------------------------------------------------------- 44 | local dummy_img = torch.DoubleTensor(1, 3, opt.imgSize, opt.imgSize) 45 | 46 | if opt.gpuid >= 0 then 47 | dummy_img = dummy_img:cuda() 48 | model = model:cuda() 49 | end 50 | 51 | model:forward(dummy_img) 52 | local ndims = model.output:squeeze():size():totable() 53 | 54 | ------------------------------------------------------------------------------- 55 | -- Defining function for image preprocessing, like mean subtraction 56 | ------------------------------------------------------------------------------- 57 | function preprocessFn(im) 58 | -- mean pixel for torch models trained on imagenet 59 | local meanstd = { 60 | mean = { 0.485, 0.456, 0.406 }, 61 | std = { 0.229, 0.224, 0.225 }, 62 | } 63 | for i = 1, 3 do 64 | im[i]:add(-meanstd.mean[i]) 65 | im[i]:div(meanstd.std[i]) 66 | end 67 | return im 68 | end 69 | 70 | ------------------------------------------------------------------------------- 71 | -- Extract features and save to HDF 72 | ------------------------------------------------------------------------------- 73 | extractFeatures(model, opt, ndims, preprocessFn) 74 | -------------------------------------------------------------------------------- /data/prepro_img_vgg16.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'loadcaffe' 3 | require 'prepro_utils' 4 | 5 | ------------------------------------------------------------------------------- 6 | -- Input arguments and options 7 | ------------------------------------------------------------------------------- 8 | cmd = torch.CmdLine() 9 | cmd:text('Extract Image Features from Pretrained VGG16 Models (prototxt + caffemodel)') 10 | cmd:text() 11 | cmd:text('Options') 12 | cmd:option('-inputJson', 'visdial_params.json', 'Path to JSON file') 13 | cmd:option('-imageRoot', '/path/to/images/', 'Path to COCO image root') 14 | cmd:option('-cnnProto', 'models/vgg16/VGG_ILSVRC_16_layers_deploy.prototxt', 'Path to the CNN prototxt') 15 | cmd:option('-cnnModel', 'models/vgg16/VGG_ILSVRC_16_layers.caffemodel', 'Path to the CNN model') 16 | cmd:option('-trainSplit', 'train', 'Which split to use: train | trainval') 17 | cmd:option('-batchSize', 50, 'Batch size') 18 | 19 | cmd:option('-outName', 'data_img.h5', 'Output name') 20 | cmd:option('-gpuid', 0, 'Which gpu to use. -1 = use CPU') 21 | cmd:option('-backend', 'nn', 'nn|cudnn') 22 | 23 | cmd:option('-imgSize', 224) 24 | cmd:option('-layerName', 'relu7') 25 | 26 | opt = cmd:parse(arg) 27 | print(opt) 28 | 29 | if opt.gpuid >= 0 then 30 | require 'cutorch' 31 | require 'cunn' 32 | cutorch.setDevice(opt.gpuid + 1) 33 | end 34 | 35 | ------------------------------------------------------------------------------- 36 | -- Loading model and removing extra layers 37 | ------------------------------------------------------------------------------- 38 | model = loadcaffe.load(opt.cnnProto, opt.cnnModel, opt.backend); 39 | 40 | for i = #model.modules, 1, -1 do 41 | local layer = model:get(i) 42 | if layer.name == opt.layerName then break end 43 | model:remove() 44 | end 45 | model:evaluate() 46 | 47 | ------------------------------------------------------------------------------- 48 | -- Infering output dim 49 | ------------------------------------------------------------------------------- 50 | local dummy_img = torch.DoubleTensor(1, 3, opt.imgSize, opt.imgSize) 51 | 52 | if opt.gpuid >= 0 then 53 | dummy_img = dummy_img:cuda() 54 | model = model:cuda() 55 | end 56 | 57 | model:forward(dummy_img) 58 | local ndims = model.output:squeeze():size():totable() 59 | 60 | ------------------------------------------------------------------------------- 61 | -- Defining function for image preprocessing, like mean subtraction 62 | ------------------------------------------------------------------------------- 63 | function preprocessFn(im) 64 | -- mean pixel for caffemodels trained on imagenet 65 | local meanPixel = torch.DoubleTensor({103.939, 116.779, 123.68}) 66 | im = im:index(1, torch.LongTensor{3, 2, 1}):mul(255.0) 67 | meanPixel = meanPixel:view(3, 1, 1):expandAs(im) 68 | im:add(-1, meanPixel) 69 | return im 70 | end 71 | 72 | ------------------------------------------------------------------------------- 73 | -- Extract features and save to HDF 74 | ------------------------------------------------------------------------------- 75 | extractFeatures(model, opt, ndims, preprocessFn) 76 | -------------------------------------------------------------------------------- /data/prepro_utils.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'math' 3 | require 'nn' 4 | require 'image' 5 | require 'xlua' 6 | require 'hdf5' 7 | cjson = require('cjson') 8 | 9 | 10 | function loadImage(imageName, imgSize, preprocessType) 11 | im = image.load(imageName) 12 | 13 | if im:size(1) == 1 then 14 | im = im:repeatTensor(3, 1, 1) 15 | elseif im:size(1) == 4 then 16 | im = im[{{1,3}, {}, {}}] 17 | end 18 | 19 | im = image.scale(im, imgSize, imgSize) 20 | return im 21 | end 22 | 23 | function extractFeaturesSplit(model, opt, ndims, preprocessFn, dtype) 24 | local file = io.open(opt.inputJson, 'r') 25 | local text = file:read() 26 | file:close() 27 | jsonFile = cjson.decode(text) 28 | 29 | local imList = {} 30 | for i, imName in pairs(jsonFile['unique_img_'..dtype]) do 31 | table.insert(imList, string.format('%s/%s', opt.imageRoot, imName)) 32 | end 33 | 34 | local sz = #imList 35 | local imFeats = torch.FloatTensor(sz, unpack(ndims)) 36 | 37 | -- feature_dims shall be either 2 (NW format), else 4 (having NCHW format) 38 | local feature_dims = #imFeats:size() 39 | 40 | print(string.format('Processing %d %s images...', sz, dtype)) 41 | for i = 1, sz, opt.batchSize do 42 | xlua.progress(i, sz) 43 | r = math.min(sz, i + opt.batchSize - 1) 44 | ims = torch.DoubleTensor(r - i + 1, 3, opt.imgSize, opt.imgSize) 45 | for j = 1, r - i + 1 do 46 | ims[j] = loadImage(imList[i + j - 1], opt.imgSize) 47 | ims[j] = preprocessFn(ims[j]) 48 | end 49 | if opt.gpuid >= 0 then 50 | ims = ims:cuda() 51 | end 52 | 53 | if feature_dims == 4 then 54 | -- forward pass and permute to get NHWC format 55 | model:forward(ims):permute(1, 3, 4, 2):contiguous():float() 56 | else 57 | model:forward(ims) 58 | end 59 | imFeats[{{i, r}, {}}] = model.output:float() 60 | collectgarbage() 61 | end 62 | 63 | return imFeats 64 | end 65 | 66 | function extractFeatures(model, opt, ndims, preprocessFn) 67 | local h5File = hdf5.open(opt.outName, 'w') 68 | imFeats = extractFeaturesSplit(model, opt, ndims, preprocessFn, 'train') 69 | h5File:write('/images_train', imFeats) 70 | if opt.trainSplit == 'train' then 71 | imFeats = extractFeaturesSplit(model, opt, ndims, preprocessFn, 'val') 72 | h5File:write('/images_val', imFeats) 73 | imFeats = extractFeaturesSplit(model, opt, ndims, preprocessFn, 'test') 74 | h5File:write('/images_test', imFeats) 75 | elseif opt.trainSplit == 'trainval' then 76 | imFeats = extractFeaturesSplit(model, opt, ndims, preprocessFn, 'test') 77 | h5File:write('/images_test', imFeats) 78 | end 79 | h5File:close() 80 | end 81 | -------------------------------------------------------------------------------- /dataloader.lua: -------------------------------------------------------------------------------- 1 | require 'hdf5' 2 | require 'xlua' 3 | local utils = require 'utils' 4 | 5 | local dataloader = {}; 6 | 7 | -- read the data 8 | -- params: object itself, command line options, 9 | -- subset of data to load (train, val, test) 10 | function dataloader:initialize(opt, subsets) 11 | -- read additional info like dictionary, etc 12 | print('DataLoader loading json file: ', opt.inputJson) 13 | info = utils.readJSON(opt.inputJson); 14 | for key, value in pairs(info) do dataloader[key] = value; end 15 | 16 | -- add and to vocabulary 17 | count = 0; 18 | for _ in pairs(dataloader['word2ind']) do count = count + 1; end 19 | dataloader['word2ind'][''] = count + 1; 20 | dataloader['word2ind'][''] = count + 2; 21 | count = count + 2; 22 | dataloader.vocabSize = count; 23 | print(string.format('Vocabulary size (with ,): %d\n', count)); 24 | 25 | -- construct ind2word 26 | local ind2word = {}; 27 | for word, ind in pairs(dataloader['word2ind']) do 28 | ind2word[ind] = word; 29 | end 30 | dataloader['ind2word'] = ind2word; 31 | 32 | -- read questions, answers and options 33 | print('DataLoader loading h5 file: ', opt.inputQues) 34 | local quesFile = hdf5.open(opt.inputQues, 'r'); 35 | 36 | print('DataLoader loading h5 file: ', opt.inputImg) 37 | local imgFile = hdf5.open(opt.inputImg, 'r'); 38 | -- number of threads 39 | self.numThreads = {}; 40 | 41 | for _, dtype in pairs(subsets) do 42 | -- convert image ids to numbers 43 | for k, v in pairs(dataloader['unique_img_'..dtype]) do 44 | dataloader['unique_img_'..dtype][k] = tonumber(string.match(v, '000%d+')) 45 | end 46 | 47 | -- read question related information 48 | self[dtype..'_ques'] = quesFile:read('ques_'..dtype):all(); 49 | self[dtype..'_ques_len'] = quesFile:read('ques_length_'..dtype):all(); 50 | 51 | -- read answer related information 52 | self[dtype..'_ans'] = quesFile:read('ans_'..dtype):all(); 53 | self[dtype..'_ans_len'] = quesFile:read('ans_length_'..dtype):all(); 54 | if dtype ~= 'test' then 55 | self[dtype..'_ans_ind'] = quesFile:read('ans_index_'..dtype):all():long(); 56 | end 57 | 58 | -- read image list, if image features are needed 59 | if opt.useIm then 60 | print('Reading image features..') 61 | local imgFeats = imgFile:read('/images_'..dtype):all(); 62 | 63 | -- Normalize the image features (if needed) 64 | if opt.imgNorm == 1 then 65 | print('Normalizing image features..') 66 | local nm = torch.sqrt(torch.sum(torch.cmul(imgFeats, imgFeats), 2)); 67 | imgFeats = torch.cdiv(imgFeats, nm:expandAs(imgFeats)):float(); 68 | end 69 | -- Transpose from N x 512 x 14 x 14 to N x 14 x 14 x 512 70 | if string.match(opt.encoder, 'att') then 71 | imgFeats = imgFeats:permute(1, 3, 4, 2); 72 | end 73 | self[dtype..'_img_fv'] = imgFeats; 74 | -- TODO: make it 1 indexed in processing code 75 | -- currently zero indexed, adjust manually 76 | self[dtype..'_img_pos'] = quesFile:read('img_pos_'..dtype):all():long(); 77 | self[dtype..'_img_pos'] = self[dtype..'_img_pos'] + 1; 78 | end 79 | 80 | -- print information for data type 81 | print(string.format('%s:\n\tNo. of threads: %d\n\tNo. of rounds: %d'.. 82 | '\n\tMax ques len: %d'..'\n\tMax ans len: %d\n', 83 | dtype, self[dtype..'_ques']:size(1), 84 | self[dtype..'_ques']:size(2), 85 | self[dtype..'_ques']:size(3), 86 | self[dtype..'_ans']:size(3))); 87 | 88 | -- record some stats 89 | if dtype == 'train' then 90 | self.numTrainThreads = self['train_ques']:size(1); 91 | self.numThreads['train'] = self.numTrainThreads; 92 | end 93 | if dtype == 'test' then 94 | self.numTestThreads = self['test_ques']:size(1); 95 | self.numThreads['test'] = self.numTestThreads; 96 | end 97 | if dtype == 'val' then 98 | self.numValThreads = self['val_ques']:size(1); 99 | self.numThreads['val'] = self.numValThreads; 100 | end 101 | 102 | -- record the options 103 | if dtype == 'train' or dtype == 'val' or dtype == 'test' then 104 | self[dtype..'_opt'] = quesFile:read('opt_'..dtype):all():long(); 105 | self[dtype..'_opt_len'] = quesFile:read('opt_length_'..dtype):all(); 106 | self[dtype..'_opt_list'] = quesFile:read('opt_list_'..dtype):all(); 107 | self.numOptions = self[dtype..'_opt']:size(3); 108 | end 109 | 110 | self[dtype..'_num_rounds'] = quesFile:read('num_rounds_'..dtype):all(); 111 | 112 | -- assume similar stats across multiple data subsets 113 | -- maximum number of questions per image, ideally 10 114 | self.maxQuesCount = self[dtype..'_ques']:size(2); 115 | -- maximum length of question 116 | self.maxQuesLen = self[dtype..'_ques']:size(3); 117 | -- maximum length of answer 118 | self.maxAnsLen = self[dtype..'_ans']:size(3); 119 | 120 | -- if history is needed 121 | if opt.useHistory then 122 | self[dtype..'_cap'] = quesFile:read('cap_'..dtype):all():long(); 123 | self[dtype..'_cap_len'] = quesFile:read('cap_length_'..dtype):all(); 124 | end 125 | end 126 | -- done reading, close files 127 | quesFile:close(); 128 | imgFile:close(); 129 | 130 | -- take desired flags/values from opt 131 | self.useHistory = opt.useHistory; 132 | self.concatHistory = opt.concatHistory; 133 | self.useIm = opt.useIm; 134 | self.maxHistoryLen = opt.maxHistoryLen or 60; 135 | 136 | -- prepareDataset for training 137 | for _, dtype in pairs(subsets) do self:prepareDataset(dtype); end 138 | end 139 | 140 | -- method to prepare questions and answers for retrieval 141 | -- questions : right align 142 | -- answers : prefix with and 143 | function dataloader:prepareDataset(dtype) 144 | -- right align the questions 145 | print('Right aligning questions: '..dtype); 146 | self[dtype..'_ques_fwd'] = utils.rightAlign(self[dtype..'_ques'], 147 | self[dtype..'_ques_len']); 148 | 149 | -- if separate captions are needed 150 | if self.useHistory then self:processHistory(dtype); end 151 | -- prefix options with and , if not train 152 | -- if dtype ~= 'train' then self:processOptions(dtype); end 153 | self:processOptions(dtype) 154 | -- process answers 155 | self:processAnswers(dtype); 156 | end 157 | 158 | -- process answers 159 | function dataloader:processAnswers(dtype) 160 | --prefix answers with , ; adjust answer lengths 161 | local answers = self[dtype..'_ans']; 162 | local ansLen = self[dtype..'_ans_len']; 163 | 164 | local numConvs = answers:size(1); 165 | local numRounds = answers:size(2); 166 | local maxAnsLen = answers:size(3); 167 | 168 | local decodeIn = torch.LongTensor(numConvs, numRounds, maxAnsLen+1):zero(); 169 | local decodeOut = torch.LongTensor(numConvs, numRounds, maxAnsLen+1):zero(); 170 | 171 | -- decodeIn begins with 172 | decodeIn[{{}, {}, 1}] = self.word2ind['']; 173 | 174 | -- go over each answer and modify 175 | local endTokenId = self.word2ind['']; 176 | for thId = 1, numConvs do 177 | for roundId = 1, numRounds do 178 | local length = ansLen[thId][roundId]; 179 | 180 | -- only if nonzero 181 | if length > 0 then 182 | decodeIn[thId][roundId][{{2, length + 1}}] 183 | = answers[thId][roundId][{{1, length}}]; 184 | 185 | decodeOut[thId][roundId][{{1, length}}] 186 | = answers[thId][roundId][{{1, length}}]; 187 | else 188 | if dtype ~= 'test' then 189 | print(string.format('Warning: empty answer at (%d %d %d)', 190 | thId, roundId, length)) 191 | end 192 | end 193 | decodeOut[thId][roundId][length+1] = endTokenId; 194 | end 195 | end 196 | 197 | self[dtype..'_ans_len'] = self[dtype..'_ans_len'] + 1; 198 | self[dtype..'_ans_in'] = decodeIn; 199 | self[dtype..'_ans_out'] = decodeOut; 200 | end 201 | 202 | -- process caption as history 203 | function dataloader:processHistory(dtype) 204 | local captions = self[dtype..'_cap']; 205 | local questions = self[dtype..'_ques']; 206 | local quesLen = self[dtype..'_ques_len']; 207 | local capLen = self[dtype..'_cap_len']; 208 | local maxQuesLen = questions:size(3); 209 | 210 | local answers = self[dtype..'_ans']; 211 | local ansLen = self[dtype..'_ans_len']; 212 | local numConvs = answers:size(1); 213 | local numRounds = answers:size(2); 214 | local maxAnsLen = answers:size(3); 215 | 216 | local history, histLen; 217 | if self.concatHistory == true then 218 | self.maxHistoryLen = math.min(numRounds * (maxQuesLen + maxAnsLen), 300); 219 | 220 | history = torch.LongTensor(numConvs, numRounds, 221 | self.maxHistoryLen):zero(); 222 | histLen = torch.LongTensor(numConvs, numRounds):zero(); 223 | else 224 | history = torch.LongTensor(numConvs, numRounds, 225 | maxQuesLen+maxAnsLen):zero(); 226 | histLen = torch.LongTensor(numConvs, numRounds):zero(); 227 | end 228 | 229 | -- go over each question and append it with answer 230 | for thId = 1, numConvs do 231 | local lenC = capLen[thId]; 232 | local lenH; -- length of history 233 | for roundId = 1, numRounds do 234 | if roundId == 1 then 235 | -- first round has caption as history 236 | history[thId][roundId][{{1, maxQuesLen + maxAnsLen}}] 237 | = captions[thId][{{1, maxQuesLen + maxAnsLen}}]; 238 | lenH = math.min(lenC, maxQuesLen + maxAnsLen); 239 | else 240 | local lenQ = quesLen[thId][roundId-1]; 241 | local lenA = ansLen[thId][roundId-1]; 242 | -- if concatHistory, string together all previous QAs 243 | if self.concatHistory == true then 244 | history[thId][roundId][{{1, lenH}}] 245 | = history[thId][roundId-1][{{1, lenH}}]; 246 | history[thId][roundId][{{lenH+1}}] = self.word2ind['']; 247 | if lenQ > 0 then 248 | history[thId][roundId][{{lenH+2, lenH+1+lenQ}}] 249 | = questions[thId][roundId-1][{{1, lenQ}}]; 250 | end 251 | if lenA > 0 then 252 | history[thId][roundId][{{lenH+1+lenQ+1, lenH+1+lenQ+lenA}}] 253 | = answers[thId][roundId-1][{{1, lenA}}]; 254 | end 255 | lenH = lenH + lenQ + lenA + 1 256 | -- else, history is just previous round QA 257 | else 258 | if lenQ > 0 then 259 | history[thId][roundId][{{1, lenQ}}] 260 | = questions[thId][roundId-1][{{1, lenQ}}]; 261 | end 262 | if lenA > 0 then 263 | history[thId][roundId][{{lenQ + 1, lenQ + lenA}}] 264 | = answers[thId][roundId-1][{{1, lenA}}]; 265 | end 266 | lenH = lenA + lenQ; 267 | end 268 | end 269 | -- save the history length 270 | histLen[thId][roundId] = lenH; 271 | end 272 | end 273 | 274 | -- right align history and then save 275 | print('Right aligning history: '..dtype); 276 | self[dtype..'_hist'] = utils.rightAlign(history, histLen); 277 | self[dtype..'_hist_len'] = histLen; 278 | end 279 | 280 | -- process options 281 | function dataloader:processOptions(dtype) 282 | local lengths = self[dtype..'_opt_len']; 283 | local answers = self[dtype..'_ans']; 284 | local maxAnsLen = answers:size(3); 285 | local answers = self[dtype..'_opt_list']; 286 | local numConvs = answers:size(1); 287 | 288 | local ansListLen = answers:size(1); 289 | local decodeIn = torch.LongTensor(ansListLen, maxAnsLen + 1):zero(); 290 | local decodeOut = torch.LongTensor(ansListLen, maxAnsLen + 1):zero(); 291 | 292 | -- decodeIn begins with 293 | decodeIn[{{}, 1}] = self.word2ind['']; 294 | 295 | -- go over each answer and modify 296 | local endTokenId = self.word2ind['']; 297 | for id = 1, ansListLen do 298 | -- print progress for number of images 299 | if id % 100 == 0 then 300 | xlua.progress(id, numConvs); 301 | end 302 | local length = lengths[id]; 303 | 304 | -- only if nonzero 305 | if length > 0 then 306 | decodeIn[id][{{2, length + 1}}] = answers[id][{{1, length}}]; 307 | 308 | decodeOut[id][{{1, length}}] = answers[id][{{1, length}}]; 309 | decodeOut[id][length + 1] = endTokenId; 310 | else 311 | print(string.format('Warning: empty answer for %s at %d', 312 | dtype, id)) 313 | end 314 | end 315 | 316 | self[dtype..'_opt_len'] = self[dtype..'_opt_len'] + 1; 317 | self[dtype..'_opt_in'] = decodeIn; 318 | self[dtype..'_opt_out'] = decodeOut; 319 | 320 | collectgarbage(); 321 | end 322 | 323 | -- method to grab the next training batch 324 | function dataloader.getTrainBatch(self, params, batchSize) 325 | local size = batchSize or params.batchSize; 326 | local inds = torch.LongTensor(size):random(1, params.numTrainThreads); 327 | 328 | -- Index question, answers, image features for batch 329 | local batchOutput = self:getIndexData(inds, params, 'train') 330 | if params.decoder == 'disc' then 331 | local optionOutput = self:getIndexOption(inds, params, 'train') 332 | batchOutput['options'] = optionOutput:view(optionOutput:size(1) 333 | * optionOutput:size(2), optionOutput:size(3), -1) 334 | batchOutput['answer_ind'] = batchOutput['answer_ind']:view(batchOutput['answer_ind'] 335 | :size(1) * batchOutput['answer_ind']:size(2)) 336 | end 337 | 338 | return batchOutput 339 | end 340 | 341 | -- method to grab the next test/val batch, for evaluation of a given size 342 | function dataloader.getTestBatch(self, startId, params, dtype) 343 | local batchSize = params.batchSize 344 | -- get the next start id and fill up current indices till then 345 | local nextStartId; 346 | if dtype == 'val' then 347 | nextStartId = math.min(self.numValThreads+1, startId + batchSize); 348 | end 349 | if dtype == 'test' then 350 | nextStartId = math.min(self.numTestThreads+1, startId + batchSize); 351 | end 352 | 353 | -- dumb way to get range (complains if cudatensor is default) 354 | local inds = torch.LongTensor(nextStartId - startId); 355 | for ii = startId, nextStartId - 1 do inds[ii - startId + 1] = ii; end 356 | 357 | -- Index question, answers, image features for batch 358 | local batchOutput = self:getIndexData(inds, params, dtype); 359 | local optionOutput = self:getIndexOption(inds, params, dtype); 360 | 361 | if params.decoder == 'disc' then 362 | batchOutput['options'] = optionOutput:view(optionOutput:size(1) 363 | * optionOutput:size(2), optionOutput:size(3), -1) 364 | if dtype ~= 'test' then 365 | batchOutput['answer_ind'] = batchOutput['answer_ind']:view(batchOutput['answer_ind'] 366 | :size(1) * batchOutput['answer_ind']:size(2)) 367 | end 368 | elseif params.decoder == 'gen' then 369 | -- merge both the tables and return 370 | for key, value in pairs(optionOutput) do batchOutput[key] = value; end 371 | end 372 | batchOutput['num_rounds'] = self[dtype..'_num_rounds']:index(1, inds):long() 373 | 374 | return batchOutput, nextStartId; 375 | end 376 | 377 | -- get batch from data subset given the indices 378 | function dataloader.getIndexData(self, inds, params, dtype) 379 | -- get the question lengths 380 | local batchQuesLen = self[dtype..'_ques_len']:index(1, inds); 381 | local maxQuesLen = torch.max(batchQuesLen); 382 | -- get questions 383 | local quesFwd = self[dtype..'_ques_fwd']:index(1, inds) 384 | [{{}, {}, {-maxQuesLen, -1}}]; 385 | 386 | local history; 387 | if self.useHistory then 388 | local batchHistLen = self[dtype..'_hist_len']:index(1, inds); 389 | local maxHistLen = math.min(torch.max(batchHistLen), self.maxHistoryLen); 390 | history = self[dtype..'_hist']:index(1, inds) 391 | [{{}, {}, {-maxHistLen, -1}}]; 392 | end 393 | 394 | local imgFeats; 395 | if self.useIm then 396 | local imgInds = self[dtype..'_img_pos']:index(1, inds); 397 | imgFeats = self[dtype..'_img_fv']:index(1, imgInds); 398 | end 399 | 400 | -- get the answer lengths 401 | local batchAnsLen = self[dtype..'_ans_len']:index(1, inds); 402 | local maxAnsLen = torch.max(batchAnsLen); 403 | -- answer labels (decode input and output) 404 | local answerIn = self[dtype..'_ans_in'] 405 | :index(1, inds)[{{}, {}, {1, maxAnsLen}}]; 406 | local answerOut = self[dtype..'_ans_out'] 407 | :index(1, inds)[{{}, {}, {1, maxAnsLen}}]; 408 | 409 | local output = {}; 410 | if params.gpuid >= 0 then 411 | output['ques_fwd'] = quesFwd:cuda(); 412 | output['answer_in'] = answerIn:cuda(); 413 | output['answer_out'] = answerOut:cuda(); 414 | if history then output['hist'] = history:cuda(); end 415 | if caption then output['cap'] = caption:cuda(); end 416 | if imgFeats then output['img_feat'] = imgFeats:cuda(); end 417 | else 418 | output['ques_fwd'] = quesFwd:contiguous(); 419 | output['answer_in'] = answerIn:contiguous(); 420 | output['answer_out'] = answerOut:contiguous(); 421 | if history then output['hist'] = history:contiguous(); end 422 | if caption then output['cap'] = caption:contiguous(); end 423 | if imgFeats then output['img_feat'] = imgFeats:contiguous(); end 424 | end 425 | 426 | if dtype ~= 'test' then 427 | local answerInd = self[dtype..'_ans_ind']:index(1, inds); 428 | output['answer_ind'] = params.gpuid >= 0 and answerInd:cuda() or answerInd:contiguous(); 429 | end 430 | 431 | return output; 432 | end 433 | 434 | -- get batch from options given the indices 435 | function dataloader.getIndexOption(self, inds, params, dtype) 436 | local output = {}; 437 | if params.decoder == 'gen' then 438 | local optionIn, optionOut 439 | 440 | local optInds = self[dtype..'_opt']:index(1, inds); 441 | local indVector = optInds:view(-1); 442 | 443 | local batchOptLen = self[dtype..'_opt_len']:index(1, indVector); 444 | local maxOptLen = torch.max(batchOptLen); 445 | 446 | optionIn = self[dtype..'_opt_in']:index(1, indVector); 447 | optionIn = optionIn:view(optInds:size(1), optInds:size(2), 448 | optInds:size(3), -1); 449 | optionIn = optionIn[{{}, {}, {}, {1, maxOptLen}}]; 450 | 451 | optionOut = self[dtype..'_opt_out']:index(1, indVector); 452 | optionOut = optionOut:view(optInds:size(1), optInds:size(2), 453 | optInds:size(3), -1); 454 | optionOut = optionOut[{{}, {}, {}, {1, maxOptLen}}]; 455 | 456 | if params.gpuid >= 0 then 457 | output['option_in'] = optionIn:cuda(); 458 | output['option_out'] = optionOut:cuda(); 459 | else 460 | output['option_in'] = optionIn:contiguous(); 461 | output['option_out'] = optionOut:contiguous(); 462 | end 463 | elseif params.decoder == 'disc' then 464 | local optInds = self[dtype .. '_opt']:index(1, inds) 465 | local indVector = optInds:view(-1) 466 | 467 | local optionIn = self[dtype .. '_opt_list']:index(1, indVector) 468 | 469 | optionIn = optionIn:view(optInds:size(1), optInds:size(2), optInds:size(3), -1) 470 | output = optionIn 471 | 472 | if params.gpuid >= 0 then 473 | output = output:cuda() 474 | end 475 | end 476 | 477 | return output; 478 | end 479 | 480 | return dataloader; 481 | -------------------------------------------------------------------------------- /decoders/disc.lua: -------------------------------------------------------------------------------- 1 | local decoderNet = {} 2 | 3 | function decoderNet.model(params, enc) 4 | local optionLSTM = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 5 | optionLSTM.batchfirst = true 6 | 7 | local optionEnc = {} 8 | local numOptions = 100 9 | for i = 1, numOptions do 10 | optionEnc[i] = nn.Sequential() 11 | optionEnc[i]:add(nn.Select(2,i)) 12 | optionEnc[i]:add(enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')); 13 | optionEnc[i]:add(optionLSTM:clone('weight', 'bias', 'gradWeight', 'gradBias')) 14 | optionEnc[i]:add(nn.Select(2,-1)) 15 | optionEnc[i]:add(nn.Reshape(1,params.rnnHiddenSize, true)) -- True ensures that the first dimension remains the batch size 16 | end 17 | optionEncConcat = nn.Concat(2) 18 | for i = 1, numOptions do 19 | optionEncConcat:add(optionEnc[i]) 20 | end 21 | 22 | local jointModel = nn.ParallelTable() 23 | jointModel:add(optionEncConcat) 24 | jointModel:add(nn.Reshape(params.rnnHiddenSize, 1, true)) 25 | 26 | local dec = nn.Sequential() 27 | dec:add(jointModel) 28 | dec:add(nn.MM()) 29 | dec:add(nn.Squeeze()) 30 | 31 | return dec; 32 | end 33 | 34 | -- dummy forwardConnect 35 | function decoderNet.forwardConnect(enc, dec, encOut, seqLen) end 36 | 37 | -- dummy backwardConnect 38 | function decoderNet.backwardConnect(enc, dec) end 39 | 40 | return decoderNet; 41 | -------------------------------------------------------------------------------- /decoders/gen.lua: -------------------------------------------------------------------------------- 1 | local decoderNet = {} 2 | 3 | function decoderNet.model(params, enc) 4 | -- use `nngraph` 5 | nn.FastLSTM.usenngraph = true 6 | 7 | -- decoder network 8 | local dec = nn.Sequential() 9 | -- use the same embedding for both encoder and decoder lstm 10 | local embedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 11 | dec:add(embedNet); 12 | 13 | dec.rnnLayers = {}; 14 | -- check if decoder has different hidden size 15 | local hiddenSize = (params.ansHiddenSize ~= 0) and params.ansHiddenSize 16 | or params.rnnHiddenSize; 17 | for layer = 1, params.numLayers do 18 | local inputSize = (layer == 1) and params.embedSize or hiddenSize; 19 | dec.rnnLayers[layer] = nn.SeqLSTM(inputSize, hiddenSize); 20 | dec.rnnLayers[layer]:maskZero(); 21 | dec:add(dec.rnnLayers[layer]); 22 | end 23 | dec:add(nn.Sequencer(nn.MaskZero(nn.Linear(hiddenSize, params.vocabSize), 1))) 24 | dec:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), 1))) 25 | 26 | return dec; 27 | end 28 | 29 | -- transfer the hidden state from encoder to decoder 30 | function decoderNet.forwardConnect(enc, dec, encOut, seqLen) 31 | if enc.rnnLayers ~= nil then 32 | for ii = 1, #enc.rnnLayers do 33 | dec.rnnLayers[ii].userPrevOutput = enc.rnnLayers[ii].output[seqLen]; 34 | dec.rnnLayers[ii].userPrevCell = enc.rnnLayers[ii].cell[seqLen]; 35 | end 36 | 37 | -- last layer gets output gradients 38 | dec.rnnLayers[#enc.rnnLayers].userPrevOutput = encOut; 39 | else 40 | dec.rnnLayers[#dec.rnnLayers].userPrevOutput = encOut 41 | end 42 | end 43 | 44 | -- transfer gradients from decoder to encoder 45 | function decoderNet.backwardConnect(enc, dec) 46 | if enc.rnnLayers ~= nil then 47 | -- borrow gradients from decoder 48 | for ii = 1, #dec.rnnLayers do 49 | enc.rnnLayers[ii].userNextGradCell = dec.rnnLayers[ii].userGradPrevCell; 50 | if ii ~= #dec.rnnLayers then 51 | enc.rnnLayers[ii].gradPrevOutput = dec.rnnLayers[ii].userGradPrevOutput; 52 | end 53 | end 54 | 55 | -- return the gradients for the last layer 56 | return dec.rnnLayers[#enc.rnnLayers].userGradPrevOutput; 57 | else 58 | return dec.rnnLayers[#dec.rnnLayers].userGradPrevOutput 59 | end 60 | end 61 | 62 | -- connecting decoder to itself; useful while sampling 63 | function decoderNet.decoderConnect(dec) 64 | for ii = 1, #dec.rnnLayers do 65 | dec.rnnLayers[ii].userPrevCell = dec.rnnLayers[ii].cell[1] 66 | dec.rnnLayers[ii].userPrevOutput = dec.rnnLayers[ii].output[1] 67 | end 68 | end 69 | 70 | return decoderNet; 71 | -------------------------------------------------------------------------------- /encoders/hre-ques-hist.lua: -------------------------------------------------------------------------------- 1 | local encoderNet = {} 2 | 3 | function encoderNet.model(params) 4 | local dropout = params.dropout or 0.5 5 | 6 | -- Use `nngraph` 7 | nn.FastLSTM.usenngraph = true; 8 | 9 | -- encoder network 10 | local enc = nn.Sequential(); 11 | 12 | -- create the two branches 13 | local concat = nn.ConcatTable(); 14 | 15 | -- word branch, along with embedding layer 16 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 17 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed); 18 | 19 | -- language model 20 | enc.rnnLayers = {}; 21 | for layer = 1, params.numLayers do 22 | local inputSize = (layer==1) and (params.embedSize) 23 | or params.rnnHiddenSize; 24 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 25 | enc.rnnLayers[layer]:maskZero(); 26 | 27 | wordBranch:add(enc.rnnLayers[layer]); 28 | end 29 | wordBranch:add(nn.Select(1, -1)); 30 | 31 | -- make clones for embed layer 32 | local qEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 33 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 34 | 35 | -- create two branches 36 | local histBranch = nn.Sequential() 37 | :add(nn.SelectTable(2)) 38 | :add(hEmbedNet); 39 | enc.histLayers = {}; 40 | 41 | -- number of layers to read the history 42 | for layer = 1, params.numLayers do 43 | local inputSize = (layer == 1) and params.embedSize 44 | or params.rnnHiddenSize; 45 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 46 | enc.histLayers[layer]:maskZero(); 47 | 48 | histBranch:add(enc.histLayers[layer]); 49 | end 50 | histBranch:add(nn.Select(1, -1)); 51 | 52 | -- add concatTable and join 53 | concat:add(wordBranch) 54 | concat:add(histBranch) 55 | enc:add(concat); 56 | 57 | -- another concat table 58 | local concat2 = nn.ConcatTable(); 59 | enc:add(nn.JoinTable(1, 1)) 60 | 61 | --change the view of the data 62 | -- always split it back wrt batch size and then do transpose 63 | enc:add(nn.View(-1, params.maxQuesCount, 2*params.rnnHiddenSize)); 64 | enc:add(nn.Transpose({1, 2})); 65 | enc:add(nn.View(params.maxQuesCount, -1, 2*params.rnnHiddenSize)) 66 | enc:add(nn.SeqLSTM(2*params.rnnHiddenSize, params.rnnHiddenSize)) 67 | enc:add(nn.Transpose({1, 2})); 68 | enc:add(nn.View(-1, params.rnnHiddenSize)) 69 | 70 | return enc; 71 | end 72 | 73 | return encoderNet 74 | -------------------------------------------------------------------------------- /encoders/hre-ques-im-hist.lua: -------------------------------------------------------------------------------- 1 | require 'model_utils.MaskTime' 2 | 3 | local encoderNet = {} 4 | 5 | function encoderNet.model(params) 6 | local dropout = params.dropout or 0.5 7 | 8 | -- Use `nngraph` 9 | nn.FastLSTM.usenngraph = true; 10 | 11 | -- encoder network 12 | local enc = nn.Sequential(); 13 | 14 | -- create the two branches 15 | local concat = nn.ConcatTable(); 16 | 17 | -- word branch, along with embedding layer 18 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 19 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed); 20 | 21 | -- make clones for embed layer 22 | local qEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 23 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 24 | 25 | -- create two branches 26 | local histBranch = nn.Sequential() 27 | :add(nn.SelectTable(3)) 28 | :add(hEmbedNet); 29 | enc.histLayers = {}; 30 | -- number of layers to read the history 31 | for layer = 1, params.numLayers do 32 | local inputSize = (layer == 1) and params.embedSize 33 | or params.rnnHiddenSize; 34 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 35 | enc.histLayers[layer]:maskZero(); 36 | 37 | histBranch:add(enc.histLayers[layer]); 38 | end 39 | histBranch:add(nn.Select(1, -1)); 40 | 41 | -- image branch 42 | -- embedding for images 43 | local imgPar = nn.ParallelTable() 44 | :add(nn.Identity()) 45 | :add(nn.Sequential() 46 | -- :add(nn.Dropout(0.5)) 47 | :add(nn.Linear(params.imgFeatureSize, 48 | params.imgEmbedSize))); 49 | -- select words and image only 50 | local imageBranch = nn.Sequential() 51 | :add(nn.NarrowTable(1, 2)) 52 | :add(imgPar) 53 | :add(nn.MaskTime(params.imgEmbedSize)); 54 | 55 | -- add concatTable and join 56 | concat:add(wordBranch) 57 | concat:add(imageBranch) 58 | concat:add(histBranch) 59 | enc:add(concat); 60 | 61 | -- another concat table 62 | local concat2 = nn.ConcatTable(); 63 | 64 | -- select words + image, and history 65 | local wordImageBranch = nn.Sequential() 66 | :add(nn.NarrowTable(1, 2)) 67 | :add(nn.JoinTable(-1)) 68 | 69 | -- language model 70 | enc.rnnLayers = {}; 71 | for layer = 1, params.numLayers do 72 | local inputSize = (layer==1) and (params.imgEmbedSize+params.embedSize) 73 | or params.rnnHiddenSize; 74 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 75 | enc.rnnLayers[layer]:maskZero(); 76 | 77 | wordImageBranch:add(enc.rnnLayers[layer]); 78 | end 79 | wordImageBranch:add(nn.Select(1, -1)); 80 | 81 | -- add both the branches (wordImage, select history) to concat2 82 | concat2:add(wordImageBranch):add(nn.SelectTable(3)); 83 | enc:add(concat2); 84 | 85 | -- join both the tensors 86 | enc:add(nn.JoinTable(-1)); 87 | 88 | -- change the view of the data 89 | -- always split it back wrt batch size and then do transpose 90 | enc:add(nn.View(-1, params.maxQuesCount, 2*params.rnnHiddenSize)); 91 | enc:add(nn.Transpose({1, 2})); 92 | enc:add(nn.SeqLSTM(2*params.rnnHiddenSize, params.rnnHiddenSize)) 93 | enc:add(nn.Transpose({1, 2})); 94 | enc:add(nn.View(-1, params.rnnHiddenSize)) 95 | 96 | return enc; 97 | end 98 | 99 | return encoderNet 100 | -------------------------------------------------------------------------------- /encoders/hrea-ques-im-hist.lua: -------------------------------------------------------------------------------- 1 | require 'model_utils.MaskTime' 2 | require 'model_utils.MaskFuture' 3 | require 'model_utils.ReplaceZero' 4 | 5 | local encoderNet = {} 6 | 7 | function encoderNet.model(params) 8 | local dropout = params.dropout or 0.5 9 | 10 | -- Use `nngraph` 11 | nn.FastLSTM.usenngraph = true; 12 | 13 | -- encoder network 14 | local enc = nn.Sequential(); 15 | 16 | -- create the two branches 17 | local concat = nn.ConcatTable(); 18 | 19 | -- word branch, along with embedding layer 20 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 21 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed); 22 | 23 | -- make clones for embed layer 24 | local qEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 25 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 26 | 27 | -- create two branches 28 | local histBranch = nn.Sequential() 29 | :add(nn.SelectTable(3)) 30 | :add(hEmbedNet); 31 | enc.histLayers = {}; 32 | -- number of layers to read the history 33 | for layer = 1, params.numLayers do 34 | local inputSize = (layer == 1) and params.embedSize 35 | or params.rnnHiddenSize; 36 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 37 | enc.histLayers[layer]:maskZero(); 38 | 39 | histBranch:add(enc.histLayers[layer]); 40 | end 41 | histBranch:add(nn.Select(1, -1)); 42 | 43 | -- image branch 44 | -- embedding for images 45 | local imgPar = nn.ParallelTable() 46 | :add(nn.Identity()) 47 | :add(nn.Sequential() 48 | :add(nn.Dropout(0.5)) 49 | :add(nn.Linear(params.imgFeatureSize, 50 | params.imgEmbedSize))); 51 | -- select words and image only 52 | local imageBranch = nn.Sequential() 53 | :add(nn.NarrowTable(1, 2)) 54 | :add(imgPar) 55 | :add(nn.MaskTime(params.imgEmbedSize)); 56 | 57 | -- add concatTable and join 58 | concat:add(wordBranch) 59 | concat:add(imageBranch) 60 | concat:add(histBranch) 61 | enc:add(concat); 62 | 63 | -- another concat table 64 | local concat2 = nn.ConcatTable(); 65 | 66 | -- select words + image, and history 67 | local wordImageBranch = nn.Sequential() 68 | :add(nn.NarrowTable(1, 2)) 69 | :add(nn.JoinTable(-1)) 70 | 71 | -- language model 72 | enc.rnnLayers = {}; 73 | for layer = 1, params.numLayers do 74 | local inputSize = (layer==1) and (params.imgEmbedSize+params.embedSize) 75 | or params.rnnHiddenSize; 76 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 77 | enc.rnnLayers[layer]:maskZero(); 78 | 79 | wordImageBranch:add(enc.rnnLayers[layer]); 80 | end 81 | wordImageBranch:add(nn.Select(1, -1)); 82 | 83 | -- add both the branches (wordImage, select history) to concat2 84 | concat2:add(wordImageBranch):add(nn.SelectTable(3)); 85 | enc:add(concat2); 86 | 87 | -- single layer neural network 88 | -- create attention based history 89 | local prepare = nn.Sequential() 90 | :add(nn.Linear(params.rnnHiddenSize, 1)) 91 | :add(nn.View(-1, params.maxQuesCount)) 92 | local wordHistBranch = nn.Sequential() 93 | :add(nn.ParallelTable() 94 | :add(prepare:clone()) 95 | :add(prepare:clone())) 96 | :add(nn.ParallelTable() 97 | :add(nn.Replicate(10, 3)) 98 | :add(nn.Replicate(10, 2))) 99 | :add(nn.CAddTable()) 100 | --:add(nn.Tanh()) 101 | :add(nn.MaskFuture(params.maxQuesCount)) 102 | :add(nn.View(-1, params.maxQuesCount)) 103 | :add(nn.ReplaceZero(-1*math.huge)) 104 | :add(nn.SoftMax()) 105 | :add(nn.View(-1, params.maxQuesCount, params.maxQuesCount)) 106 | :add(nn.Replicate(params.rnnHiddenSize, 4)); 107 | 108 | local histOnlyBranch = nn.Sequential() 109 | :add(nn.SelectTable(2)) 110 | :add(nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)) 111 | :add(nn.Replicate(params.maxQuesCount, 2)) 112 | 113 | -- add another concatTable to create attention over history 114 | local concat3 = nn.ConcatTable() 115 | :add(wordHistBranch) 116 | :add(histOnlyBranch) 117 | :add(nn.SelectTable(1)) -- append attended history with question 118 | enc:add(concat3); 119 | 120 | -- parallel table to multiply first two tables, and leave the third one untouched 121 | local multiplier = nn.Sequential() 122 | :add(nn.NarrowTable(1, 2)) 123 | :add(nn.CMulTable()) 124 | :add(nn.Sum(3)) 125 | :add(nn.View(-1, params.rnnHiddenSize)); 126 | local concat4 = nn.ConcatTable() 127 | :add(multiplier) 128 | :add(nn.SelectTable(3)); 129 | enc:add(concat4); 130 | 131 | -- join both the tensors (att over history and encoded question) 132 | enc:add(nn.JoinTable(-1)); 133 | enc:add(nn.View(-1, params.maxQuesCount, 2*params.rnnHiddenSize)) 134 | enc:add(nn.Transpose({1, 2})) 135 | enc:add(nn.SeqLSTM(2 * params.rnnHiddenSize, params.rnnHiddenSize)) 136 | enc:add(nn.Transpose({1, 2})) 137 | enc:add(nn.View(-1, params.rnnHiddenSize)) 138 | 139 | return enc; 140 | end 141 | 142 | return encoderNet 143 | -------------------------------------------------------------------------------- /encoders/lf-att-ques-im-hist.lua: -------------------------------------------------------------------------------- 1 | local encoderNet = {} 2 | 3 | function encoderNet.model(params) 4 | 5 | local inputs = {} 6 | local outputs = {} 7 | 8 | table.insert(inputs, nn.Identity()()) -- question 9 | table.insert(inputs, nn.Identity()()) -- img feats 10 | table.insert(inputs, nn.Identity()()) -- history 11 | 12 | local ques = inputs[1] 13 | local img_feats = inputs[2] 14 | local hist = inputs[3] 15 | 16 | -- word embed layer 17 | wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 18 | 19 | -- make clones for embed layer 20 | local qEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(ques)); 21 | local hEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(hist)); 22 | 23 | local lst1 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 24 | lst1:maskZero() 25 | 26 | local lst2 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize) 27 | lst2:maskZero() 28 | 29 | local h1 = lst1(hEmbed) 30 | local h2 = lst2(h1) 31 | local h3 = nn.Select(1, -1)(h2) 32 | 33 | local lst3 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 34 | lst3:maskZero() 35 | 36 | local lst4 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize) 37 | lst4:maskZero() 38 | 39 | local q1 = lst3(qEmbed) 40 | local q2 = lst4(q1) 41 | local q3 = nn.Select(1, -1)(q2) 42 | 43 | local qh = nn.Tanh()(nn.Linear(2*params.rnnHiddenSize, params.rnnHiddenSize)(nn.JoinTable(1,1)({q3,h3}))) 44 | 45 | -- image attention (inspired by SAN, Yang et al., CVPR16) 46 | local img_tr_size = params.rnnHiddenSize 47 | local rnn_size = params.rnnHiddenSize 48 | local common_embedding_size = params.commonEmbeddingSize 49 | local num_attention_layer = 1 50 | 51 | local u = qh 52 | local img_tr = nn.Dropout(0.5)( 53 | nn.Tanh()( 54 | nn.View(-1, params.imgSpatialSize * params.imgSpatialSize, img_tr_size)( 55 | nn.Linear(params.imgFeatureSize, img_tr_size)( 56 | nn.View(params.imgFeatureSize):setNumInputDims(2)(img_feats))))) 57 | 58 | for i = 1, num_attention_layer do 59 | 60 | -- linear layer: 14x14x1024 -> 14x14x512 61 | local img_common = nn.View(-1, params.imgSpatialSize * params.imgSpatialSize, common_embedding_size)( 62 | nn.Linear(img_tr_size, common_embedding_size)( 63 | nn.View(-1, img_tr_size)(img_tr))) 64 | 65 | -- replicate lstm state 196 times 66 | local ques_common = nn.Linear(rnn_size, common_embedding_size)(u) 67 | local ques_repl = nn.Replicate(params.imgSpatialSize * params.imgSpatialSize, 2)(ques_common) 68 | 69 | -- add image and question features (both 196x512) 70 | local img_ques_common = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({img_common, ques_repl}))) 71 | local h = nn.Linear(common_embedding_size, 1)(nn.View(-1, common_embedding_size)(img_ques_common)) 72 | local p = nn.SoftMax()(nn.View(-1, params.imgSpatialSize * params.imgSpatialSize)(h)) 73 | 74 | -- weighted sum of image features 75 | local p_att = nn.View(1, -1):setNumInputDims(1)(p) 76 | local img_tr_att = nn.MM(false, false)({p_att, img_tr}) 77 | local img_tr_att_feat = nn.View(-1, img_tr_size)(img_tr_att) 78 | 79 | -- add image feature vector and question vector 80 | u = nn.CAddTable()({img_tr_att_feat, u}) 81 | 82 | end 83 | 84 | local o = nn.Tanh()(nn.Linear(rnn_size, rnn_size)(nn.Dropout(0.5)(u))) 85 | -- SAN stuff ends 86 | 87 | table.insert(outputs, o) 88 | 89 | local enc = nn.gModule(inputs, outputs) 90 | enc.wordEmbed = wordEmbed 91 | 92 | return enc; 93 | end 94 | 95 | return encoderNet 96 | -------------------------------------------------------------------------------- /encoders/lf-ques-hist.lua: -------------------------------------------------------------------------------- 1 | local encoderNet = {}; 2 | 3 | function encoderNet.model(params) 4 | local dropout = params.dropout or 0.5 5 | -- Use nngraph 6 | nn.FastLSTM.usenngraph = true; 7 | 8 | -- encoder network 9 | local enc = nn.Sequential(); 10 | 11 | -- create the two branches 12 | local concat = nn.ConcatTable(); 13 | 14 | -- word branch, along with embedding layer 15 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 16 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed); 17 | 18 | -- language model 19 | enc.rnnLayers = {}; 20 | for layer = 1, params.numLayers do 21 | local inputSize = (layer==1) and (params.embedSize) 22 | or params.rnnHiddenSize; 23 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 24 | enc.rnnLayers[layer]:maskZero(); 25 | 26 | wordBranch:add(enc.rnnLayers[layer]); 27 | end 28 | wordBranch:add(nn.Select(1, -1)); 29 | 30 | -- make clones for embed layer 31 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 32 | 33 | -- create two branches 34 | local histBranch = nn.Sequential() 35 | :add(nn.SelectTable(2)) 36 | :add(hEmbedNet); 37 | enc.histLayers = {}; 38 | -- number of layers to read the history 39 | for layer = 1, params.numLayers do 40 | local inputSize = (layer == 1) and params.embedSize 41 | or params.rnnHiddenSize; 42 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 43 | enc.histLayers[layer]:maskZero(); 44 | 45 | histBranch:add(enc.histLayers[layer]); 46 | end 47 | histBranch:add(nn.Select(1, -1)); 48 | 49 | -- add concatTable and join 50 | concat:add(wordBranch) 51 | concat:add(histBranch) 52 | enc:add(concat); 53 | 54 | enc:add(nn.JoinTable(2)) 55 | if dropout > 0 then 56 | enc:add(nn.Dropout(dropout)) 57 | end 58 | enc:add(nn.Linear(2 * params.rnnHiddenSize, params.rnnHiddenSize)) 59 | enc:add(nn.Tanh()) 60 | 61 | return enc; 62 | end 63 | 64 | return encoderNet 65 | -------------------------------------------------------------------------------- /encoders/lf-ques-im-hist.lua: -------------------------------------------------------------------------------- 1 | local encoderNet = {} 2 | 3 | function encoderNet.model(params) 4 | local dropout = params.dropout or 0.5 5 | -- Use `nngraph` 6 | nn.FastLSTM.usenngraph = true; 7 | 8 | -- encoder network 9 | local enc = nn.Sequential(); 10 | 11 | -- create the two branches 12 | local concat = nn.ConcatTable(); 13 | 14 | -- word branch, along with embedding layer 15 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 16 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed); 17 | 18 | -- language model 19 | enc.rnnLayers = {}; 20 | for layer = 1, params.numLayers do 21 | local inputSize = (layer==1) and (params.embedSize) 22 | or params.rnnHiddenSize; 23 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 24 | enc.rnnLayers[layer]:maskZero(); 25 | 26 | wordBranch:add(enc.rnnLayers[layer]); 27 | end 28 | wordBranch:add(nn.Select(1, -1)); 29 | 30 | -- make clones for embed layer 31 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 32 | 33 | -- create two branches 34 | local histBranch = nn.Sequential() 35 | :add(nn.SelectTable(3)) 36 | :add(hEmbedNet); 37 | enc.histLayers = {}; 38 | -- number of layers to read the history 39 | for layer = 1, params.numLayers do 40 | local inputSize = (layer == 1) and params.embedSize 41 | or params.rnnHiddenSize; 42 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 43 | enc.histLayers[layer]:maskZero(); 44 | 45 | histBranch:add(enc.histLayers[layer]); 46 | end 47 | histBranch:add(nn.Select(1, -1)); 48 | 49 | concat:add(wordBranch) 50 | concat:add(nn.SelectTable(2)) 51 | concat:add(histBranch) 52 | enc:add(concat); 53 | 54 | enc:add(nn.JoinTable(2)) 55 | if dropout > 0 then 56 | enc:add(nn.Dropout(dropout)) 57 | end 58 | enc:add(nn.Linear(2 * params.rnnHiddenSize + params.imgFeatureSize, params.rnnHiddenSize)) 59 | enc:add(nn.Tanh()) 60 | 61 | return enc; 62 | end 63 | 64 | return encoderNet 65 | -------------------------------------------------------------------------------- /encoders/lf-ques-im.lua: -------------------------------------------------------------------------------- 1 | local encoderNet = {} 2 | 3 | function encoderNet.model(params) 4 | local dropout = params.dropout or 0.5 5 | -- Use `nngraph` 6 | nn.FastLSTM.usenngraph = true; 7 | 8 | -- encoder network 9 | local enc = nn.Sequential(); 10 | 11 | -- create the two branches 12 | local concat = nn.ConcatTable(); 13 | 14 | -- word branch, along with embedding layer 15 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 16 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed); 17 | 18 | -- language model 19 | enc.rnnLayers = {}; 20 | for layer = 1, params.numLayers do 21 | local inputSize = (layer==1) and (params.embedSize) 22 | or params.rnnHiddenSize; 23 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 24 | enc.rnnLayers[layer]:maskZero(); 25 | 26 | wordBranch:add(enc.rnnLayers[layer]); 27 | end 28 | wordBranch:add(nn.Select(1, -1)); 29 | 30 | -- add concatTable and join 31 | concat:add(wordBranch) 32 | concat:add(nn.SelectTable(2)) 33 | enc:add(concat); 34 | 35 | enc:add(nn.JoinTable(2)) 36 | if dropout > 0 then 37 | enc:add(nn.Dropout(dropout)) 38 | end 39 | enc:add(nn.Linear(params.rnnHiddenSize + params.imgFeatureSize, params.rnnHiddenSize)) 40 | enc:add(nn.Tanh()) 41 | 42 | return enc; 43 | end 44 | 45 | return encoderNet 46 | -------------------------------------------------------------------------------- /encoders/lf-ques.lua: -------------------------------------------------------------------------------- 1 | local encoderNet = {}; 2 | 3 | function encoderNet.model(params) 4 | local dropout = params.dropout or 0.5; 5 | -- Use `nngraph` 6 | nn.FastLSTM.usenngraph = true; 7 | 8 | -- encoder network 9 | local enc = nn.Sequential(); 10 | 11 | -- word branch, along with embedding layer 12 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 13 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed); 14 | 15 | -- language model 16 | enc.rnnLayers = {}; 17 | for layer = 1, params.numLayers do 18 | local inputSize = (layer==1) and (params.embedSize) 19 | or params.rnnHiddenSize; 20 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 21 | enc.rnnLayers[layer]:maskZero(); 22 | 23 | wordBranch:add(enc.rnnLayers[layer]); 24 | end 25 | wordBranch:add(nn.Select(1, -1)); 26 | 27 | enc:add(wordBranch); 28 | 29 | if dropout > 0 then 30 | enc:add(nn.Dropout(dropout)) 31 | end 32 | enc:add(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)) 33 | enc:add(nn.Tanh()) 34 | 35 | return enc; 36 | end 37 | 38 | return encoderNet; 39 | -------------------------------------------------------------------------------- /encoders/mn-att-ques-im-hist.lua: -------------------------------------------------------------------------------- 1 | require 'model_utils.MaskSoftMax' 2 | 3 | local encoderNet = {} 4 | 5 | function encoderNet.model(params) 6 | 7 | local inputs = {} 8 | local outputs = {} 9 | 10 | table.insert(inputs, nn.Identity()()) -- question 11 | table.insert(inputs, nn.Identity()()) -- img feats 12 | table.insert(inputs, nn.Identity()()) -- history 13 | table.insert(inputs, nn.Identity()()) -- 10x10 mask 14 | 15 | local ques = inputs[1] 16 | local img_feats = inputs[2] 17 | local hist = inputs[3] 18 | local mask = inputs[4] 19 | 20 | -- word embed layer 21 | wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 22 | 23 | -- make clones for embed layer 24 | local qEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(ques)); 25 | local hEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(hist)); 26 | 27 | local lst1 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 28 | lst1:maskZero() 29 | 30 | local lst2 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize) 31 | lst2:maskZero() 32 | 33 | local h1 = lst1(hEmbed) 34 | local h2 = lst2(h1) 35 | local h3 = nn.Select(1, -1)(h2) 36 | 37 | local lst3 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 38 | lst3:maskZero() 39 | 40 | local lst4 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize) 41 | lst4:maskZero() 42 | 43 | local q1 = lst3(qEmbed) 44 | local q2 = lst4(q1) 45 | local q3 = nn.Select(1, -1)(q2) 46 | 47 | -- View as batch x rounds 48 | local qEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(q3) 49 | local hEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(h3) 50 | 51 | -- Inner product 52 | -- q is Bx10xE, h is Bx10xE 53 | -- qh is Bx10x10, rows correspond to questions, columns to facts 54 | local qh = nn.MM(false, true)({qEmbedView, hEmbedView}) 55 | local qhView = nn.View(-1, params.maxQuesCount)(qh) 56 | local qhprobs = nn.MaskSoftMax(){qhView, mask} 57 | local qhView2 = nn.View(-1, params.maxQuesCount, params.maxQuesCount)(qhprobs) 58 | 59 | -- Weighted sum of h features 60 | -- h is Bx10xE, qhView2 is Bx10x10 61 | local hAtt = nn.MM(){qhView2, hEmbedView} 62 | local hAttView = nn.View(-1, params.rnnHiddenSize)(hAtt) 63 | 64 | local hAttTr = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.Dropout(0.5)(hAttView))) 65 | local qh2 = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.CAddTable(){hAttTr, nn.View(-1, params.rnnHiddenSize)(qEmbedView)})) 66 | 67 | -- image attention (inspired by SAN, Yang et al., CVPR16) 68 | local img_tr_size = params.rnnHiddenSize 69 | local rnn_size = params.rnnHiddenSize 70 | local common_embedding_size = params.commonEmbeddingSize or 512 71 | local num_attention_layer = params.numAttentionLayers or 1 72 | 73 | local u = qh2 74 | local img_tr = nn.Dropout(0.5)( 75 | nn.Tanh()( 76 | nn.View(-1, params.imgSpatialSize * params.imgSpatialSize, img_tr_size)( 77 | nn.Linear(params.imgFeatureSize, img_tr_size)( 78 | nn.View(params.imgFeatureSize):setNumInputDims(2)(img_feats))))) 79 | 80 | for i = 1, num_attention_layer do 81 | 82 | -- linear layer: 14x14x1024 -> 14x14x512 83 | local img_common = nn.View(-1, params.imgSpatialSize * params.imgSpatialSize, common_embedding_size)( 84 | nn.Linear(img_tr_size, common_embedding_size)( 85 | nn.View(-1, img_tr_size)(img_tr))) 86 | 87 | -- replicate lstm state 196 times 88 | local ques_common = nn.Linear(rnn_size, common_embedding_size)(u) 89 | local ques_repl = nn.Replicate(params.imgSpatialSize * params.imgSpatialSize, 2)(ques_common) 90 | 91 | -- add image and question features (both 196x512) 92 | local img_ques_common = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({img_common, ques_repl}))) 93 | local h = nn.Linear(common_embedding_size, 1)(nn.View(-1, common_embedding_size)(img_ques_common)) 94 | local p = nn.SoftMax()(nn.View(-1, params.imgSpatialSize * params.imgSpatialSize)(h)) 95 | 96 | -- weighted sum of image features 97 | local p_att = nn.View(1, -1):setNumInputDims(1)(p) 98 | local img_tr_att = nn.MM(false, false)({p_att, img_tr}) 99 | local img_tr_att_feat = nn.View(-1, img_tr_size)(img_tr_att) 100 | 101 | -- add image feature vector and question vector 102 | u = nn.CAddTable()({img_tr_att_feat, u}) 103 | 104 | end 105 | 106 | local o = nn.Tanh()(nn.Linear(rnn_size, rnn_size)(nn.Dropout(0.5)(u))) 107 | -- SAN stuff ends 108 | 109 | table.insert(outputs, o) 110 | 111 | local enc = nn.gModule(inputs, outputs) 112 | enc.wordEmbed = wordEmbed 113 | 114 | return enc; 115 | end 116 | 117 | return encoderNet 118 | -------------------------------------------------------------------------------- /encoders/mn-ques-hist.lua: -------------------------------------------------------------------------------- 1 | require 'model_utils.MaskSoftMax' 2 | 3 | local encoderNet = {} 4 | 5 | function encoderNet.model(params) 6 | 7 | local inputs = {} 8 | local outputs = {} 9 | 10 | table.insert(inputs, nn.Identity()()) -- question 11 | table.insert(inputs, nn.Identity()()) -- history 12 | table.insert(inputs, nn.Identity()()) -- 10x10 mask 13 | 14 | local ques = inputs[1] 15 | local hist = inputs[2] 16 | local mask = inputs[3] 17 | 18 | -- word embed layer 19 | wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 20 | 21 | -- make clones for embed layer 22 | local qEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(ques)); 23 | local hEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(hist)); 24 | 25 | local lst1 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 26 | lst1:maskZero() 27 | 28 | local lst2 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize) 29 | lst2:maskZero() 30 | 31 | local h1 = lst1(hEmbed) 32 | local h2 = lst2(h1) 33 | local h3 = nn.Select(1, -1)(h2) 34 | 35 | local lst3 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 36 | lst3:maskZero() 37 | 38 | local lst4 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize) 39 | lst4:maskZero() 40 | 41 | local q1 = lst3(qEmbed) 42 | local q2 = lst4(q1) 43 | local q3 = nn.Select(1, -1)(q2) 44 | 45 | -- View as batch x rounds 46 | local qEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(q3) 47 | local hEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(h3) 48 | 49 | -- Inner product 50 | -- q is Bx10xE, h is Bx10xE 51 | -- qh is Bx10x10, rows correspond to questions, columns to facts 52 | local qh = nn.MM(false, true)({qEmbedView, hEmbedView}) 53 | local qhView = nn.View(-1, params.maxQuesCount)(qh) 54 | local qhprobs = nn.MaskSoftMax(){qhView, mask} 55 | local qhView2 = nn.View(-1, params.maxQuesCount, params.maxQuesCount)(qhprobs) 56 | 57 | -- Weighted sum of h features 58 | -- h is Bx10xE, qhView2 is Bx10x10 59 | local hAtt = nn.MM(){qhView2, hEmbedView} 60 | local hAttView = nn.View(-1, params.rnnHiddenSize)(hAtt) 61 | 62 | local hAttTr = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.Dropout(0.5)(hAttView))) 63 | local qh2 = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.CAddTable(){hAttTr, nn.View(-1, params.rnnHiddenSize)(qEmbedView)})) 64 | 65 | table.insert(outputs, qh2) 66 | 67 | local enc = nn.gModule(inputs, outputs) 68 | enc.wordEmbed = wordEmbed 69 | 70 | return enc; 71 | end 72 | 73 | return encoderNet 74 | -------------------------------------------------------------------------------- /encoders/mn-ques-im-hist.lua: -------------------------------------------------------------------------------- 1 | require 'model_utils.MaskSoftMax' 2 | 3 | local encoderNet = {} 4 | 5 | function encoderNet.model(params) 6 | 7 | local inputs = {} 8 | local outputs = {} 9 | 10 | table.insert(inputs, nn.Identity()()) -- question 11 | table.insert(inputs, nn.Identity()()) -- img feats 12 | table.insert(inputs, nn.Identity()()) -- history 13 | table.insert(inputs, nn.Identity()()) -- 10x10 mask 14 | 15 | local ques = inputs[1] 16 | local img_feats = inputs[2] 17 | local hist = inputs[3] 18 | local mask = inputs[4] 19 | 20 | -- word embed layer 21 | wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 22 | 23 | -- make clones for embed layer 24 | local qEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(ques)); 25 | local hEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(hist)); 26 | 27 | local lst1 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 28 | lst1:maskZero() 29 | 30 | local lst2 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize) 31 | lst2:maskZero() 32 | 33 | local h1 = lst1(hEmbed) 34 | local h2 = lst2(h1) 35 | local h3 = nn.Select(1, -1)(h2) 36 | 37 | local lst3 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize) 38 | lst3:maskZero() 39 | 40 | local lst4 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize) 41 | lst4:maskZero() 42 | 43 | local q1 = lst3(qEmbed) 44 | local q2 = lst4(q1) 45 | local q3 = nn.Select(1, -1)(q2) 46 | 47 | local qi = nn.JoinTable(1,1)({q3,img_feats}) 48 | local qi_proj = nn.Tanh()(nn.Linear(params.imgFeatureSize + params.rnnHiddenSize, params.rnnHiddenSize)(qi)) 49 | 50 | -- View as batch x rounds 51 | local qEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(qi_proj) 52 | local hEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(h3) 53 | 54 | -- Inner product 55 | -- q is Bx10xE, h is Bx10xE 56 | -- qh is Bx10x10, rows correspond to questions, columns to facts 57 | local qh = nn.MM(false, true)({qEmbedView, hEmbedView}) 58 | local qhView = nn.View(-1, params.maxQuesCount)(qh) 59 | local qhprobs = nn.MaskSoftMax(){qhView, mask} 60 | local qhView2 = nn.View(-1, params.maxQuesCount, params.maxQuesCount)(qhprobs) 61 | 62 | -- Weighted sum of h features 63 | -- h is Bx10xE, qhView2 is Bx10x10 64 | local hAtt = nn.MM(){qhView2, hEmbedView} 65 | local hAttView = nn.View(-1, params.rnnHiddenSize)(hAtt) 66 | 67 | local hAttTr = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.Dropout(0.5)(hAttView))) 68 | local qh2 = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.CAddTable(){hAttTr, nn.View(-1, params.rnnHiddenSize)(qEmbedView)})) 69 | 70 | table.insert(outputs, qh2) 71 | 72 | local enc = nn.gModule(inputs, outputs) 73 | enc.wordEmbed = wordEmbed 74 | 75 | return enc; 76 | end 77 | 78 | return encoderNet 79 | -------------------------------------------------------------------------------- /evaluate.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'rnn' 3 | require 'nngraph' 4 | utils = dofile('utils.lua'); 5 | 6 | ------------------------------------------------------------------------------- 7 | -- Input arguments and options 8 | ------------------------------------------------------------------------------- 9 | cmd = torch.CmdLine() 10 | cmd:text() 11 | cmd:text('Test the VisDial model for retrieval') 12 | cmd:text() 13 | cmd:text('Options') 14 | 15 | -- Data input settings 16 | cmd:option('-inputImg','data/data_img.h5','h5file path with image feature') 17 | cmd:option('-inputQues','data/visdial_data.h5','h5file file with preprocessed questions') 18 | cmd:option('-inputJson','data/visdial_params.json','json path with info and vocab') 19 | 20 | cmd:option('-loadPath', 'checkpoints/model.t7', 'path to saved model') 21 | cmd:option('-split', 'val', 'split to evaluate on') 22 | cmd:option('-useGt', false, 'whether to use ground truth for retrieving ranks') 23 | 24 | -- Inference params 25 | cmd:option('-batchSize', 30, 'Batch size (number of threads) (Adjust base on GRAM)') 26 | cmd:option('-gpuid', 0, 'GPU id to use') 27 | cmd:option('-backend', 'cudnn', 'nn|cudnn') 28 | 29 | cmd:option('-saveRanks', false, 'Whether to save ranks or not'); 30 | cmd:option('-saveRankPath', 'logs/ranks.json'); 31 | 32 | local opt = cmd:parse(arg); 33 | 34 | if opt.useGt and opt.split == 'test' then 35 | print('Warning: No ground truth avaiilable in test split, changing useGt to false.') 36 | opt.useGt = false 37 | end 38 | print(opt) 39 | 40 | -- seed for reproducibility 41 | torch.manualSeed(1234); 42 | 43 | -- set default tensor based on gpu usage 44 | if opt.gpuid >= 0 then 45 | require 'cutorch' 46 | require 'cunn' 47 | if opt.backend == 'cudnn' then require 'cudnn' end 48 | cutorch.setDevice(opt.gpuid+1) 49 | cutorch.manualSeed(1234) 50 | torch.setdefaulttensortype('torch.CudaTensor'); 51 | else 52 | torch.setdefaulttensortype('torch.FloatTensor'); 53 | end 54 | 55 | ------------------------------------------------------------------------ 56 | -- Read saved model and parameters 57 | ------------------------------------------------------------------------ 58 | local savedModel = torch.load(opt.loadPath) 59 | 60 | -- transfer all options to model 61 | local modelParams = savedModel.modelParams 62 | 63 | opt.imgNorm = modelParams.imgNorm 64 | opt.encoder = modelParams.encoder 65 | opt.decoder = modelParams.decoder 66 | modelParams.gpuid = opt.gpuid 67 | modelParams.batchSize = opt.batchSize 68 | modelParams.useGt = opt.useGt 69 | 70 | -- add flags for various configurations 71 | -- additionally check if its imitation of discriminative model 72 | if string.match(opt.encoder, 'hist') then opt.useHistory = true; end 73 | if string.match(opt.encoder, 'im') then opt.useIm = true; end 74 | -- check if history is to be concatenated (only for late fusion encoder) 75 | if string.match(opt.encoder, 'lf') then opt.concatHistory = true end 76 | 77 | ------------------------------------------------------------------------ 78 | -- Loading dataset 79 | ------------------------------------------------------------------------ 80 | local dataloader = dofile('dataloader.lua') 81 | dataloader:initialize(opt, {opt.split}); 82 | collectgarbage(); 83 | 84 | ------------------------------------------------------------------------ 85 | -- Setup the model 86 | ------------------------------------------------------------------------ 87 | require 'model' 88 | local model = Model(modelParams) 89 | 90 | -- copy the weights from loaded model 91 | model.wrapperW:copy(savedModel.modelW); 92 | 93 | ------------------------------------------------------------------------ 94 | -- Evaluation 95 | ------------------------------------------------------------------------ 96 | print('Evaluating..') 97 | local ranks; 98 | if opt.useGt then 99 | ranks = model:retrieve(dataloader, opt.split); 100 | else 101 | ranks = model:predict(dataloader, opt.split); 102 | end 103 | 104 | if opt.saveRanks == true then 105 | print(string.format('Writing ranks to %s', opt.saveRankPath)); 106 | utils.writeJSON(opt.saveRankPath, ranks); 107 | end 108 | -------------------------------------------------------------------------------- /generate.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'rnn' 3 | require 'nngraph' 4 | utils = dofile('utils.lua'); 5 | 6 | ------------------------------------------------------------------------------- 7 | -- Input arguments and options 8 | ------------------------------------------------------------------------------- 9 | cmd = torch.CmdLine() 10 | cmd:text() 11 | cmd:text('Test the VisDial model for retrieval') 12 | cmd:text() 13 | cmd:text('Options') 14 | 15 | -- Data input settings 16 | cmd:option('-inputImg','data/data_img.h5','h5file path with image feature') 17 | cmd:option('-inputQues','data/visdial_data.h5','h5file file with preprocessed questions') 18 | cmd:option('-inputJson','data/visdial_params.json','json path with info and vocab') 19 | 20 | cmd:option('-loadPath', 'checkpoints/model.t7', 'path to saved model') 21 | cmd:option('-resultPath', 'vis/results', 'path to save generated results') 22 | 23 | -- sampling params 24 | cmd:option('-beamSize', 5, 'Beam size') 25 | cmd:option('-beamLen', 20, 'Beam length') 26 | cmd:option('-sampleWords', 0, 'Whether to sample') 27 | cmd:option('-temperature', 1.0, 'Sampling temperature') 28 | cmd:option('-maxThreads', 50, 'Max threads') 29 | cmd:option('-gpuid', 0, 'GPU id to use') 30 | cmd:option('-backend', 'cudnn', 'nn|cudnn') 31 | 32 | local opt = cmd:parse(arg); 33 | print(opt) 34 | 35 | -- seed for reproducibility 36 | torch.manualSeed(1234); 37 | 38 | -- set default tensor based on gpu usage 39 | if opt.gpuid >= 0 then 40 | require 'cutorch' 41 | require 'cunn' 42 | if opt.backend == 'cudnn' then require 'cudnn' end 43 | cutorch.setDevice(opt.gpuid+1) 44 | cutorch.manualSeed(1234) 45 | torch.setdefaulttensortype('torch.CudaTensor'); 46 | else 47 | torch.setdefaulttensortype('torch.FloatTensor'); 48 | end 49 | 50 | ------------------------------------------------------------------------ 51 | -- Read saved model and parameters 52 | ------------------------------------------------------------------------ 53 | local savedModel = torch.load(opt.loadPath) 54 | 55 | -- transfer all options to model 56 | local modelParams = savedModel.modelParams 57 | opt.imgNorm = modelParams.imgNorm 58 | opt.encoder = modelParams.encoder 59 | opt.decoder = modelParams.decoder 60 | modelParams.gpuid = opt.gpuid 61 | 62 | -- add flags for various configurations 63 | -- additionally check if its imitation of discriminative model 64 | if string.match(opt.encoder, 'hist') then 65 | opt.useHistory = true; 66 | end 67 | if string.match(opt.encoder, 'im') then opt.useIm = true; end 68 | 69 | ------------------------------------------------------------------------ 70 | -- Loading dataset 71 | ------------------------------------------------------------------------ 72 | local dataloader = dofile('dataloader.lua') 73 | dataloader:initialize(opt, {'val'}); 74 | collectgarbage(); 75 | 76 | ------------------------------------------------------------------------ 77 | -- Setup the model 78 | ------------------------------------------------------------------------ 79 | require 'model' 80 | local model = Model(modelParams) 81 | 82 | -- copy the weights from loaded model 83 | model.wrapperW:copy(savedModel.modelW); 84 | 85 | ------------------------------------------------------------------------ 86 | -- Generating 87 | ------------------------------------------------------------------------ 88 | sampleParams = { 89 | beamSize = opt.beamSize, 90 | beamLen = opt.beamLen, 91 | maxThreads = opt.maxThreads, 92 | sampleWords = opt.sampleWords, 93 | temperature = opt.temperature 94 | } 95 | 96 | local answers = model:generateAnswers(dataloader, 'val', sampleParams) 97 | local output = {opts = opt, data = answers} 98 | 99 | -- save the file to json 100 | local savePath = string.format('%s/results.json', opt.resultPath); 101 | paths.mkdir(opt.resultPath) 102 | utils.writeJSON(savePath, output); 103 | print('Writing the results to '.. savePath); 104 | 105 | -------------------------------------------------------------------------------- /model.lua: -------------------------------------------------------------------------------- 1 | -- abstract class for models 2 | require 'model_utils.optim_updates' 3 | require 'xlua' 4 | require 'hdf5' 5 | 6 | local utils = require 'utils' 7 | 8 | local Model = torch.class('Model'); 9 | 10 | -- initialize 11 | function Model:__init(params) 12 | print('Setting up model..') 13 | self.params = params 14 | 15 | print('Encoder: ', params.encoder) 16 | print('Decoder: ', params.decoder) 17 | 18 | -- build the model - encoder, decoder 19 | local encFile = string.format('encoders/%s.lua', params.encoder); 20 | local encoder = dofile(encFile); 21 | 22 | local decFile = string.format('decoders/%s.lua', params.decoder); 23 | local decoder = dofile(decFile); 24 | 25 | enc = encoder.model(params) 26 | dec = decoder.model(params, enc) 27 | 28 | local decMethods = {'forwardConnect', 'backwardConnect', 'decoderConnect'} 29 | for key, value in pairs(decMethods) do self[value] = decoder[value]; end 30 | 31 | -- criterion 32 | if params.decoder == 'gen' then 33 | self.criterion = nn.ClassNLLCriterion(); 34 | self.criterion.sizeAverage = false; 35 | self.criterion = nn.SequencerCriterion( 36 | nn.MaskZeroCriterion(self.criterion, 1)); 37 | elseif params.decoder == 'disc' then 38 | self.criterion = nn.CrossEntropyCriterion() 39 | end 40 | 41 | -- wrap the models 42 | self.wrapper = nn.Sequential():add(enc):add(dec); 43 | 44 | -- initialize weights 45 | self.wrapper = require('model_utils/weight-init')(self.wrapper, params.weightInit); 46 | 47 | -- ship to gpu if necessary 48 | if params.gpuid >= 0 then 49 | self.wrapper = self.wrapper:cuda(); 50 | self.criterion = self.criterion:cuda(); 51 | end 52 | 53 | self.encoder = self.wrapper:get(1); 54 | self.decoder = self.wrapper:get(2); 55 | self.wrapperW, self.wrapperdW = self.wrapper:getParameters(); 56 | 57 | self.wrapper:training(); 58 | 59 | -- setup the optimizer 60 | self.optims = {}; 61 | self.optims.learningRate = params.learningRate; 62 | end 63 | 64 | ------------------------------------------------------------------------------- 65 | -- One iteration of training -- forward and backward pass 66 | function Model:trainIteration(dataloader) 67 | -- clear the gradients 68 | self.wrapper:zeroGradParameters(); 69 | 70 | -- grab a training batch 71 | local batch = dataloader:getTrainBatch(self.params); 72 | 73 | -- call the internal function for model specific actions 74 | local curLoss = self:forwardBackward(batch); 75 | 76 | if self.params.decoder == 'gen' then 77 | -- count the number of tokens 78 | local numTokens = torch.sum(batch['answer_out']:gt(0)); 79 | 80 | -- update the running average of loss 81 | if runningLoss > 0 then 82 | runningLoss = 0.95 * runningLoss + 0.05 * curLoss/numTokens; 83 | else 84 | runningLoss = curLoss/numTokens; 85 | end 86 | elseif self.params.decoder == 'disc' then 87 | -- update the running average of loss 88 | if runningLoss > 0 then 89 | runningLoss = 0.95 * runningLoss + 0.05 * curLoss 90 | else 91 | runningLoss = curLoss 92 | end 93 | end 94 | 95 | -- clamp gradients 96 | self.wrapperdW:clamp(-5.0, 5.0); 97 | 98 | -- update parameters 99 | adam(self.wrapperW, self.wrapperdW, self.optims); 100 | 101 | -- decay learning rate, if needed 102 | if self.optims.learningRate > self.params.minLRate then 103 | self.optims.learningRate = self.optims.learningRate * 104 | self.params.lrDecayRate; 105 | end 106 | end 107 | --------------------------------------------------------------------- 108 | -- validation performance on test/val 109 | function Model:evaluate(dataloader, dtype) 110 | -- change to evaluate mode 111 | self.wrapper:evaluate(); 112 | 113 | local curLoss = 0; 114 | local startId = 1; 115 | local numThreads = dataloader.numThreads[dtype]; 116 | 117 | local numTokens = 0; 118 | while startId <= numThreads do 119 | -- print progress 120 | xlua.progress(startId, numThreads); 121 | 122 | -- grab a validation batch 123 | local batch, nextStartId 124 | = dataloader:getTestBatch(startId, self.params, dtype); 125 | -- count the number of tokens 126 | numTokens = numTokens + torch.sum(batch['answer_out']:gt(0)); 127 | -- forward pass to compute loss 128 | curLoss = curLoss + self:forwardBackward(batch, true); 129 | startId = nextStartId; 130 | end 131 | 132 | -- print the results 133 | curLoss = curLoss / numTokens; 134 | print(string.format('\n%s\tLoss: %f\t Perplexity: %f\n', dtype, 135 | curLoss, math.exp(curLoss))); 136 | 137 | -- change back to training 138 | self.wrapper:training(); 139 | end 140 | 141 | -- retrieval performance on val 142 | function Model:retrieve(dataloader, dtype) 143 | -- change to evaluate mode 144 | self.wrapper:evaluate(); 145 | 146 | local curLoss = 0; 147 | local startId = 1; 148 | self.params.numOptions = 100; 149 | local numThreads = dataloader.numThreads[dtype]; 150 | print('numThreads', numThreads) 151 | 152 | local ranks = torch.Tensor(numThreads, self.params.maxQuesCount); 153 | ranks:fill(self.params.numOptions + 1); 154 | 155 | while startId <= numThreads do 156 | -- print progress 157 | xlua.progress(startId, numThreads); 158 | 159 | -- grab a batch 160 | local batch, nextStartId = 161 | dataloader:getTestBatch(startId, self.params, dtype); 162 | 163 | -- Call retrieve function for specific model, and store ranks 164 | ranks[{{startId, nextStartId - 1}, {}}] = self:retrieveBatch(batch); 165 | startId = nextStartId; 166 | end 167 | 168 | print(string.format('\n%s - Retrieval:', dtype)) 169 | utils.processRanks(ranks); 170 | 171 | -- change back to training 172 | self.wrapper:training(); 173 | 174 | local retrieval = {}; 175 | local ranks = torch.totable(ranks:double()); 176 | for i = 1, #dataloader['unique_img_'..dtype] do 177 | for j = 1, dataloader[dtype..'_num_rounds'][i] do 178 | table.insert(retrieval, { 179 | image_id = dataloader['unique_img_'..dtype][i]; 180 | round_id = j; 181 | ranks = ranks[i][j] 182 | }) 183 | end 184 | end 185 | -- collect garbage 186 | collectgarbage(); 187 | 188 | return retrieval; 189 | end 190 | 191 | -- prediction on val/test 192 | function Model:predict(dataloader, dtype) 193 | -- change to evaluate mode 194 | self.wrapper:evaluate(); 195 | 196 | local curLoss = 0; 197 | local startId = 1; 198 | local numThreads = dataloader.numThreads[dtype]; 199 | self.params.numOptions = 100; 200 | print('numThreads', numThreads) 201 | 202 | local ranks = torch.Tensor(numThreads, 10, self.params.numOptions); 203 | ranks:fill(self.params.numOptions + 1); 204 | 205 | while startId <= numThreads do 206 | -- print progress 207 | xlua.progress(startId, numThreads); 208 | 209 | -- grab a batch 210 | local batch, nextStartId = 211 | dataloader:getTestBatch(startId, self.params, dtype); 212 | 213 | -- Call retrieve function for specific model, and store ranks 214 | ranks[{{startId, nextStartId - 1}, {}}] = self:retrieveBatch(batch) 215 | :view(nextStartId - startId, -1, self.params.numOptions); 216 | startId = nextStartId; 217 | end 218 | 219 | -- change back to training 220 | self.wrapper:training(); 221 | 222 | local prediction = {}; 223 | local ranks = torch.totable(ranks:double()); 224 | for i = 1, #dataloader['unique_img_'..dtype] do 225 | -- rank list for all rounds in val split and last round in test split 226 | if dtype == 'test' then 227 | table.insert(prediction, { 228 | image_id = dataloader['unique_img_'..dtype][i]; 229 | round_id = dataloader[dtype..'_num_rounds'][i]; 230 | ranks = ranks[i][dataloader[dtype..'_num_rounds'][i]] 231 | }) 232 | else 233 | for j = 1, dataloader[dtype..'_num_rounds'][i] do 234 | table.insert(prediction, { 235 | image_id = dataloader['unique_img_'..dtype][i]; 236 | round_id = j; 237 | ranks = ranks[i][j] 238 | }) 239 | end 240 | end 241 | end 242 | -- collect garbage 243 | collectgarbage(); 244 | 245 | return prediction; 246 | end 247 | 248 | -- forward + backward pass 249 | function Model:forwardBackward(batch, onlyForward, encOutOnly) 250 | local onlyForward = onlyForward or false; 251 | local encOutOnly = encOutOnly or false 252 | local inputs = {} 253 | 254 | -- transpose for timestep first 255 | local batchQues = batch['ques_fwd'] 256 | batchQues = batchQues:view(-1, batchQues:size(3)):t() 257 | table.insert(inputs, batchQues) 258 | 259 | if self.params.useIm == true then 260 | local imgFeats = batch['img_feat'] 261 | -- if attention, then conv layer features 262 | if string.match(self.params.encoder, 'att') then 263 | imgFeats = imgFeats:view(-1, 1, self.params.imgSpatialSize, self.params.imgSpatialSize, self.params.imgFeatureSize) 264 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1, 1, 1) 265 | imgFeats = imgFeats:view(-1, self.params.imgSpatialSize, self.params.imgSpatialSize, self.params.imgFeatureSize) 266 | else 267 | imgFeats = imgFeats:view(-1, 1, self.params.imgFeatureSize) 268 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1) 269 | imgFeats = imgFeats:view(-1, self.params.imgFeatureSize) 270 | end 271 | table.insert(inputs, imgFeats) 272 | end 273 | 274 | if self.params.useHistory == true then 275 | local batchHist = batch['hist'] 276 | batchHist = batchHist:view(-1, batchHist:size(3)):t() 277 | table.insert(inputs, batchHist) 278 | end 279 | 280 | if string.match(self.params.encoder, 'mn') then 281 | local mask = torch.ones(10, 10):byte() 282 | for i = 1, 10 do 283 | for j = 1, 10 do 284 | if j <= i then 285 | mask[i][j] = 0 286 | end 287 | end 288 | end 289 | if self.params.gpuid >= 0 then 290 | mask = mask:cuda() 291 | end 292 | local maskRepeat = torch.repeatTensor(mask, batch['hist']:size(1), 1) 293 | table.insert(inputs, maskRepeat) 294 | end 295 | 296 | -- encoder forward pass 297 | local encOut = self.encoder:forward(inputs) 298 | 299 | -- coupled enc-dec (only for gen) 300 | self.forwardConnect(self.encoder, self.decoder, encOut, batchQues:size(1)); 301 | 302 | if encOutOnly == true then return encOut end 303 | 304 | -- decoder forward pass 305 | local curLoss = 0 306 | if self.params.decoder == 'gen' then 307 | local answerIn = batch['answer_in']; 308 | answerIn = answerIn:view(-1, answerIn:size(3)):t(); 309 | 310 | local answerOut = batch['answer_out']; 311 | answerOut = answerOut:view(-1, answerOut:size(3)):t(); 312 | 313 | local decOut = self.decoder:forward(answerIn); 314 | curLoss = self.criterion:forward(decOut, answerOut); 315 | 316 | -- backward pass 317 | if onlyForward == false then 318 | local gradCriterionOut = self.criterion:backward(decOut, answerOut); 319 | self.decoder:backward(answerIn, gradCriterionOut); 320 | 321 | --backward connect decoder and encoder (only for gen) 322 | local gradDecOut = self.backwardConnect(self.encoder, self.decoder); 323 | self.encoder:backward(inputs, gradDecOut) 324 | end 325 | elseif self.params.decoder == 'disc' then 326 | local options = batch['options'] 327 | local answerInd = batch['answer_ind'] 328 | 329 | local decOut = self.decoder:forward({options, encOut}) 330 | curLoss = self.criterion:forward(decOut, answerInd) 331 | 332 | -- backward pass 333 | if onlyForward == false then 334 | local gradCriterionOut = self.criterion:backward(decOut, answerInd) 335 | local t = self.decoder:backward({options, encOut}, gradCriterionOut) 336 | 337 | self.encoder:backward(inputs, t[2]) 338 | end 339 | end 340 | 341 | return curLoss; 342 | end 343 | 344 | function Model:retrieveBatch(batch) 345 | local inputs = {} 346 | 347 | local batchQues = batch['ques_fwd']; 348 | batchQues = batchQues:view(-1, batchQues:size(3)):t(); 349 | table.insert(inputs, batchQues) 350 | 351 | if self.params.useIm == true then 352 | local imgFeats = batch['img_feat'] 353 | -- if attention, then conv layer features 354 | if string.match(self.params.encoder, 'att') then 355 | imgFeats = imgFeats:view(-1, 1, self.params.imgSpatialSize, self.params.imgSpatialSize, self.params.imgFeatureSize) 356 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1, 1, 1) 357 | imgFeats = imgFeats:view(-1, self.params.imgSpatialSize, self.params.imgSpatialSize, self.params.imgFeatureSize) 358 | else 359 | imgFeats = imgFeats:view(-1, 1, self.params.imgFeatureSize) 360 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1) 361 | imgFeats = imgFeats:view(-1, self.params.imgFeatureSize) 362 | end 363 | table.insert(inputs, imgFeats) 364 | end 365 | 366 | if self.params.useHistory == true then 367 | local batchHist = batch['hist'] 368 | batchHist = batchHist:view(-1, batchHist:size(3)):t() 369 | table.insert(inputs, batchHist) 370 | end 371 | 372 | if string.match(self.params.encoder, 'mn') then 373 | local mask = torch.ones(10, 10):byte() 374 | for i = 1, 10 do 375 | for j = 1, 10 do 376 | if j <= i then 377 | mask[i][j] = 0 378 | end 379 | end 380 | end 381 | if self.params.gpuid >= 0 then 382 | mask = mask:cuda() 383 | end 384 | local maskRepeat = torch.repeatTensor(mask, batch['hist']:size(1), 1) 385 | table.insert(inputs, maskRepeat) 386 | end 387 | 388 | -- forward pass 389 | local encOut = self.encoder:forward(inputs) 390 | local batchSize = batchQues:size(2); 391 | 392 | if self.params.decoder == 'gen' then 393 | local optionIn = batch['option_in']; 394 | optionIn = optionIn:view(-1, optionIn:size(3), optionIn:size(4)); 395 | 396 | local optionOut = batch['option_out']; 397 | optionOut = optionOut:view(-1, optionOut:size(3), optionOut:size(4)); 398 | optionIn = optionIn:transpose(1, 2):transpose(2, 3); 399 | optionOut = optionOut:transpose(1, 2):transpose(2, 3); 400 | 401 | -- tensor holds the likelihood for all the options 402 | local optionLhood = torch.Tensor(self.params.numOptions, batchSize); 403 | 404 | -- repeat for each option and get gt rank 405 | for opId = 1, self.params.numOptions do 406 | -- forward connect encoder and decoder 407 | self.forwardConnect(self.encoder, self.decoder, encOut, batchQues:size(1)); 408 | 409 | local curOptIn = optionIn[opId]; 410 | local curOptOut = optionOut[opId]; 411 | local decOut = self.decoder:forward(curOptIn); 412 | 413 | -- compute the probabilities for each answer, based on its tokens 414 | optionLhood[opId] = utils.computeLhood(curOptOut, decOut); 415 | end 416 | -- gtPosition can be nil if ground truth does not exist 417 | local gtPosition = self.params.useGt and batch['answer_ind'] or nil; 418 | 419 | -- return the ranks for this batch 420 | return utils.computeRanks(optionLhood:t(), gtPosition); 421 | elseif self.params.decoder == 'disc' then 422 | local options = batch['options'] 423 | local decOut = self.decoder:forward({options, encOut}) 424 | local gtPosition = self.params.useGt and batch['answer_ind'] or nil; 425 | 426 | -- return the ranks for this batch 427 | return utils.computeRanks(decOut, gtPosition) 428 | end 429 | 430 | end 431 | 432 | function Model:generateAnswers(dataloader, dtype, params) 433 | -- check decoder 434 | if self.params.decoder == 'disc' then 435 | print('Sampling/beam search only for generative model') 436 | os.exit() 437 | end 438 | 439 | -- setting the options for beam search / sampling 440 | params = params or {}; 441 | 442 | -- sample or take max 443 | local sampleWords = params.sampleWords and params.sampleWords == 1 or false; 444 | local temperature = params.temperature or 1.0; 445 | local beamSize = params.beamSize or 5; 446 | local beamLen = params.beamLen or 20; 447 | 448 | print('Beam size', beamSize) 449 | print('Beam length', beamLen) 450 | 451 | -- endToken index 452 | local startToken = dataloader.word2ind['']; 453 | local endToken = dataloader.word2ind['']; 454 | local numThreads = params.maxThreads or dataloader.numThreads[dtype]; 455 | print('No. of threads', numThreads) 456 | 457 | local answerTable = {} 458 | for convId = 1, numThreads do 459 | xlua.progress(convId, numThreads); 460 | self.wrapper:evaluate() 461 | 462 | local inds = torch.LongTensor(1):fill(convId); 463 | local batch = dataloader:getIndexData(inds, self.params, dtype); 464 | local numQues = batch['ques_fwd']:size(1) * batch['ques_fwd']:size(2); 465 | 466 | local encOut = self:forwardBackward(batch, true, true) 467 | local threadAnswers = {} 468 | 469 | if sampleWords == false then 470 | -- do beam search for each example now 471 | for iter = 1, 10 do 472 | local encInSeq = batch['ques_fwd']:view(-1, batch['ques_fwd']:size(3)):t(); 473 | encInSeq = encInSeq[{{},{iter}}]:squeeze():float() 474 | 475 | -- beams 476 | local beams = torch.LongTensor(beamLen, beamSize):zero(); 477 | 478 | -- initial hidden states for the beam at current round of dialog 479 | local hiddenBeams = {}; 480 | if self.encoder.rnnLayers ~=nil then 481 | for level = 1, #self.encoder.rnnLayers do 482 | if hiddenBeams[level] == nil then hiddenBeams[level] = {} end 483 | hiddenBeams[level]['output'] = self.encoder.rnnLayers[level].output[batch['ques_fwd']:size(3)][iter]; 484 | hiddenBeams[level]['cell'] = self.encoder.rnnLayers[level].cell[batch['ques_fwd']:size(3)][iter]; 485 | if level == #self.encoder.rnnLayers then 486 | hiddenBeams[#self.encoder.rnnLayers]['output'] = encOut[iter] 487 | end 488 | hiddenBeams[level]['output'] = torch.repeatTensor(hiddenBeams[level]['output'], beamSize, 1); 489 | hiddenBeams[level]['cell'] = torch.repeatTensor(hiddenBeams[level]['cell'], beamSize, 1); 490 | end 491 | -- hiddenBeams[]['cell'] is beam_nums x 512 492 | -- hiddenBeams[]['output'] is beam_nums x 512 493 | else 494 | for level = 1, #self.decoder.rnnLayers do 495 | if hiddenBeams[level] == nil then hiddenBeams[level] = {} end 496 | if level == #self.decoder.rnnLayers then 497 | hiddenBeams[level]['output'] = torch.repeatTensor(encOut[iter], beamSize, 1) 498 | else 499 | hiddenBeams[level]['output'] = torch.Tensor(beamSize, encOut:size(2)):zero() 500 | end 501 | hiddenBeams[level]['cell'] = hiddenBeams[level]['output']:clone():zero() 502 | end 503 | end 504 | 505 | -- for first step, initialize with start symbols 506 | beams[1] = dataloader.word2ind['']; 507 | scores = torch.DoubleTensor(beamSize):zero(); 508 | finishBeams = {}; -- accumulate beams that are done 509 | 510 | for step = 2, beamLen do 511 | 512 | -- candidates for the current iteration 513 | cands = {}; 514 | 515 | -- if step == 2, explore only one beam (all are ) 516 | local exploreSize = (step == 2) and 1 or beamSize; 517 | 518 | -- first copy the hidden states to the decoder 519 | for level = 1, #self.decoder.rnnLayers do 520 | self.decoder.rnnLayers[level].userPrevOutput = hiddenBeams[level]['output'] 521 | self.decoder.rnnLayers[level].userPrevCell = hiddenBeams[level]['cell'] 522 | end 523 | 524 | -- decoder forward pass 525 | decOut = self.decoder:forward(beams[{{step-1}}]); 526 | decOut = decOut:squeeze(); -- decOut is beam_nums x vocab_size 527 | 528 | -- iterate separately for each possible word of beam 529 | for wordId = 1, exploreSize do 530 | local curHidden = {}; 531 | for level = 1, #self.decoder.rnnLayers do 532 | if curHidden[level] == nil then curHidden[level] = {} end 533 | curHidden[level]['output'] = self.decoder.rnnLayers[level].output[{{1},{wordId}}]:clone():squeeze(); -- rnnLayers[].output is 1 x beam_nums x 512 534 | curHidden[level]['cell'] = self.decoder.rnnLayers[level].cell[{{1},{wordId}}]:clone():squeeze(); 535 | end 536 | 537 | -- sort and get the top probabilities 538 | if beamSize == 1 then 539 | topProb, topInd = torch.topk(decOut, beamSize, true); 540 | else 541 | topProb, topInd = torch.topk(decOut[wordId], beamSize, true); 542 | end 543 | 544 | for candId = 1, beamSize do 545 | local candBeam = beams[{{}, {wordId}}]:clone(); 546 | -- get the updated cost for each explored candidate, pool 547 | candBeam[step] = topInd[candId]; 548 | if topInd[candId] == endToken then 549 | table.insert(finishBeams, {beam = candBeam:double():squeeze(), length = step, score = scores[wordId] + topProb[candId]}); 550 | else 551 | table.insert(cands, {score = scores[wordId] + topProb[candId], 552 | beam = candBeam, 553 | hidden = curHidden}); 554 | end 555 | end 556 | end 557 | 558 | -- sort the candidates and stick to beam size 559 | table.sort(cands, function (a, b) return a.score > b.score; end); 560 | 561 | for candId = 1, math.min(#cands, beamSize) do 562 | beams[{{}, {candId}}] = cands[candId].beam; 563 | 564 | --recursive copy 565 | for level = 1, #self.decoder.rnnLayers do 566 | hiddenBeams[level]['output'][candId] = cands[candId].hidden[level]['output']:clone(); 567 | hiddenBeams[level]['cell'][candId] = cands[candId].hidden[level]['cell']:clone(); 568 | end 569 | 570 | scores[candId] = cands[candId].score; 571 | end 572 | end 573 | 574 | table.sort(finishBeams, function (a, b) return a.score > b.score; end); 575 | 576 | local quesWords = encInSeq:double():squeeze() 577 | local ansWords = finishBeams[1].beam:squeeze(); 578 | 579 | local quesText = utils.idToWords(quesWords, dataloader.ind2word); 580 | local ansText = utils.idToWords(ansWords, dataloader.ind2word); 581 | 582 | table.insert(threadAnswers, {question = quesText, answer = ansText}) 583 | end 584 | else 585 | local answerIn = torch.Tensor(1, numQues):fill(startToken) 586 | local answer = {answerIn:t():double()} 587 | for timeStep = 1, beamLen do 588 | -- one pass through decoder 589 | local decOut = self.decoder:forward(answerIn):squeeze() 590 | -- connect decoder to itself 591 | self.decoderConnect(self.decoder) 592 | 593 | local nextToken = torch.multinomial(torch.exp(decOut / temperature), 1) 594 | table.insert(answer, nextToken:double()) 595 | answerIn:copy(nextToken) 596 | end 597 | answer = nn.JoinTable(-1):forward(answer) 598 | 599 | for iter = 1, 10 do 600 | local quesWords = batch['ques_fwd'][{{1}, {iter}, {}}]:squeeze():double() 601 | local ansWords = answer[{{iter}, {}}]:squeeze() 602 | 603 | local quesText = utils.idToWords(quesWords, dataloader.ind2word) 604 | local ansText = utils.idToWords(ansWords, dataloader.ind2word) 605 | 606 | table.insert(threadAnswers, {question = quesText, answer = ansText}) 607 | end 608 | end 609 | self.wrapper:training() 610 | table.insert(answerTable, {image_id = dataloader['unique_img_'..dtype][convId], dialog = threadAnswers}) 611 | end 612 | return answerTable 613 | end 614 | 615 | return Model; 616 | -------------------------------------------------------------------------------- /model_utils/MaskFuture.lua: -------------------------------------------------------------------------------- 1 | -- new module to replace zero with a given value 2 | local MaskFuture, Parent = torch.class('nn.MaskFuture', 'nn.Module') 3 | 4 | function MaskFuture:__init(numClasses) 5 | Parent.__init(self); 6 | self.mask = torch.Tensor(1, numClasses, numClasses):fill(1); 7 | -- extract the upper diagonal matrix 8 | self.mask[1] = torch.triu(self.mask[1], 1); 9 | self.gradInput = torch.Tensor(); 10 | self.output = torch.Tensor(); 11 | end 12 | 13 | function MaskFuture:updateOutput(input) 14 | local batchSize = input:size(1); 15 | 16 | self.output:resizeAs(input):copy(input); 17 | -- expand mask based on input 18 | self.output[self.mask:expandAs(input)] = 0; 19 | 20 | return self.output; 21 | end 22 | 23 | function MaskFuture:updateGradInput(input, gradOutput) 24 | -- the first component is zero gradients 25 | --self.gradInput:resizeAs(input):zero(); 26 | -- zero out the gradients based on the mask 27 | self.gradInput:resizeAs(gradOutput):copy(gradOutput); 28 | self.gradInput[self.mask:expandAs(input)] = 0; 29 | 30 | return self.gradInput; 31 | end 32 | -------------------------------------------------------------------------------- /model_utils/MaskSoftMax.lua: -------------------------------------------------------------------------------- 1 | -- Author: Jiasen Lu 2 | -- From https://raw.githubusercontent.com/jiasenlu/HieCoAttenVQA/master/misc/maskSoftmax.lua 3 | local MaskSoftMax, _ = torch.class('nn.MaskSoftMax', 'nn.Module') 4 | 5 | function MaskSoftMax:updateOutput(input) 6 | local data = input[1] 7 | local mask = input[2] 8 | if(mask:type() == 'torch.CudaTensor') then 9 | mask = mask:cudaByte() 10 | end 11 | 12 | data:maskedFill(mask, -9999999) 13 | if(mask:type() == 'torch.CudaByteTensor') then 14 | mask = mask:cuda() 15 | end 16 | data.THNN.SoftMax_updateOutput( 17 | data:cdata(), 18 | self.output:cdata() 19 | ) 20 | return self.output 21 | end 22 | 23 | function MaskSoftMax:updateGradInput(input, gradOutput) 24 | local data = input[1] 25 | local mask = input[2] 26 | if(mask:type() == 'torch.CudaTensor') then 27 | mask = mask:cudaByte() 28 | end 29 | 30 | data:maskedFill(mask, -9999999) 31 | if(mask:type() == 'torch.CudaByteTensor') then 32 | mask = mask:cuda() 33 | end 34 | 35 | data.THNN.SoftMax_updateGradInput( 36 | data:cdata(), 37 | gradOutput:cdata(), 38 | self.gradInput:cdata(), 39 | self.output:cdata() 40 | ) 41 | if not self.dummy_out then 42 | self.dummy_out = mask:clone() 43 | end 44 | self.dummy_out:resizeAs(mask):zero() 45 | return {self.gradInput, self.dummy_out} 46 | end 47 | -------------------------------------------------------------------------------- /model_utils/MaskTime.lua: -------------------------------------------------------------------------------- 1 | -- new module to replace zero with a given value 2 | local MaskTime, Parent = torch.class('nn.MaskTime', 'nn.Module') 3 | 4 | function MaskTime:__init(featSize) 5 | Parent.__init(self); 6 | self.mask = torch.Tensor(); 7 | self.seqLen = nil; 8 | self.featSize = featSize; 9 | self.gradInput = {torch.Tensor(), torch.Tensor()}; 10 | end 11 | 12 | function MaskTime:updateOutput(input) 13 | local seqLen = input[1]:size(1); 14 | local batchSize = input[1]:size(2); 15 | 16 | -- expand the feature vector 17 | self.output:resizeAs(input[2]):copy(input[2]); 18 | self.output = self.output:view(1, batchSize, self.featSize); 19 | self.output = self.output:repeatTensor(seqLen, 1, 1); 20 | 21 | -- expand the word mask 22 | self.mask = input[1]:eq(0); 23 | self.mask = self.mask:view(seqLen, batchSize, 1) 24 | :expand(seqLen, batchSize, self.featSize); 25 | self.output[self.mask] = 0; 26 | 27 | return self.output; 28 | end 29 | 30 | function MaskTime:updateGradInput(input, gradOutput) 31 | -- the first component is zero gradients 32 | self.gradInput[1]:resizeAs(input[1]):zero(); 33 | -- second component has zeroed out gradients 34 | -- sum along first dimension 35 | self.gradInput[2]:resizeAs(gradOutput):copy(gradOutput); 36 | self.gradInput[2][self.mask] = 0; 37 | self.gradInput[2] = self.gradInput[2]:sum(1):squeeze(); 38 | 39 | return self.gradInput; 40 | end 41 | -------------------------------------------------------------------------------- /model_utils/ReplaceZero.lua: -------------------------------------------------------------------------------- 1 | -- new module to replace zero with a given value 2 | local ReplaceZero, Parent = torch.class('nn.ReplaceZero', 'nn.Module') 3 | 4 | function ReplaceZero:__init(constant) 5 | Parent.__init(self); 6 | if not constant then 7 | error(' constant must be specified') 8 | end 9 | self.constant = constant; 10 | self.mask = torch.Tensor(); 11 | end 12 | 13 | function ReplaceZero:updateOutput(input) 14 | self.output:resizeAs(input):copy(input); 15 | self.mask = input:eq(0); 16 | self.output[self.mask] = self.constant; 17 | return self.output; 18 | end 19 | 20 | function ReplaceZero:updateGradInput(input, gradOutput) 21 | self.gradInput:resizeAs(gradOutput):copy(gradOutput); 22 | -- remove the gradients at those points 23 | self.gradInput[self.mask] = 0; 24 | return self.gradInput; 25 | end 26 | -------------------------------------------------------------------------------- /model_utils/optim_updates.lua: -------------------------------------------------------------------------------- 1 | -- Author: Andrej Karpathy https://github.com/karpathy 2 | -- Project: neuraltalk2 https://github.com/karpathy/neuraltalk2 3 | -- Slightly modified by Xiao Lin for initial values of rmsprop. 4 | 5 | -- optim, simple as it should be, written from scratch. That's how I roll 6 | 7 | function sgd(x, dx, lr) 8 | x:add(-lr, dx) 9 | end 10 | 11 | function sgdm(x, dx, lr, alpha, state) 12 | -- sgd with momentum, standard update 13 | if not state.v then 14 | state.v = x.new(#x):zero() 15 | end 16 | state.v:mul(alpha) 17 | state.v:add(lr, dx) 18 | x:add(-1, state.v) 19 | end 20 | 21 | function sgdmom(x, dx, lr, alpha, state) 22 | -- sgd momentum, uses nesterov update (reference: http://cs231n.github.io/neural-networks-3/#sgd) 23 | if not state.m then 24 | state.m = x.new(#x):zero() 25 | state.tmp = x.new(#x) 26 | end 27 | state.tmp:copy(state.m) 28 | state.m:mul(alpha):add(-lr, dx) 29 | x:add(-alpha, state.tmp) 30 | x:add(1+alpha, state.m) 31 | end 32 | 33 | function adagrad(x, dx, lr, epsilon, state) 34 | if not state.m then 35 | state.m = x.new(#x):zero() 36 | state.tmp = x.new(#x) 37 | end 38 | -- calculate new mean squared values 39 | state.m:addcmul(1.0, dx, dx) 40 | -- perform update 41 | state.tmp:sqrt(state.m):add(epsilon) 42 | x:addcdiv(-lr, dx, state.tmp) 43 | end 44 | 45 | -- rmsprop implementation, simple as it should be 46 | function rmsprop(x, dx, state) 47 | local alpha = state.alpha or 0.99; 48 | local learningRate = state.learningRate or 1e-2; 49 | local epsilon = state.epsilon or 1e-8; 50 | if not state.m then 51 | state.m = x.new(#x):zero() 52 | state.tmp = x.new(#x) 53 | end 54 | -- calculate new (leaky) mean squared values 55 | state.m:mul(alpha) 56 | state.m:addcmul(1.0-alpha, dx, dx) 57 | -- perform update 58 | state.tmp:sqrt(state.m):add(epsilon) 59 | x:addcdiv(-learningRate, dx, state.tmp) 60 | end 61 | 62 | function adam(x, dx, state) 63 | local beta1 = state.beta1 or 0.9 64 | local beta2 = state.beta2 or 0.999 65 | local epsilon = state.epsilon or 1e-8 66 | local lr = state.learningRate or 1e-2; 67 | 68 | if not state.m then 69 | -- Initialization 70 | state.t = 0 71 | -- Exponential moving average of gradient values 72 | state.m = x.new(#dx):zero() 73 | -- Exponential moving average of squared gradient values 74 | state.v = x.new(#dx):zero() 75 | -- A tmp tensor to hold the sqrt(v) + epsilon 76 | state.tmp = x.new(#dx):zero() 77 | end 78 | 79 | -- Decay the first and second moment running average coefficient 80 | state.m:mul(beta1):add(1-beta1, dx) 81 | state.v:mul(beta2):addcmul(1-beta2, dx, dx) 82 | state.tmp:copy(state.v):sqrt():add(epsilon) 83 | 84 | state.t = state.t + 1 85 | local biasCorrection1 = 1 - beta1^state.t 86 | local biasCorrection2 = 1 - beta2^state.t 87 | local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1 88 | 89 | -- perform update 90 | x:addcdiv(-stepSize, state.m, state.tmp) 91 | end 92 | -------------------------------------------------------------------------------- /model_utils/weight-init.lua: -------------------------------------------------------------------------------- 1 | -- From https://raw.githubusercontent.com/e-lab/torch-toolbox/master/Weight-init/weight-init.lua 2 | -- 3 | -- Different weight initialization methods 4 | -- 5 | -- > model = require('weight-init')(model, 'heuristic') 6 | -- 7 | require("nn") 8 | 9 | 10 | -- "Efficient backprop" 11 | -- Yann Lecun, 1998 12 | local function w_init_heuristic(fan_in, fan_out) 13 | return math.sqrt(1/(3*fan_in)) 14 | end 15 | 16 | 17 | -- "Understanding the difficulty of training deep feedforward neural networks" 18 | -- Xavier Glorot, 2010 19 | local function w_init_xavier(fan_in, fan_out) 20 | return math.sqrt(2/(fan_in + fan_out)) 21 | end 22 | 23 | 24 | -- "Understanding the difficulty of training deep feedforward neural networks" 25 | -- Xavier Glorot, 2010 26 | local function w_init_xavier_caffe(fan_in, fan_out) 27 | return math.sqrt(1/fan_in) 28 | end 29 | 30 | 31 | -- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" 32 | -- Kaiming He, 2015 33 | local function w_init_kaiming(fan_in, fan_out) 34 | return math.sqrt(4/(fan_in + fan_out)) 35 | end 36 | 37 | 38 | local function w_init(net, arg) 39 | -- choose initialization method 40 | local method = nil 41 | if arg == 'heuristic' then method = w_init_heuristic 42 | elseif arg == 'xavier' then method = w_init_xavier 43 | elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe 44 | elseif arg == 'kaiming' then method = w_init_kaiming 45 | else 46 | assert(false) 47 | end 48 | 49 | -- loop over all convolutional modules 50 | for i = 1, #net.modules do 51 | local m = net.modules[i] 52 | if m.__typename == 'nn.SpatialConvolution' then 53 | m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) 54 | elseif m.__typename == 'nn.SpatialConvolutionMM' then 55 | m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) 56 | elseif m.__typename == 'nn.LateralConvolution' then 57 | m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1)) 58 | elseif m.__typename == 'nn.VerticalConvolution' then 59 | m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) 60 | elseif m.__typename == 'nn.HorizontalConvolution' then 61 | m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) 62 | elseif m.__typename == 'nn.Linear' then 63 | m:reset(method(m.weight:size(2), m.weight:size(1))) 64 | elseif m.__typename == 'nn.TemporalConvolution' then 65 | m:reset(method(m.weight:size(2), m.weight:size(1))) 66 | end 67 | 68 | if m.bias then 69 | m.bias:zero() 70 | end 71 | end 72 | return net 73 | end 74 | 75 | 76 | return w_init 77 | -------------------------------------------------------------------------------- /opts.lua: -------------------------------------------------------------------------------- 1 | cmd = torch.CmdLine() 2 | cmd:text('Train the Visual Dialog model') 3 | cmd:text() 4 | cmd:text('Options') 5 | -- Data input settings 6 | cmd:option('-inputImg', 'data/data_img.h5', 'HDF5 file with image features') 7 | cmd:option('-inputQues', 'data/visdial_data.h5', 'HDF5 file with preprocessed questions') 8 | cmd:option('-inputJson', 'data/visdial_params.json', 'JSON file with info and vocab') 9 | cmd:option('-savePath', 'checkpoints/', 'Path to save checkpoints') 10 | cmd:option('-saveIter', 2, 'Save model checkpoint after every saveIter epochs') 11 | 12 | -- specify encoder/decoder 13 | cmd:option('-encoder', 'lf-ques-hist', 'Name of the encoder to use') 14 | cmd:option('-decoder', 'gen', 'Name of the decoder to use (gen/disc)') 15 | cmd:option('-imgNorm', 1, 'normalize the image feature. 1=yes, 0=no') 16 | 17 | -- model params 18 | cmd:option('-imgEmbedSize', 300, 'Size of the multimodal embedding') 19 | cmd:option('-imgFeatureSize', 4096, 'Channel size of the image feature') 20 | cmd:option('-imgSpatialSize', 14, 'Spatial size of image features (for attention-based encoders).') 21 | cmd:option('-embedSize', 300, 'Size of input word embeddings') 22 | cmd:option('-rnnHiddenSize', 512, 'Size of the LSTM state') 23 | cmd:option('-maxHistoryLen', 60, 'Maximum history to consider when using concatenated QA pairs') 24 | cmd:option('-numLayers', 2, 'Number of layers in LSTM') 25 | cmd:option('-commonEmbeddingSize', 512, 'Common embedding size in MN-ATT-QIH') 26 | cmd:option('-numAttentionLayers', 1, 'No. of attention hops in MN-ATT-QIH') 27 | 28 | cmd:option('-loadPath', '', 'Checkpoint path to load from') 29 | 30 | -- optimization params 31 | cmd:option('-batchSize', 40, 'Batch size (number of threads) (Adjust base on GPU memory)') 32 | cmd:option('-learningRate', 1e-3, 'Learning rate') 33 | cmd:option('-weightInit', 'xavier', 'Weight initialization strategy: xavier|heuristic|kaiming') 34 | cmd:option('-dropout', 0.5, 'Dropout') 35 | cmd:option('-numEpochs', 100, 'Epochs') 36 | cmd:option('-LRateDecay', 10, 'After lr_decay epochs lr reduces to 0.1*lr') 37 | cmd:option('-lrDecayRate', 0.9997592083, 'Decay for learning rate') 38 | cmd:option('-minLRate', 5e-5, 'Minimum learning rate') 39 | cmd:option('-gpuid', 0, 'GPU id to use') 40 | cmd:option('-backend', 'cudnn', 'nn|cudnn') 41 | 42 | local opts = cmd:parse(arg); 43 | 44 | -- if save path is not given, use default — time 45 | -- get the current time 46 | local curTime = os.date('*t', os.time()); 47 | -- create another folder to avoid clutter 48 | local modelPath = string.format('checkpoints/model-%d-%d-%d-%d:%d:%d-%s-%s/', 49 | curTime.month, curTime.day, curTime.year, 50 | curTime.hour, curTime.min, curTime.sec, 51 | opts.encoder, opts.decoder) 52 | if opts.savePath == 'checkpoints/' then opts.savePath = modelPath end; 53 | 54 | -- check for inputs required 55 | if string.match(opts.encoder, 'hist') then opts.useHistory = true end 56 | if string.match(opts.encoder, 'im') then opts.useIm = true end 57 | 58 | -- check if history is to be concatenated (only for late fusion encoder) 59 | if string.match(opts.encoder, 'lf') then opts.concatHistory = true end 60 | 61 | -- attention is always on conv features, not fc7 62 | if string.match(opts.encoder, 'att') then 63 | if opts.inputImg == 'data/data_img.h5' then 64 | opts.inputImg = 'data/data_img_pool5.h5' 65 | end 66 | opts.imgNorm = 0 67 | end 68 | 69 | return opts; 70 | -------------------------------------------------------------------------------- /scripts/download_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | print_usage () { 4 | echo "Usage: download_model.sh [vgg|resnet] [layers]" 5 | echo "For vgg, 'layers' can be one of {16, 19}" 6 | echo "For resnet, 'layers' can be one of {18, 34, 50, 101, 152, 200}" 7 | } 8 | 9 | if [ $1 = "vgg" ] 10 | then 11 | if [ $2 = "16" ] 12 | then 13 | mkdir -p data/models/vgg16 14 | cd data/models/vgg16 15 | wget https://gist.githubusercontent.com/ksimonyan/211839e770f7b538e2d8/raw/ded9363bd93ec0c770134f4e387d8aaaaa2407ce/VGG_ILSVRC_16_layers_deploy.prototxt 16 | wget http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel 17 | elif [ $2 = "19" ] 18 | then 19 | mkdir -p data/models/vgg19 20 | cd data/models/vgg19 21 | wget https://gist.githubusercontent.com/ksimonyan/3785162f95cd2d5fee77/raw/f43eeefc869d646b449aa6ce66f87bf987a1c9b5/VGG_ILSVRC_19_layers_deploy.prototxt 22 | wget http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel 23 | else 24 | print_usage 25 | fi 26 | elif [ $1 = "resnet" ] 27 | then 28 | if echo "18 34 50 101 152 200" | grep -w $2 > /dev/null 29 | then 30 | mkdir -p data/models/resnet 31 | cd data/models/resnet 32 | wget https://d2j0dndfm35trm.cloudfront.net/resnet-$2.t7 33 | else 34 | print_usage 35 | fi 36 | else 37 | print_usage 38 | fi 39 | 40 | cd ../../.. 41 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | require 'rnn' 4 | 5 | ------------------------------------------------------------------------ 6 | -- Input arguments and options 7 | ------------------------------------------------------------------------ 8 | local opt = require 'opts'; 9 | print(opt) 10 | 11 | -- seed for reproducibility 12 | torch.manualSeed(1234); 13 | 14 | -- set default tensor based on gpu usage 15 | if opt.gpuid >= 0 then 16 | require 'cutorch' 17 | require 'cunn' 18 | if opt.backend == 'cudnn' then require 'cudnn' end 19 | cutorch.setDevice(opt.gpuid+1) 20 | cutorch.manualSeed(1234) 21 | torch.setdefaulttensortype('torch.CudaTensor'); 22 | else 23 | torch.setdefaulttensortype('torch.FloatTensor'); 24 | end 25 | 26 | -- transfer all options to model 27 | local modelParams = opt; 28 | 29 | ------------------------------------------------------------------------ 30 | -- Read saved model and parameters 31 | ------------------------------------------------------------------------ 32 | local savedModel = false; 33 | if opt.loadPath ~= '' then 34 | savedModel = torch.load(opt.loadPath); 35 | modelParams = savedModel.modelParams; 36 | 37 | opt.imgNorm = modelParams.imgNorm; 38 | opt.encoder = modelParams.encoder; 39 | opt.decoder = modelParams.decoder; 40 | modelParams.gpuid = opt.gpuid; 41 | modelParams.batchSize = opt.batchSize; 42 | end 43 | 44 | ------------------------------------------------------------------------ 45 | -- Loading dataset 46 | ------------------------------------------------------------------------ 47 | local dataloader = dofile('dataloader.lua'); 48 | dataloader:initialize(opt, {'train'}); 49 | collectgarbage(); 50 | 51 | ------------------------------------------------------------------------ 52 | -- Setting model parameters 53 | ------------------------------------------------------------------------ 54 | -- transfer parameters from dataloader to model 55 | paramNames = {'numTrainThreads', 'numTestThreads', 'numValThreads', 56 | 'vocabSize', 'maxQuesCount', 'maxQuesLen', 'maxAnsLen'}; 57 | for _, value in pairs(paramNames) do 58 | modelParams[value] = dataloader[value]; 59 | end 60 | 61 | -- path to save the model 62 | local modelPath = opt.savePath 63 | 64 | -- creating the directory to save the model 65 | paths.mkdir(modelPath); 66 | 67 | -- Iterations per epoch 68 | modelParams.numIterPerEpoch = math.ceil(modelParams.numTrainThreads / 69 | modelParams.batchSize); 70 | print(string.format('\n%d iter per epoch.', modelParams.numIterPerEpoch)); 71 | 72 | ------------------------------------------------------------------------ 73 | -- Setup the model 74 | ------------------------------------------------------------------------ 75 | require 'model' 76 | local model = Model(modelParams); 77 | 78 | if opt.loadPath ~= '' then 79 | model.wrapperW:copy(savedModel.modelW); 80 | model.optims.learningRate = savedModel.optims.learningRate; 81 | end 82 | 83 | ------------------------------------------------------------------------ 84 | -- Training 85 | ------------------------------------------------------------------------ 86 | print('Training..') 87 | collectgarbage() 88 | 89 | runningLoss = 0; 90 | for iter = 1, modelParams.numEpochs * modelParams.numIterPerEpoch do 91 | -- forward and backward propagation 92 | model:trainIteration(dataloader); 93 | 94 | -- evaluate on val and save model 95 | if iter % (modelParams.saveIter * modelParams.numIterPerEpoch) == 0 then 96 | local currentEpoch = iter / modelParams.numIterPerEpoch 97 | 98 | -- save model and optimization parameters 99 | torch.save(string.format(modelPath .. 'model_epoch_%d.t7', currentEpoch), 100 | {modelW = model.wrapperW, 101 | optims = model.optims, 102 | modelParams = modelParams}) 103 | -- validation accuracy 104 | -- model:retrieve(dataloader, 'val'); 105 | end 106 | 107 | -- print after every few iterations 108 | if iter % 100 == 0 then 109 | local currentEpoch = iter / modelParams.numIterPerEpoch; 110 | 111 | -- print current time, running average, learning rate, iteration, epoch 112 | print(string.format('[%s][Epoch:%.02f][Iter:%d][Loss:%.05f][lr:%f]', 113 | os.date(), currentEpoch, iter, runningLoss, 114 | model.optims.learningRate)) 115 | end 116 | if iter % 10 == 0 then collectgarbage(); end 117 | end 118 | 119 | -- Saving the final model 120 | torch.save(modelPath .. 'model_final.t7', {modelW = model.wrapperW:float(), 121 | modelParams = modelParams}); 122 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | -- script containing supporting code/methods 2 | local utils = {}; 3 | cjson = require 'cjson' 4 | 5 | -- right align the question tokens in 3d volume 6 | function utils.rightAlign(sequences, lengths) 7 | -- clone the sequences 8 | local rAligned = sequences:clone():fill(0); 9 | local numDims = sequences:dim(); 10 | 11 | if numDims == 3 then 12 | local M = sequences:size(3); -- maximum length of question 13 | local numImgs = sequences:size(1); -- number of images 14 | local maxCount = sequences:size(2); -- number of questions / image 15 | 16 | for imId = 1, numImgs do 17 | for quesId = 1, maxCount do 18 | -- do only for non zero sequence counts 19 | if lengths[imId][quesId] == 0 then 20 | break; 21 | end 22 | 23 | -- copy based on the sequence length 24 | rAligned[imId][quesId][{{M - lengths[imId][quesId] + 1, M}}] = 25 | sequences[imId][quesId][{{1, lengths[imId][quesId]}}]; 26 | end 27 | end 28 | else if numDims == 2 then 29 | -- handle 2 dimensional matrices as well 30 | local M = sequences:size(2); -- maximum length of question 31 | local numImgs = sequences:size(1); -- number of images 32 | 33 | for imId = 1, numImgs do 34 | -- do only for non zero sequence counts 35 | if lengths[imId] > 0 then 36 | -- copy based on the sequence length 37 | rAligned[imId][{{M - lengths[imId] + 1, M}}] = 38 | sequences[imId][{{1, lengths[imId]}}]; 39 | end 40 | end 41 | end 42 | end 43 | 44 | return rAligned; 45 | end 46 | 47 | -- translate a given tensor/table to sentence 48 | function utils.idToWords(vector, ind2word) 49 | local sentence = ''; 50 | 51 | local nextWord; 52 | for wordId = 1, vector:size(1) do 53 | if vector[wordId] > 0 then 54 | nextWord = ind2word[vector[wordId]]; 55 | sentence = sentence..' '..nextWord; 56 | end 57 | 58 | -- stop if end of token is attained 59 | if nextWord == '' then break; end 60 | end 61 | 62 | return sentence; 63 | end 64 | 65 | -- read a json file and lua table 66 | function utils.readJSON(fileName) 67 | local file = io.open(fileName, 'r'); 68 | local text = file:read(); 69 | file:close(); 70 | 71 | -- convert and save information 72 | return cjson.decode(text); 73 | end 74 | 75 | -- save a lua table to the json 76 | function utils.writeJSON(fileName, luaTable) 77 | -- serialize lua table 78 | local text = cjson.encode(luaTable) 79 | 80 | local file = io.open(fileName, 'w'); 81 | file:write(text); 82 | file:close(); 83 | end 84 | 85 | -- compute the likelihood given the gt words and predicted probabilities 86 | function utils.computeLhood(words, predProbs) 87 | -- compute the probabilities for each answer, based on its tokens 88 | -- convert to 2d matrix 89 | local predVec = predProbs:view(-1, predProbs:size(3)); 90 | local indices = words:contiguous():view(-1, 1); 91 | local mask = indices:eq(0); 92 | -- assign proxy values to avoid 0 index errors 93 | indices[mask] = 1; 94 | local logProbs = predVec:gather(2, indices); 95 | -- neutralize other values 96 | logProbs[mask] = 0; 97 | logProbs = logProbs:viewAs(words); 98 | -- sum up for each sentence 99 | logProbs = logProbs:sum(1):squeeze(); 100 | 101 | return logProbs; 102 | end 103 | 104 | -- process the scores and obtain the ranks 105 | -- input: scores for all options, ground truth positions 106 | function utils.computeRanks(scores, gtPos) 107 | -- sort in descending order - largest score gets highest rank 108 | local sorted, rankedIdx = scores:sort(2, true) 109 | 110 | -- convert from ranked_idx to ranks 111 | local ranks = rankedIdx:clone():fill(0) 112 | for i = 1, rankedIdx:size(1) do 113 | for j = 1, 100 do 114 | ranks[{i, rankedIdx[{i, j}]}] = j 115 | end 116 | end 117 | 118 | if gtPos then 119 | gtPos = gtPos:view(-1) 120 | local gtRanks = torch.LongTensor(gtPos:size(1)) 121 | for i = 1, gtPos:size(1) do 122 | gtRanks[i] = ranks[{i, gtPos[i]}] 123 | end 124 | ranks = gtRanks 125 | end 126 | 127 | return ranks:double() 128 | end 129 | 130 | -- process the ranks and print metrics 131 | function utils.processRanks(ranks) 132 | -- print the results 133 | local numQues = ranks:size(1) * ranks:size(2); 134 | 135 | local numOptions = 100; 136 | 137 | -- convert ranks to double, vector and remove zeros 138 | ranks = ranks:double():view(-1); 139 | -- non of the values should be 0, there is gt in options 140 | if torch.sum(ranks:le(0)) > 0 then 141 | numZero = torch.sum(ranks:le(0)); 142 | print(string.format('Warning: some of ranks are zero : %d', numZero)) 143 | ranks = ranks[ranks:gt(0)]; 144 | end 145 | 146 | if torch.sum(ranks:ge(numOptions + 1)) > 0 then 147 | numGreater = torch.sum(ranks:ge(numOptions + 1)); 148 | print(string.format('Warning: some of ranks >100 : %d', numGreater)) 149 | ranks = ranks[ranks:le(numOptions + 1)]; 150 | end 151 | 152 | ------------------------------------------------ 153 | print(string.format('\tNo. questions: %d', numQues)) 154 | print(string.format('\tr@1: %f', torch.sum(torch.le(ranks, 1))/numQues)) 155 | print(string.format('\tr@5: %f', torch.sum(torch.le(ranks, 5))/numQues)) 156 | print(string.format('\tr@10: %f', torch.sum(torch.le(ranks, 10))/numQues)) 157 | print(string.format('\tmedianR: %f', torch.median(ranks:view(-1))[1])) 158 | print(string.format('\tmeanR: %f', torch.mean(ranks))) 159 | print(string.format('\tmeanRR: %f', torch.mean(ranks:cinv()))) 160 | end 161 | 162 | return utils; 163 | -------------------------------------------------------------------------------- /vis/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | VisDial Results 5 | 6 | 15 | 16 | 17 |
18 |
19 |
20 |

21 |
22 |
23 |
24 |
25 |
26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /vis/static/main.js: -------------------------------------------------------------------------------- 1 | (function(){ 2 | function pad(num, size) { 3 | var s = "000000000000" + num; 4 | return s.substr(s.length-size); 5 | } 6 | if (!String.prototype.format) { 7 | String.prototype.format = function() { 8 | var args = arguments; 9 | return this.replace(/{(\d+)}/g, function(match, number) { 10 | return typeof args[number] != 'undefined' 11 | ? args[number] 12 | : match 13 | ; 14 | }); 15 | }; 16 | } 17 | $.get('results/results.json', function(data) { 18 | var image_root = "https://vision.ece.vt.edu/mscoco/images/val2014/"; 19 | if (data.opts.sampleWords == 0) 20 | $('#heading').html('Encoder: ' + data.opts.encoder 21 | + ', Decoder: ' + data.opts.decoder + ', Beam size: ' + data.opts.beamSize + ', Max beam length: ' + data.opts.beamLen); 22 | else 23 | $('#heading').html('Encoder: ' + data.opts.encoder 24 | + ', Decoder: ' + data.opts.decoder + ', Temperature: ' + data.opts.temperature); 25 | 26 | var html = ''; 27 | for (var i in data.data) { 28 | if (i % 4 == 0) 29 | html += "
" 30 | html += "
"// + data.data[i].image_id 31 | html += "".format(image_root, pad(parseInt(data.data[i].image_id), 12)) 32 | html += "
    " 33 | for (var j = 0; j < 10; j++) { 34 | html += "
  1. " + data.data[i].dialog[j].question + "" + data.data[i].dialog[j].answer + "
  2. " 35 | } 36 | html += "
" 37 | if (i % 4 == 3) 38 | html += "

" 39 | } 40 | $('#main').html(html); 41 | }) 42 | })(); 43 | --------------------------------------------------------------------------------