├── .gitignore ├── LICENSE ├── README.md ├── Training IRLC.ipynb ├── Training SoftCount.ipynb ├── Visualize IRLC.ipynb ├── config.py ├── dataset.py ├── model.py ├── tools ├── compute_softscore.py ├── create_dictionary.py ├── create_how_many_qa_dataset.py ├── detection_features_converter.py ├── download.sh ├── download_hmqa.sh ├── process.sh └── process_hmqa.sh └── vis ├── 00-selection_image-335.png ├── 00-selection_image-364.png ├── 01-selection_image-335.png ├── 01-selection_image-364.png ├── 02-selection_image-335.png ├── 02-selection_image-364.png ├── 03-selection_image-335.png ├── 04-selection_image-335.png ├── image_candidates-335.png ├── image_candidates-364.png ├── orig_image-335.png └── orig_image-364.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | saved_models/ 3 | .idea 4 | data 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanyam Agarwal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # irlc-vqa 2 | Code for **[Interpretable Counting for Visual Question Answering](https://arxiv.org/pdf/1712.08697.pdf)** for ICLR 2018 reproducibility challenge. 3 | 4 | ## About 5 | The paper improves upon the state-of-the art accuracy for counting based questions in VQA. They do it by enforcing the prior that each count corresponds to a well defined region in the image and is not diffused all over it. They hard-attend over a fixed set of candiate regions (taken from pre-trained Faster-R-CNN network) in the image by fusing it with the information from the question. They use a variant of REINFORCE - Self Critical Training - which is well suited for generating sequences. 6 | 7 | I found the paper quite interesting. Since I could not find any publicly available implementation of this paper I decided to implement it as a self-excercise. 8 | 9 | 10 | ## Results (without caption grounding) 11 | 12 | #### SoftCount 13 | | Model | Test Accuracy | Test RMSE | Training Time 14 | | --- | --- | -- | -- | 15 | | Reported | 49.2 | 2.45 | Unknown | 16 | | This implementation | **49.7** | **2.31** | ~12 minutes (Nvidia-1080 Ti) | 17 | 18 | #### IRLC 19 | | Model | Test Accuracy | Test RMSE | Training Time 20 | | --- | --- | -- | -- | 21 | | Reported | **56.1** | 2.45 | Unknown | 22 | | This implementation | 55.7* | **2.41** | ~6 hours (Nvidia-1080 Ti) | 23 | 24 | *= Still improving. Work in Progress. 25 | 26 | The **accuracy** was calculated using the [VQA evaluation metric](http://www.visualqa.org/evaluation.html). I used the exact same script for calculating "soft score" as in https://github.com/hengyuan-hu/bottom-up-attention-vqa. 27 | 28 | **RMSE** = root mean squared error from the ground truth (see below for how ground truth was chosen for VQA). 29 | 30 | **Note**: These numbers correspond to the test accuracy and RMSE when the accuracy on the development set was maximum. The peak test accuracy is usually higher by about a percent. 31 | 32 | 33 | ## Key differences from the paper 34 | - GRU was used instead of LSTM for generating question embeddings. Experiments with LSTM led to slower learning and more over-fitting. More hyper-parameter search is required to fix this. 35 | 36 | - Gated Tanh Unit is not used. Instead, a 2-layer Leaky ReLu based network inspired by https://github.com/hengyuan-hu/bottom-up-attention-vqa with slight modifications is used. 37 | 38 | 39 | ## Filling in missing details in the paper 40 | 41 | #### VQA Ground Truth 42 | I couldn't find any annotations for a "single ground truth" which is requred to calculate the REINFORCE reward in IRLC. Also, I could not find any details in the paper relating to this issue. So I took as ground truth the label that was reported as the answer most number of times. In case there are more than one such label, the one having the least numerical value was picked (this might explain a lower RMSE). 43 | 44 | #### Number of epochs 45 | The authors mentioned that they use early stopping based on the development set accuracy but I couldn't find an exact method to determine when to stop. So I run the training for 100 epochs for IRLC and 20 epochs for SoftCount. 46 | 47 | #### Number of candidate objects 48 | I could not find the value of N = number of candidate objects that are taken from Faster-R-CNN so following https://github.com/hengyuan-hu/bottom-up-attention-vqa I took N=36. 49 | 50 | ## Minor discrepancies 51 | 52 | #### Number of images due to Visual Genome 53 | From Table 1 in the paper, it would seem that adding the extra data from Visual Genome doesn't change the number of training images (31932). However, while writing the dataloaders for Visual Genome I noticed around 45k images after including the visual genome dataset. This is not really a big issue, but I still thought I'd write it so that other people can avoid wasting their time investigating it. 54 | 55 | ## Other Implementation Details 56 | 57 | - This implementation borrows most of its pre-processing and data loading code from https://github.com/hengyuan-hu/bottom-up-attention-vqa 58 | 59 | - The optional "caption grounding" step was skipped since it only improved the accuracy by a percent or so and was not the main focus of the paper. 60 | 61 | - The authors didn't mention the amount of dropout they used. After trying a few values the value 0.5 was chosen. 62 | 63 | - The value for number of samples was kept to 32 (instead of 5 as mentioned in the paper). The value 32 was chosen because it was the maximum value for which the training time did not suffer significantly. The effects of changing sample size on accuracy were not tested. 64 | 65 | - All other parameters were kept same. Optimizer, learning rate, learning schedule, etc. are exactly the same as mentioned in the paper. 66 | 67 | ## Usage 68 | #### Prerequisites 69 | Make sure you are on a machine with an NVIDIA GPU and Python 3 with about 100 GB disk space. Python 2 might be required for running some scripts in ./tools (will try to fix this soon) 70 | 71 | #### Installation 72 | - Install PyTorch v0.4 with CUDA and Python 3.5. 73 | - Install h5py. 74 | 75 | #### Data Setup 76 | All data should be downloaded to a 'data/' directory in the root directory of this repository. 77 | 78 | The easiest way to download the data is to run the provided scripts `tools/download.sh` and then `tools/download_hmqa.sh` from the repository root. If the script does not work, it should be easy to examine the script and modify the steps outlined in it according to your needs. Then run `tools/process.sh` and `tools/process_hmqa.sh` from the repository root to process the data to the correct format. Some scripts in `tools/process.sh` might require Python2 (I am working on fixing this). 79 | 80 | #### Training 81 | Simply execute the cells in the IPython notebook `Training IRLC.ipynb` to start training. The development and testing scores will be printed every epoch. The model is saved every 10 epochs under the `saved_models` directory. 82 | 83 | #### Visualization 84 | Simply follow the `Visualize IRLC.ipynb` notebook. For this you will need to download MS-COCO images. 85 | 86 | Here are sample visualizations: 87 | 88 | Ques: How many sheepskin are grazing? 89 | 90 | Ans: 4 91 | 92 | Pred: 4 93 | 94 |
 Original Image                                        Candidates Objects in image 
95 | 96 | 97 | 98 | IRLC assigns different probabilities to each candidate. Dark blue boxes is more probable, faint blue is less probable. IRLC then picks most probable ones first. The picked up objects are show in bright red boxes. 99 | 100 |
 timestep=0                                              timestep=1 
101 | 102 |
 timestep=2                                              timestep=3 
103 | 104 |
 timestep=4   
