├── .gitignore
├── LICENSE.txt
├── README.md
├── convert_gpu_to_cpu.lua
├── data
├── prepro.py
├── prepro_img_resnet.lua
├── prepro_img_vgg16.lua
└── prepro_utils.lua
├── dataloader.lua
├── decoders
├── disc.lua
└── gen.lua
├── encoders
├── hre-ques-hist.lua
├── hre-ques-im-hist.lua
├── hrea-ques-im-hist.lua
├── lf-att-ques-im-hist.lua
├── lf-ques-hist.lua
├── lf-ques-im-hist.lua
├── lf-ques-im.lua
├── lf-ques.lua
├── mn-att-ques-im-hist.lua
├── mn-ques-hist.lua
└── mn-ques-im-hist.lua
├── evaluate.lua
├── generate.lua
├── model.lua
├── model_utils
├── MaskFuture.lua
├── MaskSoftMax.lua
├── MaskTime.lua
├── ReplaceZero.lua
├── optim_updates.lua
└── weight-init.lua
├── opts.lua
├── scripts
└── download_model.sh
├── train.lua
├── utils.lua
└── vis
├── index.html
└── static
├── bootstrap.min.css
├── jquery-3.2.1.min.js
└── main.js
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | logs
3 | tmp
4 | checkpoints
5 |
6 | vis/results
7 |
8 | *scripts*
9 | *internal*
10 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | BSD License
2 |
3 | For visdial software
4 |
5 | Copyright (c) 2017-present, Machine Learning & Perception Lab, Georgia Tech. All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without modification,
8 | are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | * Neither the name of the copyright holder nor the names of its contributors
18 | may be used to endorse or promote products derived from this software
19 | without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VisDial
2 |
3 | Code for the paper
4 |
5 | **[Visual Dialog][1]**
6 | Abhishek Das, Satwik Kottur, Khushi Gupta, Avi Singh, Deshraj Yadav, José M. F. Moura, Devi Parikh, Dhruv Batra
7 | [arxiv.org/abs/1611.08669][1]
8 | [CVPR 2017][10] (Spotlight)
9 |
10 | **Visual Dialog** requires an AI agent to hold a meaningful dialog with humans in natural, conversational language about visual content. Given an image, dialog history, and a follow-up question about the image, the AI agent has to answer the question.
11 |
12 | Demo: [demo.visualdialog.org][11]
13 |
14 |
15 |
16 | This repository contains code for **training**, **evaluating** and **visualizing results** for all combinations of encoder-decoder architectures described in the paper. Specifically, we have 3 encoders: **Late Fusion** (LF), **Hierarchical Recurrent Encoder** (HRE), **Memory Network** (MN), and 2 kinds of decoding: **Generative** (G) and **Discriminative** (D).
17 |
18 | [][1]
19 |
20 | If you find this code useful, consider citing our work:
21 |
22 | ```
23 | @inproceedings{visdial,
24 | title={{V}isual {D}ialog},
25 | author={Abhishek Das and Satwik Kottur and Khushi Gupta and Avi Singh
26 | and Deshraj Yadav and Jos\'e M.F. Moura and Devi Parikh and Dhruv Batra},
27 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
28 | year={2017}
29 | }
30 | ```
31 |
32 | ## Setup
33 |
34 | All our code is implemented in [Torch][13] (Lua). Installation instructions are as follows:
35 |
36 | ```sh
37 | git clone https://github.com/torch/distro.git ~/torch --recursive
38 | cd ~/torch; bash install-deps;
39 | TORCH_LUA_VERSION=LUA51 ./install.sh
40 | ```
41 |
42 | Additionally, our code uses the following packages: [torch/torch7][14], [torch/nn][15], [torch/nngraph][16], [Element-Research/rnn][17], [torch/image][18], [lua-cjson][19], [loadcaffe][20], [torch-hdf5][25]. After Torch is installed, these can be installed/updated using:
43 |
44 | ```sh
45 | luarocks install torch
46 | luarocks install nn
47 | luarocks install nngraph
48 | luarocks install image
49 | luarocks install lua-cjson
50 | luarocks install loadcaffe
51 | luarocks install luabitop
52 | luarocks install totem
53 | ```
54 |
55 | **NOTE**: `luarocks install rnn` defaults to [torch/rnn][33], follow these steps to install [Element-Research/rnn][17].
56 |
57 | ```sh
58 | git clone https://github.com/Element-Research/rnn.git
59 | cd rnn
60 | luarocks make rocks/rnn-scm-1.rockspec
61 | ```
62 |
63 | Installation instructions for torch-hdf5 are given [here][26].
64 |
65 | **NOTE**: torch-hdf5 does not work with few versions of gcc. It is recommended that you use gcc 4.8 / gcc 4.9 with Lua 5.1 for proper installation of torch-hdf5.
66 |
67 | ### Running on GPUs
68 |
69 | Although our code should work on CPUs, it is *highly* recommended to use GPU acceleration with [CUDA][21]. You'll also need [torch/cutorch][22], [torch/cudnn][31] and [torch/cunn][23].
70 |
71 | ```sh
72 | luarocks install cutorch
73 | luarocks install cunn
74 | luarocks install cudnn
75 | ```
76 |
77 | ## Training your own network
78 |
79 | ### Preprocessing VisDial
80 |
81 | The preprocessing script is in Python, and you'll need to install [NLTK][24].
82 |
83 | ```sh
84 | pip install nltk
85 | pip install numpy
86 | pip install h5py
87 | python -c "import nltk; nltk.download('all')"
88 | ```
89 |
90 | [VisDial v1.0][27] dataset can be downloaded and preprocessed as specified below. The path provided as `-image_root` must have four subdirectories - [`train2014`][34] and [`val2014`][35] as per COCO dataset, `VisualDialog_val2018` and `VisualDialog_test2018` which can be downloaded from [here][27].
91 |
92 | ```sh
93 | cd data
94 | python prepro.py -download -image_root /path/to/images
95 | cd ..
96 | ```
97 |
98 | To download and preprocess [Visdial v0.9][27] dataset, provide an extra `-version 0.9` argument while execution.
99 |
100 | This script will generate the files `data/visdial_data.h5` (contains tokenized captions, questions, answers, image indices) and `data/visdial_params.json` (contains vocabulary mappings and COCO image ids).
101 |
102 |
103 | ### Extracting image features
104 |
105 | Since we don't finetune the CNN, training is significantly faster if image features are pre-extracted. Currently this repository provides support for extraction from VGG-16 and ResNets. We use image features from [VGG-16][28]. The VGG-16 model can be downloaded and features extracted using:
106 |
107 | ```sh
108 | sh scripts/download_model.sh vgg 16 # works for 19 as well
109 | cd data
110 | # For all models except mn-att-ques-im-hist
111 | th prepro_img_vgg16.lua -imageRoot /path/to/images -gpuid 0
112 | # For mn-att-ques-im-hist
113 | th prepro_img_vgg16.lua -imageRoot /path/to/images -imgSize 448 -layerName pool5 -gpuid 0
114 | ```
115 |
116 | Similarly, [ResNet models][32] released by Facebook can be used for feature extraction. Feature extraction can be carried out in a similar manner as VGG-16:
117 |
118 | ```sh
119 | sh scripts/download_model.sh resnet 200 # works for 18, 34, 50, 101, 152 as well
120 | cd data
121 | th prepro_img_resnet.lua -imageRoot /path/to/images -cnnModel /path/to/t7/model -gpuid 0
122 | ```
123 |
124 | Running either of these should generate `data/data_img.h5` containing features for `train`, `val` and `test` splits corresponding to VisDial v1.0.
125 |
126 |
127 | ### Training
128 |
129 | Finally, we can get to training models! All supported encoders are in the `encoders/` folder (`lf-ques`, `lf-ques-im`, `lf-ques-hist`, `lf-ques-im-hist`, `hre-ques-hist`, `hre-ques-im-hist`, `hrea-ques-im-hist`, `mn-ques-hist`, `mn-ques-im-hist`, `mn-att-ques-im-hist`), and decoders in the `decoders/` folder (`gen` and `disc`).
130 |
131 | **Generative** (`gen`) decoding tries to maximize likelihood of ground-truth response and only has access to single input-output pairs of dialog, while **discriminative** (`disc`) decoding makes use of 100 candidate option responses provided for every round of dialog, and maximizes likelihood of correct option.
132 |
133 | Encoders and decoders can be arbitrarily plugged together. For example, to train an HRE model with question and history information only (no images), and generative decoding:
134 |
135 | ```sh
136 | th train.lua -encoder hre-ques-hist -decoder gen -gpuid 0
137 | ```
138 |
139 | Similarly, to train a Memory Network model with question, image and history information, and discriminative decoding:
140 |
141 | ```sh
142 | th train.lua -encoder mn-ques-im-hist -decoder disc -gpuid 0
143 | ```
144 |
145 | **Note:** For attention based encoders, set both `imgSpatialSize` and `imgFeatureSize` command line params, feature dimensions are interpreted as `(batch X spatial X spatial X feature)`. For other encoders, `imgSpatialSize` is redundant.
146 |
147 | The training script saves model snapshots at regular intervals in the `checkpoints/` folder.
148 |
149 | It takes about 15-20 epochs to train models with generative decoding to convergence, and 4-8 epochs for discriminative decoding.
150 |
151 | ## Evaluation
152 |
153 | We evaluate model performance by where it ranks human response given 100 response options for every round of dialog, based on retrieval metrics — mean reciprocal rank, R@1, R@5, R@10, mean rank.
154 |
155 | Model evaluation can be run using:
156 |
157 | ```sh
158 | th evaluate.lua -loadPath checkpoints/model.t7 -gpuid 0
159 | ```
160 |
161 | Note that evaluation requires image features `data/data_img.h5`, tokenized dialogs `data/visdial_data.h5` and vocabulary mappings `data/visdial_params.json`.
162 |
163 | ## Running Beam Search & Visualizing Results
164 |
165 | We also include code for running beam search on your model snapshots. This gives significantly nicer results than argmax decoding, and can be run as follows:
166 |
167 | ```sh
168 | th generate.lua -loadPath checkpoints/model.t7 -maxThreads 50
169 | ```
170 |
171 | This would compute predictions for 50 threads from the `val` split and save results in `vis/results/results.json`.
172 |
173 | ```sh
174 | cd vis
175 | # python 3.6
176 | python -m http.server
177 | # python 2.7
178 | # python -m SimpleHTTPServer
179 | ```
180 |
181 | Now visit `localhost:8000` in your browser to see generated results.
182 |
183 | Sample results from HRE-QIH-G available [here](https://computing.ece.vt.edu/~abhshkdz/visdial/browse_results/).
184 |
185 | 
186 |
187 | ## Download Extracted Features & Pretrained Models
188 |
189 | ### v0.9
190 |
191 | Extracted features for v0.9 train and val are available for download.
192 |
193 | * [`visdial_data.h5`](https://s3.amazonaws.com/visual-dialog/data/v0.9/visdial_data.h5): Tokenized captions, questions, answers, image indices
194 | * [`visdial_params.json`](https://s3.amazonaws.com/visual-dialog/data/v0.9/visdial_params.json): Vocabulary mappings and COCO image ids
195 | * [`data_img_vgg16_relu7.h5`](https://s3.amazonaws.com/visual-dialog/data/v0.9/data_img_vgg16_relu7.h5): VGG16 `relu7` image features
196 | * [`data_img_vgg16_pool5.h5`](https://s3.amazonaws.com/visual-dialog/data/v0.9/data_img_vgg16_pool5.h5): VGG16 `pool5` image features
197 |
198 | #### Pretrained models
199 |
200 | Trained on v0.9 `train`, results on v0.9 `val`.
201 |
202 |
277 |
278 | ### v1.0
279 |
280 | Extracted features for v1.0 train, val and test are available for download.
281 |
282 | * [`visdial_data_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/visdial_data_train.h5): Tokenized captions, questions, answers, image indices, for training on `train`
283 | * [`visdial_params_train.json`](https://s3.amazonaws.com/visual-dialog/data/v1.0/visdial_params_train.json): Vocabulary mappings and COCO image ids for training on `train`
284 | * [`data_img_vgg16_relu7_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/data_img_vgg16_relu7_train.h5): VGG16 `relu7` image features for training on `train`
285 | * [`data_img_vgg16_pool5_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/data_img_vgg16_pool5_train.h5): VGG16 `pool5` image features for training on `train`
286 | * [`visdial_data_trainval.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/visdial_data_trainval.h5): Tokenized captions, questions, answers, image indices, for training on `train`+`val`
287 | * [`visdial_params_trainval.json`](https://s3.amazonaws.com/visual-dialog/data/v1.0/visdial_params_trainval.json): Vocabulary mappings and COCO image ids for training on `train`+`val`
288 | * [`data_img_vgg16_relu7_trainval.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/data_img_vgg16_relu7_trainval.h5): VGG16 `relu7` image features for training on `train`+`val`
289 | * [`data_img_vgg16_pool5_trainval.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/data_img_vgg16_pool5_trainval.h5): VGG16 `pool5` image features for training on `train`+`val`
290 |
291 | #### Pretrained models
292 |
293 | Trained on v1.0 `train` + v1.0 `val`, results on v1.0 `test-std`. Leaderboard [here][evalai-leaderboard].
294 |
295 |
334 |
335 | ## License
336 |
337 | BSD
338 |
339 |
340 | [1]: https://arxiv.org/abs/1611.08669
341 | [2]: https://abhishekdas.com
342 | [3]: https://satwikkottur.github.io
343 | [4]: http://www.linkedin.com/in/khushi-gupta-9a678448
344 | [5]: http://people.eecs.berkeley.edu/~avisingh/
345 | [6]: http://deshraj.github.io
346 | [7]: http://users.ece.cmu.edu/~moura/
347 | [8]: http://www.cc.gatech.edu/~parikh/
348 | [9]: http://www.cc.gatech.edu/~dbatra
349 | [10]: http://cvpr2017.thecvf.com/
350 | [11]: http://demo.visualdialog.org
351 | [12]: https://vimeo.com/193092429
352 | [13]: http://torch.ch/
353 | [14]: https://github.com/torch/torch7
354 | [15]: https://github.com/torch/nn
355 | [16]: https://github.com/torch/nngraph
356 | [17]: https://github.com/Element-Research/rnn/
357 | [18]: https://github.com/torch/image
358 | [19]: https://luarocks.org/modules/luarocks/lua-cjson
359 | [20]: https://github.com/szagoruyko/loadcaffe
360 | [21]: https://developer.nvidia.com/cuda-toolkit
361 | [22]: https://github.com/torch/cutorch
362 | [23]: https://github.com/torch/cunn
363 | [24]: http://www.nltk.org/
364 | [25]: https://github.com/deepmind/torch-hdf5
365 | [26]: https://github.com/deepmind/torch-hdf5/blob/master/doc/usage.md
366 | [27]: https://visualdialog.org/data
367 | [28]: http://www.robots.ox.ac.uk/~vgg/research/very_deep/
368 | [31]: https://www.github.com/soumith/cudnn.torch
369 | [32]: https://github.com/facebook/fb.resnet.torch/tree/master/pretrained
370 | [33]: https://github.com/torch/rnn
371 | [34]: http://images.cocodataset.org/zips/train2014.zip
372 | [35]: http://images.cocodataset.org/zips/val2014.zip
373 | [evalai-leaderboard]: https://evalai.cloudcv.org/web/challenges/challenge-page/103/leaderboard/298
374 |
--------------------------------------------------------------------------------
/convert_gpu_to_cpu.lua:
--------------------------------------------------------------------------------
1 | require 'torch';
2 | require 'nn';
3 |
4 | cmd = torch.CmdLine()
5 | cmd:option('-loadPath', 'checkpoints/model.t7')
6 | cmd:option('-savePath', 'checkpoints/model_cpu.t7')
7 | cmd:option('-gpuid', 0)
8 |
9 | opt = cmd:parse(arg)
10 |
11 | -- check for new save path
12 | if opt.savePath == 'checkpoints/model_cpu.t7' then
13 | opt.savePath = opt.loadPath .. '.cpu.t7'
14 | end
15 |
16 | print(opt)
17 |
18 | if opt.gpuid >= 0 then
19 | require 'cutorch'
20 | require 'cunn'
21 | if opt.backend == 'cudnn' then require 'cudnn' end
22 | cutorch.setDevice(opt.gpuid+1)
23 | torch.setdefaulttensortype('torch.CudaTensor');
24 | else
25 | print('Gotta have a GPU to convert to CPU :(')
26 | os.exit()
27 | end
28 |
29 | print('Loading model')
30 | model = torch.load(opt.loadPath)
31 |
32 | -- convert modelW and optims to cpu
33 | print('Shipping params to CPU')
34 | if model.modelW:type() == 'torch.CudaTensor' then
35 | model.modelW = model.modelW:float()
36 | end
37 |
38 | for k,v in pairs(model.optims) do
39 | if torch.type(v) ~= 'number' and v:type() == 'torch.CudaTensor' then
40 | model.optims[k] = v:float()
41 | end
42 | end
43 |
44 | print('Saving to ' .. opt.savePath)
45 | torch.save(opt.savePath, model)
46 |
--------------------------------------------------------------------------------
/data/prepro.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import h5py
4 | import json
5 | import os
6 | import numpy as np
7 | from nltk.tokenize import word_tokenize
8 | from tqdm import tqdm
9 |
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('-download', action='store_true', help='Whether to download VisDial data')
13 | parser.add_argument('-version', default='1.0', choices=['0.9', '1.0'], help='Version of VisDial to be downloaded')
14 | parser.add_argument('-train_split', default='train', help='Choose the data split: train | trainval', choices=['train', 'trainval'])
15 |
16 | # Input files
17 | parser.add_argument('-input_json_train', default='visdial_1.0_train.json', help='Input `train` json file')
18 | parser.add_argument('-input_json_val', default='visdial_1.0_val.json', help='Input `val` json file')
19 | parser.add_argument('-input_json_test', default='visdial_1.0_test.json', help='Input `test` json file')
20 | parser.add_argument('-image_root', default='/path/to/images', help='Path to coco and VisDial val/test images')
21 | parser.add_argument('-input_vocab', default=False, help='Optional vocab file; similar to visdial_params.json')
22 |
23 | # Output files
24 | parser.add_argument('-output_json', default='visdial_params.json', help='Output json file')
25 | parser.add_argument('-output_h5', default='visdial_data.h5', help='Output hdf5 file')
26 |
27 | # Options
28 | parser.add_argument('-max_ques_len', default=20, type=int, help='Max length of questions')
29 | parser.add_argument('-max_ans_len', default=20, type=int, help='Max length of answers')
30 | parser.add_argument('-max_cap_len', default=40, type=int, help='Max length of captions')
31 | parser.add_argument('-word_count_threshold', default=5, type=int, help='Min threshold of word count to include in vocabulary')
32 |
33 |
34 | def tokenize_data(data, word_count=False):
35 | """Tokenize captions, questions and answers, maintain word count
36 | if required.
37 | """
38 | word_counts = {}
39 | dialogs = data['data']['dialogs']
40 | # dialogs is a nested dict so won't be copied, just a reference
41 |
42 | print("[%s] Tokenizing captions..." % data['split'])
43 | for i, dialog in enumerate(tqdm(dialogs)):
44 | caption = word_tokenize(dialog['caption'])
45 | dialogs[i]['caption_tokens'] = caption
46 |
47 | print("[%s] Tokenizing questions and answers..." % data['split'])
48 | q_tokens, a_tokens = [], []
49 | for q in tqdm(data['data']['questions']):
50 | q_tokens.append(word_tokenize(q + '?'))
51 |
52 | for a in tqdm(data['data']['answers']):
53 | a_tokens.append(word_tokenize(a))
54 | data['data']['question_tokens'] = q_tokens
55 | data['data']['answer_tokens'] = a_tokens
56 |
57 | print("[%s] Filling missing values in dialog, if any..." % data['split'])
58 | for i, dialog in enumerate(tqdm(dialogs)):
59 | # last round of dialog will not have answer for test split
60 | if 'answer' not in dialog['dialog'][-1]:
61 | dialog['dialog'][-1]['answer'] = -1
62 | # right-pad dialog with empty question-answer pairs at the end
63 | dialog['num_rounds'] = len(dialog['dialog'])
64 | while len(dialog['dialog']) < 10:
65 | dialog['dialog'].append({'question': -1, 'answer': -1})
66 | dialogs[i] = dialog
67 |
68 | if word_count:
69 | print("[%s] Building word counts from tokens..." % data['split'])
70 | for i, dialog in enumerate(tqdm(dialogs)):
71 | caption = dialogs[i]['caption_tokens']
72 | all_qa = []
73 | for j in range(10):
74 | all_qa += q_tokens[dialog['dialog'][j]['question']]
75 | all_qa += a_tokens[dialog['dialog'][j]['answer']]
76 | for word in caption + all_qa:
77 | word_counts[word] = word_counts.get(word, 0) + 1
78 | print('\n')
79 | return data, word_counts
80 |
81 |
82 | def encode_vocab(data, word2ind):
83 | """Converts string tokens to indices based on given dictionary."""
84 | dialogs = data['data']['dialogs']
85 | print("[%s] Encoding caption tokens..." % data['split'])
86 | for i, dialog in enumerate(tqdm(dialogs)):
87 | dialogs[i]['caption_tokens'] = [word2ind.get(word, word2ind['UNK']) \
88 | for word in dialog['caption_tokens']]
89 |
90 | print("[%s] Encoding question and answer tokens..." % data['split'])
91 | q_tokens = data['data']['question_tokens']
92 | a_tokens = data['data']['answer_tokens']
93 |
94 | for i, q in enumerate(tqdm(q_tokens)):
95 | q_tokens[i] = [word2ind.get(word, word2ind['UNK']) for word in q]
96 |
97 | for i, a in enumerate(tqdm(a_tokens)):
98 | a_tokens[i] = [word2ind.get(word, word2ind['UNK']) for word in a]
99 |
100 | data['data']['question_tokens'] = q_tokens
101 | data['data']['answer_tokens'] = a_tokens
102 | return data
103 |
104 |
105 | def create_data_mats(data, params, dtype):
106 | num_threads = len(data['data']['dialogs'])
107 | data_mats = {}
108 | data_mats['img_pos'] = np.arange(num_threads, dtype=np.int)
109 |
110 | print("[%s] Creating caption data matrices..." % data['split'])
111 | max_cap_len = params.max_cap_len
112 | captions = np.zeros([num_threads, max_cap_len])
113 | caption_len = np.zeros(num_threads, dtype=np.int)
114 |
115 | for i, dialog in enumerate(tqdm(data['data']['dialogs'])):
116 | caption_len[i] = len(dialog['caption_tokens'][0:max_cap_len])
117 | captions[i][0:caption_len[i]] = dialog['caption_tokens'][0:max_cap_len]
118 | data_mats['cap_length'] = caption_len
119 | data_mats['cap'] = captions
120 |
121 | print("[%s] Creating question and answer data matrices..." % data['split'])
122 | num_rounds = 10
123 | max_ques_len = params.max_ques_len
124 | max_ans_len = params.max_ans_len
125 |
126 | ques = np.zeros([num_threads, num_rounds, max_ques_len])
127 | ans = np.zeros([num_threads, num_rounds, max_ans_len])
128 | ques_length = np.zeros([num_threads, num_rounds], dtype=np.int)
129 | ans_length = np.zeros([num_threads, num_rounds], dtype=np.int)
130 |
131 | for i, dialog in enumerate(tqdm(data['data']['dialogs'])):
132 | for j in range(num_rounds):
133 | if dialog['dialog'][j]['question'] != -1:
134 | ques_length[i][j] = len(data['data']['question_tokens'][
135 | dialog['dialog'][j]['question']][0:max_ques_len])
136 | ques[i][j][0:ques_length[i][j]] = data['data']['question_tokens'][
137 | dialog['dialog'][j]['question']][0:max_ques_len]
138 | if dialog['dialog'][j]['answer'] != -1:
139 | ans_length[i][j] = len(data['data']['answer_tokens'][
140 | dialog['dialog'][j]['answer']][0:max_ans_len])
141 | ans[i][j][0:ans_length[i][j]] = data['data']['answer_tokens'][
142 | dialog['dialog'][j]['answer']][0:max_ans_len]
143 |
144 | data_mats['ques'] = ques
145 | data_mats['ans'] = ans
146 | data_mats['ques_length'] = ques_length
147 | data_mats['ans_length'] = ans_length
148 |
149 | print("[%s] Creating options data matrices..." % data['split'])
150 | # options and answer_index are 1-indexed specifically for lua
151 | options = np.ones([num_threads, num_rounds, 100])
152 | num_rounds_list = np.full(num_threads, 10)
153 |
154 | for i, dialog in enumerate(tqdm(data['data']['dialogs'])):
155 | for j in range(num_rounds):
156 | num_rounds_list[i] = dialog['num_rounds']
157 | # v1.0 test does not have options for all dialog rounds
158 | if 'answer_options' in dialog['dialog'][j]:
159 | options[i][j] += np.array(dialog['dialog'][j]['answer_options'])
160 |
161 | data_mats['num_rounds'] = num_rounds_list
162 | data_mats['opt'] = options
163 |
164 | if dtype != 'test':
165 | print("[%s] Creating ground truth answer data matrices..." % data['split'])
166 | answer_index = np.zeros([num_threads, num_rounds])
167 | for i, dialog in enumerate(tqdm(data['data']['dialogs'])):
168 | for j in range(num_rounds):
169 | answer_index[i][j] = dialog['dialog'][j]['gt_index'] + 1
170 | data_mats['ans_index'] = answer_index
171 |
172 | options_len = np.zeros(len(data['data']['answer_tokens']), dtype=np.int)
173 | options_list = np.zeros([len(data['data']['answer_tokens']), max_ans_len])
174 |
175 | for i, ans_token in enumerate(tqdm(data['data']['answer_tokens'])):
176 | options_len[i] = len(ans_token[0:max_ans_len])
177 | options_list[i][0:options_len[i]] = ans_token[0:max_ans_len]
178 |
179 | data_mats['opt_length'] = options_len
180 | data_mats['opt_list'] = options_list
181 | return data_mats
182 |
183 |
184 | def get_image_ids(data, id2path):
185 | image_ids = [dialog['image_id'] for dialog in data['data']['dialogs']]
186 | for i, image_id in enumerate(image_ids):
187 | image_ids[i] = id2path[image_id]
188 | return image_ids
189 |
190 |
191 | if __name__ == "__main__":
192 | args = parser.parse_args()
193 |
194 | if args.download:
195 | if args.version == '1.0':
196 | os.system('wget https://www.dropbox.com/s/ix8keeudqrd8hn8/visdial_1.0_train.zip')
197 | os.system('wget https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip')
198 | elif args.version == '0.9':
199 | os.system('wget https://computing.ece.vt.edu/~abhshkdz/data/visdial/visdial_0.9_train.zip')
200 | os.system('wget https://computing.ece.vt.edu/~abhshkdz/data/visdial/visdial_0.9_val.zip')
201 | os.system('wget https://www.dropbox.com/s/o7mucbre2zm7i5n/visdial_1.0_test.zip')
202 |
203 | os.system('unzip visdial_%s_train.zip' % args.version)
204 | os.system('unzip visdial_%s_val.zip' % args.version)
205 | os.system('unzip visdial_1.0_test.zip')
206 |
207 | args.input_json_train = 'visdial_%s_train.json' % args.version
208 | args.input_json_val = 'visdial_%s_val.json' % args.version
209 | args.input_json_test = 'visdial_1.0_test.json'
210 |
211 | print('Reading json...')
212 | data_train = json.load(open(args.input_json_train, 'r'))
213 | data_val = json.load(open(args.input_json_val, 'r'))
214 | data_test = json.load(open(args.input_json_test, 'r'))
215 |
216 | # Tokenizing
217 | data_train, word_counts_train = tokenize_data(data_train, True)
218 | data_val, word_counts_val = tokenize_data(data_val, True)
219 | data_test, _ = tokenize_data(data_test)
220 |
221 | if args.input_vocab == False:
222 | word_counts_all = dict(word_counts_train)
223 | # combining the word counts of train and val splits
224 | if args.train_split == 'trainval':
225 | for word, count in word_counts_val.items():
226 | word_counts_all[word] = word_counts_all.get(word, 0) + count
227 |
228 | print('Building vocabulary...')
229 | word_counts_all['UNK'] = args.word_count_threshold
230 | vocab = [word for word in word_counts_all \
231 | if word_counts_all[word] >= args.word_count_threshold]
232 | print('Words: %d' % len(vocab))
233 | word2ind = {word: word_ind + 1 for word_ind, word in enumerate(vocab)}
234 | ind2word = {word_ind: word for word, word_ind in word2ind.items()}
235 | else:
236 | print('Loading vocab from %s...' % args.input_vocab)
237 | vocab_data = json.load(open(args.input_vocab, 'r'))
238 |
239 | word2ind = vocab_data['word2ind']
240 | for i in word2ind:
241 | word2ind[i] = int(word2ind[i])
242 |
243 | ind2word = {}
244 | for i in vocab_data['ind2word']:
245 | ind2word[int(i)] = vocab_data['ind2word'][i]
246 |
247 | print('Encoding based on vocabulary...')
248 | data_train = encode_vocab(data_train, word2ind)
249 | data_val = encode_vocab(data_val, word2ind)
250 | data_test = encode_vocab(data_test, word2ind)
251 |
252 | print('Creating data matrices...')
253 | data_mats_train = create_data_mats(data_train, args, 'train')
254 | data_mats_val = create_data_mats(data_val, args, 'val')
255 | data_mats_test = create_data_mats(data_test, args, 'test')
256 |
257 | if args.train_split == 'trainval':
258 | data_mats_trainval = {}
259 | for key in data_mats_train:
260 | data_mats_trainval[key] = np.concatenate((data_mats_train[key],
261 | data_mats_val[key]), axis = 0)
262 |
263 | print('Saving hdf5 to %s...' % args.output_h5)
264 | f = h5py.File(args.output_h5, 'w')
265 | if args.train_split == 'train':
266 | for key in data_mats_train:
267 | f.create_dataset(key + '_train', dtype='uint32', data=data_mats_train[key])
268 |
269 | for key in data_mats_val:
270 | f.create_dataset(key + '_val', dtype='uint32', data=data_mats_val[key])
271 |
272 | elif args.train_split == 'trainval':
273 | for key in data_mats_trainval:
274 | f.create_dataset(key + '_train', dtype='uint32', data=data_mats_trainval[key])
275 |
276 | for key in data_mats_test:
277 | f.create_dataset(key + '_test', dtype='uint32', data=data_mats_test[key])
278 | f.close()
279 |
280 | out = {}
281 | out['ind2word'] = ind2word
282 | out['word2ind'] = word2ind
283 |
284 | print('Preparing image paths with image_ids...')
285 | id2path = {}
286 | # NOTE: based on assumption that image_id is unique across all splits
287 | for image_path in tqdm(glob.iglob(os.path.join(args.image_root, '*', '*.jpg'))):
288 | id2path[int(image_path[-12:-4])] = '/'.join(image_path.split('/')[-2:])
289 |
290 | out['unique_img_train'] = get_image_ids(data_train, id2path)
291 | out['unique_img_val'] = get_image_ids(data_val, id2path)
292 | out['unique_img_test'] = get_image_ids(data_test, id2path)
293 | if args.train_split == 'trainval':
294 | out['unique_img_train'] += out['unique_img_val']
295 | out.pop('unique_img_val')
296 | print('Saving json to %s...' % args.output_json)
297 | json.dump(out, open(args.output_json, 'w'))
298 |
--------------------------------------------------------------------------------
/data/prepro_img_resnet.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | require 'prepro_utils'
4 |
5 | -------------------------------------------------------------------------------
6 | -- Input arguments and options
7 | -------------------------------------------------------------------------------
8 | cmd = torch.CmdLine()
9 | cmd:text('Extract Image Features from Pretrained ResNet Models (t7 models)')
10 | cmd:text()
11 | cmd:text('Options')
12 | cmd:option('-inputJson', 'visdial_params.json', 'Path to JSON file')
13 | cmd:option('-imageRoot', '/path/to/images/', 'Path to COCO image root')
14 | cmd:option('-cnnModel', '/path/to/t7/model', 'Path to Pretrained T7 Model')
15 | cmd:option('-trainSplit', 'train', 'Which split to use: train | trainval')
16 | cmd:option('-batchSize', 50, 'Batch size')
17 |
18 | cmd:option('-outName', 'data_img.h5', 'Output name')
19 | cmd:option('-gpuid', 0, 'Which gpu to use. -1 = use CPU')
20 |
21 | cmd:option('-imgSize', 224)
22 |
23 | opt = cmd:parse(arg)
24 | print(opt)
25 |
26 | if opt.gpuid >= 0 then
27 | require 'cutorch'
28 | require 'cunn'
29 | require 'cudnn'
30 | cutorch.setDevice(opt.gpuid + 1)
31 | end
32 |
33 | -------------------------------------------------------------------------------
34 | -- Loading model and removing extra layers
35 | -------------------------------------------------------------------------------
36 | model = torch.load(opt.cnnModel);
37 | -- Remove the last fully connected + softmax layer of the model
38 | model:remove()
39 | model:evaluate()
40 |
41 | -------------------------------------------------------------------------------
42 | -- Infering output dim
43 | -------------------------------------------------------------------------------
44 | local dummy_img = torch.DoubleTensor(1, 3, opt.imgSize, opt.imgSize)
45 |
46 | if opt.gpuid >= 0 then
47 | dummy_img = dummy_img:cuda()
48 | model = model:cuda()
49 | end
50 |
51 | model:forward(dummy_img)
52 | local ndims = model.output:squeeze():size():totable()
53 |
54 | -------------------------------------------------------------------------------
55 | -- Defining function for image preprocessing, like mean subtraction
56 | -------------------------------------------------------------------------------
57 | function preprocessFn(im)
58 | -- mean pixel for torch models trained on imagenet
59 | local meanstd = {
60 | mean = { 0.485, 0.456, 0.406 },
61 | std = { 0.229, 0.224, 0.225 },
62 | }
63 | for i = 1, 3 do
64 | im[i]:add(-meanstd.mean[i])
65 | im[i]:div(meanstd.std[i])
66 | end
67 | return im
68 | end
69 |
70 | -------------------------------------------------------------------------------
71 | -- Extract features and save to HDF
72 | -------------------------------------------------------------------------------
73 | extractFeatures(model, opt, ndims, preprocessFn)
74 |
--------------------------------------------------------------------------------
/data/prepro_img_vgg16.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'loadcaffe'
3 | require 'prepro_utils'
4 |
5 | -------------------------------------------------------------------------------
6 | -- Input arguments and options
7 | -------------------------------------------------------------------------------
8 | cmd = torch.CmdLine()
9 | cmd:text('Extract Image Features from Pretrained VGG16 Models (prototxt + caffemodel)')
10 | cmd:text()
11 | cmd:text('Options')
12 | cmd:option('-inputJson', 'visdial_params.json', 'Path to JSON file')
13 | cmd:option('-imageRoot', '/path/to/images/', 'Path to COCO image root')
14 | cmd:option('-cnnProto', 'models/vgg16/VGG_ILSVRC_16_layers_deploy.prototxt', 'Path to the CNN prototxt')
15 | cmd:option('-cnnModel', 'models/vgg16/VGG_ILSVRC_16_layers.caffemodel', 'Path to the CNN model')
16 | cmd:option('-trainSplit', 'train', 'Which split to use: train | trainval')
17 | cmd:option('-batchSize', 50, 'Batch size')
18 |
19 | cmd:option('-outName', 'data_img.h5', 'Output name')
20 | cmd:option('-gpuid', 0, 'Which gpu to use. -1 = use CPU')
21 | cmd:option('-backend', 'nn', 'nn|cudnn')
22 |
23 | cmd:option('-imgSize', 224)
24 | cmd:option('-layerName', 'relu7')
25 |
26 | opt = cmd:parse(arg)
27 | print(opt)
28 |
29 | if opt.gpuid >= 0 then
30 | require 'cutorch'
31 | require 'cunn'
32 | cutorch.setDevice(opt.gpuid + 1)
33 | end
34 |
35 | -------------------------------------------------------------------------------
36 | -- Loading model and removing extra layers
37 | -------------------------------------------------------------------------------
38 | model = loadcaffe.load(opt.cnnProto, opt.cnnModel, opt.backend);
39 |
40 | for i = #model.modules, 1, -1 do
41 | local layer = model:get(i)
42 | if layer.name == opt.layerName then break end
43 | model:remove()
44 | end
45 | model:evaluate()
46 |
47 | -------------------------------------------------------------------------------
48 | -- Infering output dim
49 | -------------------------------------------------------------------------------
50 | local dummy_img = torch.DoubleTensor(1, 3, opt.imgSize, opt.imgSize)
51 |
52 | if opt.gpuid >= 0 then
53 | dummy_img = dummy_img:cuda()
54 | model = model:cuda()
55 | end
56 |
57 | model:forward(dummy_img)
58 | local ndims = model.output:squeeze():size():totable()
59 |
60 | -------------------------------------------------------------------------------
61 | -- Defining function for image preprocessing, like mean subtraction
62 | -------------------------------------------------------------------------------
63 | function preprocessFn(im)
64 | -- mean pixel for caffemodels trained on imagenet
65 | local meanPixel = torch.DoubleTensor({103.939, 116.779, 123.68})
66 | im = im:index(1, torch.LongTensor{3, 2, 1}):mul(255.0)
67 | meanPixel = meanPixel:view(3, 1, 1):expandAs(im)
68 | im:add(-1, meanPixel)
69 | return im
70 | end
71 |
72 | -------------------------------------------------------------------------------
73 | -- Extract features and save to HDF
74 | -------------------------------------------------------------------------------
75 | extractFeatures(model, opt, ndims, preprocessFn)
76 |
--------------------------------------------------------------------------------
/data/prepro_utils.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'math'
3 | require 'nn'
4 | require 'image'
5 | require 'xlua'
6 | require 'hdf5'
7 | cjson = require('cjson')
8 |
9 |
10 | function loadImage(imageName, imgSize, preprocessType)
11 | im = image.load(imageName)
12 |
13 | if im:size(1) == 1 then
14 | im = im:repeatTensor(3, 1, 1)
15 | elseif im:size(1) == 4 then
16 | im = im[{{1,3}, {}, {}}]
17 | end
18 |
19 | im = image.scale(im, imgSize, imgSize)
20 | return im
21 | end
22 |
23 | function extractFeaturesSplit(model, opt, ndims, preprocessFn, dtype)
24 | local file = io.open(opt.inputJson, 'r')
25 | local text = file:read()
26 | file:close()
27 | jsonFile = cjson.decode(text)
28 |
29 | local imList = {}
30 | for i, imName in pairs(jsonFile['unique_img_'..dtype]) do
31 | table.insert(imList, string.format('%s/%s', opt.imageRoot, imName))
32 | end
33 |
34 | local sz = #imList
35 | local imFeats = torch.FloatTensor(sz, unpack(ndims))
36 |
37 | -- feature_dims shall be either 2 (NW format), else 4 (having NCHW format)
38 | local feature_dims = #imFeats:size()
39 |
40 | print(string.format('Processing %d %s images...', sz, dtype))
41 | for i = 1, sz, opt.batchSize do
42 | xlua.progress(i, sz)
43 | r = math.min(sz, i + opt.batchSize - 1)
44 | ims = torch.DoubleTensor(r - i + 1, 3, opt.imgSize, opt.imgSize)
45 | for j = 1, r - i + 1 do
46 | ims[j] = loadImage(imList[i + j - 1], opt.imgSize)
47 | ims[j] = preprocessFn(ims[j])
48 | end
49 | if opt.gpuid >= 0 then
50 | ims = ims:cuda()
51 | end
52 |
53 | if feature_dims == 4 then
54 | -- forward pass and permute to get NHWC format
55 | model:forward(ims):permute(1, 3, 4, 2):contiguous():float()
56 | else
57 | model:forward(ims)
58 | end
59 | imFeats[{{i, r}, {}}] = model.output:float()
60 | collectgarbage()
61 | end
62 |
63 | return imFeats
64 | end
65 |
66 | function extractFeatures(model, opt, ndims, preprocessFn)
67 | local h5File = hdf5.open(opt.outName, 'w')
68 | imFeats = extractFeaturesSplit(model, opt, ndims, preprocessFn, 'train')
69 | h5File:write('/images_train', imFeats)
70 | if opt.trainSplit == 'train' then
71 | imFeats = extractFeaturesSplit(model, opt, ndims, preprocessFn, 'val')
72 | h5File:write('/images_val', imFeats)
73 | imFeats = extractFeaturesSplit(model, opt, ndims, preprocessFn, 'test')
74 | h5File:write('/images_test', imFeats)
75 | elseif opt.trainSplit == 'trainval' then
76 | imFeats = extractFeaturesSplit(model, opt, ndims, preprocessFn, 'test')
77 | h5File:write('/images_test', imFeats)
78 | end
79 | h5File:close()
80 | end
81 |
--------------------------------------------------------------------------------
/dataloader.lua:
--------------------------------------------------------------------------------
1 | require 'hdf5'
2 | require 'xlua'
3 | local utils = require 'utils'
4 |
5 | local dataloader = {};
6 |
7 | -- read the data
8 | -- params: object itself, command line options,
9 | -- subset of data to load (train, val, test)
10 | function dataloader:initialize(opt, subsets)
11 | -- read additional info like dictionary, etc
12 | print('DataLoader loading json file: ', opt.inputJson)
13 | info = utils.readJSON(opt.inputJson);
14 | for key, value in pairs(info) do dataloader[key] = value; end
15 |
16 | -- add and to vocabulary
17 | count = 0;
18 | for _ in pairs(dataloader['word2ind']) do count = count + 1; end
19 | dataloader['word2ind'][''] = count + 1;
20 | dataloader['word2ind'][''] = count + 2;
21 | count = count + 2;
22 | dataloader.vocabSize = count;
23 | print(string.format('Vocabulary size (with ,): %d\n', count));
24 |
25 | -- construct ind2word
26 | local ind2word = {};
27 | for word, ind in pairs(dataloader['word2ind']) do
28 | ind2word[ind] = word;
29 | end
30 | dataloader['ind2word'] = ind2word;
31 |
32 | -- read questions, answers and options
33 | print('DataLoader loading h5 file: ', opt.inputQues)
34 | local quesFile = hdf5.open(opt.inputQues, 'r');
35 |
36 | print('DataLoader loading h5 file: ', opt.inputImg)
37 | local imgFile = hdf5.open(opt.inputImg, 'r');
38 | -- number of threads
39 | self.numThreads = {};
40 |
41 | for _, dtype in pairs(subsets) do
42 | -- convert image ids to numbers
43 | for k, v in pairs(dataloader['unique_img_'..dtype]) do
44 | dataloader['unique_img_'..dtype][k] = tonumber(string.match(v, '000%d+'))
45 | end
46 |
47 | -- read question related information
48 | self[dtype..'_ques'] = quesFile:read('ques_'..dtype):all();
49 | self[dtype..'_ques_len'] = quesFile:read('ques_length_'..dtype):all();
50 |
51 | -- read answer related information
52 | self[dtype..'_ans'] = quesFile:read('ans_'..dtype):all();
53 | self[dtype..'_ans_len'] = quesFile:read('ans_length_'..dtype):all();
54 | if dtype ~= 'test' then
55 | self[dtype..'_ans_ind'] = quesFile:read('ans_index_'..dtype):all():long();
56 | end
57 |
58 | -- read image list, if image features are needed
59 | if opt.useIm then
60 | print('Reading image features..')
61 | local imgFeats = imgFile:read('/images_'..dtype):all();
62 |
63 | -- Normalize the image features (if needed)
64 | if opt.imgNorm == 1 then
65 | print('Normalizing image features..')
66 | local nm = torch.sqrt(torch.sum(torch.cmul(imgFeats, imgFeats), 2));
67 | imgFeats = torch.cdiv(imgFeats, nm:expandAs(imgFeats)):float();
68 | end
69 | -- Transpose from N x 512 x 14 x 14 to N x 14 x 14 x 512
70 | if string.match(opt.encoder, 'att') then
71 | imgFeats = imgFeats:permute(1, 3, 4, 2);
72 | end
73 | self[dtype..'_img_fv'] = imgFeats;
74 | -- TODO: make it 1 indexed in processing code
75 | -- currently zero indexed, adjust manually
76 | self[dtype..'_img_pos'] = quesFile:read('img_pos_'..dtype):all():long();
77 | self[dtype..'_img_pos'] = self[dtype..'_img_pos'] + 1;
78 | end
79 |
80 | -- print information for data type
81 | print(string.format('%s:\n\tNo. of threads: %d\n\tNo. of rounds: %d'..
82 | '\n\tMax ques len: %d'..'\n\tMax ans len: %d\n',
83 | dtype, self[dtype..'_ques']:size(1),
84 | self[dtype..'_ques']:size(2),
85 | self[dtype..'_ques']:size(3),
86 | self[dtype..'_ans']:size(3)));
87 |
88 | -- record some stats
89 | if dtype == 'train' then
90 | self.numTrainThreads = self['train_ques']:size(1);
91 | self.numThreads['train'] = self.numTrainThreads;
92 | end
93 | if dtype == 'test' then
94 | self.numTestThreads = self['test_ques']:size(1);
95 | self.numThreads['test'] = self.numTestThreads;
96 | end
97 | if dtype == 'val' then
98 | self.numValThreads = self['val_ques']:size(1);
99 | self.numThreads['val'] = self.numValThreads;
100 | end
101 |
102 | -- record the options
103 | if dtype == 'train' or dtype == 'val' or dtype == 'test' then
104 | self[dtype..'_opt'] = quesFile:read('opt_'..dtype):all():long();
105 | self[dtype..'_opt_len'] = quesFile:read('opt_length_'..dtype):all();
106 | self[dtype..'_opt_list'] = quesFile:read('opt_list_'..dtype):all();
107 | self.numOptions = self[dtype..'_opt']:size(3);
108 | end
109 |
110 | self[dtype..'_num_rounds'] = quesFile:read('num_rounds_'..dtype):all();
111 |
112 | -- assume similar stats across multiple data subsets
113 | -- maximum number of questions per image, ideally 10
114 | self.maxQuesCount = self[dtype..'_ques']:size(2);
115 | -- maximum length of question
116 | self.maxQuesLen = self[dtype..'_ques']:size(3);
117 | -- maximum length of answer
118 | self.maxAnsLen = self[dtype..'_ans']:size(3);
119 |
120 | -- if history is needed
121 | if opt.useHistory then
122 | self[dtype..'_cap'] = quesFile:read('cap_'..dtype):all():long();
123 | self[dtype..'_cap_len'] = quesFile:read('cap_length_'..dtype):all();
124 | end
125 | end
126 | -- done reading, close files
127 | quesFile:close();
128 | imgFile:close();
129 |
130 | -- take desired flags/values from opt
131 | self.useHistory = opt.useHistory;
132 | self.concatHistory = opt.concatHistory;
133 | self.useIm = opt.useIm;
134 | self.maxHistoryLen = opt.maxHistoryLen or 60;
135 |
136 | -- prepareDataset for training
137 | for _, dtype in pairs(subsets) do self:prepareDataset(dtype); end
138 | end
139 |
140 | -- method to prepare questions and answers for retrieval
141 | -- questions : right align
142 | -- answers : prefix with and
143 | function dataloader:prepareDataset(dtype)
144 | -- right align the questions
145 | print('Right aligning questions: '..dtype);
146 | self[dtype..'_ques_fwd'] = utils.rightAlign(self[dtype..'_ques'],
147 | self[dtype..'_ques_len']);
148 |
149 | -- if separate captions are needed
150 | if self.useHistory then self:processHistory(dtype); end
151 | -- prefix options with and , if not train
152 | -- if dtype ~= 'train' then self:processOptions(dtype); end
153 | self:processOptions(dtype)
154 | -- process answers
155 | self:processAnswers(dtype);
156 | end
157 |
158 | -- process answers
159 | function dataloader:processAnswers(dtype)
160 | --prefix answers with , ; adjust answer lengths
161 | local answers = self[dtype..'_ans'];
162 | local ansLen = self[dtype..'_ans_len'];
163 |
164 | local numConvs = answers:size(1);
165 | local numRounds = answers:size(2);
166 | local maxAnsLen = answers:size(3);
167 |
168 | local decodeIn = torch.LongTensor(numConvs, numRounds, maxAnsLen+1):zero();
169 | local decodeOut = torch.LongTensor(numConvs, numRounds, maxAnsLen+1):zero();
170 |
171 | -- decodeIn begins with
172 | decodeIn[{{}, {}, 1}] = self.word2ind[''];
173 |
174 | -- go over each answer and modify
175 | local endTokenId = self.word2ind[''];
176 | for thId = 1, numConvs do
177 | for roundId = 1, numRounds do
178 | local length = ansLen[thId][roundId];
179 |
180 | -- only if nonzero
181 | if length > 0 then
182 | decodeIn[thId][roundId][{{2, length + 1}}]
183 | = answers[thId][roundId][{{1, length}}];
184 |
185 | decodeOut[thId][roundId][{{1, length}}]
186 | = answers[thId][roundId][{{1, length}}];
187 | else
188 | if dtype ~= 'test' then
189 | print(string.format('Warning: empty answer at (%d %d %d)',
190 | thId, roundId, length))
191 | end
192 | end
193 | decodeOut[thId][roundId][length+1] = endTokenId;
194 | end
195 | end
196 |
197 | self[dtype..'_ans_len'] = self[dtype..'_ans_len'] + 1;
198 | self[dtype..'_ans_in'] = decodeIn;
199 | self[dtype..'_ans_out'] = decodeOut;
200 | end
201 |
202 | -- process caption as history
203 | function dataloader:processHistory(dtype)
204 | local captions = self[dtype..'_cap'];
205 | local questions = self[dtype..'_ques'];
206 | local quesLen = self[dtype..'_ques_len'];
207 | local capLen = self[dtype..'_cap_len'];
208 | local maxQuesLen = questions:size(3);
209 |
210 | local answers = self[dtype..'_ans'];
211 | local ansLen = self[dtype..'_ans_len'];
212 | local numConvs = answers:size(1);
213 | local numRounds = answers:size(2);
214 | local maxAnsLen = answers:size(3);
215 |
216 | local history, histLen;
217 | if self.concatHistory == true then
218 | self.maxHistoryLen = math.min(numRounds * (maxQuesLen + maxAnsLen), 300);
219 |
220 | history = torch.LongTensor(numConvs, numRounds,
221 | self.maxHistoryLen):zero();
222 | histLen = torch.LongTensor(numConvs, numRounds):zero();
223 | else
224 | history = torch.LongTensor(numConvs, numRounds,
225 | maxQuesLen+maxAnsLen):zero();
226 | histLen = torch.LongTensor(numConvs, numRounds):zero();
227 | end
228 |
229 | -- go over each question and append it with answer
230 | for thId = 1, numConvs do
231 | local lenC = capLen[thId];
232 | local lenH; -- length of history
233 | for roundId = 1, numRounds do
234 | if roundId == 1 then
235 | -- first round has caption as history
236 | history[thId][roundId][{{1, maxQuesLen + maxAnsLen}}]
237 | = captions[thId][{{1, maxQuesLen + maxAnsLen}}];
238 | lenH = math.min(lenC, maxQuesLen + maxAnsLen);
239 | else
240 | local lenQ = quesLen[thId][roundId-1];
241 | local lenA = ansLen[thId][roundId-1];
242 | -- if concatHistory, string together all previous QAs
243 | if self.concatHistory == true then
244 | history[thId][roundId][{{1, lenH}}]
245 | = history[thId][roundId-1][{{1, lenH}}];
246 | history[thId][roundId][{{lenH+1}}] = self.word2ind[''];
247 | if lenQ > 0 then
248 | history[thId][roundId][{{lenH+2, lenH+1+lenQ}}]
249 | = questions[thId][roundId-1][{{1, lenQ}}];
250 | end
251 | if lenA > 0 then
252 | history[thId][roundId][{{lenH+1+lenQ+1, lenH+1+lenQ+lenA}}]
253 | = answers[thId][roundId-1][{{1, lenA}}];
254 | end
255 | lenH = lenH + lenQ + lenA + 1
256 | -- else, history is just previous round QA
257 | else
258 | if lenQ > 0 then
259 | history[thId][roundId][{{1, lenQ}}]
260 | = questions[thId][roundId-1][{{1, lenQ}}];
261 | end
262 | if lenA > 0 then
263 | history[thId][roundId][{{lenQ + 1, lenQ + lenA}}]
264 | = answers[thId][roundId-1][{{1, lenA}}];
265 | end
266 | lenH = lenA + lenQ;
267 | end
268 | end
269 | -- save the history length
270 | histLen[thId][roundId] = lenH;
271 | end
272 | end
273 |
274 | -- right align history and then save
275 | print('Right aligning history: '..dtype);
276 | self[dtype..'_hist'] = utils.rightAlign(history, histLen);
277 | self[dtype..'_hist_len'] = histLen;
278 | end
279 |
280 | -- process options
281 | function dataloader:processOptions(dtype)
282 | local lengths = self[dtype..'_opt_len'];
283 | local answers = self[dtype..'_ans'];
284 | local maxAnsLen = answers:size(3);
285 | local answers = self[dtype..'_opt_list'];
286 | local numConvs = answers:size(1);
287 |
288 | local ansListLen = answers:size(1);
289 | local decodeIn = torch.LongTensor(ansListLen, maxAnsLen + 1):zero();
290 | local decodeOut = torch.LongTensor(ansListLen, maxAnsLen + 1):zero();
291 |
292 | -- decodeIn begins with
293 | decodeIn[{{}, 1}] = self.word2ind[''];
294 |
295 | -- go over each answer and modify
296 | local endTokenId = self.word2ind[''];
297 | for id = 1, ansListLen do
298 | -- print progress for number of images
299 | if id % 100 == 0 then
300 | xlua.progress(id, numConvs);
301 | end
302 | local length = lengths[id];
303 |
304 | -- only if nonzero
305 | if length > 0 then
306 | decodeIn[id][{{2, length + 1}}] = answers[id][{{1, length}}];
307 |
308 | decodeOut[id][{{1, length}}] = answers[id][{{1, length}}];
309 | decodeOut[id][length + 1] = endTokenId;
310 | else
311 | print(string.format('Warning: empty answer for %s at %d',
312 | dtype, id))
313 | end
314 | end
315 |
316 | self[dtype..'_opt_len'] = self[dtype..'_opt_len'] + 1;
317 | self[dtype..'_opt_in'] = decodeIn;
318 | self[dtype..'_opt_out'] = decodeOut;
319 |
320 | collectgarbage();
321 | end
322 |
323 | -- method to grab the next training batch
324 | function dataloader.getTrainBatch(self, params, batchSize)
325 | local size = batchSize or params.batchSize;
326 | local inds = torch.LongTensor(size):random(1, params.numTrainThreads);
327 |
328 | -- Index question, answers, image features for batch
329 | local batchOutput = self:getIndexData(inds, params, 'train')
330 | if params.decoder == 'disc' then
331 | local optionOutput = self:getIndexOption(inds, params, 'train')
332 | batchOutput['options'] = optionOutput:view(optionOutput:size(1)
333 | * optionOutput:size(2), optionOutput:size(3), -1)
334 | batchOutput['answer_ind'] = batchOutput['answer_ind']:view(batchOutput['answer_ind']
335 | :size(1) * batchOutput['answer_ind']:size(2))
336 | end
337 |
338 | return batchOutput
339 | end
340 |
341 | -- method to grab the next test/val batch, for evaluation of a given size
342 | function dataloader.getTestBatch(self, startId, params, dtype)
343 | local batchSize = params.batchSize
344 | -- get the next start id and fill up current indices till then
345 | local nextStartId;
346 | if dtype == 'val' then
347 | nextStartId = math.min(self.numValThreads+1, startId + batchSize);
348 | end
349 | if dtype == 'test' then
350 | nextStartId = math.min(self.numTestThreads+1, startId + batchSize);
351 | end
352 |
353 | -- dumb way to get range (complains if cudatensor is default)
354 | local inds = torch.LongTensor(nextStartId - startId);
355 | for ii = startId, nextStartId - 1 do inds[ii - startId + 1] = ii; end
356 |
357 | -- Index question, answers, image features for batch
358 | local batchOutput = self:getIndexData(inds, params, dtype);
359 | local optionOutput = self:getIndexOption(inds, params, dtype);
360 |
361 | if params.decoder == 'disc' then
362 | batchOutput['options'] = optionOutput:view(optionOutput:size(1)
363 | * optionOutput:size(2), optionOutput:size(3), -1)
364 | if dtype ~= 'test' then
365 | batchOutput['answer_ind'] = batchOutput['answer_ind']:view(batchOutput['answer_ind']
366 | :size(1) * batchOutput['answer_ind']:size(2))
367 | end
368 | elseif params.decoder == 'gen' then
369 | -- merge both the tables and return
370 | for key, value in pairs(optionOutput) do batchOutput[key] = value; end
371 | end
372 | batchOutput['num_rounds'] = self[dtype..'_num_rounds']:index(1, inds):long()
373 |
374 | return batchOutput, nextStartId;
375 | end
376 |
377 | -- get batch from data subset given the indices
378 | function dataloader.getIndexData(self, inds, params, dtype)
379 | -- get the question lengths
380 | local batchQuesLen = self[dtype..'_ques_len']:index(1, inds);
381 | local maxQuesLen = torch.max(batchQuesLen);
382 | -- get questions
383 | local quesFwd = self[dtype..'_ques_fwd']:index(1, inds)
384 | [{{}, {}, {-maxQuesLen, -1}}];
385 |
386 | local history;
387 | if self.useHistory then
388 | local batchHistLen = self[dtype..'_hist_len']:index(1, inds);
389 | local maxHistLen = math.min(torch.max(batchHistLen), self.maxHistoryLen);
390 | history = self[dtype..'_hist']:index(1, inds)
391 | [{{}, {}, {-maxHistLen, -1}}];
392 | end
393 |
394 | local imgFeats;
395 | if self.useIm then
396 | local imgInds = self[dtype..'_img_pos']:index(1, inds);
397 | imgFeats = self[dtype..'_img_fv']:index(1, imgInds);
398 | end
399 |
400 | -- get the answer lengths
401 | local batchAnsLen = self[dtype..'_ans_len']:index(1, inds);
402 | local maxAnsLen = torch.max(batchAnsLen);
403 | -- answer labels (decode input and output)
404 | local answerIn = self[dtype..'_ans_in']
405 | :index(1, inds)[{{}, {}, {1, maxAnsLen}}];
406 | local answerOut = self[dtype..'_ans_out']
407 | :index(1, inds)[{{}, {}, {1, maxAnsLen}}];
408 |
409 | local output = {};
410 | if params.gpuid >= 0 then
411 | output['ques_fwd'] = quesFwd:cuda();
412 | output['answer_in'] = answerIn:cuda();
413 | output['answer_out'] = answerOut:cuda();
414 | if history then output['hist'] = history:cuda(); end
415 | if caption then output['cap'] = caption:cuda(); end
416 | if imgFeats then output['img_feat'] = imgFeats:cuda(); end
417 | else
418 | output['ques_fwd'] = quesFwd:contiguous();
419 | output['answer_in'] = answerIn:contiguous();
420 | output['answer_out'] = answerOut:contiguous();
421 | if history then output['hist'] = history:contiguous(); end
422 | if caption then output['cap'] = caption:contiguous(); end
423 | if imgFeats then output['img_feat'] = imgFeats:contiguous(); end
424 | end
425 |
426 | if dtype ~= 'test' then
427 | local answerInd = self[dtype..'_ans_ind']:index(1, inds);
428 | output['answer_ind'] = params.gpuid >= 0 and answerInd:cuda() or answerInd:contiguous();
429 | end
430 |
431 | return output;
432 | end
433 |
434 | -- get batch from options given the indices
435 | function dataloader.getIndexOption(self, inds, params, dtype)
436 | local output = {};
437 | if params.decoder == 'gen' then
438 | local optionIn, optionOut
439 |
440 | local optInds = self[dtype..'_opt']:index(1, inds);
441 | local indVector = optInds:view(-1);
442 |
443 | local batchOptLen = self[dtype..'_opt_len']:index(1, indVector);
444 | local maxOptLen = torch.max(batchOptLen);
445 |
446 | optionIn = self[dtype..'_opt_in']:index(1, indVector);
447 | optionIn = optionIn:view(optInds:size(1), optInds:size(2),
448 | optInds:size(3), -1);
449 | optionIn = optionIn[{{}, {}, {}, {1, maxOptLen}}];
450 |
451 | optionOut = self[dtype..'_opt_out']:index(1, indVector);
452 | optionOut = optionOut:view(optInds:size(1), optInds:size(2),
453 | optInds:size(3), -1);
454 | optionOut = optionOut[{{}, {}, {}, {1, maxOptLen}}];
455 |
456 | if params.gpuid >= 0 then
457 | output['option_in'] = optionIn:cuda();
458 | output['option_out'] = optionOut:cuda();
459 | else
460 | output['option_in'] = optionIn:contiguous();
461 | output['option_out'] = optionOut:contiguous();
462 | end
463 | elseif params.decoder == 'disc' then
464 | local optInds = self[dtype .. '_opt']:index(1, inds)
465 | local indVector = optInds:view(-1)
466 |
467 | local optionIn = self[dtype .. '_opt_list']:index(1, indVector)
468 |
469 | optionIn = optionIn:view(optInds:size(1), optInds:size(2), optInds:size(3), -1)
470 | output = optionIn
471 |
472 | if params.gpuid >= 0 then
473 | output = output:cuda()
474 | end
475 | end
476 |
477 | return output;
478 | end
479 |
480 | return dataloader;
481 |
--------------------------------------------------------------------------------
/decoders/disc.lua:
--------------------------------------------------------------------------------
1 | local decoderNet = {}
2 |
3 | function decoderNet.model(params, enc)
4 | local optionLSTM = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
5 | optionLSTM.batchfirst = true
6 |
7 | local optionEnc = {}
8 | local numOptions = 100
9 | for i = 1, numOptions do
10 | optionEnc[i] = nn.Sequential()
11 | optionEnc[i]:add(nn.Select(2,i))
12 | optionEnc[i]:add(enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'));
13 | optionEnc[i]:add(optionLSTM:clone('weight', 'bias', 'gradWeight', 'gradBias'))
14 | optionEnc[i]:add(nn.Select(2,-1))
15 | optionEnc[i]:add(nn.Reshape(1,params.rnnHiddenSize, true)) -- True ensures that the first dimension remains the batch size
16 | end
17 | optionEncConcat = nn.Concat(2)
18 | for i = 1, numOptions do
19 | optionEncConcat:add(optionEnc[i])
20 | end
21 |
22 | local jointModel = nn.ParallelTable()
23 | jointModel:add(optionEncConcat)
24 | jointModel:add(nn.Reshape(params.rnnHiddenSize, 1, true))
25 |
26 | local dec = nn.Sequential()
27 | dec:add(jointModel)
28 | dec:add(nn.MM())
29 | dec:add(nn.Squeeze())
30 |
31 | return dec;
32 | end
33 |
34 | -- dummy forwardConnect
35 | function decoderNet.forwardConnect(enc, dec, encOut, seqLen) end
36 |
37 | -- dummy backwardConnect
38 | function decoderNet.backwardConnect(enc, dec) end
39 |
40 | return decoderNet;
41 |
--------------------------------------------------------------------------------
/decoders/gen.lua:
--------------------------------------------------------------------------------
1 | local decoderNet = {}
2 |
3 | function decoderNet.model(params, enc)
4 | -- use `nngraph`
5 | nn.FastLSTM.usenngraph = true
6 |
7 | -- decoder network
8 | local dec = nn.Sequential()
9 | -- use the same embedding for both encoder and decoder lstm
10 | local embedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
11 | dec:add(embedNet);
12 |
13 | dec.rnnLayers = {};
14 | -- check if decoder has different hidden size
15 | local hiddenSize = (params.ansHiddenSize ~= 0) and params.ansHiddenSize
16 | or params.rnnHiddenSize;
17 | for layer = 1, params.numLayers do
18 | local inputSize = (layer == 1) and params.embedSize or hiddenSize;
19 | dec.rnnLayers[layer] = nn.SeqLSTM(inputSize, hiddenSize);
20 | dec.rnnLayers[layer]:maskZero();
21 | dec:add(dec.rnnLayers[layer]);
22 | end
23 | dec:add(nn.Sequencer(nn.MaskZero(nn.Linear(hiddenSize, params.vocabSize), 1)))
24 | dec:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), 1)))
25 |
26 | return dec;
27 | end
28 |
29 | -- transfer the hidden state from encoder to decoder
30 | function decoderNet.forwardConnect(enc, dec, encOut, seqLen)
31 | if enc.rnnLayers ~= nil then
32 | for ii = 1, #enc.rnnLayers do
33 | dec.rnnLayers[ii].userPrevOutput = enc.rnnLayers[ii].output[seqLen];
34 | dec.rnnLayers[ii].userPrevCell = enc.rnnLayers[ii].cell[seqLen];
35 | end
36 |
37 | -- last layer gets output gradients
38 | dec.rnnLayers[#enc.rnnLayers].userPrevOutput = encOut;
39 | else
40 | dec.rnnLayers[#dec.rnnLayers].userPrevOutput = encOut
41 | end
42 | end
43 |
44 | -- transfer gradients from decoder to encoder
45 | function decoderNet.backwardConnect(enc, dec)
46 | if enc.rnnLayers ~= nil then
47 | -- borrow gradients from decoder
48 | for ii = 1, #dec.rnnLayers do
49 | enc.rnnLayers[ii].userNextGradCell = dec.rnnLayers[ii].userGradPrevCell;
50 | if ii ~= #dec.rnnLayers then
51 | enc.rnnLayers[ii].gradPrevOutput = dec.rnnLayers[ii].userGradPrevOutput;
52 | end
53 | end
54 |
55 | -- return the gradients for the last layer
56 | return dec.rnnLayers[#enc.rnnLayers].userGradPrevOutput;
57 | else
58 | return dec.rnnLayers[#dec.rnnLayers].userGradPrevOutput
59 | end
60 | end
61 |
62 | -- connecting decoder to itself; useful while sampling
63 | function decoderNet.decoderConnect(dec)
64 | for ii = 1, #dec.rnnLayers do
65 | dec.rnnLayers[ii].userPrevCell = dec.rnnLayers[ii].cell[1]
66 | dec.rnnLayers[ii].userPrevOutput = dec.rnnLayers[ii].output[1]
67 | end
68 | end
69 |
70 | return decoderNet;
71 |
--------------------------------------------------------------------------------
/encoders/hre-ques-hist.lua:
--------------------------------------------------------------------------------
1 | local encoderNet = {}
2 |
3 | function encoderNet.model(params)
4 | local dropout = params.dropout or 0.5
5 |
6 | -- Use `nngraph`
7 | nn.FastLSTM.usenngraph = true;
8 |
9 | -- encoder network
10 | local enc = nn.Sequential();
11 |
12 | -- create the two branches
13 | local concat = nn.ConcatTable();
14 |
15 | -- word branch, along with embedding layer
16 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
17 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed);
18 |
19 | -- language model
20 | enc.rnnLayers = {};
21 | for layer = 1, params.numLayers do
22 | local inputSize = (layer==1) and (params.embedSize)
23 | or params.rnnHiddenSize;
24 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
25 | enc.rnnLayers[layer]:maskZero();
26 |
27 | wordBranch:add(enc.rnnLayers[layer]);
28 | end
29 | wordBranch:add(nn.Select(1, -1));
30 |
31 | -- make clones for embed layer
32 | local qEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
33 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
34 |
35 | -- create two branches
36 | local histBranch = nn.Sequential()
37 | :add(nn.SelectTable(2))
38 | :add(hEmbedNet);
39 | enc.histLayers = {};
40 |
41 | -- number of layers to read the history
42 | for layer = 1, params.numLayers do
43 | local inputSize = (layer == 1) and params.embedSize
44 | or params.rnnHiddenSize;
45 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
46 | enc.histLayers[layer]:maskZero();
47 |
48 | histBranch:add(enc.histLayers[layer]);
49 | end
50 | histBranch:add(nn.Select(1, -1));
51 |
52 | -- add concatTable and join
53 | concat:add(wordBranch)
54 | concat:add(histBranch)
55 | enc:add(concat);
56 |
57 | -- another concat table
58 | local concat2 = nn.ConcatTable();
59 | enc:add(nn.JoinTable(1, 1))
60 |
61 | --change the view of the data
62 | -- always split it back wrt batch size and then do transpose
63 | enc:add(nn.View(-1, params.maxQuesCount, 2*params.rnnHiddenSize));
64 | enc:add(nn.Transpose({1, 2}));
65 | enc:add(nn.View(params.maxQuesCount, -1, 2*params.rnnHiddenSize))
66 | enc:add(nn.SeqLSTM(2*params.rnnHiddenSize, params.rnnHiddenSize))
67 | enc:add(nn.Transpose({1, 2}));
68 | enc:add(nn.View(-1, params.rnnHiddenSize))
69 |
70 | return enc;
71 | end
72 |
73 | return encoderNet
74 |
--------------------------------------------------------------------------------
/encoders/hre-ques-im-hist.lua:
--------------------------------------------------------------------------------
1 | require 'model_utils.MaskTime'
2 |
3 | local encoderNet = {}
4 |
5 | function encoderNet.model(params)
6 | local dropout = params.dropout or 0.5
7 |
8 | -- Use `nngraph`
9 | nn.FastLSTM.usenngraph = true;
10 |
11 | -- encoder network
12 | local enc = nn.Sequential();
13 |
14 | -- create the two branches
15 | local concat = nn.ConcatTable();
16 |
17 | -- word branch, along with embedding layer
18 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
19 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed);
20 |
21 | -- make clones for embed layer
22 | local qEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
23 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
24 |
25 | -- create two branches
26 | local histBranch = nn.Sequential()
27 | :add(nn.SelectTable(3))
28 | :add(hEmbedNet);
29 | enc.histLayers = {};
30 | -- number of layers to read the history
31 | for layer = 1, params.numLayers do
32 | local inputSize = (layer == 1) and params.embedSize
33 | or params.rnnHiddenSize;
34 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
35 | enc.histLayers[layer]:maskZero();
36 |
37 | histBranch:add(enc.histLayers[layer]);
38 | end
39 | histBranch:add(nn.Select(1, -1));
40 |
41 | -- image branch
42 | -- embedding for images
43 | local imgPar = nn.ParallelTable()
44 | :add(nn.Identity())
45 | :add(nn.Sequential()
46 | -- :add(nn.Dropout(0.5))
47 | :add(nn.Linear(params.imgFeatureSize,
48 | params.imgEmbedSize)));
49 | -- select words and image only
50 | local imageBranch = nn.Sequential()
51 | :add(nn.NarrowTable(1, 2))
52 | :add(imgPar)
53 | :add(nn.MaskTime(params.imgEmbedSize));
54 |
55 | -- add concatTable and join
56 | concat:add(wordBranch)
57 | concat:add(imageBranch)
58 | concat:add(histBranch)
59 | enc:add(concat);
60 |
61 | -- another concat table
62 | local concat2 = nn.ConcatTable();
63 |
64 | -- select words + image, and history
65 | local wordImageBranch = nn.Sequential()
66 | :add(nn.NarrowTable(1, 2))
67 | :add(nn.JoinTable(-1))
68 |
69 | -- language model
70 | enc.rnnLayers = {};
71 | for layer = 1, params.numLayers do
72 | local inputSize = (layer==1) and (params.imgEmbedSize+params.embedSize)
73 | or params.rnnHiddenSize;
74 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
75 | enc.rnnLayers[layer]:maskZero();
76 |
77 | wordImageBranch:add(enc.rnnLayers[layer]);
78 | end
79 | wordImageBranch:add(nn.Select(1, -1));
80 |
81 | -- add both the branches (wordImage, select history) to concat2
82 | concat2:add(wordImageBranch):add(nn.SelectTable(3));
83 | enc:add(concat2);
84 |
85 | -- join both the tensors
86 | enc:add(nn.JoinTable(-1));
87 |
88 | -- change the view of the data
89 | -- always split it back wrt batch size and then do transpose
90 | enc:add(nn.View(-1, params.maxQuesCount, 2*params.rnnHiddenSize));
91 | enc:add(nn.Transpose({1, 2}));
92 | enc:add(nn.SeqLSTM(2*params.rnnHiddenSize, params.rnnHiddenSize))
93 | enc:add(nn.Transpose({1, 2}));
94 | enc:add(nn.View(-1, params.rnnHiddenSize))
95 |
96 | return enc;
97 | end
98 |
99 | return encoderNet
100 |
--------------------------------------------------------------------------------
/encoders/hrea-ques-im-hist.lua:
--------------------------------------------------------------------------------
1 | require 'model_utils.MaskTime'
2 | require 'model_utils.MaskFuture'
3 | require 'model_utils.ReplaceZero'
4 |
5 | local encoderNet = {}
6 |
7 | function encoderNet.model(params)
8 | local dropout = params.dropout or 0.5
9 |
10 | -- Use `nngraph`
11 | nn.FastLSTM.usenngraph = true;
12 |
13 | -- encoder network
14 | local enc = nn.Sequential();
15 |
16 | -- create the two branches
17 | local concat = nn.ConcatTable();
18 |
19 | -- word branch, along with embedding layer
20 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
21 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed);
22 |
23 | -- make clones for embed layer
24 | local qEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
25 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
26 |
27 | -- create two branches
28 | local histBranch = nn.Sequential()
29 | :add(nn.SelectTable(3))
30 | :add(hEmbedNet);
31 | enc.histLayers = {};
32 | -- number of layers to read the history
33 | for layer = 1, params.numLayers do
34 | local inputSize = (layer == 1) and params.embedSize
35 | or params.rnnHiddenSize;
36 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
37 | enc.histLayers[layer]:maskZero();
38 |
39 | histBranch:add(enc.histLayers[layer]);
40 | end
41 | histBranch:add(nn.Select(1, -1));
42 |
43 | -- image branch
44 | -- embedding for images
45 | local imgPar = nn.ParallelTable()
46 | :add(nn.Identity())
47 | :add(nn.Sequential()
48 | :add(nn.Dropout(0.5))
49 | :add(nn.Linear(params.imgFeatureSize,
50 | params.imgEmbedSize)));
51 | -- select words and image only
52 | local imageBranch = nn.Sequential()
53 | :add(nn.NarrowTable(1, 2))
54 | :add(imgPar)
55 | :add(nn.MaskTime(params.imgEmbedSize));
56 |
57 | -- add concatTable and join
58 | concat:add(wordBranch)
59 | concat:add(imageBranch)
60 | concat:add(histBranch)
61 | enc:add(concat);
62 |
63 | -- another concat table
64 | local concat2 = nn.ConcatTable();
65 |
66 | -- select words + image, and history
67 | local wordImageBranch = nn.Sequential()
68 | :add(nn.NarrowTable(1, 2))
69 | :add(nn.JoinTable(-1))
70 |
71 | -- language model
72 | enc.rnnLayers = {};
73 | for layer = 1, params.numLayers do
74 | local inputSize = (layer==1) and (params.imgEmbedSize+params.embedSize)
75 | or params.rnnHiddenSize;
76 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
77 | enc.rnnLayers[layer]:maskZero();
78 |
79 | wordImageBranch:add(enc.rnnLayers[layer]);
80 | end
81 | wordImageBranch:add(nn.Select(1, -1));
82 |
83 | -- add both the branches (wordImage, select history) to concat2
84 | concat2:add(wordImageBranch):add(nn.SelectTable(3));
85 | enc:add(concat2);
86 |
87 | -- single layer neural network
88 | -- create attention based history
89 | local prepare = nn.Sequential()
90 | :add(nn.Linear(params.rnnHiddenSize, 1))
91 | :add(nn.View(-1, params.maxQuesCount))
92 | local wordHistBranch = nn.Sequential()
93 | :add(nn.ParallelTable()
94 | :add(prepare:clone())
95 | :add(prepare:clone()))
96 | :add(nn.ParallelTable()
97 | :add(nn.Replicate(10, 3))
98 | :add(nn.Replicate(10, 2)))
99 | :add(nn.CAddTable())
100 | --:add(nn.Tanh())
101 | :add(nn.MaskFuture(params.maxQuesCount))
102 | :add(nn.View(-1, params.maxQuesCount))
103 | :add(nn.ReplaceZero(-1*math.huge))
104 | :add(nn.SoftMax())
105 | :add(nn.View(-1, params.maxQuesCount, params.maxQuesCount))
106 | :add(nn.Replicate(params.rnnHiddenSize, 4));
107 |
108 | local histOnlyBranch = nn.Sequential()
109 | :add(nn.SelectTable(2))
110 | :add(nn.View(-1, params.maxQuesCount, params.rnnHiddenSize))
111 | :add(nn.Replicate(params.maxQuesCount, 2))
112 |
113 | -- add another concatTable to create attention over history
114 | local concat3 = nn.ConcatTable()
115 | :add(wordHistBranch)
116 | :add(histOnlyBranch)
117 | :add(nn.SelectTable(1)) -- append attended history with question
118 | enc:add(concat3);
119 |
120 | -- parallel table to multiply first two tables, and leave the third one untouched
121 | local multiplier = nn.Sequential()
122 | :add(nn.NarrowTable(1, 2))
123 | :add(nn.CMulTable())
124 | :add(nn.Sum(3))
125 | :add(nn.View(-1, params.rnnHiddenSize));
126 | local concat4 = nn.ConcatTable()
127 | :add(multiplier)
128 | :add(nn.SelectTable(3));
129 | enc:add(concat4);
130 |
131 | -- join both the tensors (att over history and encoded question)
132 | enc:add(nn.JoinTable(-1));
133 | enc:add(nn.View(-1, params.maxQuesCount, 2*params.rnnHiddenSize))
134 | enc:add(nn.Transpose({1, 2}))
135 | enc:add(nn.SeqLSTM(2 * params.rnnHiddenSize, params.rnnHiddenSize))
136 | enc:add(nn.Transpose({1, 2}))
137 | enc:add(nn.View(-1, params.rnnHiddenSize))
138 |
139 | return enc;
140 | end
141 |
142 | return encoderNet
143 |
--------------------------------------------------------------------------------
/encoders/lf-att-ques-im-hist.lua:
--------------------------------------------------------------------------------
1 | local encoderNet = {}
2 |
3 | function encoderNet.model(params)
4 |
5 | local inputs = {}
6 | local outputs = {}
7 |
8 | table.insert(inputs, nn.Identity()()) -- question
9 | table.insert(inputs, nn.Identity()()) -- img feats
10 | table.insert(inputs, nn.Identity()()) -- history
11 |
12 | local ques = inputs[1]
13 | local img_feats = inputs[2]
14 | local hist = inputs[3]
15 |
16 | -- word embed layer
17 | wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
18 |
19 | -- make clones for embed layer
20 | local qEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(ques));
21 | local hEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(hist));
22 |
23 | local lst1 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
24 | lst1:maskZero()
25 |
26 | local lst2 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize)
27 | lst2:maskZero()
28 |
29 | local h1 = lst1(hEmbed)
30 | local h2 = lst2(h1)
31 | local h3 = nn.Select(1, -1)(h2)
32 |
33 | local lst3 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
34 | lst3:maskZero()
35 |
36 | local lst4 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize)
37 | lst4:maskZero()
38 |
39 | local q1 = lst3(qEmbed)
40 | local q2 = lst4(q1)
41 | local q3 = nn.Select(1, -1)(q2)
42 |
43 | local qh = nn.Tanh()(nn.Linear(2*params.rnnHiddenSize, params.rnnHiddenSize)(nn.JoinTable(1,1)({q3,h3})))
44 |
45 | -- image attention (inspired by SAN, Yang et al., CVPR16)
46 | local img_tr_size = params.rnnHiddenSize
47 | local rnn_size = params.rnnHiddenSize
48 | local common_embedding_size = params.commonEmbeddingSize
49 | local num_attention_layer = 1
50 |
51 | local u = qh
52 | local img_tr = nn.Dropout(0.5)(
53 | nn.Tanh()(
54 | nn.View(-1, params.imgSpatialSize * params.imgSpatialSize, img_tr_size)(
55 | nn.Linear(params.imgFeatureSize, img_tr_size)(
56 | nn.View(params.imgFeatureSize):setNumInputDims(2)(img_feats)))))
57 |
58 | for i = 1, num_attention_layer do
59 |
60 | -- linear layer: 14x14x1024 -> 14x14x512
61 | local img_common = nn.View(-1, params.imgSpatialSize * params.imgSpatialSize, common_embedding_size)(
62 | nn.Linear(img_tr_size, common_embedding_size)(
63 | nn.View(-1, img_tr_size)(img_tr)))
64 |
65 | -- replicate lstm state 196 times
66 | local ques_common = nn.Linear(rnn_size, common_embedding_size)(u)
67 | local ques_repl = nn.Replicate(params.imgSpatialSize * params.imgSpatialSize, 2)(ques_common)
68 |
69 | -- add image and question features (both 196x512)
70 | local img_ques_common = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({img_common, ques_repl})))
71 | local h = nn.Linear(common_embedding_size, 1)(nn.View(-1, common_embedding_size)(img_ques_common))
72 | local p = nn.SoftMax()(nn.View(-1, params.imgSpatialSize * params.imgSpatialSize)(h))
73 |
74 | -- weighted sum of image features
75 | local p_att = nn.View(1, -1):setNumInputDims(1)(p)
76 | local img_tr_att = nn.MM(false, false)({p_att, img_tr})
77 | local img_tr_att_feat = nn.View(-1, img_tr_size)(img_tr_att)
78 |
79 | -- add image feature vector and question vector
80 | u = nn.CAddTable()({img_tr_att_feat, u})
81 |
82 | end
83 |
84 | local o = nn.Tanh()(nn.Linear(rnn_size, rnn_size)(nn.Dropout(0.5)(u)))
85 | -- SAN stuff ends
86 |
87 | table.insert(outputs, o)
88 |
89 | local enc = nn.gModule(inputs, outputs)
90 | enc.wordEmbed = wordEmbed
91 |
92 | return enc;
93 | end
94 |
95 | return encoderNet
96 |
--------------------------------------------------------------------------------
/encoders/lf-ques-hist.lua:
--------------------------------------------------------------------------------
1 | local encoderNet = {};
2 |
3 | function encoderNet.model(params)
4 | local dropout = params.dropout or 0.5
5 | -- Use nngraph
6 | nn.FastLSTM.usenngraph = true;
7 |
8 | -- encoder network
9 | local enc = nn.Sequential();
10 |
11 | -- create the two branches
12 | local concat = nn.ConcatTable();
13 |
14 | -- word branch, along with embedding layer
15 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
16 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed);
17 |
18 | -- language model
19 | enc.rnnLayers = {};
20 | for layer = 1, params.numLayers do
21 | local inputSize = (layer==1) and (params.embedSize)
22 | or params.rnnHiddenSize;
23 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
24 | enc.rnnLayers[layer]:maskZero();
25 |
26 | wordBranch:add(enc.rnnLayers[layer]);
27 | end
28 | wordBranch:add(nn.Select(1, -1));
29 |
30 | -- make clones for embed layer
31 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
32 |
33 | -- create two branches
34 | local histBranch = nn.Sequential()
35 | :add(nn.SelectTable(2))
36 | :add(hEmbedNet);
37 | enc.histLayers = {};
38 | -- number of layers to read the history
39 | for layer = 1, params.numLayers do
40 | local inputSize = (layer == 1) and params.embedSize
41 | or params.rnnHiddenSize;
42 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
43 | enc.histLayers[layer]:maskZero();
44 |
45 | histBranch:add(enc.histLayers[layer]);
46 | end
47 | histBranch:add(nn.Select(1, -1));
48 |
49 | -- add concatTable and join
50 | concat:add(wordBranch)
51 | concat:add(histBranch)
52 | enc:add(concat);
53 |
54 | enc:add(nn.JoinTable(2))
55 | if dropout > 0 then
56 | enc:add(nn.Dropout(dropout))
57 | end
58 | enc:add(nn.Linear(2 * params.rnnHiddenSize, params.rnnHiddenSize))
59 | enc:add(nn.Tanh())
60 |
61 | return enc;
62 | end
63 |
64 | return encoderNet
65 |
--------------------------------------------------------------------------------
/encoders/lf-ques-im-hist.lua:
--------------------------------------------------------------------------------
1 | local encoderNet = {}
2 |
3 | function encoderNet.model(params)
4 | local dropout = params.dropout or 0.5
5 | -- Use `nngraph`
6 | nn.FastLSTM.usenngraph = true;
7 |
8 | -- encoder network
9 | local enc = nn.Sequential();
10 |
11 | -- create the two branches
12 | local concat = nn.ConcatTable();
13 |
14 | -- word branch, along with embedding layer
15 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
16 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed);
17 |
18 | -- language model
19 | enc.rnnLayers = {};
20 | for layer = 1, params.numLayers do
21 | local inputSize = (layer==1) and (params.embedSize)
22 | or params.rnnHiddenSize;
23 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
24 | enc.rnnLayers[layer]:maskZero();
25 |
26 | wordBranch:add(enc.rnnLayers[layer]);
27 | end
28 | wordBranch:add(nn.Select(1, -1));
29 |
30 | -- make clones for embed layer
31 | local hEmbedNet = enc.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias');
32 |
33 | -- create two branches
34 | local histBranch = nn.Sequential()
35 | :add(nn.SelectTable(3))
36 | :add(hEmbedNet);
37 | enc.histLayers = {};
38 | -- number of layers to read the history
39 | for layer = 1, params.numLayers do
40 | local inputSize = (layer == 1) and params.embedSize
41 | or params.rnnHiddenSize;
42 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
43 | enc.histLayers[layer]:maskZero();
44 |
45 | histBranch:add(enc.histLayers[layer]);
46 | end
47 | histBranch:add(nn.Select(1, -1));
48 |
49 | concat:add(wordBranch)
50 | concat:add(nn.SelectTable(2))
51 | concat:add(histBranch)
52 | enc:add(concat);
53 |
54 | enc:add(nn.JoinTable(2))
55 | if dropout > 0 then
56 | enc:add(nn.Dropout(dropout))
57 | end
58 | enc:add(nn.Linear(2 * params.rnnHiddenSize + params.imgFeatureSize, params.rnnHiddenSize))
59 | enc:add(nn.Tanh())
60 |
61 | return enc;
62 | end
63 |
64 | return encoderNet
65 |
--------------------------------------------------------------------------------
/encoders/lf-ques-im.lua:
--------------------------------------------------------------------------------
1 | local encoderNet = {}
2 |
3 | function encoderNet.model(params)
4 | local dropout = params.dropout or 0.5
5 | -- Use `nngraph`
6 | nn.FastLSTM.usenngraph = true;
7 |
8 | -- encoder network
9 | local enc = nn.Sequential();
10 |
11 | -- create the two branches
12 | local concat = nn.ConcatTable();
13 |
14 | -- word branch, along with embedding layer
15 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
16 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed);
17 |
18 | -- language model
19 | enc.rnnLayers = {};
20 | for layer = 1, params.numLayers do
21 | local inputSize = (layer==1) and (params.embedSize)
22 | or params.rnnHiddenSize;
23 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
24 | enc.rnnLayers[layer]:maskZero();
25 |
26 | wordBranch:add(enc.rnnLayers[layer]);
27 | end
28 | wordBranch:add(nn.Select(1, -1));
29 |
30 | -- add concatTable and join
31 | concat:add(wordBranch)
32 | concat:add(nn.SelectTable(2))
33 | enc:add(concat);
34 |
35 | enc:add(nn.JoinTable(2))
36 | if dropout > 0 then
37 | enc:add(nn.Dropout(dropout))
38 | end
39 | enc:add(nn.Linear(params.rnnHiddenSize + params.imgFeatureSize, params.rnnHiddenSize))
40 | enc:add(nn.Tanh())
41 |
42 | return enc;
43 | end
44 |
45 | return encoderNet
46 |
--------------------------------------------------------------------------------
/encoders/lf-ques.lua:
--------------------------------------------------------------------------------
1 | local encoderNet = {};
2 |
3 | function encoderNet.model(params)
4 | local dropout = params.dropout or 0.5;
5 | -- Use `nngraph`
6 | nn.FastLSTM.usenngraph = true;
7 |
8 | -- encoder network
9 | local enc = nn.Sequential();
10 |
11 | -- word branch, along with embedding layer
12 | enc.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
13 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(enc.wordEmbed);
14 |
15 | -- language model
16 | enc.rnnLayers = {};
17 | for layer = 1, params.numLayers do
18 | local inputSize = (layer==1) and (params.embedSize)
19 | or params.rnnHiddenSize;
20 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize);
21 | enc.rnnLayers[layer]:maskZero();
22 |
23 | wordBranch:add(enc.rnnLayers[layer]);
24 | end
25 | wordBranch:add(nn.Select(1, -1));
26 |
27 | enc:add(wordBranch);
28 |
29 | if dropout > 0 then
30 | enc:add(nn.Dropout(dropout))
31 | end
32 | enc:add(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize))
33 | enc:add(nn.Tanh())
34 |
35 | return enc;
36 | end
37 |
38 | return encoderNet;
39 |
--------------------------------------------------------------------------------
/encoders/mn-att-ques-im-hist.lua:
--------------------------------------------------------------------------------
1 | require 'model_utils.MaskSoftMax'
2 |
3 | local encoderNet = {}
4 |
5 | function encoderNet.model(params)
6 |
7 | local inputs = {}
8 | local outputs = {}
9 |
10 | table.insert(inputs, nn.Identity()()) -- question
11 | table.insert(inputs, nn.Identity()()) -- img feats
12 | table.insert(inputs, nn.Identity()()) -- history
13 | table.insert(inputs, nn.Identity()()) -- 10x10 mask
14 |
15 | local ques = inputs[1]
16 | local img_feats = inputs[2]
17 | local hist = inputs[3]
18 | local mask = inputs[4]
19 |
20 | -- word embed layer
21 | wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
22 |
23 | -- make clones for embed layer
24 | local qEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(ques));
25 | local hEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(hist));
26 |
27 | local lst1 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
28 | lst1:maskZero()
29 |
30 | local lst2 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize)
31 | lst2:maskZero()
32 |
33 | local h1 = lst1(hEmbed)
34 | local h2 = lst2(h1)
35 | local h3 = nn.Select(1, -1)(h2)
36 |
37 | local lst3 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
38 | lst3:maskZero()
39 |
40 | local lst4 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize)
41 | lst4:maskZero()
42 |
43 | local q1 = lst3(qEmbed)
44 | local q2 = lst4(q1)
45 | local q3 = nn.Select(1, -1)(q2)
46 |
47 | -- View as batch x rounds
48 | local qEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(q3)
49 | local hEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(h3)
50 |
51 | -- Inner product
52 | -- q is Bx10xE, h is Bx10xE
53 | -- qh is Bx10x10, rows correspond to questions, columns to facts
54 | local qh = nn.MM(false, true)({qEmbedView, hEmbedView})
55 | local qhView = nn.View(-1, params.maxQuesCount)(qh)
56 | local qhprobs = nn.MaskSoftMax(){qhView, mask}
57 | local qhView2 = nn.View(-1, params.maxQuesCount, params.maxQuesCount)(qhprobs)
58 |
59 | -- Weighted sum of h features
60 | -- h is Bx10xE, qhView2 is Bx10x10
61 | local hAtt = nn.MM(){qhView2, hEmbedView}
62 | local hAttView = nn.View(-1, params.rnnHiddenSize)(hAtt)
63 |
64 | local hAttTr = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.Dropout(0.5)(hAttView)))
65 | local qh2 = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.CAddTable(){hAttTr, nn.View(-1, params.rnnHiddenSize)(qEmbedView)}))
66 |
67 | -- image attention (inspired by SAN, Yang et al., CVPR16)
68 | local img_tr_size = params.rnnHiddenSize
69 | local rnn_size = params.rnnHiddenSize
70 | local common_embedding_size = params.commonEmbeddingSize or 512
71 | local num_attention_layer = params.numAttentionLayers or 1
72 |
73 | local u = qh2
74 | local img_tr = nn.Dropout(0.5)(
75 | nn.Tanh()(
76 | nn.View(-1, params.imgSpatialSize * params.imgSpatialSize, img_tr_size)(
77 | nn.Linear(params.imgFeatureSize, img_tr_size)(
78 | nn.View(params.imgFeatureSize):setNumInputDims(2)(img_feats)))))
79 |
80 | for i = 1, num_attention_layer do
81 |
82 | -- linear layer: 14x14x1024 -> 14x14x512
83 | local img_common = nn.View(-1, params.imgSpatialSize * params.imgSpatialSize, common_embedding_size)(
84 | nn.Linear(img_tr_size, common_embedding_size)(
85 | nn.View(-1, img_tr_size)(img_tr)))
86 |
87 | -- replicate lstm state 196 times
88 | local ques_common = nn.Linear(rnn_size, common_embedding_size)(u)
89 | local ques_repl = nn.Replicate(params.imgSpatialSize * params.imgSpatialSize, 2)(ques_common)
90 |
91 | -- add image and question features (both 196x512)
92 | local img_ques_common = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({img_common, ques_repl})))
93 | local h = nn.Linear(common_embedding_size, 1)(nn.View(-1, common_embedding_size)(img_ques_common))
94 | local p = nn.SoftMax()(nn.View(-1, params.imgSpatialSize * params.imgSpatialSize)(h))
95 |
96 | -- weighted sum of image features
97 | local p_att = nn.View(1, -1):setNumInputDims(1)(p)
98 | local img_tr_att = nn.MM(false, false)({p_att, img_tr})
99 | local img_tr_att_feat = nn.View(-1, img_tr_size)(img_tr_att)
100 |
101 | -- add image feature vector and question vector
102 | u = nn.CAddTable()({img_tr_att_feat, u})
103 |
104 | end
105 |
106 | local o = nn.Tanh()(nn.Linear(rnn_size, rnn_size)(nn.Dropout(0.5)(u)))
107 | -- SAN stuff ends
108 |
109 | table.insert(outputs, o)
110 |
111 | local enc = nn.gModule(inputs, outputs)
112 | enc.wordEmbed = wordEmbed
113 |
114 | return enc;
115 | end
116 |
117 | return encoderNet
118 |
--------------------------------------------------------------------------------
/encoders/mn-ques-hist.lua:
--------------------------------------------------------------------------------
1 | require 'model_utils.MaskSoftMax'
2 |
3 | local encoderNet = {}
4 |
5 | function encoderNet.model(params)
6 |
7 | local inputs = {}
8 | local outputs = {}
9 |
10 | table.insert(inputs, nn.Identity()()) -- question
11 | table.insert(inputs, nn.Identity()()) -- history
12 | table.insert(inputs, nn.Identity()()) -- 10x10 mask
13 |
14 | local ques = inputs[1]
15 | local hist = inputs[2]
16 | local mask = inputs[3]
17 |
18 | -- word embed layer
19 | wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
20 |
21 | -- make clones for embed layer
22 | local qEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(ques));
23 | local hEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(hist));
24 |
25 | local lst1 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
26 | lst1:maskZero()
27 |
28 | local lst2 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize)
29 | lst2:maskZero()
30 |
31 | local h1 = lst1(hEmbed)
32 | local h2 = lst2(h1)
33 | local h3 = nn.Select(1, -1)(h2)
34 |
35 | local lst3 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
36 | lst3:maskZero()
37 |
38 | local lst4 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize)
39 | lst4:maskZero()
40 |
41 | local q1 = lst3(qEmbed)
42 | local q2 = lst4(q1)
43 | local q3 = nn.Select(1, -1)(q2)
44 |
45 | -- View as batch x rounds
46 | local qEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(q3)
47 | local hEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(h3)
48 |
49 | -- Inner product
50 | -- q is Bx10xE, h is Bx10xE
51 | -- qh is Bx10x10, rows correspond to questions, columns to facts
52 | local qh = nn.MM(false, true)({qEmbedView, hEmbedView})
53 | local qhView = nn.View(-1, params.maxQuesCount)(qh)
54 | local qhprobs = nn.MaskSoftMax(){qhView, mask}
55 | local qhView2 = nn.View(-1, params.maxQuesCount, params.maxQuesCount)(qhprobs)
56 |
57 | -- Weighted sum of h features
58 | -- h is Bx10xE, qhView2 is Bx10x10
59 | local hAtt = nn.MM(){qhView2, hEmbedView}
60 | local hAttView = nn.View(-1, params.rnnHiddenSize)(hAtt)
61 |
62 | local hAttTr = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.Dropout(0.5)(hAttView)))
63 | local qh2 = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.CAddTable(){hAttTr, nn.View(-1, params.rnnHiddenSize)(qEmbedView)}))
64 |
65 | table.insert(outputs, qh2)
66 |
67 | local enc = nn.gModule(inputs, outputs)
68 | enc.wordEmbed = wordEmbed
69 |
70 | return enc;
71 | end
72 |
73 | return encoderNet
74 |
--------------------------------------------------------------------------------
/encoders/mn-ques-im-hist.lua:
--------------------------------------------------------------------------------
1 | require 'model_utils.MaskSoftMax'
2 |
3 | local encoderNet = {}
4 |
5 | function encoderNet.model(params)
6 |
7 | local inputs = {}
8 | local outputs = {}
9 |
10 | table.insert(inputs, nn.Identity()()) -- question
11 | table.insert(inputs, nn.Identity()()) -- img feats
12 | table.insert(inputs, nn.Identity()()) -- history
13 | table.insert(inputs, nn.Identity()()) -- 10x10 mask
14 |
15 | local ques = inputs[1]
16 | local img_feats = inputs[2]
17 | local hist = inputs[3]
18 | local mask = inputs[4]
19 |
20 | -- word embed layer
21 | wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize);
22 |
23 | -- make clones for embed layer
24 | local qEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(ques));
25 | local hEmbed = nn.Dropout(0.5)(wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias')(hist));
26 |
27 | local lst1 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
28 | lst1:maskZero()
29 |
30 | local lst2 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize)
31 | lst2:maskZero()
32 |
33 | local h1 = lst1(hEmbed)
34 | local h2 = lst2(h1)
35 | local h3 = nn.Select(1, -1)(h2)
36 |
37 | local lst3 = nn.SeqLSTM(params.embedSize, params.rnnHiddenSize)
38 | lst3:maskZero()
39 |
40 | local lst4 = nn.SeqLSTM(params.rnnHiddenSize, params.rnnHiddenSize)
41 | lst4:maskZero()
42 |
43 | local q1 = lst3(qEmbed)
44 | local q2 = lst4(q1)
45 | local q3 = nn.Select(1, -1)(q2)
46 |
47 | local qi = nn.JoinTable(1,1)({q3,img_feats})
48 | local qi_proj = nn.Tanh()(nn.Linear(params.imgFeatureSize + params.rnnHiddenSize, params.rnnHiddenSize)(qi))
49 |
50 | -- View as batch x rounds
51 | local qEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(qi_proj)
52 | local hEmbedView = nn.View(-1, params.maxQuesCount, params.rnnHiddenSize)(h3)
53 |
54 | -- Inner product
55 | -- q is Bx10xE, h is Bx10xE
56 | -- qh is Bx10x10, rows correspond to questions, columns to facts
57 | local qh = nn.MM(false, true)({qEmbedView, hEmbedView})
58 | local qhView = nn.View(-1, params.maxQuesCount)(qh)
59 | local qhprobs = nn.MaskSoftMax(){qhView, mask}
60 | local qhView2 = nn.View(-1, params.maxQuesCount, params.maxQuesCount)(qhprobs)
61 |
62 | -- Weighted sum of h features
63 | -- h is Bx10xE, qhView2 is Bx10x10
64 | local hAtt = nn.MM(){qhView2, hEmbedView}
65 | local hAttView = nn.View(-1, params.rnnHiddenSize)(hAtt)
66 |
67 | local hAttTr = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.Dropout(0.5)(hAttView)))
68 | local qh2 = nn.Tanh()(nn.Linear(params.rnnHiddenSize, params.rnnHiddenSize)(nn.CAddTable(){hAttTr, nn.View(-1, params.rnnHiddenSize)(qEmbedView)}))
69 |
70 | table.insert(outputs, qh2)
71 |
72 | local enc = nn.gModule(inputs, outputs)
73 | enc.wordEmbed = wordEmbed
74 |
75 | return enc;
76 | end
77 |
78 | return encoderNet
79 |
--------------------------------------------------------------------------------
/evaluate.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'rnn'
3 | require 'nngraph'
4 | utils = dofile('utils.lua');
5 |
6 | -------------------------------------------------------------------------------
7 | -- Input arguments and options
8 | -------------------------------------------------------------------------------
9 | cmd = torch.CmdLine()
10 | cmd:text()
11 | cmd:text('Test the VisDial model for retrieval')
12 | cmd:text()
13 | cmd:text('Options')
14 |
15 | -- Data input settings
16 | cmd:option('-inputImg','data/data_img.h5','h5file path with image feature')
17 | cmd:option('-inputQues','data/visdial_data.h5','h5file file with preprocessed questions')
18 | cmd:option('-inputJson','data/visdial_params.json','json path with info and vocab')
19 |
20 | cmd:option('-loadPath', 'checkpoints/model.t7', 'path to saved model')
21 | cmd:option('-split', 'val', 'split to evaluate on')
22 | cmd:option('-useGt', false, 'whether to use ground truth for retrieving ranks')
23 |
24 | -- Inference params
25 | cmd:option('-batchSize', 30, 'Batch size (number of threads) (Adjust base on GRAM)')
26 | cmd:option('-gpuid', 0, 'GPU id to use')
27 | cmd:option('-backend', 'cudnn', 'nn|cudnn')
28 |
29 | cmd:option('-saveRanks', false, 'Whether to save ranks or not');
30 | cmd:option('-saveRankPath', 'logs/ranks.json');
31 |
32 | local opt = cmd:parse(arg);
33 |
34 | if opt.useGt and opt.split == 'test' then
35 | print('Warning: No ground truth avaiilable in test split, changing useGt to false.')
36 | opt.useGt = false
37 | end
38 | print(opt)
39 |
40 | -- seed for reproducibility
41 | torch.manualSeed(1234);
42 |
43 | -- set default tensor based on gpu usage
44 | if opt.gpuid >= 0 then
45 | require 'cutorch'
46 | require 'cunn'
47 | if opt.backend == 'cudnn' then require 'cudnn' end
48 | cutorch.setDevice(opt.gpuid+1)
49 | cutorch.manualSeed(1234)
50 | torch.setdefaulttensortype('torch.CudaTensor');
51 | else
52 | torch.setdefaulttensortype('torch.FloatTensor');
53 | end
54 |
55 | ------------------------------------------------------------------------
56 | -- Read saved model and parameters
57 | ------------------------------------------------------------------------
58 | local savedModel = torch.load(opt.loadPath)
59 |
60 | -- transfer all options to model
61 | local modelParams = savedModel.modelParams
62 |
63 | opt.imgNorm = modelParams.imgNorm
64 | opt.encoder = modelParams.encoder
65 | opt.decoder = modelParams.decoder
66 | modelParams.gpuid = opt.gpuid
67 | modelParams.batchSize = opt.batchSize
68 | modelParams.useGt = opt.useGt
69 |
70 | -- add flags for various configurations
71 | -- additionally check if its imitation of discriminative model
72 | if string.match(opt.encoder, 'hist') then opt.useHistory = true; end
73 | if string.match(opt.encoder, 'im') then opt.useIm = true; end
74 | -- check if history is to be concatenated (only for late fusion encoder)
75 | if string.match(opt.encoder, 'lf') then opt.concatHistory = true end
76 |
77 | ------------------------------------------------------------------------
78 | -- Loading dataset
79 | ------------------------------------------------------------------------
80 | local dataloader = dofile('dataloader.lua')
81 | dataloader:initialize(opt, {opt.split});
82 | collectgarbage();
83 |
84 | ------------------------------------------------------------------------
85 | -- Setup the model
86 | ------------------------------------------------------------------------
87 | require 'model'
88 | local model = Model(modelParams)
89 |
90 | -- copy the weights from loaded model
91 | model.wrapperW:copy(savedModel.modelW);
92 |
93 | ------------------------------------------------------------------------
94 | -- Evaluation
95 | ------------------------------------------------------------------------
96 | print('Evaluating..')
97 | local ranks;
98 | if opt.useGt then
99 | ranks = model:retrieve(dataloader, opt.split);
100 | else
101 | ranks = model:predict(dataloader, opt.split);
102 | end
103 |
104 | if opt.saveRanks == true then
105 | print(string.format('Writing ranks to %s', opt.saveRankPath));
106 | utils.writeJSON(opt.saveRankPath, ranks);
107 | end
108 |
--------------------------------------------------------------------------------
/generate.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'rnn'
3 | require 'nngraph'
4 | utils = dofile('utils.lua');
5 |
6 | -------------------------------------------------------------------------------
7 | -- Input arguments and options
8 | -------------------------------------------------------------------------------
9 | cmd = torch.CmdLine()
10 | cmd:text()
11 | cmd:text('Test the VisDial model for retrieval')
12 | cmd:text()
13 | cmd:text('Options')
14 |
15 | -- Data input settings
16 | cmd:option('-inputImg','data/data_img.h5','h5file path with image feature')
17 | cmd:option('-inputQues','data/visdial_data.h5','h5file file with preprocessed questions')
18 | cmd:option('-inputJson','data/visdial_params.json','json path with info and vocab')
19 |
20 | cmd:option('-loadPath', 'checkpoints/model.t7', 'path to saved model')
21 | cmd:option('-resultPath', 'vis/results', 'path to save generated results')
22 |
23 | -- sampling params
24 | cmd:option('-beamSize', 5, 'Beam size')
25 | cmd:option('-beamLen', 20, 'Beam length')
26 | cmd:option('-sampleWords', 0, 'Whether to sample')
27 | cmd:option('-temperature', 1.0, 'Sampling temperature')
28 | cmd:option('-maxThreads', 50, 'Max threads')
29 | cmd:option('-gpuid', 0, 'GPU id to use')
30 | cmd:option('-backend', 'cudnn', 'nn|cudnn')
31 |
32 | local opt = cmd:parse(arg);
33 | print(opt)
34 |
35 | -- seed for reproducibility
36 | torch.manualSeed(1234);
37 |
38 | -- set default tensor based on gpu usage
39 | if opt.gpuid >= 0 then
40 | require 'cutorch'
41 | require 'cunn'
42 | if opt.backend == 'cudnn' then require 'cudnn' end
43 | cutorch.setDevice(opt.gpuid+1)
44 | cutorch.manualSeed(1234)
45 | torch.setdefaulttensortype('torch.CudaTensor');
46 | else
47 | torch.setdefaulttensortype('torch.FloatTensor');
48 | end
49 |
50 | ------------------------------------------------------------------------
51 | -- Read saved model and parameters
52 | ------------------------------------------------------------------------
53 | local savedModel = torch.load(opt.loadPath)
54 |
55 | -- transfer all options to model
56 | local modelParams = savedModel.modelParams
57 | opt.imgNorm = modelParams.imgNorm
58 | opt.encoder = modelParams.encoder
59 | opt.decoder = modelParams.decoder
60 | modelParams.gpuid = opt.gpuid
61 |
62 | -- add flags for various configurations
63 | -- additionally check if its imitation of discriminative model
64 | if string.match(opt.encoder, 'hist') then
65 | opt.useHistory = true;
66 | end
67 | if string.match(opt.encoder, 'im') then opt.useIm = true; end
68 |
69 | ------------------------------------------------------------------------
70 | -- Loading dataset
71 | ------------------------------------------------------------------------
72 | local dataloader = dofile('dataloader.lua')
73 | dataloader:initialize(opt, {'val'});
74 | collectgarbage();
75 |
76 | ------------------------------------------------------------------------
77 | -- Setup the model
78 | ------------------------------------------------------------------------
79 | require 'model'
80 | local model = Model(modelParams)
81 |
82 | -- copy the weights from loaded model
83 | model.wrapperW:copy(savedModel.modelW);
84 |
85 | ------------------------------------------------------------------------
86 | -- Generating
87 | ------------------------------------------------------------------------
88 | sampleParams = {
89 | beamSize = opt.beamSize,
90 | beamLen = opt.beamLen,
91 | maxThreads = opt.maxThreads,
92 | sampleWords = opt.sampleWords,
93 | temperature = opt.temperature
94 | }
95 |
96 | local answers = model:generateAnswers(dataloader, 'val', sampleParams)
97 | local output = {opts = opt, data = answers}
98 |
99 | -- save the file to json
100 | local savePath = string.format('%s/results.json', opt.resultPath);
101 | paths.mkdir(opt.resultPath)
102 | utils.writeJSON(savePath, output);
103 | print('Writing the results to '.. savePath);
104 |
105 |
--------------------------------------------------------------------------------
/model.lua:
--------------------------------------------------------------------------------
1 | -- abstract class for models
2 | require 'model_utils.optim_updates'
3 | require 'xlua'
4 | require 'hdf5'
5 |
6 | local utils = require 'utils'
7 |
8 | local Model = torch.class('Model');
9 |
10 | -- initialize
11 | function Model:__init(params)
12 | print('Setting up model..')
13 | self.params = params
14 |
15 | print('Encoder: ', params.encoder)
16 | print('Decoder: ', params.decoder)
17 |
18 | -- build the model - encoder, decoder
19 | local encFile = string.format('encoders/%s.lua', params.encoder);
20 | local encoder = dofile(encFile);
21 |
22 | local decFile = string.format('decoders/%s.lua', params.decoder);
23 | local decoder = dofile(decFile);
24 |
25 | enc = encoder.model(params)
26 | dec = decoder.model(params, enc)
27 |
28 | local decMethods = {'forwardConnect', 'backwardConnect', 'decoderConnect'}
29 | for key, value in pairs(decMethods) do self[value] = decoder[value]; end
30 |
31 | -- criterion
32 | if params.decoder == 'gen' then
33 | self.criterion = nn.ClassNLLCriterion();
34 | self.criterion.sizeAverage = false;
35 | self.criterion = nn.SequencerCriterion(
36 | nn.MaskZeroCriterion(self.criterion, 1));
37 | elseif params.decoder == 'disc' then
38 | self.criterion = nn.CrossEntropyCriterion()
39 | end
40 |
41 | -- wrap the models
42 | self.wrapper = nn.Sequential():add(enc):add(dec);
43 |
44 | -- initialize weights
45 | self.wrapper = require('model_utils/weight-init')(self.wrapper, params.weightInit);
46 |
47 | -- ship to gpu if necessary
48 | if params.gpuid >= 0 then
49 | self.wrapper = self.wrapper:cuda();
50 | self.criterion = self.criterion:cuda();
51 | end
52 |
53 | self.encoder = self.wrapper:get(1);
54 | self.decoder = self.wrapper:get(2);
55 | self.wrapperW, self.wrapperdW = self.wrapper:getParameters();
56 |
57 | self.wrapper:training();
58 |
59 | -- setup the optimizer
60 | self.optims = {};
61 | self.optims.learningRate = params.learningRate;
62 | end
63 |
64 | -------------------------------------------------------------------------------
65 | -- One iteration of training -- forward and backward pass
66 | function Model:trainIteration(dataloader)
67 | -- clear the gradients
68 | self.wrapper:zeroGradParameters();
69 |
70 | -- grab a training batch
71 | local batch = dataloader:getTrainBatch(self.params);
72 |
73 | -- call the internal function for model specific actions
74 | local curLoss = self:forwardBackward(batch);
75 |
76 | if self.params.decoder == 'gen' then
77 | -- count the number of tokens
78 | local numTokens = torch.sum(batch['answer_out']:gt(0));
79 |
80 | -- update the running average of loss
81 | if runningLoss > 0 then
82 | runningLoss = 0.95 * runningLoss + 0.05 * curLoss/numTokens;
83 | else
84 | runningLoss = curLoss/numTokens;
85 | end
86 | elseif self.params.decoder == 'disc' then
87 | -- update the running average of loss
88 | if runningLoss > 0 then
89 | runningLoss = 0.95 * runningLoss + 0.05 * curLoss
90 | else
91 | runningLoss = curLoss
92 | end
93 | end
94 |
95 | -- clamp gradients
96 | self.wrapperdW:clamp(-5.0, 5.0);
97 |
98 | -- update parameters
99 | adam(self.wrapperW, self.wrapperdW, self.optims);
100 |
101 | -- decay learning rate, if needed
102 | if self.optims.learningRate > self.params.minLRate then
103 | self.optims.learningRate = self.optims.learningRate *
104 | self.params.lrDecayRate;
105 | end
106 | end
107 | ---------------------------------------------------------------------
108 | -- validation performance on test/val
109 | function Model:evaluate(dataloader, dtype)
110 | -- change to evaluate mode
111 | self.wrapper:evaluate();
112 |
113 | local curLoss = 0;
114 | local startId = 1;
115 | local numThreads = dataloader.numThreads[dtype];
116 |
117 | local numTokens = 0;
118 | while startId <= numThreads do
119 | -- print progress
120 | xlua.progress(startId, numThreads);
121 |
122 | -- grab a validation batch
123 | local batch, nextStartId
124 | = dataloader:getTestBatch(startId, self.params, dtype);
125 | -- count the number of tokens
126 | numTokens = numTokens + torch.sum(batch['answer_out']:gt(0));
127 | -- forward pass to compute loss
128 | curLoss = curLoss + self:forwardBackward(batch, true);
129 | startId = nextStartId;
130 | end
131 |
132 | -- print the results
133 | curLoss = curLoss / numTokens;
134 | print(string.format('\n%s\tLoss: %f\t Perplexity: %f\n', dtype,
135 | curLoss, math.exp(curLoss)));
136 |
137 | -- change back to training
138 | self.wrapper:training();
139 | end
140 |
141 | -- retrieval performance on val
142 | function Model:retrieve(dataloader, dtype)
143 | -- change to evaluate mode
144 | self.wrapper:evaluate();
145 |
146 | local curLoss = 0;
147 | local startId = 1;
148 | self.params.numOptions = 100;
149 | local numThreads = dataloader.numThreads[dtype];
150 | print('numThreads', numThreads)
151 |
152 | local ranks = torch.Tensor(numThreads, self.params.maxQuesCount);
153 | ranks:fill(self.params.numOptions + 1);
154 |
155 | while startId <= numThreads do
156 | -- print progress
157 | xlua.progress(startId, numThreads);
158 |
159 | -- grab a batch
160 | local batch, nextStartId =
161 | dataloader:getTestBatch(startId, self.params, dtype);
162 |
163 | -- Call retrieve function for specific model, and store ranks
164 | ranks[{{startId, nextStartId - 1}, {}}] = self:retrieveBatch(batch);
165 | startId = nextStartId;
166 | end
167 |
168 | print(string.format('\n%s - Retrieval:', dtype))
169 | utils.processRanks(ranks);
170 |
171 | -- change back to training
172 | self.wrapper:training();
173 |
174 | local retrieval = {};
175 | local ranks = torch.totable(ranks:double());
176 | for i = 1, #dataloader['unique_img_'..dtype] do
177 | for j = 1, dataloader[dtype..'_num_rounds'][i] do
178 | table.insert(retrieval, {
179 | image_id = dataloader['unique_img_'..dtype][i];
180 | round_id = j;
181 | ranks = ranks[i][j]
182 | })
183 | end
184 | end
185 | -- collect garbage
186 | collectgarbage();
187 |
188 | return retrieval;
189 | end
190 |
191 | -- prediction on val/test
192 | function Model:predict(dataloader, dtype)
193 | -- change to evaluate mode
194 | self.wrapper:evaluate();
195 |
196 | local curLoss = 0;
197 | local startId = 1;
198 | local numThreads = dataloader.numThreads[dtype];
199 | self.params.numOptions = 100;
200 | print('numThreads', numThreads)
201 |
202 | local ranks = torch.Tensor(numThreads, 10, self.params.numOptions);
203 | ranks:fill(self.params.numOptions + 1);
204 |
205 | while startId <= numThreads do
206 | -- print progress
207 | xlua.progress(startId, numThreads);
208 |
209 | -- grab a batch
210 | local batch, nextStartId =
211 | dataloader:getTestBatch(startId, self.params, dtype);
212 |
213 | -- Call retrieve function for specific model, and store ranks
214 | ranks[{{startId, nextStartId - 1}, {}}] = self:retrieveBatch(batch)
215 | :view(nextStartId - startId, -1, self.params.numOptions);
216 | startId = nextStartId;
217 | end
218 |
219 | -- change back to training
220 | self.wrapper:training();
221 |
222 | local prediction = {};
223 | local ranks = torch.totable(ranks:double());
224 | for i = 1, #dataloader['unique_img_'..dtype] do
225 | -- rank list for all rounds in val split and last round in test split
226 | if dtype == 'test' then
227 | table.insert(prediction, {
228 | image_id = dataloader['unique_img_'..dtype][i];
229 | round_id = dataloader[dtype..'_num_rounds'][i];
230 | ranks = ranks[i][dataloader[dtype..'_num_rounds'][i]]
231 | })
232 | else
233 | for j = 1, dataloader[dtype..'_num_rounds'][i] do
234 | table.insert(prediction, {
235 | image_id = dataloader['unique_img_'..dtype][i];
236 | round_id = j;
237 | ranks = ranks[i][j]
238 | })
239 | end
240 | end
241 | end
242 | -- collect garbage
243 | collectgarbage();
244 |
245 | return prediction;
246 | end
247 |
248 | -- forward + backward pass
249 | function Model:forwardBackward(batch, onlyForward, encOutOnly)
250 | local onlyForward = onlyForward or false;
251 | local encOutOnly = encOutOnly or false
252 | local inputs = {}
253 |
254 | -- transpose for timestep first
255 | local batchQues = batch['ques_fwd']
256 | batchQues = batchQues:view(-1, batchQues:size(3)):t()
257 | table.insert(inputs, batchQues)
258 |
259 | if self.params.useIm == true then
260 | local imgFeats = batch['img_feat']
261 | -- if attention, then conv layer features
262 | if string.match(self.params.encoder, 'att') then
263 | imgFeats = imgFeats:view(-1, 1, self.params.imgSpatialSize, self.params.imgSpatialSize, self.params.imgFeatureSize)
264 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1, 1, 1)
265 | imgFeats = imgFeats:view(-1, self.params.imgSpatialSize, self.params.imgSpatialSize, self.params.imgFeatureSize)
266 | else
267 | imgFeats = imgFeats:view(-1, 1, self.params.imgFeatureSize)
268 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1)
269 | imgFeats = imgFeats:view(-1, self.params.imgFeatureSize)
270 | end
271 | table.insert(inputs, imgFeats)
272 | end
273 |
274 | if self.params.useHistory == true then
275 | local batchHist = batch['hist']
276 | batchHist = batchHist:view(-1, batchHist:size(3)):t()
277 | table.insert(inputs, batchHist)
278 | end
279 |
280 | if string.match(self.params.encoder, 'mn') then
281 | local mask = torch.ones(10, 10):byte()
282 | for i = 1, 10 do
283 | for j = 1, 10 do
284 | if j <= i then
285 | mask[i][j] = 0
286 | end
287 | end
288 | end
289 | if self.params.gpuid >= 0 then
290 | mask = mask:cuda()
291 | end
292 | local maskRepeat = torch.repeatTensor(mask, batch['hist']:size(1), 1)
293 | table.insert(inputs, maskRepeat)
294 | end
295 |
296 | -- encoder forward pass
297 | local encOut = self.encoder:forward(inputs)
298 |
299 | -- coupled enc-dec (only for gen)
300 | self.forwardConnect(self.encoder, self.decoder, encOut, batchQues:size(1));
301 |
302 | if encOutOnly == true then return encOut end
303 |
304 | -- decoder forward pass
305 | local curLoss = 0
306 | if self.params.decoder == 'gen' then
307 | local answerIn = batch['answer_in'];
308 | answerIn = answerIn:view(-1, answerIn:size(3)):t();
309 |
310 | local answerOut = batch['answer_out'];
311 | answerOut = answerOut:view(-1, answerOut:size(3)):t();
312 |
313 | local decOut = self.decoder:forward(answerIn);
314 | curLoss = self.criterion:forward(decOut, answerOut);
315 |
316 | -- backward pass
317 | if onlyForward == false then
318 | local gradCriterionOut = self.criterion:backward(decOut, answerOut);
319 | self.decoder:backward(answerIn, gradCriterionOut);
320 |
321 | --backward connect decoder and encoder (only for gen)
322 | local gradDecOut = self.backwardConnect(self.encoder, self.decoder);
323 | self.encoder:backward(inputs, gradDecOut)
324 | end
325 | elseif self.params.decoder == 'disc' then
326 | local options = batch['options']
327 | local answerInd = batch['answer_ind']
328 |
329 | local decOut = self.decoder:forward({options, encOut})
330 | curLoss = self.criterion:forward(decOut, answerInd)
331 |
332 | -- backward pass
333 | if onlyForward == false then
334 | local gradCriterionOut = self.criterion:backward(decOut, answerInd)
335 | local t = self.decoder:backward({options, encOut}, gradCriterionOut)
336 |
337 | self.encoder:backward(inputs, t[2])
338 | end
339 | end
340 |
341 | return curLoss;
342 | end
343 |
344 | function Model:retrieveBatch(batch)
345 | local inputs = {}
346 |
347 | local batchQues = batch['ques_fwd'];
348 | batchQues = batchQues:view(-1, batchQues:size(3)):t();
349 | table.insert(inputs, batchQues)
350 |
351 | if self.params.useIm == true then
352 | local imgFeats = batch['img_feat']
353 | -- if attention, then conv layer features
354 | if string.match(self.params.encoder, 'att') then
355 | imgFeats = imgFeats:view(-1, 1, self.params.imgSpatialSize, self.params.imgSpatialSize, self.params.imgFeatureSize)
356 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1, 1, 1)
357 | imgFeats = imgFeats:view(-1, self.params.imgSpatialSize, self.params.imgSpatialSize, self.params.imgFeatureSize)
358 | else
359 | imgFeats = imgFeats:view(-1, 1, self.params.imgFeatureSize)
360 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1)
361 | imgFeats = imgFeats:view(-1, self.params.imgFeatureSize)
362 | end
363 | table.insert(inputs, imgFeats)
364 | end
365 |
366 | if self.params.useHistory == true then
367 | local batchHist = batch['hist']
368 | batchHist = batchHist:view(-1, batchHist:size(3)):t()
369 | table.insert(inputs, batchHist)
370 | end
371 |
372 | if string.match(self.params.encoder, 'mn') then
373 | local mask = torch.ones(10, 10):byte()
374 | for i = 1, 10 do
375 | for j = 1, 10 do
376 | if j <= i then
377 | mask[i][j] = 0
378 | end
379 | end
380 | end
381 | if self.params.gpuid >= 0 then
382 | mask = mask:cuda()
383 | end
384 | local maskRepeat = torch.repeatTensor(mask, batch['hist']:size(1), 1)
385 | table.insert(inputs, maskRepeat)
386 | end
387 |
388 | -- forward pass
389 | local encOut = self.encoder:forward(inputs)
390 | local batchSize = batchQues:size(2);
391 |
392 | if self.params.decoder == 'gen' then
393 | local optionIn = batch['option_in'];
394 | optionIn = optionIn:view(-1, optionIn:size(3), optionIn:size(4));
395 |
396 | local optionOut = batch['option_out'];
397 | optionOut = optionOut:view(-1, optionOut:size(3), optionOut:size(4));
398 | optionIn = optionIn:transpose(1, 2):transpose(2, 3);
399 | optionOut = optionOut:transpose(1, 2):transpose(2, 3);
400 |
401 | -- tensor holds the likelihood for all the options
402 | local optionLhood = torch.Tensor(self.params.numOptions, batchSize);
403 |
404 | -- repeat for each option and get gt rank
405 | for opId = 1, self.params.numOptions do
406 | -- forward connect encoder and decoder
407 | self.forwardConnect(self.encoder, self.decoder, encOut, batchQues:size(1));
408 |
409 | local curOptIn = optionIn[opId];
410 | local curOptOut = optionOut[opId];
411 | local decOut = self.decoder:forward(curOptIn);
412 |
413 | -- compute the probabilities for each answer, based on its tokens
414 | optionLhood[opId] = utils.computeLhood(curOptOut, decOut);
415 | end
416 | -- gtPosition can be nil if ground truth does not exist
417 | local gtPosition = self.params.useGt and batch['answer_ind'] or nil;
418 |
419 | -- return the ranks for this batch
420 | return utils.computeRanks(optionLhood:t(), gtPosition);
421 | elseif self.params.decoder == 'disc' then
422 | local options = batch['options']
423 | local decOut = self.decoder:forward({options, encOut})
424 | local gtPosition = self.params.useGt and batch['answer_ind'] or nil;
425 |
426 | -- return the ranks for this batch
427 | return utils.computeRanks(decOut, gtPosition)
428 | end
429 |
430 | end
431 |
432 | function Model:generateAnswers(dataloader, dtype, params)
433 | -- check decoder
434 | if self.params.decoder == 'disc' then
435 | print('Sampling/beam search only for generative model')
436 | os.exit()
437 | end
438 |
439 | -- setting the options for beam search / sampling
440 | params = params or {};
441 |
442 | -- sample or take max
443 | local sampleWords = params.sampleWords and params.sampleWords == 1 or false;
444 | local temperature = params.temperature or 1.0;
445 | local beamSize = params.beamSize or 5;
446 | local beamLen = params.beamLen or 20;
447 |
448 | print('Beam size', beamSize)
449 | print('Beam length', beamLen)
450 |
451 | -- endToken index
452 | local startToken = dataloader.word2ind[''];
453 | local endToken = dataloader.word2ind[''];
454 | local numThreads = params.maxThreads or dataloader.numThreads[dtype];
455 | print('No. of threads', numThreads)
456 |
457 | local answerTable = {}
458 | for convId = 1, numThreads do
459 | xlua.progress(convId, numThreads);
460 | self.wrapper:evaluate()
461 |
462 | local inds = torch.LongTensor(1):fill(convId);
463 | local batch = dataloader:getIndexData(inds, self.params, dtype);
464 | local numQues = batch['ques_fwd']:size(1) * batch['ques_fwd']:size(2);
465 |
466 | local encOut = self:forwardBackward(batch, true, true)
467 | local threadAnswers = {}
468 |
469 | if sampleWords == false then
470 | -- do beam search for each example now
471 | for iter = 1, 10 do
472 | local encInSeq = batch['ques_fwd']:view(-1, batch['ques_fwd']:size(3)):t();
473 | encInSeq = encInSeq[{{},{iter}}]:squeeze():float()
474 |
475 | -- beams
476 | local beams = torch.LongTensor(beamLen, beamSize):zero();
477 |
478 | -- initial hidden states for the beam at current round of dialog
479 | local hiddenBeams = {};
480 | if self.encoder.rnnLayers ~=nil then
481 | for level = 1, #self.encoder.rnnLayers do
482 | if hiddenBeams[level] == nil then hiddenBeams[level] = {} end
483 | hiddenBeams[level]['output'] = self.encoder.rnnLayers[level].output[batch['ques_fwd']:size(3)][iter];
484 | hiddenBeams[level]['cell'] = self.encoder.rnnLayers[level].cell[batch['ques_fwd']:size(3)][iter];
485 | if level == #self.encoder.rnnLayers then
486 | hiddenBeams[#self.encoder.rnnLayers]['output'] = encOut[iter]
487 | end
488 | hiddenBeams[level]['output'] = torch.repeatTensor(hiddenBeams[level]['output'], beamSize, 1);
489 | hiddenBeams[level]['cell'] = torch.repeatTensor(hiddenBeams[level]['cell'], beamSize, 1);
490 | end
491 | -- hiddenBeams[]['cell'] is beam_nums x 512
492 | -- hiddenBeams[]['output'] is beam_nums x 512
493 | else
494 | for level = 1, #self.decoder.rnnLayers do
495 | if hiddenBeams[level] == nil then hiddenBeams[level] = {} end
496 | if level == #self.decoder.rnnLayers then
497 | hiddenBeams[level]['output'] = torch.repeatTensor(encOut[iter], beamSize, 1)
498 | else
499 | hiddenBeams[level]['output'] = torch.Tensor(beamSize, encOut:size(2)):zero()
500 | end
501 | hiddenBeams[level]['cell'] = hiddenBeams[level]['output']:clone():zero()
502 | end
503 | end
504 |
505 | -- for first step, initialize with start symbols
506 | beams[1] = dataloader.word2ind[''];
507 | scores = torch.DoubleTensor(beamSize):zero();
508 | finishBeams = {}; -- accumulate beams that are done
509 |
510 | for step = 2, beamLen do
511 |
512 | -- candidates for the current iteration
513 | cands = {};
514 |
515 | -- if step == 2, explore only one beam (all are )
516 | local exploreSize = (step == 2) and 1 or beamSize;
517 |
518 | -- first copy the hidden states to the decoder
519 | for level = 1, #self.decoder.rnnLayers do
520 | self.decoder.rnnLayers[level].userPrevOutput = hiddenBeams[level]['output']
521 | self.decoder.rnnLayers[level].userPrevCell = hiddenBeams[level]['cell']
522 | end
523 |
524 | -- decoder forward pass
525 | decOut = self.decoder:forward(beams[{{step-1}}]);
526 | decOut = decOut:squeeze(); -- decOut is beam_nums x vocab_size
527 |
528 | -- iterate separately for each possible word of beam
529 | for wordId = 1, exploreSize do
530 | local curHidden = {};
531 | for level = 1, #self.decoder.rnnLayers do
532 | if curHidden[level] == nil then curHidden[level] = {} end
533 | curHidden[level]['output'] = self.decoder.rnnLayers[level].output[{{1},{wordId}}]:clone():squeeze(); -- rnnLayers[].output is 1 x beam_nums x 512
534 | curHidden[level]['cell'] = self.decoder.rnnLayers[level].cell[{{1},{wordId}}]:clone():squeeze();
535 | end
536 |
537 | -- sort and get the top probabilities
538 | if beamSize == 1 then
539 | topProb, topInd = torch.topk(decOut, beamSize, true);
540 | else
541 | topProb, topInd = torch.topk(decOut[wordId], beamSize, true);
542 | end
543 |
544 | for candId = 1, beamSize do
545 | local candBeam = beams[{{}, {wordId}}]:clone();
546 | -- get the updated cost for each explored candidate, pool
547 | candBeam[step] = topInd[candId];
548 | if topInd[candId] == endToken then
549 | table.insert(finishBeams, {beam = candBeam:double():squeeze(), length = step, score = scores[wordId] + topProb[candId]});
550 | else
551 | table.insert(cands, {score = scores[wordId] + topProb[candId],
552 | beam = candBeam,
553 | hidden = curHidden});
554 | end
555 | end
556 | end
557 |
558 | -- sort the candidates and stick to beam size
559 | table.sort(cands, function (a, b) return a.score > b.score; end);
560 |
561 | for candId = 1, math.min(#cands, beamSize) do
562 | beams[{{}, {candId}}] = cands[candId].beam;
563 |
564 | --recursive copy
565 | for level = 1, #self.decoder.rnnLayers do
566 | hiddenBeams[level]['output'][candId] = cands[candId].hidden[level]['output']:clone();
567 | hiddenBeams[level]['cell'][candId] = cands[candId].hidden[level]['cell']:clone();
568 | end
569 |
570 | scores[candId] = cands[candId].score;
571 | end
572 | end
573 |
574 | table.sort(finishBeams, function (a, b) return a.score > b.score; end);
575 |
576 | local quesWords = encInSeq:double():squeeze()
577 | local ansWords = finishBeams[1].beam:squeeze();
578 |
579 | local quesText = utils.idToWords(quesWords, dataloader.ind2word);
580 | local ansText = utils.idToWords(ansWords, dataloader.ind2word);
581 |
582 | table.insert(threadAnswers, {question = quesText, answer = ansText})
583 | end
584 | else
585 | local answerIn = torch.Tensor(1, numQues):fill(startToken)
586 | local answer = {answerIn:t():double()}
587 | for timeStep = 1, beamLen do
588 | -- one pass through decoder
589 | local decOut = self.decoder:forward(answerIn):squeeze()
590 | -- connect decoder to itself
591 | self.decoderConnect(self.decoder)
592 |
593 | local nextToken = torch.multinomial(torch.exp(decOut / temperature), 1)
594 | table.insert(answer, nextToken:double())
595 | answerIn:copy(nextToken)
596 | end
597 | answer = nn.JoinTable(-1):forward(answer)
598 |
599 | for iter = 1, 10 do
600 | local quesWords = batch['ques_fwd'][{{1}, {iter}, {}}]:squeeze():double()
601 | local ansWords = answer[{{iter}, {}}]:squeeze()
602 |
603 | local quesText = utils.idToWords(quesWords, dataloader.ind2word)
604 | local ansText = utils.idToWords(ansWords, dataloader.ind2word)
605 |
606 | table.insert(threadAnswers, {question = quesText, answer = ansText})
607 | end
608 | end
609 | self.wrapper:training()
610 | table.insert(answerTable, {image_id = dataloader['unique_img_'..dtype][convId], dialog = threadAnswers})
611 | end
612 | return answerTable
613 | end
614 |
615 | return Model;
616 |
--------------------------------------------------------------------------------
/model_utils/MaskFuture.lua:
--------------------------------------------------------------------------------
1 | -- new module to replace zero with a given value
2 | local MaskFuture, Parent = torch.class('nn.MaskFuture', 'nn.Module')
3 |
4 | function MaskFuture:__init(numClasses)
5 | Parent.__init(self);
6 | self.mask = torch.Tensor(1, numClasses, numClasses):fill(1);
7 | -- extract the upper diagonal matrix
8 | self.mask[1] = torch.triu(self.mask[1], 1);
9 | self.gradInput = torch.Tensor();
10 | self.output = torch.Tensor();
11 | end
12 |
13 | function MaskFuture:updateOutput(input)
14 | local batchSize = input:size(1);
15 |
16 | self.output:resizeAs(input):copy(input);
17 | -- expand mask based on input
18 | self.output[self.mask:expandAs(input)] = 0;
19 |
20 | return self.output;
21 | end
22 |
23 | function MaskFuture:updateGradInput(input, gradOutput)
24 | -- the first component is zero gradients
25 | --self.gradInput:resizeAs(input):zero();
26 | -- zero out the gradients based on the mask
27 | self.gradInput:resizeAs(gradOutput):copy(gradOutput);
28 | self.gradInput[self.mask:expandAs(input)] = 0;
29 |
30 | return self.gradInput;
31 | end
32 |
--------------------------------------------------------------------------------
/model_utils/MaskSoftMax.lua:
--------------------------------------------------------------------------------
1 | -- Author: Jiasen Lu
2 | -- From https://raw.githubusercontent.com/jiasenlu/HieCoAttenVQA/master/misc/maskSoftmax.lua
3 | local MaskSoftMax, _ = torch.class('nn.MaskSoftMax', 'nn.Module')
4 |
5 | function MaskSoftMax:updateOutput(input)
6 | local data = input[1]
7 | local mask = input[2]
8 | if(mask:type() == 'torch.CudaTensor') then
9 | mask = mask:cudaByte()
10 | end
11 |
12 | data:maskedFill(mask, -9999999)
13 | if(mask:type() == 'torch.CudaByteTensor') then
14 | mask = mask:cuda()
15 | end
16 | data.THNN.SoftMax_updateOutput(
17 | data:cdata(),
18 | self.output:cdata()
19 | )
20 | return self.output
21 | end
22 |
23 | function MaskSoftMax:updateGradInput(input, gradOutput)
24 | local data = input[1]
25 | local mask = input[2]
26 | if(mask:type() == 'torch.CudaTensor') then
27 | mask = mask:cudaByte()
28 | end
29 |
30 | data:maskedFill(mask, -9999999)
31 | if(mask:type() == 'torch.CudaByteTensor') then
32 | mask = mask:cuda()
33 | end
34 |
35 | data.THNN.SoftMax_updateGradInput(
36 | data:cdata(),
37 | gradOutput:cdata(),
38 | self.gradInput:cdata(),
39 | self.output:cdata()
40 | )
41 | if not self.dummy_out then
42 | self.dummy_out = mask:clone()
43 | end
44 | self.dummy_out:resizeAs(mask):zero()
45 | return {self.gradInput, self.dummy_out}
46 | end
47 |
--------------------------------------------------------------------------------
/model_utils/MaskTime.lua:
--------------------------------------------------------------------------------
1 | -- new module to replace zero with a given value
2 | local MaskTime, Parent = torch.class('nn.MaskTime', 'nn.Module')
3 |
4 | function MaskTime:__init(featSize)
5 | Parent.__init(self);
6 | self.mask = torch.Tensor();
7 | self.seqLen = nil;
8 | self.featSize = featSize;
9 | self.gradInput = {torch.Tensor(), torch.Tensor()};
10 | end
11 |
12 | function MaskTime:updateOutput(input)
13 | local seqLen = input[1]:size(1);
14 | local batchSize = input[1]:size(2);
15 |
16 | -- expand the feature vector
17 | self.output:resizeAs(input[2]):copy(input[2]);
18 | self.output = self.output:view(1, batchSize, self.featSize);
19 | self.output = self.output:repeatTensor(seqLen, 1, 1);
20 |
21 | -- expand the word mask
22 | self.mask = input[1]:eq(0);
23 | self.mask = self.mask:view(seqLen, batchSize, 1)
24 | :expand(seqLen, batchSize, self.featSize);
25 | self.output[self.mask] = 0;
26 |
27 | return self.output;
28 | end
29 |
30 | function MaskTime:updateGradInput(input, gradOutput)
31 | -- the first component is zero gradients
32 | self.gradInput[1]:resizeAs(input[1]):zero();
33 | -- second component has zeroed out gradients
34 | -- sum along first dimension
35 | self.gradInput[2]:resizeAs(gradOutput):copy(gradOutput);
36 | self.gradInput[2][self.mask] = 0;
37 | self.gradInput[2] = self.gradInput[2]:sum(1):squeeze();
38 |
39 | return self.gradInput;
40 | end
41 |
--------------------------------------------------------------------------------
/model_utils/ReplaceZero.lua:
--------------------------------------------------------------------------------
1 | -- new module to replace zero with a given value
2 | local ReplaceZero, Parent = torch.class('nn.ReplaceZero', 'nn.Module')
3 |
4 | function ReplaceZero:__init(constant)
5 | Parent.__init(self);
6 | if not constant then
7 | error(' constant must be specified')
8 | end
9 | self.constant = constant;
10 | self.mask = torch.Tensor();
11 | end
12 |
13 | function ReplaceZero:updateOutput(input)
14 | self.output:resizeAs(input):copy(input);
15 | self.mask = input:eq(0);
16 | self.output[self.mask] = self.constant;
17 | return self.output;
18 | end
19 |
20 | function ReplaceZero:updateGradInput(input, gradOutput)
21 | self.gradInput:resizeAs(gradOutput):copy(gradOutput);
22 | -- remove the gradients at those points
23 | self.gradInput[self.mask] = 0;
24 | return self.gradInput;
25 | end
26 |
--------------------------------------------------------------------------------
/model_utils/optim_updates.lua:
--------------------------------------------------------------------------------
1 | -- Author: Andrej Karpathy https://github.com/karpathy
2 | -- Project: neuraltalk2 https://github.com/karpathy/neuraltalk2
3 | -- Slightly modified by Xiao Lin for initial values of rmsprop.
4 |
5 | -- optim, simple as it should be, written from scratch. That's how I roll
6 |
7 | function sgd(x, dx, lr)
8 | x:add(-lr, dx)
9 | end
10 |
11 | function sgdm(x, dx, lr, alpha, state)
12 | -- sgd with momentum, standard update
13 | if not state.v then
14 | state.v = x.new(#x):zero()
15 | end
16 | state.v:mul(alpha)
17 | state.v:add(lr, dx)
18 | x:add(-1, state.v)
19 | end
20 |
21 | function sgdmom(x, dx, lr, alpha, state)
22 | -- sgd momentum, uses nesterov update (reference: http://cs231n.github.io/neural-networks-3/#sgd)
23 | if not state.m then
24 | state.m = x.new(#x):zero()
25 | state.tmp = x.new(#x)
26 | end
27 | state.tmp:copy(state.m)
28 | state.m:mul(alpha):add(-lr, dx)
29 | x:add(-alpha, state.tmp)
30 | x:add(1+alpha, state.m)
31 | end
32 |
33 | function adagrad(x, dx, lr, epsilon, state)
34 | if not state.m then
35 | state.m = x.new(#x):zero()
36 | state.tmp = x.new(#x)
37 | end
38 | -- calculate new mean squared values
39 | state.m:addcmul(1.0, dx, dx)
40 | -- perform update
41 | state.tmp:sqrt(state.m):add(epsilon)
42 | x:addcdiv(-lr, dx, state.tmp)
43 | end
44 |
45 | -- rmsprop implementation, simple as it should be
46 | function rmsprop(x, dx, state)
47 | local alpha = state.alpha or 0.99;
48 | local learningRate = state.learningRate or 1e-2;
49 | local epsilon = state.epsilon or 1e-8;
50 | if not state.m then
51 | state.m = x.new(#x):zero()
52 | state.tmp = x.new(#x)
53 | end
54 | -- calculate new (leaky) mean squared values
55 | state.m:mul(alpha)
56 | state.m:addcmul(1.0-alpha, dx, dx)
57 | -- perform update
58 | state.tmp:sqrt(state.m):add(epsilon)
59 | x:addcdiv(-learningRate, dx, state.tmp)
60 | end
61 |
62 | function adam(x, dx, state)
63 | local beta1 = state.beta1 or 0.9
64 | local beta2 = state.beta2 or 0.999
65 | local epsilon = state.epsilon or 1e-8
66 | local lr = state.learningRate or 1e-2;
67 |
68 | if not state.m then
69 | -- Initialization
70 | state.t = 0
71 | -- Exponential moving average of gradient values
72 | state.m = x.new(#dx):zero()
73 | -- Exponential moving average of squared gradient values
74 | state.v = x.new(#dx):zero()
75 | -- A tmp tensor to hold the sqrt(v) + epsilon
76 | state.tmp = x.new(#dx):zero()
77 | end
78 |
79 | -- Decay the first and second moment running average coefficient
80 | state.m:mul(beta1):add(1-beta1, dx)
81 | state.v:mul(beta2):addcmul(1-beta2, dx, dx)
82 | state.tmp:copy(state.v):sqrt():add(epsilon)
83 |
84 | state.t = state.t + 1
85 | local biasCorrection1 = 1 - beta1^state.t
86 | local biasCorrection2 = 1 - beta2^state.t
87 | local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1
88 |
89 | -- perform update
90 | x:addcdiv(-stepSize, state.m, state.tmp)
91 | end
92 |
--------------------------------------------------------------------------------
/model_utils/weight-init.lua:
--------------------------------------------------------------------------------
1 | -- From https://raw.githubusercontent.com/e-lab/torch-toolbox/master/Weight-init/weight-init.lua
2 | --
3 | -- Different weight initialization methods
4 | --
5 | -- > model = require('weight-init')(model, 'heuristic')
6 | --
7 | require("nn")
8 |
9 |
10 | -- "Efficient backprop"
11 | -- Yann Lecun, 1998
12 | local function w_init_heuristic(fan_in, fan_out)
13 | return math.sqrt(1/(3*fan_in))
14 | end
15 |
16 |
17 | -- "Understanding the difficulty of training deep feedforward neural networks"
18 | -- Xavier Glorot, 2010
19 | local function w_init_xavier(fan_in, fan_out)
20 | return math.sqrt(2/(fan_in + fan_out))
21 | end
22 |
23 |
24 | -- "Understanding the difficulty of training deep feedforward neural networks"
25 | -- Xavier Glorot, 2010
26 | local function w_init_xavier_caffe(fan_in, fan_out)
27 | return math.sqrt(1/fan_in)
28 | end
29 |
30 |
31 | -- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
32 | -- Kaiming He, 2015
33 | local function w_init_kaiming(fan_in, fan_out)
34 | return math.sqrt(4/(fan_in + fan_out))
35 | end
36 |
37 |
38 | local function w_init(net, arg)
39 | -- choose initialization method
40 | local method = nil
41 | if arg == 'heuristic' then method = w_init_heuristic
42 | elseif arg == 'xavier' then method = w_init_xavier
43 | elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe
44 | elseif arg == 'kaiming' then method = w_init_kaiming
45 | else
46 | assert(false)
47 | end
48 |
49 | -- loop over all convolutional modules
50 | for i = 1, #net.modules do
51 | local m = net.modules[i]
52 | if m.__typename == 'nn.SpatialConvolution' then
53 | m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
54 | elseif m.__typename == 'nn.SpatialConvolutionMM' then
55 | m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
56 | elseif m.__typename == 'nn.LateralConvolution' then
57 | m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))
58 | elseif m.__typename == 'nn.VerticalConvolution' then
59 | m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
60 | elseif m.__typename == 'nn.HorizontalConvolution' then
61 | m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
62 | elseif m.__typename == 'nn.Linear' then
63 | m:reset(method(m.weight:size(2), m.weight:size(1)))
64 | elseif m.__typename == 'nn.TemporalConvolution' then
65 | m:reset(method(m.weight:size(2), m.weight:size(1)))
66 | end
67 |
68 | if m.bias then
69 | m.bias:zero()
70 | end
71 | end
72 | return net
73 | end
74 |
75 |
76 | return w_init
77 |
--------------------------------------------------------------------------------
/opts.lua:
--------------------------------------------------------------------------------
1 | cmd = torch.CmdLine()
2 | cmd:text('Train the Visual Dialog model')
3 | cmd:text()
4 | cmd:text('Options')
5 | -- Data input settings
6 | cmd:option('-inputImg', 'data/data_img.h5', 'HDF5 file with image features')
7 | cmd:option('-inputQues', 'data/visdial_data.h5', 'HDF5 file with preprocessed questions')
8 | cmd:option('-inputJson', 'data/visdial_params.json', 'JSON file with info and vocab')
9 | cmd:option('-savePath', 'checkpoints/', 'Path to save checkpoints')
10 | cmd:option('-saveIter', 2, 'Save model checkpoint after every saveIter epochs')
11 |
12 | -- specify encoder/decoder
13 | cmd:option('-encoder', 'lf-ques-hist', 'Name of the encoder to use')
14 | cmd:option('-decoder', 'gen', 'Name of the decoder to use (gen/disc)')
15 | cmd:option('-imgNorm', 1, 'normalize the image feature. 1=yes, 0=no')
16 |
17 | -- model params
18 | cmd:option('-imgEmbedSize', 300, 'Size of the multimodal embedding')
19 | cmd:option('-imgFeatureSize', 4096, 'Channel size of the image feature')
20 | cmd:option('-imgSpatialSize', 14, 'Spatial size of image features (for attention-based encoders).')
21 | cmd:option('-embedSize', 300, 'Size of input word embeddings')
22 | cmd:option('-rnnHiddenSize', 512, 'Size of the LSTM state')
23 | cmd:option('-maxHistoryLen', 60, 'Maximum history to consider when using concatenated QA pairs')
24 | cmd:option('-numLayers', 2, 'Number of layers in LSTM')
25 | cmd:option('-commonEmbeddingSize', 512, 'Common embedding size in MN-ATT-QIH')
26 | cmd:option('-numAttentionLayers', 1, 'No. of attention hops in MN-ATT-QIH')
27 |
28 | cmd:option('-loadPath', '', 'Checkpoint path to load from')
29 |
30 | -- optimization params
31 | cmd:option('-batchSize', 40, 'Batch size (number of threads) (Adjust base on GPU memory)')
32 | cmd:option('-learningRate', 1e-3, 'Learning rate')
33 | cmd:option('-weightInit', 'xavier', 'Weight initialization strategy: xavier|heuristic|kaiming')
34 | cmd:option('-dropout', 0.5, 'Dropout')
35 | cmd:option('-numEpochs', 100, 'Epochs')
36 | cmd:option('-LRateDecay', 10, 'After lr_decay epochs lr reduces to 0.1*lr')
37 | cmd:option('-lrDecayRate', 0.9997592083, 'Decay for learning rate')
38 | cmd:option('-minLRate', 5e-5, 'Minimum learning rate')
39 | cmd:option('-gpuid', 0, 'GPU id to use')
40 | cmd:option('-backend', 'cudnn', 'nn|cudnn')
41 |
42 | local opts = cmd:parse(arg);
43 |
44 | -- if save path is not given, use default — time
45 | -- get the current time
46 | local curTime = os.date('*t', os.time());
47 | -- create another folder to avoid clutter
48 | local modelPath = string.format('checkpoints/model-%d-%d-%d-%d:%d:%d-%s-%s/',
49 | curTime.month, curTime.day, curTime.year,
50 | curTime.hour, curTime.min, curTime.sec,
51 | opts.encoder, opts.decoder)
52 | if opts.savePath == 'checkpoints/' then opts.savePath = modelPath end;
53 |
54 | -- check for inputs required
55 | if string.match(opts.encoder, 'hist') then opts.useHistory = true end
56 | if string.match(opts.encoder, 'im') then opts.useIm = true end
57 |
58 | -- check if history is to be concatenated (only for late fusion encoder)
59 | if string.match(opts.encoder, 'lf') then opts.concatHistory = true end
60 |
61 | -- attention is always on conv features, not fc7
62 | if string.match(opts.encoder, 'att') then
63 | if opts.inputImg == 'data/data_img.h5' then
64 | opts.inputImg = 'data/data_img_pool5.h5'
65 | end
66 | opts.imgNorm = 0
67 | end
68 |
69 | return opts;
70 |
--------------------------------------------------------------------------------
/scripts/download_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | print_usage () {
4 | echo "Usage: download_model.sh [vgg|resnet] [layers]"
5 | echo "For vgg, 'layers' can be one of {16, 19}"
6 | echo "For resnet, 'layers' can be one of {18, 34, 50, 101, 152, 200}"
7 | }
8 |
9 | if [ $1 = "vgg" ]
10 | then
11 | if [ $2 = "16" ]
12 | then
13 | mkdir -p data/models/vgg16
14 | cd data/models/vgg16
15 | wget https://gist.githubusercontent.com/ksimonyan/211839e770f7b538e2d8/raw/ded9363bd93ec0c770134f4e387d8aaaaa2407ce/VGG_ILSVRC_16_layers_deploy.prototxt
16 | wget http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel
17 | elif [ $2 = "19" ]
18 | then
19 | mkdir -p data/models/vgg19
20 | cd data/models/vgg19
21 | wget https://gist.githubusercontent.com/ksimonyan/3785162f95cd2d5fee77/raw/f43eeefc869d646b449aa6ce66f87bf987a1c9b5/VGG_ILSVRC_19_layers_deploy.prototxt
22 | wget http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel
23 | else
24 | print_usage
25 | fi
26 | elif [ $1 = "resnet" ]
27 | then
28 | if echo "18 34 50 101 152 200" | grep -w $2 > /dev/null
29 | then
30 | mkdir -p data/models/resnet
31 | cd data/models/resnet
32 | wget https://d2j0dndfm35trm.cloudfront.net/resnet-$2.t7
33 | else
34 | print_usage
35 | fi
36 | else
37 | print_usage
38 | fi
39 |
40 | cd ../../..
41 |
--------------------------------------------------------------------------------
/train.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 | require 'nngraph'
3 | require 'rnn'
4 |
5 | ------------------------------------------------------------------------
6 | -- Input arguments and options
7 | ------------------------------------------------------------------------
8 | local opt = require 'opts';
9 | print(opt)
10 |
11 | -- seed for reproducibility
12 | torch.manualSeed(1234);
13 |
14 | -- set default tensor based on gpu usage
15 | if opt.gpuid >= 0 then
16 | require 'cutorch'
17 | require 'cunn'
18 | if opt.backend == 'cudnn' then require 'cudnn' end
19 | cutorch.setDevice(opt.gpuid+1)
20 | cutorch.manualSeed(1234)
21 | torch.setdefaulttensortype('torch.CudaTensor');
22 | else
23 | torch.setdefaulttensortype('torch.FloatTensor');
24 | end
25 |
26 | -- transfer all options to model
27 | local modelParams = opt;
28 |
29 | ------------------------------------------------------------------------
30 | -- Read saved model and parameters
31 | ------------------------------------------------------------------------
32 | local savedModel = false;
33 | if opt.loadPath ~= '' then
34 | savedModel = torch.load(opt.loadPath);
35 | modelParams = savedModel.modelParams;
36 |
37 | opt.imgNorm = modelParams.imgNorm;
38 | opt.encoder = modelParams.encoder;
39 | opt.decoder = modelParams.decoder;
40 | modelParams.gpuid = opt.gpuid;
41 | modelParams.batchSize = opt.batchSize;
42 | end
43 |
44 | ------------------------------------------------------------------------
45 | -- Loading dataset
46 | ------------------------------------------------------------------------
47 | local dataloader = dofile('dataloader.lua');
48 | dataloader:initialize(opt, {'train'});
49 | collectgarbage();
50 |
51 | ------------------------------------------------------------------------
52 | -- Setting model parameters
53 | ------------------------------------------------------------------------
54 | -- transfer parameters from dataloader to model
55 | paramNames = {'numTrainThreads', 'numTestThreads', 'numValThreads',
56 | 'vocabSize', 'maxQuesCount', 'maxQuesLen', 'maxAnsLen'};
57 | for _, value in pairs(paramNames) do
58 | modelParams[value] = dataloader[value];
59 | end
60 |
61 | -- path to save the model
62 | local modelPath = opt.savePath
63 |
64 | -- creating the directory to save the model
65 | paths.mkdir(modelPath);
66 |
67 | -- Iterations per epoch
68 | modelParams.numIterPerEpoch = math.ceil(modelParams.numTrainThreads /
69 | modelParams.batchSize);
70 | print(string.format('\n%d iter per epoch.', modelParams.numIterPerEpoch));
71 |
72 | ------------------------------------------------------------------------
73 | -- Setup the model
74 | ------------------------------------------------------------------------
75 | require 'model'
76 | local model = Model(modelParams);
77 |
78 | if opt.loadPath ~= '' then
79 | model.wrapperW:copy(savedModel.modelW);
80 | model.optims.learningRate = savedModel.optims.learningRate;
81 | end
82 |
83 | ------------------------------------------------------------------------
84 | -- Training
85 | ------------------------------------------------------------------------
86 | print('Training..')
87 | collectgarbage()
88 |
89 | runningLoss = 0;
90 | for iter = 1, modelParams.numEpochs * modelParams.numIterPerEpoch do
91 | -- forward and backward propagation
92 | model:trainIteration(dataloader);
93 |
94 | -- evaluate on val and save model
95 | if iter % (modelParams.saveIter * modelParams.numIterPerEpoch) == 0 then
96 | local currentEpoch = iter / modelParams.numIterPerEpoch
97 |
98 | -- save model and optimization parameters
99 | torch.save(string.format(modelPath .. 'model_epoch_%d.t7', currentEpoch),
100 | {modelW = model.wrapperW,
101 | optims = model.optims,
102 | modelParams = modelParams})
103 | -- validation accuracy
104 | -- model:retrieve(dataloader, 'val');
105 | end
106 |
107 | -- print after every few iterations
108 | if iter % 100 == 0 then
109 | local currentEpoch = iter / modelParams.numIterPerEpoch;
110 |
111 | -- print current time, running average, learning rate, iteration, epoch
112 | print(string.format('[%s][Epoch:%.02f][Iter:%d][Loss:%.05f][lr:%f]',
113 | os.date(), currentEpoch, iter, runningLoss,
114 | model.optims.learningRate))
115 | end
116 | if iter % 10 == 0 then collectgarbage(); end
117 | end
118 |
119 | -- Saving the final model
120 | torch.save(modelPath .. 'model_final.t7', {modelW = model.wrapperW:float(),
121 | modelParams = modelParams});
122 |
--------------------------------------------------------------------------------
/utils.lua:
--------------------------------------------------------------------------------
1 | -- script containing supporting code/methods
2 | local utils = {};
3 | cjson = require 'cjson'
4 |
5 | -- right align the question tokens in 3d volume
6 | function utils.rightAlign(sequences, lengths)
7 | -- clone the sequences
8 | local rAligned = sequences:clone():fill(0);
9 | local numDims = sequences:dim();
10 |
11 | if numDims == 3 then
12 | local M = sequences:size(3); -- maximum length of question
13 | local numImgs = sequences:size(1); -- number of images
14 | local maxCount = sequences:size(2); -- number of questions / image
15 |
16 | for imId = 1, numImgs do
17 | for quesId = 1, maxCount do
18 | -- do only for non zero sequence counts
19 | if lengths[imId][quesId] == 0 then
20 | break;
21 | end
22 |
23 | -- copy based on the sequence length
24 | rAligned[imId][quesId][{{M - lengths[imId][quesId] + 1, M}}] =
25 | sequences[imId][quesId][{{1, lengths[imId][quesId]}}];
26 | end
27 | end
28 | else if numDims == 2 then
29 | -- handle 2 dimensional matrices as well
30 | local M = sequences:size(2); -- maximum length of question
31 | local numImgs = sequences:size(1); -- number of images
32 |
33 | for imId = 1, numImgs do
34 | -- do only for non zero sequence counts
35 | if lengths[imId] > 0 then
36 | -- copy based on the sequence length
37 | rAligned[imId][{{M - lengths[imId] + 1, M}}] =
38 | sequences[imId][{{1, lengths[imId]}}];
39 | end
40 | end
41 | end
42 | end
43 |
44 | return rAligned;
45 | end
46 |
47 | -- translate a given tensor/table to sentence
48 | function utils.idToWords(vector, ind2word)
49 | local sentence = '';
50 |
51 | local nextWord;
52 | for wordId = 1, vector:size(1) do
53 | if vector[wordId] > 0 then
54 | nextWord = ind2word[vector[wordId]];
55 | sentence = sentence..' '..nextWord;
56 | end
57 |
58 | -- stop if end of token is attained
59 | if nextWord == '' then break; end
60 | end
61 |
62 | return sentence;
63 | end
64 |
65 | -- read a json file and lua table
66 | function utils.readJSON(fileName)
67 | local file = io.open(fileName, 'r');
68 | local text = file:read();
69 | file:close();
70 |
71 | -- convert and save information
72 | return cjson.decode(text);
73 | end
74 |
75 | -- save a lua table to the json
76 | function utils.writeJSON(fileName, luaTable)
77 | -- serialize lua table
78 | local text = cjson.encode(luaTable)
79 |
80 | local file = io.open(fileName, 'w');
81 | file:write(text);
82 | file:close();
83 | end
84 |
85 | -- compute the likelihood given the gt words and predicted probabilities
86 | function utils.computeLhood(words, predProbs)
87 | -- compute the probabilities for each answer, based on its tokens
88 | -- convert to 2d matrix
89 | local predVec = predProbs:view(-1, predProbs:size(3));
90 | local indices = words:contiguous():view(-1, 1);
91 | local mask = indices:eq(0);
92 | -- assign proxy values to avoid 0 index errors
93 | indices[mask] = 1;
94 | local logProbs = predVec:gather(2, indices);
95 | -- neutralize other values
96 | logProbs[mask] = 0;
97 | logProbs = logProbs:viewAs(words);
98 | -- sum up for each sentence
99 | logProbs = logProbs:sum(1):squeeze();
100 |
101 | return logProbs;
102 | end
103 |
104 | -- process the scores and obtain the ranks
105 | -- input: scores for all options, ground truth positions
106 | function utils.computeRanks(scores, gtPos)
107 | -- sort in descending order - largest score gets highest rank
108 | local sorted, rankedIdx = scores:sort(2, true)
109 |
110 | -- convert from ranked_idx to ranks
111 | local ranks = rankedIdx:clone():fill(0)
112 | for i = 1, rankedIdx:size(1) do
113 | for j = 1, 100 do
114 | ranks[{i, rankedIdx[{i, j}]}] = j
115 | end
116 | end
117 |
118 | if gtPos then
119 | gtPos = gtPos:view(-1)
120 | local gtRanks = torch.LongTensor(gtPos:size(1))
121 | for i = 1, gtPos:size(1) do
122 | gtRanks[i] = ranks[{i, gtPos[i]}]
123 | end
124 | ranks = gtRanks
125 | end
126 |
127 | return ranks:double()
128 | end
129 |
130 | -- process the ranks and print metrics
131 | function utils.processRanks(ranks)
132 | -- print the results
133 | local numQues = ranks:size(1) * ranks:size(2);
134 |
135 | local numOptions = 100;
136 |
137 | -- convert ranks to double, vector and remove zeros
138 | ranks = ranks:double():view(-1);
139 | -- non of the values should be 0, there is gt in options
140 | if torch.sum(ranks:le(0)) > 0 then
141 | numZero = torch.sum(ranks:le(0));
142 | print(string.format('Warning: some of ranks are zero : %d', numZero))
143 | ranks = ranks[ranks:gt(0)];
144 | end
145 |
146 | if torch.sum(ranks:ge(numOptions + 1)) > 0 then
147 | numGreater = torch.sum(ranks:ge(numOptions + 1));
148 | print(string.format('Warning: some of ranks >100 : %d', numGreater))
149 | ranks = ranks[ranks:le(numOptions + 1)];
150 | end
151 |
152 | ------------------------------------------------
153 | print(string.format('\tNo. questions: %d', numQues))
154 | print(string.format('\tr@1: %f', torch.sum(torch.le(ranks, 1))/numQues))
155 | print(string.format('\tr@5: %f', torch.sum(torch.le(ranks, 5))/numQues))
156 | print(string.format('\tr@10: %f', torch.sum(torch.le(ranks, 10))/numQues))
157 | print(string.format('\tmedianR: %f', torch.median(ranks:view(-1))[1]))
158 | print(string.format('\tmeanR: %f', torch.mean(ranks)))
159 | print(string.format('\tmeanRR: %f', torch.mean(ranks:cinv())))
160 | end
161 |
162 | return utils;
163 |
--------------------------------------------------------------------------------
/vis/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | VisDial Results
5 |
6 |
15 |
16 |
17 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/vis/static/main.js:
--------------------------------------------------------------------------------
1 | (function(){
2 | function pad(num, size) {
3 | var s = "000000000000" + num;
4 | return s.substr(s.length-size);
5 | }
6 | if (!String.prototype.format) {
7 | String.prototype.format = function() {
8 | var args = arguments;
9 | return this.replace(/{(\d+)}/g, function(match, number) {
10 | return typeof args[number] != 'undefined'
11 | ? args[number]
12 | : match
13 | ;
14 | });
15 | };
16 | }
17 | $.get('results/results.json', function(data) {
18 | var image_root = "https://vision.ece.vt.edu/mscoco/images/val2014/";
19 | if (data.opts.sampleWords == 0)
20 | $('#heading').html('Encoder: ' + data.opts.encoder
21 | + ', Decoder: ' + data.opts.decoder + ', Beam size: ' + data.opts.beamSize + ', Max beam length: ' + data.opts.beamLen);
22 | else
23 | $('#heading').html('Encoder: ' + data.opts.encoder
24 | + ', Decoder: ' + data.opts.decoder + ', Temperature: ' + data.opts.temperature);
25 |
26 | var html = '';
27 | for (var i in data.data) {
28 | if (i % 4 == 0)
29 | html += ""
30 | html += "
"// + data.data[i].image_id
31 | html += "
".format(image_root, pad(parseInt(data.data[i].image_id), 12))
32 | html += "
"
33 | for (var j = 0; j < 10; j++) {
34 | html += "" + data.data[i].dialog[j].question + " " + data.data[i].dialog[j].answer + " "
35 | }
36 | html += " "
37 | if (i % 4 == 3)
38 | html += "
"
39 | }
40 | $('#main').html(html);
41 | })
42 | })();
43 |
--------------------------------------------------------------------------------