├── .gitignore ├── .gitmodules ├── README.md ├── data └── .keep ├── demo_server.py ├── demo_web ├── css │ ├── bootstrap-theme.css │ ├── bootstrap-theme.css.map │ ├── bootstrap-theme.min.css │ ├── bootstrap-theme.min.css.map │ ├── bootstrap.css │ ├── bootstrap.css.map │ ├── bootstrap.min.css │ ├── bootstrap.min.css.map │ └── custom.css ├── fonts │ ├── glyphicons-halflings-regular.eot │ ├── glyphicons-halflings-regular.svg │ ├── glyphicons-halflings-regular.ttf │ ├── glyphicons-halflings-regular.woff │ └── glyphicons-halflings-regular.woff2 ├── images │ └── logos.png ├── index.html └── js │ ├── bootstrap.js │ ├── bootstrap.min.js │ ├── custom.js │ └── npm.js ├── doc ├── mutan.png ├── mutan_noatt.html ├── mutan_noatt.png ├── mutan_noatt_vs_att.html ├── mutan_noatt_vs_att.png └── vqa_task.png ├── eval_res.py ├── extract.py ├── logs └── .keep ├── options ├── vqa │ ├── default.yaml │ ├── mlb_att_trainval.yaml │ ├── mlb_noatt_train.yaml │ ├── mutan_att_trainval.yaml │ └── mutan_noatt_train.yaml └── vqa2 │ ├── default.yaml │ ├── mlb_att_trainval.yaml │ ├── mlb_noatt_train.yaml │ ├── mutan_att_train.yaml │ ├── mutan_att_trainval.yaml │ ├── mutan_att_trainval_vg.yaml │ └── mutan_noatt_train.yaml ├── requirements.txt ├── train.py ├── visu.ipynb ├── visu.py └── vqa ├── __init__.py ├── datasets ├── __init__.py ├── coco.py ├── features.py ├── images.py ├── utils.py ├── vgenome.py ├── vgenome_interim.py ├── vgenome_processed.py ├── vqa.py ├── vqa2_interim.py ├── vqa_interim.py └── vqa_processed.py ├── lib ├── __init__.py ├── criterions.py ├── dataloader.py ├── engine.py ├── logger.py ├── sampler.py └── utils.py └── models ├── __init__.py ├── att.py ├── convnets.py ├── fusion.py ├── noatt.py ├── seq2vec.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/coco 2 | data/porting 3 | data/skip-thoughts 4 | data/vqa 5 | data/vqa2 6 | 7 | logs/porting 8 | logs/vqa_old 9 | logs/vqa 10 | logs/vqa2 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # Mac 18 | .DS_Store 19 | ._.DS_Store 20 | 21 | .ipynb_checkpoints 22 | 23 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "vqa/external/VQA"] 2 | path = vqa/external/VQA 3 | url = https://github.com/Cadene/VQA.git 4 | [submodule "vqa/external/skip-thoughts.torch"] 5 | path = vqa/external/skip-thoughts.torch 6 | url = https://github.com/Cadene/skip-thoughts.torch.git 7 | [submodule "vqa/external/pretrained-models.pytorch"] 8 | path = vqa/external/pretrained-models.pytorch 9 | url = https://github.com/Cadene/pretrained-models.pytorch.git 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Question Answering in pytorch 2 | 3 | **/!\ New version of pytorch for VQA available here:** https://github.com/Cadene/block.bootstrap.pytorch 4 | 5 | This repo was made by [Remi Cadene](http://remicadene.com) (LIP6) and [Hedi Ben-Younes](https://twitter.com/labegne) (LIP6-Heuritech), two PhD Students working on VQA at [UPMC-LIP6](http://lip6.fr) and their professors [Matthieu Cord](http://webia.lip6.fr/~cord) (LIP6) and [Nicolas Thome](http://webia.lip6.fr/~thomen) (LIP6-CNAM). We developed this code in the frame of a research paper called [MUTAN: Multimodal Tucker Fusion for VQA](https://arxiv.org/abs/1705.06676) which is (as far as we know) the current state-of-the-art on the [VQA 1.0 dataset](http://visualqa.org). 6 | 7 | The goal of this repo is two folds: 8 | - to make it easier to reproduce our results, 9 | - to provide an efficient and modular code base to the community for further research on other VQA datasets. 10 | 11 | If you have any questions about our code or model, don't hesitate to contact us or to submit any issues. Pull request are welcome! 12 | 13 | #### News: 14 | 15 | - 16th january 2018: a pretrained vqa2 model and web demo 16 | - 18th july 2017: VQA2, VisualGenome, FBResnet152 (for pytorch) added [v2.0 commit msg](https://github.com/Cadene/vqa.pytorch/commit/42391fd4a39c31e539eb6cb73ecd370bac0f010a) 17 | - 16th july 2017: paper accepted at ICCV2017 18 | - 30th may 2017: poster accepted at CVPR2017 (VQA Workshop) 19 | 20 | #### Summary: 21 | 22 | * [Introduction](#introduction) 23 | * [What is the task about?](#what-is-the-task-about) 24 | * [Quick insight about our method](#quick-insight-about-our-method) 25 | * [Installation](#installation) 26 | * [Requirements](#requirements) 27 | * [Submodules](#submodules) 28 | * [Data](#data) 29 | * [Reproducing results on VQA 1.0](#reproducing-results-on-vqa-10) 30 | * [Features](#features) 31 | * [Pretrained models](#pretrained-models) 32 | * [Reproducing results on VQA 2.0](#reproducing-results-on-vqa-20) 33 | * [Features](#features-20) 34 | * [Pretrained models](#pretrained-models-20) 35 | * [Documentation](#documentation) 36 | * [Architecture](#architecture) 37 | * [Options](#options) 38 | * [Datasets](#datasets) 39 | * [Models](#models) 40 | * [Quick examples](#quick-examples) 41 | * [Extract features from COCO](#extract-features-from-coco) 42 | * [Extract features from VisualGenome](#extract-features-from-visualgenome) 43 | * [Train models on VQA 1.0](#train-models-on-vqa-10) 44 | * [Train models on VQA 2.0](#train-models-on-vqa-20) 45 | * [Train models on VQA + VisualGenome](#train-models-on-vqa-10-or-20--visualgenome) 46 | * [Monitor training](#monitor-training) 47 | * [Restart training](#restart-training) 48 | * [Evaluate models on VQA](#evaluate-models-on-vqa) 49 | * [Web demo](#web-demo) 50 | * [Citation](#citation) 51 | * [Acknowledgment](#acknowledgment) 52 | 53 | ## Introduction 54 | 55 | ### What is the task about? 56 | 57 | The task is about training models in a end-to-end fashion on a multimodal dataset made of triplets: 58 | 59 | - an **image** with no other information than the raw pixels, 60 | - a **question** about visual content(s) on the associated image, 61 | - a short **answer** to the question (one or a few words). 62 | 63 | As you can see in the illustration bellow, two different triplets (but same image) of the VQA dataset are represented. The models need to learn rich multimodal representations to be able to give the right answers. 64 | 65 |

66 | 67 |

68 | 69 | The VQA task is still on active research. However, when it will be solved, it could be very useful to improve human-to-machine interfaces (especially for the blinds). 70 | 71 | ### Quick insight about our method 72 | 73 | The VQA community developped an approach based on four learnable components: 74 | 75 | - a question model which can be a LSTM, GRU, or pretrained Skipthoughts, 76 | - an image model which can be a pretrained VGG16 or ResNet-152, 77 | - a fusion scheme which can be an element-wise sum, concatenation, [MCB](https://arxiv.org/abs/1606.01847), [MLB](https://arxiv.org/abs/1610.04325), or [Mutan](https://arxiv.org/abs/1705.06676), 78 | - optionally, an attention scheme which may have several "glimpses". 79 | 80 |

81 | 82 |

83 | 84 | One of our claim is that the multimodal fusion between the image and the question representations is a critical component. Thus, our proposed model uses a Tucker Decomposition of the correlation Tensor to model richer multimodal interactions in order to provide proper answers. Our best model is based on : 85 | 86 | - a pretrained Skipthoughts for the question model, 87 | - features from a pretrained Resnet-152 (with images of size 3x448x448) for the image model, 88 | - our proposed Mutan (based on a Tucker Decomposition) for the fusion scheme, 89 | - an attention scheme with two "glimpses". 90 | 91 | ## Installation 92 | 93 | ### Requirements 94 | 95 | First install python 3 (we don't provide support for python 2). We advise you to install python 3 and pytorch with Anaconda: 96 | 97 | - [python with anaconda](https://www.continuum.io/downloads) 98 | - [pytorch with CUDA](http://pytorch.org) 99 | 100 | ``` 101 | conda create --name vqa python=3 102 | source activate vqa 103 | conda install pytorch torchvision cuda80 -c soumith 104 | ``` 105 | 106 | Then clone the repo (with the `--recursive` flag for submodules) and install the complementary requirements: 107 | 108 | ``` 109 | cd $HOME 110 | git clone --recursive https://github.com/Cadene/vqa.pytorch.git 111 | cd vqa.pytorch 112 | pip install -r requirements.txt 113 | ``` 114 | 115 | ### Submodules 116 | 117 | Our code has two external dependencies: 118 | 119 | - [VQA](https://github.com/Cadene/VQA) is used to evaluate results files on the valset with the OpendEnded accuracy, 120 | - [skip-thoughts.torch](https://github.com/Cadene/skip-thoughts.torch) is used to import pretrained GRUs and embeddings, 121 | - [pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch) is used to load pretrained convnets. 122 | 123 | ### Data 124 | 125 | Data will be automaticaly downloaded and preprocessed when needed. Links to data are stored in `vqa/datasets/vqa.py`, `vqa/datasets/coco.py` and `vqa/datasets/vgenome.py`. 126 | 127 | 128 | ## Reproducing results on VQA 1.0 129 | 130 | ### Features 131 | 132 | As we first developped on Lua/Torch7, we used the features of [ResNet-152 pretrained with Torch7](https://github.com/facebook/fb.resnet.torch). We ported the pretrained resnet152 trained with Torch7 in pytorch in the v2.0 release. We will provide all the extracted features soon. Meanwhile, you can download the coco features as following: 133 | 134 | ``` 135 | mkdir -p data/coco/extract/arch,fbresnet152torch 136 | cd data/coco/extract/arch,fbresnet152torch 137 | wget https://data.lip6.fr/coco/trainset.hdf5 138 | wget https://data.lip6.fr/coco/trainset.txt 139 | wget https://data.lip6.fr/coco/valset.hdf5 140 | wget https://data.lip6.fr/coco/valset.txt 141 | wget https://data.lip6.fr/coco/testset.hdf5 142 | wget https://data.lip6.fr/coco/testset.txt 143 | ``` 144 | 145 | /!\ There are currently 3 versions of ResNet152: 146 | 147 | - fbresnet152torch which is the torch7 model, 148 | - fbresnet152 which is the porting of the torch7 in pytorch, 149 | - [resnet152](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) which is the pretrained model from torchvision (we've got lower results with it). 150 | 151 | ### Pretrained VQA models 152 | 153 | We currently provide three models trained with our old Torch7 code and ported to Pytorch: 154 | 155 | - MutanNoAtt trained on the VQA 1.0 trainset, 156 | - MLBAtt trained on the VQA 1.0 trainvalset and VisualGenome, 157 | - MutanAtt trained on the VQA 1.0 trainvalset and VisualGenome. 158 | 159 | ``` 160 | mkdir -p logs/vqa 161 | cd logs/vqa 162 | wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mutan_noatt_train.zip 163 | wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mlb_att_trainval.zip 164 | wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mutan_att_trainval.zip 165 | ``` 166 | 167 | Even if we provide results files associated to our pretrained models, you can evaluate them once again on the valset, testset and testdevset using a single command: 168 | 169 | ``` 170 | python train.py -e --path_opt options/vqa/mutan_noatt_train.yaml --resume ckpt 171 | python train.py -e --path_opt options/vqa/mlb_noatt_trainval.yaml --resume ckpt 172 | python train.py -e --path_opt options/vqa/mutan_att_trainval.yaml --resume ckpt 173 | ``` 174 | 175 | To obtain test and testdev results on VQA 1.0, you will need to zip your result json file (name it as `results.zip`) and to submit it on the [evaluation server](https://competitions.codalab.org/competitions/6961). 176 | 177 | 178 | ## Reproducing results on VQA 2.0 179 | 180 | ### Features 2.0 181 | 182 | You must download the coco dataset (and visual genome if needed) and then extract the features with a convolutional neural network. 183 | 184 | ### Pretrained VQA models 2.0 185 | 186 | We currently provide three models trained with our current pytorch code on VQA 2.0 187 | 188 | - MutanAtt trained on the trainset with the fbresnet152 features, 189 | - MutanAtt trained on thetrainvalset with the fbresnet152 features. 190 | 191 | ``` 192 | cd $VQAPYTORCH 193 | mkdir -p logs/vqa2 194 | cd logs/vqa2 195 | wget http://data.lip6.fr/cadene/vqa.pytorch/vqa2/mutan_att_train.zip 196 | wget http://data.lip6.fr/cadene/vqa.pytorch/vqa2/mutan_att_trainval.zip 197 | ``` 198 | 199 | ## Documentation 200 | 201 | ### Architecture 202 | 203 | ``` 204 | . 205 | ├── options # default options dir containing yaml files 206 | ├── logs # experiments dir containing directories of logs (one by experiment) 207 | ├── data # datasets directories 208 | | ├── coco # images and features 209 | | ├── vqa # raw, interim and processed data 210 | | ├── vgenome # raw, interim, processed data + images and features 211 | | └── ... 212 | ├── vqa # vqa package dir 213 | | ├── datasets # datasets classes & functions dir (vqa, coco, vgenome, images, features, etc.) 214 | | ├── external # submodules dir (VQA, skip-thoughts.torch, pretrained-models.pytorch) 215 | | ├── lib # misc classes & func dir (engine, logger, dataloader, etc.) 216 | | └── models # models classes & func dir (att, fusion, notatt, seq2vec, convnets) 217 | | 218 | ├── train.py # train & eval models 219 | ├── eval_res.py # eval results files with OpenEnded metric 220 | ├── extract.py # extract features from coco with CNNs 221 | └── visu.py # visualize logs and monitor training 222 | ``` 223 | 224 | ### Options 225 | 226 | There are three kind of options: 227 | 228 | - options from the yaml options files stored in the `options` directory which are used as default (path to directory, logs, model, features, etc.) 229 | - options from the ArgumentParser in the `train.py` file which are set to None and can overwrite default options (learning rate, batch size, etc.) 230 | - options from the ArgumentParser in the `train.py` file which are set to default values (print frequency, number of threads, resume model, evaluate model, etc.) 231 | 232 | You can easly add new options in your custom yaml file if needed. Also, if you want to grid search a parameter, you can add an ArgumentParser option and modify the dictionnary in `train.py:L80`. 233 | 234 | ### Datasets 235 | 236 | We currently provide four datasets: 237 | 238 | - [COCOImages](http://mscoco.org/) currently used to extract features, it comes with three datasets: trainset, valset and testset 239 | - [VisualGenomeImages]() currently used to extract features, it comes with one split: trainset 240 | - [VQA 1.0](http://www.visualqa.org/vqa_v1_download.html) comes with four datasets: trainset, valset, testset (including test-std and test-dev) and "trainvalset" (concatenation of trainset and valset) 241 | - [VQA 2.0](http://www.visualqa.org) same but twice bigger (however same images than VQA 1.0) 242 | 243 | We plan to add: 244 | 245 | - [CLEVR](http://cs.stanford.edu/people/jcjohns/clevr/) 246 | 247 | ### Models 248 | 249 | We currently provide four models: 250 | 251 | - MLBNoAtt: a strong baseline (BayesianGRU + Element-wise product) 252 | - [MLBAtt](https://arxiv.org/abs/1610.04325): the previous state-of-the-art which adds an attention strategy 253 | - MutanNoAtt: our proof of concept (BayesianGRU + Mutan Fusion) 254 | - MutanAtt: the current state-of-the-art 255 | 256 | We plan to add several other strategies in the futur. 257 | 258 | ## Quick examples 259 | 260 | ### Extract features from COCO 261 | 262 | The needed images will be automaticaly downloaded to `dir_data` and the features will be extracted with a resnet152 by default. 263 | 264 | There are three options for `mode` : 265 | 266 | - `att`: features will be of size 2048x14x14, 267 | - `noatt`: features will be of size 2048, 268 | - `both`: default option. 269 | 270 | Beware, you will need some space on your SSD: 271 | 272 | - 32GB for the images, 273 | - 125GB for the train features, 274 | - 123GB for the test features, 275 | - 61GB for the val features. 276 | 277 | ``` 278 | python extract.py -h 279 | python extract.py --dir_data data/coco --data_split train 280 | python extract.py --dir_data data/coco --data_split val 281 | python extract.py --dir_data data/coco --data_split test 282 | ``` 283 | 284 | Note: By default our code will share computations over all available GPUs. If you want to select only one or a few, use the following prefix: 285 | 286 | ``` 287 | CUDA_VISIBLE_DEVICES=0 python extract.py 288 | CUDA_VISIBLE_DEVICES=1,2 python extract.py 289 | ``` 290 | 291 | ### Extract features from VisualGenome 292 | 293 | Same here, but only train is available: 294 | 295 | ``` 296 | python extract.py --dataset vgenome --dir_data data/vgenome --data_split train 297 | ``` 298 | 299 | 300 | ### Train models on VQA 1.0 301 | 302 | Display help message, selected options and run default. The needed data will be automaticaly downloaded and processed using the options in `options/vqa/default.yaml`. 303 | 304 | ``` 305 | python train.py -h 306 | python train.py --help_opt 307 | python train.py 308 | ``` 309 | 310 | Run a MutanNoAtt model with default options. 311 | 312 | ``` 313 | python train.py --path_opt options/vqa/mutan_noatt_train.yaml --dir_logs logs/vqa/mutan_noatt_train 314 | ``` 315 | 316 | Run a MutanAtt model on the trainset and evaluate on the valset after each epoch. 317 | 318 | ``` 319 | python train.py --vqa_trainsplit train --path_opt options/vqa/mutan_att_trainval.yaml 320 | ``` 321 | 322 | Run a MutanAtt model on the trainset and valset (by default) and run throw the testset after each epoch (produce a results file that you can submit to the evaluation server). 323 | 324 | ``` 325 | python train.py --vqa_trainsplit trainval --path_opt options/vqa/mutan_att_trainval.yaml 326 | ``` 327 | 328 | ### Train models on VQA 2.0 329 | 330 | See options of [vqa2/mutan_att_trainval](https://github.com/Cadene/vqa.pytorch/blob/master/options/vqa2/mutan_att_trainval.yaml): 331 | 332 | ``` 333 | python train.py --path_opt options/vqa2/mutan_att_trainval.yaml 334 | ``` 335 | 336 | ### Train models on VQA (1.0 or 2.0) + VisualGenome 337 | 338 | See options of [vqa2/mutan_att_trainval_vg](https://github.com/Cadene/vqa.pytorch/blob/master/options/vqa2/mutan_att_trainval_vg.yaml): 339 | 340 | ``` 341 | python train.py --path_opt options/vqa2/mutan_att_trainval_vg.yaml 342 | ``` 343 | 344 | ### Monitor training 345 | 346 | Create a visualization of an experiment using `plotly` to monitor the training, just like the picture bellow (**click the image to access the html/js file**): 347 | 348 |

349 | 350 | 351 | 352 |

353 | 354 | Note that you have to wait until the first open ended accuracy has finished processing and then the html file will be created and will pop out on your default browser. The html will be refreshed every 60 seconds. However, you will currently need to press F5 on your browser to see the change. 355 | 356 | ``` 357 | python visu.py --dir_logs logs/vqa/mutan_noatt 358 | ``` 359 | 360 | Create a visualization of multiple experiments to compare them or monitor them like the picture bellow (**click the image to access the html/js file**): 361 | 362 |

363 | 364 | 365 | 366 |

367 | 368 | ``` 369 | python visu.py --dir_logs logs/vqa/mutan_noatt,logs/vqa/mutan_att 370 | ``` 371 | 372 | 373 | 374 | ### Restart training 375 | 376 | Restart the model from the last checkpoint. 377 | 378 | ``` 379 | python train.py --path_opt options/vqa/mutan_noatt.yaml --dir_logs logs/vqa/mutan_noatt --resume ckpt 380 | ``` 381 | 382 | Restart the model from the best checkpoint. 383 | 384 | ``` 385 | python train.py --path_opt options/vqa/mutan_noatt.yaml --dir_logs logs/vqa/mutan_noatt --resume best 386 | ``` 387 | 388 | ### Evaluate models on VQA 389 | 390 | Evaluate the model from the best checkpoint. If your model has been trained on the training set only (`vqa_trainsplit=train`), the model will be evaluate on the valset and will run throw the testset. If it was trained on the trainset + valset (`vqa_trainsplit=trainval`), it will not be evaluate on the valset. 391 | 392 | ``` 393 | python train.py --vqa_trainsplit train --path_opt options/vqa/mutan_att.yaml --dir_logs logs/vqa/mutan_att --resume best -e 394 | ``` 395 | 396 | ### Web demo 397 | 398 | You must set your local ip address and port in `demo_server.py` line 169 and your global ip address and port in `demo_web/js/custom.js` line 51. 399 | The port associated to the global ip address must redirect to your local ip address. 400 | 401 | Launch your API: 402 | ``` 403 | CUDA_VISIBLE_DEVICES=0 python demo_server.py 404 | ``` 405 | 406 | Open `demo_web/index.html` on your browser to access the API with a human interface. 407 | 408 | ## Citation 409 | 410 | Please cite the arXiv paper if you use Mutan in your work: 411 | 412 | ``` 413 | @article{benyounescadene2017mutan, 414 | author = {Hedi Ben-Younes and 415 | R{\'{e}}mi Cad{\`{e}}ne and 416 | Nicolas Thome and 417 | Matthieu Cord}, 418 | title = {MUTAN: Multimodal Tucker Fusion for Visual Question Answering}, 419 | journal = {ICCV}, 420 | year = {2017}, 421 | url = {http://arxiv.org/abs/1705.06676} 422 | } 423 | ``` 424 | 425 | ## Acknowledgment 426 | 427 | Special thanks to the authors of [MLB](https://arxiv.org/abs/1610.04325) for providing some [Torch7 code](https://github.com/jnhwkim/MulLowBiVQA), [MCB](https://arxiv.org/abs/1606.01847) for providing some [Caffe code](https://github.com/akirafukui/vqa-mcb), and our professors and friends from LIP6 for the perfect working atmosphere. 428 | -------------------------------------------------------------------------------- /data/.keep: -------------------------------------------------------------------------------- 1 | .empty -------------------------------------------------------------------------------- /demo_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import yaml 4 | import json 5 | import argparse 6 | import re 7 | import base64 8 | import torch 9 | from torch.autograd import Variable 10 | from PIL import Image 11 | from io import BytesIO 12 | from pprint import pprint 13 | 14 | from werkzeug.wrappers import Request, Response 15 | from werkzeug.serving import run_simple 16 | 17 | import torchvision.transforms as transforms 18 | import vqa.lib.utils as utils 19 | import vqa.datasets as datasets 20 | import vqa.models as models 21 | import vqa.models.convnets as convnets 22 | from vqa.datasets.vqa_processed import tokenize_mcb 23 | from train import load_checkpoint 24 | 25 | parser = argparse.ArgumentParser( 26 | description='Demo server', 27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument('--dir_logs', type=str, 29 | #default='logs/vqa2/blocmutan_noatt_fbresnet152torchported_save_all', 30 | default='logs/vqa2/mutan_att_train', 31 | help='dir logs') 32 | parser.add_argument('--path_opt', type=str, 33 | #default='logs/vqa2/blocmutan_noatt_fbresnet152torchported_save_all/blocmutan_noatt.yaml', 34 | default='logs/vqa2/mutan_att_train/mutan_att_train.yaml', 35 | help='path to a yaml options file') 36 | parser.add_argument('--resume', type=str, 37 | default='best', 38 | help='path to latest checkpoint') 39 | parser.add_argument('--cuda', type=bool, 40 | const=True, 41 | nargs='?', 42 | help='path to latest checkpoint') 43 | 44 | @Request.application 45 | def application(request): 46 | print('') 47 | if 'visual' in request.form and 'question' in request.form: 48 | visual = process_visual(request.form['visual']) 49 | question = process_question(request.form['question']) 50 | answer = process_answer(model(visual, question)) 51 | response = Response(answer) 52 | 53 | elif 'question' not in request.form: 54 | response = Response('Question missing') 55 | 56 | elif 'visual' not in request.form: 57 | response = Response('Image missing') 58 | 59 | else: 60 | response = Response('what?') 61 | 62 | response.headers.add('Access-Control-Allow-Origin', '*') 63 | response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,PATCH') 64 | response.headers.add('Access-Control-Allow-Headers', 'Content-Type, Authorization') 65 | response.headers.add('X-XSS-Protection', '0') 66 | return response 67 | 68 | def process_visual(visual_strb64): 69 | visual_strb64 = re.sub('^data:image/.+;base64,', '', visual_strb64) 70 | visual_PIL = Image.open(BytesIO(base64.b64decode(visual_strb64))) 71 | visual_tensor = transform(visual_PIL) 72 | visual_data = torch.FloatTensor(1, 3, 73 | visual_tensor.size(1), 74 | visual_tensor.size(2)) 75 | visual_data[0][0] = visual_tensor[0] 76 | visual_data[0][1] = visual_tensor[1] 77 | visual_data[0][2] = visual_tensor[2] 78 | print('visual', visual_data.size(), visual_data.mean()) 79 | if args.cuda: 80 | visual_data = visual_data.cuda(async=True) 81 | visual_input = Variable(visual_data, volatile=True) 82 | visual_features = cnn(visual_input) 83 | if 'NoAtt' in options['model']['arch']: 84 | nb_regions = visual_features.size(2) * visual_features.size(3) 85 | visual_features = visual_features.sum(3).sum(2).div(nb_regions).view(-1, 2048) 86 | return visual_features 87 | 88 | def process_question(question_str): 89 | question_tokens = tokenize_mcb(question_str) 90 | question_data = torch.LongTensor(1, len(question_tokens)) 91 | for i, word in enumerate(question_tokens): 92 | if word in trainset.word_to_wid: 93 | question_data[0][i] = trainset.word_to_wid[word] 94 | else: 95 | question_data[0][i] = trainset.word_to_wid['UNK'] 96 | if args.cuda: 97 | question_data = question_data.cuda(async=True) 98 | question_input = Variable(question_data, volatile=True) 99 | print('question', question_str, question_tokens, question_data) 100 | 101 | return question_input 102 | 103 | def process_answer(answer_var): 104 | answer_sm = torch.nn.functional.softmax(answer_var.data[0].cpu()) 105 | max_, aid = answer_sm.topk(5, 0, True, True) 106 | ans = [] 107 | val = [] 108 | for i in range(5): 109 | ans.append(trainset.aid_to_ans[aid.data[i]]) 110 | val.append(max_.data[i]) 111 | 112 | att = [] 113 | for x_att in model.list_att: 114 | img = x_att.view(1,14,14).cpu() 115 | img = transforms.ToPILImage()(img) 116 | buffer_ = BytesIO() 117 | img.save(buffer_, format="PNG") 118 | img_str = base64.b64encode(buffer_.getvalue()).decode() 119 | img_str = 'data:image/png;base64,'+img_str 120 | att.append(img_str) 121 | 122 | answer = {'ans':ans,'val':val,'att':att} 123 | answer_str = json.dumps(answer) 124 | 125 | return answer_str 126 | 127 | def main(): 128 | global args, options, model, cnn, transform, trainset 129 | args = parser.parse_args() 130 | 131 | options = { 132 | 'logs': { 133 | 'dir_logs': args.dir_logs 134 | } 135 | } 136 | if args.path_opt is not None: 137 | with open(args.path_opt, 'r') as handle: 138 | options_yaml = yaml.load(handle) 139 | options = utils.update_values(options, options_yaml) 140 | print('## args'); pprint(vars(args)) 141 | print('## options'); pprint(options) 142 | 143 | trainset = datasets.factory_VQA(options['vqa']['trainsplit'], 144 | options['vqa']) 145 | #options['coco']) 146 | 147 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 148 | std=[0.229, 0.224, 0.225]) 149 | transform = transforms.Compose([ 150 | transforms.Scale(options['coco']['size']), 151 | transforms.CenterCrop(options['coco']['size']), 152 | transforms.ToTensor(), 153 | normalize, 154 | ]) 155 | 156 | opt_factory_cnn = { 157 | 'arch': options['coco']['arch'] 158 | } 159 | cnn = convnets.factory(opt_factory_cnn, cuda=args.cuda, data_parallel=False) 160 | model = models.factory(options['model'], 161 | trainset.vocab_words(), 162 | trainset.vocab_answers(), 163 | cuda=args.cuda, 164 | data_parallel=False) 165 | model.eval() 166 | start_epoch, best_acc1, _ = load_checkpoint(model, None, 167 | os.path.join(options['logs']['dir_logs'], args.resume)) 168 | 169 | my_local_ip = '192.168.0.32' 170 | my_local_port = 3456 171 | run_simple(my_local_ip, my_local_port, application) 172 | 173 | if __name__ == '__main__': 174 | main() 175 | 176 | -------------------------------------------------------------------------------- /demo_web/css/custom.css: -------------------------------------------------------------------------------- 1 | .btn-file { 2 | position: relative; 3 | overflow: hidden; 4 | } 5 | .btn-file input[type=file] { 6 | position: absolute; 7 | top: 0; 8 | right: 0; 9 | min-width: 100%; 10 | min-height: 100%; 11 | font-size: 100px; 12 | text-align: right; 13 | filter: alpha(opacity=0); 14 | opacity: 0; 15 | outline: none; 16 | background: white; 17 | cursor: inherit; 18 | display: block; 19 | } 20 | 21 | #vqa-visual{ 22 | width: 300px; 23 | } -------------------------------------------------------------------------------- /demo_web/fonts/glyphicons-halflings-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/fonts/glyphicons-halflings-regular.eot -------------------------------------------------------------------------------- /demo_web/fonts/glyphicons-halflings-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/fonts/glyphicons-halflings-regular.ttf -------------------------------------------------------------------------------- /demo_web/fonts/glyphicons-halflings-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/fonts/glyphicons-halflings-regular.woff -------------------------------------------------------------------------------- /demo_web/fonts/glyphicons-halflings-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/fonts/glyphicons-halflings-regular.woff2 -------------------------------------------------------------------------------- /demo_web/images/logos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/images/logos.png -------------------------------------------------------------------------------- /demo_web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | VQA Demo MUTAN 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 24 | 25 | 26 | 27 | 54 | 55 | 56 |
57 |
58 |