105 | 106 | 107 | 108 | ## Acknowledgements 109 | The repository https://github.com/hengyuan-hu/bottom-up-attention-vqa was a huge help. It would have taken me a week at the least to write all code for pre-processing the data myself. A big thanks to the authors of this repository! 110 | -------------------------------------------------------------------------------- /Training SoftCount.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", 21 | "# os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "from torch.autograd import Variable\n", 32 | "\n", 33 | "from dataset import Dictionary, HMQAFeatureDataset\n", 34 | "from model import SoftCount\n", 35 | "from config import *\n", 36 | "from datetime import datetime, timedelta\n", 37 | "\n", 38 | "import h5py\n", 39 | "import numpy as np\n", 40 | "import _pickle as pkl\n", 41 | "import json\n", 42 | "import torch.nn.functional as F" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "loading dictionary from data/dictionary.pkl\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "dictionary = Dictionary.load_from_file('data/dictionary.pkl')" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 6, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "%%time\n", 69 | "print('loading features from train hdf5 file')\n", 70 | "train_h5_loc = './data/train36.hdf5'\n", 71 | "with h5py.File(train_h5_loc, 'r') as hf:\n", 72 | " train_image_features = np.array(hf.get('image_features'))\n", 73 | " train_spatials_features = np.array(hf.get('spatial_features'))\n", 74 | "# np.save( open(\"/tmp/vqa/train_image_features\", \"wb\"), train_image_features)\n", 75 | "# np.save( open(\"/tmp/vqa/train_spatials_features\", \"wb\"), train_spatials_features)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 7, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "CPU times: user 80 ms, sys: 11.7 s, total: 11.8 s\n", 88 | "Wall time: 4min 25s\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "# %%time\n", 94 | "# train_image_features = np.load(open(\"/tmp/vqa/train_image_features\", \"rb\"))\n", 95 | "# train_spatials_features = np.load(open(\"/tmp/vqa/train_spatials_features\", \"rb\"))" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 8, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "from dataset import HMQAFeatureDataset\n", 105 | "\n", 106 | "hmqa_train_dset = HMQAFeatureDataset(\n", 107 | " img_id2hqma_idx = pkl.load(open(\"./data/train36_imgid2idx.pkl\", \"rb\")),\n", 108 | " image_features = train_image_features, \n", 109 | " spatial_features = train_spatials_features, \n", 110 | " qid2count = json.load(open(\"./data/how_many_qa/qid2count.json\", \"rb\")), \n", 111 | " qid2count2score = json.load(open(\"./data/how_many_qa/qid2count2score.json\", \"rb\")), \n", 112 | " name=\"train\", \n", 113 | " dictionary=dictionary\n", 114 | ")\n", 115 | "del HMQAFeatureDataset" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 9, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "83642" 127 | ] 128 | }, 129 | "execution_count": 9, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "len(hmqa_train_dset)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 10, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "text/plain": [ 146 | "45546" 147 | ] 148 | }, 149 | "execution_count": 10, 150 | "metadata": {}, 151 | "output_type": "execute_result" 152 | } 153 | ], 154 | "source": [ 155 | "len(set([x[\"image_id\"] for x in hmqa_train_dset.entries]))" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 11, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "%%time\n", 165 | "print('loading features from val hdf5 file')\n", 166 | "val_h5_loc = './data/val36.hdf5'\n", 167 | "with h5py.File(val_h5_loc, 'r') as hf:\n", 168 | " val_image_features = np.array(hf.get('image_features'))\n", 169 | " val_spatials_features = np.array(hf.get('spatial_features'))\n", 170 | "# np.save( open(\"/tmp/vqa/val_image_features\", \"wb\"), val_image_features)\n", 171 | "# np.save( open(\"/tmp/vqa/val_spatials_features\", \"wb\"), val_spatials_features)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 12, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "CPU times: user 44 ms, sys: 22.7 s, total: 22.7 s\n", 184 | "Wall time: 1min 9s\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "# %%time\n", 190 | "# val_image_features = np.load(open(\"/tmp/vqa/val_image_features\", \"rb\"))\n", 191 | "# val_spatials_features = np.load(open(\"/tmp/vqa/val_spatials_features\", \"rb\"))" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 13, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "# len(train_image_features)\n", 201 | "\n", 202 | "from dataset import HMQAFeatureDataset\n", 203 | "\n", 204 | "hmqa_dev_dset = HMQAFeatureDataset(\n", 205 | " img_id2hqma_idx = pkl.load(open(\"./data/val36_imgid2idx.pkl\", \"rb\")),\n", 206 | " image_features = val_image_features, \n", 207 | " spatial_features = val_spatials_features, \n", 208 | " qid2count = json.load(open(\"./data/how_many_qa/qid2count.json\", \"rb\")), \n", 209 | " qid2count2score = json.load(open(\"./data/how_many_qa/qid2count2score.json\", \"rb\")), \n", 210 | " name=\"dev\", \n", 211 | " dictionary=dictionary\n", 212 | ")\n", 213 | "\n", 214 | "hmqa_test_dset = HMQAFeatureDataset(\n", 215 | " img_id2hqma_idx = pkl.load(open(\"./data/val36_imgid2idx.pkl\", \"rb\")),\n", 216 | " image_features = val_image_features, \n", 217 | " spatial_features = val_spatials_features, \n", 218 | " qid2count = json.load(open(\"./data/how_many_qa/qid2count.json\", \"rb\")), \n", 219 | " qid2count2score = json.load(open(\"./data/how_many_qa/qid2count2score.json\", \"rb\")), \n", 220 | " name=\"test\", \n", 221 | " dictionary=dictionary\n", 222 | ")\n", 223 | "del HMQAFeatureDataset" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 14, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "(17714, 5000)" 235 | ] 236 | }, 237 | "execution_count": 14, 238 | "metadata": {}, 239 | "output_type": "execute_result" 240 | } 241 | ], 242 | "source": [ 243 | "len(hmqa_dev_dset), len(hmqa_test_dset)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 15, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "from torch.utils.data import DataLoader\n", 253 | "\n", 254 | "hmqa_train_loader = DataLoader(hmqa_train_dset, 64, shuffle=True, num_workers=0)\n", 255 | "hmqa_dev_loader = DataLoader(hmqa_dev_dset, 64, shuffle=True, num_workers=0)\n", 256 | "hmqa_test_loader = DataLoader(hmqa_test_dset, 64, shuffle=True, num_workers=0)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 55, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "def evaluate(model, hmqa_loader):\n", 266 | " \n", 267 | " all_acc = []\n", 268 | " all_se = []\n", 269 | " for i, (v_emb, b, q, c, c2s) in enumerate(hmqa_loader):\n", 270 | " v_emb = Variable(v_emb)\n", 271 | " q = Variable(q)\n", 272 | " c = Variable(c).float()\n", 273 | " \n", 274 | " if USE_CUDA:\n", 275 | " v_emb = v_emb.cuda()\n", 276 | " q = q.cuda()\n", 277 | " c = c.cuda()\n", 278 | "\n", 279 | " pred = model(v_emb, q)\n", 280 | " \n", 281 | " nearest_pred = (pred + 0.5).long().clamp(0, 20)\n", 282 | " for one_c, one_c2s, one_pred in zip(c, c2s, nearest_pred):\n", 283 | " one_c = one_c.cpu().data\n", 284 | " one_pred = one_pred.cpu().data\n", 285 | " \n", 286 | " all_se.append((one_c - one_pred.float()) ** 2)\n", 287 | " all_acc.append(one_c2s[one_pred])\n", 288 | " \n", 289 | " acc = torch.stack(all_acc).mean()\n", 290 | " rmse = torch.stack(all_se).mean() ** 0.5\n", 291 | " \n", 292 | " return acc, rmse" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 67, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "initialising with glove embeddings\n", 305 | "done.\n" 306 | ] 307 | } 308 | ], 309 | "source": [ 310 | "from model import SoftCount\n", 311 | "model = SoftCount(ques_dim=1024, score_dim=512, dropout=0.2)\n", 312 | "del SoftCount" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 68, 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "data": { 322 | "text/plain": [ 323 | "SoftCount(\n", 324 | " (ques_parser): QuestionParser(\n", 325 | " (embd): Embedding(20159, 300, padding_idx=20158)\n", 326 | " (rnn): GRU(300, 1024)\n", 327 | " (drop): Dropout(p=0.2)\n", 328 | " )\n", 329 | " (f): ScoringFunction(\n", 330 | " (v_drop): Dropout(p=0.2)\n", 331 | " (q_drop): Dropout(p=0.2)\n", 332 | " (v_proj): FCNet(\n", 333 | " (main): Sequential(\n", 334 | " (0): Linear(in_features=2048, out_features=512, bias=True)\n", 335 | " (1): LeakyReLU(negative_slope=0.01)\n", 336 | " )\n", 337 | " )\n", 338 | " (q_proj): FCNet(\n", 339 | " (main): Sequential(\n", 340 | " (0): Linear(in_features=1024, out_features=512, bias=True)\n", 341 | " (1): LeakyReLU(negative_slope=0.01)\n", 342 | " )\n", 343 | " )\n", 344 | " (s_drop): Dropout(p=0.2)\n", 345 | " )\n", 346 | " (W): Linear(in_features=512, out_features=1, bias=True)\n", 347 | ")" 348 | ] 349 | }, 350 | "execution_count": 68, 351 | "metadata": {}, 352 | "output_type": "execute_result" 353 | } 354 | ], 355 | "source": [ 356 | "if USE_CUDA:\n", 357 | " model.cuda()\n", 358 | "model" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 69, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "data": { 368 | "text/plain": [ 369 | "(tensor(1.00000e-03 *\n", 370 | " 3.0400), tensor(15.4749))" 371 | ] 372 | }, 373 | "execution_count": 69, 374 | "metadata": {}, 375 | "output_type": "execute_result" 376 | } 377 | ], 378 | "source": [ 379 | "test_acc, test_rmse = evaluate(model, hmqa_test_loader)\n", 380 | "test_acc, test_rmse" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 70, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "opt = torch.optim.Adam(model.parameters(), lr=3e-4)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 71, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "test_accs = []\n", 399 | "test_rmses = []\n", 400 | "\n", 401 | "dev_accs = []\n", 402 | "dev_rmses = []" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 72, 408 | "metadata": { 409 | "scrolled": true 410 | }, 411 | "outputs": [ 412 | { 413 | "name": "stdout", 414 | "output_type": "stream", 415 | "text": [ 416 | "epoch = 0, i = 0, loss = 14.628400802612305\n" 417 | ] 418 | }, 419 | { 420 | "name": "stderr", 421 | "output_type": "stream", 422 | "text": [ 423 | "/home/sanyam/miniconda3/envs/py3t4/lib/python3.6/site-packages/ipykernel_launcher.py:21: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.\n" 424 | ] 425 | }, 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "epoch = 0, i = 100, loss = 1.1435465812683105\n", 431 | "epoch = 0, i = 200, loss = 1.0605370998382568\n", 432 | "epoch = 0, i = 300, loss = 0.9011174440383911\n", 433 | "epoch = 0, i = 400, loss = 1.6539443731307983\n", 434 | "epoch = 0, i = 500, loss = 1.0184824466705322\n", 435 | "epoch = 0, i = 600, loss = 1.2407264709472656\n", 436 | "epoch = 0, i = 700, loss = 1.3670611381530762\n", 437 | "epoch = 0, i = 800, loss = 1.5058083534240723\n", 438 | "epoch = 0, i = 900, loss = 0.9516408443450928\n", 439 | "epoch = 0, i = 1000, loss = 0.7743334174156189\n", 440 | "epoch = 0, i = 1100, loss = 0.8096193671226501\n", 441 | "epoch = 0, i = 1200, loss = 1.2751123905181885\n", 442 | "epoch = 0, i = 1300, loss = 1.3818869590759277\n", 443 | "evaluating model on dev and test...\n", 444 | "dev_acc: 0.41113242506980896, dev_rmse: 2.9276204109191895\n", 445 | "test_acc: 0.4032999873161316, test_rmse: 2.686856985092163\n", 446 | "epoch = 1, i = 0, loss = 1.0717486143112183\n", 447 | "epoch = 1, i = 100, loss = 0.7833133935928345\n", 448 | "epoch = 1, i = 200, loss = 0.8350750207901001\n", 449 | "epoch = 1, i = 300, loss = 0.9446524381637573\n", 450 | "epoch = 1, i = 400, loss = 0.6972931623458862\n", 451 | "epoch = 1, i = 500, loss = 0.9530639052391052\n", 452 | "epoch = 1, i = 600, loss = 0.7725903987884521\n", 453 | "epoch = 1, i = 700, loss = 0.7252683639526367\n", 454 | "epoch = 1, i = 800, loss = 0.7162841558456421\n", 455 | "epoch = 1, i = 900, loss = 0.8944467306137085\n", 456 | "epoch = 1, i = 1000, loss = 0.5890282392501831\n", 457 | "epoch = 1, i = 1100, loss = 0.7408154606819153\n", 458 | "epoch = 1, i = 1200, loss = 0.7860620617866516\n", 459 | "epoch = 1, i = 1300, loss = 0.9527196884155273\n", 460 | "evaluating model on dev and test...\n", 461 | "dev_acc: 0.438692569732666, dev_rmse: 2.9244272708892822\n", 462 | "test_acc: 0.43577998876571655, test_rmse: 2.673985719680786\n", 463 | "epoch = 2, i = 0, loss = 0.5547971725463867\n", 464 | "epoch = 2, i = 100, loss = 0.8121598958969116\n", 465 | "epoch = 2, i = 200, loss = 1.082768201828003\n", 466 | "epoch = 2, i = 300, loss = 0.9104152917861938\n", 467 | "epoch = 2, i = 400, loss = 0.6321248412132263\n", 468 | "epoch = 2, i = 500, loss = 0.538659930229187\n", 469 | "epoch = 2, i = 600, loss = 0.5090115666389465\n", 470 | "epoch = 2, i = 700, loss = 0.9042094349861145\n", 471 | "epoch = 2, i = 800, loss = 0.6963248252868652\n", 472 | "epoch = 2, i = 900, loss = 0.7098879218101501\n", 473 | "epoch = 2, i = 1000, loss = 0.6962026357650757\n", 474 | "epoch = 2, i = 1100, loss = 0.8791968822479248\n", 475 | "epoch = 2, i = 1200, loss = 1.2076174020767212\n", 476 | "epoch = 2, i = 1300, loss = 0.6140683889389038\n", 477 | "evaluating model on dev and test...\n", 478 | "dev_acc: 0.4398724138736725, dev_rmse: 2.7906925678253174\n", 479 | "test_acc: 0.43893998861312866, test_rmse: 2.562772035598755\n", 480 | "epoch = 3, i = 0, loss = 1.0250123739242554\n", 481 | "epoch = 3, i = 100, loss = 0.5901322364807129\n", 482 | "epoch = 3, i = 200, loss = 0.6190376877784729\n", 483 | "epoch = 3, i = 300, loss = 0.747307300567627\n", 484 | "epoch = 3, i = 400, loss = 0.712824821472168\n", 485 | "epoch = 3, i = 500, loss = 0.5977566242218018\n", 486 | "epoch = 3, i = 600, loss = 0.9443403482437134\n", 487 | "epoch = 3, i = 700, loss = 0.537765383720398\n", 488 | "epoch = 3, i = 800, loss = 0.5435270071029663\n", 489 | "epoch = 3, i = 900, loss = 0.4770953953266144\n", 490 | "epoch = 3, i = 1000, loss = 0.8796859979629517\n", 491 | "epoch = 3, i = 1100, loss = 0.4031006097793579\n", 492 | "epoch = 3, i = 1200, loss = 0.8212300539016724\n", 493 | "epoch = 3, i = 1300, loss = 0.9886336326599121\n", 494 | "evaluating model on dev and test...\n", 495 | "dev_acc: 0.45250648260116577, dev_rmse: 2.733506202697754\n", 496 | "test_acc: 0.46630001068115234, test_rmse: 2.499239921569824\n", 497 | "epoch = 4, i = 0, loss = 0.8000674843788147\n", 498 | "epoch = 4, i = 100, loss = 0.6346131563186646\n", 499 | "epoch = 4, i = 200, loss = 0.8446760177612305\n", 500 | "epoch = 4, i = 300, loss = 1.2564671039581299\n", 501 | "epoch = 4, i = 400, loss = 1.1416804790496826\n", 502 | "epoch = 4, i = 500, loss = 0.8681952953338623\n", 503 | "epoch = 4, i = 600, loss = 0.596240758895874\n", 504 | "epoch = 4, i = 700, loss = 0.760988175868988\n", 505 | "epoch = 4, i = 800, loss = 0.8439410328865051\n", 506 | "epoch = 4, i = 900, loss = 0.6394860148429871\n", 507 | "epoch = 4, i = 1000, loss = 0.589470624923706\n", 508 | "epoch = 4, i = 1100, loss = 0.8739631175994873\n", 509 | "epoch = 4, i = 1200, loss = 0.43961581587791443\n", 510 | "epoch = 4, i = 1300, loss = 0.7726854085922241\n", 511 | "evaluating model on dev and test...\n", 512 | "dev_acc: 0.4580501317977905, dev_rmse: 2.6624057292938232\n", 513 | "test_acc: 0.46700000762939453, test_rmse: 2.432323932647705\n", 514 | "epoch = 5, i = 0, loss = 0.5627894401550293\n", 515 | "epoch = 5, i = 100, loss = 0.6562784910202026\n", 516 | "epoch = 5, i = 200, loss = 0.9611048698425293\n", 517 | "epoch = 5, i = 300, loss = 0.40457683801651\n", 518 | "epoch = 5, i = 400, loss = 0.5332111716270447\n", 519 | "epoch = 5, i = 500, loss = 0.7762923240661621\n", 520 | "epoch = 5, i = 600, loss = 0.6530032157897949\n", 521 | "epoch = 5, i = 700, loss = 0.8856066465377808\n", 522 | "epoch = 5, i = 800, loss = 0.5990040302276611\n", 523 | "epoch = 5, i = 900, loss = 0.8177927136421204\n", 524 | "epoch = 5, i = 1000, loss = 0.5464267730712891\n", 525 | "epoch = 5, i = 1100, loss = 0.577069878578186\n", 526 | "epoch = 5, i = 1200, loss = 0.6961219906806946\n", 527 | "epoch = 5, i = 1300, loss = 0.6551117897033691\n", 528 | "evaluating model on dev and test...\n", 529 | "dev_acc: 0.46489217877388, dev_rmse: 2.658437728881836\n", 530 | "test_acc: 0.4720799922943115, test_rmse: 2.4418435096740723\n", 531 | "epoch = 6, i = 0, loss = 0.9389997720718384\n", 532 | "epoch = 6, i = 100, loss = 0.6106300354003906\n", 533 | "epoch = 6, i = 200, loss = 0.5450400114059448\n", 534 | "epoch = 6, i = 300, loss = 0.9185134172439575\n", 535 | "epoch = 6, i = 400, loss = 0.4569363594055176\n", 536 | "epoch = 6, i = 500, loss = 0.9947926998138428\n", 537 | "epoch = 6, i = 600, loss = 0.8373309373855591\n", 538 | "epoch = 6, i = 700, loss = 0.5167772769927979\n", 539 | "epoch = 6, i = 800, loss = 0.4667099416255951\n", 540 | "epoch = 6, i = 900, loss = 0.7207782864570618\n", 541 | "epoch = 6, i = 1000, loss = 0.7581618428230286\n", 542 | "epoch = 6, i = 1100, loss = 0.5784988403320312\n", 543 | "epoch = 6, i = 1200, loss = 0.5652433633804321\n", 544 | "epoch = 6, i = 1300, loss = 0.49687403440475464\n", 545 | "evaluating model on dev and test...\n", 546 | "dev_acc: 0.47183018922805786, dev_rmse: 2.684798002243042\n", 547 | "test_acc: 0.48076000809669495, test_rmse: 2.4388933181762695\n", 548 | "epoch = 7, i = 0, loss = 0.6185715794563293\n", 549 | "epoch = 7, i = 100, loss = 0.4020466208457947\n", 550 | "epoch = 7, i = 200, loss = 0.5055506825447083\n", 551 | "epoch = 7, i = 300, loss = 0.9028556942939758\n", 552 | "epoch = 7, i = 400, loss = 0.6325691938400269\n", 553 | "epoch = 7, i = 500, loss = 0.5334853529930115\n", 554 | "epoch = 7, i = 600, loss = 0.7755710482597351\n", 555 | "epoch = 7, i = 700, loss = 0.6400124430656433\n", 556 | "epoch = 7, i = 800, loss = 0.8320757746696472\n", 557 | "epoch = 7, i = 900, loss = 0.6655226945877075\n", 558 | "epoch = 7, i = 1000, loss = 0.6563359498977661\n", 559 | "epoch = 7, i = 1100, loss = 0.7202091217041016\n", 560 | "epoch = 7, i = 1200, loss = 0.47889444231987\n", 561 | "epoch = 7, i = 1300, loss = 0.5980694890022278\n", 562 | "evaluating model on dev and test...\n", 563 | "dev_acc: 0.452574223279953, dev_rmse: 2.565908908843994\n", 564 | "test_acc: 0.45767998695373535, test_rmse: 2.3844916820526123\n", 565 | "epoch = 8, i = 0, loss = 0.4203689396381378\n", 566 | "epoch = 8, i = 100, loss = 0.6011685132980347\n", 567 | "epoch = 8, i = 200, loss = 0.4834330081939697\n", 568 | "epoch = 8, i = 300, loss = 0.5145273208618164\n", 569 | "epoch = 8, i = 400, loss = 0.5205063223838806\n", 570 | "epoch = 8, i = 500, loss = 0.46570590138435364\n", 571 | "epoch = 8, i = 600, loss = 0.5869192481040955\n", 572 | "epoch = 8, i = 700, loss = 0.9000004529953003\n", 573 | "epoch = 8, i = 800, loss = 1.029176950454712\n", 574 | "epoch = 8, i = 900, loss = 0.6565335988998413\n", 575 | "epoch = 8, i = 1000, loss = 0.48471736907958984\n", 576 | "epoch = 8, i = 1100, loss = 0.5047010183334351\n", 577 | "epoch = 8, i = 1200, loss = 0.9043950438499451\n", 578 | "epoch = 8, i = 1300, loss = 0.6348180174827576\n", 579 | "evaluating model on dev and test...\n", 580 | "dev_acc: 0.47310036420822144, dev_rmse: 2.5988094806671143\n", 581 | "test_acc: 0.4937399923801422, test_rmse: 2.366136074066162\n", 582 | "epoch = 9, i = 0, loss = 0.3131829798221588\n", 583 | "epoch = 9, i = 100, loss = 0.46003639698028564\n", 584 | "epoch = 9, i = 200, loss = 0.5366580486297607\n", 585 | "epoch = 9, i = 300, loss = 0.45132943987846375\n", 586 | "epoch = 9, i = 400, loss = 0.5376684665679932\n", 587 | "epoch = 9, i = 500, loss = 0.9212081432342529\n", 588 | "epoch = 9, i = 600, loss = 0.4985387921333313\n", 589 | "epoch = 9, i = 700, loss = 0.2654397785663605\n", 590 | "epoch = 9, i = 800, loss = 1.150120496749878\n", 591 | "epoch = 9, i = 900, loss = 0.44342440366744995\n", 592 | "epoch = 9, i = 1000, loss = 0.6016779541969299\n", 593 | "epoch = 9, i = 1100, loss = 0.7659499645233154\n", 594 | "epoch = 9, i = 1200, loss = 0.5102328062057495\n", 595 | "epoch = 9, i = 1300, loss = 0.6641837954521179\n", 596 | "evaluating model on dev and test...\n", 597 | "dev_acc: 0.47911256551742554, dev_rmse: 2.5959513187408447\n", 598 | "test_acc: 0.490339994430542, test_rmse: 2.3581349849700928\n", 599 | "epoch = 10, i = 0, loss = 0.9485877156257629\n", 600 | "epoch = 10, i = 100, loss = 0.4240656793117523\n", 601 | "epoch = 10, i = 200, loss = 0.7539646625518799\n", 602 | "epoch = 10, i = 300, loss = 0.43598222732543945\n", 603 | "epoch = 10, i = 400, loss = 0.5969750285148621\n", 604 | "epoch = 10, i = 500, loss = 0.6358165144920349\n" 605 | ] 606 | }, 607 | { 608 | "name": "stdout", 609 | "output_type": "stream", 610 | "text": [ 611 | "epoch = 10, i = 600, loss = 0.4486617147922516\n", 612 | "epoch = 10, i = 700, loss = 0.4893093705177307\n", 613 | "epoch = 10, i = 800, loss = 0.8087934851646423\n", 614 | "epoch = 10, i = 900, loss = 0.6663355827331543\n", 615 | "epoch = 10, i = 1000, loss = 0.38856399059295654\n", 616 | "epoch = 10, i = 1100, loss = 0.6736136674880981\n", 617 | "epoch = 10, i = 1200, loss = 0.8147845268249512\n", 618 | "epoch = 10, i = 1300, loss = 0.7092846632003784\n", 619 | "evaluating model on dev and test...\n", 620 | "dev_acc: 0.473275363445282, dev_rmse: 2.5685806274414062\n", 621 | "test_acc: 0.4786800146102905, test_rmse: 2.3469128608703613\n", 622 | "epoch = 11, i = 0, loss = 0.5136302709579468\n", 623 | "epoch = 11, i = 100, loss = 0.6258996725082397\n", 624 | "epoch = 11, i = 200, loss = 0.5074462294578552\n", 625 | "epoch = 11, i = 300, loss = 0.8129042983055115\n", 626 | "epoch = 11, i = 400, loss = 0.5542230606079102\n", 627 | "epoch = 11, i = 500, loss = 0.5161349773406982\n", 628 | "epoch = 11, i = 600, loss = 0.5896819233894348\n", 629 | "epoch = 11, i = 700, loss = 0.6863290667533875\n", 630 | "epoch = 11, i = 800, loss = 0.5083870887756348\n", 631 | "epoch = 11, i = 900, loss = 0.3265340328216553\n", 632 | "epoch = 11, i = 1000, loss = 0.7775629162788391\n", 633 | "epoch = 11, i = 1100, loss = 0.7597150802612305\n", 634 | "epoch = 11, i = 1200, loss = 0.49520641565322876\n", 635 | "epoch = 11, i = 1300, loss = 0.44222038984298706\n", 636 | "evaluating model on dev and test...\n", 637 | "dev_acc: 0.4771198034286499, dev_rmse: 2.5524861812591553\n", 638 | "test_acc: 0.47642001509666443, test_rmse: 2.3526580333709717\n", 639 | "epoch = 12, i = 0, loss = 0.3910914957523346\n", 640 | "epoch = 12, i = 100, loss = 0.33726972341537476\n", 641 | "epoch = 12, i = 200, loss = 0.5295529365539551\n", 642 | "epoch = 12, i = 300, loss = 0.4419896602630615\n", 643 | "epoch = 12, i = 400, loss = 0.3294357657432556\n", 644 | "epoch = 12, i = 500, loss = 0.8175026178359985\n", 645 | "epoch = 12, i = 600, loss = 0.5169392228126526\n", 646 | "epoch = 12, i = 700, loss = 0.37090879678726196\n", 647 | "epoch = 12, i = 800, loss = 0.4401392340660095\n", 648 | "epoch = 12, i = 900, loss = 0.657434344291687\n", 649 | "epoch = 12, i = 1000, loss = 0.47396910190582275\n", 650 | "epoch = 12, i = 1100, loss = 0.6415238380432129\n", 651 | "epoch = 12, i = 1200, loss = 0.49534136056900024\n", 652 | "epoch = 12, i = 1300, loss = 0.5604371428489685\n", 653 | "evaluating model on dev and test...\n", 654 | "dev_acc: 0.482234388589859, dev_rmse: 2.542769432067871\n", 655 | "test_acc: 0.4944800138473511, test_rmse: 2.3373489379882812\n", 656 | "epoch = 13, i = 0, loss = 0.6540343761444092\n", 657 | "epoch = 13, i = 100, loss = 0.5504652261734009\n", 658 | "epoch = 13, i = 200, loss = 0.613745391368866\n", 659 | "epoch = 13, i = 300, loss = 0.48012328147888184\n", 660 | "epoch = 13, i = 400, loss = 0.5549217462539673\n", 661 | "epoch = 13, i = 500, loss = 0.7506428956985474\n", 662 | "epoch = 13, i = 600, loss = 0.6042653322219849\n", 663 | "epoch = 13, i = 700, loss = 0.6180323362350464\n", 664 | "epoch = 13, i = 800, loss = 0.4621243178844452\n", 665 | "epoch = 13, i = 900, loss = 0.29211950302124023\n", 666 | "epoch = 13, i = 1000, loss = 0.6967971324920654\n", 667 | "epoch = 13, i = 1100, loss = 0.5880112648010254\n", 668 | "epoch = 13, i = 1200, loss = 0.7064930200576782\n", 669 | "epoch = 13, i = 1300, loss = 0.3936261534690857\n", 670 | "evaluating model on dev and test...\n", 671 | "dev_acc: 0.4788077175617218, dev_rmse: 2.534341335296631\n", 672 | "test_acc: 0.48833999037742615, test_rmse: 2.3233165740966797\n", 673 | "epoch = 14, i = 0, loss = 0.6755526065826416\n", 674 | "epoch = 14, i = 100, loss = 0.33443135023117065\n", 675 | "epoch = 14, i = 200, loss = 0.36527568101882935\n", 676 | "epoch = 14, i = 300, loss = 1.0132145881652832\n", 677 | "epoch = 14, i = 400, loss = 0.5340710878372192\n", 678 | "epoch = 14, i = 500, loss = 0.7321001887321472\n", 679 | "epoch = 14, i = 600, loss = 0.8975104093551636\n", 680 | "epoch = 14, i = 700, loss = 0.5202504396438599\n", 681 | "epoch = 14, i = 800, loss = 1.0374795198440552\n", 682 | "epoch = 14, i = 900, loss = 0.8503618836402893\n", 683 | "epoch = 14, i = 1000, loss = 0.26997876167297363\n", 684 | "epoch = 14, i = 1100, loss = 0.42291390895843506\n", 685 | "epoch = 14, i = 1200, loss = 0.549490213394165\n", 686 | "epoch = 14, i = 1300, loss = 0.4087107479572296\n", 687 | "evaluating model on dev and test...\n", 688 | "dev_acc: 0.47907304763793945, dev_rmse: 2.5677783489227295\n", 689 | "test_acc: 0.5010799765586853, test_rmse: 2.3345234394073486\n", 690 | "epoch = 15, i = 0, loss = 0.3381015658378601\n", 691 | "epoch = 15, i = 100, loss = 0.47336000204086304\n", 692 | "epoch = 15, i = 200, loss = 0.5756043195724487\n", 693 | "epoch = 15, i = 300, loss = 0.3944014012813568\n", 694 | "epoch = 15, i = 400, loss = 0.6456435918807983\n", 695 | "epoch = 15, i = 500, loss = 0.6376544833183289\n", 696 | "epoch = 15, i = 600, loss = 0.38120079040527344\n", 697 | "epoch = 15, i = 700, loss = 0.39884161949157715\n", 698 | "epoch = 15, i = 800, loss = 0.3604346215724945\n", 699 | "epoch = 15, i = 900, loss = 0.355624258518219\n", 700 | "epoch = 15, i = 1000, loss = 0.4303146004676819\n", 701 | "epoch = 15, i = 1100, loss = 0.49465665221214294\n", 702 | "epoch = 15, i = 1200, loss = 0.610739529132843\n", 703 | "epoch = 15, i = 1300, loss = 0.4250361919403076\n", 704 | "evaluating model on dev and test...\n", 705 | "dev_acc: 0.4612114727497101, dev_rmse: 2.5253374576568604\n", 706 | "test_acc: 0.4729999899864197, test_rmse: 2.3174986839294434\n", 707 | "epoch = 16, i = 0, loss = 0.28534242510795593\n", 708 | "epoch = 16, i = 100, loss = 0.6168354153633118\n", 709 | "epoch = 16, i = 200, loss = 0.43274885416030884\n", 710 | "epoch = 16, i = 300, loss = 0.38376590609550476\n", 711 | "epoch = 16, i = 400, loss = 0.527509868144989\n", 712 | "epoch = 16, i = 500, loss = 0.45143264532089233\n", 713 | "epoch = 16, i = 600, loss = 0.3917695879936218\n", 714 | "epoch = 16, i = 700, loss = 0.5136957168579102\n", 715 | "epoch = 16, i = 800, loss = 0.3321128487586975\n", 716 | "epoch = 16, i = 900, loss = 0.35770952701568604\n", 717 | "epoch = 16, i = 1000, loss = 0.37515610456466675\n", 718 | "epoch = 16, i = 1100, loss = 0.36942195892333984\n", 719 | "epoch = 16, i = 1200, loss = 0.41256609559059143\n", 720 | "epoch = 16, i = 1300, loss = 0.7216428518295288\n", 721 | "evaluating model on dev and test...\n", 722 | "dev_acc: 0.48063114285469055, dev_rmse: 2.539537191390991\n", 723 | "test_acc: 0.48642000555992126, test_rmse: 2.3322949409484863\n", 724 | "epoch = 17, i = 0, loss = 0.1989438235759735\n", 725 | "epoch = 17, i = 100, loss = 0.37077924609184265\n", 726 | "epoch = 17, i = 200, loss = 0.548552393913269\n", 727 | "epoch = 17, i = 300, loss = 0.32453879714012146\n", 728 | "epoch = 17, i = 400, loss = 0.37269145250320435\n", 729 | "epoch = 17, i = 500, loss = 0.49430444836616516\n", 730 | "epoch = 17, i = 600, loss = 0.42291736602783203\n", 731 | "epoch = 17, i = 700, loss = 0.4055926203727722\n", 732 | "epoch = 17, i = 800, loss = 0.6607651710510254\n", 733 | "epoch = 17, i = 900, loss = 0.3112538456916809\n", 734 | "epoch = 17, i = 1000, loss = 0.3921870291233063\n", 735 | "epoch = 17, i = 1100, loss = 0.4733957052230835\n", 736 | "epoch = 17, i = 1200, loss = 0.3893886208534241\n", 737 | "epoch = 17, i = 1300, loss = 0.5213131308555603\n", 738 | "evaluating model on dev and test...\n", 739 | "dev_acc: 0.4829513430595398, dev_rmse: 2.543346643447876\n", 740 | "test_acc: 0.5041000247001648, test_rmse: 2.3136982917785645\n", 741 | "epoch = 18, i = 0, loss = 0.36615651845932007\n", 742 | "epoch = 18, i = 100, loss = 0.4917939305305481\n", 743 | "epoch = 18, i = 200, loss = 0.30025339126586914\n", 744 | "epoch = 18, i = 300, loss = 0.5090234279632568\n", 745 | "epoch = 18, i = 400, loss = 0.3057374060153961\n", 746 | "epoch = 18, i = 500, loss = 0.467905193567276\n", 747 | "epoch = 18, i = 600, loss = 0.31017005443573\n", 748 | "epoch = 18, i = 700, loss = 0.35623863339424133\n", 749 | "epoch = 18, i = 800, loss = 0.5136244297027588\n", 750 | "epoch = 18, i = 900, loss = 0.2820129990577698\n", 751 | "epoch = 18, i = 1000, loss = 0.5056727528572083\n", 752 | "epoch = 18, i = 1100, loss = 0.38559073209762573\n", 753 | "epoch = 18, i = 1200, loss = 0.5067494511604309\n", 754 | "epoch = 18, i = 1300, loss = 0.23265519738197327\n", 755 | "evaluating model on dev and test...\n", 756 | "dev_acc: 0.4571525454521179, dev_rmse: 2.53709077835083\n", 757 | "test_acc: 0.45956000685691833, test_rmse: 2.32684326171875\n", 758 | "epoch = 19, i = 0, loss = 0.4580419361591339\n", 759 | "epoch = 19, i = 100, loss = 0.6830368041992188\n", 760 | "epoch = 19, i = 200, loss = 0.6702297925949097\n", 761 | "epoch = 19, i = 300, loss = 0.3448217809200287\n", 762 | "epoch = 19, i = 400, loss = 0.5003037452697754\n", 763 | "epoch = 19, i = 500, loss = 0.5249284505844116\n", 764 | "epoch = 19, i = 600, loss = 0.5032473206520081\n", 765 | "epoch = 19, i = 700, loss = 0.46184176206588745\n", 766 | "epoch = 19, i = 800, loss = 0.5465124845504761\n", 767 | "epoch = 19, i = 900, loss = 0.2270112931728363\n", 768 | "epoch = 19, i = 1000, loss = 0.2748583257198334\n", 769 | "epoch = 19, i = 1100, loss = 0.5165499448776245\n", 770 | "epoch = 19, i = 1200, loss = 0.43792837858200073\n", 771 | "epoch = 19, i = 1300, loss = 0.5342369079589844\n", 772 | "evaluating model on dev and test...\n", 773 | "dev_acc: 0.4835158586502075, dev_rmse: 2.512709140777588\n", 774 | "test_acc: 0.49654000997543335, test_rmse: 2.3054285049438477\n" 775 | ] 776 | } 777 | ], 778 | "source": [ 779 | "for epoch in range(20):\n", 780 | " for i, (v_emb, b, q, c, _) in enumerate(hmqa_train_loader):\n", 781 | " v_emb = Variable(v_emb)\n", 782 | " q = Variable(q)\n", 783 | " c = Variable(c).float().view(-1)\n", 784 | " \n", 785 | " if USE_CUDA:\n", 786 | " v_emb = v_emb.cuda()\n", 787 | " q = q.cuda()\n", 788 | " c = c.cuda()\n", 789 | "\n", 790 | " pred = model(v_emb, q)\n", 791 | " loss = F.smooth_l1_loss(pred, c)\n", 792 | " \n", 793 | " if i % 100 == 0:\n", 794 | " print(\"epoch = {}, i = {}, loss = {}\".format(\n", 795 | " epoch, i, loss.item()))\n", 796 | " \n", 797 | " opt.zero_grad()\n", 798 | " loss.backward()\n", 799 | " torch.nn.utils.clip_grad_norm(model.parameters(), 0.25)\n", 800 | " opt.step()\n", 801 | " \n", 802 | " print(\"evaluating model on dev and test...\")\n", 803 | "\n", 804 | " model.eval()\n", 805 | " dev_acc, dev_rmse = evaluate(model, hmqa_dev_loader)\n", 806 | " print(\"dev_acc: {}, dev_rmse: {}\".format(dev_acc, dev_rmse))\n", 807 | " test_acc, test_rmse = evaluate(model, hmqa_test_loader)\n", 808 | " print(\"test_acc: {}, test_rmse: {}\".format(test_acc, test_rmse))\n", 809 | " model.train()\n", 810 | " \n", 811 | " test_accs.append(test_acc)\n", 812 | " test_rmses.append(test_rmse)\n", 813 | " dev_accs.append(dev_acc)\n", 814 | " dev_rmses.append(dev_rmse)\n", 815 | " " 816 | ] 817 | }, 818 | { 819 | "cell_type": "code", 820 | "execution_count": 73, 821 | "metadata": { 822 | "scrolled": true 823 | }, 824 | "outputs": [ 825 | { 826 | "data": { 827 | "text/plain": [ 828 | "[(tensor(0.4835), tensor(0.4965), tensor(2.3054)),\n", 829 | " (tensor(0.4830), tensor(0.5041), tensor(2.3137)),\n", 830 | " (tensor(0.4822), tensor(0.4945), tensor(2.3373)),\n", 831 | " (tensor(0.4806), tensor(0.4864), tensor(2.3323)),\n", 832 | " (tensor(0.4791), tensor(0.4903), tensor(2.3581)),\n", 833 | " (tensor(0.4791), tensor(0.5011), tensor(2.3345)),\n", 834 | " (tensor(0.4788), tensor(0.4883), tensor(2.3233)),\n", 835 | " (tensor(0.4771), tensor(0.4764), tensor(2.3527)),\n", 836 | " (tensor(0.4733), tensor(0.4787), tensor(2.3469)),\n", 837 | " (tensor(0.4731), tensor(0.4937), tensor(2.3661)),\n", 838 | " (tensor(0.4718), tensor(0.4808), tensor(2.4389)),\n", 839 | " (tensor(0.4649), tensor(0.4721), tensor(2.4418)),\n", 840 | " (tensor(0.4612), tensor(0.4730), tensor(2.3175)),\n", 841 | " (tensor(0.4581), tensor(0.4670), tensor(2.4323)),\n", 842 | " (tensor(0.4572), tensor(0.4596), tensor(2.3268)),\n", 843 | " (tensor(0.4526), tensor(0.4577), tensor(2.3845)),\n", 844 | " (tensor(0.4525), tensor(0.4663), tensor(2.4992)),\n", 845 | " (tensor(0.4399), tensor(0.4389), tensor(2.5628)),\n", 846 | " (tensor(0.4387), tensor(0.4358), tensor(2.6740)),\n", 847 | " (tensor(0.4111), tensor(0.4033), tensor(2.6869))]" 848 | ] 849 | }, 850 | "execution_count": 73, 851 | "metadata": {}, 852 | "output_type": "execute_result" 853 | } 854 | ], 855 | "source": [ 856 | "top_dev_accs = sorted(zip(dev_accs, test_accs, test_rmses), reverse=True)\n", 857 | "top_dev_accs" 858 | ] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "execution_count": 74, 863 | "metadata": {}, 864 | "outputs": [ 865 | { 866 | "name": "stdout", 867 | "output_type": "stream", 868 | "text": [ 869 | "The best dev accuracy is 0.4835158586502075. The corresponding test accuracy and test RMSE are 0.49654000997543335 and 2.3054285049438477 respectively\n" 870 | ] 871 | } 872 | ], 873 | "source": [ 874 | "best_dev_acc, corr_test_acc, corr_test_rmse = top_dev_accs[0]\n", 875 | "print(\"The best dev accuracy is {}. The corresponding test accuracy and test RMSE are {} and {} respectively\".format(\n", 876 | " best_dev_acc, corr_test_acc, corr_test_rmse\n", 877 | "))" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": null, 883 | "metadata": {}, 884 | "outputs": [], 885 | "source": [] 886 | } 887 | ], 888 | "metadata": { 889 | "kernelspec": { 890 | "display_name": "Python 3", 891 | "language": "python", 892 | "name": "python3" 893 | }, 894 | "language_info": { 895 | "codemirror_mode": { 896 | "name": "ipython", 897 | "version": 3 898 | }, 899 | "file_extension": ".py", 900 | "mimetype": "text/x-python", 901 | "name": "python", 902 | "nbconvert_exporter": "python", 903 | "pygments_lexer": "ipython3", 904 | "version": "3.6.5" 905 | } 906 | }, 907 | "nbformat": 4, 908 | "nbformat_minor": 2 909 | } 910 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | USE_CUDA=True 2 | VOCAB_SIZE = 20158 3 | DATA_DIR = "./data" -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import json 3 | try: 4 | import _pickle as pkl 5 | except: 6 | import cPickle as pkl 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class Dictionary(object): 13 | def __init__(self, word2idx=None, idx2word=None): 14 | if word2idx is None: 15 | word2idx = {} 16 | if idx2word is None: 17 | idx2word = [] 18 | self.word2idx = word2idx 19 | self.idx2word = idx2word 20 | 21 | @property 22 | def ntoken(self): 23 | return len(self.word2idx) 24 | 25 | @property 26 | def padding_idx(self): 27 | return len(self.word2idx) 28 | 29 | def tokenize(self, sentence, add_word): 30 | sentence = sentence.lower() 31 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s') 32 | words = sentence.split() 33 | tokens = [] 34 | if add_word: 35 | for w in words: 36 | tokens.append(self.add_word(w)) 37 | else: 38 | for w in words: 39 | tokens.append(self.word2idx[w]) 40 | return tokens 41 | 42 | def dump_to_file(self, path): 43 | pkl.dump([self.word2idx, self.idx2word], open(path, 'wb')) 44 | print('dictionary dumped to %s' % path) 45 | 46 | @classmethod 47 | def load_from_file(cls, path): 48 | print('loading dictionary from %s' % path) 49 | word2idx, idx2word = pkl.load(open(path, 'rb')) 50 | d = cls(word2idx, idx2word) 51 | return d 52 | 53 | def add_word(self, word): 54 | if word not in self.word2idx: 55 | self.idx2word.append(word) 56 | self.word2idx[word] = len(self.idx2word) - 1 57 | return self.word2idx[word] 58 | 59 | def __len__(self): 60 | return len(self.idx2word) 61 | 62 | 63 | class HMQAFeatureDataset(Dataset): 64 | def __init__(self, img_id2hqma_idx, image_features, spatial_features, qid2count, qid2count2score, 65 | name, dictionary): 66 | super(HMQAFeatureDataset, self).__init__() 67 | 68 | assert name in ["train", "dev", "test"] 69 | self.name = name 70 | self.qid2count = qid2count 71 | self.qid2count2score = qid2count2score 72 | self.qids = None 73 | if self.name == "train": 74 | self.part_qid2count = self.qid2count["train"] 75 | self.part_qid2count2score = self.qid2count2score["train"] 76 | else: 77 | self.part_qid2count = self.qid2count[self.name] 78 | self.part_qid2count2score = self.qid2count2score[self.name] 79 | 80 | self.dictionary = dictionary 81 | 82 | self.img_id2hqma_idx = img_id2hqma_idx 83 | self._features = image_features 84 | self._spatials = spatial_features 85 | 86 | self.entries = self.load_dataset() 87 | 88 | self.tokenize() 89 | self.tensorize() 90 | self.v_dim = self.features.size(2) 91 | self.s_dim = self.spatials.size(2) 92 | 93 | def prepare_entries(self, questions, part_qid2count, part_qid2count2score, qid_prefix=''): 94 | entries = [] 95 | 96 | set_qids = set(part_qid2count.keys()) 97 | for question in questions: 98 | question_id = str(question["question_id"]) 99 | if question_id not in set_qids: 100 | # print("{} is not there".format(question_id)) 101 | # break 102 | continue 103 | 104 | image_id = question['image_id'] 105 | img_hqma_idx = self.img_id2hqma_idx[image_id] 106 | count = part_qid2count[question_id] 107 | score2count = part_qid2count2score[question_id] 108 | question = question["question"] 109 | 110 | entries.append({ 111 | "question_id": qid_prefix + question_id, 112 | "image_id": image_id, 113 | "img_hqma_idx": img_hqma_idx, 114 | "count": count, 115 | "score2count": score2count, 116 | "question": question, 117 | }) 118 | 119 | return entries 120 | 121 | def load_dataset(self): 122 | """Load entries 123 | 124 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features 125 | dataroot: root path of dataset 126 | name: 'train', 'val' 127 | """ 128 | 129 | if self.name == "train": 130 | fname = "train" 131 | vqa_question_path = './data/v2_OpenEnded_mscoco_{}2014_questions.json'.format(fname) 132 | vqa_questions = sorted(json.load(open(vqa_question_path))['questions'], key=lambda x: x['question_id']) 133 | 134 | vqa_entries = self.prepare_entries( 135 | questions=vqa_questions, 136 | part_qid2count=self.part_qid2count["vqa"], 137 | part_qid2count2score=self.part_qid2count2score["vqa"], 138 | qid_prefix="vqa" 139 | ) 140 | 141 | vgn_questions = sorted(json.load(open("./data/how_many_qa/vgn_ques.json")), key=lambda x: x['question_id']) 142 | vgn_entries = self.prepare_entries( 143 | questions=vgn_questions, 144 | part_qid2count=self.part_qid2count["visual_genome"], 145 | part_qid2count2score=self.part_qid2count2score["visual_genome"], 146 | qid_prefix="vgn" 147 | ) 148 | 149 | return vqa_entries + vgn_entries 150 | 151 | elif self.name in ["dev", "test"]: 152 | fname = "val" 153 | question_path = './data/v2_OpenEnded_mscoco_{}2014_questions.json'.format(fname) 154 | questions = sorted(json.load(open(question_path))['questions'], key=lambda x: x['question_id']) 155 | 156 | val_entries = self.prepare_entries( 157 | questions=questions, 158 | part_qid2count=self.part_qid2count, 159 | part_qid2count2score=self.part_qid2count2score, 160 | ) 161 | return val_entries 162 | else: 163 | raise Exception("uknown name '{}'".format(self.name)) 164 | 165 | def tokenize(self, max_length=14): 166 | """Tokenizes the questions. 167 | 168 | This will add q_token in each entry of the dataset. 169 | -1 represent nil, and should be treated as padding_idx in embedding 170 | """ 171 | for entry in self.entries: 172 | tokens = self.dictionary.tokenize(entry['question'], False) 173 | tokens = tokens[:max_length] 174 | if len(tokens) < max_length: 175 | # Note here we pad in front of the sentence 176 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 177 | tokens = padding + tokens 178 | assert len(tokens) == max_length 179 | entry['q_token'] = tokens 180 | 181 | def tensorize(self): 182 | # TODO: uncomment later 183 | self.features = torch.from_numpy(self._features) 184 | self.spatials = torch.from_numpy(self._spatials) 185 | 186 | for entry in self.entries: 187 | question = torch.from_numpy(np.array(entry['q_token'])) 188 | entry['q_token'] = question 189 | entry["count"] = torch.from_numpy(np.array([entry["count"]])) 190 | entry["score2count"] = torch.from_numpy(np.array(entry["score2count"])).float() 191 | 192 | def __getitem__(self, index): 193 | entry = self.entries[index] 194 | features = self.features[entry['img_hqma_idx']] 195 | spatials = self.spatials[entry['img_hqma_idx']] 196 | 197 | question = entry['q_token'] 198 | count = entry["count"] 199 | score2count = entry["score2count"] 200 | 201 | return features, spatials, question, count, score2count 202 | 203 | def __len__(self): 204 | return len(self.entries) 205 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from config import * 6 | from torch.nn.utils.weight_norm import weight_norm 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from torch.distributions.categorical import Categorical 10 | 11 | 12 | class QuestionParser(nn.Module): 13 | glove_file = DATA_DIR + "/glove6b_init_300d.npy" 14 | 15 | def __init__(self, dropout=0.3, word_dim=300, ques_dim=1024): 16 | super(QuestionParser, self).__init__() 17 | 18 | self.dropout = dropout 19 | self.word_dim = word_dim 20 | self.ques_dim = ques_dim 21 | 22 | self.embd = nn.Embedding(VOCAB_SIZE + 1, self.word_dim, padding_idx=VOCAB_SIZE) 23 | self.rnn = nn.GRU(self.word_dim, self.ques_dim) 24 | self.drop = nn.Dropout(self.dropout) 25 | self.glove_init() 26 | 27 | def glove_init(self): 28 | print("initialising with glove embeddings") 29 | glove_embds = torch.from_numpy(np.load(self.glove_file)) 30 | assert glove_embds.size() == (VOCAB_SIZE, self.word_dim) 31 | self.embd.weight.data[:VOCAB_SIZE] = glove_embds 32 | print("done.") 33 | 34 | def forward(self, questions): 35 | # (B, MAXLEN) 36 | # print("question size ", questions.size()) 37 | questions = questions.t() # (MAXLEN, B) 38 | questions = self.embd(questions) # (MAXLEN, B, word_size) 39 | _, (q_emb) = self.rnn(questions) 40 | q_emb = q_emb[-1] # (B, ques_size) 41 | q_emb = self.drop(q_emb) 42 | 43 | return q_emb 44 | 45 | 46 | class FCNet(nn.Module): 47 | """Simple class for non-linear fully connect network 48 | """ 49 | def __init__(self, dims): 50 | super(FCNet, self).__init__() 51 | 52 | layers = [] 53 | for i in range(len(dims)-2): 54 | in_dim = dims[i] 55 | out_dim = dims[i+1] 56 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 57 | layers.append(nn.LeakyReLU()) 58 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) 59 | layers.append(nn.LeakyReLU()) 60 | 61 | self.main = nn.Sequential(*layers) 62 | 63 | def forward(self, x): 64 | return self.main(x) 65 | 66 | 67 | class ScoringFunction(nn.Module): 68 | 69 | def __init__(self, ques_dim, dropout=0.3, v_dim=2048, score_dim=1024): 70 | super(ScoringFunction, self).__init__() 71 | 72 | self.q_dim = ques_dim 73 | self.dropout = dropout 74 | self.v_dim = v_dim 75 | self.score_dim = score_dim 76 | 77 | self.v_drop = nn.Dropout(self.dropout) 78 | self.q_drop = nn.Dropout(self.dropout) 79 | self.v_proj = FCNet([self.v_dim, self.score_dim]) 80 | self.q_proj = FCNet([self.q_dim, self.score_dim]) 81 | self.s_drop = nn.Dropout(self.dropout) 82 | 83 | def forward(self, v, q): 84 | """ 85 | v: [batch, k, vdim] 86 | q: [batch, qdim] 87 | """ 88 | batch, k, _ = v.size() 89 | v = self.v_drop(v) 90 | q = self.q_drop(q) 91 | v_proj = self.v_proj(v) # [batch, k, qdim] 92 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) # [batch, k, qdim] 93 | s = v_proj * q_proj 94 | s = self.s_drop(s) 95 | return s # (B, k, score_dim) 96 | 97 | 98 | class GTUScoringFunction(nn.Module): 99 | 100 | def __init__(self, ques_dim, dropout=0.3, v_dim=2048, score_dim=2048): 101 | super(GTUScoringFunction, self).__init__() 102 | 103 | self.q_dim = ques_dim 104 | self.dropout = dropout 105 | self.v_dim = v_dim 106 | self.score_dim = score_dim 107 | 108 | self.predrop = nn.Dropout(self.dropout) 109 | self.dense1 = weight_norm(nn.Linear(self.v_dim + self.q_dim, self.score_dim), dim=None) 110 | self.dense2 = weight_norm(nn.Linear(self.v_dim + self.q_dim, self.score_dim), dim=None) 111 | 112 | self.s_drop = nn.Dropout(self.dropout) 113 | 114 | def forward(self, v, q): 115 | """ 116 | v: [batch, k, vdim] 117 | q: [batch, qdim] 118 | """ 119 | batch, k, _ = v.size() 120 | 121 | q = q[:, None, :].repeat(1, k, 1) # (B, k, q_dim) 122 | vq = torch.cat([v, q], dim=2) # (B, k, v_dim + q_dim) 123 | 124 | vq = self.predrop(vq) 125 | 126 | y = F.tanh(self.dense1(vq)) # (B, k, score_dim) 127 | g = F.sigmoid(self.dense2(vq)) # (B, k, score_dim) 128 | 129 | s = y * g 130 | s = self.s_drop(s) 131 | return s # (B, k, score_dim) 132 | 133 | 134 | class SoftCount(nn.Module): 135 | 136 | def __init__(self, ques_dim=1024, score_dim=512, dropout=0.1): 137 | super(SoftCount, self).__init__() 138 | self.ques_parser = QuestionParser(ques_dim=ques_dim, dropout=dropout) 139 | self.f = ScoringFunction(ques_dim=ques_dim, score_dim=score_dim, dropout=dropout) 140 | self.W = weight_norm(nn.Linear(score_dim, 1), dim=None) 141 | 142 | def forward(self, v_emb, q): 143 | # v_emb = (B, k, v_dim) 144 | # q = (B, MAXLEN) 145 | 146 | q_emb = self.ques_parser(q) # (B, q_dim) 147 | s = self.f(v_emb, q_emb) # (B, k, score_dim) 148 | soft_counts = F.sigmoid(self.W(s)).squeeze(2) # (B, k) 149 | C = soft_counts.sum(dim=1) # (B,) 150 | return C 151 | 152 | 153 | class RhoScorer(nn.Module): 154 | 155 | def __init__(self, ques_dim): 156 | super(RhoScorer, self).__init__() 157 | self.W = weight_norm(nn.Linear(ques_dim, 1), dim=None) 158 | 159 | inp_dim = 1 + 1 + 6 + 6 + 1 + 1 + 1 # 17 160 | self.f_rho = FCNet([inp_dim, 100]) 161 | self.dense = weight_norm(nn.Linear(100, 1), dim=None) 162 | 163 | @staticmethod 164 | def get_spatials(b): 165 | # b = (B, k, 6) 166 | 167 | b = b.float() 168 | 169 | B, k, _ = b.size() 170 | 171 | b_ij = torch.stack([b] * k, dim=1) # (B, k, k, 6) 172 | b_ji = b_ij.transpose(1, 2) 173 | 174 | area_ij = (b_ij[:, :, :, 2] - b_ij[:, :, :, 0]) * (b_ij[:, :, :, 3] - b_ij[:, :, :, 1]) 175 | area_ji = (b_ji[:, :, :, 2] - b_ji[:, :, :, 0]) * (b_ji[:, :, :, 3] - b_ji[:, :, :, 1]) 176 | 177 | righmost_left = torch.max(b_ij[:, :, :, 0], b_ji[:, :, :, 0]) 178 | downmost_top = torch.max(b_ij[:, :, :, 1], b_ji[:, :, :, 1]) 179 | leftmost_right = torch.min(b_ij[:, :, :, 2], b_ji[:, :, :, 2]) 180 | topmost_down = torch.min(b_ij[:, :, :, 3], b_ji[:, :, :, 3]) 181 | 182 | # calucate the separations 183 | left_right = (leftmost_right - righmost_left) 184 | up_down = (topmost_down - downmost_top) 185 | 186 | # don't multiply negative separations, 187 | # might actually give a postive area that doesn't exit! 188 | left_right = torch.max(0*left_right, left_right) 189 | up_down = torch.max(0*up_down, up_down) 190 | 191 | overlap = left_right * up_down 192 | 193 | iou = overlap / (area_ij + area_ji - overlap) 194 | o_ij = overlap / area_ij 195 | o_ji = overlap / area_ji 196 | 197 | iou = iou.unsqueeze(3) # (B, k, k, 1) 198 | o_ij = o_ij.unsqueeze(3) # (B, k, k, 1) 199 | o_ji = o_ji.unsqueeze(3) # (B, k, k, 1) 200 | 201 | return b_ij, b_ji, iou, o_ij, o_ji 202 | 203 | def forward(self, q_emb, v_emb, b): 204 | # q_emb = (B, ques_size) 205 | # v_emb = (B, k, v_dim) 206 | # b = (B, k, 6) 207 | 208 | B, k, _ = v_emb.size() 209 | 210 | features = [] 211 | 212 | wq = self.W(q_emb).squeeze(1) # (B,) 213 | wq = wq[:, None, None, None].repeat(1, k, k, 1) # (B, k, k, 1) 214 | assert wq.size() == (B, k, k, 1), "wq size is {}".format(wq.size()) 215 | features.append(wq) 216 | 217 | norm_v_emb = F.normalize(v_emb, dim=2) # (B, k, v_dim) 218 | vtv = torch.bmm(norm_v_emb, norm_v_emb.transpose(1, 2)) # (B, k, k) 219 | vtv = vtv[:, :, :, None].repeat(1, 1, 1, 1) # (B, k, k, 1) 220 | assert vtv.size() == (B, k, k, 1) 221 | features.append(vtv) 222 | 223 | b_ij, b_ji, iou, o_ij, o_ji = self.get_spatials(b) 224 | 225 | assert b_ij.size() == (B, k, k, 6) 226 | assert b_ji.size() == (B, k, k, 6) 227 | assert iou.size() == (B, k, k, 1) 228 | assert o_ij.size() == (B, k, k, 1) 229 | assert o_ji.size() == (B, k, k, 1) 230 | 231 | features.append(b_ij) # (B, k, k, 6) 232 | features.append(b_ji) # (B, k, k, 6) 233 | features.append(iou) # (B, k, k, 1) 234 | features.append(o_ij) # (B, k, k, 1) 235 | features.append(o_ji) # (B, k, k, 1) 236 | 237 | features = torch.cat(features, dim=3) # (B, k, k, 17) 238 | 239 | rho = self.f_rho(features) # (B, k, k, 100) 240 | rho = self.dense(rho).squeeze(3) # (B, k, k) 241 | 242 | return rho, features # (B, k, k) 243 | 244 | 245 | class IRLC(nn.Module): 246 | 247 | def __init__(self, ques_dim=1024, score_dim=2048, dropout=0.5): 248 | super(IRLC, self).__init__() 249 | # print("question parser has zero dropout") 250 | # put zero dropout for question, because it will get dropped out in scoring anyways. 251 | self.ques_parser = QuestionParser(ques_dim=ques_dim, dropout=0) 252 | self.f_s = ScoringFunction(ques_dim=ques_dim, score_dim=score_dim, dropout=dropout) 253 | self.W = weight_norm(nn.Linear(score_dim, 1), dim=None) 254 | self.f_rho = RhoScorer(ques_dim=ques_dim) 255 | 256 | # extra custom parameters 257 | self.eps = nn.Parameter(torch.zeros(1)) 258 | self.extra_params = nn.ParameterList([self.eps]) 259 | 260 | def sample_action(self, probs, already_selected=None, greedy=False): 261 | # probs = (B, k+1) 262 | # already_selected = (num_timesteps, B) 263 | 264 | if already_selected is None: 265 | mask = 1 266 | else: 267 | mask = Variable(torch.ones(probs.size())) 268 | if USE_CUDA: 269 | # TODO: uncomment this, when this model works 270 | mask = mask.cuda() 271 | pass 272 | mask = mask.scatter_(1, already_selected.t(), 0) # (B, k+1) 273 | 274 | masked_probs = mask * (probs + 1e-20) # (B, k+1), add epsilon to make sure no non-masked value is zero. 275 | dist = Categorical(probs=masked_probs) 276 | 277 | if greedy: 278 | _, a = masked_probs.max(dim=1) # (B) 279 | else: 280 | a = dist.sample() # (B) 281 | 282 | log_prob = dist.log_prob(a) # (B) 283 | entropy = dist.entropy() # (B) 284 | return a, log_prob, entropy 285 | 286 | @staticmethod 287 | def get_interaction(rho, a): 288 | # get the interaction row in rho corresponding to the action a 289 | # rho = (B, num_actions, k) 290 | # a = (B) containing action indices between 0 and num_actions-1 291 | 292 | B, _, k = rho.size() 293 | 294 | # first expand a to the size required output 295 | a = a[:, None].repeat(1, k) # (B, k) 296 | 297 | # print("rho size = {} and a size = {}".format(rho.size(), a.size())) 298 | interaction = rho.gather(dim=1, index=a.unsqueeze(dim=1)).squeeze(dim=1) # (B, k) 299 | assert interaction.size() == (B, k), "interaction size is {}".format(interaction.size()) 300 | # print("interaction size = {}".format(interaction.size())) 301 | 302 | return interaction # (B, k) 303 | 304 | def sample_objects(self, kappa_0, rho, batch_eps, greedy=False): 305 | # kappa_0 = (B, k) 306 | # rho = (B, k, k) 307 | 308 | # add an extra row of 0 interaction for the terminal action 309 | rho = torch.cat((rho, 0 * rho[:, :1, :]), dim=1) # (B, k+1, k) 310 | 311 | B, k = kappa_0.size() 312 | 313 | P = None # save un-scaled probabilities of each action at each time-step. mainly for visualization 314 | logPA = None # log prob values for each time-step. 315 | entP = None # distribution entropy value for each timestep. 316 | A = None # will store action values at each timestep. 317 | T = k+1 # num timesteps = different possible actions. +1 for the terminal action 318 | kappa = kappa_0 # (B, k), starting kappa 319 | 320 | for t in range(T): 321 | # calculate probabilities of each action 322 | unscaled_p = F.softmax(torch.cat((kappa, batch_eps), dim=1), dim=1) # (B, k+1) 323 | # print("p = ", p) 324 | # select one object (called "action" in RL terms), avoid already selected objects. 325 | a, log_prob, entropy = self.sample_action( 326 | probs=unscaled_p, already_selected=A, greedy=greedy) # (B,), (B,), (B,) 327 | # update kappa logits with the row in the interaction matrix corresponding to the chosen action. 328 | interaction = self.get_interaction(rho, a) 329 | kappa = kappa + interaction 330 | 331 | # record the prob and action values at each timestep for later use 332 | P = unscaled_p[None] if P is None else torch.cat((P, unscaled_p[None]), dim=0) # (t+1, B, k+1) 333 | logPA = log_prob[None] if logPA is None else torch.cat((logPA, log_prob[None]), dim=0) # (t+1, B) 334 | entP = entropy[None] if entP is None else torch.cat((entP, log_prob[None]), dim=0) # (t+1, B) 335 | A = a[None] if A is None else torch.cat((A, a[None]), dim=0) # (t+1, B) 336 | 337 | assert logPA.size() == (T, B) 338 | assert entP.size() == (T, B) 339 | assert A.size() == (T, B) 340 | 341 | # calculate count 342 | terminal_action = (A == k) # (T, B) # true for the timestep when terminal action was selected. 343 | _, count = terminal_action.max(dim=0) # (B,) # index of the terminal action is considered the count 344 | 345 | return logPA, entP, A, count, P 346 | 347 | def compute_vars(self, v_emb, b, q): 348 | # v_emb = (B, k, v_dim) 349 | # b = (B, k, 6) 350 | # q = (B, MAXLEN) 351 | 352 | B, k, _ = v_emb.size() 353 | 354 | q_emb = self.ques_parser(q) # (B, q_dim) 355 | s = self.f_s(v_emb, q_emb) # (B, k, score_dim) 356 | kappa_0 = self.W(s).squeeze(2) # (B, k) 357 | 358 | rho, _ = self.f_rho(q_emb, v_emb, b) # (B, k, k) 359 | 360 | return kappa_0, rho 361 | 362 | def take_mc_samples(self, kappa_0, rho, num_mc_samples): 363 | # kappa_0 = (B, k) 364 | # rho = (B, k, k) 365 | 366 | B, k = kappa_0.size() 367 | assert rho.size() == (B, k, k) 368 | 369 | kappa_0 = kappa_0.repeat(num_mc_samples, 1) # (B * samples, k) 370 | rho = rho.repeat(num_mc_samples, 1, 1) # (B * samples, k) 371 | 372 | batch_eps = torch.cat([self.eps] * B * num_mc_samples)[:, None] # (B * samples, 1) 373 | 374 | logPA, entP, A, count, P = self.sample_objects(kappa_0=kappa_0, rho=rho, batch_eps=batch_eps) 375 | _, _, _, greedy_count, _ = self.sample_objects(kappa_0=kappa_0, rho=rho, batch_eps=batch_eps, greedy=True) 376 | 377 | return count, greedy_count, logPA, entP, A, rho, P 378 | 379 | def get_sc_loss(self, count_gt, count, greedy_count, logPA, valid_A): 380 | # count_gt = (B,) 381 | # count = (B,) 382 | # greedy_count = (B,) 383 | # logPA = (T, B) 384 | # valid_A = (T, B) 385 | 386 | assert count.size() == count_gt.size() 387 | assert greedy_count.size() == count_gt.size() 388 | 389 | count = count.float() 390 | greedy_count = greedy_count.float() 391 | count_gt = count_gt.float() 392 | 393 | # self-critical loss 394 | E = torch.abs(count - count_gt) # (B,) 395 | E_greedy = torch.abs(greedy_count - count_gt) # (B,) 396 | 397 | R = E_greedy - E # (B,) 398 | 399 | assert R.size() == count.size(), "R size is {}".format(R.size()) 400 | 401 | mean_log_PA = (logPA * valid_A).sum(dim=0) / valid_A.sum(dim=0) # (B,) 402 | 403 | batch_sc_loss = - R * mean_log_PA # (B,) 404 | sc_loss = batch_sc_loss.mean(dim=0) # (1,) 405 | 406 | return sc_loss 407 | 408 | def get_entropy_loss(self, entP, valid_A): 409 | # entP = (T, B) 410 | # valid_A = (T, B) 411 | 412 | batch_entropy_loss = - (entP * valid_A).sum(dim=0) / valid_A.sum(dim=0) # (B,) 413 | entropy_loss = batch_entropy_loss.mean(dim=0) 414 | 415 | return entropy_loss 416 | 417 | def get_interaction_strength(self, rho, A, pre_terminal_A): 418 | # rho = (B, k, k) 419 | # A = (T, B) 420 | 421 | B, k, _ = rho.size() 422 | T, B = A.size() 423 | 424 | # interaction strength, lower is better, sparse preferred. 425 | interactions = F.smooth_l1_loss(rho, 0*rho.detach(), reduce=False) # (B, k, k) 426 | interactions = interactions.mean(dim=2) # (B, k) 427 | 428 | # add a dummy interaction for the terminal action, pad zeros (value doesn't matter as it will be masked anyways. 429 | interactions = torch.cat((interactions, 0*interactions[:, :1]), dim=1) # (B, k+1) 430 | 431 | # for each timestep select the interaction corresponding to the performed action 432 | repeated_interactions = interactions[None].repeat(T, 1, 1) # (T, B, k) 433 | action_interactions = repeated_interactions.gather(dim=2, index=A.unsqueeze(2)).squeeze(2) # (T ,B) 434 | 435 | # mask out interactions due to actions done after the terminal action 436 | valid_interactions = (action_interactions * pre_terminal_A).sum(dim=0) / (1e-20 + pre_terminal_A.sum(dim=0)) # (B,) 437 | 438 | interaction_strength = valid_interactions.mean(dim=0) 439 | 440 | return interaction_strength 441 | 442 | def get_loss(self, count_gt, count, greedy_count, logPA, entP, A, rho): 443 | # count_gt = (B,) 444 | # count = (B,) 445 | # greedy_count = (B,) 446 | # logPA = (T, B) 447 | # entP = (T, B) 448 | # A = (T, B) 449 | # rho = (B, k, k) 450 | 451 | assert count.size() == count_gt.size() 452 | assert greedy_count.size() == count_gt.size() 453 | 454 | B, k, _ = rho.size() 455 | 456 | terminal_action = A.max() 457 | assert terminal_action.item() == k 458 | 459 | terminal_A = (A == terminal_action).float() # (T, B) 460 | post_terminal_A = terminal_A.cumsum(dim=0) - terminal_A # (T, B) 461 | pre_terminal_A = 1 - post_terminal_A - terminal_A 462 | valid_A = 1 - post_terminal_A # (T, B) 463 | 464 | sc_loss = self.get_sc_loss(count_gt, count, greedy_count, logPA, valid_A) 465 | entropy_loss = self.get_entropy_loss(entP, valid_A) 466 | interaction_strength = self.get_interaction_strength(rho, A, pre_terminal_A) 467 | 468 | # print("sc_loss", sc_loss, "entropy loss", entropy_loss, "interaction strength", interaction_strength) 469 | 470 | loss = 1.0 * sc_loss + .005 * entropy_loss + .005 * interaction_strength 471 | 472 | return loss 473 | -------------------------------------------------------------------------------- /tools/compute_softscore.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | import re 7 | import cPickle 8 | 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | from dataset import Dictionary 11 | 12 | 13 | def utils_create_dir(path): 14 | if not os.path.exists(path): 15 | try: 16 | os.makedirs(path) 17 | except OSError as exc: 18 | if exc.errno != errno.EEXIST: 19 | raise 20 | 21 | 22 | contractions = { 23 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 24 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 25 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 26 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 27 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 28 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 29 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 30 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 31 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 32 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 33 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 34 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 35 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 36 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 37 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 38 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 39 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 40 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 41 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 42 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 43 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 44 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 45 | "someonell": "someone'll", "someones": "someone's", "somethingd": 46 | "something'd", "somethingd've": "something'd've", "something'dve": 47 | "something'd've", "somethingll": "something'll", "thats": 48 | "that's", "thered": "there'd", "thered've": "there'd've", 49 | "there'dve": "there'd've", "therere": "there're", "theres": 50 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 51 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 52 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 53 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 54 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 55 | "what's", "whatve": "what've", "whens": "when's", "whered": 56 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 57 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 58 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 59 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 60 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 61 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 62 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 63 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 64 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 65 | "you'll", "youre": "you're", "youve": "you've" 66 | } 67 | 68 | manual_map = { 'none': '0', 69 | 'zero': '0', 70 | 'one': '1', 71 | 'two': '2', 72 | 'three': '3', 73 | 'four': '4', 74 | 'five': '5', 75 | 'six': '6', 76 | 'seven': '7', 77 | 'eight': '8', 78 | 'nine': '9', 79 | 'ten': '10'} 80 | articles = ['a', 'an', 'the'] 81 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 82 | comma_strip = re.compile("(\d)(\,)(\d)") 83 | punct = [';', r"/", '[', ']', '"', '{', '}', 84 | '(', ')', '=', '+', '\\', '_', '-', 85 | '>', '<', '@', '`', ',', '?', '!'] 86 | 87 | 88 | def get_score(occurences): 89 | if occurences == 0: 90 | return 0 91 | elif occurences == 1: 92 | return 0.3 93 | elif occurences == 2: 94 | return 0.6 95 | elif occurences == 3: 96 | return 0.9 97 | else: 98 | return 1 99 | 100 | 101 | def process_punctuation(inText): 102 | outText = inText 103 | for p in punct: 104 | if (p + ' ' in inText or ' ' + p in inText) \ 105 | or (re.search(comma_strip, inText) != None): 106 | outText = outText.replace(p, '') 107 | else: 108 | outText = outText.replace(p, ' ') 109 | outText = period_strip.sub("", outText, re.UNICODE) 110 | return outText 111 | 112 | 113 | def process_digit_article(inText): 114 | outText = [] 115 | tempText = inText.lower().split() 116 | for word in tempText: 117 | word = manual_map.setdefault(word, word) 118 | if word not in articles: 119 | outText.append(word) 120 | else: 121 | pass 122 | for wordId, word in enumerate(outText): 123 | if word in contractions: 124 | outText[wordId] = contractions[word] 125 | outText = ' '.join(outText) 126 | return outText 127 | 128 | 129 | def multiple_replace(text, wordDict): 130 | for key in wordDict: 131 | text = text.replace(key, wordDict[key]) 132 | return text 133 | 134 | 135 | def preprocess_answer(answer): 136 | answer = process_digit_article(process_punctuation(answer)) 137 | answer = answer.replace(',', '') 138 | return answer 139 | 140 | 141 | def filter_answers(answers_dset, min_occurence): 142 | """This will change the answer to preprocessed version 143 | """ 144 | occurence = {} 145 | 146 | for ans_entry in answers_dset: 147 | answers = ans_entry['answers'] 148 | gtruth = ans_entry['multiple_choice_answer'] 149 | gtruth = preprocess_answer(gtruth) 150 | if gtruth not in occurence: 151 | occurence[gtruth] = set() 152 | occurence[gtruth].add(ans_entry['question_id']) 153 | for answer in occurence.keys(): 154 | if len(occurence[answer]) < min_occurence: 155 | occurence.pop(answer) 156 | 157 | print('Num of answers that appear >= %d times: %d' % ( 158 | min_occurence, len(occurence))) 159 | return occurence 160 | 161 | 162 | def create_ans2label(occurence, name, cache_root='data/cache'): 163 | """Note that this will also create label2ans.pkl at the same time 164 | 165 | occurence: dict {answer -> whatever} 166 | name: prefix of the output file 167 | cache_root: str 168 | """ 169 | ans2label = {} 170 | label2ans = [] 171 | label = 0 172 | for answer in occurence: 173 | label2ans.append(answer) 174 | ans2label[answer] = label 175 | label += 1 176 | 177 | utils_create_dir(cache_root) 178 | 179 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl') 180 | cPickle.dump(ans2label, open(cache_file, 'wb')) 181 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl') 182 | cPickle.dump(label2ans, open(cache_file, 'wb')) 183 | return ans2label 184 | 185 | 186 | def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): 187 | """Augment answers_dset with soft score as label 188 | 189 | ***answers_dset should be preprocessed*** 190 | 191 | Write result into a cache file 192 | """ 193 | target = [] 194 | for ans_entry in answers_dset: 195 | answers = ans_entry['answers'] 196 | answer_count = {} 197 | for answer in answers: 198 | answer_ = answer['answer'] 199 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 200 | 201 | labels = [] 202 | scores = [] 203 | counts = [] 204 | for answer in answer_count: 205 | if answer not in ans2label: 206 | continue 207 | labels.append(ans2label[answer]) 208 | count = answer_count[answer] 209 | counts.append(count) 210 | score = get_score(count) 211 | scores.append(score) 212 | 213 | target.append({ 214 | 'question_id': ans_entry['question_id'], 215 | 'image_id': ans_entry['image_id'], 216 | 'labels': labels, 217 | 'scores': scores, 218 | 'counts': counts 219 | }) 220 | 221 | utils_create_dir(cache_root) 222 | cache_file = os.path.join(cache_root, name+'_target.pkl') 223 | cPickle.dump(target, open(cache_file, 'wb')) 224 | return target 225 | 226 | 227 | def get_answer(qid, answers): 228 | for ans in answers: 229 | if ans['question_id'] == qid: 230 | return ans 231 | 232 | 233 | def get_question(qid, questions): 234 | for question in questions: 235 | if question['question_id'] == qid: 236 | return question 237 | 238 | 239 | if __name__ == '__main__': 240 | train_answer_file = 'data/v2_mscoco_train2014_annotations.json' 241 | train_answers = json.load(open(train_answer_file))['annotations'] 242 | 243 | val_answer_file = 'data/v2_mscoco_val2014_annotations.json' 244 | val_answers = json.load(open(val_answer_file))['annotations'] 245 | 246 | train_question_file = 'data/v2_OpenEnded_mscoco_train2014_questions.json' 247 | train_questions = json.load(open(train_question_file))['questions'] 248 | 249 | val_question_file = 'data/v2_OpenEnded_mscoco_val2014_questions.json' 250 | val_questions = json.load(open(val_question_file))['questions'] 251 | 252 | answers = train_answers + val_answers 253 | occurence = filter_answers(answers, 9) 254 | ans2label = create_ans2label(occurence, 'trainval') 255 | compute_target(train_answers, ans2label, 'train') 256 | compute_target(val_answers, ans2label, 'val') 257 | -------------------------------------------------------------------------------- /tools/create_dictionary.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | from dataset import Dictionary 8 | 9 | 10 | from dataset import Dictionary 11 | def create_dictionary(dataroot): 12 | dictionary = Dictionary() 13 | questions = [] 14 | files = [ 15 | 'v2_OpenEnded_mscoco_train2014_questions.json', 16 | 'v2_OpenEnded_mscoco_val2014_questions.json', 17 | 'v2_OpenEnded_mscoco_test2015_questions.json', 18 | 'v2_OpenEnded_mscoco_test-dev2015_questions.json', 19 | 'how_many_qa/vgn_ques.json' 20 | ] 21 | for path in files: 22 | question_path = os.path.join(dataroot, path) 23 | qs = json.load(open(question_path)) 24 | if "vgn" not in path: 25 | qs = qs['questions'] 26 | for q in qs: 27 | dictionary.tokenize(q['question'], True) 28 | return dictionary 29 | 30 | 31 | def create_glove_embedding_init(idx2word, glove_file): 32 | word2emb = {} 33 | with open(glove_file, 'r') as f: 34 | entries = f.readlines() 35 | emb_dim = len(entries[0].split(' ')) - 1 36 | print('embedding dim is %d' % emb_dim) 37 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 38 | 39 | for entry in entries: 40 | vals = entry.split(' ') 41 | word = vals[0] 42 | vals = map(float, vals[1:]) 43 | word2emb[word] = np.array(vals) 44 | for idx, word in enumerate(idx2word): 45 | if word not in word2emb: 46 | continue 47 | weights[idx] = word2emb[word] 48 | return weights, word2emb 49 | 50 | 51 | if __name__ == '__main__': 52 | d = create_dictionary('data') 53 | d.dump_to_file('data/dictionary.pkl') 54 | 55 | d = Dictionary.load_from_file('data/dictionary.pkl') 56 | emb_dim = 300 57 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 58 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 59 | np.save('data/glove6b_init_%dd.npy' % emb_dim, weights) 60 | -------------------------------------------------------------------------------- /tools/create_how_many_qa_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import json 3 | import cPickle as pkl 4 | import os 5 | import re 6 | 7 | 8 | def find_image_ids(): 9 | 10 | # read locations 11 | train_targets_loc = "./data/cache/train_target.pkl" 12 | val_targets_loc = "./data/cache/val_target.pkl" 13 | hmq_ids_loc = "./data/how_many_qa/HowMany-QA/question_ids.json" 14 | 15 | # write locations 16 | hmq_image_ids_loc = "./data/how_many_qa/image_ids.json" 17 | if os.path.isfile(hmq_image_ids_loc): 18 | print("The file {} already exists. Skipping finding image ids.".format(hmq_image_ids_loc)) 19 | return 20 | 21 | train_targets = pkl.load(open(train_targets_loc, "rb")) 22 | val_targets = pkl.load(open(val_targets_loc, "rb")) 23 | 24 | hmq_ids = json.load(open(hmq_ids_loc, "rb")) 25 | qids = { 26 | "train": set(hmq_ids["train"]["vqa"]), 27 | "test": set(hmq_ids["test"]), 28 | "dev": set(hmq_ids["dev"]), 29 | } 30 | 31 | image_ids = { 32 | "test": [], 33 | "train": [], 34 | "dev": [], 35 | } 36 | 37 | # train 38 | for i, ans in enumerate(train_targets): 39 | if i % 10000 == 0: 40 | print(i) 41 | if ans["question_id"] in qids["train"]: 42 | image_ids["train"].append(ans["image_id"]) 43 | if ans["question_id"] in qids["test"] or ans["question_id"] in qids["dev"]: 44 | raise Exception("found train question id {} in qids marked for test and dev") 45 | 46 | # dev and test 47 | for i, ans in enumerate(val_targets): 48 | if i % 10000 == 0: 49 | print(i) 50 | if ans["question_id"] in qids["train"]: 51 | raise Exception("found validation question id {} in qids marked for training") 52 | if ans["question_id"] in qids["test"]: 53 | image_ids["test"].append(ans["image_id"]) 54 | if ans["question_id"] in qids["dev"]: 55 | image_ids["dev"].append(ans["image_id"]) 56 | 57 | unique_image_ids = { 58 | "test": list(set(image_ids["test"])), 59 | "train": list(set(image_ids["train"])), 60 | "dev": list(set(image_ids["dev"])), 61 | } 62 | 63 | assert len(unique_image_ids["train"]) == 31932 64 | assert len(unique_image_ids["dev"]) == 13119 65 | assert len(unique_image_ids["test"]) == 2483 66 | 67 | print("writing image ids for how many QA to disk..") 68 | json.dump(unique_image_ids, open(hmq_image_ids_loc, "wb")) 69 | print("Done.") 70 | return 71 | 72 | 73 | def prepare_visual_genome(): 74 | 75 | # read locations 76 | hmqa_qids_loc = "./data/how_many_qa/HowMany-QA/question_ids.json" 77 | all_vg_loc = "./data/how_many_qa/HowMany-QA/visual_genome_question_answers.json" 78 | vg_image_data_loc = "./data/how_many_qa/HowMany-QA/visual_genome_image_data.json" 79 | 80 | # write locations 81 | vgn_ques_loc = "./data/how_many_qa/vgn_ques.json" 82 | if os.path.isfile(vgn_ques_loc): 83 | print("The file {} already exists. Skipping preparing visual genome.".format(vgn_ques_loc)) 84 | return 85 | 86 | hmqa_qids = json.load(open(hmqa_qids_loc, "rb")) 87 | hmqa_vg_qids = set(hmqa_qids["train"]["visual_genome"]) 88 | all_vg = json.load(open(all_vg_loc)) 89 | vg_image_data = json.load(open(vg_image_data_loc)) 90 | 91 | vg_entries = [] 92 | 93 | for qaset in all_vg: 94 | # setid = qaset['id'] 95 | for entry in qaset['qas']: 96 | if entry["qa_id"] in hmqa_vg_qids: 97 | vg_entries.append(entry) 98 | 99 | vg_image_id2coco_id = {x["image_id"]: x["coco_id"] for x in vg_image_data} 100 | 101 | vgn_ques = [ 102 | {'image_id': vg_image_id2coco_id[x['image_id']], 'question': x['question'], 'question_id': x['qa_id'] 103 | } for x in vg_entries 104 | ] 105 | 106 | json.dump(vgn_ques, open(vgn_ques_loc, "wb")) 107 | 108 | 109 | def vg_ans2count(s): 110 | # replace all non-alphanums with space 111 | r = re.sub('[^0-9a-zA-Z]+', ' ', s) 112 | 113 | r = r.lower() 114 | r = r.split(' ') 115 | 116 | word2num = { 117 | "zero": 0, 118 | "one": 1, 119 | "two": 2, 120 | "three": 3, 121 | "four": 4, 122 | "five": 5, 123 | "six": 6, 124 | "seven": 7, 125 | "eight": 8, 126 | "nine": 9, 127 | "ten": 10, 128 | "eleven": 11, 129 | "twelve": 12, 130 | "thirteen": 13, 131 | "fourteen": 14, 132 | "fifteen": 15, 133 | "sixteen": 16, 134 | "seventeen": 17, 135 | "eighteen": 18, 136 | "nineteen": 19, 137 | "twenty": 20 138 | } 139 | 140 | cands = [] 141 | for word in r: 142 | try: 143 | cands.append(int(word)) 144 | except: 145 | pass 146 | try: 147 | cands.append(word2num[word]) 148 | except: 149 | pass 150 | 151 | cands = list(set(cands)) # merging duplicates 152 | 153 | if len(cands) != 1: 154 | print(s, cands) 155 | if (s == "3 or 4." and cands == [3, 4]): 156 | print("manually correcting '{}'".format(s)) 157 | cands = cands[:1] 158 | 159 | assert len(cands) == 1 160 | count = cands[0] 161 | 162 | assert 0 <= count <= 20 163 | 164 | return count 165 | 166 | 167 | def find_counts(): 168 | 169 | # read locations 170 | _hmq_ids_loc = "./data/how_many_qa/HowMany-QA/question_ids.json" 171 | vqa_train_entries_loc = "./data/cache/train_target.pkl" 172 | test_dev_entries_loc = "./data/cache/val_target.pkl" 173 | label2ans_loc = "./data/cache/trainval_label2ans.pkl" 174 | all_vg_loc = "./data/how_many_qa/HowMany-QA/visual_genome_question_answers.json" 175 | 176 | # write locations 177 | qid2count_loc = "./data/how_many_qa/qid2count.json" 178 | qid2count2score_loc = "./data/how_many_qa/qid2count2score.json" 179 | 180 | if os.path.isfile(qid2count_loc) and os.path.isfile(qid2count2score_loc): 181 | print("The file {} and {} already exists. Skipping finding counts.".format(qid2count_loc, qid2count2score_loc)) 182 | return 183 | 184 | _hmq_ids = json.load(open(_hmq_ids_loc, "rb")) 185 | hmq_ids = { 186 | "train": { 187 | "vqa": set(_hmq_ids["train"]["vqa"]), 188 | "visual_genome": set(_hmq_ids["train"]["visual_genome"]), 189 | }, 190 | "test": set(_hmq_ids["test"]), 191 | "dev": set(_hmq_ids["dev"]), 192 | } 193 | 194 | vqa_train_entries = pkl.load(open(vqa_train_entries_loc, "rb")) 195 | test_dev_entries = pkl.load(open(test_dev_entries_loc, "rb")) 196 | label2ans = pkl.load(open(label2ans_loc, "rb")) 197 | 198 | qid2count = { 199 | "train": { 200 | "vqa": {}, 201 | "visual_genome": {} 202 | }, 203 | "test": {}, 204 | "dev": {}, 205 | } 206 | 207 | qid2count2score = { 208 | "train": { 209 | "vqa": {}, 210 | "visual_genome": {} 211 | }, 212 | "test": {}, 213 | "dev": {}, 214 | } 215 | 216 | # vqa train 217 | for entry in vqa_train_entries: 218 | qid = entry['question_id'] 219 | 220 | if qid not in hmq_ids["train"]["vqa"]: 221 | continue 222 | 223 | gt_cands = [] 224 | max_occurence_count = 0 225 | 226 | for occurence_count, score, label in zip(entry["counts"], entry["scores"], entry["labels"]): 227 | try: 228 | count = int(label2ans[label]) 229 | assert count <= 20, "No {} is more (score: {})".format(count, score) 230 | 231 | if occurence_count > max_occurence_count: 232 | max_occurence_count = occurence_count 233 | gt_cands = [count] 234 | elif occurence_count == max_occurence_count: 235 | gt_cands.append(count) 236 | 237 | if qid2count2score["train"]["vqa"].get(qid) is None: 238 | qid2count2score["train"]["vqa"][qid] = [0] * 21 # count2score list mapping 239 | qid2count2score["train"]["vqa"][qid][count] = score 240 | 241 | except Exception as e: 242 | print(e) 243 | pass 244 | 245 | # select the answer with highest occurence count, in case of a tie select the minimum 246 | qid2count["train"]["vqa"][qid] = min(gt_cands) 247 | 248 | ##### VISUAL GENOME ####### 249 | 250 | hmqa_qids = json.load(open(_hmq_ids_loc, "rb")) 251 | hmqa_vg_qids = set(hmqa_qids["train"]["visual_genome"]) 252 | all_vg = json.load(open(all_vg_loc)) 253 | 254 | vg_entries = [] 255 | 256 | for qaset in all_vg: 257 | # setid = qaset['id'] 258 | for entry in qaset['qas']: 259 | if entry["qa_id"] in hmqa_vg_qids: 260 | vg_entries.append(entry) 261 | 262 | for entry in vg_entries: 263 | qid = entry["qa_id"] 264 | count = vg_ans2count(entry["answer"]) 265 | assert qid2count["train"]["visual_genome"].get(qid) is None 266 | assert qid2count2score["train"]["visual_genome"].get(qid) is None 267 | 268 | qid2count["train"]["visual_genome"][qid] = count 269 | qid2count2score["train"]["visual_genome"][qid] = [0] * 21 270 | qid2count2score["train"]["visual_genome"][qid][count] = 1 271 | 272 | ################################## 273 | 274 | # test and dev 275 | for entry in test_dev_entries: 276 | qid = entry['question_id'] 277 | 278 | test_entry = qid in hmq_ids["test"] 279 | dev_entry = qid in hmq_ids["dev"] 280 | 281 | if not (test_entry or dev_entry): 282 | continue 283 | 284 | if test_entry and dev_entry: 285 | raise Exception("Found qid {} that is marked for both test set and train set!!".format(qid)) 286 | 287 | gt_cands = [] 288 | max_occurence_count = 0 289 | 290 | for occurence_count, score, label in zip(entry["counts"], entry["scores"], entry["labels"]): 291 | try: 292 | count = int(label2ans[label]) 293 | assert count <= 20, "No {} is more (score: {})".format(count, score) 294 | 295 | if occurence_count > max_occurence_count: 296 | max_occurence_count = occurence_count 297 | gt_cands = [count] 298 | elif occurence_count == max_occurence_count: 299 | gt_cands.append(count) 300 | 301 | if test_entry: 302 | 303 | if qid2count2score["test"].get(qid) is None: 304 | qid2count2score["test"][qid] = [0] * 21 # count2score list mapping 305 | qid2count2score["test"][qid][count] = score 306 | 307 | if dev_entry: 308 | 309 | if qid2count2score["dev"].get(qid) is None: 310 | qid2count2score["dev"][qid] = [0] * 21 # count2score list mapping 311 | qid2count2score["dev"][qid][count] = score 312 | 313 | except Exception as e: 314 | print(e) 315 | pass 316 | 317 | # select the answer with highest occurence count, in case of a tie select the minimum 318 | if test_entry: 319 | qid2count["test"][qid] = min(gt_cands) 320 | if dev_entry: 321 | assert not test_entry 322 | qid2count["dev"][qid] = min(gt_cands) 323 | 324 | assert len(qid2count["train"]["vqa"]) == 47542 325 | assert len(qid2count["test"]) == 5000 326 | assert len(qid2count["dev"]) == 17714 327 | 328 | json.dump(qid2count, open(qid2count_loc, "w")) 329 | json.dump(qid2count2score, open(qid2count2score_loc, "w")) 330 | return 331 | 332 | 333 | def main(): 334 | find_image_ids() 335 | prepare_visual_genome() 336 | find_counts() 337 | 338 | 339 | if __name__ == '__main__': 340 | main() 341 | 342 | -------------------------------------------------------------------------------- /tools/detection_features_converter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reads in a tsv file with pre-trained bottom up attention features and 3 | stores it in HDF5 format. Also store {image_id: feature_idx} 4 | as a pickle file. 5 | 6 | Hierarchy of HDF5 file: 7 | 8 | { 'image_features': num_images x num_boxes x 2048 array of features 9 | 'image_bb': num_images x num_boxes x 4 array of bounding boxes } 10 | """ 11 | from __future__ import print_function 12 | 13 | import os 14 | import sys 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 16 | 17 | import base64 18 | import csv 19 | import h5py 20 | import cPickle 21 | import numpy as np 22 | import utils 23 | 24 | 25 | csv.field_size_limit(sys.maxsize) 26 | 27 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features'] 28 | infile = 'data/trainval_36/trainval_resnet101_faster_rcnn_genome_36.tsv' 29 | train_data_file = 'data/train36.hdf5' 30 | val_data_file = 'data/val36.hdf5' 31 | train_indices_file = 'data/train36_imgid2idx.pkl' 32 | val_indices_file = 'data/val36_imgid2idx.pkl' 33 | train_ids_file = 'data/train_ids.pkl' 34 | val_ids_file = 'data/val_ids.pkl' 35 | 36 | feature_length = 2048 37 | num_fixed_boxes = 36 38 | 39 | 40 | if __name__ == '__main__': 41 | h_train = h5py.File(train_data_file, "w") 42 | h_val = h5py.File(val_data_file, "w") 43 | 44 | if os.path.exists(train_ids_file) and os.path.exists(val_ids_file): 45 | train_imgids = cPickle.load(open(train_ids_file)) 46 | val_imgids = cPickle.load(open(val_ids_file)) 47 | else: 48 | train_imgids = utils.load_imageid('data/train2014') 49 | val_imgids = utils.load_imageid('data/val2014') 50 | cPickle.dump(train_imgids, open(train_ids_file, 'wb')) 51 | cPickle.dump(val_imgids, open(val_ids_file, 'wb')) 52 | 53 | train_indices = {} 54 | val_indices = {} 55 | 56 | train_img_features = h_train.create_dataset( 57 | 'image_features', (len(train_imgids), num_fixed_boxes, feature_length), 'f') 58 | train_img_bb = h_train.create_dataset( 59 | 'image_bb', (len(train_imgids), num_fixed_boxes, 4), 'f') 60 | train_spatial_img_features = h_train.create_dataset( 61 | 'spatial_features', (len(train_imgids), num_fixed_boxes, 6), 'f') 62 | 63 | val_img_bb = h_val.create_dataset( 64 | 'image_bb', (len(val_imgids), num_fixed_boxes, 4), 'f') 65 | val_img_features = h_val.create_dataset( 66 | 'image_features', (len(val_imgids), num_fixed_boxes, feature_length), 'f') 67 | val_spatial_img_features = h_val.create_dataset( 68 | 'spatial_features', (len(val_imgids), num_fixed_boxes, 6), 'f') 69 | 70 | train_counter = 0 71 | val_counter = 0 72 | 73 | print("reading tsv...") 74 | with open(infile, "r+b") as tsv_in_file: 75 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) 76 | for item in reader: 77 | item['num_boxes'] = int(item['num_boxes']) 78 | image_id = int(item['image_id']) 79 | image_w = float(item['image_w']) 80 | image_h = float(item['image_h']) 81 | bboxes = np.frombuffer( 82 | base64.decodestring(item['boxes']), 83 | dtype=np.float32).reshape((item['num_boxes'], -1)) 84 | 85 | box_width = bboxes[:, 2] - bboxes[:, 0] 86 | box_height = bboxes[:, 3] - bboxes[:, 1] 87 | scaled_width = box_width / image_w 88 | scaled_height = box_height / image_h 89 | scaled_x = bboxes[:, 0] / image_w 90 | scaled_y = bboxes[:, 1] / image_h 91 | 92 | box_width = box_width[..., np.newaxis] 93 | box_height = box_height[..., np.newaxis] 94 | scaled_width = scaled_width[..., np.newaxis] 95 | scaled_height = scaled_height[..., np.newaxis] 96 | scaled_x = scaled_x[..., np.newaxis] 97 | scaled_y = scaled_y[..., np.newaxis] 98 | 99 | spatial_features = np.concatenate( 100 | (scaled_x, 101 | scaled_y, 102 | scaled_x + scaled_width, 103 | scaled_y + scaled_height, 104 | scaled_width, 105 | scaled_height), 106 | axis=1) 107 | 108 | if image_id in train_imgids: 109 | train_imgids.remove(image_id) 110 | train_indices[image_id] = train_counter 111 | train_img_bb[train_counter, :, :] = bboxes 112 | train_img_features[train_counter, :, :] = np.frombuffer( 113 | base64.decodestring(item['features']), 114 | dtype=np.float32).reshape((item['num_boxes'], -1)) 115 | train_spatial_img_features[train_counter, :, :] = spatial_features 116 | train_counter += 1 117 | elif image_id in val_imgids: 118 | val_imgids.remove(image_id) 119 | val_indices[image_id] = val_counter 120 | val_img_bb[val_counter, :, :] = bboxes 121 | val_img_features[val_counter, :, :] = np.frombuffer( 122 | base64.decodestring(item['features']), 123 | dtype=np.float32).reshape((item['num_boxes'], -1)) 124 | val_spatial_img_features[val_counter, :, :] = spatial_features 125 | val_counter += 1 126 | else: 127 | assert False, 'Unknown image id: %d' % image_id 128 | 129 | if len(train_imgids) != 0: 130 | print('Warning: train_image_ids is not empty') 131 | 132 | if len(val_imgids) != 0: 133 | print('Warning: val_image_ids is not empty') 134 | 135 | cPickle.dump(train_indices, open(train_indices_file, 'wb')) 136 | cPickle.dump(val_indices, open(val_indices_file, 'wb')) 137 | h_train.close() 138 | h_val.close() 139 | print("done!") 140 | -------------------------------------------------------------------------------- /tools/download.sh: -------------------------------------------------------------------------------- 1 | ## Script for downloading data 2 | 3 | # GloVe Vectors 4 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip 5 | unzip data/glove.6B.zip -d data/glove 6 | # rm data/glove.6B.zip 7 | 8 | # Questions 9 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Questions_Train_mscoco.zip 10 | unzip data/v2_Questions_Train_mscoco.zip -d data 11 | # rm data/v2_Questions_Train_mscoco.zip 12 | 13 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Questions_Val_mscoco.zip 14 | unzip data/v2_Questions_Val_mscoco.zip -d data 15 | # rm data/v2_Questions_Val_mscoco.zip 16 | 17 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Questions_Test_mscoco.zip 18 | unzip data/v2_Questions_Test_mscoco.zip -d data 19 | # rm data/v2_Questions_Test_mscoco.zip 20 | 21 | # Annotations 22 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Annotations_Train_mscoco.zip 23 | unzip data/v2_Annotations_Train_mscoco.zip -d data 24 | # rm data/v2_Annotations_Train_mscoco.zip 25 | 26 | wget -P data http://visualqa.org/data/mscoco/vqa/v2_Annotations_Val_mscoco.zip 27 | unzip data/v2_Annotations_Val_mscoco.zip -d data 28 | # rm data/v2_Annotations_Val_mscoco.zip 29 | 30 | # Image Features 31 | wget -P data https://imagecaption.blob.core.windows.net/imagecaption/trainval_36.zip 32 | unzip data/trainval_36.zip -d data 33 | # rm data/trainval_36.zip 34 | -------------------------------------------------------------------------------- /tools/download_hmqa.sh: -------------------------------------------------------------------------------- 1 | # HowManQA 2 | wget -P data https://einstein.ai/research/interpretable-counting-for-visual-question-answering/HowMany-QA.zip 3 | unzip data/HowMany-QA.zip -d data/how_many_qa 4 | 5 | # Visual Genome 6 | wget -P data https://visualgenome.org/static/data/dataset/question_answers.json.zip 7 | unzip data/question_answers.json.zip -d data/how_many_qa 8 | 9 | wget -P data https://visualgenome.org/static/data/dataset/image_data.json.zip 10 | unzip data/image_data.json.zip -d data/how_many_qa 11 | 12 | mv data/how_many_qa/question_answers.json data/how_many_qa/HowMany-QA/visual_genome_question_answers.json 13 | mv data/how_many_qa/image_data.json data/how_many_qa/HowMany-QA/visual_genome_image_data.json 14 | -------------------------------------------------------------------------------- /tools/process.sh: -------------------------------------------------------------------------------- 1 | # Process data 2 | 3 | python tools/create_dictionary.py 4 | python tools/compute_softscore.py 5 | python tools/detection_features_converter.py 6 | -------------------------------------------------------------------------------- /tools/process_hmqa.sh: -------------------------------------------------------------------------------- 1 | python tools/create_how_many_qa_dataset.py -------------------------------------------------------------------------------- /vis/00-selection_image-335.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/00-selection_image-335.png -------------------------------------------------------------------------------- /vis/00-selection_image-364.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/00-selection_image-364.png -------------------------------------------------------------------------------- /vis/01-selection_image-335.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/01-selection_image-335.png -------------------------------------------------------------------------------- /vis/01-selection_image-364.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/01-selection_image-364.png -------------------------------------------------------------------------------- /vis/02-selection_image-335.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/02-selection_image-335.png -------------------------------------------------------------------------------- /vis/02-selection_image-364.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/02-selection_image-364.png -------------------------------------------------------------------------------- /vis/03-selection_image-335.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/03-selection_image-335.png -------------------------------------------------------------------------------- /vis/04-selection_image-335.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/04-selection_image-335.png -------------------------------------------------------------------------------- /vis/image_candidates-335.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/image_candidates-335.png -------------------------------------------------------------------------------- /vis/image_candidates-364.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/image_candidates-364.png -------------------------------------------------------------------------------- /vis/orig_image-335.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/orig_image-335.png -------------------------------------------------------------------------------- /vis/orig_image-364.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanyam5/irlc-vqa-counting/a00bac7a5e2df12ed695e5128437299ea678f0d3/vis/orig_image-364.png --------------------------------------------------------------------------------