├── .gitignore
├── LICENSE
├── README.md
├── Training IRLC.ipynb
├── Training SoftCount.ipynb
├── Visualize IRLC.ipynb
├── config.py
├── dataset.py
├── model.py
├── tools
├── compute_softscore.py
├── create_dictionary.py
├── create_how_many_qa_dataset.py
├── detection_features_converter.py
├── download.sh
├── download_hmqa.sh
├── process.sh
└── process_hmqa.sh
└── vis
├── 00-selection_image-335.png
├── 00-selection_image-364.png
├── 01-selection_image-335.png
├── 01-selection_image-364.png
├── 02-selection_image-335.png
├── 02-selection_image-364.png
├── 03-selection_image-335.png
├── 04-selection_image-335.png
├── image_candidates-335.png
├── image_candidates-364.png
├── orig_image-335.png
└── orig_image-364.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | saved_models/
3 | .idea
4 | data
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # SageMath parsed files
85 | *.sage.py
86 |
87 | # Environments
88 | .env
89 | .venv
90 | env/
91 | venv/
92 | ENV/
93 | env.bak/
94 | venv.bak/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Sanyam Agarwal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # irlc-vqa
2 | Code for **[Interpretable Counting for Visual Question Answering](https://arxiv.org/pdf/1712.08697.pdf)** for ICLR 2018 reproducibility challenge.
3 |
4 | ## About
5 | The paper improves upon the state-of-the art accuracy for counting based questions in VQA. They do it by enforcing the prior that each count corresponds to a well defined region in the image and is not diffused all over it. They hard-attend over a fixed set of candiate regions (taken from pre-trained Faster-R-CNN network) in the image by fusing it with the information from the question. They use a variant of REINFORCE - Self Critical Training - which is well suited for generating sequences.
6 |
7 | I found the paper quite interesting. Since I could not find any publicly available implementation of this paper I decided to implement it as a self-excercise.
8 |
9 |
10 | ## Results (without caption grounding)
11 |
12 | #### SoftCount
13 | | Model | Test Accuracy | Test RMSE | Training Time
14 | | --- | --- | -- | -- |
15 | | Reported | 49.2 | 2.45 | Unknown |
16 | | This implementation | **49.7** | **2.31** | ~12 minutes (Nvidia-1080 Ti) |
17 |
18 | #### IRLC
19 | | Model | Test Accuracy | Test RMSE | Training Time
20 | | --- | --- | -- | -- |
21 | | Reported | **56.1** | 2.45 | Unknown |
22 | | This implementation | 55.7* | **2.41** | ~6 hours (Nvidia-1080 Ti) |
23 |
24 | *= Still improving. Work in Progress.
25 |
26 | The **accuracy** was calculated using the [VQA evaluation metric](http://www.visualqa.org/evaluation.html). I used the exact same script for calculating "soft score" as in https://github.com/hengyuan-hu/bottom-up-attention-vqa.
27 |
28 | **RMSE** = root mean squared error from the ground truth (see below for how ground truth was chosen for VQA).
29 |
30 | **Note**: These numbers correspond to the test accuracy and RMSE when the accuracy on the development set was maximum. The peak test accuracy is usually higher by about a percent.
31 |
32 |
33 | ## Key differences from the paper
34 | - GRU was used instead of LSTM for generating question embeddings. Experiments with LSTM led to slower learning and more over-fitting. More hyper-parameter search is required to fix this.
35 |
36 | - Gated Tanh Unit is not used. Instead, a 2-layer Leaky ReLu based network inspired by https://github.com/hengyuan-hu/bottom-up-attention-vqa with slight modifications is used.
37 |
38 |
39 | ## Filling in missing details in the paper
40 |
41 | #### VQA Ground Truth
42 | I couldn't find any annotations for a "single ground truth" which is requred to calculate the REINFORCE reward in IRLC. Also, I could not find any details in the paper relating to this issue. So I took as ground truth the label that was reported as the answer most number of times. In case there are more than one such label, the one having the least numerical value was picked (this might explain a lower RMSE).
43 |
44 | #### Number of epochs
45 | The authors mentioned that they use early stopping based on the development set accuracy but I couldn't find an exact method to determine when to stop. So I run the training for 100 epochs for IRLC and 20 epochs for SoftCount.
46 |
47 | #### Number of candidate objects
48 | I could not find the value of N = number of candidate objects that are taken from Faster-R-CNN so following https://github.com/hengyuan-hu/bottom-up-attention-vqa I took N=36.
49 |
50 | ## Minor discrepancies
51 |
52 | #### Number of images due to Visual Genome
53 | From Table 1 in the paper, it would seem that adding the extra data from Visual Genome doesn't change the number of training images (31932). However, while writing the dataloaders for Visual Genome I noticed around 45k images after including the visual genome dataset. This is not really a big issue, but I still thought I'd write it so that other people can avoid wasting their time investigating it.
54 |
55 | ## Other Implementation Details
56 |
57 | - This implementation borrows most of its pre-processing and data loading code from https://github.com/hengyuan-hu/bottom-up-attention-vqa
58 |
59 | - The optional "caption grounding" step was skipped since it only improved the accuracy by a percent or so and was not the main focus of the paper.
60 |
61 | - The authors didn't mention the amount of dropout they used. After trying a few values the value 0.5 was chosen.
62 |
63 | - The value for number of samples was kept to 32 (instead of 5 as mentioned in the paper). The value 32 was chosen because it was the maximum value for which the training time did not suffer significantly. The effects of changing sample size on accuracy were not tested.
64 |
65 | - All other parameters were kept same. Optimizer, learning rate, learning schedule, etc. are exactly the same as mentioned in the paper.
66 |
67 | ## Usage
68 | #### Prerequisites
69 | Make sure you are on a machine with an NVIDIA GPU and Python 3 with about 100 GB disk space. Python 2 might be required for running some scripts in ./tools (will try to fix this soon)
70 |
71 | #### Installation
72 | - Install PyTorch v0.4 with CUDA and Python 3.5.
73 | - Install h5py.
74 |
75 | #### Data Setup
76 | All data should be downloaded to a 'data/' directory in the root directory of this repository.
77 |
78 | The easiest way to download the data is to run the provided scripts `tools/download.sh` and then `tools/download_hmqa.sh` from the repository root. If the script does not work, it should be easy to examine the script and modify the steps outlined in it according to your needs. Then run `tools/process.sh` and `tools/process_hmqa.sh` from the repository root to process the data to the correct format. Some scripts in `tools/process.sh` might require Python2 (I am working on fixing this).
79 |
80 | #### Training
81 | Simply execute the cells in the IPython notebook `Training IRLC.ipynb` to start training. The development and testing scores will be printed every epoch. The model is saved every 10 epochs under the `saved_models` directory.
82 |
83 | #### Visualization
84 | Simply follow the `Visualize IRLC.ipynb` notebook. For this you will need to download MS-COCO images.
85 |
86 | Here are sample visualizations:
87 |
88 | Ques: How many sheepskin are grazing?
89 |
90 | Ans: 4
91 |
92 | Pred: 4
93 |
94 |
Original Image Candidates Objects in image
95 |
96 |
97 |
98 | IRLC assigns different probabilities to each candidate. Dark blue boxes is more probable, faint blue is less probable. IRLC then picks most probable ones first. The picked up objects are show in bright red boxes.
99 |
100 | timestep=0 timestep=1
101 |
102 | timestep=2 timestep=3
103 |
104 | timestep=4
105 |
106 |
107 |
108 | ## Acknowledgements
109 | The repository https://github.com/hengyuan-hu/bottom-up-attention-vqa was a huge help. It would have taken me a week at the least to write all code for pre-processing the data myself. A big thanks to the authors of this repository!
110 |
--------------------------------------------------------------------------------
/Training SoftCount.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import os\n",
20 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
21 | "# os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\""
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 3,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import torch\n",
31 | "from torch.autograd import Variable\n",
32 | "\n",
33 | "from dataset import Dictionary, HMQAFeatureDataset\n",
34 | "from model import SoftCount\n",
35 | "from config import *\n",
36 | "from datetime import datetime, timedelta\n",
37 | "\n",
38 | "import h5py\n",
39 | "import numpy as np\n",
40 | "import _pickle as pkl\n",
41 | "import json\n",
42 | "import torch.nn.functional as F"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 4,
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "name": "stdout",
52 | "output_type": "stream",
53 | "text": [
54 | "loading dictionary from data/dictionary.pkl\n"
55 | ]
56 | }
57 | ],
58 | "source": [
59 | "dictionary = Dictionary.load_from_file('data/dictionary.pkl')"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 6,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "%%time\n",
69 | "print('loading features from train hdf5 file')\n",
70 | "train_h5_loc = './data/train36.hdf5'\n",
71 | "with h5py.File(train_h5_loc, 'r') as hf:\n",
72 | " train_image_features = np.array(hf.get('image_features'))\n",
73 | " train_spatials_features = np.array(hf.get('spatial_features'))\n",
74 | "# np.save( open(\"/tmp/vqa/train_image_features\", \"wb\"), train_image_features)\n",
75 | "# np.save( open(\"/tmp/vqa/train_spatials_features\", \"wb\"), train_spatials_features)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 7,
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "name": "stdout",
85 | "output_type": "stream",
86 | "text": [
87 | "CPU times: user 80 ms, sys: 11.7 s, total: 11.8 s\n",
88 | "Wall time: 4min 25s\n"
89 | ]
90 | }
91 | ],
92 | "source": [
93 | "# %%time\n",
94 | "# train_image_features = np.load(open(\"/tmp/vqa/train_image_features\", \"rb\"))\n",
95 | "# train_spatials_features = np.load(open(\"/tmp/vqa/train_spatials_features\", \"rb\"))"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 8,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "from dataset import HMQAFeatureDataset\n",
105 | "\n",
106 | "hmqa_train_dset = HMQAFeatureDataset(\n",
107 | " img_id2hqma_idx = pkl.load(open(\"./data/train36_imgid2idx.pkl\", \"rb\")),\n",
108 | " image_features = train_image_features, \n",
109 | " spatial_features = train_spatials_features, \n",
110 | " qid2count = json.load(open(\"./data/how_many_qa/qid2count.json\", \"rb\")), \n",
111 | " qid2count2score = json.load(open(\"./data/how_many_qa/qid2count2score.json\", \"rb\")), \n",
112 | " name=\"train\", \n",
113 | " dictionary=dictionary\n",
114 | ")\n",
115 | "del HMQAFeatureDataset"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 9,
121 | "metadata": {},
122 | "outputs": [
123 | {
124 | "data": {
125 | "text/plain": [
126 | "83642"
127 | ]
128 | },
129 | "execution_count": 9,
130 | "metadata": {},
131 | "output_type": "execute_result"
132 | }
133 | ],
134 | "source": [
135 | "len(hmqa_train_dset)"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 10,
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "data": {
145 | "text/plain": [
146 | "45546"
147 | ]
148 | },
149 | "execution_count": 10,
150 | "metadata": {},
151 | "output_type": "execute_result"
152 | }
153 | ],
154 | "source": [
155 | "len(set([x[\"image_id\"] for x in hmqa_train_dset.entries]))"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 11,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "%%time\n",
165 | "print('loading features from val hdf5 file')\n",
166 | "val_h5_loc = './data/val36.hdf5'\n",
167 | "with h5py.File(val_h5_loc, 'r') as hf:\n",
168 | " val_image_features = np.array(hf.get('image_features'))\n",
169 | " val_spatials_features = np.array(hf.get('spatial_features'))\n",
170 | "# np.save( open(\"/tmp/vqa/val_image_features\", \"wb\"), val_image_features)\n",
171 | "# np.save( open(\"/tmp/vqa/val_spatials_features\", \"wb\"), val_spatials_features)"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 12,
177 | "metadata": {},
178 | "outputs": [
179 | {
180 | "name": "stdout",
181 | "output_type": "stream",
182 | "text": [
183 | "CPU times: user 44 ms, sys: 22.7 s, total: 22.7 s\n",
184 | "Wall time: 1min 9s\n"
185 | ]
186 | }
187 | ],
188 | "source": [
189 | "# %%time\n",
190 | "# val_image_features = np.load(open(\"/tmp/vqa/val_image_features\", \"rb\"))\n",
191 | "# val_spatials_features = np.load(open(\"/tmp/vqa/val_spatials_features\", \"rb\"))"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 13,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "# len(train_image_features)\n",
201 | "\n",
202 | "from dataset import HMQAFeatureDataset\n",
203 | "\n",
204 | "hmqa_dev_dset = HMQAFeatureDataset(\n",
205 | " img_id2hqma_idx = pkl.load(open(\"./data/val36_imgid2idx.pkl\", \"rb\")),\n",
206 | " image_features = val_image_features, \n",
207 | " spatial_features = val_spatials_features, \n",
208 | " qid2count = json.load(open(\"./data/how_many_qa/qid2count.json\", \"rb\")), \n",
209 | " qid2count2score = json.load(open(\"./data/how_many_qa/qid2count2score.json\", \"rb\")), \n",
210 | " name=\"dev\", \n",
211 | " dictionary=dictionary\n",
212 | ")\n",
213 | "\n",
214 | "hmqa_test_dset = HMQAFeatureDataset(\n",
215 | " img_id2hqma_idx = pkl.load(open(\"./data/val36_imgid2idx.pkl\", \"rb\")),\n",
216 | " image_features = val_image_features, \n",
217 | " spatial_features = val_spatials_features, \n",
218 | " qid2count = json.load(open(\"./data/how_many_qa/qid2count.json\", \"rb\")), \n",
219 | " qid2count2score = json.load(open(\"./data/how_many_qa/qid2count2score.json\", \"rb\")), \n",
220 | " name=\"test\", \n",
221 | " dictionary=dictionary\n",
222 | ")\n",
223 | "del HMQAFeatureDataset"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 14,
229 | "metadata": {},
230 | "outputs": [
231 | {
232 | "data": {
233 | "text/plain": [
234 | "(17714, 5000)"
235 | ]
236 | },
237 | "execution_count": 14,
238 | "metadata": {},
239 | "output_type": "execute_result"
240 | }
241 | ],
242 | "source": [
243 | "len(hmqa_dev_dset), len(hmqa_test_dset)"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 15,
249 | "metadata": {},
250 | "outputs": [],
251 | "source": [
252 | "from torch.utils.data import DataLoader\n",
253 | "\n",
254 | "hmqa_train_loader = DataLoader(hmqa_train_dset, 64, shuffle=True, num_workers=0)\n",
255 | "hmqa_dev_loader = DataLoader(hmqa_dev_dset, 64, shuffle=True, num_workers=0)\n",
256 | "hmqa_test_loader = DataLoader(hmqa_test_dset, 64, shuffle=True, num_workers=0)"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 55,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "def evaluate(model, hmqa_loader):\n",
266 | " \n",
267 | " all_acc = []\n",
268 | " all_se = []\n",
269 | " for i, (v_emb, b, q, c, c2s) in enumerate(hmqa_loader):\n",
270 | " v_emb = Variable(v_emb)\n",
271 | " q = Variable(q)\n",
272 | " c = Variable(c).float()\n",
273 | " \n",
274 | " if USE_CUDA:\n",
275 | " v_emb = v_emb.cuda()\n",
276 | " q = q.cuda()\n",
277 | " c = c.cuda()\n",
278 | "\n",
279 | " pred = model(v_emb, q)\n",
280 | " \n",
281 | " nearest_pred = (pred + 0.5).long().clamp(0, 20)\n",
282 | " for one_c, one_c2s, one_pred in zip(c, c2s, nearest_pred):\n",
283 | " one_c = one_c.cpu().data\n",
284 | " one_pred = one_pred.cpu().data\n",
285 | " \n",
286 | " all_se.append((one_c - one_pred.float()) ** 2)\n",
287 | " all_acc.append(one_c2s[one_pred])\n",
288 | " \n",
289 | " acc = torch.stack(all_acc).mean()\n",
290 | " rmse = torch.stack(all_se).mean() ** 0.5\n",
291 | " \n",
292 | " return acc, rmse"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": 67,
298 | "metadata": {},
299 | "outputs": [
300 | {
301 | "name": "stdout",
302 | "output_type": "stream",
303 | "text": [
304 | "initialising with glove embeddings\n",
305 | "done.\n"
306 | ]
307 | }
308 | ],
309 | "source": [
310 | "from model import SoftCount\n",
311 | "model = SoftCount(ques_dim=1024, score_dim=512, dropout=0.2)\n",
312 | "del SoftCount"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": 68,
318 | "metadata": {},
319 | "outputs": [
320 | {
321 | "data": {
322 | "text/plain": [
323 | "SoftCount(\n",
324 | " (ques_parser): QuestionParser(\n",
325 | " (embd): Embedding(20159, 300, padding_idx=20158)\n",
326 | " (rnn): GRU(300, 1024)\n",
327 | " (drop): Dropout(p=0.2)\n",
328 | " )\n",
329 | " (f): ScoringFunction(\n",
330 | " (v_drop): Dropout(p=0.2)\n",
331 | " (q_drop): Dropout(p=0.2)\n",
332 | " (v_proj): FCNet(\n",
333 | " (main): Sequential(\n",
334 | " (0): Linear(in_features=2048, out_features=512, bias=True)\n",
335 | " (1): LeakyReLU(negative_slope=0.01)\n",
336 | " )\n",
337 | " )\n",
338 | " (q_proj): FCNet(\n",
339 | " (main): Sequential(\n",
340 | " (0): Linear(in_features=1024, out_features=512, bias=True)\n",
341 | " (1): LeakyReLU(negative_slope=0.01)\n",
342 | " )\n",
343 | " )\n",
344 | " (s_drop): Dropout(p=0.2)\n",
345 | " )\n",
346 | " (W): Linear(in_features=512, out_features=1, bias=True)\n",
347 | ")"
348 | ]
349 | },
350 | "execution_count": 68,
351 | "metadata": {},
352 | "output_type": "execute_result"
353 | }
354 | ],
355 | "source": [
356 | "if USE_CUDA:\n",
357 | " model.cuda()\n",
358 | "model"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 69,
364 | "metadata": {},
365 | "outputs": [
366 | {
367 | "data": {
368 | "text/plain": [
369 | "(tensor(1.00000e-03 *\n",
370 | " 3.0400), tensor(15.4749))"
371 | ]
372 | },
373 | "execution_count": 69,
374 | "metadata": {},
375 | "output_type": "execute_result"
376 | }
377 | ],
378 | "source": [
379 | "test_acc, test_rmse = evaluate(model, hmqa_test_loader)\n",
380 | "test_acc, test_rmse"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": 70,
386 | "metadata": {},
387 | "outputs": [],
388 | "source": [
389 | "opt = torch.optim.Adam(model.parameters(), lr=3e-4)"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": 71,
395 | "metadata": {},
396 | "outputs": [],
397 | "source": [
398 | "test_accs = []\n",
399 | "test_rmses = []\n",
400 | "\n",
401 | "dev_accs = []\n",
402 | "dev_rmses = []"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": 72,
408 | "metadata": {
409 | "scrolled": true
410 | },
411 | "outputs": [
412 | {
413 | "name": "stdout",
414 | "output_type": "stream",
415 | "text": [
416 | "epoch = 0, i = 0, loss = 14.628400802612305\n"
417 | ]
418 | },
419 | {
420 | "name": "stderr",
421 | "output_type": "stream",
422 | "text": [
423 | "/home/sanyam/miniconda3/envs/py3t4/lib/python3.6/site-packages/ipykernel_launcher.py:21: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.\n"
424 | ]
425 | },
426 | {
427 | "name": "stdout",
428 | "output_type": "stream",
429 | "text": [
430 | "epoch = 0, i = 100, loss = 1.1435465812683105\n",
431 | "epoch = 0, i = 200, loss = 1.0605370998382568\n",
432 | "epoch = 0, i = 300, loss = 0.9011174440383911\n",
433 | "epoch = 0, i = 400, loss = 1.6539443731307983\n",
434 | "epoch = 0, i = 500, loss = 1.0184824466705322\n",
435 | "epoch = 0, i = 600, loss = 1.2407264709472656\n",
436 | "epoch = 0, i = 700, loss = 1.3670611381530762\n",
437 | "epoch = 0, i = 800, loss = 1.5058083534240723\n",
438 | "epoch = 0, i = 900, loss = 0.9516408443450928\n",
439 | "epoch = 0, i = 1000, loss = 0.7743334174156189\n",
440 | "epoch = 0, i = 1100, loss = 0.8096193671226501\n",
441 | "epoch = 0, i = 1200, loss = 1.2751123905181885\n",
442 | "epoch = 0, i = 1300, loss = 1.3818869590759277\n",
443 | "evaluating model on dev and test...\n",
444 | "dev_acc: 0.41113242506980896, dev_rmse: 2.9276204109191895\n",
445 | "test_acc: 0.4032999873161316, test_rmse: 2.686856985092163\n",
446 | "epoch = 1, i = 0, loss = 1.0717486143112183\n",
447 | "epoch = 1, i = 100, loss = 0.7833133935928345\n",
448 | "epoch = 1, i = 200, loss = 0.8350750207901001\n",
449 | "epoch = 1, i = 300, loss = 0.9446524381637573\n",
450 | "epoch = 1, i = 400, loss = 0.6972931623458862\n",
451 | "epoch = 1, i = 500, loss = 0.9530639052391052\n",
452 | "epoch = 1, i = 600, loss = 0.7725903987884521\n",
453 | "epoch = 1, i = 700, loss = 0.7252683639526367\n",
454 | "epoch = 1, i = 800, loss = 0.7162841558456421\n",
455 | "epoch = 1, i = 900, loss = 0.8944467306137085\n",
456 | "epoch = 1, i = 1000, loss = 0.5890282392501831\n",
457 | "epoch = 1, i = 1100, loss = 0.7408154606819153\n",
458 | "epoch = 1, i = 1200, loss = 0.7860620617866516\n",
459 | "epoch = 1, i = 1300, loss = 0.9527196884155273\n",
460 | "evaluating model on dev and test...\n",
461 | "dev_acc: 0.438692569732666, dev_rmse: 2.9244272708892822\n",
462 | "test_acc: 0.43577998876571655, test_rmse: 2.673985719680786\n",
463 | "epoch = 2, i = 0, loss = 0.5547971725463867\n",
464 | "epoch = 2, i = 100, loss = 0.8121598958969116\n",
465 | "epoch = 2, i = 200, loss = 1.082768201828003\n",
466 | "epoch = 2, i = 300, loss = 0.9104152917861938\n",
467 | "epoch = 2, i = 400, loss = 0.6321248412132263\n",
468 | "epoch = 2, i = 500, loss = 0.538659930229187\n",
469 | "epoch = 2, i = 600, loss = 0.5090115666389465\n",
470 | "epoch = 2, i = 700, loss = 0.9042094349861145\n",
471 | "epoch = 2, i = 800, loss = 0.6963248252868652\n",
472 | "epoch = 2, i = 900, loss = 0.7098879218101501\n",
473 | "epoch = 2, i = 1000, loss = 0.6962026357650757\n",
474 | "epoch = 2, i = 1100, loss = 0.8791968822479248\n",
475 | "epoch = 2, i = 1200, loss = 1.2076174020767212\n",
476 | "epoch = 2, i = 1300, loss = 0.6140683889389038\n",
477 | "evaluating model on dev and test...\n",
478 | "dev_acc: 0.4398724138736725, dev_rmse: 2.7906925678253174\n",
479 | "test_acc: 0.43893998861312866, test_rmse: 2.562772035598755\n",
480 | "epoch = 3, i = 0, loss = 1.0250123739242554\n",
481 | "epoch = 3, i = 100, loss = 0.5901322364807129\n",
482 | "epoch = 3, i = 200, loss = 0.6190376877784729\n",
483 | "epoch = 3, i = 300, loss = 0.747307300567627\n",
484 | "epoch = 3, i = 400, loss = 0.712824821472168\n",
485 | "epoch = 3, i = 500, loss = 0.5977566242218018\n",
486 | "epoch = 3, i = 600, loss = 0.9443403482437134\n",
487 | "epoch = 3, i = 700, loss = 0.537765383720398\n",
488 | "epoch = 3, i = 800, loss = 0.5435270071029663\n",
489 | "epoch = 3, i = 900, loss = 0.4770953953266144\n",
490 | "epoch = 3, i = 1000, loss = 0.8796859979629517\n",
491 | "epoch = 3, i = 1100, loss = 0.4031006097793579\n",
492 | "epoch = 3, i = 1200, loss = 0.8212300539016724\n",
493 | "epoch = 3, i = 1300, loss = 0.9886336326599121\n",
494 | "evaluating model on dev and test...\n",
495 | "dev_acc: 0.45250648260116577, dev_rmse: 2.733506202697754\n",
496 | "test_acc: 0.46630001068115234, test_rmse: 2.499239921569824\n",
497 | "epoch = 4, i = 0, loss = 0.8000674843788147\n",
498 | "epoch = 4, i = 100, loss = 0.6346131563186646\n",
499 | "epoch = 4, i = 200, loss = 0.8446760177612305\n",
500 | "epoch = 4, i = 300, loss = 1.2564671039581299\n",
501 | "epoch = 4, i = 400, loss = 1.1416804790496826\n",
502 | "epoch = 4, i = 500, loss = 0.8681952953338623\n",
503 | "epoch = 4, i = 600, loss = 0.596240758895874\n",
504 | "epoch = 4, i = 700, loss = 0.760988175868988\n",
505 | "epoch = 4, i = 800, loss = 0.8439410328865051\n",
506 | "epoch = 4, i = 900, loss = 0.6394860148429871\n",
507 | "epoch = 4, i = 1000, loss = 0.589470624923706\n",
508 | "epoch = 4, i = 1100, loss = 0.8739631175994873\n",
509 | "epoch = 4, i = 1200, loss = 0.43961581587791443\n",
510 | "epoch = 4, i = 1300, loss = 0.7726854085922241\n",
511 | "evaluating model on dev and test...\n",
512 | "dev_acc: 0.4580501317977905, dev_rmse: 2.6624057292938232\n",
513 | "test_acc: 0.46700000762939453, test_rmse: 2.432323932647705\n",
514 | "epoch = 5, i = 0, loss = 0.5627894401550293\n",
515 | "epoch = 5, i = 100, loss = 0.6562784910202026\n",
516 | "epoch = 5, i = 200, loss = 0.9611048698425293\n",
517 | "epoch = 5, i = 300, loss = 0.40457683801651\n",
518 | "epoch = 5, i = 400, loss = 0.5332111716270447\n",
519 | "epoch = 5, i = 500, loss = 0.7762923240661621\n",
520 | "epoch = 5, i = 600, loss = 0.6530032157897949\n",
521 | "epoch = 5, i = 700, loss = 0.8856066465377808\n",
522 | "epoch = 5, i = 800, loss = 0.5990040302276611\n",
523 | "epoch = 5, i = 900, loss = 0.8177927136421204\n",
524 | "epoch = 5, i = 1000, loss = 0.5464267730712891\n",
525 | "epoch = 5, i = 1100, loss = 0.577069878578186\n",
526 | "epoch = 5, i = 1200, loss = 0.6961219906806946\n",
527 | "epoch = 5, i = 1300, loss = 0.6551117897033691\n",
528 | "evaluating model on dev and test...\n",
529 | "dev_acc: 0.46489217877388, dev_rmse: 2.658437728881836\n",
530 | "test_acc: 0.4720799922943115, test_rmse: 2.4418435096740723\n",
531 | "epoch = 6, i = 0, loss = 0.9389997720718384\n",
532 | "epoch = 6, i = 100, loss = 0.6106300354003906\n",
533 | "epoch = 6, i = 200, loss = 0.5450400114059448\n",
534 | "epoch = 6, i = 300, loss = 0.9185134172439575\n",
535 | "epoch = 6, i = 400, loss = 0.4569363594055176\n",
536 | "epoch = 6, i = 500, loss = 0.9947926998138428\n",
537 | "epoch = 6, i = 600, loss = 0.8373309373855591\n",
538 | "epoch = 6, i = 700, loss = 0.5167772769927979\n",
539 | "epoch = 6, i = 800, loss = 0.4667099416255951\n",
540 | "epoch = 6, i = 900, loss = 0.7207782864570618\n",
541 | "epoch = 6, i = 1000, loss = 0.7581618428230286\n",
542 | "epoch = 6, i = 1100, loss = 0.5784988403320312\n",
543 | "epoch = 6, i = 1200, loss = 0.5652433633804321\n",
544 | "epoch = 6, i = 1300, loss = 0.49687403440475464\n",
545 | "evaluating model on dev and test...\n",
546 | "dev_acc: 0.47183018922805786, dev_rmse: 2.684798002243042\n",
547 | "test_acc: 0.48076000809669495, test_rmse: 2.4388933181762695\n",
548 | "epoch = 7, i = 0, loss = 0.6185715794563293\n",
549 | "epoch = 7, i = 100, loss = 0.4020466208457947\n",
550 | "epoch = 7, i = 200, loss = 0.5055506825447083\n",
551 | "epoch = 7, i = 300, loss = 0.9028556942939758\n",
552 | "epoch = 7, i = 400, loss = 0.6325691938400269\n",
553 | "epoch = 7, i = 500, loss = 0.5334853529930115\n",
554 | "epoch = 7, i = 600, loss = 0.7755710482597351\n",
555 | "epoch = 7, i = 700, loss = 0.6400124430656433\n",
556 | "epoch = 7, i = 800, loss = 0.8320757746696472\n",
557 | "epoch = 7, i = 900, loss = 0.6655226945877075\n",
558 | "epoch = 7, i = 1000, loss = 0.6563359498977661\n",
559 | "epoch = 7, i = 1100, loss = 0.7202091217041016\n",
560 | "epoch = 7, i = 1200, loss = 0.47889444231987\n",
561 | "epoch = 7, i = 1300, loss = 0.5980694890022278\n",
562 | "evaluating model on dev and test...\n",
563 | "dev_acc: 0.452574223279953, dev_rmse: 2.565908908843994\n",
564 | "test_acc: 0.45767998695373535, test_rmse: 2.3844916820526123\n",
565 | "epoch = 8, i = 0, loss = 0.4203689396381378\n",
566 | "epoch = 8, i = 100, loss = 0.6011685132980347\n",
567 | "epoch = 8, i = 200, loss = 0.4834330081939697\n",
568 | "epoch = 8, i = 300, loss = 0.5145273208618164\n",
569 | "epoch = 8, i = 400, loss = 0.5205063223838806\n",
570 | "epoch = 8, i = 500, loss = 0.46570590138435364\n",
571 | "epoch = 8, i = 600, loss = 0.5869192481040955\n",
572 | "epoch = 8, i = 700, loss = 0.9000004529953003\n",
573 | "epoch = 8, i = 800, loss = 1.029176950454712\n",
574 | "epoch = 8, i = 900, loss = 0.6565335988998413\n",
575 | "epoch = 8, i = 1000, loss = 0.48471736907958984\n",
576 | "epoch = 8, i = 1100, loss = 0.5047010183334351\n",
577 | "epoch = 8, i = 1200, loss = 0.9043950438499451\n",
578 | "epoch = 8, i = 1300, loss = 0.6348180174827576\n",
579 | "evaluating model on dev and test...\n",
580 | "dev_acc: 0.47310036420822144, dev_rmse: 2.5988094806671143\n",
581 | "test_acc: 0.4937399923801422, test_rmse: 2.366136074066162\n",
582 | "epoch = 9, i = 0, loss = 0.3131829798221588\n",
583 | "epoch = 9, i = 100, loss = 0.46003639698028564\n",
584 | "epoch = 9, i = 200, loss = 0.5366580486297607\n",
585 | "epoch = 9, i = 300, loss = 0.45132943987846375\n",
586 | "epoch = 9, i = 400, loss = 0.5376684665679932\n",
587 | "epoch = 9, i = 500, loss = 0.9212081432342529\n",
588 | "epoch = 9, i = 600, loss = 0.4985387921333313\n",
589 | "epoch = 9, i = 700, loss = 0.2654397785663605\n",
590 | "epoch = 9, i = 800, loss = 1.150120496749878\n",
591 | "epoch = 9, i = 900, loss = 0.44342440366744995\n",
592 | "epoch = 9, i = 1000, loss = 0.6016779541969299\n",
593 | "epoch = 9, i = 1100, loss = 0.7659499645233154\n",
594 | "epoch = 9, i = 1200, loss = 0.5102328062057495\n",
595 | "epoch = 9, i = 1300, loss = 0.6641837954521179\n",
596 | "evaluating model on dev and test...\n",
597 | "dev_acc: 0.47911256551742554, dev_rmse: 2.5959513187408447\n",
598 | "test_acc: 0.490339994430542, test_rmse: 2.3581349849700928\n",
599 | "epoch = 10, i = 0, loss = 0.9485877156257629\n",
600 | "epoch = 10, i = 100, loss = 0.4240656793117523\n",
601 | "epoch = 10, i = 200, loss = 0.7539646625518799\n",
602 | "epoch = 10, i = 300, loss = 0.43598222732543945\n",
603 | "epoch = 10, i = 400, loss = 0.5969750285148621\n",
604 | "epoch = 10, i = 500, loss = 0.6358165144920349\n"
605 | ]
606 | },
607 | {
608 | "name": "stdout",
609 | "output_type": "stream",
610 | "text": [
611 | "epoch = 10, i = 600, loss = 0.4486617147922516\n",
612 | "epoch = 10, i = 700, loss = 0.4893093705177307\n",
613 | "epoch = 10, i = 800, loss = 0.8087934851646423\n",
614 | "epoch = 10, i = 900, loss = 0.6663355827331543\n",
615 | "epoch = 10, i = 1000, loss = 0.38856399059295654\n",
616 | "epoch = 10, i = 1100, loss = 0.6736136674880981\n",
617 | "epoch = 10, i = 1200, loss = 0.8147845268249512\n",
618 | "epoch = 10, i = 1300, loss = 0.7092846632003784\n",
619 | "evaluating model on dev and test...\n",
620 | "dev_acc: 0.473275363445282, dev_rmse: 2.5685806274414062\n",
621 | "test_acc: 0.4786800146102905, test_rmse: 2.3469128608703613\n",
622 | "epoch = 11, i = 0, loss = 0.5136302709579468\n",
623 | "epoch = 11, i = 100, loss = 0.6258996725082397\n",
624 | "epoch = 11, i = 200, loss = 0.5074462294578552\n",
625 | "epoch = 11, i = 300, loss = 0.8129042983055115\n",
626 | "epoch = 11, i = 400, loss = 0.5542230606079102\n",
627 | "epoch = 11, i = 500, loss = 0.5161349773406982\n",
628 | "epoch = 11, i = 600, loss = 0.5896819233894348\n",
629 | "epoch = 11, i = 700, loss = 0.6863290667533875\n",
630 | "epoch = 11, i = 800, loss = 0.5083870887756348\n",
631 | "epoch = 11, i = 900, loss = 0.3265340328216553\n",
632 | "epoch = 11, i = 1000, loss = 0.7775629162788391\n",
633 | "epoch = 11, i = 1100, loss = 0.7597150802612305\n",
634 | "epoch = 11, i = 1200, loss = 0.49520641565322876\n",
635 | "epoch = 11, i = 1300, loss = 0.44222038984298706\n",
636 | "evaluating model on dev and test...\n",
637 | "dev_acc: 0.4771198034286499, dev_rmse: 2.5524861812591553\n",
638 | "test_acc: 0.47642001509666443, test_rmse: 2.3526580333709717\n",
639 | "epoch = 12, i = 0, loss = 0.3910914957523346\n",
640 | "epoch = 12, i = 100, loss = 0.33726972341537476\n",
641 | "epoch = 12, i = 200, loss = 0.5295529365539551\n",
642 | "epoch = 12, i = 300, loss = 0.4419896602630615\n",
643 | "epoch = 12, i = 400, loss = 0.3294357657432556\n",
644 | "epoch = 12, i = 500, loss = 0.8175026178359985\n",
645 | "epoch = 12, i = 600, loss = 0.5169392228126526\n",
646 | "epoch = 12, i = 700, loss = 0.37090879678726196\n",
647 | "epoch = 12, i = 800, loss = 0.4401392340660095\n",
648 | "epoch = 12, i = 900, loss = 0.657434344291687\n",
649 | "epoch = 12, i = 1000, loss = 0.47396910190582275\n",
650 | "epoch = 12, i = 1100, loss = 0.6415238380432129\n",
651 | "epoch = 12, i = 1200, loss = 0.49534136056900024\n",
652 | "epoch = 12, i = 1300, loss = 0.5604371428489685\n",
653 | "evaluating model on dev and test...\n",
654 | "dev_acc: 0.482234388589859, dev_rmse: 2.542769432067871\n",
655 | "test_acc: 0.4944800138473511, test_rmse: 2.3373489379882812\n",
656 | "epoch = 13, i = 0, loss = 0.6540343761444092\n",
657 | "epoch = 13, i = 100, loss = 0.5504652261734009\n",
658 | "epoch = 13, i = 200, loss = 0.613745391368866\n",
659 | "epoch = 13, i = 300, loss = 0.48012328147888184\n",
660 | "epoch = 13, i = 400, loss = 0.5549217462539673\n",
661 | "epoch = 13, i = 500, loss = 0.7506428956985474\n",
662 | "epoch = 13, i = 600, loss = 0.6042653322219849\n",
663 | "epoch = 13, i = 700, loss = 0.6180323362350464\n",
664 | "epoch = 13, i = 800, loss = 0.4621243178844452\n",
665 | "epoch = 13, i = 900, loss = 0.29211950302124023\n",
666 | "epoch = 13, i = 1000, loss = 0.6967971324920654\n",
667 | "epoch = 13, i = 1100, loss = 0.5880112648010254\n",
668 | "epoch = 13, i = 1200, loss = 0.7064930200576782\n",
669 | "epoch = 13, i = 1300, loss = 0.3936261534690857\n",
670 | "evaluating model on dev and test...\n",
671 | "dev_acc: 0.4788077175617218, dev_rmse: 2.534341335296631\n",
672 | "test_acc: 0.48833999037742615, test_rmse: 2.3233165740966797\n",
673 | "epoch = 14, i = 0, loss = 0.6755526065826416\n",
674 | "epoch = 14, i = 100, loss = 0.33443135023117065\n",
675 | "epoch = 14, i = 200, loss = 0.36527568101882935\n",
676 | "epoch = 14, i = 300, loss = 1.0132145881652832\n",
677 | "epoch = 14, i = 400, loss = 0.5340710878372192\n",
678 | "epoch = 14, i = 500, loss = 0.7321001887321472\n",
679 | "epoch = 14, i = 600, loss = 0.8975104093551636\n",
680 | "epoch = 14, i = 700, loss = 0.5202504396438599\n",
681 | "epoch = 14, i = 800, loss = 1.0374795198440552\n",
682 | "epoch = 14, i = 900, loss = 0.8503618836402893\n",
683 | "epoch = 14, i = 1000, loss = 0.26997876167297363\n",
684 | "epoch = 14, i = 1100, loss = 0.42291390895843506\n",
685 | "epoch = 14, i = 1200, loss = 0.549490213394165\n",
686 | "epoch = 14, i = 1300, loss = 0.4087107479572296\n",
687 | "evaluating model on dev and test...\n",
688 | "dev_acc: 0.47907304763793945, dev_rmse: 2.5677783489227295\n",
689 | "test_acc: 0.5010799765586853, test_rmse: 2.3345234394073486\n",
690 | "epoch = 15, i = 0, loss = 0.3381015658378601\n",
691 | "epoch = 15, i = 100, loss = 0.47336000204086304\n",
692 | "epoch = 15, i = 200, loss = 0.5756043195724487\n",
693 | "epoch = 15, i = 300, loss = 0.3944014012813568\n",
694 | "epoch = 15, i = 400, loss = 0.6456435918807983\n",
695 | "epoch = 15, i = 500, loss = 0.6376544833183289\n",
696 | "epoch = 15, i = 600, loss = 0.38120079040527344\n",
697 | "epoch = 15, i = 700, loss = 0.39884161949157715\n",
698 | "epoch = 15, i = 800, loss = 0.3604346215724945\n",
699 | "epoch = 15, i = 900, loss = 0.355624258518219\n",
700 | "epoch = 15, i = 1000, loss = 0.4303146004676819\n",
701 | "epoch = 15, i = 1100, loss = 0.49465665221214294\n",
702 | "epoch = 15, i = 1200, loss = 0.610739529132843\n",
703 | "epoch = 15, i = 1300, loss = 0.4250361919403076\n",
704 | "evaluating model on dev and test...\n",
705 | "dev_acc: 0.4612114727497101, dev_rmse: 2.5253374576568604\n",
706 | "test_acc: 0.4729999899864197, test_rmse: 2.3174986839294434\n",
707 | "epoch = 16, i = 0, loss = 0.28534242510795593\n",
708 | "epoch = 16, i = 100, loss = 0.6168354153633118\n",
709 | "epoch = 16, i = 200, loss = 0.43274885416030884\n",
710 | "epoch = 16, i = 300, loss = 0.38376590609550476\n",
711 | "epoch = 16, i = 400, loss = 0.527509868144989\n",
712 | "epoch = 16, i = 500, loss = 0.45143264532089233\n",
713 | "epoch = 16, i = 600, loss = 0.3917695879936218\n",
714 | "epoch = 16, i = 700, loss = 0.5136957168579102\n",
715 | "epoch = 16, i = 800, loss = 0.3321128487586975\n",
716 | "epoch = 16, i = 900, loss = 0.35770952701568604\n",
717 | "epoch = 16, i = 1000, loss = 0.37515610456466675\n",
718 | "epoch = 16, i = 1100, loss = 0.36942195892333984\n",
719 | "epoch = 16, i = 1200, loss = 0.41256609559059143\n",
720 | "epoch = 16, i = 1300, loss = 0.7216428518295288\n",
721 | "evaluating model on dev and test...\n",
722 | "dev_acc: 0.48063114285469055, dev_rmse: 2.539537191390991\n",
723 | "test_acc: 0.48642000555992126, test_rmse: 2.3322949409484863\n",
724 | "epoch = 17, i = 0, loss = 0.1989438235759735\n",
725 | "epoch = 17, i = 100, loss = 0.37077924609184265\n",
726 | "epoch = 17, i = 200, loss = 0.548552393913269\n",
727 | "epoch = 17, i = 300, loss = 0.32453879714012146\n",
728 | "epoch = 17, i = 400, loss = 0.37269145250320435\n",
729 | "epoch = 17, i = 500, loss = 0.49430444836616516\n",
730 | "epoch = 17, i = 600, loss = 0.42291736602783203\n",
731 | "epoch = 17, i = 700, loss = 0.4055926203727722\n",
732 | "epoch = 17, i = 800, loss = 0.6607651710510254\n",
733 | "epoch = 17, i = 900, loss = 0.3112538456916809\n",
734 | "epoch = 17, i = 1000, loss = 0.3921870291233063\n",
735 | "epoch = 17, i = 1100, loss = 0.4733957052230835\n",
736 | "epoch = 17, i = 1200, loss = 0.3893886208534241\n",
737 | "epoch = 17, i = 1300, loss = 0.5213131308555603\n",
738 | "evaluating model on dev and test...\n",
739 | "dev_acc: 0.4829513430595398, dev_rmse: 2.543346643447876\n",
740 | "test_acc: 0.5041000247001648, test_rmse: 2.3136982917785645\n",
741 | "epoch = 18, i = 0, loss = 0.36615651845932007\n",
742 | "epoch = 18, i = 100, loss = 0.4917939305305481\n",
743 | "epoch = 18, i = 200, loss = 0.30025339126586914\n",
744 | "epoch = 18, i = 300, loss = 0.5090234279632568\n",
745 | "epoch = 18, i = 400, loss = 0.3057374060153961\n",
746 | "epoch = 18, i = 500, loss = 0.467905193567276\n",
747 | "epoch = 18, i = 600, loss = 0.31017005443573\n",
748 | "epoch = 18, i = 700, loss = 0.35623863339424133\n",
749 | "epoch = 18, i = 800, loss = 0.5136244297027588\n",
750 | "epoch = 18, i = 900, loss = 0.2820129990577698\n",
751 | "epoch = 18, i = 1000, loss = 0.5056727528572083\n",
752 | "epoch = 18, i = 1100, loss = 0.38559073209762573\n",
753 | "epoch = 18, i = 1200, loss = 0.5067494511604309\n",
754 | "epoch = 18, i = 1300, loss = 0.23265519738197327\n",
755 | "evaluating model on dev and test...\n",
756 | "dev_acc: 0.4571525454521179, dev_rmse: 2.53709077835083\n",
757 | "test_acc: 0.45956000685691833, test_rmse: 2.32684326171875\n",
758 | "epoch = 19, i = 0, loss = 0.4580419361591339\n",
759 | "epoch = 19, i = 100, loss = 0.6830368041992188\n",
760 | "epoch = 19, i = 200, loss = 0.6702297925949097\n",
761 | "epoch = 19, i = 300, loss = 0.3448217809200287\n",
762 | "epoch = 19, i = 400, loss = 0.5003037452697754\n",
763 | "epoch = 19, i = 500, loss = 0.5249284505844116\n",
764 | "epoch = 19, i = 600, loss = 0.5032473206520081\n",
765 | "epoch = 19, i = 700, loss = 0.46184176206588745\n",
766 | "epoch = 19, i = 800, loss = 0.5465124845504761\n",
767 | "epoch = 19, i = 900, loss = 0.2270112931728363\n",
768 | "epoch = 19, i = 1000, loss = 0.2748583257198334\n",
769 | "epoch = 19, i = 1100, loss = 0.5165499448776245\n",
770 | "epoch = 19, i = 1200, loss = 0.43792837858200073\n",
771 | "epoch = 19, i = 1300, loss = 0.5342369079589844\n",
772 | "evaluating model on dev and test...\n",
773 | "dev_acc: 0.4835158586502075, dev_rmse: 2.512709140777588\n",
774 | "test_acc: 0.49654000997543335, test_rmse: 2.3054285049438477\n"
775 | ]
776 | }
777 | ],
778 | "source": [
779 | "for epoch in range(20):\n",
780 | " for i, (v_emb, b, q, c, _) in enumerate(hmqa_train_loader):\n",
781 | " v_emb = Variable(v_emb)\n",
782 | " q = Variable(q)\n",
783 | " c = Variable(c).float().view(-1)\n",
784 | " \n",
785 | " if USE_CUDA:\n",
786 | " v_emb = v_emb.cuda()\n",
787 | " q = q.cuda()\n",
788 | " c = c.cuda()\n",
789 | "\n",
790 | " pred = model(v_emb, q)\n",
791 | " loss = F.smooth_l1_loss(pred, c)\n",
792 | " \n",
793 | " if i % 100 == 0:\n",
794 | " print(\"epoch = {}, i = {}, loss = {}\".format(\n",
795 | " epoch, i, loss.item()))\n",
796 | " \n",
797 | " opt.zero_grad()\n",
798 | " loss.backward()\n",
799 | " torch.nn.utils.clip_grad_norm(model.parameters(), 0.25)\n",
800 | " opt.step()\n",
801 | " \n",
802 | " print(\"evaluating model on dev and test...\")\n",
803 | "\n",
804 | " model.eval()\n",
805 | " dev_acc, dev_rmse = evaluate(model, hmqa_dev_loader)\n",
806 | " print(\"dev_acc: {}, dev_rmse: {}\".format(dev_acc, dev_rmse))\n",
807 | " test_acc, test_rmse = evaluate(model, hmqa_test_loader)\n",
808 | " print(\"test_acc: {}, test_rmse: {}\".format(test_acc, test_rmse))\n",
809 | " model.train()\n",
810 | " \n",
811 | " test_accs.append(test_acc)\n",
812 | " test_rmses.append(test_rmse)\n",
813 | " dev_accs.append(dev_acc)\n",
814 | " dev_rmses.append(dev_rmse)\n",
815 | " "
816 | ]
817 | },
818 | {
819 | "cell_type": "code",
820 | "execution_count": 73,
821 | "metadata": {
822 | "scrolled": true
823 | },
824 | "outputs": [
825 | {
826 | "data": {
827 | "text/plain": [
828 | "[(tensor(0.4835), tensor(0.4965), tensor(2.3054)),\n",
829 | " (tensor(0.4830), tensor(0.5041), tensor(2.3137)),\n",
830 | " (tensor(0.4822), tensor(0.4945), tensor(2.3373)),\n",
831 | " (tensor(0.4806), tensor(0.4864), tensor(2.3323)),\n",
832 | " (tensor(0.4791), tensor(0.4903), tensor(2.3581)),\n",
833 | " (tensor(0.4791), tensor(0.5011), tensor(2.3345)),\n",
834 | " (tensor(0.4788), tensor(0.4883), tensor(2.3233)),\n",
835 | " (tensor(0.4771), tensor(0.4764), tensor(2.3527)),\n",
836 | " (tensor(0.4733), tensor(0.4787), tensor(2.3469)),\n",
837 | " (tensor(0.4731), tensor(0.4937), tensor(2.3661)),\n",
838 | " (tensor(0.4718), tensor(0.4808), tensor(2.4389)),\n",
839 | " (tensor(0.4649), tensor(0.4721), tensor(2.4418)),\n",
840 | " (tensor(0.4612), tensor(0.4730), tensor(2.3175)),\n",
841 | " (tensor(0.4581), tensor(0.4670), tensor(2.4323)),\n",
842 | " (tensor(0.4572), tensor(0.4596), tensor(2.3268)),\n",
843 | " (tensor(0.4526), tensor(0.4577), tensor(2.3845)),\n",
844 | " (tensor(0.4525), tensor(0.4663), tensor(2.4992)),\n",
845 | " (tensor(0.4399), tensor(0.4389), tensor(2.5628)),\n",
846 | " (tensor(0.4387), tensor(0.4358), tensor(2.6740)),\n",
847 | " (tensor(0.4111), tensor(0.4033), tensor(2.6869))]"
848 | ]
849 | },
850 | "execution_count": 73,
851 | "metadata": {},
852 | "output_type": "execute_result"
853 | }
854 | ],
855 | "source": [
856 | "top_dev_accs = sorted(zip(dev_accs, test_accs, test_rmses), reverse=True)\n",
857 | "top_dev_accs"
858 | ]
859 | },
860 | {
861 | "cell_type": "code",
862 | "execution_count": 74,
863 | "metadata": {},
864 | "outputs": [
865 | {
866 | "name": "stdout",
867 | "output_type": "stream",
868 | "text": [
869 | "The best dev accuracy is 0.4835158586502075. The corresponding test accuracy and test RMSE are 0.49654000997543335 and 2.3054285049438477 respectively\n"
870 | ]
871 | }
872 | ],
873 | "source": [
874 | "best_dev_acc, corr_test_acc, corr_test_rmse = top_dev_accs[0]\n",
875 | "print(\"The best dev accuracy is {}. The corresponding test accuracy and test RMSE are {} and {} respectively\".format(\n",
876 | " best_dev_acc, corr_test_acc, corr_test_rmse\n",
877 | "))"
878 | ]
879 | },
880 | {
881 | "cell_type": "code",
882 | "execution_count": null,
883 | "metadata": {},
884 | "outputs": [],
885 | "source": []
886 | }
887 | ],
888 | "metadata": {
889 | "kernelspec": {
890 | "display_name": "Python 3",
891 | "language": "python",
892 | "name": "python3"
893 | },
894 | "language_info": {
895 | "codemirror_mode": {
896 | "name": "ipython",
897 | "version": 3
898 | },
899 | "file_extension": ".py",
900 | "mimetype": "text/x-python",
901 | "name": "python",
902 | "nbconvert_exporter": "python",
903 | "pygments_lexer": "ipython3",
904 | "version": "3.6.5"
905 | }
906 | },
907 | "nbformat": 4,
908 | "nbformat_minor": 2
909 | }
910 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | USE_CUDA=True
2 | VOCAB_SIZE = 20158
3 | DATA_DIR = "./data"
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import json
3 | try:
4 | import _pickle as pkl
5 | except:
6 | import cPickle as pkl
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class Dictionary(object):
13 | def __init__(self, word2idx=None, idx2word=None):
14 | if word2idx is None:
15 | word2idx = {}
16 | if idx2word is None:
17 | idx2word = []
18 | self.word2idx = word2idx
19 | self.idx2word = idx2word
20 |
21 | @property
22 | def ntoken(self):
23 | return len(self.word2idx)
24 |
25 | @property
26 | def padding_idx(self):
27 | return len(self.word2idx)
28 |
29 | def tokenize(self, sentence, add_word):
30 | sentence = sentence.lower()
31 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s')
32 | words = sentence.split()
33 | tokens = []
34 | if add_word:
35 | for w in words:
36 | tokens.append(self.add_word(w))
37 | else:
38 | for w in words:
39 | tokens.append(self.word2idx[w])
40 | return tokens
41 |
42 | def dump_to_file(self, path):
43 | pkl.dump([self.word2idx, self.idx2word], open(path, 'wb'))
44 | print('dictionary dumped to %s' % path)
45 |
46 | @classmethod
47 | def load_from_file(cls, path):
48 | print('loading dictionary from %s' % path)
49 | word2idx, idx2word = pkl.load(open(path, 'rb'))
50 | d = cls(word2idx, idx2word)
51 | return d
52 |
53 | def add_word(self, word):
54 | if word not in self.word2idx:
55 | self.idx2word.append(word)
56 | self.word2idx[word] = len(self.idx2word) - 1
57 | return self.word2idx[word]
58 |
59 | def __len__(self):
60 | return len(self.idx2word)
61 |
62 |
63 | class HMQAFeatureDataset(Dataset):
64 | def __init__(self, img_id2hqma_idx, image_features, spatial_features, qid2count, qid2count2score,
65 | name, dictionary):
66 | super(HMQAFeatureDataset, self).__init__()
67 |
68 | assert name in ["train", "dev", "test"]
69 | self.name = name
70 | self.qid2count = qid2count
71 | self.qid2count2score = qid2count2score
72 | self.qids = None
73 | if self.name == "train":
74 | self.part_qid2count = self.qid2count["train"]
75 | self.part_qid2count2score = self.qid2count2score["train"]
76 | else:
77 | self.part_qid2count = self.qid2count[self.name]
78 | self.part_qid2count2score = self.qid2count2score[self.name]
79 |
80 | self.dictionary = dictionary
81 |
82 | self.img_id2hqma_idx = img_id2hqma_idx
83 | self._features = image_features
84 | self._spatials = spatial_features
85 |
86 | self.entries = self.load_dataset()
87 |
88 | self.tokenize()
89 | self.tensorize()
90 | self.v_dim = self.features.size(2)
91 | self.s_dim = self.spatials.size(2)
92 |
93 | def prepare_entries(self, questions, part_qid2count, part_qid2count2score, qid_prefix=''):
94 | entries = []
95 |
96 | set_qids = set(part_qid2count.keys())
97 | for question in questions:
98 | question_id = str(question["question_id"])
99 | if question_id not in set_qids:
100 | # print("{} is not there".format(question_id))
101 | # break
102 | continue
103 |
104 | image_id = question['image_id']
105 | img_hqma_idx = self.img_id2hqma_idx[image_id]
106 | count = part_qid2count[question_id]
107 | score2count = part_qid2count2score[question_id]
108 | question = question["question"]
109 |
110 | entries.append({
111 | "question_id": qid_prefix + question_id,
112 | "image_id": image_id,
113 | "img_hqma_idx": img_hqma_idx,
114 | "count": count,
115 | "score2count": score2count,
116 | "question": question,
117 | })
118 |
119 | return entries
120 |
121 | def load_dataset(self):
122 | """Load entries
123 |
124 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features
125 | dataroot: root path of dataset
126 | name: 'train', 'val'
127 | """
128 |
129 | if self.name == "train":
130 | fname = "train"
131 | vqa_question_path = './data/v2_OpenEnded_mscoco_{}2014_questions.json'.format(fname)
132 | vqa_questions = sorted(json.load(open(vqa_question_path))['questions'], key=lambda x: x['question_id'])
133 |
134 | vqa_entries = self.prepare_entries(
135 | questions=vqa_questions,
136 | part_qid2count=self.part_qid2count["vqa"],
137 | part_qid2count2score=self.part_qid2count2score["vqa"],
138 | qid_prefix="vqa"
139 | )
140 |
141 | vgn_questions = sorted(json.load(open("./data/how_many_qa/vgn_ques.json")), key=lambda x: x['question_id'])
142 | vgn_entries = self.prepare_entries(
143 | questions=vgn_questions,
144 | part_qid2count=self.part_qid2count["visual_genome"],
145 | part_qid2count2score=self.part_qid2count2score["visual_genome"],
146 | qid_prefix="vgn"
147 | )
148 |
149 | return vqa_entries + vgn_entries
150 |
151 | elif self.name in ["dev", "test"]:
152 | fname = "val"
153 | question_path = './data/v2_OpenEnded_mscoco_{}2014_questions.json'.format(fname)
154 | questions = sorted(json.load(open(question_path))['questions'], key=lambda x: x['question_id'])
155 |
156 | val_entries = self.prepare_entries(
157 | questions=questions,
158 | part_qid2count=self.part_qid2count,
159 | part_qid2count2score=self.part_qid2count2score,
160 | )
161 | return val_entries
162 | else:
163 | raise Exception("uknown name '{}'".format(self.name))
164 |
165 | def tokenize(self, max_length=14):
166 | """Tokenizes the questions.
167 |
168 | This will add q_token in each entry of the dataset.
169 | -1 represent nil, and should be treated as padding_idx in embedding
170 | """
171 | for entry in self.entries:
172 | tokens = self.dictionary.tokenize(entry['question'], False)
173 | tokens = tokens[:max_length]
174 | if len(tokens) < max_length:
175 | # Note here we pad in front of the sentence
176 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens))
177 | tokens = padding + tokens
178 | assert len(tokens) == max_length
179 | entry['q_token'] = tokens
180 |
181 | def tensorize(self):
182 | # TODO: uncomment later
183 | self.features = torch.from_numpy(self._features)
184 | self.spatials = torch.from_numpy(self._spatials)
185 |
186 | for entry in self.entries:
187 | question = torch.from_numpy(np.array(entry['q_token']))
188 | entry['q_token'] = question
189 | entry["count"] = torch.from_numpy(np.array([entry["count"]]))
190 | entry["score2count"] = torch.from_numpy(np.array(entry["score2count"])).float()
191 |
192 | def __getitem__(self, index):
193 | entry = self.entries[index]
194 | features = self.features[entry['img_hqma_idx']]
195 | spatials = self.spatials[entry['img_hqma_idx']]
196 |
197 | question = entry['q_token']
198 | count = entry["count"]
199 | score2count = entry["score2count"]
200 |
201 | return features, spatials, question, count, score2count
202 |
203 | def __len__(self):
204 | return len(self.entries)
205 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from config import *
6 | from torch.nn.utils.weight_norm import weight_norm
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 | from torch.distributions.categorical import Categorical
10 |
11 |
12 | class QuestionParser(nn.Module):
13 | glove_file = DATA_DIR + "/glove6b_init_300d.npy"
14 |
15 | def __init__(self, dropout=0.3, word_dim=300, ques_dim=1024):
16 | super(QuestionParser, self).__init__()
17 |
18 | self.dropout = dropout
19 | self.word_dim = word_dim
20 | self.ques_dim = ques_dim
21 |
22 | self.embd = nn.Embedding(VOCAB_SIZE + 1, self.word_dim, padding_idx=VOCAB_SIZE)
23 | self.rnn = nn.GRU(self.word_dim, self.ques_dim)
24 | self.drop = nn.Dropout(self.dropout)
25 | self.glove_init()
26 |
27 | def glove_init(self):
28 | print("initialising with glove embeddings")
29 | glove_embds = torch.from_numpy(np.load(self.glove_file))
30 | assert glove_embds.size() == (VOCAB_SIZE, self.word_dim)
31 | self.embd.weight.data[:VOCAB_SIZE] = glove_embds
32 | print("done.")
33 |
34 | def forward(self, questions):
35 | # (B, MAXLEN)
36 | # print("question size ", questions.size())
37 | questions = questions.t() # (MAXLEN, B)
38 | questions = self.embd(questions) # (MAXLEN, B, word_size)
39 | _, (q_emb) = self.rnn(questions)
40 | q_emb = q_emb[-1] # (B, ques_size)
41 | q_emb = self.drop(q_emb)
42 |
43 | return q_emb
44 |
45 |
46 | class FCNet(nn.Module):
47 | """Simple class for non-linear fully connect network
48 | """
49 | def __init__(self, dims):
50 | super(FCNet, self).__init__()
51 |
52 | layers = []
53 | for i in range(len(dims)-2):
54 | in_dim = dims[i]
55 | out_dim = dims[i+1]
56 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
57 | layers.append(nn.LeakyReLU())
58 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
59 | layers.append(nn.LeakyReLU())
60 |
61 | self.main = nn.Sequential(*layers)
62 |
63 | def forward(self, x):
64 | return self.main(x)
65 |
66 |
67 | class ScoringFunction(nn.Module):
68 |
69 | def __init__(self, ques_dim, dropout=0.3, v_dim=2048, score_dim=1024):
70 | super(ScoringFunction, self).__init__()
71 |
72 | self.q_dim = ques_dim
73 | self.dropout = dropout
74 | self.v_dim = v_dim
75 | self.score_dim = score_dim
76 |
77 | self.v_drop = nn.Dropout(self.dropout)
78 | self.q_drop = nn.Dropout(self.dropout)
79 | self.v_proj = FCNet([self.v_dim, self.score_dim])
80 | self.q_proj = FCNet([self.q_dim, self.score_dim])
81 | self.s_drop = nn.Dropout(self.dropout)
82 |
83 | def forward(self, v, q):
84 | """
85 | v: [batch, k, vdim]
86 | q: [batch, qdim]
87 | """
88 | batch, k, _ = v.size()
89 | v = self.v_drop(v)
90 | q = self.q_drop(q)
91 | v_proj = self.v_proj(v) # [batch, k, qdim]
92 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) # [batch, k, qdim]
93 | s = v_proj * q_proj
94 | s = self.s_drop(s)
95 | return s # (B, k, score_dim)
96 |
97 |
98 | class GTUScoringFunction(nn.Module):
99 |
100 | def __init__(self, ques_dim, dropout=0.3, v_dim=2048, score_dim=2048):
101 | super(GTUScoringFunction, self).__init__()
102 |
103 | self.q_dim = ques_dim
104 | self.dropout = dropout
105 | self.v_dim = v_dim
106 | self.score_dim = score_dim
107 |
108 | self.predrop = nn.Dropout(self.dropout)
109 | self.dense1 = weight_norm(nn.Linear(self.v_dim + self.q_dim, self.score_dim), dim=None)
110 | self.dense2 = weight_norm(nn.Linear(self.v_dim + self.q_dim, self.score_dim), dim=None)
111 |
112 | self.s_drop = nn.Dropout(self.dropout)
113 |
114 | def forward(self, v, q):
115 | """
116 | v: [batch, k, vdim]
117 | q: [batch, qdim]
118 | """
119 | batch, k, _ = v.size()
120 |
121 | q = q[:, None, :].repeat(1, k, 1) # (B, k, q_dim)
122 | vq = torch.cat([v, q], dim=2) # (B, k, v_dim + q_dim)
123 |
124 | vq = self.predrop(vq)
125 |
126 | y = F.tanh(self.dense1(vq)) # (B, k, score_dim)
127 | g = F.sigmoid(self.dense2(vq)) # (B, k, score_dim)
128 |
129 | s = y * g
130 | s = self.s_drop(s)
131 | return s # (B, k, score_dim)
132 |
133 |
134 | class SoftCount(nn.Module):
135 |
136 | def __init__(self, ques_dim=1024, score_dim=512, dropout=0.1):
137 | super(SoftCount, self).__init__()
138 | self.ques_parser = QuestionParser(ques_dim=ques_dim, dropout=dropout)
139 | self.f = ScoringFunction(ques_dim=ques_dim, score_dim=score_dim, dropout=dropout)
140 | self.W = weight_norm(nn.Linear(score_dim, 1), dim=None)
141 |
142 | def forward(self, v_emb, q):
143 | # v_emb = (B, k, v_dim)
144 | # q = (B, MAXLEN)
145 |
146 | q_emb = self.ques_parser(q) # (B, q_dim)
147 | s = self.f(v_emb, q_emb) # (B, k, score_dim)
148 | soft_counts = F.sigmoid(self.W(s)).squeeze(2) # (B, k)
149 | C = soft_counts.sum(dim=1) # (B,)
150 | return C
151 |
152 |
153 | class RhoScorer(nn.Module):
154 |
155 | def __init__(self, ques_dim):
156 | super(RhoScorer, self).__init__()
157 | self.W = weight_norm(nn.Linear(ques_dim, 1), dim=None)
158 |
159 | inp_dim = 1 + 1 + 6 + 6 + 1 + 1 + 1 # 17
160 | self.f_rho = FCNet([inp_dim, 100])
161 | self.dense = weight_norm(nn.Linear(100, 1), dim=None)
162 |
163 | @staticmethod
164 | def get_spatials(b):
165 | # b = (B, k, 6)
166 |
167 | b = b.float()
168 |
169 | B, k, _ = b.size()
170 |
171 | b_ij = torch.stack([b] * k, dim=1) # (B, k, k, 6)
172 | b_ji = b_ij.transpose(1, 2)
173 |
174 | area_ij = (b_ij[:, :, :, 2] - b_ij[:, :, :, 0]) * (b_ij[:, :, :, 3] - b_ij[:, :, :, 1])
175 | area_ji = (b_ji[:, :, :, 2] - b_ji[:, :, :, 0]) * (b_ji[:, :, :, 3] - b_ji[:, :, :, 1])
176 |
177 | righmost_left = torch.max(b_ij[:, :, :, 0], b_ji[:, :, :, 0])
178 | downmost_top = torch.max(b_ij[:, :, :, 1], b_ji[:, :, :, 1])
179 | leftmost_right = torch.min(b_ij[:, :, :, 2], b_ji[:, :, :, 2])
180 | topmost_down = torch.min(b_ij[:, :, :, 3], b_ji[:, :, :, 3])
181 |
182 | # calucate the separations
183 | left_right = (leftmost_right - righmost_left)
184 | up_down = (topmost_down - downmost_top)
185 |
186 | # don't multiply negative separations,
187 | # might actually give a postive area that doesn't exit!
188 | left_right = torch.max(0*left_right, left_right)
189 | up_down = torch.max(0*up_down, up_down)
190 |
191 | overlap = left_right * up_down
192 |
193 | iou = overlap / (area_ij + area_ji - overlap)
194 | o_ij = overlap / area_ij
195 | o_ji = overlap / area_ji
196 |
197 | iou = iou.unsqueeze(3) # (B, k, k, 1)
198 | o_ij = o_ij.unsqueeze(3) # (B, k, k, 1)
199 | o_ji = o_ji.unsqueeze(3) # (B, k, k, 1)
200 |
201 | return b_ij, b_ji, iou, o_ij, o_ji
202 |
203 | def forward(self, q_emb, v_emb, b):
204 | # q_emb = (B, ques_size)
205 | # v_emb = (B, k, v_dim)
206 | # b = (B, k, 6)
207 |
208 | B, k, _ = v_emb.size()
209 |
210 | features = []
211 |
212 | wq = self.W(q_emb).squeeze(1) # (B,)
213 | wq = wq[:, None, None, None].repeat(1, k, k, 1) # (B, k, k, 1)
214 | assert wq.size() == (B, k, k, 1), "wq size is {}".format(wq.size())
215 | features.append(wq)
216 |
217 | norm_v_emb = F.normalize(v_emb, dim=2) # (B, k, v_dim)
218 | vtv = torch.bmm(norm_v_emb, norm_v_emb.transpose(1, 2)) # (B, k, k)
219 | vtv = vtv[:, :, :, None].repeat(1, 1, 1, 1) # (B, k, k, 1)
220 | assert vtv.size() == (B, k, k, 1)
221 | features.append(vtv)
222 |
223 | b_ij, b_ji, iou, o_ij, o_ji = self.get_spatials(b)
224 |
225 | assert b_ij.size() == (B, k, k, 6)
226 | assert b_ji.size() == (B, k, k, 6)
227 | assert iou.size() == (B, k, k, 1)
228 | assert o_ij.size() == (B, k, k, 1)
229 | assert o_ji.size() == (B, k, k, 1)
230 |
231 | features.append(b_ij) # (B, k, k, 6)
232 | features.append(b_ji) # (B, k, k, 6)
233 | features.append(iou) # (B, k, k, 1)
234 | features.append(o_ij) # (B, k, k, 1)
235 | features.append(o_ji) # (B, k, k, 1)
236 |
237 | features = torch.cat(features, dim=3) # (B, k, k, 17)
238 |
239 | rho = self.f_rho(features) # (B, k, k, 100)
240 | rho = self.dense(rho).squeeze(3) # (B, k, k)
241 |
242 | return rho, features # (B, k, k)
243 |
244 |
245 | class IRLC(nn.Module):
246 |
247 | def __init__(self, ques_dim=1024, score_dim=2048, dropout=0.5):
248 | super(IRLC, self).__init__()
249 | # print("question parser has zero dropout")
250 | # put zero dropout for question, because it will get dropped out in scoring anyways.
251 | self.ques_parser = QuestionParser(ques_dim=ques_dim, dropout=0)
252 | self.f_s = ScoringFunction(ques_dim=ques_dim, score_dim=score_dim, dropout=dropout)
253 | self.W = weight_norm(nn.Linear(score_dim, 1), dim=None)
254 | self.f_rho = RhoScorer(ques_dim=ques_dim)
255 |
256 | # extra custom parameters
257 | self.eps = nn.Parameter(torch.zeros(1))
258 | self.extra_params = nn.ParameterList([self.eps])
259 |
260 | def sample_action(self, probs, already_selected=None, greedy=False):
261 | # probs = (B, k+1)
262 | # already_selected = (num_timesteps, B)
263 |
264 | if already_selected is None:
265 | mask = 1
266 | else:
267 | mask = Variable(torch.ones(probs.size()))
268 | if USE_CUDA:
269 | # TODO: uncomment this, when this model works
270 | mask = mask.cuda()
271 | pass
272 | mask = mask.scatter_(1, already_selected.t(), 0) # (B, k+1)
273 |
274 | masked_probs = mask * (probs + 1e-20) # (B, k+1), add epsilon to make sure no non-masked value is zero.
275 | dist = Categorical(probs=masked_probs)
276 |
277 | if greedy:
278 | _, a = masked_probs.max(dim=1) # (B)
279 | else:
280 | a = dist.sample() # (B)
281 |
282 | log_prob = dist.log_prob(a) # (B)
283 | entropy = dist.entropy() # (B)
284 | return a, log_prob, entropy
285 |
286 | @staticmethod
287 | def get_interaction(rho, a):
288 | # get the interaction row in rho corresponding to the action a
289 | # rho = (B, num_actions, k)
290 | # a = (B) containing action indices between 0 and num_actions-1
291 |
292 | B, _, k = rho.size()
293 |
294 | # first expand a to the size required output
295 | a = a[:, None].repeat(1, k) # (B, k)
296 |
297 | # print("rho size = {} and a size = {}".format(rho.size(), a.size()))
298 | interaction = rho.gather(dim=1, index=a.unsqueeze(dim=1)).squeeze(dim=1) # (B, k)
299 | assert interaction.size() == (B, k), "interaction size is {}".format(interaction.size())
300 | # print("interaction size = {}".format(interaction.size()))
301 |
302 | return interaction # (B, k)
303 |
304 | def sample_objects(self, kappa_0, rho, batch_eps, greedy=False):
305 | # kappa_0 = (B, k)
306 | # rho = (B, k, k)
307 |
308 | # add an extra row of 0 interaction for the terminal action
309 | rho = torch.cat((rho, 0 * rho[:, :1, :]), dim=1) # (B, k+1, k)
310 |
311 | B, k = kappa_0.size()
312 |
313 | P = None # save un-scaled probabilities of each action at each time-step. mainly for visualization
314 | logPA = None # log prob values for each time-step.
315 | entP = None # distribution entropy value for each timestep.
316 | A = None # will store action values at each timestep.
317 | T = k+1 # num timesteps = different possible actions. +1 for the terminal action
318 | kappa = kappa_0 # (B, k), starting kappa
319 |
320 | for t in range(T):
321 | # calculate probabilities of each action
322 | unscaled_p = F.softmax(torch.cat((kappa, batch_eps), dim=1), dim=1) # (B, k+1)
323 | # print("p = ", p)
324 | # select one object (called "action" in RL terms), avoid already selected objects.
325 | a, log_prob, entropy = self.sample_action(
326 | probs=unscaled_p, already_selected=A, greedy=greedy) # (B,), (B,), (B,)
327 | # update kappa logits with the row in the interaction matrix corresponding to the chosen action.
328 | interaction = self.get_interaction(rho, a)
329 | kappa = kappa + interaction
330 |
331 | # record the prob and action values at each timestep for later use
332 | P = unscaled_p[None] if P is None else torch.cat((P, unscaled_p[None]), dim=0) # (t+1, B, k+1)
333 | logPA = log_prob[None] if logPA is None else torch.cat((logPA, log_prob[None]), dim=0) # (t+1, B)
334 | entP = entropy[None] if entP is None else torch.cat((entP, log_prob[None]), dim=0) # (t+1, B)
335 | A = a[None] if A is None else torch.cat((A, a[None]), dim=0) # (t+1, B)
336 |
337 | assert logPA.size() == (T, B)
338 | assert entP.size() == (T, B)
339 | assert A.size() == (T, B)
340 |
341 | # calculate count
342 | terminal_action = (A == k) # (T, B) # true for the timestep when terminal action was selected.
343 | _, count = terminal_action.max(dim=0) # (B,) # index of the terminal action is considered the count
344 |
345 | return logPA, entP, A, count, P
346 |
347 | def compute_vars(self, v_emb, b, q):
348 | # v_emb = (B, k, v_dim)
349 | # b = (B, k, 6)
350 | # q = (B, MAXLEN)
351 |
352 | B, k, _ = v_emb.size()
353 |
354 | q_emb = self.ques_parser(q) # (B, q_dim)
355 | s = self.f_s(v_emb, q_emb) # (B, k, score_dim)
356 | kappa_0 = self.W(s).squeeze(2) # (B, k)
357 |
358 | rho, _ = self.f_rho(q_emb, v_emb, b) # (B, k, k)
359 |
360 | return kappa_0, rho
361 |
362 | def take_mc_samples(self, kappa_0, rho, num_mc_samples):
363 | # kappa_0 = (B, k)
364 | # rho = (B, k, k)
365 |
366 | B, k = kappa_0.size()
367 | assert rho.size() == (B, k, k)
368 |
369 | kappa_0 = kappa_0.repeat(num_mc_samples, 1) # (B * samples, k)
370 | rho = rho.repeat(num_mc_samples, 1, 1) # (B * samples, k)
371 |
372 | batch_eps = torch.cat([self.eps] * B * num_mc_samples)[:, None] # (B * samples, 1)
373 |
374 | logPA, entP, A, count, P = self.sample_objects(kappa_0=kappa_0, rho=rho, batch_eps=batch_eps)
375 | _, _, _, greedy_count, _ = self.sample_objects(kappa_0=kappa_0, rho=rho, batch_eps=batch_eps, greedy=True)
376 |
377 | return count, greedy_count, logPA, entP, A, rho, P
378 |
379 | def get_sc_loss(self, count_gt, count, greedy_count, logPA, valid_A):
380 | # count_gt = (B,)
381 | # count = (B,)
382 | # greedy_count = (B,)
383 | # logPA = (T, B)
384 | # valid_A = (T, B)
385 |
386 | assert count.size() == count_gt.size()
387 | assert greedy_count.size() == count_gt.size()
388 |
389 | count = count.float()
390 | greedy_count = greedy_count.float()
391 | count_gt = count_gt.float()
392 |
393 | # self-critical loss
394 | E = torch.abs(count - count_gt) # (B,)
395 | E_greedy = torch.abs(greedy_count - count_gt) # (B,)
396 |
397 | R = E_greedy - E # (B,)
398 |
399 | assert R.size() == count.size(), "R size is {}".format(R.size())
400 |
401 | mean_log_PA = (logPA * valid_A).sum(dim=0) / valid_A.sum(dim=0) # (B,)
402 |
403 | batch_sc_loss = - R * mean_log_PA # (B,)
404 | sc_loss = batch_sc_loss.mean(dim=0) # (1,)
405 |
406 | return sc_loss
407 |
408 | def get_entropy_loss(self, entP, valid_A):
409 | # entP = (T, B)
410 | # valid_A = (T, B)
411 |
412 | batch_entropy_loss = - (entP * valid_A).sum(dim=0) / valid_A.sum(dim=0) # (B,)
413 | entropy_loss = batch_entropy_loss.mean(dim=0)
414 |
415 | return entropy_loss
416 |
417 | def get_interaction_strength(self, rho, A, pre_terminal_A):
418 | # rho = (B, k, k)
419 | # A = (T, B)
420 |
421 | B, k, _ = rho.size()
422 | T, B = A.size()
423 |
424 | # interaction strength, lower is better, sparse preferred.
425 | interactions = F.smooth_l1_loss(rho, 0*rho.detach(), reduce=False) # (B, k, k)
426 | interactions = interactions.mean(dim=2) # (B, k)
427 |
428 | # add a dummy interaction for the terminal action, pad zeros (value doesn't matter as it will be masked anyways.
429 | interactions = torch.cat((interactions, 0*interactions[:, :1]), dim=1) # (B, k+1)
430 |
431 | # for each timestep select the interaction corresponding to the performed action
432 | repeated_interactions = interactions[None].repeat(T, 1, 1) # (T, B, k)
433 | action_interactions = repeated_interactions.gather(dim=2, index=A.unsqueeze(2)).squeeze(2) # (T ,B)
434 |
435 | # mask out interactions due to actions done after the terminal action
436 | valid_interactions = (action_interactions * pre_terminal_A).sum(dim=0) / (1e-20 + pre_terminal_A.sum(dim=0)) # (B,)
437 |
438 | interaction_strength = valid_interactions.mean(dim=0)
439 |
440 | return interaction_strength
441 |
442 | def get_loss(self, count_gt, count, greedy_count, logPA, entP, A, rho):
443 | # count_gt = (B,)
444 | # count = (B,)
445 | # greedy_count = (B,)
446 | # logPA = (T, B)
447 | # entP = (T, B)
448 | # A = (T, B)
449 | # rho = (B, k, k)
450 |
451 | assert count.size() == count_gt.size()
452 | assert greedy_count.size() == count_gt.size()
453 |
454 | B, k, _ = rho.size()
455 |
456 | terminal_action = A.max()
457 | assert terminal_action.item() == k
458 |
459 | terminal_A = (A == terminal_action).float() # (T, B)
460 | post_terminal_A = terminal_A.cumsum(dim=0) - terminal_A # (T, B)
461 | pre_terminal_A = 1 - post_terminal_A - terminal_A
462 | valid_A = 1 - post_terminal_A # (T, B)
463 |
464 | sc_loss = self.get_sc_loss(count_gt, count, greedy_count, logPA, valid_A)
465 | entropy_loss = self.get_entropy_loss(entP, valid_A)
466 | interaction_strength = self.get_interaction_strength(rho, A, pre_terminal_A)
467 |
468 | # print("sc_loss", sc_loss, "entropy loss", entropy_loss, "interaction strength", interaction_strength)
469 |
470 | loss = 1.0 * sc_loss + .005 * entropy_loss + .005 * interaction_strength
471 |
472 | return loss
473 |
--------------------------------------------------------------------------------
/tools/compute_softscore.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import json
5 | import numpy as np
6 | import re
7 | import cPickle
8 |
9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10 | from dataset import Dictionary
11 |
12 |
13 | def utils_create_dir(path):
14 | if not os.path.exists(path):
15 | try:
16 | os.makedirs(path)
17 | except OSError as exc:
18 | if exc.errno != errno.EEXIST:
19 | raise
20 |
21 |
22 | contractions = {
23 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
24 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
25 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
26 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
27 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
28 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
29 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
30 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
31 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
32 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
33 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've":
34 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
35 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
36 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
37 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
38 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
39 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
40 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
41 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
42 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
43 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
44 | "someoned've": "someone'd've", "someone'dve": "someone'd've",
45 | "someonell": "someone'll", "someones": "someone's", "somethingd":
46 | "something'd", "somethingd've": "something'd've", "something'dve":
47 | "something'd've", "somethingll": "something'll", "thats":
48 | "that's", "thered": "there'd", "thered've": "there'd've",
49 | "there'dve": "there'd've", "therere": "there're", "theres":
50 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
51 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
52 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
53 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
54 | "weren't", "whatll": "what'll", "whatre": "what're", "whats":
55 | "what's", "whatve": "what've", "whens": "when's", "whered":
56 | "where'd", "wheres": "where's", "whereve": "where've", "whod":
57 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
58 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
59 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
60 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
61 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
62 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
63 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
64 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
65 | "you'll", "youre": "you're", "youve": "you've"
66 | }
67 |
68 | manual_map = { 'none': '0',
69 | 'zero': '0',
70 | 'one': '1',
71 | 'two': '2',
72 | 'three': '3',
73 | 'four': '4',
74 | 'five': '5',
75 | 'six': '6',
76 | 'seven': '7',
77 | 'eight': '8',
78 | 'nine': '9',
79 | 'ten': '10'}
80 | articles = ['a', 'an', 'the']
81 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
82 | comma_strip = re.compile("(\d)(\,)(\d)")
83 | punct = [';', r"/", '[', ']', '"', '{', '}',
84 | '(', ')', '=', '+', '\\', '_', '-',
85 | '>', '<', '@', '`', ',', '?', '!']
86 |
87 |
88 | def get_score(occurences):
89 | if occurences == 0:
90 | return 0
91 | elif occurences == 1:
92 | return 0.3
93 | elif occurences == 2:
94 | return 0.6
95 | elif occurences == 3:
96 | return 0.9
97 | else:
98 | return 1
99 |
100 |
101 | def process_punctuation(inText):
102 | outText = inText
103 | for p in punct:
104 | if (p + ' ' in inText or ' ' + p in inText) \
105 | or (re.search(comma_strip, inText) != None):
106 | outText = outText.replace(p, '')
107 | else:
108 | outText = outText.replace(p, ' ')
109 | outText = period_strip.sub("", outText, re.UNICODE)
110 | return outText
111 |
112 |
113 | def process_digit_article(inText):
114 | outText = []
115 | tempText = inText.lower().split()
116 | for word in tempText:
117 | word = manual_map.setdefault(word, word)
118 | if word not in articles:
119 | outText.append(word)
120 | else:
121 | pass
122 | for wordId, word in enumerate(outText):
123 | if word in contractions:
124 | outText[wordId] = contractions[word]
125 | outText = ' '.join(outText)
126 | return outText
127 |
128 |
129 | def multiple_replace(text, wordDict):
130 | for key in wordDict:
131 | text = text.replace(key, wordDict[key])
132 | return text
133 |
134 |
135 | def preprocess_answer(answer):
136 | answer = process_digit_article(process_punctuation(answer))
137 | answer = answer.replace(',', '')
138 | return answer
139 |
140 |
141 | def filter_answers(answers_dset, min_occurence):
142 | """This will change the answer to preprocessed version
143 | """
144 | occurence = {}
145 |
146 | for ans_entry in answers_dset:
147 | answers = ans_entry['answers']
148 | gtruth = ans_entry['multiple_choice_answer']
149 | gtruth = preprocess_answer(gtruth)
150 | if gtruth not in occurence:
151 | occurence[gtruth] = set()
152 | occurence[gtruth].add(ans_entry['question_id'])
153 | for answer in occurence.keys():
154 | if len(occurence[answer]) < min_occurence:
155 | occurence.pop(answer)
156 |
157 | print('Num of answers that appear >= %d times: %d' % (
158 | min_occurence, len(occurence)))
159 | return occurence
160 |
161 |
162 | def create_ans2label(occurence, name, cache_root='data/cache'):
163 | """Note that this will also create label2ans.pkl at the same time
164 |
165 | occurence: dict {answer -> whatever}
166 | name: prefix of the output file
167 | cache_root: str
168 | """
169 | ans2label = {}
170 | label2ans = []
171 | label = 0
172 | for answer in occurence:
173 | label2ans.append(answer)
174 | ans2label[answer] = label
175 | label += 1
176 |
177 | utils_create_dir(cache_root)
178 |
179 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl')
180 | cPickle.dump(ans2label, open(cache_file, 'wb'))
181 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl')
182 | cPickle.dump(label2ans, open(cache_file, 'wb'))
183 | return ans2label
184 |
185 |
186 | def compute_target(answers_dset, ans2label, name, cache_root='data/cache'):
187 | """Augment answers_dset with soft score as label
188 |
189 | ***answers_dset should be preprocessed***
190 |
191 | Write result into a cache file
192 | """
193 | target = []
194 | for ans_entry in answers_dset:
195 | answers = ans_entry['answers']
196 | answer_count = {}
197 | for answer in answers:
198 | answer_ = answer['answer']
199 | answer_count[answer_] = answer_count.get(answer_, 0) + 1
200 |
201 | labels = []
202 | scores = []
203 | counts = []
204 | for answer in answer_count:
205 | if answer not in ans2label:
206 | continue
207 | labels.append(ans2label[answer])
208 | count = answer_count[answer]
209 | counts.append(count)
210 | score = get_score(count)
211 | scores.append(score)
212 |
213 | target.append({
214 | 'question_id': ans_entry['question_id'],
215 | 'image_id': ans_entry['image_id'],
216 | 'labels': labels,
217 | 'scores': scores,
218 | 'counts': counts
219 | })
220 |
221 | utils_create_dir(cache_root)
222 | cache_file = os.path.join(cache_root, name+'_target.pkl')
223 | cPickle.dump(target, open(cache_file, 'wb'))
224 | return target
225 |
226 |
227 | def get_answer(qid, answers):
228 | for ans in answers:
229 | if ans['question_id'] == qid:
230 | return ans
231 |
232 |
233 | def get_question(qid, questions):
234 | for question in questions:
235 | if question['question_id'] == qid:
236 | return question
237 |
238 |
239 | if __name__ == '__main__':
240 | train_answer_file = 'data/v2_mscoco_train2014_annotations.json'
241 | train_answers = json.load(open(train_answer_file))['annotations']
242 |
243 | val_answer_file = 'data/v2_mscoco_val2014_annotations.json'
244 | val_answers = json.load(open(val_answer_file))['annotations']
245 |
246 | train_question_file = 'data/v2_OpenEnded_mscoco_train2014_questions.json'
247 | train_questions = json.load(open(train_question_file))['questions']
248 |
249 | val_question_file = 'data/v2_OpenEnded_mscoco_val2014_questions.json'
250 | val_questions = json.load(open(val_question_file))['questions']
251 |
252 | answers = train_answers + val_answers
253 | occurence = filter_answers(answers, 9)
254 | ans2label = create_ans2label(occurence, 'trainval')
255 | compute_target(train_answers, ans2label, 'train')
256 | compute_target(val_answers, ans2label, 'val')
257 |
--------------------------------------------------------------------------------
/tools/create_dictionary.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import json
5 | import numpy as np
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 | from dataset import Dictionary
8 |
9 |
10 | from dataset import Dictionary
11 | def create_dictionary(dataroot):
12 | dictionary = Dictionary()
13 | questions = []
14 | files = [
15 | 'v2_OpenEnded_mscoco_train2014_questions.json',
16 | 'v2_OpenEnded_mscoco_val2014_questions.json',
17 | 'v2_OpenEnded_mscoco_test2015_questions.json',
18 | 'v2_OpenEnded_mscoco_test-dev2015_questions.json',
19 | 'how_many_qa/vgn_ques.json'
20 | ]
21 | for path in files:
22 | question_path = os.path.join(dataroot, path)
23 | qs = json.load(open(question_path))
24 | if "vgn" not in path:
25 | qs = qs['questions']
26 | for q in qs:
27 | dictionary.tokenize(q['question'], True)
28 | return dictionary
29 |
30 |
31 | def create_glove_embedding_init(idx2word, glove_file):
32 | word2emb = {}
33 | with open(glove_file, 'r') as f:
34 | entries = f.readlines()
35 | emb_dim = len(entries[0].split(' ')) - 1
36 | print('embedding dim is %d' % emb_dim)
37 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32)
38 |
39 | for entry in entries:
40 | vals = entry.split(' ')
41 | word = vals[0]
42 | vals = map(float, vals[1:])
43 | word2emb[word] = np.array(vals)
44 | for idx, word in enumerate(idx2word):
45 | if word not in word2emb:
46 | continue
47 | weights[idx] = word2emb[word]
48 | return weights, word2emb
49 |
50 |
51 | if __name__ == '__main__':
52 | d = create_dictionary('data')
53 | d.dump_to_file('data/dictionary.pkl')
54 |
55 | d = Dictionary.load_from_file('data/dictionary.pkl')
56 | emb_dim = 300
57 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim
58 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file)
59 | np.save('data/glove6b_init_%dd.npy' % emb_dim, weights)
60 |
--------------------------------------------------------------------------------
/tools/create_how_many_qa_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import json
3 | import cPickle as pkl
4 | import os
5 | import re
6 |
7 |
8 | def find_image_ids():
9 |
10 | # read locations
11 | train_targets_loc = "./data/cache/train_target.pkl"
12 | val_targets_loc = "./data/cache/val_target.pkl"
13 | hmq_ids_loc = "./data/how_many_qa/HowMany-QA/question_ids.json"
14 |
15 | # write locations
16 | hmq_image_ids_loc = "./data/how_many_qa/image_ids.json"
17 | if os.path.isfile(hmq_image_ids_loc):
18 | print("The file {} already exists. Skipping finding image ids.".format(hmq_image_ids_loc))
19 | return
20 |
21 | train_targets = pkl.load(open(train_targets_loc, "rb"))
22 | val_targets = pkl.load(open(val_targets_loc, "rb"))
23 |
24 | hmq_ids = json.load(open(hmq_ids_loc, "rb"))
25 | qids = {
26 | "train": set(hmq_ids["train"]["vqa"]),
27 | "test": set(hmq_ids["test"]),
28 | "dev": set(hmq_ids["dev"]),
29 | }
30 |
31 | image_ids = {
32 | "test": [],
33 | "train": [],
34 | "dev": [],
35 | }
36 |
37 | # train
38 | for i, ans in enumerate(train_targets):
39 | if i % 10000 == 0:
40 | print(i)
41 | if ans["question_id"] in qids["train"]:
42 | image_ids["train"].append(ans["image_id"])
43 | if ans["question_id"] in qids["test"] or ans["question_id"] in qids["dev"]:
44 | raise Exception("found train question id {} in qids marked for test and dev")
45 |
46 | # dev and test
47 | for i, ans in enumerate(val_targets):
48 | if i % 10000 == 0:
49 | print(i)
50 | if ans["question_id"] in qids["train"]:
51 | raise Exception("found validation question id {} in qids marked for training")
52 | if ans["question_id"] in qids["test"]:
53 | image_ids["test"].append(ans["image_id"])
54 | if ans["question_id"] in qids["dev"]:
55 | image_ids["dev"].append(ans["image_id"])
56 |
57 | unique_image_ids = {
58 | "test": list(set(image_ids["test"])),
59 | "train": list(set(image_ids["train"])),
60 | "dev": list(set(image_ids["dev"])),
61 | }
62 |
63 | assert len(unique_image_ids["train"]) == 31932
64 | assert len(unique_image_ids["dev"]) == 13119
65 | assert len(unique_image_ids["test"]) == 2483
66 |
67 | print("writing image ids for how many QA to disk..")
68 | json.dump(unique_image_ids, open(hmq_image_ids_loc, "wb"))
69 | print("Done.")
70 | return
71 |
72 |
73 | def prepare_visual_genome():
74 |
75 | # read locations
76 | hmqa_qids_loc = "./data/how_many_qa/HowMany-QA/question_ids.json"
77 | all_vg_loc = "./data/how_many_qa/HowMany-QA/visual_genome_question_answers.json"
78 | vg_image_data_loc = "./data/how_many_qa/HowMany-QA/visual_genome_image_data.json"
79 |
80 | # write locations
81 | vgn_ques_loc = "./data/how_many_qa/vgn_ques.json"
82 | if os.path.isfile(vgn_ques_loc):
83 | print("The file {} already exists. Skipping preparing visual genome.".format(vgn_ques_loc))
84 | return
85 |
86 | hmqa_qids = json.load(open(hmqa_qids_loc, "rb"))
87 | hmqa_vg_qids = set(hmqa_qids["train"]["visual_genome"])
88 | all_vg = json.load(open(all_vg_loc))
89 | vg_image_data = json.load(open(vg_image_data_loc))
90 |
91 | vg_entries = []
92 |
93 | for qaset in all_vg:
94 | # setid = qaset['id']
95 | for entry in qaset['qas']:
96 | if entry["qa_id"] in hmqa_vg_qids:
97 | vg_entries.append(entry)
98 |
99 | vg_image_id2coco_id = {x["image_id"]: x["coco_id"] for x in vg_image_data}
100 |
101 | vgn_ques = [
102 | {'image_id': vg_image_id2coco_id[x['image_id']], 'question': x['question'], 'question_id': x['qa_id']
103 | } for x in vg_entries
104 | ]
105 |
106 | json.dump(vgn_ques, open(vgn_ques_loc, "wb"))
107 |
108 |
109 | def vg_ans2count(s):
110 | # replace all non-alphanums with space
111 | r = re.sub('[^0-9a-zA-Z]+', ' ', s)
112 |
113 | r = r.lower()
114 | r = r.split(' ')
115 |
116 | word2num = {
117 | "zero": 0,
118 | "one": 1,
119 | "two": 2,
120 | "three": 3,
121 | "four": 4,
122 | "five": 5,
123 | "six": 6,
124 | "seven": 7,
125 | "eight": 8,
126 | "nine": 9,
127 | "ten": 10,
128 | "eleven": 11,
129 | "twelve": 12,
130 | "thirteen": 13,
131 | "fourteen": 14,
132 | "fifteen": 15,
133 | "sixteen": 16,
134 | "seventeen": 17,
135 | "eighteen": 18,
136 | "nineteen": 19,
137 | "twenty": 20
138 | }
139 |
140 | cands = []
141 | for word in r:
142 | try:
143 | cands.append(int(word))
144 | except:
145 | pass
146 | try:
147 | cands.append(word2num[word])
148 | except:
149 | pass
150 |
151 | cands = list(set(cands)) # merging duplicates
152 |
153 | if len(cands) != 1:
154 | print(s, cands)
155 | if (s == "3 or 4." and cands == [3, 4]):
156 | print("manually correcting '{}'".format(s))
157 | cands = cands[:1]
158 |
159 | assert len(cands) == 1
160 | count = cands[0]
161 |
162 | assert 0 <= count <= 20
163 |
164 | return count
165 |
166 |
167 | def find_counts():
168 |
169 | # read locations
170 | _hmq_ids_loc = "./data/how_many_qa/HowMany-QA/question_ids.json"
171 | vqa_train_entries_loc = "./data/cache/train_target.pkl"
172 | test_dev_entries_loc = "./data/cache/val_target.pkl"
173 | label2ans_loc = "./data/cache/trainval_label2ans.pkl"
174 | all_vg_loc = "./data/how_many_qa/HowMany-QA/visual_genome_question_answers.json"
175 |
176 | # write locations
177 | qid2count_loc = "./data/how_many_qa/qid2count.json"
178 | qid2count2score_loc = "./data/how_many_qa/qid2count2score.json"
179 |
180 | if os.path.isfile(qid2count_loc) and os.path.isfile(qid2count2score_loc):
181 | print("The file {} and {} already exists. Skipping finding counts.".format(qid2count_loc, qid2count2score_loc))
182 | return
183 |
184 | _hmq_ids = json.load(open(_hmq_ids_loc, "rb"))
185 | hmq_ids = {
186 | "train": {
187 | "vqa": set(_hmq_ids["train"]["vqa"]),
188 | "visual_genome": set(_hmq_ids["train"]["visual_genome"]),
189 | },
190 | "test": set(_hmq_ids["test"]),
191 | "dev": set(_hmq_ids["dev"]),
192 | }
193 |
194 | vqa_train_entries = pkl.load(open(vqa_train_entries_loc, "rb"))
195 | test_dev_entries = pkl.load(open(test_dev_entries_loc, "rb"))
196 | label2ans = pkl.load(open(label2ans_loc, "rb"))
197 |
198 | qid2count = {
199 | "train": {
200 | "vqa": {},
201 | "visual_genome": {}
202 | },
203 | "test": {},
204 | "dev": {},
205 | }
206 |
207 | qid2count2score = {
208 | "train": {
209 | "vqa": {},
210 | "visual_genome": {}
211 | },
212 | "test": {},
213 | "dev": {},
214 | }
215 |
216 | # vqa train
217 | for entry in vqa_train_entries:
218 | qid = entry['question_id']
219 |
220 | if qid not in hmq_ids["train"]["vqa"]:
221 | continue
222 |
223 | gt_cands = []
224 | max_occurence_count = 0
225 |
226 | for occurence_count, score, label in zip(entry["counts"], entry["scores"], entry["labels"]):
227 | try:
228 | count = int(label2ans[label])
229 | assert count <= 20, "No {} is more (score: {})".format(count, score)
230 |
231 | if occurence_count > max_occurence_count:
232 | max_occurence_count = occurence_count
233 | gt_cands = [count]
234 | elif occurence_count == max_occurence_count:
235 | gt_cands.append(count)
236 |
237 | if qid2count2score["train"]["vqa"].get(qid) is None:
238 | qid2count2score["train"]["vqa"][qid] = [0] * 21 # count2score list mapping
239 | qid2count2score["train"]["vqa"][qid][count] = score
240 |
241 | except Exception as e:
242 | print(e)
243 | pass
244 |
245 | # select the answer with highest occurence count, in case of a tie select the minimum
246 | qid2count["train"]["vqa"][qid] = min(gt_cands)
247 |
248 | ##### VISUAL GENOME #######
249 |
250 | hmqa_qids = json.load(open(_hmq_ids_loc, "rb"))
251 | hmqa_vg_qids = set(hmqa_qids["train"]["visual_genome"])
252 | all_vg = json.load(open(all_vg_loc))
253 |
254 | vg_entries = []
255 |
256 | for qaset in all_vg:
257 | # setid = qaset['id']
258 | for entry in qaset['qas']:
259 | if entry["qa_id"] in hmqa_vg_qids:
260 | vg_entries.append(entry)
261 |
262 | for entry in vg_entries:
263 | qid = entry["qa_id"]
264 | count = vg_ans2count(entry["answer"])
265 | assert qid2count["train"]["visual_genome"].get(qid) is None
266 | assert qid2count2score["train"]["visual_genome"].get(qid) is None
267 |
268 | qid2count["train"]["visual_genome"][qid] = count
269 | qid2count2score["train"]["visual_genome"][qid] = [0] * 21
270 | qid2count2score["train"]["visual_genome"][qid][count] = 1
271 |
272 | ##################################
273 |
274 | # test and dev
275 | for entry in test_dev_entries:
276 | qid = entry['question_id']
277 |
278 | test_entry = qid in hmq_ids["test"]
279 | dev_entry = qid in hmq_ids["dev"]
280 |
281 | if not (test_entry or dev_entry):
282 | continue
283 |
284 | if test_entry and dev_entry:
285 | raise Exception("Found qid {} that is marked for both test set and train set!!".format(qid))
286 |
287 | gt_cands = []
288 | max_occurence_count = 0
289 |
290 | for occurence_count, score, label in zip(entry["counts"], entry["scores"], entry["labels"]):
291 | try:
292 | count = int(label2ans[label])
293 | assert count <= 20, "No {} is more (score: {})".format(count, score)
294 |
295 | if occurence_count > max_occurence_count:
296 | max_occurence_count = occurence_count
297 | gt_cands = [count]
298 | elif occurence_count == max_occurence_count:
299 | gt_cands.append(count)
300 |
301 | if test_entry:
302 |
303 | if qid2count2score["test"].get(qid) is None:
304 | qid2count2score["test"][qid] = [0] * 21 # count2score list mapping
305 | qid2count2score["test"][qid][count] = score
306 |
307 | if dev_entry:
308 |
309 | if qid2count2score["dev"].get(qid) is None:
310 | qid2count2score["dev"][qid] = [0] * 21 # count2score list mapping
311 | qid2count2score["dev"][qid][count] = score
312 |
313 | except Exception as e:
314 | print(e)
315 | pass
316 |
317 | # select the answer with highest occurence count, in case of a tie select the minimum
318 | if test_entry:
319 | qid2count["test"][qid] = min(gt_cands)
320 | if dev_entry:
321 | assert not test_entry
322 | qid2count["dev"][qid] = min(gt_cands)
323 |
324 | assert len(qid2count["train"]["vqa"]) == 47542
325 | assert len(qid2count["test"]) == 5000
326 | assert len(qid2count["dev"]) == 17714
327 |
328 | json.dump(qid2count, open(qid2count_loc, "w"))
329 | json.dump(qid2count2score, open(qid2count2score_loc, "w"))
330 | return
331 |
332 |
333 | def main():
334 | find_image_ids()
335 | prepare_visual_genome()
336 | find_counts()
337 |
338 |
339 | if __name__ == '__main__':
340 | main()
341 |
342 |
--------------------------------------------------------------------------------
/tools/detection_features_converter.py:
--------------------------------------------------------------------------------
1 | """
2 | Reads in a tsv file with pre-trained bottom up attention features and
3 | stores it in HDF5 format. Also store {image_id: feature_idx}
4 | as a pickle file.
5 |
6 | Hierarchy of HDF5 file:
7 |
8 | { 'image_features': num_images x num_boxes x 2048 array of features
9 | 'image_bb': num_images x num_boxes x 4 array of bounding boxes }
10 | """
11 | from __future__ import print_function
12 |
13 | import os
14 | import sys
15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16 |
17 | import base64
18 | import csv
19 | import h5py
20 | import cPickle
21 | import numpy as np
22 | import utils
23 |
24 |
25 | csv.field_size_limit(sys.maxsize)
26 |
27 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features']
28 | infile = 'data/trainval_36/trainval_resnet101_faster_rcnn_genome_36.tsv'
29 | train_data_file = 'data/train36.hdf5'
30 | val_data_file = 'data/val36.hdf5'
31 | train_indices_file = 'data/train36_imgid2idx.pkl'
32 | val_indices_file = 'data/val36_imgid2idx.pkl'
33 | train_ids_file = 'data/train_ids.pkl'
34 | val_ids_file = 'data/val_ids.pkl'
35 |
36 | feature_length = 2048
37 | num_fixed_boxes = 36
38 |
39 |
40 | if __name__ == '__main__':
41 | h_train = h5py.File(train_data_file, "w")
42 | h_val = h5py.File(val_data_file, "w")
43 |
44 | if os.path.exists(train_ids_file) and os.path.exists(val_ids_file):
45 | train_imgids = cPickle.load(open(train_ids_file))
46 | val_imgids = cPickle.load(open(val_ids_file))
47 | else:
48 | train_imgids = utils.load_imageid('data/train2014')
49 | val_imgids = utils.load_imageid('data/val2014')
50 | cPickle.dump(train_imgids, open(train_ids_file, 'wb'))
51 | cPickle.dump(val_imgids, open(val_ids_file, 'wb'))
52 |
53 | train_indices = {}
54 | val_indices = {}
55 |
56 | train_img_features = h_train.create_dataset(
57 | 'image_features', (len(train_imgids), num_fixed_boxes, feature_length), 'f')
58 | train_img_bb = h_train.create_dataset(
59 | 'image_bb', (len(train_imgids), num_fixed_boxes, 4), 'f')
60 | train_spatial_img_features = h_train.create_dataset(
61 | 'spatial_features', (len(train_imgids), num_fixed_boxes, 6), 'f')
62 |
63 | val_img_bb = h_val.create_dataset(
64 | 'image_bb', (len(val_imgids), num_fixed_boxes, 4), 'f')
65 | val_img_features = h_val.create_dataset(
66 | 'image_features', (len(val_imgids), num_fixed_boxes, feature_length), 'f')
67 | val_spatial_img_features = h_val.create_dataset(
68 | 'spatial_features', (len(val_imgids), num_fixed_boxes, 6), 'f')
69 |
70 | train_counter = 0
71 | val_counter = 0
72 |
73 | print("reading tsv...")
74 | with open(infile, "r+b") as tsv_in_file:
75 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES)
76 | for item in reader:
77 | item['num_boxes'] = int(item['num_boxes'])
78 | image_id = int(item['image_id'])
79 | image_w = float(item['image_w'])
80 | image_h = float(item['image_h'])
81 | bboxes = np.frombuffer(
82 | base64.decodestring(item['boxes']),
83 | dtype=np.float32).reshape((item['num_boxes'], -1))
84 |
85 | box_width = bboxes[:, 2] - bboxes[:, 0]
86 | box_height = bboxes[:, 3] - bboxes[:, 1]
87 | scaled_width = box_width / image_w
88 | scaled_height = box_height / image_h
89 | scaled_x = bboxes[:, 0] / image_w
90 | scaled_y = bboxes[:, 1] / image_h
91 |
92 | box_width = box_width[..., np.newaxis]
93 | box_height = box_height[..., np.newaxis]
94 | scaled_width = scaled_width[..., np.newaxis]
95 | scaled_height = scaled_height[..., np.newaxis]
96 | scaled_x = scaled_x[..., np.newaxis]
97 | scaled_y = scaled_y[..., np.newaxis]
98 |
99 | spatial_features = np.concatenate(
100 | (scaled_x,
101 | scaled_y,
102 | scaled_x + scaled_width,
103 | scaled_y + scaled_height,
104 | scaled_width,
105 | scaled_height),
106 | axis=1)
107 |
108 | if image_id in train_imgids:
109 | train_imgids.remove(image_id)
110 | train_indices[image_id] = train_counter
111 | train_img_bb[train_counter, :, :] = bboxes
112 | train_img_features[train_counter, :, :] = np.frombuffer(
113 | base64.decodestring(item['features']),
114 | dtype=np.float32).reshape((item['num_boxes'], -1))
115 | train_spatial_img_features[train_counter, :, :] = spatial_features
116 | train_counter += 1
117 | elif image_id in val_imgids:
118 | val_imgids.remove(image_id)
119 | val_indices[image_id] = val_counter
120 | val_img_bb[val_counter, :, :] = bboxes
121 | val_img_features[val_counter, :, :] = np.frombuffer(
122 | base64.decodestring(item['features']),
123 | dtype=np.float32).reshape((item['num_boxes'], -1))
124 | val_spatial_img_features[val_counter, :, :] = spatial_features
125 | val_counter += 1
126 | else:
127 | assert False, 'Unknown image id: %d' % image_id
128 |
129 | if len(train_imgids) != 0:
130 | print('Warning: train_image_ids is not empty')
131 |
132 | if len(val_imgids) != 0:
133 | print('Warning: val_image_ids is not empty')
134 |
135 | cPickle.dump(train_indices, open(train_indices_file, 'wb'))
136 | cPickle.dump(val_indices, open(val_indices_file, 'wb'))
137 | h_train.close()
138 | h_val.close()
139 | print("done!")
140 |
--------------------------------------------------------------------------------
/tools/download.sh:
--------------------------------------------------------------------------------
1 | ## Script for downloading data
2 |
3 | # GloVe Vectors
4 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip
5 | unzip data/glove.6B.zip -d data/glove
6 | # rm data/glove.6B.zip
7 |
8 | # Questions
9 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Questions_Train_mscoco.zip
10 | unzip data/v2_Questions_Train_mscoco.zip -d data
11 | # rm data/v2_Questions_Train_mscoco.zip
12 |
13 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Questions_Val_mscoco.zip
14 | unzip data/v2_Questions_Val_mscoco.zip -d data
15 | # rm data/v2_Questions_Val_mscoco.zip
16 |
17 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Questions_Test_mscoco.zip
18 | unzip data/v2_Questions_Test_mscoco.zip -d data
19 | # rm data/v2_Questions_Test_mscoco.zip
20 |
21 | # Annotations
22 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Annotations_Train_mscoco.zip
23 | unzip data/v2_Annotations_Train_mscoco.zip -d data
24 | # rm data/v2_Annotations_Train_mscoco.zip
25 |
26 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Annotations_Val_mscoco.zip
27 | unzip data/v2_Annotations_Val_mscoco.zip -d data
28 | # rm data/v2_Annotations_Val_mscoco.zip
29 |
30 | # Image Features
31 | wget -P data https://imagecaption.blob.core.windows.net/imagecaption/trainval_36.zip
32 | unzip data/trainval_36.zip -d data
33 | # rm data/trainval_36.zip
34 |
--------------------------------------------------------------------------------
/tools/download_hmqa.sh:
--------------------------------------------------------------------------------
1 | # HowManQA
2 | wget -P data https://einstein.ai/research/interpretable-counting-for-visual-question-answering/HowMany-QA.zip
3 | unzip data/HowMany-QA.zip -d data/how_many_qa
4 |
5 | # Visual Genome
6 | wget -P data https://visualgenome.org/static/data/dataset/question_answers.json.zip
7 | unzip data/question_answers.json.zip -d data/how_many_qa
8 |
9 | wget -P data https://visualgenome.org/static/data/dataset/image_data.json.zip
10 | unzip data/image_data.json.zip -d data/how_many_qa
11 |
12 | mv data/how_many_qa/question_answers.json data/how_many_qa/HowMany-QA/visual_genome_question_answers.json
13 | mv data/how_many_qa/image_data.json data/how_many_qa/HowMany-QA/visual_genome_image_data.json
14 |
--------------------------------------------------------------------------------
/tools/process.sh:
--------------------------------------------------------------------------------
1 | # Process data
2 |
3 | python tools/create_dictionary.py
4 | python tools/compute_softscore.py
5 | python tools/detection_features_converter.py
6 |
--------------------------------------------------------------------------------
/tools/process_hmqa.sh:
--------------------------------------------------------------------------------
1 | python tools/create_how_many_qa_dataset.py
--------------------------------------------------------------------------------
/vis/00-selection_image-335.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/00-selection_image-335.png
--------------------------------------------------------------------------------
/vis/00-selection_image-364.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/00-selection_image-364.png
--------------------------------------------------------------------------------
/vis/01-selection_image-335.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/01-selection_image-335.png
--------------------------------------------------------------------------------
/vis/01-selection_image-364.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/01-selection_image-364.png
--------------------------------------------------------------------------------
/vis/02-selection_image-335.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/02-selection_image-335.png
--------------------------------------------------------------------------------
/vis/02-selection_image-364.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/02-selection_image-364.png
--------------------------------------------------------------------------------
/vis/03-selection_image-335.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/03-selection_image-335.png
--------------------------------------------------------------------------------
/vis/04-selection_image-335.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/04-selection_image-335.png
--------------------------------------------------------------------------------
/vis/image_candidates-335.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/image_candidates-335.png
--------------------------------------------------------------------------------
/vis/image_candidates-364.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/image_candidates-364.png
--------------------------------------------------------------------------------
/vis/orig_image-335.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/orig_image-335.png
--------------------------------------------------------------------------------
/vis/orig_image-364.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/orig_image-364.png
--------------------------------------------------------------------------------