Welcome to the demo page

59 |

The VQA task is about training models in a end-to-end fashion on a multimodal dataset made of triplets: (image, question, answer). We recently proposed a new model called MUTAN based on Neural Networks and the Tucker Decomposition to address this machine learning problem. Please, try out our latest model :)

60 | 61 |

Best regards. Remi Cadene, Hedi Ben-Younes, Matthieu Cord and Nicolas Thome.

62 |

63 | Github » 64 | Paper » 65 |

66 |
67 |
68 | 69 | 90 | 91 | 92 | 93 |
94 |

API

95 | 96 |
97 |

98 |
99 |
100 | 101 |
102 | 103 |
104 | 105 | Browse... 106 | 107 | 108 | 109 |
110 | 111 |
112 |
113 |
114 |
115 | 116 | 117 |
118 | 121 |
122 |
123 |
124 | 125 |
MUTAN is waiting for your question.
126 |
127 |
128 |
129 | 130 |
131 | 159 | 160 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /demo_web/js/custom.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function () { 2 | 3 | // Image 4 | 5 | $(document).on('change', '.btn-file :file', function() { 6 | var input = $(this), 7 | label = input.val().replace(/\\/g, '/').replace(/.*\//, ''); 8 | input.trigger('fileselect', [label]); 9 | }); 10 | 11 | $('.btn-file :file').on('fileselect', function(event, label) { 12 | 13 | var input = $(this).parents('.input-group').find(':text'), 14 | log = label; 15 | 16 | if( input.length ) { 17 | input.val(log); 18 | } else { 19 | if( log ) alert(log); 20 | } 21 | 22 | }); 23 | function readURL(input) { 24 | if (input.files && input.files[0]) { 25 | var reader = new FileReader(); 26 | 27 | reader.onload = function (e) { 28 | $('#vqa-visual').attr('src', e.target.result); 29 | } 30 | 31 | reader.readAsDataURL(input.files[0]); 32 | } 33 | } 34 | 35 | $("#imgInp").change(function(){ 36 | readURL(this); 37 | }); 38 | 39 | // Send Image + Question 40 | 41 | var formBasic = function () { 42 | var formData = $("#formBasic").serialize(); 43 | var data = { visual : $('#vqa-visual').attr('src'), 44 | question : $('#vqa-question').val()} 45 | console.log(data); 46 | $.ajax({ 47 | 48 | type: 'post', 49 | data: data, 50 | dataType: 'json', 51 | url: 'http://', // your global ip address and port 52 | 53 | // error: function () { 54 | // alert("There was an error processing this page."); 55 | // return false; 56 | // }, 57 | 58 | complete: function (output) { 59 | //console.log(output); 60 | //console.log(output.responseText); 61 | var ul = $(''); 62 | for (i=0; i < output.responseJSON.ans.length; i++) 63 | { 64 | var li = $('
  • '); 65 | var span = $(''); 66 | 67 | span.text(output.responseJSON.ans[i]+' ('+output.responseJSON.val[i]+')'); 68 | 69 | li.append(span); 70 | ul.append(li); 71 | } 72 | 73 | for (i=0; i < output.responseJSON.att.length; i++) 74 | { 75 | var img = $(''); 76 | 77 | ul.append(img); 78 | } 79 | 80 | $('#vqa-answer').append(ul); 81 | } 82 | }); 83 | return false; 84 | }; 85 | 86 | $("#basic-submit").on("click", function (e) { 87 | e.preventDefault(); 88 | formBasic(); 89 | }); 90 | }); -------------------------------------------------------------------------------- /demo_web/js/npm.js: -------------------------------------------------------------------------------- 1 | // This file is autogenerated via the `commonjs` Grunt task. You can require() this file in a CommonJS environment. 2 | require('../../js/transition.js') 3 | require('../../js/alert.js') 4 | require('../../js/button.js') 5 | require('../../js/carousel.js') 6 | require('../../js/collapse.js') 7 | require('../../js/dropdown.js') 8 | require('../../js/modal.js') 9 | require('../../js/tooltip.js') 10 | require('../../js/popover.js') 11 | require('../../js/scrollspy.js') 12 | require('../../js/tab.js') 13 | require('../../js/affix.js') -------------------------------------------------------------------------------- /doc/mutan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/doc/mutan.png -------------------------------------------------------------------------------- /doc/mutan_noatt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/doc/mutan_noatt.png -------------------------------------------------------------------------------- /doc/mutan_noatt_vs_att.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/doc/mutan_noatt_vs_att.png -------------------------------------------------------------------------------- /doc/vqa_task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/doc/vqa_task.png -------------------------------------------------------------------------------- /eval_res.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import os 5 | from os.path import join 6 | import sys 7 | #import pickle 8 | helperDir = 'vqa/external/VQA/' 9 | sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(helperDir)) 10 | sys.path.insert(0, '%s/PythonEvaluationTools/vqaEvaluation' %(helperDir)) 11 | from vqa import VQA 12 | from vqaEval import VQAEval 13 | 14 | 15 | if __name__=="__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dir_vqa', type=str, default='/local/cadene/data/vqa') 18 | parser.add_argument('--dir_epoch', type=str, default='logs/16_12_13_20:39:55/epoch,1') 19 | parser.add_argument('--subtype', type=str, default='train2014') 20 | args = parser.parse_args() 21 | 22 | diranno = join(args.dir_vqa, 'raw', 'annotations') 23 | annFile = join(diranno, 'mscoco_%s_annotations.json' % (args.subtype)) 24 | quesFile = join(diranno, 'OpenEnded_mscoco_%s_questions.json' % (args.subtype)) 25 | vqa = VQA(annFile, quesFile) 26 | 27 | taskType = 'OpenEnded' 28 | dataType = 'mscoco' 29 | dataSubType = args.subtype 30 | resultType = 'model' 31 | fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType'] 32 | 33 | [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = \ 34 | ['%s/%s_%s_%s_%s_%s.json' % (args.dir_epoch, taskType, dataType, 35 | dataSubType, resultType, fileType) for fileType in fileTypes] 36 | vqaRes = vqa.loadRes(resFile, quesFile) 37 | vqaEval = VQAEval(vqa, vqaRes, n=2) 38 | 39 | quesIds = [int(d['question_id']) for d in json.loads(open(resFile).read())] 40 | vqaEval.evaluate(quesIds=quesIds) 41 | 42 | json.dump(vqaEval.accuracy, open(accuracyFile, 'w')) 43 | #json.dump(vqaEval.evalQA, open(evalQAFile, 'w')) 44 | #json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w')) 45 | #json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w')) -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import h5py 5 | import numpy 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | from torch.autograd import Variable 12 | 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | 16 | import vqa.models.convnets as convnets 17 | import vqa.datasets as datasets 18 | from vqa.lib.dataloader import DataLoader 19 | from vqa.lib.logger import AvgMeter 20 | 21 | parser = argparse.ArgumentParser(description='Extract') 22 | parser.add_argument('--dataset', default='coco', 23 | choices=['coco', 'vgenome'], 24 | help='dataset type: coco (default) | vgenome') 25 | parser.add_argument('--dir_data', default='data/coco', 26 | help='dir dataset to download or/and load images') 27 | parser.add_argument('--data_split', default='train', type=str, 28 | help='Options: (default) train | val | test') 29 | parser.add_argument('--arch', '-a', default='fbresnet152', 30 | choices=convnets.model_names, 31 | help='model architecture: ' + 32 | ' | '.join(convnets.model_names) + 33 | ' (default: fbresnet152)') 34 | parser.add_argument('--workers', default=4, type=int, 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--batch_size', '-b', default=80, type=int, 37 | help='mini-batch size (default: 80)') 38 | parser.add_argument('--mode', default='both', type=str, 39 | help='Options: att | noatt | (default) both') 40 | parser.add_argument('--size', default=448, type=int, 41 | help='Image size (448 for noatt := avg pooling to get 224) (default:448)') 42 | 43 | 44 | def main(): 45 | global args 46 | args = parser.parse_args() 47 | 48 | print("=> using pre-trained model '{}'".format(args.arch)) 49 | model = convnets.factory({'arch':args.arch}, cuda=True, data_parallel=True) 50 | 51 | extract_name = 'arch,{}_size,{}'.format(args.arch, args.size) 52 | 53 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 54 | std=[0.229, 0.224, 0.225]) 55 | 56 | if args.dataset == 'coco': 57 | if 'coco' not in args.dir_data: 58 | raise ValueError('"coco" string not in dir_data') 59 | dataset = datasets.COCOImages(args.data_split, dict(dir=args.dir_data), 60 | transform=transforms.Compose([ 61 | transforms.Scale(args.size), 62 | transforms.CenterCrop(args.size), 63 | transforms.ToTensor(), 64 | normalize, 65 | ])) 66 | elif args.dataset == 'vgenome': 67 | if args.data_split != 'train': 68 | raise ValueError('train split is required for vgenome') 69 | if 'vgenome' not in args.dir_data: 70 | raise ValueError('"vgenome" string not in dir_data') 71 | dataset = datasets.VisualGenomeImages(args.data_split, dict(dir=args.dir_data), 72 | transform=transforms.Compose([ 73 | transforms.Scale(args.size), 74 | transforms.CenterCrop(args.size), 75 | transforms.ToTensor(), 76 | normalize, 77 | ])) 78 | 79 | data_loader = DataLoader(dataset, 80 | batch_size=args.batch_size, shuffle=False, 81 | num_workers=args.workers, pin_memory=True) 82 | 83 | dir_extract = os.path.join(args.dir_data, 'extract', extract_name) 84 | path_file = os.path.join(dir_extract, args.data_split + 'set') 85 | os.system('mkdir -p ' + dir_extract) 86 | 87 | extract(data_loader, model, path_file, args.mode) 88 | 89 | 90 | def extract(data_loader, model, path_file, mode): 91 | path_hdf5 = path_file + '.hdf5' 92 | path_txt = path_file + '.txt' 93 | hdf5_file = h5py.File(path_hdf5, 'w') 94 | 95 | # estimate output shapes 96 | output = model(Variable(torch.ones(1, 3, args.size, args.size), 97 | volatile=True)) 98 | 99 | nb_images = len(data_loader.dataset) 100 | if mode == 'both' or mode == 'att': 101 | shape_att = (nb_images, output.size(1), output.size(2), output.size(3)) 102 | print('Warning: shape_att={}'.format(shape_att)) 103 | hdf5_att = hdf5_file.create_dataset('att', shape_att, 104 | dtype='f')#, compression='gzip') 105 | if mode == 'both' or mode == 'noatt': 106 | shape_noatt = (nb_images, output.size(1)) 107 | print('Warning: shape_noatt={}'.format(shape_noatt)) 108 | hdf5_noatt = hdf5_file.create_dataset('noatt', shape_noatt, 109 | dtype='f')#, compression='gzip') 110 | 111 | model.eval() 112 | 113 | batch_time = AvgMeter() 114 | data_time = AvgMeter() 115 | begin = time.time() 116 | end = time.time() 117 | 118 | idx = 0 119 | for i, input in enumerate(data_loader): 120 | input_var = Variable(input['visual'], volatile=True) 121 | output_att = model(input_var) 122 | 123 | nb_regions = output_att.size(2) * output_att.size(3) 124 | output_noatt = output_att.sum(3).sum(2).div(nb_regions).view(-1, 2048) 125 | 126 | batch_size = output_att.size(0) 127 | if mode == 'both' or mode == 'att': 128 | hdf5_att[idx:idx+batch_size] = output_att.data.cpu().numpy() 129 | if mode == 'both' or mode == 'noatt': 130 | hdf5_noatt[idx:idx+batch_size] = output_noatt.data.cpu().numpy() 131 | idx += batch_size 132 | 133 | torch.cuda.synchronize() 134 | batch_time.update(time.time() - end) 135 | end = time.time() 136 | 137 | if i % 1 == 0: 138 | print('Extract: [{0}/{1}]\t' 139 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 140 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format( 141 | i, len(data_loader), 142 | batch_time=batch_time, 143 | data_time=data_time,)) 144 | 145 | hdf5_file.close() 146 | 147 | # Saving image names in the same order than extraction 148 | with open(path_txt, 'w') as handle: 149 | for name in data_loader.dataset.dataset.imgs: 150 | handle.write(name + '\n') 151 | 152 | end = time.time() - begin 153 | print('Finished in {}m and {}s'.format(int(end/60), int(end%60))) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /logs/.keep: -------------------------------------------------------------------------------- 1 | .empty -------------------------------------------------------------------------------- /options/vqa/default.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa/default 3 | vqa: 4 | dataset: VQA 5 | dir: data/vqa 6 | trainsplit: train 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152torch 16 | mode: noatt 17 | model: 18 | arch: MLBNoAtt 19 | seq2vec: 20 | arch: skipthoughts 21 | dir_st: data/skip-thoughts 22 | type: UniSkip 23 | dropout: 0.25 24 | fixed_emb: False 25 | fusion: 26 | dim_v: 2048 27 | dim_q: 2400 28 | dim_h: 1200 29 | dropout_v: 0.5 30 | dropout_q: 0.5 31 | activation_v: tanh 32 | activation_q: tanh 33 | classif: 34 | activation: tanh 35 | dropout: 0.5 36 | optim: 37 | lr: 0.0001 38 | batch_size: 512 39 | epochs: 100 40 | -------------------------------------------------------------------------------- /options/vqa/mlb_att_trainval.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa/mlb_att_trainval 3 | vqa: 4 | dataset: VQA 5 | dir: data/vqa 6 | trainsplit: trainval 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152torch 16 | mode: att 17 | model: 18 | arch: MLBAtt 19 | dim_v: 2048 20 | dim_q: 2400 21 | seq2vec: 22 | arch: skipthoughts 23 | dir_st: data/skip-thoughts 24 | type: BayesianUniSkip 25 | dropout: 0.25 26 | fixed_emb: False 27 | attention: 28 | nb_glimpses: 4 29 | dim_h: 1200 30 | dropout_v: 0.5 31 | dropout_q: 0.5 32 | dropout_mm: 0.5 33 | activation_v: tanh 34 | activation_q: tanh 35 | activation_mm: tanh 36 | fusion: 37 | dim_h: 1200 38 | dropout_v: 0.5 39 | dropout_q: 0.5 40 | activation_v: tanh 41 | activation_q: tanh 42 | classif: 43 | activation: tanh 44 | dropout: 0.5 45 | optim: 46 | lr: 0.0001 47 | batch_size: 128 48 | epochs: 100 49 | -------------------------------------------------------------------------------- /options/vqa/mlb_noatt_train.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa/mlb_noatt_train 3 | vqa: 4 | dataset: VQA 5 | dir: data/vqa 6 | trainsplit: train 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152torch 16 | mode: noatt 17 | model: 18 | arch: MLBNoAtt 19 | seq2vec: 20 | arch: skipthoughts 21 | dir_st: data/skip-thoughts 22 | type: BayesianUniSkip 23 | dropout: 0.25 24 | fixed_emb: False 25 | fusion: 26 | dim_v: 2048 27 | dim_q: 2400 28 | dim_h: 1200 29 | dropout_v: 0.5 30 | dropout_q: 0.5 31 | activation_v: tanh 32 | activation_q: tanh 33 | classif: 34 | activation: tanh 35 | dropout: 0.5 36 | optim: 37 | lr: 0.0001 38 | batch_size: 512 39 | epochs: 100 40 | -------------------------------------------------------------------------------- /options/vqa/mutan_att_trainval.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa/mutan_att_trainval 3 | vqa: 4 | dataset: VQA 5 | dir: data/vqa 6 | trainsplit: trainval 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152 16 | mode: att 17 | size: 448 18 | model: 19 | arch: MutanAtt 20 | dim_v: 2048 21 | dim_q: 2400 22 | seq2vec: 23 | arch: skipthoughts 24 | dir_st: data/skip-thoughts 25 | type: BayesianUniSkip 26 | dropout: 0.25 27 | fixed_emb: False 28 | attention: 29 | nb_glimpses: 2 30 | dim_hv: 310 31 | dim_hq: 310 32 | dim_mm: 510 33 | R: 5 34 | dropout_v: 0.5 35 | dropout_q: 0.5 36 | dropout_mm: 0.5 37 | activation_v: tanh 38 | activation_q: tanh 39 | dropout_hv: 0 40 | dropout_hq: 0 41 | fusion: 42 | dim_hv: 620 43 | dim_hq: 310 44 | dim_mm: 510 45 | R: 5 46 | dropout_v: 0.5 47 | dropout_q: 0.5 48 | activation_v: tanh 49 | activation_q: tanh 50 | dropout_hv: 0 51 | dropout_hq: 0 52 | classif: 53 | dropout: 0.5 54 | optim: 55 | lr: 0.0001 56 | batch_size: 128 57 | epochs: 100 58 | -------------------------------------------------------------------------------- /options/vqa/mutan_noatt_train.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa/mutan_noatt_train 3 | vqa: 4 | dataset: VQA 5 | dir: data/vqa 6 | trainsplit: train 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152torch 16 | mode: noatt 17 | model: 18 | arch: MutanNoAtt 19 | seq2vec: 20 | arch: skipthoughts 21 | dir_st: data/skip-thoughts 22 | type: BayesianUniSkip 23 | dropout: 0.25 24 | fixed_emb: False 25 | fusion: 26 | dim_v: 2048 27 | dim_q: 2400 28 | dim_hv: 360 29 | dim_hq: 360 30 | dim_mm: 360 31 | R: 10 32 | dropout_v: 0.5 33 | dropout_q: 0.5 34 | activation_v: tanh 35 | activation_q: tanh 36 | dropout_hv: 0 37 | dropout_hq: 0 38 | classif: 39 | dropout: 0.5 40 | optim: 41 | lr: 0.0001 42 | batch_size: 512 43 | epochs: 100 44 | -------------------------------------------------------------------------------- /options/vqa2/default.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa2/default 3 | vqa: 4 | dataset: VQA2 5 | dir: data/vqa2 6 | trainsplit: train 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152 16 | mode: noatt 17 | size: 448 18 | model: 19 | arch: MLBNoAtt 20 | seq2vec: 21 | arch: skipthoughts 22 | dir_st: data/skip-thoughts 23 | type: UniSkip 24 | dropout: 0.25 25 | fixed_emb: False 26 | fusion: 27 | dim_v: 2048 28 | dim_q: 2400 29 | dim_h: 1200 30 | dropout_v: 0.5 31 | dropout_q: 0.5 32 | activation_v: tanh 33 | activation_q: tanh 34 | classif: 35 | activation: tanh 36 | dropout: 0.5 37 | optim: 38 | lr: 0.0001 39 | batch_size: 512 40 | epochs: 100 41 | -------------------------------------------------------------------------------- /options/vqa2/mlb_att_trainval.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa2/mlb_att_trainval 3 | vqa: 4 | dataset: VQA2 5 | dir: data/vqa2 6 | trainsplit: trainval 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152 16 | mode: att 17 | size: 448 18 | model: 19 | arch: MLBAtt 20 | dim_v: 2048 21 | dim_q: 2400 22 | seq2vec: 23 | arch: skipthoughts 24 | dir_st: data/skip-thoughts 25 | type: BayesianUniSkip 26 | dropout: 0.25 27 | fixed_emb: False 28 | attention: 29 | nb_glimpses: 4 30 | dim_h: 1200 31 | dropout_v: 0.5 32 | dropout_q: 0.5 33 | dropout_mm: 0.5 34 | activation_v: tanh 35 | activation_q: tanh 36 | activation_mm: tanh 37 | fusion: 38 | dim_h: 1200 39 | dropout_v: 0.5 40 | dropout_q: 0.5 41 | activation_v: tanh 42 | activation_q: tanh 43 | classif: 44 | activation: tanh 45 | dropout: 0.5 46 | optim: 47 | lr: 0.0001 48 | batch_size: 128 49 | epochs: 100 50 | -------------------------------------------------------------------------------- /options/vqa2/mlb_noatt_train.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa2/mlb_noatt_train 3 | vqa: 4 | dataset: VQA2 5 | dir: data/vqa2 6 | trainsplit: train 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152 16 | mode: noatt 17 | size: 448 18 | model: 19 | arch: MLBNoAtt 20 | seq2vec: 21 | arch: skipthoughts 22 | dir_st: data/skip-thoughts 23 | type: BayesianUniSkip 24 | dropout: 0.25 25 | fixed_emb: False 26 | fusion: 27 | dim_v: 2048 28 | dim_q: 2400 29 | dim_h: 1200 30 | dropout_v: 0.5 31 | dropout_q: 0.5 32 | activation_v: tanh 33 | activation_q: tanh 34 | classif: 35 | activation: tanh 36 | dropout: 0.5 37 | optim: 38 | lr: 0.0001 39 | batch_size: 512 40 | epochs: 100 41 | -------------------------------------------------------------------------------- /options/vqa2/mutan_att_train.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa2/mutan_att_train 3 | vqa: 4 | dataset: VQA2 5 | dir: data/vqa2 6 | trainsplit: train 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: /local/cadene/data/coco 15 | arch: fbresnet152 16 | mode: att 17 | size: 448 18 | model: 19 | arch: MutanAtt 20 | dim_v: 2048 21 | dim_q: 2400 22 | seq2vec: 23 | arch: skipthoughts 24 | dir_st: /local/cadene/data/skip-thoughts 25 | type: BayesianUniSkip 26 | dropout: 0.25 27 | fixed_emb: False 28 | attention: 29 | nb_glimpses: 2 30 | dim_hv: 310 31 | dim_hq: 310 32 | dim_mm: 510 33 | R: 5 34 | dropout_v: 0.5 35 | dropout_q: 0.5 36 | dropout_mm: 0.5 37 | activation_v: tanh 38 | activation_q: tanh 39 | dropout_hv: 0 40 | dropout_hq: 0 41 | fusion: 42 | dim_hv: 620 43 | dim_hq: 310 44 | dim_mm: 510 45 | R: 5 46 | dropout_v: 0.5 47 | dropout_q: 0.5 48 | activation_v: tanh 49 | activation_q: tanh 50 | dropout_hv: 0 51 | dropout_hq: 0 52 | classif: 53 | dropout: 0.5 54 | optim: 55 | lr: 0.0001 56 | batch_size: 128 57 | epochs: 100 58 | -------------------------------------------------------------------------------- /options/vqa2/mutan_att_trainval.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa2/mutan_att_trainval 3 | vqa: 4 | dataset: VQA2 5 | dir: data/vqa2 6 | trainsplit: trainval 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: /local/cadene/data/coco 15 | arch: fbresnet152 16 | mode: att 17 | size: 448 18 | model: 19 | arch: MutanAtt 20 | dim_v: 2048 21 | dim_q: 2400 22 | seq2vec: 23 | arch: skipthoughts 24 | dir_st: /local/cadene/data/skip-thoughts 25 | type: BayesianUniSkip 26 | dropout: 0.25 27 | fixed_emb: False 28 | attention: 29 | nb_glimpses: 2 30 | dim_hv: 310 31 | dim_hq: 310 32 | dim_mm: 510 33 | R: 5 34 | dropout_v: 0.5 35 | dropout_q: 0.5 36 | dropout_mm: 0.5 37 | activation_v: tanh 38 | activation_q: tanh 39 | dropout_hv: 0 40 | dropout_hq: 0 41 | fusion: 42 | dim_hv: 620 43 | dim_hq: 310 44 | dim_mm: 510 45 | R: 5 46 | dropout_v: 0.5 47 | dropout_q: 0.5 48 | activation_v: tanh 49 | activation_q: tanh 50 | dropout_hv: 0 51 | dropout_hq: 0 52 | classif: 53 | dropout: 0.5 54 | optim: 55 | lr: 0.0001 56 | batch_size: 128 57 | epochs: 100 58 | -------------------------------------------------------------------------------- /options/vqa2/mutan_att_trainval_vg.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa2/mutan_att_trainval 3 | vqa: 4 | dataset: VQA2 5 | dir: data/vqa2 6 | trainsplit: trainval 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152torchported 16 | mode: att 17 | size: 448 18 | vgenome: 19 | trainsplit: train 20 | dir: data/vgenome 21 | arch: fbresnet152 22 | mode: att 23 | size: 448 24 | nans: 2000 25 | maxlength: 26 26 | minwcount: 0 27 | nlp: mcb 28 | pad: right 29 | model: 30 | arch: MutanAtt 31 | dim_v: 2048 32 | dim_q: 2400 33 | seq2vec: 34 | arch: skipthoughts 35 | dir_st: data/skip-thoughts 36 | type: BayesianUniSkip 37 | dropout: 0.25 38 | fixed_emb: False 39 | attention: 40 | nb_glimpses: 2 41 | dim_hv: 310 42 | dim_hq: 310 43 | dim_mm: 510 44 | R: 5 45 | dropout_v: 0.5 46 | dropout_q: 0.5 47 | dropout_mm: 0.5 48 | activation_v: tanh 49 | activation_q: tanh 50 | dropout_hv: 0 51 | dropout_hq: 0 52 | fusion: 53 | dim_hv: 620 54 | dim_hq: 310 55 | dim_mm: 510 56 | R: 5 57 | dropout_v: 0.5 58 | dropout_q: 0.5 59 | activation_v: tanh 60 | activation_q: tanh 61 | dropout_hv: 0 62 | dropout_hq: 0 63 | classif: 64 | dropout: 0.5 65 | optim: 66 | lr: 0.0001 67 | batch_size: 128 68 | epochs: 100 69 | -------------------------------------------------------------------------------- /options/vqa2/mutan_noatt_train.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vqa2/mutan_noatt_train 3 | vqa: 4 | dataset: VQA2 5 | dir: data/vqa2 6 | trainsplit: train 7 | nans: 2000 8 | maxlength: 26 9 | minwcount: 0 10 | nlp: mcb 11 | pad: right 12 | samplingans: True 13 | coco: 14 | dir: data/coco 15 | arch: fbresnet152 16 | mode: noatt 17 | size: 448 18 | model: 19 | arch: MutanNoAtt 20 | seq2vec: 21 | arch: skipthoughts 22 | dir_st: data/skip-thoughts 23 | type: BayesianUniSkip 24 | dropout: 0.25 25 | fixed_emb: False 26 | fusion: 27 | dim_v: 2048 28 | dim_q: 2400 29 | dim_hv: 360 30 | dim_hq: 360 31 | dim_mm: 360 32 | R: 10 33 | dropout_v: 0.5 34 | dropout_q: 0.5 35 | activation_v: tanh 36 | activation_q: tanh 37 | dropout_hv: 0 38 | dropout_hq: 0 39 | classif: 40 | dropout: 0.5 41 | optim: 42 | lr: 0.0001 43 | batch_size: 512 44 | epochs: 100 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | h5py 4 | pyyaml 5 | colorlover 6 | plotly 7 | scipy 8 | nltk 9 | click 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import yaml 5 | import json 6 | import click 7 | from pprint import pprint 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.backends.cudnn as cudnn 12 | 13 | import vqa.lib.engine as engine 14 | import vqa.lib.utils as utils 15 | import vqa.lib.logger as logger 16 | import vqa.lib.criterions as criterions 17 | import vqa.datasets as datasets 18 | import vqa.models as models 19 | 20 | parser = argparse.ArgumentParser( 21 | description='Train/Evaluate models', 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | ################################################## 24 | # yaml options file contains all default choices # 25 | parser.add_argument('--path_opt', default='options/vqa/default.yaml', type=str, 26 | help='path to a yaml options file') 27 | ################################################ 28 | # change cli options to modify default choices # 29 | # logs options 30 | parser.add_argument('--dir_logs', type=str, help='dir logs') 31 | # data options 32 | parser.add_argument('--vqa_trainsplit', type=str, choices=['train','trainval']) 33 | # model options 34 | parser.add_argument('--arch', choices=models.model_names, 35 | help='vqa model architecture: ' + 36 | ' | '.join(models.model_names)) 37 | parser.add_argument('--st_type', 38 | help='skipthoughts type') 39 | parser.add_argument('--st_dropout', type=float) 40 | parser.add_argument('--st_fixed_emb', default=None, type=utils.str2bool, 41 | help='backprop on embedding') 42 | # optim options 43 | parser.add_argument('-lr', '--learning_rate', type=float, 44 | help='initial learning rate') 45 | parser.add_argument('-b', '--batch_size', type=int, 46 | help='mini-batch size') 47 | parser.add_argument('--epochs', type=int, 48 | help='number of total epochs to run') 49 | # options not in yaml file 50 | parser.add_argument('--start_epoch', default=0, type=int, 51 | help='manual epoch number (useful on restarts)') 52 | parser.add_argument('--resume', default='', type=str, 53 | help='path to latest checkpoint') 54 | parser.add_argument('--save_model', default=True, type=utils.str2bool, 55 | help='able or disable save model and optim state') 56 | parser.add_argument('--save_all_from', type=int, 57 | help='''delete the preceding checkpoint until an epoch,''' 58 | ''' then keep all (useful to save disk space)')''') 59 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 60 | help='evaluate model on validation and test set') 61 | parser.add_argument('-j', '--workers', default=2, type=int, 62 | help='number of data loading workers') 63 | parser.add_argument('--print_freq', '-p', default=10, type=int, 64 | help='print frequency') 65 | ################################################ 66 | parser.add_argument('-ho', '--help_opt', dest='help_opt', action='store_true', 67 | help='show selected options before running') 68 | 69 | best_acc1 = 0 70 | 71 | def main(): 72 | global args, best_acc1 73 | args = parser.parse_args() 74 | 75 | ######################################################################################### 76 | # Create options 77 | ######################################################################################### 78 | 79 | options = { 80 | 'vqa' : { 81 | 'trainsplit': args.vqa_trainsplit 82 | }, 83 | 'logs': { 84 | 'dir_logs': args.dir_logs 85 | }, 86 | 'model': { 87 | 'arch': args.arch, 88 | 'seq2vec': { 89 | 'type': args.st_type, 90 | 'dropout': args.st_dropout, 91 | 'fixed_emb': args.st_fixed_emb 92 | } 93 | }, 94 | 'optim': { 95 | 'lr': args.learning_rate, 96 | 'batch_size': args.batch_size, 97 | 'epochs': args.epochs 98 | } 99 | } 100 | if args.path_opt is not None: 101 | with open(args.path_opt, 'r') as handle: 102 | options_yaml = yaml.load(handle) 103 | options = utils.update_values(options, options_yaml) 104 | print('## args'); pprint(vars(args)) 105 | print('## options'); pprint(options) 106 | if args.help_opt: 107 | return 108 | 109 | # Set datasets options 110 | if 'vgenome' not in options: 111 | options['vgenome'] = None 112 | 113 | ######################################################################################### 114 | # Create needed datasets 115 | ######################################################################################### 116 | 117 | trainset = datasets.factory_VQA(options['vqa']['trainsplit'], 118 | options['vqa'], 119 | options['coco'], 120 | options['vgenome']) 121 | train_loader = trainset.data_loader(batch_size=options['optim']['batch_size'], 122 | num_workers=args.workers, 123 | shuffle=True) 124 | 125 | if options['vqa']['trainsplit'] == 'train': 126 | valset = datasets.factory_VQA('val', options['vqa'], options['coco']) 127 | val_loader = valset.data_loader(batch_size=options['optim']['batch_size'], 128 | num_workers=args.workers) 129 | 130 | if options['vqa']['trainsplit'] == 'trainval' or args.evaluate: 131 | testset = datasets.factory_VQA('test', options['vqa'], options['coco']) 132 | test_loader = testset.data_loader(batch_size=options['optim']['batch_size'], 133 | num_workers=args.workers) 134 | 135 | ######################################################################################### 136 | # Create model, criterion and optimizer 137 | ######################################################################################### 138 | 139 | model = models.factory(options['model'], 140 | trainset.vocab_words(), trainset.vocab_answers(), 141 | cuda=True, data_parallel=True) 142 | criterion = criterions.factory(options['vqa'], cuda=True) 143 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 144 | options['optim']['lr']) 145 | 146 | ######################################################################################### 147 | # args.resume: resume from a checkpoint OR create logs directory 148 | ######################################################################################### 149 | 150 | exp_logger = None 151 | if args.resume: 152 | args.start_epoch, best_acc1, exp_logger = load_checkpoint(model.module, optimizer, 153 | os.path.join(options['logs']['dir_logs'], args.resume)) 154 | else: 155 | # Or create logs directory 156 | if os.path.isdir(options['logs']['dir_logs']): 157 | if click.confirm('Logs directory already exists in {}. Erase?' 158 | .format(options['logs']['dir_logs'], default=False)): 159 | os.system('rm -r ' + options['logs']['dir_logs']) 160 | else: 161 | return 162 | os.system('mkdir -p ' + options['logs']['dir_logs']) 163 | path_new_opt = os.path.join(options['logs']['dir_logs'], 164 | os.path.basename(args.path_opt)) 165 | path_args = os.path.join(options['logs']['dir_logs'], 'args.yaml') 166 | with open(path_new_opt, 'w') as f: 167 | yaml.dump(options, f, default_flow_style=False) 168 | with open(path_args, 'w') as f: 169 | yaml.dump(vars(args), f, default_flow_style=False) 170 | 171 | if exp_logger is None: 172 | # Set loggers 173 | exp_name = os.path.basename(options['logs']['dir_logs']) # add timestamp 174 | exp_logger = logger.Experiment(exp_name, options) 175 | exp_logger.add_meters('train', make_meters()) 176 | exp_logger.add_meters('test', make_meters()) 177 | if options['vqa']['trainsplit'] == 'train': 178 | exp_logger.add_meters('val', make_meters()) 179 | exp_logger.info['model_params'] = utils.params_count(model) 180 | print('Model has {} parameters'.format(exp_logger.info['model_params'])) 181 | 182 | ######################################################################################### 183 | # args.evaluate: on valset OR/AND on testset 184 | ######################################################################################### 185 | 186 | if args.evaluate: 187 | path_logger_json = os.path.join(options['logs']['dir_logs'], 'logger.json') 188 | 189 | if options['vqa']['trainsplit'] == 'train': 190 | acc1, val_results = engine.validate(val_loader, model, criterion, 191 | exp_logger, args.start_epoch, args.print_freq) 192 | # save results and compute OpenEnd accuracy 193 | exp_logger.to_json(path_logger_json) 194 | save_results(val_results, args.start_epoch, valset.split_name(), 195 | options['logs']['dir_logs'], options['vqa']['dir']) 196 | 197 | test_results, testdev_results = engine.test(test_loader, model, exp_logger, 198 | args.start_epoch, args.print_freq) 199 | # save results and DOES NOT compute OpenEnd accuracy 200 | exp_logger.to_json(path_logger_json) 201 | save_results(test_results, args.start_epoch, testset.split_name(), 202 | options['logs']['dir_logs'], options['vqa']['dir']) 203 | save_results(testdev_results, args.start_epoch, testset.split_name(testdev=True), 204 | options['logs']['dir_logs'], options['vqa']['dir']) 205 | return 206 | 207 | ######################################################################################### 208 | # Begin training on train/val or trainval/test 209 | ######################################################################################### 210 | 211 | for epoch in range(args.start_epoch+1, options['optim']['epochs']): 212 | #adjust_learning_rate(optimizer, epoch) 213 | 214 | # train for one epoch 215 | engine.train(train_loader, model, criterion, optimizer, 216 | exp_logger, epoch, args.print_freq) 217 | 218 | if options['vqa']['trainsplit'] == 'train': 219 | # evaluate on validation set 220 | acc1, val_results = engine.validate(val_loader, model, criterion, 221 | exp_logger, epoch, args.print_freq) 222 | # remember best prec@1 and save checkpoint 223 | is_best = acc1 > best_acc1 224 | best_acc1 = max(acc1, best_acc1) 225 | save_checkpoint({ 226 | 'epoch': epoch, 227 | 'arch': options['model']['arch'], 228 | 'best_acc1': best_acc1, 229 | 'exp_logger': exp_logger 230 | }, 231 | model.module.state_dict(), 232 | optimizer.state_dict(), 233 | options['logs']['dir_logs'], 234 | args.save_model, 235 | args.save_all_from, 236 | is_best) 237 | 238 | # save results and compute OpenEnd accuracy 239 | save_results(val_results, epoch, valset.split_name(), 240 | options['logs']['dir_logs'], options['vqa']['dir']) 241 | else: 242 | test_results, testdev_results = engine.test(test_loader, model, exp_logger, 243 | epoch, args.print_freq) 244 | 245 | # save checkpoint at every timestep 246 | save_checkpoint({ 247 | 'epoch': epoch, 248 | 'arch': options['model']['arch'], 249 | 'best_acc1': best_acc1, 250 | 'exp_logger': exp_logger 251 | }, 252 | model.module.state_dict(), 253 | optimizer.state_dict(), 254 | options['logs']['dir_logs'], 255 | args.save_model, 256 | args.save_all_from) 257 | 258 | # save results and DOES NOT compute OpenEnd accuracy 259 | save_results(test_results, epoch, testset.split_name(), 260 | options['logs']['dir_logs'], options['vqa']['dir']) 261 | save_results(testdev_results, epoch, testset.split_name(testdev=True), 262 | options['logs']['dir_logs'], options['vqa']['dir']) 263 | 264 | 265 | def make_meters(): 266 | meters_dict = { 267 | 'loss': logger.AvgMeter(), 268 | 'acc1': logger.AvgMeter(), 269 | 'acc5': logger.AvgMeter(), 270 | 'batch_time': logger.AvgMeter(), 271 | 'data_time': logger.AvgMeter(), 272 | 'epoch_time': logger.SumMeter() 273 | } 274 | return meters_dict 275 | 276 | def save_results(results, epoch, split_name, dir_logs, dir_vqa): 277 | dir_epoch = os.path.join(dir_logs, 'epoch_' + str(epoch)) 278 | name_json = 'OpenEnded_mscoco_{}_model_results.json'.format(split_name) 279 | # TODO: simplify formating 280 | if 'test' in split_name: 281 | name_json = 'vqa_' + name_json 282 | path_rslt = os.path.join(dir_epoch, name_json) 283 | os.system('mkdir -p ' + dir_epoch) 284 | with open(path_rslt, 'w') as handle: 285 | json.dump(results, handle) 286 | if not 'test' in split_name: 287 | os.system('python2 eval_res.py --dir_vqa {} --dir_epoch {} --subtype {} &' 288 | .format(dir_vqa, dir_epoch, split_name)) 289 | 290 | def save_checkpoint(info, model, optim, dir_logs, save_model, save_all_from=None, is_best=True): 291 | os.system('mkdir -p ' + dir_logs) 292 | if save_all_from is None: 293 | path_ckpt_info = os.path.join(dir_logs, 'ckpt_info.pth.tar') 294 | path_ckpt_model = os.path.join(dir_logs, 'ckpt_model.pth.tar') 295 | path_ckpt_optim = os.path.join(dir_logs, 'ckpt_optim.pth.tar') 296 | path_best_info = os.path.join(dir_logs, 'best_info.pth.tar') 297 | path_best_model = os.path.join(dir_logs, 'best_model.pth.tar') 298 | path_best_optim = os.path.join(dir_logs, 'best_optim.pth.tar') 299 | # save info & logger 300 | path_logger = os.path.join(dir_logs, 'logger.json') 301 | info['exp_logger'].to_json(path_logger) 302 | torch.save(info, path_ckpt_info) 303 | if is_best: 304 | shutil.copyfile(path_ckpt_info, path_best_info) 305 | # save model state & optim state 306 | if save_model: 307 | torch.save(model, path_ckpt_model) 308 | torch.save(optim, path_ckpt_optim) 309 | if is_best: 310 | shutil.copyfile(path_ckpt_model, path_best_model) 311 | shutil.copyfile(path_ckpt_optim, path_best_optim) 312 | else: 313 | is_best = False # because we don't know the test accuracy 314 | path_ckpt_info = os.path.join(dir_logs, 'ckpt_epoch,{}_info.pth.tar') 315 | path_ckpt_model = os.path.join(dir_logs, 'ckpt_epoch,{}_model.pth.tar') 316 | path_ckpt_optim = os.path.join(dir_logs, 'ckpt_epoch,{}_optim.pth.tar') 317 | # save info & logger 318 | path_logger = os.path.join(dir_logs, 'logger.json') 319 | info['exp_logger'].to_json(path_logger) 320 | torch.save(info, path_ckpt_info.format(info['epoch'])) 321 | # save model state & optim state 322 | if save_model: 323 | torch.save(model, path_ckpt_model.format(info['epoch'])) 324 | torch.save(optim, path_ckpt_optim.format(info['epoch'])) 325 | if info['epoch'] > 1 and info['epoch'] < save_all_from + 1: 326 | os.system('rm ' + path_ckpt_info.format(info['epoch'] - 1)) 327 | os.system('rm ' + path_ckpt_model.format(info['epoch'] - 1)) 328 | os.system('rm ' + path_ckpt_optim.format(info['epoch'] - 1)) 329 | if not save_model: 330 | print('Warning train.py: checkpoint not saved') 331 | 332 | def load_checkpoint(model, optimizer, path_ckpt): 333 | path_ckpt_info = path_ckpt + '_info.pth.tar' 334 | path_ckpt_model = path_ckpt + '_model.pth.tar' 335 | path_ckpt_optim = path_ckpt + '_optim.pth.tar' 336 | if os.path.isfile(path_ckpt_info): 337 | info = torch.load(path_ckpt_info) 338 | start_epoch = 0 339 | best_acc1 = 0 340 | exp_logger = None 341 | if 'epoch' in info: 342 | start_epoch = info['epoch'] 343 | else: 344 | print('Warning train.py: no epoch to resume') 345 | if 'best_acc1' in info: 346 | best_acc1 = info['best_acc1'] 347 | else: 348 | print('Warning train.py: no best_acc1 to resume') 349 | if 'exp_logger' in info: 350 | exp_logger = info['exp_logger'] 351 | else: 352 | print('Warning train.py: no exp_logger to resume') 353 | else: 354 | print("Warning train.py: no info checkpoint found at '{}'".format(path_ckpt_info)) 355 | if os.path.isfile(path_ckpt_model): 356 | model_state = torch.load(path_ckpt_model) 357 | model.load_state_dict(model_state) 358 | else: 359 | print("Warning train.py: no model checkpoint found at '{}'".format(path_ckpt_model)) 360 | if optimizer is not None and os.path.isfile(path_ckpt_optim): 361 | optim_state = torch.load(path_ckpt_optim) 362 | optimizer.load_state_dict(optim_state) 363 | else: 364 | print("Warning train.py: no optim checkpoint found at '{}'".format(path_ckpt_optim)) 365 | print("=> loaded checkpoint '{}' (epoch {}, best_acc1 {})" 366 | .format(path_ckpt, start_epoch, best_acc1)) 367 | return start_epoch, best_acc1, exp_logger 368 | 369 | if __name__ == '__main__': 370 | main() 371 | -------------------------------------------------------------------------------- /visu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import time\n", 14 | "import os\n", 15 | "import matplotlib.pyplot as plt, mpld3\n", 16 | "import seaborn as sns\n", 17 | "from vqa.lib.logger import Experiment\n", 18 | "import json\n", 19 | "import random\n", 20 | "#%matplotlib notebook\n", 21 | "#mpld3.enable_notebook()" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": { 28 | "collapsed": false, 29 | "scrolled": false 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "def plot_line(data_x, data_y, xlabel=\"epochs\", ylabel=\"accuracy %\", title=\"Name\"):\n", 34 | " fig, ax = plt.subplots()\n", 35 | " ax.set_xlabel(\"epochs\")\n", 36 | " ax.set_ylabel(\"accuracy\")\n", 37 | " ax.set_title(title)\n", 38 | "\n", 39 | " lines = ax.plot(data_y, data_x, '--', marker=\"o\", label=\"lol\")\n", 40 | "\n", 41 | " mpld3.plugins.connect(fig, mpld3.plugins.LineLabelTooltip(lines[0], label=xp_name))\n", 42 | " mpld3.plugins.connect(fig, mpld3.plugins.PointLabelTooltip(lines[0]))#data_y)\n", 43 | " mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fontsize=14))\n", 44 | " #ax.grid(color='lightgrey', linestyle='solid')\n", 45 | " mpld3.show()\n", 46 | " \n", 47 | "def plot_acc(train_accs, val_accs, fig=None, \\\n", 48 | " xlabel=\"epochs\", ylabel=\"accuracy %\", title=\"Name\"):\n", 49 | " if fig is None:\n", 50 | " fig = plt.figure()\n", 51 | " n = 0\n", 52 | " else:\n", 53 | " n = len(fig.axes)\n", 54 | " for i in range(n):\n", 55 | " fig.axes[i].change_geometry(n+1, 1, i+1)\n", 56 | " \n", 57 | " ax = fig.add_subplot(n+1, 1, n+1)\n", 58 | " \n", 59 | " ax.set_xlabel(xlabel)\n", 60 | " ax.set_ylabel(ylabel)\n", 61 | " ax.set_title(title)\n", 62 | "\n", 63 | " data_x = range(1, len(train_accs)+1)\n", 64 | " ax.plot(data_x, train_accs, '--', marker=\"o\", label=\"train\")\n", 65 | " ax.plot(data_x, val_accs, '--', marker=\"o\", label=\"val\")\n", 66 | "\n", 67 | " line_labels = ['train', 'val']\n", 68 | " point_labels = [train_accs, val_accs]\n", 69 | " for i, line in enumerate(ax.lines):\n", 70 | " mpld3.plugins.connect(fig, \\\n", 71 | " mpld3.plugins.LineLabelTooltip(line, label=line_labels[i]))\n", 72 | " mpld3.plugins.connect(fig, \\\n", 73 | " mpld3.plugins.PointLabelTooltip(line, labels=point_labels[i]))\n", 74 | " \n", 75 | " mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fontsize=14))\n", 76 | " #ax.grid(color='lightgrey', linestyle='solid')\n", 77 | " #mpld3.show(open_browser=False)\n", 78 | " return fig\n", 79 | " \n", 80 | "def plot_accs(train_accs=[], val_accs=[], names=[], fig=None, \\\n", 81 | " xlabel=\"epochs\", ylabel=\"accuracy %\", title=\"Name\"):\n", 82 | " if fig is None:\n", 83 | " fig = plt.figure()\n", 84 | " n = 0\n", 85 | " else:\n", 86 | " n = len(fig.axes)\n", 87 | " for i in range(n):\n", 88 | " fig.axes[i].change_geometry(n+1, 1, i+1)\n", 89 | " \n", 90 | " ax = fig.add_subplot(n+1, 1, n+1)\n", 91 | " \n", 92 | " ax.set_xlabel(xlabel)\n", 93 | " ax.set_ylabel(ylabel)\n", 94 | " ax.set_title(title)\n", 95 | "\n", 96 | " max_length = max([len(acc) for acc in val_accs])\n", 97 | " data_x = range(1, max_length+1)\n", 98 | " for i in range(len(val_accs)):\n", 99 | " color = (random.random(), random.random(), random.random())\n", 100 | " data_x_train = range(1, len(train_accs[i])+1)\n", 101 | " data_x_val = range(1, len(val_accs[i])+1)\n", 102 | " ax.plot(data_x_train, train_accs[i],\n", 103 | " '--', marker=\"o\", color=color,\n", 104 | " label=names[i]+\"train\")\n", 105 | " ax.plot(data_x_val, val_accs[i],\n", 106 | " '-', marker=\"o\", color=color,\n", 107 | " label=names[i]+\"val\")\n", 108 | "\n", 109 | " point_labels = {\n", 110 | " 'train': train_accs,\n", 111 | " 'val': val_accs\n", 112 | " }\n", 113 | " for line_id in range(len(ax.lines)):\n", 114 | " label = 'train' if line_id%2==0 else 'val'\n", 115 | " #mpld3.plugins.connect(fig, \\\n", 116 | " # mpld3.plugins.LineLabelTooltip(ax.lines[line_id], label=label))\n", 117 | " mpld3.plugins.connect(fig, \\\n", 118 | " mpld3.plugins.PointLabelTooltip(ax.lines[line_id], labels=point_labels[label][int(i/2)]))\n", 119 | " \n", 120 | " ax.legend()\n", 121 | " labels = []\n", 122 | " for name in names:\n", 123 | " labels.append(name + '_train')\n", 124 | " labels.append(name + '_val')\n", 125 | " mpld3.plugins.connect(fig, mpld3.plugins.InteractiveLegendPlugin(ax.lines, labels))\n", 126 | " mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fontsize=14))\n", 127 | " #ax.grid(color='lightgrey', linestyle='solid')\n", 128 | " #mpld3.show(open_browser=False)\n", 129 | " return fig \n", 130 | "\n", 131 | "def load_accs_oe(path_logger):\n", 132 | " dir_xp = os.path.dirname(path_logger)\n", 133 | " epochs = []\n", 134 | " for name in os.listdir(dir_xp):\n", 135 | " if name.startswith('epoch'):\n", 136 | " epochs.append(name)\n", 137 | " epochs = sorted(epochs, key=lambda x: float(x.split('_')[1]))\n", 138 | " accs = {}\n", 139 | " for i, epoch in enumerate(epochs):\n", 140 | " epoch_id = i+1\n", 141 | " path_acc = os.path.join(dir_xp, epoch, 'OpenEnded_mscoco_val2014_model_accuracy.json')\n", 142 | " if os.path.exists(path_acc):\n", 143 | " with open(path_acc, 'r') as f:\n", 144 | " data = json.load(f)\n", 145 | " accs[epoch_id] = data['overall']\n", 146 | " return accs\n", 147 | "\n", 148 | "def sort(dict_):\n", 149 | " return [v for k,v in sorted(dict_.items(), \\\n", 150 | " key=lambda x: float(x[0]))]\n", 151 | "\n", 152 | "def reduce(list_, num=15):\n", 153 | " tmp = []\n", 154 | " for i, val in enumerate(list_):\n", 155 | " if i < num:\n", 156 | " tmp.append(val)\n", 157 | " return tmp" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": { 164 | "collapsed": false, 165 | "scrolled": true 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "dir_logs = 'logs/vqa'\n", 170 | "xps = []\n", 171 | "for p in sorted(os.listdir(dir_logs)):\n", 172 | " if not p.startswith('.'): print(p)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "collapsed": false 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "path_logger = os.path.join(dir_logs, 'mlb_att_samplingans,True', 'logger.json')\n", 184 | "xp = Experiment.from_json(path_logger)\n", 185 | "xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger)\n", 186 | "#print(xp.logged['val']['acc1_oe'])\n", 187 | "print(xp.options)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "collapsed": false 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "# Display accuracy & loss of one exp\n", 199 | "def display_xp(path_logger):\n", 200 | " xp = Experiment.from_json(path_logger)\n", 201 | " xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger)\n", 202 | "\n", 203 | " path_visu = os.path.join(os.path.dirname(path_logger),\n", 204 | " 'visu.html')\n", 205 | "\n", 206 | " val_acc1 = sort(xp.logged['val']['acc1_oe'])\n", 207 | " train_acc1 = sort(xp.logged['train']['acc1'])\n", 208 | " loss_val = sort(xp.logged['val']['loss'])\n", 209 | " loss_train = sort(xp.logged['train']['loss'])\n", 210 | "\n", 211 | " max_length = min([len(val_acc1), len(train_acc1)])\n", 212 | " val_acc1 = reduce(val_acc1, max_length)\n", 213 | " train_acc1= reduce(train_acc1, max_length)\n", 214 | " loss_val = reduce(loss_val, max_length)\n", 215 | " loss_train=reduce(loss_train, max_length)\n", 216 | " \n", 217 | " print(max(val_acc1))\n", 218 | " \n", 219 | " fig = plt.figure(figsize=(10,13))\n", 220 | " fig = plot_acc(train_acc1, val_acc1, fig=fig, \\\n", 221 | " title='max: '+str(max(val_acc1)))\n", 222 | " fig = plot_acc(loss_train, loss_val, fig=fig, \\\n", 223 | " ylabel='loss', title='min: '+str(min(loss_val)))\n", 224 | " fig.tight_layout()\n", 225 | " with open(path_visu, 'w') as f:\n", 226 | " mpld3.save_html(fig, f)\n", 227 | " \n", 228 | "#for name in ['blocmutan']:#['seq2vec/bgru', 'mutan_noatt_2', 'mlb_att_2']:\n", 229 | "for name in os.listdir(dir_logs):\n", 230 | " if not name.startswith('.'):\n", 231 | " path_logger = os.path.join(dir_logs, name, 'logger.json')\n", 232 | " print(name)\n", 233 | " display_xp(path_logger)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": { 240 | "collapsed": false 241 | }, 242 | "outputs": [], 243 | "source": [ 244 | "# Display accuracy & loss of multiple exp\n", 245 | "accs_train = []\n", 246 | "accs_val = []\n", 247 | "names = []\n", 248 | "acc_max = ('xp name', -1)\n", 249 | "fig = plt.figure(figsize=(20,13))\n", 250 | "#for name in os.listdir(dir_logs):#['bgu_normal', 'mutan_noatt', 'mutan_noatt_2', 'mlb_att_2', 'mlb_att_3']:\n", 251 | "#for name in ['seq2vec/bgru_dropout,25%', 'mutan_noatt', 'blocmutan_noatt', 'mlb_noatt_2lstm', 'mlb_noatt_2lstm_emb']:\n", 252 | "#for name in ['blocmutan_noatt', 'blocmutan_noatt_samplingans,True']:\n", 253 | "dir_logs = 'logs/vqa'\n", 254 | "exps = []\n", 255 | "exps.append('blocmutan_att_2')\n", 256 | "exps.append('mutan_att_2')\n", 257 | "exps.append('mlb_att_2')\n", 258 | "\n", 259 | "for name in exps:\n", 260 | " if not name.startswith('.'):\n", 261 | " path_logger = os.path.join(dir_logs, name, 'logger.json')\n", 262 | " xp = Experiment.from_json(path_logger)\n", 263 | " xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger)\n", 264 | " train_acc1 = sort(xp.logged['train']['acc1'])\n", 265 | " val_acc1 = sort(xp.logged['val']['acc1_oe'])\n", 266 | "\n", 267 | " max_length = min([len(val_acc1), len(train_acc1)])\n", 268 | " val_acc1 = reduce(val_acc1, max_length)\n", 269 | " train_acc1= reduce(train_acc1, max_length)\n", 270 | "\n", 271 | " accs_val.append(val_acc1)\n", 272 | " accs_train.append(train_acc1)\n", 273 | " print(name, max(val_acc1))\n", 274 | " if max(val_acc1) > acc_max[1]:\n", 275 | " acc_max = (name, max(val_acc1))\n", 276 | " names.append(name)\n", 277 | "fig = plot_accs(accs_train, accs_val, names,\n", 278 | " title=acc_max[0]+' '+str(acc_max[1]), fig=fig)\n", 279 | "path_visu = dir_logs+'/test.html'\n", 280 | "with open(path_visu, 'w') as f:\n", 281 | " mpld3.save_html(fig, f)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": { 288 | "collapsed": false 289 | }, 290 | "outputs": [], 291 | "source": [ 292 | "# Display accuracy & loss of multiple exp\n", 293 | "accs_train = []\n", 294 | "accs_val = []\n", 295 | "names = []\n", 296 | "acc_max = ('xp name', -1)\n", 297 | "fig = plt.figure(figsize=(20,13))\n", 298 | "#for name in os.listdir(dir_logs):#['bgu_normal', 'mutan_noatt', 'mutan_noatt_2', 'mlb_att_2', 'mlb_att_3']:\n", 299 | "#for name in ['seq2vec/bgru_dropout,25%', 'mutan_noatt', 'blocmutan_noatt', 'mlb_noatt_2lstm', 'mlb_noatt_2lstm_emb']:\n", 300 | "#for name in ['blocmutan_noatt', 'blocmutan_noatt_samplingans,True']:\n", 301 | "\n", 302 | "dir_logs = 'logs/vqa'\n", 303 | "list_dir = []\n", 304 | "for name in os.listdir(dir_logs):\n", 305 | " if not name.startswith('.') and not name.endswith('.html'):\n", 306 | " list_dir.append(name)\n", 307 | " \n", 308 | "for name in list_dir:\n", 309 | " path_logger = os.path.join(dir_logs, name, 'logger.json')\n", 310 | " #xp = Experiment.from_json(path_logger)\n", 311 | " val_acc1 = load_accs_oe(path_logger)\n", 312 | " #train_acc1 = sort(xp.logged['train']['acc1'])\n", 313 | " val_acc1 = sort(val_acc1)\n", 314 | "\n", 315 | " max_length = len(val_acc1)\n", 316 | " val_acc1 = reduce(val_acc1, max_length)\n", 317 | " #train_acc1= reduce(train_acc1, max_length)\n", 318 | "\n", 319 | " accs_val.append(val_acc1)\n", 320 | " #accs_train.append(train_acc1)\n", 321 | " print(name, max(val_acc1))\n", 322 | " if max(val_acc1) > acc_max[1]:\n", 323 | " acc_max = (name, max(val_acc1))\n", 324 | " names.append(name)\n", 325 | "\n", 326 | "fig = plot_accs(accs_val, accs_val, names,\n", 327 | " title=acc_max[0]+' '+str(acc_max[1]), fig=fig)\n", 328 | "path_visu = dir_logs+'/test.html'\n", 329 | "with open(path_visu, 'w') as f:\n", 330 | " mpld3.save_html(fig, f)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": { 337 | "collapsed": true 338 | }, 339 | "outputs": [], 340 | "source": [] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": { 346 | "collapsed": false 347 | }, 348 | "outputs": [], 349 | "source": [ 350 | "# Test displaying accuracy of one exp\n", 351 | "val_acc1 = sort(xp.logged['val']['acc1'])\n", 352 | "train_acc1 = sort(xp.logged['train']['acc1'])\n", 353 | "loss_val = sort(xp.logged['val']['loss'])\n", 354 | "loss_train = sort(xp.logged['train']['loss'])\n", 355 | "\n", 356 | "#plot_line(val_acc1, range(1,len(val_acc1)+1))\n", 357 | "for i in range(1,40):\n", 358 | " val_acc1_tmp = reduce(val_acc1, i)\n", 359 | " train_acc1_tmp = reduce(train_acc1, i)\n", 360 | " loss_val_tmp = reduce(loss_val, i)\n", 361 | " loss_train_tmp = reduce(loss_train, i)\n", 362 | " fig = plt.figure(figsize=(10,13))\n", 363 | " fig = plot_acc(train_acc1_tmp, val_acc1_tmp, fig=fig, \\\n", 364 | " title='max: '+str(max(val_acc1_tmp)))\n", 365 | " fig = plot_acc(loss_train_tmp, loss_val_tmp, fig=fig, \\\n", 366 | " ylabel='loss', title='min: '+str(min(loss_val_tmp)))\n", 367 | " fig.tight_layout()\n", 368 | " with open('test.html', 'w') as f:\n", 369 | " mpld3.save_html(fig, f)\n", 370 | " time.sleep(3)\n", 371 | " " 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": { 378 | "collapsed": true 379 | }, 380 | "outputs": [], 381 | "source": [] 382 | } 383 | ], 384 | "metadata": { 385 | "anaconda-cloud": {}, 386 | "kernelspec": { 387 | "display_name": "Python [conda env:vqa]", 388 | "language": "python", 389 | "name": "conda-env-vqa-py" 390 | }, 391 | "language_info": { 392 | "codemirror_mode": { 393 | "name": "ipython", 394 | "version": 3 395 | }, 396 | "file_extension": ".py", 397 | "mimetype": "text/x-python", 398 | "name": "python", 399 | "nbconvert_exporter": "python", 400 | "pygments_lexer": "ipython3", 401 | "version": "3.6.0" 402 | } 403 | }, 404 | "nbformat": 4, 405 | "nbformat_minor": 1 406 | } 407 | -------------------------------------------------------------------------------- /visu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import json 5 | import random 6 | import numpy as np 7 | import colorlover as cl 8 | 9 | import plotly.plotly as py 10 | import plotly.graph_objs as go 11 | from plotly.offline import download_plotlyjs, plot 12 | from plotly import tools 13 | 14 | from vqa.lib.logger import Experiment 15 | 16 | def load_accs_oe(path_logger): 17 | dir_xp = os.path.dirname(path_logger) 18 | epochs = [] 19 | for name in os.listdir(dir_xp): 20 | if name.startswith('epoch'): 21 | epochs.append(name) 22 | epochs = sorted(epochs, key=lambda x: float(x.split('_')[1])) 23 | accs = {} 24 | for i, epoch in enumerate(epochs): 25 | epoch_id = i+1 26 | path_acc = os.path.join(dir_xp, epoch, 'OpenEnded_mscoco_val2014_model_accuracy.json') 27 | if os.path.exists(path_acc): 28 | with open(path_acc, 'r') as f: 29 | data = json.load(f) 30 | accs[epoch_id] = data['overall'] 31 | return accs 32 | 33 | def sort(dict_): 34 | return [v for k,v in sorted(dict_.items(), \ 35 | key=lambda x: float(x[0]))] 36 | 37 | def reduce(list_, num=15): 38 | tmp = [] 39 | for i, val in enumerate(list_): 40 | if i < num: 41 | tmp.append(val) 42 | return tmp 43 | 44 | # Display accuracy & loss of one exp 45 | def visu_one_exp(path_logger, path_visu, auto_open=True): 46 | xp = Experiment.from_json(path_logger) 47 | xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger) 48 | train_acc1 = sort(xp.logged['train']['acc1']) 49 | val_acc1 = sort(xp.logged['val']['acc1_oe']) 50 | train_loss = sort(xp.logged['train']['loss']) 51 | val_loss = sort(xp.logged['val']['loss']) 52 | train_data_x = list(range(1, len(train_acc1)+1)) 53 | val_data_x = list(range(1, len(val_acc1)+1)) 54 | 55 | fig = tools.make_subplots(rows=1, 56 | cols=2, 57 | subplot_titles=('Accuracy top1', 'Loss')) 58 | # blue rgb(31, 119, 180) 59 | # orange rgb(255, 127, 14) 60 | 61 | train_acc1_trace = go.Scatter( 62 | x=train_data_x, 63 | y=train_acc1, 64 | name='train accuracy top1' 65 | ) 66 | val_acc1_trace = go.Scatter( 67 | x=val_data_x, 68 | y=val_acc1, 69 | name='val accuracy top1', 70 | line = dict( 71 | color = ('rgb(255, 127, 14)'), 72 | ) 73 | ) 74 | best_val_acc1_trace = go.Scatter( 75 | x=[np.argmax(val_acc1)+1], 76 | y=[max(val_acc1)], 77 | mode='markers', 78 | name='best val accuracy top1', 79 | marker = dict( 80 | color = 'rgb(255, 127, 14)', 81 | size = 10 82 | ) 83 | ) 84 | 85 | val_loss_trace = go.Scatter( 86 | x=val_data_x, 87 | y=val_loss, 88 | name='val loss' 89 | ) 90 | train_loss_trace = go.Scatter( 91 | x=train_data_x, 92 | y=train_loss, 93 | name='train loss' 94 | ) 95 | 96 | fig.append_trace(train_acc1_trace, 1, 1) 97 | fig.append_trace(val_acc1_trace, 1, 1) 98 | fig.append_trace(best_val_acc1_trace, 1, 1) 99 | 100 | fig.append_trace(train_loss_trace, 1, 2) 101 | fig.append_trace(val_loss_trace, 1, 2) 102 | 103 | plot(fig, filename=path_visu, auto_open=auto_open) 104 | 105 | return train_acc1, val_acc1 106 | 107 | # Display accuracy & loss of one exp 108 | def visu_exps(list_path_logger, path_visu, auto_open=True): 109 | fig = tools.make_subplots(rows=2, 110 | cols=2, 111 | subplot_titles=('Val accuracy top1', 112 | 'Val loss', 113 | 'Train accuracy top1', 114 | 'Train loss')) 115 | num_xp = len(list_path_logger) 116 | if num_xp < 3: # cl.scales not accept 117 | num_xp = 3 118 | list_color = cl.scales[str(num_xp)]['qual']['Paired'] 119 | 120 | for i, path_logger in enumerate(list_path_logger): 121 | name = path_logger.split('/')[-2] 122 | 123 | xp = Experiment.from_json(path_logger) 124 | xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger) 125 | train_acc1 = sort(xp.logged['train']['acc1']) 126 | val_acc1 = sort(xp.logged['val']['acc1_oe']) 127 | train_loss = sort(xp.logged['train']['loss']) 128 | val_loss = sort(xp.logged['val']['loss']) 129 | train_data_x = list(range(1, len(train_acc1)+1)) 130 | val_data_x = list(range(1, len(val_acc1)+1)) 131 | 132 | train_acc1_trace = go.Scatter( 133 | x=train_data_x, 134 | y=train_acc1, 135 | name='train acc: '+name, 136 | line=dict( 137 | color=list_color[i] 138 | ) 139 | ) 140 | val_acc1_trace = go.Scatter( 141 | x=val_data_x, 142 | y=val_acc1, 143 | name='val acc: '+name, 144 | line=dict( 145 | color=list_color[i] 146 | ) 147 | ) 148 | best_val_acc1_trace = go.Scatter( 149 | x=[np.argmax(val_acc1)+1], 150 | y=[max(val_acc1)], 151 | mode='markers', 152 | name='best val acc: '+name, 153 | marker = dict( 154 | color = list_color[i], 155 | size = 10 156 | ) 157 | ) 158 | 159 | val_loss_trace = go.Scatter( 160 | x=val_data_x, 161 | y=val_loss, 162 | name='val loss: '+name, 163 | line=dict( 164 | color=list_color[i] 165 | ) 166 | ) 167 | train_loss_trace = go.Scatter( 168 | x=train_data_x, 169 | y=train_loss, 170 | name='train loss: '+name, 171 | line=dict( 172 | color=list_color[i] 173 | ) 174 | ) 175 | 176 | fig.append_trace(val_acc1_trace, 1, 1) 177 | fig.append_trace(best_val_acc1_trace, 1, 1) 178 | fig.append_trace(train_acc1_trace, 2, 1) 179 | 180 | fig.append_trace(val_loss_trace, 1, 2) 181 | fig.append_trace(train_loss_trace, 2, 2) 182 | 183 | plot(fig, filename=path_visu, auto_open=auto_open) 184 | 185 | def main_one_exp(dir_logs, path_visu=None, refresh_freq=60): 186 | if path_visu is None: 187 | path_visu = os.path.join(dir_logs, 'visu.html') 188 | 189 | path_logger = os.path.join(dir_logs, 'logger.json') 190 | 191 | i = 1 192 | print('Create visu to ' + path_visu) 193 | while True: 194 | train_acc1, val_acc1 = visu_one_exp(path_logger, path_visu, auto_open=(i==1)) 195 | print('# Visu iteration (refresh every {} sec): {}'.format(refresh_freq, i)) 196 | print('Max Val OpenEnded-Accuracy Top1: {}'.format(max(val_acc1))) 197 | print('Max Train Accuracy Top1: {}'.format(max(train_acc1))) 198 | i += 1 199 | time.sleep(refresh_freq) 200 | 201 | def main_exps(list_dir_logs, path_visu=None, refresh_freq=60): 202 | if path_visu is None: 203 | path_visu = os.path.join(os.path.dirname(list_dir_logs[0]), 'visu.html') 204 | 205 | list_path_logger = [] 206 | for dir_logs in list_dir_logs: 207 | list_path_logger.append(os.path.join(dir_logs, 'logger.json')) 208 | 209 | i = 1 210 | print('Create visu to ' + path_visu) 211 | while True: 212 | visu_exps(list_path_logger, path_visu, auto_open=(i==1)) 213 | print('# Visu iteration (refresh every {} sec): {}'.format(refresh_freq, i)) 214 | i += 1 215 | time.sleep(refresh_freq) 216 | 217 | 218 | ########################################################################## 219 | # Main 220 | ########################################################################## 221 | 222 | parser = argparse.ArgumentParser( 223 | description='Create html visu files', 224 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 225 | 226 | parser.add_argument('--dir_logs', type=str, 227 | help='''First mode: dir to logs of an experiment (ex: logs/vqa/mutan)''' 228 | '''Second mode: add several dirs to create a comparativ visualisation (ex: logs/vqa/mutan,logs/vqa/mlb)''') 229 | parser.add_argument('--refresh_freq', '-f', default=60, type=int, 230 | help='refresh frequency in seconds') 231 | parser.add_argument('--path_visu', default=None, 232 | help='path to the html file (default: visu.html in dir_logs)') 233 | 234 | def main(): 235 | global args 236 | args = parser.parse_args() 237 | 238 | list_dir_logs = args.dir_logs.split(',') 239 | 240 | if len(list_dir_logs) == 1: 241 | main_one_exp(args.dir_logs, 242 | path_visu=args.path_visu, 243 | refresh_freq=args.refresh_freq) 244 | else: 245 | main_exps(list_dir_logs, 246 | path_visu=args.path_visu, 247 | refresh_freq=args.refresh_freq) 248 | 249 | if __name__ == '__main__': 250 | main() 251 | -------------------------------------------------------------------------------- /vqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/vqa/__init__.py -------------------------------------------------------------------------------- /vqa/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .vqa import factory as factory_VQA 2 | from .coco import COCOImages 3 | from .vgenome import VisualGenomeImages -------------------------------------------------------------------------------- /vqa/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | # from mpi4py import MPI 3 | import numpy as np 4 | import h5py 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | 8 | from ..lib import utils 9 | from .images import ImagesFolder, AbstractImagesDataset, default_loader 10 | from .features import FeaturesDataset 11 | 12 | def split_name(data_split): 13 | if data_split in ['train', 'val']: 14 | return data_split + '2014' 15 | elif data_split == 'test': 16 | return data_split + '2015' 17 | else: 18 | assert False, 'data_split {} not exists'.format(data_split) 19 | 20 | 21 | class COCOImages(AbstractImagesDataset): 22 | 23 | def __init__(self, data_split, opt, transform=None, loader=default_loader): 24 | self.split_name = split_name(data_split) 25 | super(COCOImages, self).__init__(data_split, opt, transform, loader) 26 | self.dir_split = self.get_dir_data() 27 | self.dataset = ImagesFolder(self.dir_split, transform=self.transform, loader=self.loader) 28 | self.name_to_index = self._load_name_to_index() 29 | 30 | def get_dir_data(self): 31 | return os.path.join(self.dir_raw, self.split_name) 32 | 33 | def _raw(self): 34 | if self.data_split in ['train', 'val']: 35 | os.system('wget http://msvocds.blob.core.windows.net/coco2014/{}.zip -P {}'.format(self.split_name, self.dir_raw)) 36 | elif self.data_split == 'test': 37 | os.system('wget http://msvocds.blob.core.windows.net/coco2015/test2015.zip -P '+self.dir_raw) 38 | else: 39 | assert False, 'data_split {} not exists'.format(self.data_split) 40 | os.system('unzip '+os.path.join(self.dir_raw, self.split_name+'.zip')+' -d '+self.dir_raw) 41 | 42 | def _load_name_to_index(self): 43 | self.name_to_index = {name:index for index, name in enumerate(self.dataset.imgs)} 44 | return self.name_to_index 45 | 46 | def __getitem__(self, index): 47 | item = self.dataset[index] 48 | item['name'] = os.path.join(self.split_name, item['name']) 49 | return item 50 | 51 | def __len__(self): 52 | return len(self.dataset) 53 | 54 | 55 | class COCOTrainval(data.Dataset): 56 | 57 | def __init__(self, trainset, valset): 58 | self.trainset = trainset 59 | self.valset = valset 60 | 61 | def __getitem__(self, index): 62 | if index < len(self.trainset): 63 | item = self.trainset[index] 64 | else: 65 | item = self.valset[index - len(self.trainset)] 66 | return item 67 | 68 | def get_by_name(self, image_name): 69 | if image_name in self.trainset.name_to_index: 70 | index = self.trainset.name_to_index[image_name] 71 | item = self.trainset[index] 72 | return item 73 | elif image_name in self.valset.name_to_index: 74 | index = self.valset.name_to_index[image_name] 75 | item = self.valset[index] 76 | return item 77 | else: 78 | raise ValueError 79 | 80 | def __len__(self): 81 | return len(self.trainset) + len(self.valset) 82 | 83 | 84 | def default_transform(size): 85 | transform = transforms.Compose([ 86 | transforms.Scale(size), 87 | transforms.CenterCrop(size), 88 | transforms.ToTensor(), 89 | transforms.Normalize(mean=[0.485, 0.456, 0.406], # resnet imagnet 90 | std=[0.229, 0.224, 0.225]) 91 | ]) 92 | return transform 93 | 94 | def factory(data_split, opt, transform=None): 95 | if data_split == 'trainval': 96 | trainset = factory('train', opt, transform) 97 | valset = factory('val', opt, transform) 98 | return COCOTrainval(trainset, valset) 99 | elif data_split in ['train', 'val', 'test']: 100 | if opt['mode'] == 'img': 101 | if transform is None: 102 | transform = default_transform(opt['size']) 103 | return COCOImages(data_split, opt, transform) 104 | elif opt['mode'] in ['noatt', 'att']: 105 | return FeaturesDataset(data_split, opt) 106 | else: 107 | raise ValueError 108 | else: 109 | raise ValueError 110 | -------------------------------------------------------------------------------- /vqa/datasets/features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import torch 5 | import torch.utils.data as data 6 | 7 | class FeaturesDataset(data.Dataset): 8 | 9 | def __init__(self, data_split, opt): 10 | self.data_split = data_split 11 | self.opt = opt 12 | self.dir_extract = os.path.join(self.opt['dir'], 13 | 'extract', 14 | 'arch,' + self.opt['arch']) 15 | if 'size' in opt: 16 | self.dir_extract += '_size,' + str(opt['size']) 17 | self.path_hdf5 = os.path.join(self.dir_extract, 18 | data_split + 'set.hdf5') 19 | assert os.path.isfile(self.path_hdf5), \ 20 | 'File not found in {}, you must extract the features first with extract.py'.format(self.path_hdf5) 21 | self.hdf5_file = h5py.File(self.path_hdf5, 'r')#, driver='mpio', comm=MPI.COMM_WORLD) 22 | self.dataset_features = self.hdf5_file[self.opt['mode']] 23 | self.index_to_name, self.name_to_index = self._load_dicts() 24 | 25 | def _load_dicts(self): 26 | self.path_fname = os.path.join(self.dir_extract, 27 | self.data_split + 'set.txt') 28 | with open(self.path_fname, 'r') as handle: 29 | self.index_to_name = handle.readlines() 30 | self.index_to_name = [name[:-1] for name in self.index_to_name] # remove char '\n' 31 | self.name_to_index = {name:index for index,name in enumerate(self.index_to_name)} 32 | return self.index_to_name, self.name_to_index 33 | 34 | def __getitem__(self, index): 35 | item = {} 36 | item['name'] = self.index_to_name[index] 37 | item['visual'] = self.get_features(index) 38 | #item = torch.Tensor(self.get_features(index)) 39 | return item 40 | 41 | def get_features(self, index): 42 | return torch.Tensor(self.dataset_features[index]) 43 | 44 | def get_features_old(self, index): 45 | try: 46 | self.features_array 47 | except AttributeError: 48 | if self.opt['mode'] == 'att': 49 | self.features_array = np.zeros((2048,14,14), dtype='f') 50 | elif self.opt['mode'] == 'noatt': 51 | self.features_array = np.zeros((2048), dtype='f') 52 | 53 | if self.opt['mode'] == 'att': 54 | self.dataset_features.read_direct(self.features_array, 55 | np.s_[index,:2048,:14,:14], 56 | np.s_[:2048,:14,:14]) 57 | elif self.opt['mode'] == 'noatt': 58 | self.dataset_features.read_direct(self.features_array, 59 | np.s_[index,:2048], 60 | np.s_[:2048]) 61 | return self.features_array 62 | 63 | 64 | def get_by_name(self, image_name): 65 | index = self.name_to_index[image_name] 66 | return self[index] 67 | 68 | def __len__(self): 69 | return self.dataset_features.shape[0] 70 | -------------------------------------------------------------------------------- /vqa/datasets/images.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | def make_dataset(dir): 16 | images = [] 17 | for fname in os.listdir(dir): 18 | if is_image_file(fname): 19 | images.append(fname) 20 | return images 21 | 22 | 23 | def default_loader(path): 24 | return Image.open(path).convert('RGB') 25 | 26 | 27 | class ImagesFolder(data.Dataset): 28 | 29 | def __init__(self, root, transform=None, loader=default_loader): 30 | imgs = make_dataset(root) 31 | if len(imgs) == 0: 32 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 33 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 34 | self.root = root 35 | self.imgs = imgs 36 | self.transform = transform 37 | self.loader = loader 38 | 39 | def __getitem__(self, index): 40 | item = {} 41 | item['name'] = self.imgs[index] 42 | item['path'] = os.path.join(self.root, item['name']) 43 | if self.loader is not None: 44 | item['visual'] = self.loader(item['path']) 45 | if self.transform is not None: 46 | item['visual'] = self.transform(item['visual']) 47 | return item 48 | 49 | def __len__(self): 50 | return len(self.imgs) 51 | 52 | 53 | class AbstractImagesDataset(data.Dataset): 54 | 55 | def __init__(self, data_split, opt, transform=None, loader=default_loader): 56 | self.data_split = data_split 57 | self.opt = opt 58 | self.transform = transform 59 | self.loader = loader 60 | self.dir_raw = os.path.join(self.opt['dir'], 'raw') 61 | 62 | if not os.path.exists(self.get_dir_data()): 63 | self._raw() 64 | 65 | def get_dir_data(self): 66 | return self.dir_raw 67 | 68 | def get_by_name(self, image_name): 69 | index = self.name_to_index[image_name] 70 | return self[index] 71 | 72 | def _raw(self): 73 | raise NotImplementedError 74 | 75 | def __getitem__(self, index): 76 | raise NotImplementedError 77 | 78 | def __len__(self): 79 | raise NotImplementedError 80 | -------------------------------------------------------------------------------- /vqa/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | import torch.utils.data as data 5 | import copy 6 | 7 | class AbstractVQADataset(data.Dataset): 8 | 9 | def __init__(self, data_split, opt, dataset_img=None): 10 | self.data_split = data_split 11 | self.opt = copy.copy(opt) 12 | self.dataset_img = dataset_img 13 | 14 | self.dir_raw = os.path.join(self.opt['dir'], 'raw') 15 | if not os.path.exists(self.dir_raw): 16 | self._raw() 17 | 18 | self.dir_interim = os.path.join(self.opt['dir'], 'interim') 19 | if not os.path.exists(self.dir_interim): 20 | self._interim() 21 | 22 | self.dir_processed = os.path.join(self.opt['dir'], 'processed') 23 | self.subdir_processed = self.subdir_processed() 24 | if not os.path.exists(self.subdir_processed): 25 | self._processed() 26 | 27 | path_wid_to_word = os.path.join(self.subdir_processed, 'wid_to_word.pickle') 28 | path_word_to_wid = os.path.join(self.subdir_processed, 'word_to_wid.pickle') 29 | path_aid_to_ans = os.path.join(self.subdir_processed, 'aid_to_ans.pickle') 30 | path_ans_to_aid = os.path.join(self.subdir_processed, 'ans_to_aid.pickle') 31 | path_dataset = os.path.join(self.subdir_processed, self.data_split+'set.pickle') 32 | 33 | with open(path_wid_to_word, 'rb') as handle: 34 | self.wid_to_word = pickle.load(handle) 35 | 36 | with open(path_word_to_wid, 'rb') as handle: 37 | self.word_to_wid = pickle.load(handle) 38 | 39 | with open(path_aid_to_ans, 'rb') as handle: 40 | self.aid_to_ans = pickle.load(handle) 41 | 42 | with open(path_ans_to_aid, 'rb') as handle: 43 | self.ans_to_aid = pickle.load(handle) 44 | 45 | with open(path_dataset, 'rb') as handle: 46 | self.dataset = pickle.load(handle) 47 | 48 | def _raw(self): 49 | raise NotImplementedError 50 | 51 | def _interim(self): 52 | raise NotImplementedError 53 | 54 | def _processed(self): 55 | raise NotImplementedError 56 | 57 | def __getitem__(self, index): 58 | raise NotImplementedError 59 | 60 | def subdir_processed(self): 61 | subdir = 'nans,' + str(self.opt['nans']) \ 62 | + '_maxlength,' + str(self.opt['maxlength']) \ 63 | + '_minwcount,' + str(self.opt['minwcount']) \ 64 | + '_nlp,' + self.opt['nlp'] \ 65 | + '_pad,' + self.opt['pad'] \ 66 | + '_trainsplit,' + self.opt['trainsplit'] 67 | subdir = os.path.join(self.dir_processed, subdir) 68 | return subdir -------------------------------------------------------------------------------- /vqa/datasets/vgenome.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import copy 5 | 6 | from .images import ImagesFolder, AbstractImagesDataset, default_loader 7 | from .features import FeaturesDataset 8 | from .vgenome_interim import vgenome_interim 9 | from .vgenome_processed import vgenome_processed 10 | from .coco import default_transform 11 | from .utils import AbstractVQADataset 12 | 13 | def raw(dir_raw): 14 | dir_img = os.path.join(dir_raw, 'images') 15 | os.system('wget http://visualgenome.org/static/data/dataset/image_data.json.zip -P '+dir_raw) 16 | os.system('wget http://visualgenome.org/static/data/dataset/question_answers.json.zip -P '+dir_raw) 17 | os.system('wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -P '+dir_raw) 18 | os.system('wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -P '+dir_raw) 19 | 20 | os.system('unzip '+os.path.join(dir_raw, 'image_data.json.zip')+' -d '+dir_raw) 21 | os.system('unzip '+os.path.join(dir_raw, 'question_answers.json.zip')+' -d '+dir_raw) 22 | os.system('unzip '+os.path.join(dir_raw, 'images.zip')+' -d '+dir_raw) 23 | os.system('unzip '+os.path.join(dir_raw, 'images2.zip')+' -d '+dir_raw) 24 | 25 | os.system('mv '+os.path.join(dir_raw, 'VG_100K')+' '+dir_img) 26 | 27 | #os.system('mv '+os.path.join(dir_raw, 'VG_100K_2', '*.jpg')+' '+dir_img) 28 | os.system('find '+os.path.join(dir_raw, 'VG_100K_2')+' -type f -name \'*\' -exec mv {} '+dir_img+' \\;') 29 | os.system('rm -rf '+os.path.join(dir_raw, 'VG_100K_2')) 30 | 31 | # remove images with 0 octet in a ugly but efficient way :') 32 | #print('for f in $(ls -lh '+dir_img+' | grep " 0 " | cut -s -f14 --delimiter=" "); do rm '+dir_img+'/${f}; done;') 33 | os.system('for f in $(ls -lh '+dir_img+' | grep " 0 " | cut -s -f14 --delimiter=" "); do echo '+dir_img+'/${f}; done;') 34 | os.system('for f in $(ls -lh '+dir_img+' | grep " 0 " | cut -s -f14 --delimiter=" "); do rm '+dir_img+'/${f}; done;') 35 | 36 | 37 | class VisualGenome(AbstractVQADataset): 38 | 39 | def __init__(self, data_split, opt, dataset_img=None): 40 | super(VisualGenome, self).__init__(data_split, opt, dataset_img) 41 | 42 | def __getitem__(self, index): 43 | item_qa = self.dataset[index] 44 | item = {} 45 | if self.dataset_img is not None: 46 | item_img = self.dataset_img.get_by_name(item_qa['image_name']) 47 | item['visual'] = item_img['visual'] 48 | # DEBUG 49 | #item['visual_debug'] = item_qa['image_name'] 50 | item['question'] = torch.LongTensor(item_qa['question_wids']) 51 | # DEBUG 52 | #item['question_debug'] = item_qa['question'] 53 | item['question_id'] = item_qa['question_id'] 54 | item['answer'] = item_qa['answer_aid'] 55 | # DEBUG 56 | #item['answer_debug'] = item_qa['answer'] 57 | return item 58 | 59 | def _raw(self): 60 | raw(self.dir_raw) 61 | 62 | def _interim(self): 63 | vgenome_interim(self.opt) 64 | 65 | def _processed(self): 66 | vgenome_processed(self.opt) 67 | 68 | def __len__(self): 69 | return len(self.dataset) 70 | 71 | 72 | class VisualGenomeImages(AbstractImagesDataset): 73 | 74 | def __init__(self, data_split, opt, transform=None, loader=default_loader): 75 | super(VisualGenomeImages, self).__init__(data_split, opt, transform, loader) 76 | self.dir_img = os.path.join(self.dir_raw, 'images') 77 | self.dataset = ImagesFolder(self.dir_img, transform=self.transform, loader=self.loader) 78 | self.name_to_index = self._load_name_to_index() 79 | 80 | def _raw(self): 81 | raw(self.dir_raw) 82 | 83 | def _load_name_to_index(self): 84 | self.name_to_index = {name:index for index, name in enumerate(self.dataset.imgs)} 85 | return self.name_to_index 86 | 87 | def __getitem__(self, index): 88 | item = self.dataset[index] 89 | return item 90 | 91 | def __len__(self): 92 | return len(self.dataset) 93 | 94 | 95 | def factory(opt, vqa=False, transform=None): 96 | 97 | if vqa: 98 | dataset_img = factory(opt, vqa=False, transform=transform) 99 | return VisualGenome('train', opt, dataset_img) 100 | 101 | if opt['mode'] == 'img': 102 | if transform is None: 103 | transform = default_transform(opt['size']) 104 | 105 | elif opt['mode'] in ['noatt', 'att']: 106 | return FeaturesDataset('train', opt) 107 | 108 | else: 109 | raise ValueError 110 | 111 | 112 | -------------------------------------------------------------------------------- /vqa/datasets/vgenome_interim.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | 5 | # def get_image_path(subtype='train2014', image_id='1', format='%s/COCO_%s_%012d.jpg'): 6 | # return format%(subtype, subtype, image_id) 7 | 8 | def interim(questions_annotations): 9 | data = [] 10 | for i in range(len(questions_annotations)): 11 | qa_img = questions_annotations[i] 12 | qa_img_id = qa_img['id'] 13 | for j in range(len(qa_img['qas'])): 14 | qa = qa_img['qas'][j] 15 | row = {} 16 | row['question_id'] = qa['qa_id'] 17 | row['image_id'] = qa_img_id 18 | row['image_name'] = str(qa_img_id) + '.jpg' 19 | row['question'] = qa['question'] 20 | row['answer'] = qa['answer'] 21 | data.append(row) 22 | return data 23 | 24 | def vgenome_interim(params): 25 | ''' 26 | Put the VisualGenomme VQA data into single json file in data/interim 27 | or train, val, trainval : [[question_id, image_id, question, answer] ... ] 28 | ''' 29 | path_qa = os.path.join(params['dir'], 'interim', 'questions_annotations.json') 30 | os.system('mkdir -p ' + os.path.join(params['dir'], 'interim')) 31 | 32 | print('Loading annotations and questions...') 33 | questions_annotations = json.load(open(os.path.join(params['dir'], 'raw', 'question_answers.json'), 'r')) 34 | 35 | data = interim(questions_annotations) 36 | print('Questions number %d'%len(data)) 37 | print('Write', path_qa) 38 | json.dump(data, open(path_qa, 'w')) 39 | 40 | if __name__ == "__main__": 41 | 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--dir_vg', default='data/visualgenome', type=str, help='Path to visual genome data directory') 44 | args = parser.parse_args() 45 | params = vars(args) 46 | vgenome_interim(params) 47 | -------------------------------------------------------------------------------- /vqa/datasets/vgenome_processed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess an interim json data files 3 | into one preprocess hdf5/json data files. 4 | Caption: Use nltk, or mcb, or split function to get tokens. 5 | """ 6 | from random import shuffle, seed 7 | import sys 8 | import os.path 9 | import argparse 10 | import numpy as np 11 | import scipy.io 12 | import pdb 13 | import h5py 14 | from nltk.tokenize import word_tokenize 15 | import json 16 | import csv 17 | import re 18 | import math 19 | import pickle 20 | 21 | from .vqa_processed import get_top_answers, remove_examples, tokenize, tokenize_mcb, \ 22 | preprocess_questions, remove_long_tail_train, \ 23 | encode_question, encode_answer 24 | 25 | def preprocess_answers(examples, nlp='nltk'): 26 | print('Example of modified answers after preprocessing:') 27 | for i, ex in enumerate(examples): 28 | s = ex['answer'] 29 | if nlp == 'nltk': 30 | ex['answer'] = " ".join(word_tokenize(str(s).lower())) 31 | elif nlp == 'mcb': 32 | ex['answer'] = " ".join(tokenize_mcb(s)) 33 | else: 34 | ex['answer'] = " ".join(tokenize(s)) 35 | if i < 10: print(s, 'became', "->"+ex['answer']+"<-") 36 | if i>0 and i % 1000 == 0: 37 | sys.stdout.write("processing %d/%d (%.2f%% done) \r" % (i, len(examples), i*100.0/len(examples)) ) 38 | sys.stdout.flush() 39 | return examples 40 | 41 | def build_csv(path, examples, split='train', delimiter_col='~', delimiter_number='|'): 42 | with open(path, 'wb') as f: 43 | writer = csv.writer(f, delimiter=delimiter_col) 44 | for ex in examples: 45 | import ipdb; ipdb.set_trace() 46 | row = [] 47 | row.append(ex['question_id']) 48 | row.append(ex['question']) 49 | row.append(delimiter_number.join(ex['question_words_UNK'])) 50 | row.append(delimiter_number.join(ex['question_wids'])) 51 | 52 | row.append(ex['image_id']) 53 | 54 | if split in ['train','val','trainval']: 55 | row.append(ex['answer_aid']) 56 | row.append(ex['answer']) 57 | writer.writerow(row) 58 | 59 | def vgenome_processed(params): 60 | 61 | ##################################################### 62 | ## Read input files 63 | ##################################################### 64 | 65 | path_train = os.path.join(params['dir'], 'interim', 'questions_annotations.json') 66 | 67 | # An example is a tuple (question, image, answer) 68 | # /!\ test and test-dev have no answer 69 | trainset = json.load(open(path_train, 'r')) 70 | 71 | ##################################################### 72 | ## Preprocess examples (questions and answers) 73 | ##################################################### 74 | 75 | trainset = preprocess_answers(trainset, params['nlp']) 76 | 77 | top_answers = get_top_answers(trainset, params['nans']) 78 | aid_to_ans = {i+1:w for i,w in enumerate(top_answers)} 79 | ans_to_aid = {w:i+1 for i,w in enumerate(top_answers)} 80 | 81 | # Remove examples if answer is not in top answers 82 | #trainset = remove_examples(trainset, ans_to_aid) 83 | 84 | # Add 'question_words' to the initial tuple 85 | trainset = preprocess_questions(trainset, params['nlp']) 86 | 87 | # Also process top_words which contains a UNK char 88 | trainset, top_words = remove_long_tail_train(trainset, params['minwcount']) 89 | wid_to_word = {i+1:w for i,w in enumerate(top_words)} 90 | word_to_wid = {w:i+1 for i,w in enumerate(top_words)} 91 | 92 | #examples_test = remove_long_tail_test(examples_test, word_to_wid) 93 | 94 | trainset = encode_question(trainset, word_to_wid, params['maxlength'], params['pad']) 95 | 96 | trainset = encode_answer(trainset, ans_to_aid) 97 | 98 | ##################################################### 99 | ## Write output files 100 | ##################################################### 101 | 102 | # Paths to output files 103 | # Ex: data/vqa/preprocess/nans,3000_maxlength,15_..._trainsplit,train_testsplit,val/id_to_word.json 104 | subdirname = 'nans,'+str(params['nans']) 105 | for param in ['maxlength', 'minwcount', 'nlp', 'pad', 'trainsplit']: 106 | subdirname += '_' + param + ',' + str(params[param]) 107 | os.system('mkdir -p ' + os.path.join(params['dir'], 'processed', subdirname)) 108 | 109 | path_wid_to_word = os.path.join(params['dir'], 'processed', subdirname, 'wid_to_word.pickle') 110 | path_word_to_wid = os.path.join(params['dir'], 'processed', subdirname, 'word_to_wid.pickle') 111 | path_aid_to_ans = os.path.join(params['dir'], 'processed', subdirname, 'aid_to_ans.pickle') 112 | path_ans_to_aid = os.path.join(params['dir'], 'processed', subdirname, 'ans_to_aid.pickle') 113 | #path_csv_train = os.path.join(params['dir'], 'processed', subdirname, 'train.csv') 114 | path_trainset = os.path.join(params['dir'], 'processed', subdirname, 'trainset.pickle') 115 | 116 | print('Write wid_to_word to', path_wid_to_word) 117 | with open(path_wid_to_word, 'wb') as handle: 118 | pickle.dump(wid_to_word, handle) 119 | 120 | print('Write word_to_wid to', path_word_to_wid) 121 | with open(path_word_to_wid, 'wb') as handle: 122 | pickle.dump(word_to_wid, handle) 123 | 124 | print('Write aid_to_ans to', path_aid_to_ans) 125 | with open(path_aid_to_ans, 'wb') as handle: 126 | pickle.dump(aid_to_ans, handle) 127 | 128 | print('Write ans_to_aid to', path_ans_to_aid) 129 | with open(path_ans_to_aid, 'wb') as handle: 130 | pickle.dump(ans_to_aid, handle) 131 | 132 | print('Write trainset to', path_trainset) 133 | with open(path_trainset, 'wb') as handle: 134 | pickle.dump(trainset, handle) 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('--dir_vg', 140 | default='data/visualgenome', 141 | type=str, 142 | help='Root directory containing raw, interim and processed directories' 143 | ) 144 | parser.add_argument('--nans', 145 | default=10000, 146 | type=int, 147 | help='Number of top answers for the final classifications' 148 | ) 149 | parser.add_argument('--maxlength', 150 | default=26, 151 | type=int, 152 | help='Max number of words in a caption. Captions longer get clipped' 153 | ) 154 | parser.add_argument('--minwcount', 155 | default=0, 156 | type=int, 157 | help='Words that occur less than that are removed from vocab' 158 | ) 159 | parser.add_argument('--nlp', 160 | default='mcb', 161 | type=str, 162 | help='Token method ; Options: nltk | mcb | naive' 163 | ) 164 | parser.add_argument('--pad', 165 | default='left', 166 | type=str, 167 | help='Padding ; Options: right (finish by zeros) | left (begin by zeros)' 168 | ) 169 | args = parser.parse_args() 170 | params = vars(args) 171 | vgenome_processed(params) -------------------------------------------------------------------------------- /vqa/datasets/vqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | import torch.utils.data as data 5 | import copy 6 | import numpy as np 7 | 8 | from ..lib import utils 9 | from ..lib.dataloader import DataLoader 10 | from .utils import AbstractVQADataset 11 | from .vqa_interim import vqa_interim 12 | from .vqa2_interim import vqa_interim as vqa2_interim 13 | from .vqa_processed import vqa_processed 14 | from . import coco 15 | from . import vgenome 16 | 17 | 18 | class AbstractVQA(AbstractVQADataset): 19 | 20 | def __init__(self, data_split, opt, dataset_img=None): 21 | super(AbstractVQA, self).__init__(data_split, opt, dataset_img) 22 | 23 | if 'train' not in self.data_split: # means self.data_split is 'val' or 'test' 24 | self.opt['samplingans'] = False 25 | assert 'samplingans' in self.opt, \ 26 | "opt['vqa'] does not have 'samplingans' "\ 27 | "entry. Set it to True or False." 28 | 29 | if self.data_split == 'test': 30 | path_testdevset = os.path.join(self.subdir_processed, 'testdevset.pickle') 31 | with open(path_testdevset, 'rb') as handle: 32 | self.testdevset_vqa = pickle.load(handle) 33 | self.is_qid_testdev = {} 34 | for i in range(len(self.testdevset_vqa)): 35 | qid = self.testdevset_vqa[i]['question_id'] 36 | self.is_qid_testdev[qid] = True 37 | 38 | def _raw(self): 39 | raise NotImplementedError 40 | 41 | def _interim(self): 42 | raise NotImplementedError 43 | 44 | def _processed(self): 45 | raise NotImplementedError 46 | 47 | def __getitem__(self, index): 48 | item = {} 49 | # TODO: better handle cascade of dict items 50 | item_vqa = self.dataset[index] 51 | 52 | # Process Visual (image or features) 53 | if self.dataset_img is not None: 54 | item_img = self.dataset_img.get_by_name(item_vqa['image_name']) 55 | item['visual'] = item_img['visual'] 56 | 57 | # Process Question (word token) 58 | item['question_id'] = item_vqa['question_id'] 59 | item['question'] = torch.LongTensor(item_vqa['question_wids']) 60 | 61 | if self.data_split == 'test': 62 | if item['question_id'] in self.is_qid_testdev: 63 | item['is_testdev'] = True 64 | else: 65 | item['is_testdev'] = False 66 | else: 67 | ## Process Answer if exists 68 | if self.opt['samplingans']: 69 | proba = item_vqa['answers_count'] 70 | proba = proba / np.sum(proba) 71 | item['answer'] = int(np.random.choice(item_vqa['answers_aid'], p=proba)) 72 | else: 73 | item['answer'] = item_vqa['answer_aid'] 74 | 75 | return item 76 | 77 | def __len__(self): 78 | return len(self.dataset) 79 | 80 | def num_classes(self): 81 | return len(self.aid_to_ans) 82 | 83 | def vocab_words(self): 84 | return list(self.wid_to_word.values()) 85 | 86 | def vocab_answers(self): 87 | return self.aid_to_ans 88 | 89 | def data_loader(self, batch_size=10, num_workers=4, shuffle=False): 90 | return DataLoader(self, 91 | batch_size=batch_size, shuffle=shuffle, 92 | num_workers=num_workers, pin_memory=True) 93 | 94 | def split_name(self, testdev=False): 95 | if testdev: 96 | return 'test-dev2015' 97 | if self.data_split in ['train', 'val']: 98 | return self.data_split+'2014' 99 | elif self.data_split == 'test': 100 | return self.data_split+'2015' 101 | elif self.data_split == 'testdev': 102 | return 'test-dev2015' 103 | else: 104 | assert False, 'Wrong data_split: {}'.format(self.data_split) 105 | 106 | def subdir_processed(self): 107 | subdir = 'nans,' + str(self.opt['nans']) \ 108 | + '_maxlength,' + str(self.opt['maxlength']) \ 109 | + '_minwcount,' + str(self.opt['minwcount']) \ 110 | + '_nlp,' + self.opt['nlp'] \ 111 | + '_pad,' + self.opt['pad'] \ 112 | + '_trainsplit,' + self.opt['trainsplit'] 113 | subdir = os.path.join(self.dir_processed, subdir) 114 | return subdir 115 | 116 | 117 | class VQA(AbstractVQA): 118 | 119 | def __init__(self, data_split, opt, dataset_img=None): 120 | super(VQA, self).__init__(data_split, opt, dataset_img) 121 | 122 | def _raw(self): 123 | dir_zip = os.path.join(self.dir_raw, 'zip') 124 | dir_ann = os.path.join(self.dir_raw, 'annotations') 125 | os.system('mkdir -p '+dir_zip) 126 | os.system('mkdir -p '+dir_ann) 127 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Questions_Train_mscoco.zip -P '+dir_zip) 128 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Questions_Val_mscoco.zip -P '+dir_zip) 129 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Questions_Test_mscoco.zip -P '+dir_zip) 130 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Annotations_Train_mscoco.zip -P '+dir_zip) 131 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Annotations_Val_mscoco.zip -P '+dir_zip) 132 | os.system('unzip '+os.path.join(dir_zip, 'Questions_Train_mscoco.zip')+' -d '+dir_ann) 133 | os.system('unzip '+os.path.join(dir_zip, 'Questions_Val_mscoco.zip')+' -d '+dir_ann) 134 | os.system('unzip '+os.path.join(dir_zip, 'Questions_Test_mscoco.zip')+' -d '+dir_ann) 135 | os.system('unzip '+os.path.join(dir_zip, 'Annotations_Train_mscoco.zip')+' -d '+dir_ann) 136 | os.system('unzip '+os.path.join(dir_zip, 'Annotations_Val_mscoco.zip')+' -d '+dir_ann) 137 | 138 | def _interim(self): 139 | vqa_interim(self.opt['dir']) 140 | 141 | def _processed(self): 142 | vqa_processed(self.opt) 143 | 144 | 145 | class VQA2(AbstractVQA): 146 | 147 | def __init__(self, data_split, opt, dataset_img=None): 148 | super(VQA2, self).__init__(data_split, opt, dataset_img) 149 | 150 | def _raw(self): 151 | dir_zip = os.path.join(self.dir_raw, 'zip') 152 | dir_ann = os.path.join(self.dir_raw, 'annotations') 153 | os.system('mkdir -p '+dir_zip) 154 | os.system('mkdir -p '+dir_ann) 155 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip -P '+dir_zip) 156 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip -P '+dir_zip) 157 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip -P '+dir_zip) 158 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip -P '+dir_zip) 159 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip -P '+dir_zip) 160 | os.system('unzip '+os.path.join(dir_zip, 'v2_Questions_Train_mscoco.zip')+' -d '+dir_ann) 161 | os.system('unzip '+os.path.join(dir_zip, 'v2_Questions_Val_mscoco.zip')+' -d '+dir_ann) 162 | os.system('unzip '+os.path.join(dir_zip, 'v2_Questions_Test_mscoco.zip')+' -d '+dir_ann) 163 | os.system('unzip '+os.path.join(dir_zip, 'v2_Annotations_Train_mscoco.zip')+' -d '+dir_ann) 164 | os.system('unzip '+os.path.join(dir_zip, 'v2_Annotations_Val_mscoco.zip')+' -d '+dir_ann) 165 | os.system('mv '+os.path.join(dir_ann, 'v2_mscoco_train2014_annotations.json')+' ' 166 | +os.path.join(dir_ann, 'mscoco_train2014_annotations.json')) 167 | os.system('mv '+os.path.join(dir_ann, 'v2_mscoco_val2014_annotations.json')+' ' 168 | +os.path.join(dir_ann, 'mscoco_val2014_annotations.json')) 169 | os.system('mv '+os.path.join(dir_ann, 'v2_OpenEnded_mscoco_train2014_questions.json')+' ' 170 | +os.path.join(dir_ann, 'OpenEnded_mscoco_train2014_questions.json')) 171 | os.system('mv '+os.path.join(dir_ann, 'v2_OpenEnded_mscoco_val2014_questions.json')+' ' 172 | +os.path.join(dir_ann, 'OpenEnded_mscoco_val2014_questions.json')) 173 | os.system('mv '+os.path.join(dir_ann, 'v2_OpenEnded_mscoco_test2015_questions.json')+' ' 174 | +os.path.join(dir_ann, 'OpenEnded_mscoco_test2015_questions.json')) 175 | os.system('mv '+os.path.join(dir_ann, 'v2_OpenEnded_mscoco_test-dev2015_questions.json')+' ' 176 | +os.path.join(dir_ann, 'OpenEnded_mscoco_test-dev2015_questions.json')) 177 | 178 | def _interim(self): 179 | vqa2_interim(self.opt['dir']) 180 | 181 | def _processed(self): 182 | vqa_processed(self.opt) 183 | 184 | 185 | class VQAVisualGenome(data.Dataset): 186 | 187 | def __init__(self, dataset_vqa, dataset_vgenome): 188 | self.dataset_vqa = dataset_vqa 189 | self.dataset_vgenome = dataset_vgenome 190 | self._filter_dataset_vgenome() 191 | 192 | def _filter_dataset_vgenome(self): 193 | print('-> Filtering dataset vgenome') 194 | data_vg = self.dataset_vgenome.dataset 195 | ans_to_aid = self.dataset_vqa.ans_to_aid 196 | word_to_wid = self.dataset_vqa.word_to_wid 197 | data_vg_new = [] 198 | not_in = 0 199 | for i in range(len(data_vg)): 200 | if data_vg[i]['answer'] not in ans_to_aid: 201 | not_in += 1 202 | else: 203 | data_vg[i]['answer_aid'] = ans_to_aid[data_vg[i]['answer']] 204 | for j in range(data_vg[i]['seq_length']): 205 | word = data_vg[i]['question_words_UNK'][j] 206 | if word in word_to_wid: 207 | wid = word_to_wid[word] 208 | else: 209 | wid = word_to_wid['UNK'] 210 | data_vg[i]['question_wids'][j] = wid 211 | data_vg_new.append(data_vg[i]) 212 | print('-> {} / {} items removed'.format(not_in, len(data_vg))) 213 | self.dataset_vgenome.dataset = data_vg_new 214 | print('-> {} items left in visual genome'.format(len(self.dataset_vgenome))) 215 | print('-> {} items total in vqa+vg'.format(len(self))) 216 | 217 | 218 | def __getitem__(self, index): 219 | if index < len(self.dataset_vqa): 220 | item = self.dataset_vqa[index] 221 | #print('vqa') 222 | else: 223 | item = self.dataset_vgenome[index - len(self.dataset_vqa)] 224 | #print('vg') 225 | #import ipdb; ipdb.set_trace() 226 | return item 227 | 228 | def __len__(self): 229 | return len(self.dataset_vqa) + len(self.dataset_vgenome) 230 | 231 | def num_classes(self): 232 | return self.dataset_vqa.num_classes() 233 | 234 | def vocab_words(self): 235 | return self.dataset_vqa.vocab_words() 236 | 237 | def vocab_answers(self): 238 | return self.dataset_vqa.vocab_answers() 239 | 240 | def data_loader(self, batch_size=10, num_workers=4, shuffle=False): 241 | return DataLoader(self, 242 | batch_size=batch_size, shuffle=shuffle, 243 | num_workers=num_workers, pin_memory=True) 244 | 245 | def split_name(self, testdev=False): 246 | return self.dataset_vqa.split_name(testdev=testdev) 247 | 248 | 249 | def factory(data_split, opt, opt_coco=None, opt_vgenome=None): 250 | dataset_img = None 251 | 252 | if opt_coco is not None: 253 | dataset_img = coco.factory(data_split, opt_coco) 254 | 255 | if opt['dataset'] == 'VQA' and '2' not in opt['dir']: # sanity check 256 | dataset_vqa = VQA(data_split, opt, dataset_img) 257 | elif opt['dataset'] == 'VQA2' and '2' in opt['dir']: # sanity check 258 | dataset_vqa = VQA2(data_split, opt, dataset_img) 259 | else: 260 | raise ValueError 261 | 262 | if opt_vgenome is not None: 263 | dataset_vgenome = vgenome.factory(opt_vgenome, vqa=True) 264 | return VQAVisualGenome(dataset_vqa, dataset_vgenome) 265 | else: 266 | return dataset_vqa 267 | 268 | -------------------------------------------------------------------------------- /vqa/datasets/vqa2_interim.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from collections import Counter 5 | 6 | def get_subtype(split='train'): 7 | if split in ['train', 'val']: 8 | return split + '2014' 9 | else: 10 | return 'test2015' 11 | 12 | def get_image_name_old(subtype='train2014', image_id='1', format='%s/COCO_%s_%012d.jpg'): 13 | return format%(subtype, subtype, image_id) 14 | 15 | def get_image_name(subtype='train2014', image_id='1', format='COCO_%s_%012d.jpg'): 16 | return format%(subtype, image_id) 17 | 18 | def interim(questions, split='train', annotations=[]): 19 | print('Interim', split) 20 | data = [] 21 | for i in range(len(questions)): 22 | row = {} 23 | row['question_id'] = questions[i]['question_id'] 24 | row['image_name'] = get_image_name(get_subtype(split), questions[i]['image_id']) 25 | row['question'] = questions[i]['question'] 26 | #row['MC_answer'] = questions[i]['multiple_choices'] 27 | if split in ['train', 'val', 'trainval']: 28 | row['answer'] = annotations[i]['multiple_choice_answer'] 29 | answers = [] 30 | for ans in annotations[i]['answers']: 31 | answers.append(ans['answer']) 32 | row['answers_occurence'] = Counter(answers).most_common() 33 | data.append(row) 34 | return data 35 | 36 | def vqa_interim(dir_vqa): 37 | ''' 38 | Put the VQA data into single json file in data/interim 39 | or train, val, trainval : [[question_id, image_name, question, MC_answer, answer] ... ] 40 | or test, test-dev : [[question_id, image_name, question, MC_answer] ... ] 41 | ''' 42 | 43 | path_train_qa = os.path.join(dir_vqa, 'interim', 'train_questions_annotations.json') 44 | path_val_qa = os.path.join(dir_vqa, 'interim', 'val_questions_annotations.json') 45 | path_trainval_qa = os.path.join(dir_vqa, 'interim', 'trainval_questions_annotations.json') 46 | path_test_q = os.path.join(dir_vqa, 'interim', 'test_questions.json') 47 | path_testdev_q = os.path.join(dir_vqa, 'interim', 'testdev_questions.json') 48 | 49 | os.system('mkdir -p ' + os.path.join(dir_vqa, 'interim')) 50 | 51 | print('Loading annotations and questions...') 52 | annotations_train = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'mscoco_train2014_annotations.json'), 'r')) 53 | annotations_val = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'mscoco_val2014_annotations.json'), 'r')) 54 | questions_train = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'OpenEnded_mscoco_train2014_questions.json'), 'r')) 55 | questions_val = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'OpenEnded_mscoco_val2014_questions.json'), 'r')) 56 | questions_test = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'OpenEnded_mscoco_test2015_questions.json'), 'r')) 57 | questions_testdev = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'OpenEnded_mscoco_test-dev2015_questions.json'), 'r')) 58 | 59 | data_train = interim(questions_train['questions'], 'train', annotations_train['annotations']) 60 | print('Train size %d'%len(data_train)) 61 | print('Write', path_train_qa) 62 | json.dump(data_train, open(path_train_qa, 'w')) 63 | 64 | data_val = interim(questions_val['questions'], 'val', annotations_val['annotations']) 65 | print('Val size %d'%len(data_val)) 66 | print('Write', path_val_qa) 67 | json.dump(data_val, open(path_val_qa, 'w')) 68 | 69 | print('Concat. train and val') 70 | data_trainval = data_train + data_val 71 | print('Trainval size %d'%len(data_trainval)) 72 | print('Write', path_trainval_qa) 73 | json.dump(data_trainval, open(path_trainval_qa, 'w')) 74 | 75 | data_testdev = interim(questions_testdev['questions'], 'testdev') 76 | print('Testdev size %d'%len(data_testdev)) 77 | print('Write', path_testdev_q) 78 | json.dump(data_testdev, open(path_testdev_q, 'w')) 79 | 80 | data_test = interim(questions_test['questions'], 'test') 81 | print('Test size %d'%len(data_test)) 82 | print('Write', path_test_q) 83 | json.dump(data_test, open(path_test_q, 'w')) 84 | 85 | if __name__ == "__main__": 86 | 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--dir_vqa', default='data/vqa', type=str, help='Path to vqa data directory') 89 | args = parser.parse_args() 90 | vqa_interim(args.dir_vqa) 91 | -------------------------------------------------------------------------------- /vqa/datasets/vqa_interim.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from collections import Counter 5 | 6 | def get_subtype(split='train'): 7 | if split in ['train', 'val']: 8 | return split + '2014' 9 | else: 10 | return 'test2015' 11 | 12 | def get_image_name_old(subtype='train2014', image_id='1', format='%s/COCO_%s_%012d.jpg'): 13 | return format%(subtype, subtype, image_id) 14 | 15 | def get_image_name(subtype='train2014', image_id='1', format='COCO_%s_%012d.jpg'): 16 | return format%(subtype, image_id) 17 | 18 | def interim(questions, split='train', annotations=[]): 19 | print('Interim', split) 20 | data = [] 21 | for i in range(len(questions)): 22 | row = {} 23 | row['question_id'] = questions[i]['question_id'] 24 | row['image_name'] = get_image_name(get_subtype(split), questions[i]['image_id']) 25 | row['question'] = questions[i]['question'] 26 | row['MC_answer'] = questions[i]['multiple_choices'] 27 | if split in ['train', 'val', 'trainval']: 28 | row['answer'] = annotations[i]['multiple_choice_answer'] 29 | answers = [] 30 | for ans in annotations[i]['answers']: 31 | answers.append(ans['answer']) 32 | row['answers_occurence'] = Counter(answers).most_common() 33 | data.append(row) 34 | return data 35 | 36 | def vqa_interim(dir_vqa): 37 | ''' 38 | Put the VQA data into single json file in data/interim 39 | or train, val, trainval : [[question_id, image_name, question, MC_answer, answer] ... ] 40 | or test, test-dev : [[question_id, image_name, question, MC_answer] ... ] 41 | ''' 42 | 43 | path_train_qa = os.path.join(dir_vqa, 'interim', 'train_questions_annotations.json') 44 | path_val_qa = os.path.join(dir_vqa, 'interim', 'val_questions_annotations.json') 45 | path_trainval_qa = os.path.join(dir_vqa, 'interim', 'trainval_questions_annotations.json') 46 | path_test_q = os.path.join(dir_vqa, 'interim', 'test_questions.json') 47 | path_testdev_q = os.path.join(dir_vqa, 'interim', 'testdev_questions.json') 48 | 49 | os.system('mkdir -p ' + os.path.join(dir_vqa, 'interim')) 50 | 51 | print('Loading annotations and questions...') 52 | annotations_train = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'mscoco_train2014_annotations.json'), 'r')) 53 | annotations_val = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'mscoco_val2014_annotations.json'), 'r')) 54 | questions_train = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'MultipleChoice_mscoco_train2014_questions.json'), 'r')) 55 | questions_val = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'MultipleChoice_mscoco_val2014_questions.json'), 'r')) 56 | questions_test = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'MultipleChoice_mscoco_test2015_questions.json'), 'r')) 57 | questions_testdev = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'MultipleChoice_mscoco_test-dev2015_questions.json'), 'r')) 58 | 59 | data_train = interim(questions_train['questions'], 'train', annotations_train['annotations']) 60 | print('Train size %d'%len(data_train)) 61 | print('Write', path_train_qa) 62 | json.dump(data_train, open(path_train_qa, 'w')) 63 | 64 | data_val = interim(questions_val['questions'], 'val', annotations_val['annotations']) 65 | print('Val size %d'%len(data_val)) 66 | print('Write', path_val_qa) 67 | json.dump(data_val, open(path_val_qa, 'w')) 68 | 69 | print('Concat. train and val') 70 | data_trainval = data_train + data_val 71 | print('Trainval size %d'%len(data_trainval)) 72 | print('Write', path_trainval_qa) 73 | json.dump(data_trainval, open(path_trainval_qa, 'w')) 74 | 75 | data_testdev = interim(questions_testdev['questions'], 'testdev') 76 | print('Testdev size %d'%len(data_testdev)) 77 | print('Write', path_testdev_q) 78 | json.dump(data_testdev, open(path_testdev_q, 'w')) 79 | 80 | data_test = interim(questions_test['questions'], 'test') 81 | print('Test size %d'%len(data_test)) 82 | print('Write', path_test_q) 83 | json.dump(data_test, open(path_test_q, 'w')) 84 | 85 | if __name__ == "__main__": 86 | 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--dir_vqa', default='data/vqa', type=str, help='Path to vqa data directory') 89 | args = parser.parse_args() 90 | vqa_interim(args.dir_vqa) 91 | -------------------------------------------------------------------------------- /vqa/datasets/vqa_processed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a train/test pair of interim json data files. 3 | Caption: Use NLTK or split function to get tokens. 4 | """ 5 | from random import shuffle, seed 6 | import sys 7 | import os.path 8 | import argparse 9 | import numpy as np 10 | import scipy.io 11 | import pdb 12 | import json 13 | import csv 14 | import re 15 | import math 16 | import pickle 17 | #import pprint 18 | 19 | def get_top_answers(examples, nans=3000): 20 | counts = {} 21 | for ex in examples: 22 | ans = ex['answer'] 23 | counts[ans] = counts.get(ans, 0) + 1 24 | 25 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 26 | print('Top answer and their counts:' ) 27 | print('\n'.join(map(str,cw[:20]))) 28 | 29 | vocab = [] 30 | for i in range(nans): 31 | vocab.append(cw[i][1]) 32 | return vocab[:nans] 33 | 34 | def remove_examples(examples, ans_to_aid): 35 | new_examples = [] 36 | for i, ex in enumerate(examples): 37 | if ex['answer'] in ans_to_aid: 38 | new_examples.append(ex) 39 | print('Number of examples reduced from %d to %d '%(len(examples), len(new_examples))) 40 | return new_examples 41 | 42 | def tokenize(sentence): 43 | return [i for i in re.split(r"([-.\"',:? !\$#@~()*&\^%;\[\]/\\\+<>\n=])", sentence) if i!='' and i!=' ' and i!='\n']; 44 | 45 | def tokenize_mcb(s): 46 | t_str = s.lower() 47 | for i in [r'\?',r'\!',r'\'',r'\"',r'\$',r'\:',r'\@',r'\(',r'\)',r'\,',r'\.',r'\;']: 48 | t_str = re.sub( i, '', t_str) 49 | for i in [r'\-',r'\/']: 50 | t_str = re.sub( i, ' ', t_str) 51 | q_list = re.sub(r'\?','',t_str.lower()).split(' ') 52 | q_list = list(filter(lambda x: len(x) > 0, q_list)) 53 | return q_list 54 | 55 | def preprocess_questions(examples, nlp='nltk'): 56 | if nlp == 'nltk': 57 | from nltk.tokenize import word_tokenize 58 | print('Example of generated tokens after preprocessing some questions:') 59 | for i, ex in enumerate(examples): 60 | s = ex['question'] 61 | if nlp == 'nltk': 62 | ex['question_words'] = word_tokenize(str(s).lower()) 63 | elif nlp == 'mcb': 64 | ex['question_words'] = tokenize_mcb(s) 65 | else: 66 | ex['question_words'] = tokenize(s) 67 | if i < 10: 68 | print(ex['question_words']) 69 | if i % 1000 == 0: 70 | sys.stdout.write("processing %d/%d (%.2f%% done) \r" % (i, len(examples), i*100.0/len(examples)) ) 71 | sys.stdout.flush() 72 | return examples 73 | 74 | def remove_long_tail_train(examples, minwcount=0): 75 | # Replace words which are in the long tail (counted less than 'minwcount' times) by the UNK token. 76 | # Also create vocab, a list of the final words. 77 | 78 | # count up the number of words 79 | counts = {} 80 | for ex in examples: 81 | for w in ex['question_words']: 82 | counts[w] = counts.get(w, 0) + 1 83 | cw = sorted([(count,w) for w, count in counts.items()], reverse=True) 84 | print('Top words and their counts:') 85 | print('\n'.join(map(str,cw[:20]))) 86 | 87 | total_words = sum(counts.values()) 88 | print('Total words:', total_words) 89 | bad_words = [w for w,n in counts.items() if n <= minwcount] 90 | vocab = [w for w,n in counts.items() if n > minwcount] 91 | bad_count = sum(counts[w] for w in bad_words) 92 | print('Number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 93 | print('Number of words in vocab would be %d' % (len(vocab), )) 94 | print('Number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 95 | 96 | print('Insert the special UNK token') 97 | vocab.append('UNK') 98 | for ex in examples: 99 | words = ex['question_words'] 100 | question = [w if counts.get(w,0) > minwcount else 'UNK' for w in words] 101 | ex['question_words_UNK'] = question 102 | 103 | return examples, vocab 104 | 105 | def remove_long_tail_test(examples, word_to_wid): 106 | for ex in examples: 107 | ex['question_words_UNK'] = [w if w in word_to_wid else 'UNK' for w in ex['question_words']] 108 | return examples 109 | 110 | def encode_question(examples, word_to_wid, maxlength=15, pad='left'): 111 | # Add to tuple question_wids and question_length 112 | for i, ex in enumerate(examples): 113 | ex['question_length'] = min(maxlength, len(ex['question_words_UNK'])) # record the length of this sequence 114 | ex['question_wids'] = [0]*maxlength 115 | for k, w in enumerate(ex['question_words_UNK']): 116 | if k < maxlength: 117 | if pad == 'right': 118 | ex['question_wids'][k] = word_to_wid[w] 119 | else: #['pad'] == 'left' 120 | new_k = k + maxlength - len(ex['question_words_UNK']) 121 | ex['question_wids'][new_k] = word_to_wid[w] 122 | ex['seq_length'] = len(ex['question_words_UNK']) 123 | return examples 124 | 125 | def encode_answer(examples, ans_to_aid): 126 | print('Warning: aid of answer not in vocab is 1999') 127 | for i, ex in enumerate(examples): 128 | ex['answer_aid'] = ans_to_aid.get(ex['answer'], 1999) # -1 means answer not in vocab 129 | return examples 130 | 131 | def encode_answers_occurence(examples, ans_to_aid): 132 | for i, ex in enumerate(examples): 133 | answers = [] 134 | answers_aid = [] 135 | answers_count = [] 136 | for ans in ex['answers_occurence']: 137 | aid = ans_to_aid.get(ans[0], -1) # -1 means answer not in vocab 138 | if aid != -1: 139 | answers.append(ans[0]) 140 | answers_aid.append(aid) 141 | answers_count.append(ans[1]) 142 | ex['answers'] = answers 143 | ex['answers_aid'] = answers_aid 144 | ex['answers_count'] = answers_count 145 | return examples 146 | 147 | def vqa_processed(params): 148 | 149 | ##################################################### 150 | ## Read input files 151 | ##################################################### 152 | 153 | path_train = os.path.join(params['dir'], 'interim', params['trainsplit']+'_questions_annotations.json') 154 | if params['trainsplit'] == 'train': 155 | path_val = os.path.join(params['dir'], 'interim', 'val_questions_annotations.json') 156 | path_test = os.path.join(params['dir'], 'interim', 'test_questions.json') 157 | path_testdev = os.path.join(params['dir'], 'interim', 'testdev_questions.json') 158 | 159 | # An example is a tuple (question, image, answer) 160 | # /!\ test and test-dev have no answer 161 | trainset = json.load(open(path_train, 'r')) 162 | if params['trainsplit'] == 'train': 163 | valset = json.load(open(path_val, 'r')) 164 | testset = json.load(open(path_test, 'r')) 165 | testdevset = json.load(open(path_testdev, 'r')) 166 | 167 | ##################################################### 168 | ## Preprocess examples (questions and answers) 169 | ##################################################### 170 | 171 | top_answers = get_top_answers(trainset, params['nans']) 172 | aid_to_ans = [a for i,a in enumerate(top_answers)] 173 | ans_to_aid = {a:i for i,a in enumerate(top_answers)} 174 | # Remove examples if answer is not in top answers 175 | trainset = remove_examples(trainset, ans_to_aid) 176 | 177 | # Add 'question_words' to the initial tuple 178 | trainset = preprocess_questions(trainset, params['nlp']) 179 | if params['trainsplit'] == 'train': 180 | valset = preprocess_questions(valset, params['nlp']) 181 | testset = preprocess_questions(testset, params['nlp']) 182 | testdevset = preprocess_questions(testdevset, params['nlp']) 183 | 184 | # Also process top_words which contains a UNK char 185 | trainset, top_words = remove_long_tail_train(trainset, params['minwcount']) 186 | wid_to_word = {i+1:w for i,w in enumerate(top_words)} 187 | word_to_wid = {w:i+1 for i,w in enumerate(top_words)} 188 | 189 | if params['trainsplit'] == 'train': 190 | valset = remove_long_tail_test(valset, word_to_wid) 191 | testset = remove_long_tail_test(testset, word_to_wid) 192 | testdevset = remove_long_tail_test(testdevset, word_to_wid) 193 | 194 | trainset = encode_question(trainset, word_to_wid, params['maxlength'], params['pad']) 195 | if params['trainsplit'] == 'train': 196 | valset = encode_question(valset, word_to_wid, params['maxlength'], params['pad']) 197 | testset = encode_question(testset, word_to_wid, params['maxlength'], params['pad']) 198 | testdevset = encode_question(testdevset, word_to_wid, params['maxlength'], params['pad']) 199 | 200 | trainset = encode_answer(trainset, ans_to_aid) 201 | trainset = encode_answers_occurence(trainset, ans_to_aid) 202 | if params['trainsplit'] == 'train': 203 | valset = encode_answer(valset, ans_to_aid) 204 | valset = encode_answers_occurence(valset, ans_to_aid) 205 | 206 | ##################################################### 207 | ## Write output files 208 | ##################################################### 209 | 210 | # Paths to output files 211 | # Ex: data/vqa/processed/nans,3000_maxlength,15_..._trainsplit,train_testsplit,val/id_to_word.json 212 | subdirname = 'nans,'+str(params['nans']) 213 | for param in ['maxlength', 'minwcount', 'nlp', 'pad', 'trainsplit']: 214 | subdirname += '_' + param + ',' + str(params[param]) 215 | os.system('mkdir -p ' + os.path.join(params['dir'], 'processed', subdirname)) 216 | 217 | path_wid_to_word = os.path.join(params['dir'], 'processed', subdirname, 'wid_to_word.pickle') 218 | path_word_to_wid = os.path.join(params['dir'], 'processed', subdirname, 'word_to_wid.pickle') 219 | path_aid_to_ans = os.path.join(params['dir'], 'processed', subdirname, 'aid_to_ans.pickle') 220 | path_ans_to_aid = os.path.join(params['dir'], 'processed', subdirname, 'ans_to_aid.pickle') 221 | if params['trainsplit'] == 'train': 222 | path_trainset = os.path.join(params['dir'], 'processed', subdirname, 'trainset.pickle') 223 | path_valset = os.path.join(params['dir'], 'processed', subdirname, 'valset.pickle') 224 | elif params['trainsplit'] == 'trainval': 225 | path_trainset = os.path.join(params['dir'], 'processed', subdirname, 'trainvalset.pickle') 226 | path_testset = os.path.join(params['dir'], 'processed', subdirname, 'testset.pickle') 227 | path_testdevset = os.path.join(params['dir'], 'processed', subdirname, 'testdevset.pickle') 228 | 229 | print('Write wid_to_word to', path_wid_to_word) 230 | with open(path_wid_to_word, 'wb') as handle: 231 | pickle.dump(wid_to_word, handle) 232 | 233 | print('Write word_to_wid to', path_word_to_wid) 234 | with open(path_word_to_wid, 'wb') as handle: 235 | pickle.dump(word_to_wid, handle) 236 | 237 | print('Write aid_to_ans to', path_aid_to_ans) 238 | with open(path_aid_to_ans, 'wb') as handle: 239 | pickle.dump(aid_to_ans, handle) 240 | 241 | print('Write ans_to_aid to', path_ans_to_aid) 242 | with open(path_ans_to_aid, 'wb') as handle: 243 | pickle.dump(ans_to_aid, handle) 244 | 245 | print('Write trainset to', path_trainset) 246 | with open(path_trainset, 'wb') as handle: 247 | pickle.dump(trainset, handle) 248 | 249 | if params['trainsplit'] == 'train': 250 | print('Write valset to', path_valset) 251 | with open(path_valset, 'wb') as handle: 252 | pickle.dump(valset, handle) 253 | 254 | print('Write testset to', path_testset) 255 | with open(path_testset, 'wb') as handle: 256 | pickle.dump(testset, handle) 257 | 258 | print('Write testdevset to', path_testdevset) 259 | with open(path_testdevset, 'wb') as handle: 260 | pickle.dump(testdevset, handle) 261 | 262 | 263 | if __name__ == "__main__": 264 | parser = argparse.ArgumentParser() 265 | parser.add_argument('--dirname', 266 | default='data/vqa', 267 | type=str, 268 | help='Root directory containing raw, interim and processed directories' 269 | ) 270 | parser.add_argument('--trainsplit', 271 | default='train', 272 | type=str, 273 | help='Options: train | trainval' 274 | ) 275 | parser.add_argument('--nans', 276 | default=2000, 277 | type=int, 278 | help='Number of top answers for the final classifications' 279 | ) 280 | parser.add_argument('--maxlength', 281 | default=26, 282 | type=int, 283 | help='Max number of words in a caption. Captions longer get clipped' 284 | ) 285 | parser.add_argument('--minwcount', 286 | default=0, 287 | type=int, 288 | help='Words that occur less than that are removed from vocab' 289 | ) 290 | parser.add_argument('--nlp', 291 | default='mcb', 292 | type=str, 293 | help='Token method ; Options: nltk | mcb | naive' 294 | ) 295 | parser.add_argument('--pad', 296 | default='left', 297 | type=str, 298 | help='Padding ; Options: right (finish by zeros) | left (begin by zeros)' 299 | ) 300 | args = parser.parse_args() 301 | opt_vqa = vars(args) 302 | vqa_processed(opt_vqa) -------------------------------------------------------------------------------- /vqa/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/vqa/lib/__init__.py -------------------------------------------------------------------------------- /vqa/lib/criterions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def factory(opt, cuda=True): 5 | criterion = nn.CrossEntropyLoss() 6 | if cuda: 7 | criterion = criterion.cuda() 8 | return criterion -------------------------------------------------------------------------------- /vqa/lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from .sampler import SequentialSampler, RandomSampler 4 | import collections 5 | import math 6 | import sys 7 | import traceback 8 | import threading 9 | if sys.version_info[0] == 2: 10 | import Queue as queue 11 | else: 12 | import queue 13 | 14 | 15 | class ExceptionWrapper(object): 16 | "Wraps an exception plus traceback to communicate across threads" 17 | 18 | def __init__(self, exc_info): 19 | self.exc_type = exc_info[0] 20 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 21 | 22 | 23 | def _worker_loop(dataset, index_queue, data_queue, collate_fn): 24 | torch.set_num_threads(1) 25 | while True: 26 | r = index_queue.get() 27 | if r is None: 28 | data_queue.put(None) 29 | break 30 | idx, batch_indices = r 31 | try: 32 | samples = collate_fn([dataset[i] for i in batch_indices]) 33 | except Exception: 34 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 35 | else: 36 | data_queue.put((idx, samples)) 37 | 38 | 39 | def _pin_memory_loop(in_queue, out_queue, done_event): 40 | while True: 41 | try: 42 | r = in_queue.get() 43 | except: 44 | if done_event.is_set(): 45 | return 46 | raise 47 | if r is None: 48 | break 49 | if isinstance(r[1], ExceptionWrapper): 50 | out_queue.put(r) 51 | continue 52 | idx, batch = r 53 | try: 54 | batch = pin_memory_batch(batch) 55 | except Exception: 56 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 57 | else: 58 | out_queue.put((idx, batch)) 59 | 60 | 61 | def default_collate(batch): 62 | string_classes = (str, bytes) 63 | "Puts each data field into a tensor with outer dimension batch size" 64 | if torch.is_tensor(batch[0]): 65 | return torch.stack(batch, 0) 66 | elif type(batch[0]).__module__ == 'numpy' and type(batch[0]).__name__ == 'ndarray': 67 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 68 | elif isinstance(batch[0], int): 69 | return torch.LongTensor(batch) 70 | elif isinstance(batch[0], float): 71 | return torch.DoubleTensor(batch) 72 | elif isinstance(batch[0], string_classes): 73 | return batch 74 | elif isinstance(batch[0], dict): 75 | # ~ added by Cadene ~ 76 | # if each batch element is a dict with same keys, 77 | # then it should be a dict of collated elements 78 | keys = batch[0].keys() 79 | new_dict = {} 80 | for key in keys: 81 | new_dict[key] = [] 82 | for sample in batch: 83 | for key in keys: 84 | new_dict[key].append(sample[key]) 85 | return {key:default_collate(samples) for key, samples in new_dict.items()} 86 | elif isinstance(batch[0], collections.Iterable): 87 | # if each batch element is not a tensor, then it should be a tuple 88 | # of tensors; in that case we collate each element in the tuple 89 | transposed = zip(*batch) 90 | return [default_collate(samples) for samples in transposed] 91 | raise TypeError(("batch must contain tensors, numbers, or lists; found {}" 92 | .format(type(batch[0])))) 93 | 94 | 95 | def pin_memory_batch(batch): 96 | if torch.is_tensor(batch): 97 | return batch.pin_memory() 98 | elif isinstance(batch, dict): 99 | # ~ added by Cadene ~ 100 | return {key:pin_memory_batch(sample) for key,sample in batch.items()} 101 | elif isinstance(batch[0], str): 102 | # ~ added by Cadene ~ 103 | return batch 104 | elif isinstance(batch, collections.Iterable): 105 | return [pin_memory_batch(sample) for sample in batch] 106 | else: 107 | return batch 108 | 109 | 110 | class DataLoaderIter(object): 111 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 112 | 113 | def __init__(self, loader): 114 | self.dataset = loader.dataset 115 | self.batch_size = loader.batch_size 116 | self.collate_fn = loader.collate_fn 117 | self.sampler = loader.sampler 118 | self.num_workers = loader.num_workers 119 | self.pin_memory = loader.pin_memory 120 | self.done_event = threading.Event() 121 | 122 | self.samples_remaining = len(self.sampler) 123 | self.sample_iter = iter(self.sampler) 124 | 125 | if self.num_workers > 0: 126 | self.index_queue = multiprocessing.SimpleQueue() 127 | self.data_queue = multiprocessing.SimpleQueue() 128 | self.batches_outstanding = 0 129 | self.shutdown = False 130 | self.send_idx = 0 131 | self.rcvd_idx = 0 132 | self.reorder_dict = {} 133 | 134 | self.workers = [ 135 | multiprocessing.Process( 136 | target=_worker_loop, 137 | args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn)) 138 | for _ in range(self.num_workers)] 139 | 140 | for w in self.workers: 141 | w.daemon = True # ensure that the worker exits on process exit 142 | w.start() 143 | 144 | if self.pin_memory: 145 | in_data = self.data_queue 146 | self.data_queue = queue.Queue() 147 | self.pin_thread = threading.Thread( 148 | target=_pin_memory_loop, 149 | args=(in_data, self.data_queue, self.done_event)) 150 | self.pin_thread.daemon = True 151 | self.pin_thread.start() 152 | 153 | # prime the prefetch loop 154 | for _ in range(2 * self.num_workers): 155 | self._put_indices() 156 | 157 | def __len__(self): 158 | return int(math.ceil(len(self.sampler) / float(self.batch_size))) 159 | 160 | def __next__(self): 161 | if self.num_workers == 0: 162 | # same-process loading 163 | if self.samples_remaining == 0: 164 | raise StopIteration 165 | indices = self._next_indices() 166 | batch = self.collate_fn([self.dataset[i] for i in indices]) 167 | if self.pin_memory: 168 | batch = pin_memory_batch(batch) 169 | return batch 170 | 171 | # check if the next sample has already been generated 172 | if self.rcvd_idx in self.reorder_dict: 173 | batch = self.reorder_dict.pop(self.rcvd_idx) 174 | return self._process_next_batch(batch) 175 | 176 | if self.batches_outstanding == 0: 177 | self._shutdown_workers() 178 | raise StopIteration 179 | 180 | while True: 181 | assert (not self.shutdown and self.batches_outstanding > 0) 182 | idx, batch = self.data_queue.get() 183 | self.batches_outstanding -= 1 184 | if idx != self.rcvd_idx: 185 | # store out-of-order samples 186 | self.reorder_dict[idx] = batch 187 | continue 188 | return self._process_next_batch(batch) 189 | 190 | next = __next__ # Python 2 compatibility 191 | 192 | def __iter__(self): 193 | return self 194 | 195 | def _next_indices(self): 196 | batch_size = min(self.samples_remaining, self.batch_size) 197 | batch = [next(self.sample_iter) for _ in range(batch_size)] 198 | self.samples_remaining -= len(batch) 199 | return batch 200 | 201 | def _put_indices(self): 202 | assert self.batches_outstanding < 2 * self.num_workers 203 | if self.samples_remaining > 0: 204 | self.index_queue.put((self.send_idx, self._next_indices())) 205 | self.batches_outstanding += 1 206 | self.send_idx += 1 207 | 208 | def _process_next_batch(self, batch): 209 | self.rcvd_idx += 1 210 | self._put_indices() 211 | if isinstance(batch, ExceptionWrapper): 212 | raise batch.exc_type(batch.exc_msg) 213 | return batch 214 | 215 | def __getstate__(self): 216 | # TODO: add limited pickling support for sharing an iterator 217 | # across multiple threads for HOGWILD. 218 | # Probably the best way to do this is by moving the sample pushing 219 | # to a separate thread and then just sharing the data queue 220 | # but signalling the end is tricky without a non-blocking API 221 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 222 | 223 | def _shutdown_workers(self): 224 | if not self.shutdown: 225 | self.shutdown = True 226 | self.done_event.set() 227 | for _ in self.workers: 228 | self.index_queue.put(None) 229 | 230 | def __del__(self): 231 | if self.num_workers > 0: 232 | self._shutdown_workers() 233 | 234 | 235 | class DataLoader(object): 236 | """ 237 | Data loader. Combines a dataset and a sampler, and provides 238 | single- or multi-process iterators over the dataset. 239 | 240 | Arguments: 241 | dataset (Dataset): dataset from which to load the data. 242 | batch_size (int, optional): how many samples per batch to load 243 | (default: 1). 244 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 245 | at every epoch (default: False). 246 | sampler (Sampler, optional): defines the strategy to draw samples from 247 | the dataset. If specified, the ``shuffle`` argument is ignored. 248 | num_workers (int, optional): how many subprocesses to use for data 249 | loading. 0 means that the data will be loaded in the main process 250 | (default: 0) 251 | collate_fn (callable, optional) 252 | pin_memory (bool, optional) 253 | """ 254 | 255 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 256 | num_workers=0, collate_fn=default_collate, pin_memory=False): 257 | self.dataset = dataset 258 | self.batch_size = batch_size 259 | self.num_workers = num_workers 260 | self.collate_fn = collate_fn 261 | self.pin_memory = pin_memory 262 | 263 | if sampler is not None: 264 | self.sampler = sampler 265 | elif shuffle: 266 | self.sampler = RandomSampler(dataset) 267 | elif not shuffle: 268 | self.sampler = SequentialSampler(dataset) 269 | 270 | def __iter__(self): 271 | return DataLoaderIter(self) 272 | 273 | def __len__(self): 274 | return int(math.ceil(len(self.sampler) / float(self.batch_size))) 275 | -------------------------------------------------------------------------------- /vqa/lib/engine.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch.autograd import Variable 4 | import vqa.lib.utils as utils 5 | 6 | def train(loader, model, criterion, optimizer, logger, epoch, print_freq=10): 7 | # switch to train mode 8 | model.train() 9 | meters = logger.reset_meters('train') 10 | 11 | end = time.time() 12 | for i, sample in enumerate(loader): 13 | batch_size = sample['visual'].size(0) 14 | 15 | # measure data loading time 16 | meters['data_time'].update(time.time() - end, n=batch_size) 17 | 18 | input_visual = Variable(sample['visual']) 19 | input_question = Variable(sample['question']) 20 | target_answer = Variable(sample['answer'].cuda(async=True)) 21 | 22 | # compute output 23 | output = model(input_visual, input_question) 24 | torch.cuda.synchronize() 25 | loss = criterion(output, target_answer) 26 | meters['loss'].update(loss.data[0], n=batch_size) 27 | 28 | # measure accuracy 29 | acc1, acc5 = utils.accuracy(output.data, target_answer.data, topk=(1, 5)) 30 | meters['acc1'].update(acc1[0], n=batch_size) 31 | meters['acc5'].update(acc5[0], n=batch_size) 32 | 33 | # compute gradient and do SGD step 34 | optimizer.zero_grad() 35 | loss.backward() 36 | torch.cuda.synchronize() 37 | optimizer.step() 38 | torch.cuda.synchronize() 39 | 40 | # measure elapsed time 41 | meters['batch_time'].update(time.time() - end, n=batch_size) 42 | end = time.time() 43 | 44 | if i % print_freq == 0: 45 | print('Epoch: [{0}][{1}/{2}]\t' 46 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 47 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 48 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 49 | 'Acc@1 {acc1.val:.3f} ({acc1.avg:.3f})\t' 50 | 'Acc@5 {acc5.val:.3f} ({acc5.avg:.3f})'.format( 51 | epoch, i, len(loader), 52 | batch_time=meters['batch_time'], data_time=meters['data_time'], 53 | loss=meters['loss'], acc1=meters['acc1'], acc5=meters['acc5'])) 54 | 55 | logger.log_meters('train', n=epoch) 56 | 57 | # def adjust_learning_rate(optimizer, epoch): 58 | # """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 59 | # lr = args.lr * (0.1 ** (epoch // 30)) 60 | # for param_group in optimizer.param_groups: 61 | # param_group['lr'] = lr 62 | 63 | 64 | def validate(loader, model, criterion, logger, epoch=0, print_freq=10): 65 | results = [] 66 | 67 | # switch to evaluate mode 68 | model.eval() 69 | meters = logger.reset_meters('val') 70 | 71 | end = time.time() 72 | for i, sample in enumerate(loader): 73 | batch_size = sample['visual'].size(0) 74 | input_visual = Variable(sample['visual'].cuda(async=True), volatile=True) 75 | input_question = Variable(sample['question'].cuda(async=True), volatile=True) 76 | target_answer = Variable(sample['answer'].cuda(async=True), volatile=True) 77 | 78 | # compute output 79 | output = model(input_visual, input_question) 80 | loss = criterion(output, target_answer) 81 | meters['loss'].update(loss.data[0], n=batch_size) 82 | 83 | # measure accuracy and record loss 84 | acc1, acc5 = utils.accuracy(output.data, target_answer.data, topk=(1, 5)) 85 | meters['acc1'].update(acc1[0], n=batch_size) 86 | meters['acc5'].update(acc5[0], n=batch_size) 87 | 88 | # compute predictions for OpenEnded accuracy 89 | _, pred = output.data.cpu().max(1) 90 | pred.squeeze_() 91 | for j in range(batch_size): 92 | results.append({'question_id': sample['question_id'][j], 93 | 'answer': loader.dataset.aid_to_ans[pred[j]]}) 94 | 95 | # measure elapsed time 96 | meters['batch_time'].update(time.time() - end, n=batch_size) 97 | end = time.time() 98 | 99 | if i % print_freq == 0: 100 | print('Val: [{0}/{1}]\t' 101 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 102 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 103 | 'Acc@1 {acc1.val:.3f} ({acc1.avg:.3f})\t' 104 | 'Acc@5 {acc5.val:.3f} ({acc5.avg:.3f})'.format( 105 | i, len(loader), batch_time=meters['batch_time'], 106 | data_time=meters['data_time'], loss=meters['loss'], 107 | acc1=meters['acc1'], acc5=meters['acc5'])) 108 | 109 | print(' * Acc@1 {acc1.avg:.3f} Acc@5 {acc5.avg:.3f}' 110 | .format(acc1=meters['acc1'], acc5=meters['acc1'])) 111 | 112 | logger.log_meters('val', n=epoch) 113 | return meters['acc1'].avg, results 114 | 115 | 116 | def test(loader, model, logger, epoch=0, print_freq=10): 117 | results = [] 118 | testdev_results = [] 119 | 120 | model.eval() 121 | meters = logger.reset_meters('test') 122 | 123 | end = time.time() 124 | for i, sample in enumerate(loader): 125 | batch_size = sample['visual'].size(0) 126 | input_visual = Variable(sample['visual'].cuda(async=True), volatile=True) 127 | input_question = Variable(sample['question'].cuda(async=True), volatile=True) 128 | 129 | # compute output 130 | output = model(input_visual, input_question) 131 | 132 | # compute predictions for OpenEnded accuracy 133 | _, pred = output.data.cpu().max(1) 134 | pred.squeeze_() 135 | for j in range(batch_size): 136 | item = {'question_id': sample['question_id'][j], 137 | 'answer': loader.dataset.aid_to_ans[pred[j]]} 138 | results.append(item) 139 | if sample['is_testdev'][j]: 140 | testdev_results.append(item) 141 | 142 | # measure elapsed time 143 | meters['batch_time'].update(time.time() - end, n=batch_size) 144 | end = time.time() 145 | 146 | if i % print_freq == 0: 147 | print('Test: [{0}/{1}]\t' 148 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 149 | i, len(loader), batch_time=meters['batch_time'])) 150 | 151 | logger.log_meters('test', n=epoch) 152 | return results, testdev_results 153 | -------------------------------------------------------------------------------- /vqa/lib/logger.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import time 3 | import json 4 | import numpy as np 5 | import os 6 | from collections import defaultdict 7 | 8 | class Experiment(object): 9 | 10 | def __init__(self, name, options=dict()): 11 | """ Create an experiment 12 | """ 13 | super(Experiment, self).__init__() 14 | 15 | self.name = name 16 | self.options = options 17 | self.date_and_time = time.strftime('%d-%m-%Y--%H-%M-%S') 18 | 19 | self.info = defaultdict(dict) 20 | self.logged = defaultdict(dict) 21 | self.meters = defaultdict(dict) 22 | 23 | def add_meters(self, tag, meters_dict): 24 | assert tag not in (self.meters.keys()) 25 | for name, meter in meters_dict.items(): 26 | self.add_meter(tag, name, meter) 27 | 28 | def add_meter(self, tag, name, meter): 29 | assert name not in list(self.meters[tag].keys()), \ 30 | "meter with tag {} and name {} already exists".format(tag, name) 31 | self.meters[tag][name] = meter 32 | 33 | def update_options(self, options_dict): 34 | self.options.update(options_dict) 35 | 36 | def log_meter(self, tag, name, n=1): 37 | meter = self.get_meter(tag, name) 38 | if name not in self.logged[tag]: 39 | self.logged[tag][name] = {} 40 | self.logged[tag][name][n] = meter.value() 41 | 42 | def log_meters(self, tag, n=1): 43 | for name, meter in self.get_meters(tag).items(): 44 | self.log_meter(tag, name, n=n) 45 | 46 | def reset_meters(self, tag): 47 | meters = self.get_meters(tag) 48 | for name, meter in meters.items(): 49 | meter.reset() 50 | return meters 51 | 52 | def get_meters(self, tag): 53 | assert tag in list(self.meters.keys()) 54 | return self.meters[tag] 55 | 56 | def get_meter(self, tag, name): 57 | assert tag in list(self.meters.keys()) 58 | assert name in list(self.meters[tag].keys()) 59 | return self.meters[tag][name] 60 | 61 | def to_json(self, filename): 62 | os.system('mkdir -p ' + os.path.dirname(filename)) 63 | var_dict = copy.copy(vars(self)) 64 | var_dict.pop('meters') 65 | for key in ('viz', 'viz_dict'): 66 | if key in list(var_dict.keys()): 67 | var_dict.pop(key) 68 | with open(filename, 'w') as f: 69 | json.dump(var_dict, f) 70 | 71 | def from_json(filename): 72 | with open(filename, 'r') as f: 73 | var_dict = json.load(f) 74 | xp = Experiment('') 75 | xp.date_and_time = var_dict['date_and_time'] 76 | xp.logged = var_dict['logged'] 77 | # TODO: Remove 78 | if 'info' in var_dict: 79 | xp.info = var_dict['info'] 80 | xp.options = var_dict['options'] 81 | xp.name = var_dict['name'] 82 | return xp 83 | 84 | 85 | class AvgMeter(object): 86 | """Computes and stores the average and current value""" 87 | def __init__(self): 88 | self.reset() 89 | 90 | def reset(self): 91 | self.val = 0 92 | self.avg = 0 93 | self.sum = 0 94 | self.count = 0 95 | 96 | def update(self, val, n=1): 97 | self.val = val 98 | self.sum += val * n 99 | self.count += n 100 | self.avg = self.sum / self.count 101 | 102 | def value(self): 103 | return self.avg 104 | 105 | 106 | class SumMeter(object): 107 | """Computes and stores the sum and current value""" 108 | def __init__(self): 109 | self.reset() 110 | 111 | def reset(self): 112 | self.val = 0 113 | self.sum = 0 114 | self.count = 0 115 | 116 | def update(self, val, n=1): 117 | self.val = val 118 | self.sum += val * n 119 | self.count += n 120 | 121 | def value(self): 122 | return self.sum 123 | 124 | 125 | class ValueMeter(object): 126 | """Computes and stores the average and current value""" 127 | def __init__(self): 128 | self.reset() 129 | 130 | def reset(self): 131 | self.val = 0 132 | 133 | def update(self, val): 134 | self.val = val 135 | 136 | def value(self): 137 | return self.val -------------------------------------------------------------------------------- /vqa/lib/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | Every Sampler subclass has to provide an __iter__ method, providing a way 7 | to iterate over indices of dataset elements, and a __len__ method that 8 | returns the length of the returned iterators. 9 | """ 10 | 11 | def __init__(self, data_source): 12 | pass 13 | 14 | def __iter__(self): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | raise NotImplementedError 19 | 20 | 21 | class SequentialSampler(Sampler): 22 | """Samples elements sequentially, always in the same order. 23 | Arguments: 24 | data_source (Dataset): dataset to sample from 25 | """ 26 | 27 | def __init__(self, data_source): 28 | self.num_samples = len(data_source) 29 | 30 | def __iter__(self): 31 | return iter(range(self.num_samples)) 32 | 33 | def __len__(self): 34 | return self.num_samples 35 | 36 | 37 | class RandomSampler(Sampler): 38 | """Samples elements randomly, without replacement. 39 | Arguments: 40 | data_source (Dataset): dataset to sample from 41 | """ 42 | 43 | def __init__(self, data_source): 44 | self.num_samples = len(data_source) 45 | 46 | def __iter__(self): 47 | return iter(torch.randperm(self.num_samples).long()) 48 | 49 | def __len__(self): 50 | return self.num_samples -------------------------------------------------------------------------------- /vqa/lib/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import collections 3 | import torch 4 | import numpy as np 5 | 6 | def update_values(dict_from, dict_to): 7 | for key, value in dict_from.items(): 8 | if isinstance(value, dict): 9 | update_values(dict_from[key], dict_to[key]) 10 | elif value is not None: 11 | dict_to[key] = dict_from[key] 12 | return dict_to 13 | 14 | def merge_dict(a, b): 15 | if isinstance(a, dict) and isinstance(b, dict): 16 | d = dict(a) 17 | d.update({k: merge_dict(a.get(k, None), b[k]) for k in b}) 18 | if isinstance(a, list) and isinstance(b, list): 19 | return b 20 | #return [merge_dict(x, y) for x, y in itertools.zip_longest(a, b)] 21 | return a if b is None else b 22 | 23 | def accuracy(output, target, topk=(1,)): 24 | """Computes the precision@k for the specified values of k""" 25 | maxk = max(topk) 26 | batch_size = target.size(0) 27 | 28 | _, pred = output.topk(maxk, 1, True, True) 29 | pred = pred.t() 30 | if target.dim() == 2: # multians option 31 | _, target = torch.max(target, 1) 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | 34 | res = [] 35 | for k in topk: 36 | correct_k = correct[:k].view(-1).float().sum(0) 37 | res.append(correct_k.mul_(100.0 / batch_size)) 38 | return res 39 | 40 | def params_count(model): 41 | count = 0 42 | for p in model.parameters(): 43 | c = 1 44 | for i in range(p.dim()): 45 | c *= p.size(i) 46 | count += c 47 | return count 48 | 49 | def str2bool(v): 50 | if v is None: 51 | return v 52 | elif type(v) == bool: 53 | return v 54 | elif type(v) == str: 55 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 56 | return True 57 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 58 | return False 59 | raise argparse.ArgumentTypeError('Boolean value expected.') 60 | 61 | def create_n_hot(idxs, N): 62 | out = np.zeros(N) 63 | for i in idxs: 64 | out[i] += 1 65 | return torch.Tensor(out/out.sum()) 66 | 67 | -------------------------------------------------------------------------------- /vqa/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .noatt import MLBNoAtt, MutanNoAtt 2 | from .att import MLBAtt, MutanAtt 3 | from .utils import factory 4 | from .utils import model_names -------------------------------------------------------------------------------- /vqa/models/att.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import copy 6 | 7 | from vqa.lib import utils 8 | from vqa.models import seq2vec 9 | from vqa.models import fusion 10 | 11 | class AbstractAtt(nn.Module): 12 | 13 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]): 14 | super(AbstractAtt, self).__init__() 15 | self.opt = opt 16 | self.vocab_words = vocab_words 17 | self.vocab_answers = vocab_answers 18 | self.num_classes = len(self.vocab_answers) 19 | # Modules 20 | self.seq2vec = seq2vec.factory(self.vocab_words, self.opt['seq2vec']) 21 | # Modules for attention 22 | self.conv_v_att = nn.Conv2d(self.opt['dim_v'], 23 | self.opt['attention']['dim_v'], 1, 1) 24 | self.linear_q_att = nn.Linear(self.opt['dim_q'], 25 | self.opt['attention']['dim_q']) 26 | self.conv_att = nn.Conv2d(self.opt['attention']['dim_mm'], 27 | self.opt['attention']['nb_glimpses'], 1, 1) 28 | # Modules for classification 29 | self.list_linear_v_fusion = None 30 | self.linear_q_fusion = None 31 | self.linear_classif = None 32 | 33 | def _fusion_att(self, x_v, x_q): 34 | raise NotImplementedError 35 | 36 | def _fusion_classif(self, x_v, x_q): 37 | raise NotImplementedError 38 | 39 | def _attention(self, input_v, x_q_vec): 40 | batch_size = input_v.size(0) 41 | width = input_v.size(2) 42 | height = input_v.size(3) 43 | 44 | # Process visual before fusion 45 | #x_v = input_v.view(batch_size*width*height, dim_features) 46 | x_v = input_v 47 | x_v = F.dropout(x_v, 48 | p=self.opt['attention']['dropout_v'], 49 | training=self.training) 50 | x_v = self.conv_v_att(x_v) 51 | if 'activation_v' in self.opt['attention']: 52 | x_v = getattr(F, self.opt['attention']['activation_v'])(x_v) 53 | x_v = x_v.view(batch_size, 54 | self.opt['attention']['dim_v'], 55 | width * height) 56 | x_v = x_v.transpose(1,2) 57 | 58 | # Process question before fusion 59 | x_q = F.dropout(x_q_vec, p=self.opt['attention']['dropout_q'], 60 | training=self.training) 61 | x_q = self.linear_q_att(x_q) 62 | if 'activation_q' in self.opt['attention']: 63 | x_q = getattr(F, self.opt['attention']['activation_q'])(x_q) 64 | x_q = x_q.view(batch_size, 65 | 1, 66 | self.opt['attention']['dim_q']) 67 | x_q = x_q.expand(batch_size, 68 | width * height, 69 | self.opt['attention']['dim_q']) 70 | 71 | # First multimodal fusion 72 | x_att = self._fusion_att(x_v, x_q) 73 | 74 | if 'activation_mm' in self.opt['attention']: 75 | x_att = getattr(F, self.opt['attention']['activation_mm'])(x_att) 76 | 77 | # Process attention vectors 78 | x_att = F.dropout(x_att, 79 | p=self.opt['attention']['dropout_mm'], 80 | training=self.training) 81 | # can be optim to avoid two views and transposes 82 | x_att = x_att.view(batch_size, 83 | width, 84 | height, 85 | self.opt['attention']['dim_mm']) 86 | x_att = x_att.transpose(2,3).transpose(1,2) 87 | x_att = self.conv_att(x_att) 88 | x_att = x_att.view(batch_size, 89 | self.opt['attention']['nb_glimpses'], 90 | width * height) 91 | list_att_split = torch.split(x_att, 1, dim=1) 92 | list_att = [] 93 | for x_att in list_att_split: 94 | x_att = x_att.contiguous() 95 | x_att = x_att.view(batch_size, width*height) 96 | x_att = F.softmax(x_att) 97 | list_att.append(x_att) 98 | 99 | self.list_att = [x_att.data for x_att in list_att] 100 | 101 | # Apply attention vectors to input_v 102 | x_v = input_v.view(batch_size, self.opt['dim_v'], width * height) 103 | x_v = x_v.transpose(1,2) 104 | 105 | list_v_att = [] 106 | for i, x_att in enumerate(list_att): 107 | x_att = x_att.view(batch_size, 108 | width * height, 109 | 1) 110 | x_att = x_att.expand(batch_size, 111 | width * height, 112 | self.opt['dim_v']) 113 | x_v_att = torch.mul(x_att, x_v) 114 | x_v_att = x_v_att.sum(1) 115 | x_v_att = x_v_att.view(batch_size, self.opt['dim_v']) 116 | list_v_att.append(x_v_att) 117 | 118 | return list_v_att 119 | 120 | def _fusion_glimpses(self, list_v_att, x_q_vec): 121 | # Process visual for each glimpses 122 | list_v = [] 123 | for glimpse_id, x_v_att in enumerate(list_v_att): 124 | x_v = F.dropout(x_v_att, 125 | p=self.opt['fusion']['dropout_v'], 126 | training=self.training) 127 | x_v = self.list_linear_v_fusion[glimpse_id](x_v) 128 | if 'activation_v' in self.opt['fusion']: 129 | x_v = getattr(F, self.opt['fusion']['activation_v'])(x_v) 130 | list_v.append(x_v) 131 | x_v = torch.cat(list_v, 1) 132 | 133 | # Process question 134 | x_q = F.dropout(x_q_vec, 135 | p=self.opt['fusion']['dropout_q'], 136 | training=self.training) 137 | x_q = self.linear_q_fusion(x_q) 138 | if 'activation_q' in self.opt['fusion']: 139 | x_q = getattr(F, self.opt['fusion']['activation_q'])(x_q) 140 | 141 | # Second multimodal fusion 142 | x = self._fusion_classif(x_v, x_q) 143 | return x 144 | 145 | def _classif(self, x): 146 | 147 | if 'activation' in self.opt['classif']: 148 | x = getattr(F, self.opt['classif']['activation'])(x) 149 | x = F.dropout(x, 150 | p=self.opt['classif']['dropout'], 151 | training=self.training) 152 | x = self.linear_classif(x) 153 | return x 154 | 155 | def forward(self, input_v, input_q): 156 | if input_v.dim() != 4 and input_q.dim() != 2: 157 | raise ValueError 158 | 159 | x_q_vec = self.seq2vec(input_q) 160 | list_v_att = self._attention(input_v, x_q_vec) 161 | x = self._fusion_glimpses(list_v_att, x_q_vec) 162 | x = self._classif(x) 163 | return x 164 | 165 | 166 | class MLBAtt(AbstractAtt): 167 | 168 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]): 169 | # TODO: deep copy ? 170 | opt['attention']['dim_v'] = opt['attention']['dim_h'] 171 | opt['attention']['dim_q'] = opt['attention']['dim_h'] 172 | opt['attention']['dim_mm'] = opt['attention']['dim_h'] 173 | super(MLBAtt, self).__init__(opt, vocab_words, vocab_answers) 174 | # Modules for classification 175 | self.list_linear_v_fusion = nn.ModuleList([ 176 | nn.Linear(self.opt['dim_v'], 177 | self.opt['fusion']['dim_h']) 178 | for i in range(self.opt['attention']['nb_glimpses'])]) 179 | self.linear_q_fusion = nn.Linear(self.opt['dim_q'], 180 | self.opt['fusion']['dim_h'] 181 | * self.opt['attention']['nb_glimpses']) 182 | self.linear_classif = nn.Linear(self.opt['fusion']['dim_h'] 183 | * self.opt['attention']['nb_glimpses'], 184 | self.num_classes) 185 | 186 | def _fusion_att(self, x_v, x_q): 187 | x_att = torch.mul(x_v, x_q) 188 | return x_att 189 | 190 | def _fusion_classif(self, x_v, x_q): 191 | x_mm = torch.mul(x_v, x_q) 192 | return x_mm 193 | 194 | 195 | class MutanAtt(AbstractAtt): 196 | 197 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]): 198 | # TODO: deep copy ? 199 | opt['attention']['dim_v'] = opt['attention']['dim_hv'] 200 | opt['attention']['dim_q'] = opt['attention']['dim_hq'] 201 | super(MutanAtt, self).__init__(opt, vocab_words, vocab_answers) 202 | # Modules for classification 203 | self.fusion_att = fusion.MutanFusion2d(self.opt['attention'], 204 | visual_embedding=False, 205 | question_embedding=False) 206 | self.list_linear_v_fusion = nn.ModuleList([ 207 | nn.Linear(self.opt['dim_v'], 208 | int(self.opt['fusion']['dim_hv'] 209 | / opt['attention']['nb_glimpses'])) 210 | for i in range(self.opt['attention']['nb_glimpses'])]) 211 | self.linear_q_fusion = nn.Linear(self.opt['dim_q'], 212 | self.opt['fusion']['dim_hq']) 213 | self.linear_classif = nn.Linear(self.opt['fusion']['dim_mm'], 214 | self.num_classes) 215 | self.fusion_classif = fusion.MutanFusion(self.opt['fusion'], 216 | visual_embedding=False, 217 | question_embedding=False) 218 | 219 | def _fusion_att(self, x_v, x_q): 220 | return self.fusion_att(x_v, x_q) 221 | 222 | def _fusion_classif(self, x_v, x_q): 223 | return self.fusion_classif(x_v, x_q) 224 | -------------------------------------------------------------------------------- /vqa/models/convnets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as pytorch_models 5 | import sys 6 | sys.path.append('vqa/external/pretrained-models.pytorch') 7 | import pretrainedmodels as torch7_models 8 | 9 | pytorch_resnet_names = sorted(name for name in pytorch_models.__dict__ 10 | if name.islower() 11 | and name.startswith("resnet") 12 | and callable(pytorch_models.__dict__[name])) 13 | 14 | torch7_resnet_names = sorted(name for name in torch7_models.__dict__ 15 | if name.islower() 16 | and callable(torch7_models.__dict__[name])) 17 | 18 | model_names = pytorch_resnet_names + torch7_resnet_names 19 | 20 | def factory(opt, cuda=True, data_parallel=True): 21 | opt = copy.copy(opt) 22 | 23 | class WrapperModule(nn.Module): 24 | def __init__(self, net, forward_fn): 25 | super(WrapperModule, self).__init__() 26 | self.net = net 27 | self.forward_fn = forward_fn 28 | 29 | def forward(self, x): 30 | return self.forward_fn(self.net, x) 31 | 32 | def __getattr__(self, attr): 33 | try: 34 | return super(WrapperModule, self).__getattr__(attr) 35 | except AttributeError: 36 | return getattr(self.net, attr) 37 | 38 | def forward_resnet(self, x): 39 | x = self.conv1(x) 40 | x = self.bn1(x) 41 | x = self.relu(x) 42 | x = self.maxpool(x) 43 | x = self.layer1(x) 44 | x = self.layer2(x) 45 | x = self.layer3(x) 46 | x = self.layer4(x) 47 | 48 | if 'pooling' in opt and opt['pooling']: 49 | x = self.avgpool(x) 50 | div = x.size(3) + x.size(2) 51 | x = x.sum(3) 52 | x = x.sum(2) 53 | x = x.view(x.size(0), -1) 54 | x = x.div(div) 55 | 56 | return x 57 | 58 | def forward_resnext(self, x): 59 | x = self.features(x) 60 | 61 | if 'pooling' in opt and opt['pooling']: 62 | x = self.avgpool(x) 63 | div = x.size(3) + x.size(2) 64 | x = x.sum(3) 65 | x = x.sum(2) 66 | x = x.view(x.size(0), -1) 67 | x = x.div(div) 68 | 69 | return x 70 | 71 | if opt['arch'] in pytorch_resnet_names: 72 | model = pytorch_models.__dict__[opt['arch']](pretrained=True) 73 | 74 | model = WrapperModule(model, forward_resnet) # ugly hack in case of DataParallel wrapping 75 | 76 | elif opt['arch'] == 'fbresnet152': 77 | model = torch7_models.__dict__[opt['arch']](num_classes=1000, 78 | pretrained='imagenet') 79 | 80 | model = WrapperModule(model, forward_resnet) # ugly hack in case of DataParallel wrapping 81 | 82 | elif opt['arch'] in torch7_resnet_names: 83 | model = torch7_models.__dict__[opt['arch']](num_classes=1000, 84 | pretrained='imagenet') 85 | 86 | model = WrapperModule(model, forward_resnext) # ugly hack in case of DataParallel wrapping 87 | 88 | else: 89 | raise ValueError 90 | 91 | if data_parallel: 92 | model = nn.DataParallel(model).cuda() 93 | if not cuda: 94 | raise ValueError 95 | 96 | if cuda: 97 | model.cuda() 98 | 99 | return model 100 | -------------------------------------------------------------------------------- /vqa/models/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class AbstractFusion(nn.Module): 7 | 8 | def __init__(self, opt={}): 9 | super(AbstractFusion, self).__init__() 10 | self.opt = opt 11 | 12 | def forward(self, input_v, input_q): 13 | raise NotImplementedError 14 | 15 | 16 | class MLBFusion(AbstractFusion): 17 | 18 | def __init__(self, opt): 19 | super(MLBFusion, self).__init__(opt) 20 | # Modules 21 | if 'dim_v' in self.opt: 22 | self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_h']) 23 | else: 24 | print('Warning fusion.py: no visual embedding before fusion') 25 | 26 | if 'dim_q' in self.opt: 27 | self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_h']) 28 | else: 29 | print('Warning fusion.py: no question embedding before fusion') 30 | 31 | def forward(self, input_v, input_q): 32 | # visual (cnn features) 33 | if 'dim_v' in self.opt: 34 | x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training) 35 | x_v = self.linear_v(x_v) 36 | if 'activation_v' in self.opt: 37 | x_v = getattr(F, self.opt['activation_v'])(x_v) 38 | else: 39 | x_v = input_v 40 | # question (rnn features) 41 | if 'dim_q' in self.opt: 42 | x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training) 43 | x_q = self.linear_q(x_q) 44 | if 'activation_q' in self.opt: 45 | x_q = getattr(F, self.opt['activation_q'])(x_q) 46 | else: 47 | x_q = input_q 48 | # hadamard product 49 | x_mm = torch.mul(x_q, x_v) 50 | return x_mm 51 | 52 | 53 | class MutanFusion(AbstractFusion): 54 | 55 | def __init__(self, opt, visual_embedding=True, question_embedding=True): 56 | super(MutanFusion, self).__init__(opt) 57 | self.visual_embedding = visual_embedding 58 | self.question_embedding = question_embedding 59 | # Modules 60 | if self.visual_embedding: 61 | self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_hv']) 62 | else: 63 | print('Warning fusion.py: no visual embedding before fusion') 64 | 65 | if self.question_embedding: 66 | self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_hq']) 67 | else: 68 | print('Warning fusion.py: no question embedding before fusion') 69 | 70 | self.list_linear_hv = nn.ModuleList([ 71 | nn.Linear(self.opt['dim_hv'], self.opt['dim_mm']) 72 | for i in range(self.opt['R'])]) 73 | 74 | self.list_linear_hq = nn.ModuleList([ 75 | nn.Linear(self.opt['dim_hq'], self.opt['dim_mm']) 76 | for i in range(self.opt['R'])]) 77 | 78 | def forward(self, input_v, input_q): 79 | if input_v.dim() != input_q.dim() and input_v.dim() != 2: 80 | raise ValueError 81 | batch_size = input_v.size(0) 82 | 83 | if self.visual_embedding: 84 | x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training) 85 | x_v = self.linear_v(x_v) 86 | if 'activation_v' in self.opt: 87 | x_v = getattr(F, self.opt['activation_v'])(x_v) 88 | else: 89 | x_v = input_v 90 | 91 | if self.question_embedding: 92 | x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training) 93 | x_q = self.linear_q(x_q) 94 | if 'activation_q' in self.opt: 95 | x_q = getattr(F, self.opt['activation_q'])(x_q) 96 | else: 97 | x_q = input_q 98 | 99 | x_mm = [] 100 | for i in range(self.opt['R']): 101 | 102 | x_hv = F.dropout(x_v, p=self.opt['dropout_hv'], training=self.training) 103 | x_hv = self.list_linear_hv[i](x_hv) 104 | if 'activation_hv' in self.opt: 105 | x_hv = getattr(F, self.opt['activation_hv'])(x_hv) 106 | 107 | x_hq = F.dropout(x_q, p=self.opt['dropout_hq'], training=self.training) 108 | x_hq = self.list_linear_hq[i](x_hq) 109 | if 'activation_hq' in self.opt: 110 | x_hq = getattr(F, self.opt['activation_hq'])(x_hq) 111 | 112 | x_mm.append(torch.mul(x_hq, x_hv)) 113 | 114 | x_mm = torch.stack(x_mm, dim=1) 115 | x_mm = x_mm.sum(1).view(batch_size, self.opt['dim_mm']) 116 | 117 | if 'activation_mm' in self.opt: 118 | x_mm = getattr(F, self.opt['activation_mm'])(x_mm) 119 | 120 | return x_mm 121 | 122 | 123 | class MutanFusion2d(MutanFusion): 124 | 125 | def __init__(self, opt, visual_embedding=True, question_embedding=True): 126 | super(MutanFusion2d, self).__init__(opt, 127 | visual_embedding, 128 | question_embedding) 129 | 130 | def forward(self, input_v, input_q): 131 | if input_v.dim() != input_q.dim() and input_v.dim() != 3: 132 | raise ValueError 133 | batch_size = input_v.size(0) 134 | weight_height = input_v.size(1) 135 | dim_hv = input_v.size(2) 136 | dim_hq = input_q.size(2) 137 | if not input_v.is_contiguous(): 138 | input_v = input_v.contiguous() 139 | if not input_q.is_contiguous(): 140 | input_q = input_q.contiguous() 141 | x_v = input_v.view(batch_size * weight_height, self.opt['dim_hv']) 142 | x_q = input_q.view(batch_size * weight_height, self.opt['dim_hq']) 143 | x_mm = super().forward(x_v, x_q) 144 | x_mm = x_mm.view(batch_size, weight_height, self.opt['dim_mm']) 145 | return x_mm 146 | 147 | -------------------------------------------------------------------------------- /vqa/models/noatt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from vqa.lib import utils 6 | from vqa.models import fusion 7 | from vqa.models import seq2vec 8 | 9 | class AbstractNoAtt(nn.Module): 10 | 11 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]): 12 | super(AbstractNoAtt, self).__init__() 13 | self.opt = opt 14 | self.vocab_words = vocab_words 15 | self.vocab_answers = vocab_answers 16 | self.num_classes = len(self.vocab_answers) 17 | # Modules 18 | self.seq2vec = seq2vec.factory(self.vocab_words, self.opt['seq2vec']) 19 | self.linear_classif = nn.Linear(self.opt['fusion']['dim_h'], self.num_classes) 20 | 21 | def _fusion(self, input_v, input_q): 22 | raise NotImplementedError 23 | 24 | def _classif(self, x): 25 | if 'activation' in self.opt['classif']: 26 | x = getattr(F, self.opt['classif']['activation'])(x) 27 | x = F.dropout(x, p=self.opt['classif']['dropout'], training=self.training) 28 | x = self.linear_classif(x) 29 | return x 30 | 31 | def forward(self, input_v, input_q): 32 | x_q = self.seq2vec(input_q) 33 | x = self._fusion(input_v, x_q) 34 | x = self._classif(x) 35 | return x 36 | 37 | 38 | class MLBNoAtt(AbstractNoAtt): 39 | 40 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]): 41 | super(MLBNoAtt, self).__init__(opt, vocab_words, vocab_answers) 42 | self.fusion = fusion.MLBFusion(self.opt['fusion']) 43 | 44 | def _fusion(self, input_v, input_q): 45 | x = self.fusion(input_v, input_q) 46 | return x 47 | 48 | 49 | class MutanNoAtt(AbstractNoAtt): 50 | 51 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]): 52 | opt['fusion']['dim_h'] = opt['fusion']['dim_mm'] 53 | super(MutanNoAtt, self).__init__(opt, vocab_words, vocab_answers) 54 | self.fusion = fusion.MutanFusion(self.opt['fusion']) 55 | 56 | def _fusion(self, input_v, input_q): 57 | x = self.fusion(input_v, input_q) 58 | return x 59 | 60 | -------------------------------------------------------------------------------- /vqa/models/seq2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | import sys 7 | sys.path.append('vqa/external/skip-thoughts.torch/pytorch') 8 | import skipthoughts 9 | 10 | 11 | def process_lengths(input): 12 | max_length = input.size(1) 13 | lengths = list(max_length - input.data.eq(0).sum(1).squeeze()) 14 | return lengths 15 | 16 | def select_last(x, lengths): 17 | batch_size = x.size(0) 18 | seq_length = x.size(1) 19 | mask = x.data.new().resize_as_(x.data).fill_(0) 20 | for i in range(batch_size): 21 | mask[i][lengths[i]-1].fill_(1) 22 | mask = Variable(mask) 23 | x = x.mul(mask) 24 | x = x.sum(1).view(batch_size, x.size(2)) 25 | return x 26 | 27 | class LSTM(nn.Module): 28 | 29 | def __init__(self, vocab, emb_size, hidden_size, num_layers): 30 | super(LSTM, self).__init__() 31 | self.vocab = vocab 32 | self.emb_size = emb_size 33 | self.hidden_size = hidden_size 34 | self.num_layers = num_layers 35 | self.embedding = nn.Embedding(num_embeddings=len(self.vocab)+1, 36 | embedding_dim=emb_size, 37 | padding_idx=0) 38 | self.rnn = nn.LSTM(input_size=emb_size, hidden_size=hidden_size, num_layers=num_layers) 39 | 40 | def forward(self, input): 41 | lengths = process_lengths(input) 42 | x = self.embedding(input) # seq2seq 43 | output, hn = self.rnn(x) 44 | output = select_last(output, lengths) 45 | return output 46 | 47 | 48 | class TwoLSTM(nn.Module): 49 | 50 | def __init__(self, vocab, emb_size, hidden_size): 51 | super(TwoLSTM, self).__init__() 52 | self.vocab = vocab 53 | self.emb_size = emb_size 54 | self.hidden_size = hidden_size 55 | self.embedding = nn.Embedding(num_embeddings=len(self.vocab)+1, 56 | embedding_dim=emb_size, 57 | padding_idx=0) 58 | self.rnn_0 = nn.LSTM(input_size=emb_size, hidden_size=hidden_size, num_layers=1) 59 | self.rnn_1 = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=1) 60 | 61 | def forward(self, input): 62 | lengths = process_lengths(input) 63 | x = self.embedding(input) # seq2seq 64 | x = getattr(F, 'tanh')(x) 65 | x_0, hn = self.rnn_0(x) 66 | vec_0 = select_last(x_0, lengths) 67 | 68 | # x_1 = F.dropout(x_0, p=0.3, training=self.training) 69 | # print(x_1.size()) 70 | x_1, hn = self.rnn_1(x_0) 71 | vec_1 = select_last(x_1, lengths) 72 | 73 | vec_0 = F.dropout(vec_0, p=0.3, training=self.training) 74 | vec_1 = F.dropout(vec_1, p=0.3, training=self.training) 75 | output = torch.cat((vec_0, vec_1), 1) 76 | return output 77 | 78 | 79 | def factory(vocab_words, opt): 80 | if opt['arch'] == 'skipthoughts': 81 | st_class = getattr(skipthoughts, opt['type']) 82 | seq2vec = st_class(opt['dir_st'], 83 | vocab_words, 84 | dropout=opt['dropout'], 85 | fixed_emb=opt['fixed_emb']) 86 | elif opt['arch'] == '2-lstm': 87 | seq2vec = TwoLSTM(vocab_words, 88 | opt['emb_size'], 89 | opt['hidden_size']) 90 | elif opt['arch'] == 'lstm': 91 | seq2vec = TwoLSTM(vocab_words, 92 | opt['emb_size'], 93 | opt['hidden_size'], 94 | opt['num_layers']) 95 | else: 96 | raise NotImplementedError 97 | return seq2vec 98 | 99 | 100 | if __name__ == '__main__': 101 | 102 | vocab = ['robots', 'are', 'very', 'cool', '', 'BiDiBu'] 103 | lstm = TwoLSTM(vocab, 300, 1024) 104 | 105 | input = Variable(torch.LongTensor([ 106 | [1,2,3,4,5,0,0], 107 | [6,1,2,3,3,4,5], 108 | [6,1,2,3,3,4,5] 109 | ])) 110 | output = lstm(input) 111 | -------------------------------------------------------------------------------- /vqa/models/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | 7 | from .noatt import MLBNoAtt, MutanNoAtt 8 | from .att import MLBAtt, MutanAtt 9 | 10 | model_names = sorted(name for name in sys.modules[__name__].__dict__ 11 | if not name.startswith("__"))# and 'Att' in name) 12 | 13 | def factory(opt, vocab_words, vocab_answers, cuda=True, data_parallel=True): 14 | opt = copy.copy(opt) 15 | 16 | if opt['arch'] in model_names: 17 | model = getattr(sys.modules[__name__], opt['arch'])(opt, vocab_words, vocab_answers) 18 | else: 19 | raise ValueError 20 | 21 | if data_parallel: 22 | model = nn.DataParallel(model).cuda() 23 | if not cuda: 24 | raise ValueError 25 | 26 | if cuda: 27 | model.cuda() 28 | 29 | return model --------------------------------------------------------------------------------