├── .gitignore
├── .gitmodules
├── README.md
├── data
└── .keep
├── demo_server.py
├── demo_web
├── css
│ ├── bootstrap-theme.css
│ ├── bootstrap-theme.css.map
│ ├── bootstrap-theme.min.css
│ ├── bootstrap-theme.min.css.map
│ ├── bootstrap.css
│ ├── bootstrap.css.map
│ ├── bootstrap.min.css
│ ├── bootstrap.min.css.map
│ └── custom.css
├── fonts
│ ├── glyphicons-halflings-regular.eot
│ ├── glyphicons-halflings-regular.svg
│ ├── glyphicons-halflings-regular.ttf
│ ├── glyphicons-halflings-regular.woff
│ └── glyphicons-halflings-regular.woff2
├── images
│ └── logos.png
├── index.html
└── js
│ ├── bootstrap.js
│ ├── bootstrap.min.js
│ ├── custom.js
│ └── npm.js
├── doc
├── mutan.png
├── mutan_noatt.html
├── mutan_noatt.png
├── mutan_noatt_vs_att.html
├── mutan_noatt_vs_att.png
└── vqa_task.png
├── eval_res.py
├── extract.py
├── logs
└── .keep
├── options
├── vqa
│ ├── default.yaml
│ ├── mlb_att_trainval.yaml
│ ├── mlb_noatt_train.yaml
│ ├── mutan_att_trainval.yaml
│ └── mutan_noatt_train.yaml
└── vqa2
│ ├── default.yaml
│ ├── mlb_att_trainval.yaml
│ ├── mlb_noatt_train.yaml
│ ├── mutan_att_train.yaml
│ ├── mutan_att_trainval.yaml
│ ├── mutan_att_trainval_vg.yaml
│ └── mutan_noatt_train.yaml
├── requirements.txt
├── train.py
├── visu.ipynb
├── visu.py
└── vqa
├── __init__.py
├── datasets
├── __init__.py
├── coco.py
├── features.py
├── images.py
├── utils.py
├── vgenome.py
├── vgenome_interim.py
├── vgenome_processed.py
├── vqa.py
├── vqa2_interim.py
├── vqa_interim.py
└── vqa_processed.py
├── lib
├── __init__.py
├── criterions.py
├── dataloader.py
├── engine.py
├── logger.py
├── sampler.py
└── utils.py
└── models
├── __init__.py
├── att.py
├── convnets.py
├── fusion.py
├── noatt.py
├── seq2vec.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/coco
2 | data/porting
3 | data/skip-thoughts
4 | data/vqa
5 | data/vqa2
6 |
7 | logs/porting
8 | logs/vqa_old
9 | logs/vqa
10 | logs/vqa2
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # Mac
18 | .DS_Store
19 | ._.DS_Store
20 |
21 | .ipynb_checkpoints
22 |
23 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "vqa/external/VQA"]
2 | path = vqa/external/VQA
3 | url = https://github.com/Cadene/VQA.git
4 | [submodule "vqa/external/skip-thoughts.torch"]
5 | path = vqa/external/skip-thoughts.torch
6 | url = https://github.com/Cadene/skip-thoughts.torch.git
7 | [submodule "vqa/external/pretrained-models.pytorch"]
8 | path = vqa/external/pretrained-models.pytorch
9 | url = https://github.com/Cadene/pretrained-models.pytorch.git
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Visual Question Answering in pytorch
2 |
3 | **/!\ New version of pytorch for VQA available here:** https://github.com/Cadene/block.bootstrap.pytorch
4 |
5 | This repo was made by [Remi Cadene](http://remicadene.com) (LIP6) and [Hedi Ben-Younes](https://twitter.com/labegne) (LIP6-Heuritech), two PhD Students working on VQA at [UPMC-LIP6](http://lip6.fr) and their professors [Matthieu Cord](http://webia.lip6.fr/~cord) (LIP6) and [Nicolas Thome](http://webia.lip6.fr/~thomen) (LIP6-CNAM). We developed this code in the frame of a research paper called [MUTAN: Multimodal Tucker Fusion for VQA](https://arxiv.org/abs/1705.06676) which is (as far as we know) the current state-of-the-art on the [VQA 1.0 dataset](http://visualqa.org).
6 |
7 | The goal of this repo is two folds:
8 | - to make it easier to reproduce our results,
9 | - to provide an efficient and modular code base to the community for further research on other VQA datasets.
10 |
11 | If you have any questions about our code or model, don't hesitate to contact us or to submit any issues. Pull request are welcome!
12 |
13 | #### News:
14 |
15 | - 16th january 2018: a pretrained vqa2 model and web demo
16 | - 18th july 2017: VQA2, VisualGenome, FBResnet152 (for pytorch) added [v2.0 commit msg](https://github.com/Cadene/vqa.pytorch/commit/42391fd4a39c31e539eb6cb73ecd370bac0f010a)
17 | - 16th july 2017: paper accepted at ICCV2017
18 | - 30th may 2017: poster accepted at CVPR2017 (VQA Workshop)
19 |
20 | #### Summary:
21 |
22 | * [Introduction](#introduction)
23 | * [What is the task about?](#what-is-the-task-about)
24 | * [Quick insight about our method](#quick-insight-about-our-method)
25 | * [Installation](#installation)
26 | * [Requirements](#requirements)
27 | * [Submodules](#submodules)
28 | * [Data](#data)
29 | * [Reproducing results on VQA 1.0](#reproducing-results-on-vqa-10)
30 | * [Features](#features)
31 | * [Pretrained models](#pretrained-models)
32 | * [Reproducing results on VQA 2.0](#reproducing-results-on-vqa-20)
33 | * [Features](#features-20)
34 | * [Pretrained models](#pretrained-models-20)
35 | * [Documentation](#documentation)
36 | * [Architecture](#architecture)
37 | * [Options](#options)
38 | * [Datasets](#datasets)
39 | * [Models](#models)
40 | * [Quick examples](#quick-examples)
41 | * [Extract features from COCO](#extract-features-from-coco)
42 | * [Extract features from VisualGenome](#extract-features-from-visualgenome)
43 | * [Train models on VQA 1.0](#train-models-on-vqa-10)
44 | * [Train models on VQA 2.0](#train-models-on-vqa-20)
45 | * [Train models on VQA + VisualGenome](#train-models-on-vqa-10-or-20--visualgenome)
46 | * [Monitor training](#monitor-training)
47 | * [Restart training](#restart-training)
48 | * [Evaluate models on VQA](#evaluate-models-on-vqa)
49 | * [Web demo](#web-demo)
50 | * [Citation](#citation)
51 | * [Acknowledgment](#acknowledgment)
52 |
53 | ## Introduction
54 |
55 | ### What is the task about?
56 |
57 | The task is about training models in a end-to-end fashion on a multimodal dataset made of triplets:
58 |
59 | - an **image** with no other information than the raw pixels,
60 | - a **question** about visual content(s) on the associated image,
61 | - a short **answer** to the question (one or a few words).
62 |
63 | As you can see in the illustration bellow, two different triplets (but same image) of the VQA dataset are represented. The models need to learn rich multimodal representations to be able to give the right answers.
64 |
65 |
66 |
67 |
68 |
69 | The VQA task is still on active research. However, when it will be solved, it could be very useful to improve human-to-machine interfaces (especially for the blinds).
70 |
71 | ### Quick insight about our method
72 |
73 | The VQA community developped an approach based on four learnable components:
74 |
75 | - a question model which can be a LSTM, GRU, or pretrained Skipthoughts,
76 | - an image model which can be a pretrained VGG16 or ResNet-152,
77 | - a fusion scheme which can be an element-wise sum, concatenation, [MCB](https://arxiv.org/abs/1606.01847), [MLB](https://arxiv.org/abs/1610.04325), or [Mutan](https://arxiv.org/abs/1705.06676),
78 | - optionally, an attention scheme which may have several "glimpses".
79 |
80 |
81 |
82 |
83 |
84 | One of our claim is that the multimodal fusion between the image and the question representations is a critical component. Thus, our proposed model uses a Tucker Decomposition of the correlation Tensor to model richer multimodal interactions in order to provide proper answers. Our best model is based on :
85 |
86 | - a pretrained Skipthoughts for the question model,
87 | - features from a pretrained Resnet-152 (with images of size 3x448x448) for the image model,
88 | - our proposed Mutan (based on a Tucker Decomposition) for the fusion scheme,
89 | - an attention scheme with two "glimpses".
90 |
91 | ## Installation
92 |
93 | ### Requirements
94 |
95 | First install python 3 (we don't provide support for python 2). We advise you to install python 3 and pytorch with Anaconda:
96 |
97 | - [python with anaconda](https://www.continuum.io/downloads)
98 | - [pytorch with CUDA](http://pytorch.org)
99 |
100 | ```
101 | conda create --name vqa python=3
102 | source activate vqa
103 | conda install pytorch torchvision cuda80 -c soumith
104 | ```
105 |
106 | Then clone the repo (with the `--recursive` flag for submodules) and install the complementary requirements:
107 |
108 | ```
109 | cd $HOME
110 | git clone --recursive https://github.com/Cadene/vqa.pytorch.git
111 | cd vqa.pytorch
112 | pip install -r requirements.txt
113 | ```
114 |
115 | ### Submodules
116 |
117 | Our code has two external dependencies:
118 |
119 | - [VQA](https://github.com/Cadene/VQA) is used to evaluate results files on the valset with the OpendEnded accuracy,
120 | - [skip-thoughts.torch](https://github.com/Cadene/skip-thoughts.torch) is used to import pretrained GRUs and embeddings,
121 | - [pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch) is used to load pretrained convnets.
122 |
123 | ### Data
124 |
125 | Data will be automaticaly downloaded and preprocessed when needed. Links to data are stored in `vqa/datasets/vqa.py`, `vqa/datasets/coco.py` and `vqa/datasets/vgenome.py`.
126 |
127 |
128 | ## Reproducing results on VQA 1.0
129 |
130 | ### Features
131 |
132 | As we first developped on Lua/Torch7, we used the features of [ResNet-152 pretrained with Torch7](https://github.com/facebook/fb.resnet.torch). We ported the pretrained resnet152 trained with Torch7 in pytorch in the v2.0 release. We will provide all the extracted features soon. Meanwhile, you can download the coco features as following:
133 |
134 | ```
135 | mkdir -p data/coco/extract/arch,fbresnet152torch
136 | cd data/coco/extract/arch,fbresnet152torch
137 | wget https://data.lip6.fr/coco/trainset.hdf5
138 | wget https://data.lip6.fr/coco/trainset.txt
139 | wget https://data.lip6.fr/coco/valset.hdf5
140 | wget https://data.lip6.fr/coco/valset.txt
141 | wget https://data.lip6.fr/coco/testset.hdf5
142 | wget https://data.lip6.fr/coco/testset.txt
143 | ```
144 |
145 | /!\ There are currently 3 versions of ResNet152:
146 |
147 | - fbresnet152torch which is the torch7 model,
148 | - fbresnet152 which is the porting of the torch7 in pytorch,
149 | - [resnet152](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) which is the pretrained model from torchvision (we've got lower results with it).
150 |
151 | ### Pretrained VQA models
152 |
153 | We currently provide three models trained with our old Torch7 code and ported to Pytorch:
154 |
155 | - MutanNoAtt trained on the VQA 1.0 trainset,
156 | - MLBAtt trained on the VQA 1.0 trainvalset and VisualGenome,
157 | - MutanAtt trained on the VQA 1.0 trainvalset and VisualGenome.
158 |
159 | ```
160 | mkdir -p logs/vqa
161 | cd logs/vqa
162 | wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mutan_noatt_train.zip
163 | wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mlb_att_trainval.zip
164 | wget http://webia.lip6.fr/~cadene/Downloads/vqa.pytorch/logs/vqa/mutan_att_trainval.zip
165 | ```
166 |
167 | Even if we provide results files associated to our pretrained models, you can evaluate them once again on the valset, testset and testdevset using a single command:
168 |
169 | ```
170 | python train.py -e --path_opt options/vqa/mutan_noatt_train.yaml --resume ckpt
171 | python train.py -e --path_opt options/vqa/mlb_noatt_trainval.yaml --resume ckpt
172 | python train.py -e --path_opt options/vqa/mutan_att_trainval.yaml --resume ckpt
173 | ```
174 |
175 | To obtain test and testdev results on VQA 1.0, you will need to zip your result json file (name it as `results.zip`) and to submit it on the [evaluation server](https://competitions.codalab.org/competitions/6961).
176 |
177 |
178 | ## Reproducing results on VQA 2.0
179 |
180 | ### Features 2.0
181 |
182 | You must download the coco dataset (and visual genome if needed) and then extract the features with a convolutional neural network.
183 |
184 | ### Pretrained VQA models 2.0
185 |
186 | We currently provide three models trained with our current pytorch code on VQA 2.0
187 |
188 | - MutanAtt trained on the trainset with the fbresnet152 features,
189 | - MutanAtt trained on thetrainvalset with the fbresnet152 features.
190 |
191 | ```
192 | cd $VQAPYTORCH
193 | mkdir -p logs/vqa2
194 | cd logs/vqa2
195 | wget http://data.lip6.fr/cadene/vqa.pytorch/vqa2/mutan_att_train.zip
196 | wget http://data.lip6.fr/cadene/vqa.pytorch/vqa2/mutan_att_trainval.zip
197 | ```
198 |
199 | ## Documentation
200 |
201 | ### Architecture
202 |
203 | ```
204 | .
205 | ├── options # default options dir containing yaml files
206 | ├── logs # experiments dir containing directories of logs (one by experiment)
207 | ├── data # datasets directories
208 | | ├── coco # images and features
209 | | ├── vqa # raw, interim and processed data
210 | | ├── vgenome # raw, interim, processed data + images and features
211 | | └── ...
212 | ├── vqa # vqa package dir
213 | | ├── datasets # datasets classes & functions dir (vqa, coco, vgenome, images, features, etc.)
214 | | ├── external # submodules dir (VQA, skip-thoughts.torch, pretrained-models.pytorch)
215 | | ├── lib # misc classes & func dir (engine, logger, dataloader, etc.)
216 | | └── models # models classes & func dir (att, fusion, notatt, seq2vec, convnets)
217 | |
218 | ├── train.py # train & eval models
219 | ├── eval_res.py # eval results files with OpenEnded metric
220 | ├── extract.py # extract features from coco with CNNs
221 | └── visu.py # visualize logs and monitor training
222 | ```
223 |
224 | ### Options
225 |
226 | There are three kind of options:
227 |
228 | - options from the yaml options files stored in the `options` directory which are used as default (path to directory, logs, model, features, etc.)
229 | - options from the ArgumentParser in the `train.py` file which are set to None and can overwrite default options (learning rate, batch size, etc.)
230 | - options from the ArgumentParser in the `train.py` file which are set to default values (print frequency, number of threads, resume model, evaluate model, etc.)
231 |
232 | You can easly add new options in your custom yaml file if needed. Also, if you want to grid search a parameter, you can add an ArgumentParser option and modify the dictionnary in `train.py:L80`.
233 |
234 | ### Datasets
235 |
236 | We currently provide four datasets:
237 |
238 | - [COCOImages](http://mscoco.org/) currently used to extract features, it comes with three datasets: trainset, valset and testset
239 | - [VisualGenomeImages]() currently used to extract features, it comes with one split: trainset
240 | - [VQA 1.0](http://www.visualqa.org/vqa_v1_download.html) comes with four datasets: trainset, valset, testset (including test-std and test-dev) and "trainvalset" (concatenation of trainset and valset)
241 | - [VQA 2.0](http://www.visualqa.org) same but twice bigger (however same images than VQA 1.0)
242 |
243 | We plan to add:
244 |
245 | - [CLEVR](http://cs.stanford.edu/people/jcjohns/clevr/)
246 |
247 | ### Models
248 |
249 | We currently provide four models:
250 |
251 | - MLBNoAtt: a strong baseline (BayesianGRU + Element-wise product)
252 | - [MLBAtt](https://arxiv.org/abs/1610.04325): the previous state-of-the-art which adds an attention strategy
253 | - MutanNoAtt: our proof of concept (BayesianGRU + Mutan Fusion)
254 | - MutanAtt: the current state-of-the-art
255 |
256 | We plan to add several other strategies in the futur.
257 |
258 | ## Quick examples
259 |
260 | ### Extract features from COCO
261 |
262 | The needed images will be automaticaly downloaded to `dir_data` and the features will be extracted with a resnet152 by default.
263 |
264 | There are three options for `mode` :
265 |
266 | - `att`: features will be of size 2048x14x14,
267 | - `noatt`: features will be of size 2048,
268 | - `both`: default option.
269 |
270 | Beware, you will need some space on your SSD:
271 |
272 | - 32GB for the images,
273 | - 125GB for the train features,
274 | - 123GB for the test features,
275 | - 61GB for the val features.
276 |
277 | ```
278 | python extract.py -h
279 | python extract.py --dir_data data/coco --data_split train
280 | python extract.py --dir_data data/coco --data_split val
281 | python extract.py --dir_data data/coco --data_split test
282 | ```
283 |
284 | Note: By default our code will share computations over all available GPUs. If you want to select only one or a few, use the following prefix:
285 |
286 | ```
287 | CUDA_VISIBLE_DEVICES=0 python extract.py
288 | CUDA_VISIBLE_DEVICES=1,2 python extract.py
289 | ```
290 |
291 | ### Extract features from VisualGenome
292 |
293 | Same here, but only train is available:
294 |
295 | ```
296 | python extract.py --dataset vgenome --dir_data data/vgenome --data_split train
297 | ```
298 |
299 |
300 | ### Train models on VQA 1.0
301 |
302 | Display help message, selected options and run default. The needed data will be automaticaly downloaded and processed using the options in `options/vqa/default.yaml`.
303 |
304 | ```
305 | python train.py -h
306 | python train.py --help_opt
307 | python train.py
308 | ```
309 |
310 | Run a MutanNoAtt model with default options.
311 |
312 | ```
313 | python train.py --path_opt options/vqa/mutan_noatt_train.yaml --dir_logs logs/vqa/mutan_noatt_train
314 | ```
315 |
316 | Run a MutanAtt model on the trainset and evaluate on the valset after each epoch.
317 |
318 | ```
319 | python train.py --vqa_trainsplit train --path_opt options/vqa/mutan_att_trainval.yaml
320 | ```
321 |
322 | Run a MutanAtt model on the trainset and valset (by default) and run throw the testset after each epoch (produce a results file that you can submit to the evaluation server).
323 |
324 | ```
325 | python train.py --vqa_trainsplit trainval --path_opt options/vqa/mutan_att_trainval.yaml
326 | ```
327 |
328 | ### Train models on VQA 2.0
329 |
330 | See options of [vqa2/mutan_att_trainval](https://github.com/Cadene/vqa.pytorch/blob/master/options/vqa2/mutan_att_trainval.yaml):
331 |
332 | ```
333 | python train.py --path_opt options/vqa2/mutan_att_trainval.yaml
334 | ```
335 |
336 | ### Train models on VQA (1.0 or 2.0) + VisualGenome
337 |
338 | See options of [vqa2/mutan_att_trainval_vg](https://github.com/Cadene/vqa.pytorch/blob/master/options/vqa2/mutan_att_trainval_vg.yaml):
339 |
340 | ```
341 | python train.py --path_opt options/vqa2/mutan_att_trainval_vg.yaml
342 | ```
343 |
344 | ### Monitor training
345 |
346 | Create a visualization of an experiment using `plotly` to monitor the training, just like the picture bellow (**click the image to access the html/js file**):
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 | Note that you have to wait until the first open ended accuracy has finished processing and then the html file will be created and will pop out on your default browser. The html will be refreshed every 60 seconds. However, you will currently need to press F5 on your browser to see the change.
355 |
356 | ```
357 | python visu.py --dir_logs logs/vqa/mutan_noatt
358 | ```
359 |
360 | Create a visualization of multiple experiments to compare them or monitor them like the picture bellow (**click the image to access the html/js file**):
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 | ```
369 | python visu.py --dir_logs logs/vqa/mutan_noatt,logs/vqa/mutan_att
370 | ```
371 |
372 |
373 |
374 | ### Restart training
375 |
376 | Restart the model from the last checkpoint.
377 |
378 | ```
379 | python train.py --path_opt options/vqa/mutan_noatt.yaml --dir_logs logs/vqa/mutan_noatt --resume ckpt
380 | ```
381 |
382 | Restart the model from the best checkpoint.
383 |
384 | ```
385 | python train.py --path_opt options/vqa/mutan_noatt.yaml --dir_logs logs/vqa/mutan_noatt --resume best
386 | ```
387 |
388 | ### Evaluate models on VQA
389 |
390 | Evaluate the model from the best checkpoint. If your model has been trained on the training set only (`vqa_trainsplit=train`), the model will be evaluate on the valset and will run throw the testset. If it was trained on the trainset + valset (`vqa_trainsplit=trainval`), it will not be evaluate on the valset.
391 |
392 | ```
393 | python train.py --vqa_trainsplit train --path_opt options/vqa/mutan_att.yaml --dir_logs logs/vqa/mutan_att --resume best -e
394 | ```
395 |
396 | ### Web demo
397 |
398 | You must set your local ip address and port in `demo_server.py` line 169 and your global ip address and port in `demo_web/js/custom.js` line 51.
399 | The port associated to the global ip address must redirect to your local ip address.
400 |
401 | Launch your API:
402 | ```
403 | CUDA_VISIBLE_DEVICES=0 python demo_server.py
404 | ```
405 |
406 | Open `demo_web/index.html` on your browser to access the API with a human interface.
407 |
408 | ## Citation
409 |
410 | Please cite the arXiv paper if you use Mutan in your work:
411 |
412 | ```
413 | @article{benyounescadene2017mutan,
414 | author = {Hedi Ben-Younes and
415 | R{\'{e}}mi Cad{\`{e}}ne and
416 | Nicolas Thome and
417 | Matthieu Cord},
418 | title = {MUTAN: Multimodal Tucker Fusion for Visual Question Answering},
419 | journal = {ICCV},
420 | year = {2017},
421 | url = {http://arxiv.org/abs/1705.06676}
422 | }
423 | ```
424 |
425 | ## Acknowledgment
426 |
427 | Special thanks to the authors of [MLB](https://arxiv.org/abs/1610.04325) for providing some [Torch7 code](https://github.com/jnhwkim/MulLowBiVQA), [MCB](https://arxiv.org/abs/1606.01847) for providing some [Caffe code](https://github.com/akirafukui/vqa-mcb), and our professors and friends from LIP6 for the perfect working atmosphere.
428 |
--------------------------------------------------------------------------------
/data/.keep:
--------------------------------------------------------------------------------
1 | .empty
--------------------------------------------------------------------------------
/demo_server.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import yaml
4 | import json
5 | import argparse
6 | import re
7 | import base64
8 | import torch
9 | from torch.autograd import Variable
10 | from PIL import Image
11 | from io import BytesIO
12 | from pprint import pprint
13 |
14 | from werkzeug.wrappers import Request, Response
15 | from werkzeug.serving import run_simple
16 |
17 | import torchvision.transforms as transforms
18 | import vqa.lib.utils as utils
19 | import vqa.datasets as datasets
20 | import vqa.models as models
21 | import vqa.models.convnets as convnets
22 | from vqa.datasets.vqa_processed import tokenize_mcb
23 | from train import load_checkpoint
24 |
25 | parser = argparse.ArgumentParser(
26 | description='Demo server',
27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28 | parser.add_argument('--dir_logs', type=str,
29 | #default='logs/vqa2/blocmutan_noatt_fbresnet152torchported_save_all',
30 | default='logs/vqa2/mutan_att_train',
31 | help='dir logs')
32 | parser.add_argument('--path_opt', type=str,
33 | #default='logs/vqa2/blocmutan_noatt_fbresnet152torchported_save_all/blocmutan_noatt.yaml',
34 | default='logs/vqa2/mutan_att_train/mutan_att_train.yaml',
35 | help='path to a yaml options file')
36 | parser.add_argument('--resume', type=str,
37 | default='best',
38 | help='path to latest checkpoint')
39 | parser.add_argument('--cuda', type=bool,
40 | const=True,
41 | nargs='?',
42 | help='path to latest checkpoint')
43 |
44 | @Request.application
45 | def application(request):
46 | print('')
47 | if 'visual' in request.form and 'question' in request.form:
48 | visual = process_visual(request.form['visual'])
49 | question = process_question(request.form['question'])
50 | answer = process_answer(model(visual, question))
51 | response = Response(answer)
52 |
53 | elif 'question' not in request.form:
54 | response = Response('Question missing')
55 |
56 | elif 'visual' not in request.form:
57 | response = Response('Image missing')
58 |
59 | else:
60 | response = Response('what?')
61 |
62 | response.headers.add('Access-Control-Allow-Origin', '*')
63 | response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,PATCH')
64 | response.headers.add('Access-Control-Allow-Headers', 'Content-Type, Authorization')
65 | response.headers.add('X-XSS-Protection', '0')
66 | return response
67 |
68 | def process_visual(visual_strb64):
69 | visual_strb64 = re.sub('^data:image/.+;base64,', '', visual_strb64)
70 | visual_PIL = Image.open(BytesIO(base64.b64decode(visual_strb64)))
71 | visual_tensor = transform(visual_PIL)
72 | visual_data = torch.FloatTensor(1, 3,
73 | visual_tensor.size(1),
74 | visual_tensor.size(2))
75 | visual_data[0][0] = visual_tensor[0]
76 | visual_data[0][1] = visual_tensor[1]
77 | visual_data[0][2] = visual_tensor[2]
78 | print('visual', visual_data.size(), visual_data.mean())
79 | if args.cuda:
80 | visual_data = visual_data.cuda(async=True)
81 | visual_input = Variable(visual_data, volatile=True)
82 | visual_features = cnn(visual_input)
83 | if 'NoAtt' in options['model']['arch']:
84 | nb_regions = visual_features.size(2) * visual_features.size(3)
85 | visual_features = visual_features.sum(3).sum(2).div(nb_regions).view(-1, 2048)
86 | return visual_features
87 |
88 | def process_question(question_str):
89 | question_tokens = tokenize_mcb(question_str)
90 | question_data = torch.LongTensor(1, len(question_tokens))
91 | for i, word in enumerate(question_tokens):
92 | if word in trainset.word_to_wid:
93 | question_data[0][i] = trainset.word_to_wid[word]
94 | else:
95 | question_data[0][i] = trainset.word_to_wid['UNK']
96 | if args.cuda:
97 | question_data = question_data.cuda(async=True)
98 | question_input = Variable(question_data, volatile=True)
99 | print('question', question_str, question_tokens, question_data)
100 |
101 | return question_input
102 |
103 | def process_answer(answer_var):
104 | answer_sm = torch.nn.functional.softmax(answer_var.data[0].cpu())
105 | max_, aid = answer_sm.topk(5, 0, True, True)
106 | ans = []
107 | val = []
108 | for i in range(5):
109 | ans.append(trainset.aid_to_ans[aid.data[i]])
110 | val.append(max_.data[i])
111 |
112 | att = []
113 | for x_att in model.list_att:
114 | img = x_att.view(1,14,14).cpu()
115 | img = transforms.ToPILImage()(img)
116 | buffer_ = BytesIO()
117 | img.save(buffer_, format="PNG")
118 | img_str = base64.b64encode(buffer_.getvalue()).decode()
119 | img_str = 'data:image/png;base64,'+img_str
120 | att.append(img_str)
121 |
122 | answer = {'ans':ans,'val':val,'att':att}
123 | answer_str = json.dumps(answer)
124 |
125 | return answer_str
126 |
127 | def main():
128 | global args, options, model, cnn, transform, trainset
129 | args = parser.parse_args()
130 |
131 | options = {
132 | 'logs': {
133 | 'dir_logs': args.dir_logs
134 | }
135 | }
136 | if args.path_opt is not None:
137 | with open(args.path_opt, 'r') as handle:
138 | options_yaml = yaml.load(handle)
139 | options = utils.update_values(options, options_yaml)
140 | print('## args'); pprint(vars(args))
141 | print('## options'); pprint(options)
142 |
143 | trainset = datasets.factory_VQA(options['vqa']['trainsplit'],
144 | options['vqa'])
145 | #options['coco'])
146 |
147 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
148 | std=[0.229, 0.224, 0.225])
149 | transform = transforms.Compose([
150 | transforms.Scale(options['coco']['size']),
151 | transforms.CenterCrop(options['coco']['size']),
152 | transforms.ToTensor(),
153 | normalize,
154 | ])
155 |
156 | opt_factory_cnn = {
157 | 'arch': options['coco']['arch']
158 | }
159 | cnn = convnets.factory(opt_factory_cnn, cuda=args.cuda, data_parallel=False)
160 | model = models.factory(options['model'],
161 | trainset.vocab_words(),
162 | trainset.vocab_answers(),
163 | cuda=args.cuda,
164 | data_parallel=False)
165 | model.eval()
166 | start_epoch, best_acc1, _ = load_checkpoint(model, None,
167 | os.path.join(options['logs']['dir_logs'], args.resume))
168 |
169 | my_local_ip = '192.168.0.32'
170 | my_local_port = 3456
171 | run_simple(my_local_ip, my_local_port, application)
172 |
173 | if __name__ == '__main__':
174 | main()
175 |
176 |
--------------------------------------------------------------------------------
/demo_web/css/custom.css:
--------------------------------------------------------------------------------
1 | .btn-file {
2 | position: relative;
3 | overflow: hidden;
4 | }
5 | .btn-file input[type=file] {
6 | position: absolute;
7 | top: 0;
8 | right: 0;
9 | min-width: 100%;
10 | min-height: 100%;
11 | font-size: 100px;
12 | text-align: right;
13 | filter: alpha(opacity=0);
14 | opacity: 0;
15 | outline: none;
16 | background: white;
17 | cursor: inherit;
18 | display: block;
19 | }
20 |
21 | #vqa-visual{
22 | width: 300px;
23 | }
--------------------------------------------------------------------------------
/demo_web/fonts/glyphicons-halflings-regular.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/fonts/glyphicons-halflings-regular.eot
--------------------------------------------------------------------------------
/demo_web/fonts/glyphicons-halflings-regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/fonts/glyphicons-halflings-regular.ttf
--------------------------------------------------------------------------------
/demo_web/fonts/glyphicons-halflings-regular.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/fonts/glyphicons-halflings-regular.woff
--------------------------------------------------------------------------------
/demo_web/fonts/glyphicons-halflings-regular.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/fonts/glyphicons-halflings-regular.woff2
--------------------------------------------------------------------------------
/demo_web/images/logos.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/demo_web/images/logos.png
--------------------------------------------------------------------------------
/demo_web/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | VQA Demo MUTAN
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
24 |
25 |
26 |
27 |
28 |
29 |
40 |
53 |
54 |
55 |
56 |
57 |
58 |
Welcome to the demo page
59 |
The VQA task is about training models in a end-to-end fashion on a multimodal dataset made of triplets: (image, question, answer). We recently proposed a new model called MUTAN based on Neural Networks and the Tucker Decomposition to address this machine learning problem. Please, try out our latest model :)
60 |
61 |
Best regards. Remi Cadene , Hedi Ben-Younes , Matthieu Cord and Nicolas Thome .
62 |
63 | Github »
64 | Paper »
65 |
66 |
67 |
68 |
69 |
90 |
91 |
92 |
93 |
94 |
API
95 |
96 |
97 |
98 |
123 |
124 |
3. Receive the answer
125 |
MUTAN is waiting for your question.
126 |
127 |
128 |
129 |
130 |
131 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
--------------------------------------------------------------------------------
/demo_web/js/custom.js:
--------------------------------------------------------------------------------
1 | $(document).ready(function () {
2 |
3 | // Image
4 |
5 | $(document).on('change', '.btn-file :file', function() {
6 | var input = $(this),
7 | label = input.val().replace(/\\/g, '/').replace(/.*\//, '');
8 | input.trigger('fileselect', [label]);
9 | });
10 |
11 | $('.btn-file :file').on('fileselect', function(event, label) {
12 |
13 | var input = $(this).parents('.input-group').find(':text'),
14 | log = label;
15 |
16 | if( input.length ) {
17 | input.val(log);
18 | } else {
19 | if( log ) alert(log);
20 | }
21 |
22 | });
23 | function readURL(input) {
24 | if (input.files && input.files[0]) {
25 | var reader = new FileReader();
26 |
27 | reader.onload = function (e) {
28 | $('#vqa-visual').attr('src', e.target.result);
29 | }
30 |
31 | reader.readAsDataURL(input.files[0]);
32 | }
33 | }
34 |
35 | $("#imgInp").change(function(){
36 | readURL(this);
37 | });
38 |
39 | // Send Image + Question
40 |
41 | var formBasic = function () {
42 | var formData = $("#formBasic").serialize();
43 | var data = { visual : $('#vqa-visual').attr('src'),
44 | question : $('#vqa-question').val()}
45 | console.log(data);
46 | $.ajax({
47 |
48 | type: 'post',
49 | data: data,
50 | dataType: 'json',
51 | url: 'http://', // your global ip address and port
52 |
53 | // error: function () {
54 | // alert("There was an error processing this page.");
55 | // return false;
56 | // },
57 |
58 | complete: function (output) {
59 | //console.log(output);
60 | //console.log(output.responseText);
61 | var ul = $('');
62 | for (i=0; i < output.responseJSON.ans.length; i++)
63 | {
64 | var li = $(' ');
65 | var span = $(' ');
66 |
67 | span.text(output.responseJSON.ans[i]+' ('+output.responseJSON.val[i]+')');
68 |
69 | li.append(span);
70 | ul.append(li);
71 | }
72 |
73 | for (i=0; i < output.responseJSON.att.length; i++)
74 | {
75 | var img = $(' ');
76 |
77 | ul.append(img);
78 | }
79 |
80 | $('#vqa-answer').append(ul);
81 | }
82 | });
83 | return false;
84 | };
85 |
86 | $("#basic-submit").on("click", function (e) {
87 | e.preventDefault();
88 | formBasic();
89 | });
90 | });
--------------------------------------------------------------------------------
/demo_web/js/npm.js:
--------------------------------------------------------------------------------
1 | // This file is autogenerated via the `commonjs` Grunt task. You can require() this file in a CommonJS environment.
2 | require('../../js/transition.js')
3 | require('../../js/alert.js')
4 | require('../../js/button.js')
5 | require('../../js/carousel.js')
6 | require('../../js/collapse.js')
7 | require('../../js/dropdown.js')
8 | require('../../js/modal.js')
9 | require('../../js/tooltip.js')
10 | require('../../js/popover.js')
11 | require('../../js/scrollspy.js')
12 | require('../../js/tab.js')
13 | require('../../js/affix.js')
--------------------------------------------------------------------------------
/doc/mutan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/doc/mutan.png
--------------------------------------------------------------------------------
/doc/mutan_noatt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/doc/mutan_noatt.png
--------------------------------------------------------------------------------
/doc/mutan_noatt_vs_att.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/doc/mutan_noatt_vs_att.png
--------------------------------------------------------------------------------
/doc/vqa_task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/doc/vqa_task.png
--------------------------------------------------------------------------------
/eval_res.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import random
4 | import os
5 | from os.path import join
6 | import sys
7 | #import pickle
8 | helperDir = 'vqa/external/VQA/'
9 | sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(helperDir))
10 | sys.path.insert(0, '%s/PythonEvaluationTools/vqaEvaluation' %(helperDir))
11 | from vqa import VQA
12 | from vqaEval import VQAEval
13 |
14 |
15 | if __name__=="__main__":
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--dir_vqa', type=str, default='/local/cadene/data/vqa')
18 | parser.add_argument('--dir_epoch', type=str, default='logs/16_12_13_20:39:55/epoch,1')
19 | parser.add_argument('--subtype', type=str, default='train2014')
20 | args = parser.parse_args()
21 |
22 | diranno = join(args.dir_vqa, 'raw', 'annotations')
23 | annFile = join(diranno, 'mscoco_%s_annotations.json' % (args.subtype))
24 | quesFile = join(diranno, 'OpenEnded_mscoco_%s_questions.json' % (args.subtype))
25 | vqa = VQA(annFile, quesFile)
26 |
27 | taskType = 'OpenEnded'
28 | dataType = 'mscoco'
29 | dataSubType = args.subtype
30 | resultType = 'model'
31 | fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
32 |
33 | [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = \
34 | ['%s/%s_%s_%s_%s_%s.json' % (args.dir_epoch, taskType, dataType,
35 | dataSubType, resultType, fileType) for fileType in fileTypes]
36 | vqaRes = vqa.loadRes(resFile, quesFile)
37 | vqaEval = VQAEval(vqa, vqaRes, n=2)
38 |
39 | quesIds = [int(d['question_id']) for d in json.loads(open(resFile).read())]
40 | vqaEval.evaluate(quesIds=quesIds)
41 |
42 | json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
43 | #json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
44 | #json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
45 | #json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
--------------------------------------------------------------------------------
/extract.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import h5py
5 | import numpy
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | from torch.autograd import Variable
12 |
13 | import torchvision.transforms as transforms
14 | import torchvision.datasets as datasets
15 |
16 | import vqa.models.convnets as convnets
17 | import vqa.datasets as datasets
18 | from vqa.lib.dataloader import DataLoader
19 | from vqa.lib.logger import AvgMeter
20 |
21 | parser = argparse.ArgumentParser(description='Extract')
22 | parser.add_argument('--dataset', default='coco',
23 | choices=['coco', 'vgenome'],
24 | help='dataset type: coco (default) | vgenome')
25 | parser.add_argument('--dir_data', default='data/coco',
26 | help='dir dataset to download or/and load images')
27 | parser.add_argument('--data_split', default='train', type=str,
28 | help='Options: (default) train | val | test')
29 | parser.add_argument('--arch', '-a', default='fbresnet152',
30 | choices=convnets.model_names,
31 | help='model architecture: ' +
32 | ' | '.join(convnets.model_names) +
33 | ' (default: fbresnet152)')
34 | parser.add_argument('--workers', default=4, type=int,
35 | help='number of data loading workers (default: 4)')
36 | parser.add_argument('--batch_size', '-b', default=80, type=int,
37 | help='mini-batch size (default: 80)')
38 | parser.add_argument('--mode', default='both', type=str,
39 | help='Options: att | noatt | (default) both')
40 | parser.add_argument('--size', default=448, type=int,
41 | help='Image size (448 for noatt := avg pooling to get 224) (default:448)')
42 |
43 |
44 | def main():
45 | global args
46 | args = parser.parse_args()
47 |
48 | print("=> using pre-trained model '{}'".format(args.arch))
49 | model = convnets.factory({'arch':args.arch}, cuda=True, data_parallel=True)
50 |
51 | extract_name = 'arch,{}_size,{}'.format(args.arch, args.size)
52 |
53 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
54 | std=[0.229, 0.224, 0.225])
55 |
56 | if args.dataset == 'coco':
57 | if 'coco' not in args.dir_data:
58 | raise ValueError('"coco" string not in dir_data')
59 | dataset = datasets.COCOImages(args.data_split, dict(dir=args.dir_data),
60 | transform=transforms.Compose([
61 | transforms.Scale(args.size),
62 | transforms.CenterCrop(args.size),
63 | transforms.ToTensor(),
64 | normalize,
65 | ]))
66 | elif args.dataset == 'vgenome':
67 | if args.data_split != 'train':
68 | raise ValueError('train split is required for vgenome')
69 | if 'vgenome' not in args.dir_data:
70 | raise ValueError('"vgenome" string not in dir_data')
71 | dataset = datasets.VisualGenomeImages(args.data_split, dict(dir=args.dir_data),
72 | transform=transforms.Compose([
73 | transforms.Scale(args.size),
74 | transforms.CenterCrop(args.size),
75 | transforms.ToTensor(),
76 | normalize,
77 | ]))
78 |
79 | data_loader = DataLoader(dataset,
80 | batch_size=args.batch_size, shuffle=False,
81 | num_workers=args.workers, pin_memory=True)
82 |
83 | dir_extract = os.path.join(args.dir_data, 'extract', extract_name)
84 | path_file = os.path.join(dir_extract, args.data_split + 'set')
85 | os.system('mkdir -p ' + dir_extract)
86 |
87 | extract(data_loader, model, path_file, args.mode)
88 |
89 |
90 | def extract(data_loader, model, path_file, mode):
91 | path_hdf5 = path_file + '.hdf5'
92 | path_txt = path_file + '.txt'
93 | hdf5_file = h5py.File(path_hdf5, 'w')
94 |
95 | # estimate output shapes
96 | output = model(Variable(torch.ones(1, 3, args.size, args.size),
97 | volatile=True))
98 |
99 | nb_images = len(data_loader.dataset)
100 | if mode == 'both' or mode == 'att':
101 | shape_att = (nb_images, output.size(1), output.size(2), output.size(3))
102 | print('Warning: shape_att={}'.format(shape_att))
103 | hdf5_att = hdf5_file.create_dataset('att', shape_att,
104 | dtype='f')#, compression='gzip')
105 | if mode == 'both' or mode == 'noatt':
106 | shape_noatt = (nb_images, output.size(1))
107 | print('Warning: shape_noatt={}'.format(shape_noatt))
108 | hdf5_noatt = hdf5_file.create_dataset('noatt', shape_noatt,
109 | dtype='f')#, compression='gzip')
110 |
111 | model.eval()
112 |
113 | batch_time = AvgMeter()
114 | data_time = AvgMeter()
115 | begin = time.time()
116 | end = time.time()
117 |
118 | idx = 0
119 | for i, input in enumerate(data_loader):
120 | input_var = Variable(input['visual'], volatile=True)
121 | output_att = model(input_var)
122 |
123 | nb_regions = output_att.size(2) * output_att.size(3)
124 | output_noatt = output_att.sum(3).sum(2).div(nb_regions).view(-1, 2048)
125 |
126 | batch_size = output_att.size(0)
127 | if mode == 'both' or mode == 'att':
128 | hdf5_att[idx:idx+batch_size] = output_att.data.cpu().numpy()
129 | if mode == 'both' or mode == 'noatt':
130 | hdf5_noatt[idx:idx+batch_size] = output_noatt.data.cpu().numpy()
131 | idx += batch_size
132 |
133 | torch.cuda.synchronize()
134 | batch_time.update(time.time() - end)
135 | end = time.time()
136 |
137 | if i % 1 == 0:
138 | print('Extract: [{0}/{1}]\t'
139 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
140 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format(
141 | i, len(data_loader),
142 | batch_time=batch_time,
143 | data_time=data_time,))
144 |
145 | hdf5_file.close()
146 |
147 | # Saving image names in the same order than extraction
148 | with open(path_txt, 'w') as handle:
149 | for name in data_loader.dataset.dataset.imgs:
150 | handle.write(name + '\n')
151 |
152 | end = time.time() - begin
153 | print('Finished in {}m and {}s'.format(int(end/60), int(end%60)))
154 |
155 |
156 | if __name__ == '__main__':
157 | main()
158 |
--------------------------------------------------------------------------------
/logs/.keep:
--------------------------------------------------------------------------------
1 | .empty
--------------------------------------------------------------------------------
/options/vqa/default.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa/default
3 | vqa:
4 | dataset: VQA
5 | dir: data/vqa
6 | trainsplit: train
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152torch
16 | mode: noatt
17 | model:
18 | arch: MLBNoAtt
19 | seq2vec:
20 | arch: skipthoughts
21 | dir_st: data/skip-thoughts
22 | type: UniSkip
23 | dropout: 0.25
24 | fixed_emb: False
25 | fusion:
26 | dim_v: 2048
27 | dim_q: 2400
28 | dim_h: 1200
29 | dropout_v: 0.5
30 | dropout_q: 0.5
31 | activation_v: tanh
32 | activation_q: tanh
33 | classif:
34 | activation: tanh
35 | dropout: 0.5
36 | optim:
37 | lr: 0.0001
38 | batch_size: 512
39 | epochs: 100
40 |
--------------------------------------------------------------------------------
/options/vqa/mlb_att_trainval.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa/mlb_att_trainval
3 | vqa:
4 | dataset: VQA
5 | dir: data/vqa
6 | trainsplit: trainval
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152torch
16 | mode: att
17 | model:
18 | arch: MLBAtt
19 | dim_v: 2048
20 | dim_q: 2400
21 | seq2vec:
22 | arch: skipthoughts
23 | dir_st: data/skip-thoughts
24 | type: BayesianUniSkip
25 | dropout: 0.25
26 | fixed_emb: False
27 | attention:
28 | nb_glimpses: 4
29 | dim_h: 1200
30 | dropout_v: 0.5
31 | dropout_q: 0.5
32 | dropout_mm: 0.5
33 | activation_v: tanh
34 | activation_q: tanh
35 | activation_mm: tanh
36 | fusion:
37 | dim_h: 1200
38 | dropout_v: 0.5
39 | dropout_q: 0.5
40 | activation_v: tanh
41 | activation_q: tanh
42 | classif:
43 | activation: tanh
44 | dropout: 0.5
45 | optim:
46 | lr: 0.0001
47 | batch_size: 128
48 | epochs: 100
49 |
--------------------------------------------------------------------------------
/options/vqa/mlb_noatt_train.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa/mlb_noatt_train
3 | vqa:
4 | dataset: VQA
5 | dir: data/vqa
6 | trainsplit: train
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152torch
16 | mode: noatt
17 | model:
18 | arch: MLBNoAtt
19 | seq2vec:
20 | arch: skipthoughts
21 | dir_st: data/skip-thoughts
22 | type: BayesianUniSkip
23 | dropout: 0.25
24 | fixed_emb: False
25 | fusion:
26 | dim_v: 2048
27 | dim_q: 2400
28 | dim_h: 1200
29 | dropout_v: 0.5
30 | dropout_q: 0.5
31 | activation_v: tanh
32 | activation_q: tanh
33 | classif:
34 | activation: tanh
35 | dropout: 0.5
36 | optim:
37 | lr: 0.0001
38 | batch_size: 512
39 | epochs: 100
40 |
--------------------------------------------------------------------------------
/options/vqa/mutan_att_trainval.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa/mutan_att_trainval
3 | vqa:
4 | dataset: VQA
5 | dir: data/vqa
6 | trainsplit: trainval
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152
16 | mode: att
17 | size: 448
18 | model:
19 | arch: MutanAtt
20 | dim_v: 2048
21 | dim_q: 2400
22 | seq2vec:
23 | arch: skipthoughts
24 | dir_st: data/skip-thoughts
25 | type: BayesianUniSkip
26 | dropout: 0.25
27 | fixed_emb: False
28 | attention:
29 | nb_glimpses: 2
30 | dim_hv: 310
31 | dim_hq: 310
32 | dim_mm: 510
33 | R: 5
34 | dropout_v: 0.5
35 | dropout_q: 0.5
36 | dropout_mm: 0.5
37 | activation_v: tanh
38 | activation_q: tanh
39 | dropout_hv: 0
40 | dropout_hq: 0
41 | fusion:
42 | dim_hv: 620
43 | dim_hq: 310
44 | dim_mm: 510
45 | R: 5
46 | dropout_v: 0.5
47 | dropout_q: 0.5
48 | activation_v: tanh
49 | activation_q: tanh
50 | dropout_hv: 0
51 | dropout_hq: 0
52 | classif:
53 | dropout: 0.5
54 | optim:
55 | lr: 0.0001
56 | batch_size: 128
57 | epochs: 100
58 |
--------------------------------------------------------------------------------
/options/vqa/mutan_noatt_train.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa/mutan_noatt_train
3 | vqa:
4 | dataset: VQA
5 | dir: data/vqa
6 | trainsplit: train
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152torch
16 | mode: noatt
17 | model:
18 | arch: MutanNoAtt
19 | seq2vec:
20 | arch: skipthoughts
21 | dir_st: data/skip-thoughts
22 | type: BayesianUniSkip
23 | dropout: 0.25
24 | fixed_emb: False
25 | fusion:
26 | dim_v: 2048
27 | dim_q: 2400
28 | dim_hv: 360
29 | dim_hq: 360
30 | dim_mm: 360
31 | R: 10
32 | dropout_v: 0.5
33 | dropout_q: 0.5
34 | activation_v: tanh
35 | activation_q: tanh
36 | dropout_hv: 0
37 | dropout_hq: 0
38 | classif:
39 | dropout: 0.5
40 | optim:
41 | lr: 0.0001
42 | batch_size: 512
43 | epochs: 100
44 |
--------------------------------------------------------------------------------
/options/vqa2/default.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa2/default
3 | vqa:
4 | dataset: VQA2
5 | dir: data/vqa2
6 | trainsplit: train
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152
16 | mode: noatt
17 | size: 448
18 | model:
19 | arch: MLBNoAtt
20 | seq2vec:
21 | arch: skipthoughts
22 | dir_st: data/skip-thoughts
23 | type: UniSkip
24 | dropout: 0.25
25 | fixed_emb: False
26 | fusion:
27 | dim_v: 2048
28 | dim_q: 2400
29 | dim_h: 1200
30 | dropout_v: 0.5
31 | dropout_q: 0.5
32 | activation_v: tanh
33 | activation_q: tanh
34 | classif:
35 | activation: tanh
36 | dropout: 0.5
37 | optim:
38 | lr: 0.0001
39 | batch_size: 512
40 | epochs: 100
41 |
--------------------------------------------------------------------------------
/options/vqa2/mlb_att_trainval.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa2/mlb_att_trainval
3 | vqa:
4 | dataset: VQA2
5 | dir: data/vqa2
6 | trainsplit: trainval
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152
16 | mode: att
17 | size: 448
18 | model:
19 | arch: MLBAtt
20 | dim_v: 2048
21 | dim_q: 2400
22 | seq2vec:
23 | arch: skipthoughts
24 | dir_st: data/skip-thoughts
25 | type: BayesianUniSkip
26 | dropout: 0.25
27 | fixed_emb: False
28 | attention:
29 | nb_glimpses: 4
30 | dim_h: 1200
31 | dropout_v: 0.5
32 | dropout_q: 0.5
33 | dropout_mm: 0.5
34 | activation_v: tanh
35 | activation_q: tanh
36 | activation_mm: tanh
37 | fusion:
38 | dim_h: 1200
39 | dropout_v: 0.5
40 | dropout_q: 0.5
41 | activation_v: tanh
42 | activation_q: tanh
43 | classif:
44 | activation: tanh
45 | dropout: 0.5
46 | optim:
47 | lr: 0.0001
48 | batch_size: 128
49 | epochs: 100
50 |
--------------------------------------------------------------------------------
/options/vqa2/mlb_noatt_train.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa2/mlb_noatt_train
3 | vqa:
4 | dataset: VQA2
5 | dir: data/vqa2
6 | trainsplit: train
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152
16 | mode: noatt
17 | size: 448
18 | model:
19 | arch: MLBNoAtt
20 | seq2vec:
21 | arch: skipthoughts
22 | dir_st: data/skip-thoughts
23 | type: BayesianUniSkip
24 | dropout: 0.25
25 | fixed_emb: False
26 | fusion:
27 | dim_v: 2048
28 | dim_q: 2400
29 | dim_h: 1200
30 | dropout_v: 0.5
31 | dropout_q: 0.5
32 | activation_v: tanh
33 | activation_q: tanh
34 | classif:
35 | activation: tanh
36 | dropout: 0.5
37 | optim:
38 | lr: 0.0001
39 | batch_size: 512
40 | epochs: 100
41 |
--------------------------------------------------------------------------------
/options/vqa2/mutan_att_train.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa2/mutan_att_train
3 | vqa:
4 | dataset: VQA2
5 | dir: data/vqa2
6 | trainsplit: train
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: /local/cadene/data/coco
15 | arch: fbresnet152
16 | mode: att
17 | size: 448
18 | model:
19 | arch: MutanAtt
20 | dim_v: 2048
21 | dim_q: 2400
22 | seq2vec:
23 | arch: skipthoughts
24 | dir_st: /local/cadene/data/skip-thoughts
25 | type: BayesianUniSkip
26 | dropout: 0.25
27 | fixed_emb: False
28 | attention:
29 | nb_glimpses: 2
30 | dim_hv: 310
31 | dim_hq: 310
32 | dim_mm: 510
33 | R: 5
34 | dropout_v: 0.5
35 | dropout_q: 0.5
36 | dropout_mm: 0.5
37 | activation_v: tanh
38 | activation_q: tanh
39 | dropout_hv: 0
40 | dropout_hq: 0
41 | fusion:
42 | dim_hv: 620
43 | dim_hq: 310
44 | dim_mm: 510
45 | R: 5
46 | dropout_v: 0.5
47 | dropout_q: 0.5
48 | activation_v: tanh
49 | activation_q: tanh
50 | dropout_hv: 0
51 | dropout_hq: 0
52 | classif:
53 | dropout: 0.5
54 | optim:
55 | lr: 0.0001
56 | batch_size: 128
57 | epochs: 100
58 |
--------------------------------------------------------------------------------
/options/vqa2/mutan_att_trainval.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa2/mutan_att_trainval
3 | vqa:
4 | dataset: VQA2
5 | dir: data/vqa2
6 | trainsplit: trainval
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: /local/cadene/data/coco
15 | arch: fbresnet152
16 | mode: att
17 | size: 448
18 | model:
19 | arch: MutanAtt
20 | dim_v: 2048
21 | dim_q: 2400
22 | seq2vec:
23 | arch: skipthoughts
24 | dir_st: /local/cadene/data/skip-thoughts
25 | type: BayesianUniSkip
26 | dropout: 0.25
27 | fixed_emb: False
28 | attention:
29 | nb_glimpses: 2
30 | dim_hv: 310
31 | dim_hq: 310
32 | dim_mm: 510
33 | R: 5
34 | dropout_v: 0.5
35 | dropout_q: 0.5
36 | dropout_mm: 0.5
37 | activation_v: tanh
38 | activation_q: tanh
39 | dropout_hv: 0
40 | dropout_hq: 0
41 | fusion:
42 | dim_hv: 620
43 | dim_hq: 310
44 | dim_mm: 510
45 | R: 5
46 | dropout_v: 0.5
47 | dropout_q: 0.5
48 | activation_v: tanh
49 | activation_q: tanh
50 | dropout_hv: 0
51 | dropout_hq: 0
52 | classif:
53 | dropout: 0.5
54 | optim:
55 | lr: 0.0001
56 | batch_size: 128
57 | epochs: 100
58 |
--------------------------------------------------------------------------------
/options/vqa2/mutan_att_trainval_vg.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa2/mutan_att_trainval
3 | vqa:
4 | dataset: VQA2
5 | dir: data/vqa2
6 | trainsplit: trainval
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152torchported
16 | mode: att
17 | size: 448
18 | vgenome:
19 | trainsplit: train
20 | dir: data/vgenome
21 | arch: fbresnet152
22 | mode: att
23 | size: 448
24 | nans: 2000
25 | maxlength: 26
26 | minwcount: 0
27 | nlp: mcb
28 | pad: right
29 | model:
30 | arch: MutanAtt
31 | dim_v: 2048
32 | dim_q: 2400
33 | seq2vec:
34 | arch: skipthoughts
35 | dir_st: data/skip-thoughts
36 | type: BayesianUniSkip
37 | dropout: 0.25
38 | fixed_emb: False
39 | attention:
40 | nb_glimpses: 2
41 | dim_hv: 310
42 | dim_hq: 310
43 | dim_mm: 510
44 | R: 5
45 | dropout_v: 0.5
46 | dropout_q: 0.5
47 | dropout_mm: 0.5
48 | activation_v: tanh
49 | activation_q: tanh
50 | dropout_hv: 0
51 | dropout_hq: 0
52 | fusion:
53 | dim_hv: 620
54 | dim_hq: 310
55 | dim_mm: 510
56 | R: 5
57 | dropout_v: 0.5
58 | dropout_q: 0.5
59 | activation_v: tanh
60 | activation_q: tanh
61 | dropout_hv: 0
62 | dropout_hq: 0
63 | classif:
64 | dropout: 0.5
65 | optim:
66 | lr: 0.0001
67 | batch_size: 128
68 | epochs: 100
69 |
--------------------------------------------------------------------------------
/options/vqa2/mutan_noatt_train.yaml:
--------------------------------------------------------------------------------
1 | logs:
2 | dir_logs: logs/vqa2/mutan_noatt_train
3 | vqa:
4 | dataset: VQA2
5 | dir: data/vqa2
6 | trainsplit: train
7 | nans: 2000
8 | maxlength: 26
9 | minwcount: 0
10 | nlp: mcb
11 | pad: right
12 | samplingans: True
13 | coco:
14 | dir: data/coco
15 | arch: fbresnet152
16 | mode: noatt
17 | size: 448
18 | model:
19 | arch: MutanNoAtt
20 | seq2vec:
21 | arch: skipthoughts
22 | dir_st: data/skip-thoughts
23 | type: BayesianUniSkip
24 | dropout: 0.25
25 | fixed_emb: False
26 | fusion:
27 | dim_v: 2048
28 | dim_q: 2400
29 | dim_hv: 360
30 | dim_hq: 360
31 | dim_mm: 360
32 | R: 10
33 | dropout_v: 0.5
34 | dropout_q: 0.5
35 | activation_v: tanh
36 | activation_q: tanh
37 | dropout_hv: 0
38 | dropout_hq: 0
39 | classif:
40 | dropout: 0.5
41 | optim:
42 | lr: 0.0001
43 | batch_size: 512
44 | epochs: 100
45 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch
3 | h5py
4 | pyyaml
5 | colorlover
6 | plotly
7 | scipy
8 | nltk
9 | click
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import yaml
5 | import json
6 | import click
7 | from pprint import pprint
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.backends.cudnn as cudnn
12 |
13 | import vqa.lib.engine as engine
14 | import vqa.lib.utils as utils
15 | import vqa.lib.logger as logger
16 | import vqa.lib.criterions as criterions
17 | import vqa.datasets as datasets
18 | import vqa.models as models
19 |
20 | parser = argparse.ArgumentParser(
21 | description='Train/Evaluate models',
22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
23 | ##################################################
24 | # yaml options file contains all default choices #
25 | parser.add_argument('--path_opt', default='options/vqa/default.yaml', type=str,
26 | help='path to a yaml options file')
27 | ################################################
28 | # change cli options to modify default choices #
29 | # logs options
30 | parser.add_argument('--dir_logs', type=str, help='dir logs')
31 | # data options
32 | parser.add_argument('--vqa_trainsplit', type=str, choices=['train','trainval'])
33 | # model options
34 | parser.add_argument('--arch', choices=models.model_names,
35 | help='vqa model architecture: ' +
36 | ' | '.join(models.model_names))
37 | parser.add_argument('--st_type',
38 | help='skipthoughts type')
39 | parser.add_argument('--st_dropout', type=float)
40 | parser.add_argument('--st_fixed_emb', default=None, type=utils.str2bool,
41 | help='backprop on embedding')
42 | # optim options
43 | parser.add_argument('-lr', '--learning_rate', type=float,
44 | help='initial learning rate')
45 | parser.add_argument('-b', '--batch_size', type=int,
46 | help='mini-batch size')
47 | parser.add_argument('--epochs', type=int,
48 | help='number of total epochs to run')
49 | # options not in yaml file
50 | parser.add_argument('--start_epoch', default=0, type=int,
51 | help='manual epoch number (useful on restarts)')
52 | parser.add_argument('--resume', default='', type=str,
53 | help='path to latest checkpoint')
54 | parser.add_argument('--save_model', default=True, type=utils.str2bool,
55 | help='able or disable save model and optim state')
56 | parser.add_argument('--save_all_from', type=int,
57 | help='''delete the preceding checkpoint until an epoch,'''
58 | ''' then keep all (useful to save disk space)')''')
59 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
60 | help='evaluate model on validation and test set')
61 | parser.add_argument('-j', '--workers', default=2, type=int,
62 | help='number of data loading workers')
63 | parser.add_argument('--print_freq', '-p', default=10, type=int,
64 | help='print frequency')
65 | ################################################
66 | parser.add_argument('-ho', '--help_opt', dest='help_opt', action='store_true',
67 | help='show selected options before running')
68 |
69 | best_acc1 = 0
70 |
71 | def main():
72 | global args, best_acc1
73 | args = parser.parse_args()
74 |
75 | #########################################################################################
76 | # Create options
77 | #########################################################################################
78 |
79 | options = {
80 | 'vqa' : {
81 | 'trainsplit': args.vqa_trainsplit
82 | },
83 | 'logs': {
84 | 'dir_logs': args.dir_logs
85 | },
86 | 'model': {
87 | 'arch': args.arch,
88 | 'seq2vec': {
89 | 'type': args.st_type,
90 | 'dropout': args.st_dropout,
91 | 'fixed_emb': args.st_fixed_emb
92 | }
93 | },
94 | 'optim': {
95 | 'lr': args.learning_rate,
96 | 'batch_size': args.batch_size,
97 | 'epochs': args.epochs
98 | }
99 | }
100 | if args.path_opt is not None:
101 | with open(args.path_opt, 'r') as handle:
102 | options_yaml = yaml.load(handle)
103 | options = utils.update_values(options, options_yaml)
104 | print('## args'); pprint(vars(args))
105 | print('## options'); pprint(options)
106 | if args.help_opt:
107 | return
108 |
109 | # Set datasets options
110 | if 'vgenome' not in options:
111 | options['vgenome'] = None
112 |
113 | #########################################################################################
114 | # Create needed datasets
115 | #########################################################################################
116 |
117 | trainset = datasets.factory_VQA(options['vqa']['trainsplit'],
118 | options['vqa'],
119 | options['coco'],
120 | options['vgenome'])
121 | train_loader = trainset.data_loader(batch_size=options['optim']['batch_size'],
122 | num_workers=args.workers,
123 | shuffle=True)
124 |
125 | if options['vqa']['trainsplit'] == 'train':
126 | valset = datasets.factory_VQA('val', options['vqa'], options['coco'])
127 | val_loader = valset.data_loader(batch_size=options['optim']['batch_size'],
128 | num_workers=args.workers)
129 |
130 | if options['vqa']['trainsplit'] == 'trainval' or args.evaluate:
131 | testset = datasets.factory_VQA('test', options['vqa'], options['coco'])
132 | test_loader = testset.data_loader(batch_size=options['optim']['batch_size'],
133 | num_workers=args.workers)
134 |
135 | #########################################################################################
136 | # Create model, criterion and optimizer
137 | #########################################################################################
138 |
139 | model = models.factory(options['model'],
140 | trainset.vocab_words(), trainset.vocab_answers(),
141 | cuda=True, data_parallel=True)
142 | criterion = criterions.factory(options['vqa'], cuda=True)
143 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
144 | options['optim']['lr'])
145 |
146 | #########################################################################################
147 | # args.resume: resume from a checkpoint OR create logs directory
148 | #########################################################################################
149 |
150 | exp_logger = None
151 | if args.resume:
152 | args.start_epoch, best_acc1, exp_logger = load_checkpoint(model.module, optimizer,
153 | os.path.join(options['logs']['dir_logs'], args.resume))
154 | else:
155 | # Or create logs directory
156 | if os.path.isdir(options['logs']['dir_logs']):
157 | if click.confirm('Logs directory already exists in {}. Erase?'
158 | .format(options['logs']['dir_logs'], default=False)):
159 | os.system('rm -r ' + options['logs']['dir_logs'])
160 | else:
161 | return
162 | os.system('mkdir -p ' + options['logs']['dir_logs'])
163 | path_new_opt = os.path.join(options['logs']['dir_logs'],
164 | os.path.basename(args.path_opt))
165 | path_args = os.path.join(options['logs']['dir_logs'], 'args.yaml')
166 | with open(path_new_opt, 'w') as f:
167 | yaml.dump(options, f, default_flow_style=False)
168 | with open(path_args, 'w') as f:
169 | yaml.dump(vars(args), f, default_flow_style=False)
170 |
171 | if exp_logger is None:
172 | # Set loggers
173 | exp_name = os.path.basename(options['logs']['dir_logs']) # add timestamp
174 | exp_logger = logger.Experiment(exp_name, options)
175 | exp_logger.add_meters('train', make_meters())
176 | exp_logger.add_meters('test', make_meters())
177 | if options['vqa']['trainsplit'] == 'train':
178 | exp_logger.add_meters('val', make_meters())
179 | exp_logger.info['model_params'] = utils.params_count(model)
180 | print('Model has {} parameters'.format(exp_logger.info['model_params']))
181 |
182 | #########################################################################################
183 | # args.evaluate: on valset OR/AND on testset
184 | #########################################################################################
185 |
186 | if args.evaluate:
187 | path_logger_json = os.path.join(options['logs']['dir_logs'], 'logger.json')
188 |
189 | if options['vqa']['trainsplit'] == 'train':
190 | acc1, val_results = engine.validate(val_loader, model, criterion,
191 | exp_logger, args.start_epoch, args.print_freq)
192 | # save results and compute OpenEnd accuracy
193 | exp_logger.to_json(path_logger_json)
194 | save_results(val_results, args.start_epoch, valset.split_name(),
195 | options['logs']['dir_logs'], options['vqa']['dir'])
196 |
197 | test_results, testdev_results = engine.test(test_loader, model, exp_logger,
198 | args.start_epoch, args.print_freq)
199 | # save results and DOES NOT compute OpenEnd accuracy
200 | exp_logger.to_json(path_logger_json)
201 | save_results(test_results, args.start_epoch, testset.split_name(),
202 | options['logs']['dir_logs'], options['vqa']['dir'])
203 | save_results(testdev_results, args.start_epoch, testset.split_name(testdev=True),
204 | options['logs']['dir_logs'], options['vqa']['dir'])
205 | return
206 |
207 | #########################################################################################
208 | # Begin training on train/val or trainval/test
209 | #########################################################################################
210 |
211 | for epoch in range(args.start_epoch+1, options['optim']['epochs']):
212 | #adjust_learning_rate(optimizer, epoch)
213 |
214 | # train for one epoch
215 | engine.train(train_loader, model, criterion, optimizer,
216 | exp_logger, epoch, args.print_freq)
217 |
218 | if options['vqa']['trainsplit'] == 'train':
219 | # evaluate on validation set
220 | acc1, val_results = engine.validate(val_loader, model, criterion,
221 | exp_logger, epoch, args.print_freq)
222 | # remember best prec@1 and save checkpoint
223 | is_best = acc1 > best_acc1
224 | best_acc1 = max(acc1, best_acc1)
225 | save_checkpoint({
226 | 'epoch': epoch,
227 | 'arch': options['model']['arch'],
228 | 'best_acc1': best_acc1,
229 | 'exp_logger': exp_logger
230 | },
231 | model.module.state_dict(),
232 | optimizer.state_dict(),
233 | options['logs']['dir_logs'],
234 | args.save_model,
235 | args.save_all_from,
236 | is_best)
237 |
238 | # save results and compute OpenEnd accuracy
239 | save_results(val_results, epoch, valset.split_name(),
240 | options['logs']['dir_logs'], options['vqa']['dir'])
241 | else:
242 | test_results, testdev_results = engine.test(test_loader, model, exp_logger,
243 | epoch, args.print_freq)
244 |
245 | # save checkpoint at every timestep
246 | save_checkpoint({
247 | 'epoch': epoch,
248 | 'arch': options['model']['arch'],
249 | 'best_acc1': best_acc1,
250 | 'exp_logger': exp_logger
251 | },
252 | model.module.state_dict(),
253 | optimizer.state_dict(),
254 | options['logs']['dir_logs'],
255 | args.save_model,
256 | args.save_all_from)
257 |
258 | # save results and DOES NOT compute OpenEnd accuracy
259 | save_results(test_results, epoch, testset.split_name(),
260 | options['logs']['dir_logs'], options['vqa']['dir'])
261 | save_results(testdev_results, epoch, testset.split_name(testdev=True),
262 | options['logs']['dir_logs'], options['vqa']['dir'])
263 |
264 |
265 | def make_meters():
266 | meters_dict = {
267 | 'loss': logger.AvgMeter(),
268 | 'acc1': logger.AvgMeter(),
269 | 'acc5': logger.AvgMeter(),
270 | 'batch_time': logger.AvgMeter(),
271 | 'data_time': logger.AvgMeter(),
272 | 'epoch_time': logger.SumMeter()
273 | }
274 | return meters_dict
275 |
276 | def save_results(results, epoch, split_name, dir_logs, dir_vqa):
277 | dir_epoch = os.path.join(dir_logs, 'epoch_' + str(epoch))
278 | name_json = 'OpenEnded_mscoco_{}_model_results.json'.format(split_name)
279 | # TODO: simplify formating
280 | if 'test' in split_name:
281 | name_json = 'vqa_' + name_json
282 | path_rslt = os.path.join(dir_epoch, name_json)
283 | os.system('mkdir -p ' + dir_epoch)
284 | with open(path_rslt, 'w') as handle:
285 | json.dump(results, handle)
286 | if not 'test' in split_name:
287 | os.system('python2 eval_res.py --dir_vqa {} --dir_epoch {} --subtype {} &'
288 | .format(dir_vqa, dir_epoch, split_name))
289 |
290 | def save_checkpoint(info, model, optim, dir_logs, save_model, save_all_from=None, is_best=True):
291 | os.system('mkdir -p ' + dir_logs)
292 | if save_all_from is None:
293 | path_ckpt_info = os.path.join(dir_logs, 'ckpt_info.pth.tar')
294 | path_ckpt_model = os.path.join(dir_logs, 'ckpt_model.pth.tar')
295 | path_ckpt_optim = os.path.join(dir_logs, 'ckpt_optim.pth.tar')
296 | path_best_info = os.path.join(dir_logs, 'best_info.pth.tar')
297 | path_best_model = os.path.join(dir_logs, 'best_model.pth.tar')
298 | path_best_optim = os.path.join(dir_logs, 'best_optim.pth.tar')
299 | # save info & logger
300 | path_logger = os.path.join(dir_logs, 'logger.json')
301 | info['exp_logger'].to_json(path_logger)
302 | torch.save(info, path_ckpt_info)
303 | if is_best:
304 | shutil.copyfile(path_ckpt_info, path_best_info)
305 | # save model state & optim state
306 | if save_model:
307 | torch.save(model, path_ckpt_model)
308 | torch.save(optim, path_ckpt_optim)
309 | if is_best:
310 | shutil.copyfile(path_ckpt_model, path_best_model)
311 | shutil.copyfile(path_ckpt_optim, path_best_optim)
312 | else:
313 | is_best = False # because we don't know the test accuracy
314 | path_ckpt_info = os.path.join(dir_logs, 'ckpt_epoch,{}_info.pth.tar')
315 | path_ckpt_model = os.path.join(dir_logs, 'ckpt_epoch,{}_model.pth.tar')
316 | path_ckpt_optim = os.path.join(dir_logs, 'ckpt_epoch,{}_optim.pth.tar')
317 | # save info & logger
318 | path_logger = os.path.join(dir_logs, 'logger.json')
319 | info['exp_logger'].to_json(path_logger)
320 | torch.save(info, path_ckpt_info.format(info['epoch']))
321 | # save model state & optim state
322 | if save_model:
323 | torch.save(model, path_ckpt_model.format(info['epoch']))
324 | torch.save(optim, path_ckpt_optim.format(info['epoch']))
325 | if info['epoch'] > 1 and info['epoch'] < save_all_from + 1:
326 | os.system('rm ' + path_ckpt_info.format(info['epoch'] - 1))
327 | os.system('rm ' + path_ckpt_model.format(info['epoch'] - 1))
328 | os.system('rm ' + path_ckpt_optim.format(info['epoch'] - 1))
329 | if not save_model:
330 | print('Warning train.py: checkpoint not saved')
331 |
332 | def load_checkpoint(model, optimizer, path_ckpt):
333 | path_ckpt_info = path_ckpt + '_info.pth.tar'
334 | path_ckpt_model = path_ckpt + '_model.pth.tar'
335 | path_ckpt_optim = path_ckpt + '_optim.pth.tar'
336 | if os.path.isfile(path_ckpt_info):
337 | info = torch.load(path_ckpt_info)
338 | start_epoch = 0
339 | best_acc1 = 0
340 | exp_logger = None
341 | if 'epoch' in info:
342 | start_epoch = info['epoch']
343 | else:
344 | print('Warning train.py: no epoch to resume')
345 | if 'best_acc1' in info:
346 | best_acc1 = info['best_acc1']
347 | else:
348 | print('Warning train.py: no best_acc1 to resume')
349 | if 'exp_logger' in info:
350 | exp_logger = info['exp_logger']
351 | else:
352 | print('Warning train.py: no exp_logger to resume')
353 | else:
354 | print("Warning train.py: no info checkpoint found at '{}'".format(path_ckpt_info))
355 | if os.path.isfile(path_ckpt_model):
356 | model_state = torch.load(path_ckpt_model)
357 | model.load_state_dict(model_state)
358 | else:
359 | print("Warning train.py: no model checkpoint found at '{}'".format(path_ckpt_model))
360 | if optimizer is not None and os.path.isfile(path_ckpt_optim):
361 | optim_state = torch.load(path_ckpt_optim)
362 | optimizer.load_state_dict(optim_state)
363 | else:
364 | print("Warning train.py: no optim checkpoint found at '{}'".format(path_ckpt_optim))
365 | print("=> loaded checkpoint '{}' (epoch {}, best_acc1 {})"
366 | .format(path_ckpt, start_epoch, best_acc1))
367 | return start_epoch, best_acc1, exp_logger
368 |
369 | if __name__ == '__main__':
370 | main()
371 |
--------------------------------------------------------------------------------
/visu.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import pandas as pd\n",
12 | "import numpy as np\n",
13 | "import time\n",
14 | "import os\n",
15 | "import matplotlib.pyplot as plt, mpld3\n",
16 | "import seaborn as sns\n",
17 | "from vqa.lib.logger import Experiment\n",
18 | "import json\n",
19 | "import random\n",
20 | "#%matplotlib notebook\n",
21 | "#mpld3.enable_notebook()"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {
28 | "collapsed": false,
29 | "scrolled": false
30 | },
31 | "outputs": [],
32 | "source": [
33 | "def plot_line(data_x, data_y, xlabel=\"epochs\", ylabel=\"accuracy %\", title=\"Name\"):\n",
34 | " fig, ax = plt.subplots()\n",
35 | " ax.set_xlabel(\"epochs\")\n",
36 | " ax.set_ylabel(\"accuracy\")\n",
37 | " ax.set_title(title)\n",
38 | "\n",
39 | " lines = ax.plot(data_y, data_x, '--', marker=\"o\", label=\"lol\")\n",
40 | "\n",
41 | " mpld3.plugins.connect(fig, mpld3.plugins.LineLabelTooltip(lines[0], label=xp_name))\n",
42 | " mpld3.plugins.connect(fig, mpld3.plugins.PointLabelTooltip(lines[0]))#data_y)\n",
43 | " mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fontsize=14))\n",
44 | " #ax.grid(color='lightgrey', linestyle='solid')\n",
45 | " mpld3.show()\n",
46 | " \n",
47 | "def plot_acc(train_accs, val_accs, fig=None, \\\n",
48 | " xlabel=\"epochs\", ylabel=\"accuracy %\", title=\"Name\"):\n",
49 | " if fig is None:\n",
50 | " fig = plt.figure()\n",
51 | " n = 0\n",
52 | " else:\n",
53 | " n = len(fig.axes)\n",
54 | " for i in range(n):\n",
55 | " fig.axes[i].change_geometry(n+1, 1, i+1)\n",
56 | " \n",
57 | " ax = fig.add_subplot(n+1, 1, n+1)\n",
58 | " \n",
59 | " ax.set_xlabel(xlabel)\n",
60 | " ax.set_ylabel(ylabel)\n",
61 | " ax.set_title(title)\n",
62 | "\n",
63 | " data_x = range(1, len(train_accs)+1)\n",
64 | " ax.plot(data_x, train_accs, '--', marker=\"o\", label=\"train\")\n",
65 | " ax.plot(data_x, val_accs, '--', marker=\"o\", label=\"val\")\n",
66 | "\n",
67 | " line_labels = ['train', 'val']\n",
68 | " point_labels = [train_accs, val_accs]\n",
69 | " for i, line in enumerate(ax.lines):\n",
70 | " mpld3.plugins.connect(fig, \\\n",
71 | " mpld3.plugins.LineLabelTooltip(line, label=line_labels[i]))\n",
72 | " mpld3.plugins.connect(fig, \\\n",
73 | " mpld3.plugins.PointLabelTooltip(line, labels=point_labels[i]))\n",
74 | " \n",
75 | " mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fontsize=14))\n",
76 | " #ax.grid(color='lightgrey', linestyle='solid')\n",
77 | " #mpld3.show(open_browser=False)\n",
78 | " return fig\n",
79 | " \n",
80 | "def plot_accs(train_accs=[], val_accs=[], names=[], fig=None, \\\n",
81 | " xlabel=\"epochs\", ylabel=\"accuracy %\", title=\"Name\"):\n",
82 | " if fig is None:\n",
83 | " fig = plt.figure()\n",
84 | " n = 0\n",
85 | " else:\n",
86 | " n = len(fig.axes)\n",
87 | " for i in range(n):\n",
88 | " fig.axes[i].change_geometry(n+1, 1, i+1)\n",
89 | " \n",
90 | " ax = fig.add_subplot(n+1, 1, n+1)\n",
91 | " \n",
92 | " ax.set_xlabel(xlabel)\n",
93 | " ax.set_ylabel(ylabel)\n",
94 | " ax.set_title(title)\n",
95 | "\n",
96 | " max_length = max([len(acc) for acc in val_accs])\n",
97 | " data_x = range(1, max_length+1)\n",
98 | " for i in range(len(val_accs)):\n",
99 | " color = (random.random(), random.random(), random.random())\n",
100 | " data_x_train = range(1, len(train_accs[i])+1)\n",
101 | " data_x_val = range(1, len(val_accs[i])+1)\n",
102 | " ax.plot(data_x_train, train_accs[i],\n",
103 | " '--', marker=\"o\", color=color,\n",
104 | " label=names[i]+\"train\")\n",
105 | " ax.plot(data_x_val, val_accs[i],\n",
106 | " '-', marker=\"o\", color=color,\n",
107 | " label=names[i]+\"val\")\n",
108 | "\n",
109 | " point_labels = {\n",
110 | " 'train': train_accs,\n",
111 | " 'val': val_accs\n",
112 | " }\n",
113 | " for line_id in range(len(ax.lines)):\n",
114 | " label = 'train' if line_id%2==0 else 'val'\n",
115 | " #mpld3.plugins.connect(fig, \\\n",
116 | " # mpld3.plugins.LineLabelTooltip(ax.lines[line_id], label=label))\n",
117 | " mpld3.plugins.connect(fig, \\\n",
118 | " mpld3.plugins.PointLabelTooltip(ax.lines[line_id], labels=point_labels[label][int(i/2)]))\n",
119 | " \n",
120 | " ax.legend()\n",
121 | " labels = []\n",
122 | " for name in names:\n",
123 | " labels.append(name + '_train')\n",
124 | " labels.append(name + '_val')\n",
125 | " mpld3.plugins.connect(fig, mpld3.plugins.InteractiveLegendPlugin(ax.lines, labels))\n",
126 | " mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fontsize=14))\n",
127 | " #ax.grid(color='lightgrey', linestyle='solid')\n",
128 | " #mpld3.show(open_browser=False)\n",
129 | " return fig \n",
130 | "\n",
131 | "def load_accs_oe(path_logger):\n",
132 | " dir_xp = os.path.dirname(path_logger)\n",
133 | " epochs = []\n",
134 | " for name in os.listdir(dir_xp):\n",
135 | " if name.startswith('epoch'):\n",
136 | " epochs.append(name)\n",
137 | " epochs = sorted(epochs, key=lambda x: float(x.split('_')[1]))\n",
138 | " accs = {}\n",
139 | " for i, epoch in enumerate(epochs):\n",
140 | " epoch_id = i+1\n",
141 | " path_acc = os.path.join(dir_xp, epoch, 'OpenEnded_mscoco_val2014_model_accuracy.json')\n",
142 | " if os.path.exists(path_acc):\n",
143 | " with open(path_acc, 'r') as f:\n",
144 | " data = json.load(f)\n",
145 | " accs[epoch_id] = data['overall']\n",
146 | " return accs\n",
147 | "\n",
148 | "def sort(dict_):\n",
149 | " return [v for k,v in sorted(dict_.items(), \\\n",
150 | " key=lambda x: float(x[0]))]\n",
151 | "\n",
152 | "def reduce(list_, num=15):\n",
153 | " tmp = []\n",
154 | " for i, val in enumerate(list_):\n",
155 | " if i < num:\n",
156 | " tmp.append(val)\n",
157 | " return tmp"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": null,
163 | "metadata": {
164 | "collapsed": false,
165 | "scrolled": true
166 | },
167 | "outputs": [],
168 | "source": [
169 | "dir_logs = 'logs/vqa'\n",
170 | "xps = []\n",
171 | "for p in sorted(os.listdir(dir_logs)):\n",
172 | " if not p.startswith('.'): print(p)"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {
179 | "collapsed": false
180 | },
181 | "outputs": [],
182 | "source": [
183 | "path_logger = os.path.join(dir_logs, 'mlb_att_samplingans,True', 'logger.json')\n",
184 | "xp = Experiment.from_json(path_logger)\n",
185 | "xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger)\n",
186 | "#print(xp.logged['val']['acc1_oe'])\n",
187 | "print(xp.options)"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {
194 | "collapsed": false
195 | },
196 | "outputs": [],
197 | "source": [
198 | "# Display accuracy & loss of one exp\n",
199 | "def display_xp(path_logger):\n",
200 | " xp = Experiment.from_json(path_logger)\n",
201 | " xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger)\n",
202 | "\n",
203 | " path_visu = os.path.join(os.path.dirname(path_logger),\n",
204 | " 'visu.html')\n",
205 | "\n",
206 | " val_acc1 = sort(xp.logged['val']['acc1_oe'])\n",
207 | " train_acc1 = sort(xp.logged['train']['acc1'])\n",
208 | " loss_val = sort(xp.logged['val']['loss'])\n",
209 | " loss_train = sort(xp.logged['train']['loss'])\n",
210 | "\n",
211 | " max_length = min([len(val_acc1), len(train_acc1)])\n",
212 | " val_acc1 = reduce(val_acc1, max_length)\n",
213 | " train_acc1= reduce(train_acc1, max_length)\n",
214 | " loss_val = reduce(loss_val, max_length)\n",
215 | " loss_train=reduce(loss_train, max_length)\n",
216 | " \n",
217 | " print(max(val_acc1))\n",
218 | " \n",
219 | " fig = plt.figure(figsize=(10,13))\n",
220 | " fig = plot_acc(train_acc1, val_acc1, fig=fig, \\\n",
221 | " title='max: '+str(max(val_acc1)))\n",
222 | " fig = plot_acc(loss_train, loss_val, fig=fig, \\\n",
223 | " ylabel='loss', title='min: '+str(min(loss_val)))\n",
224 | " fig.tight_layout()\n",
225 | " with open(path_visu, 'w') as f:\n",
226 | " mpld3.save_html(fig, f)\n",
227 | " \n",
228 | "#for name in ['blocmutan']:#['seq2vec/bgru', 'mutan_noatt_2', 'mlb_att_2']:\n",
229 | "for name in os.listdir(dir_logs):\n",
230 | " if not name.startswith('.'):\n",
231 | " path_logger = os.path.join(dir_logs, name, 'logger.json')\n",
232 | " print(name)\n",
233 | " display_xp(path_logger)"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": null,
239 | "metadata": {
240 | "collapsed": false
241 | },
242 | "outputs": [],
243 | "source": [
244 | "# Display accuracy & loss of multiple exp\n",
245 | "accs_train = []\n",
246 | "accs_val = []\n",
247 | "names = []\n",
248 | "acc_max = ('xp name', -1)\n",
249 | "fig = plt.figure(figsize=(20,13))\n",
250 | "#for name in os.listdir(dir_logs):#['bgu_normal', 'mutan_noatt', 'mutan_noatt_2', 'mlb_att_2', 'mlb_att_3']:\n",
251 | "#for name in ['seq2vec/bgru_dropout,25%', 'mutan_noatt', 'blocmutan_noatt', 'mlb_noatt_2lstm', 'mlb_noatt_2lstm_emb']:\n",
252 | "#for name in ['blocmutan_noatt', 'blocmutan_noatt_samplingans,True']:\n",
253 | "dir_logs = 'logs/vqa'\n",
254 | "exps = []\n",
255 | "exps.append('blocmutan_att_2')\n",
256 | "exps.append('mutan_att_2')\n",
257 | "exps.append('mlb_att_2')\n",
258 | "\n",
259 | "for name in exps:\n",
260 | " if not name.startswith('.'):\n",
261 | " path_logger = os.path.join(dir_logs, name, 'logger.json')\n",
262 | " xp = Experiment.from_json(path_logger)\n",
263 | " xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger)\n",
264 | " train_acc1 = sort(xp.logged['train']['acc1'])\n",
265 | " val_acc1 = sort(xp.logged['val']['acc1_oe'])\n",
266 | "\n",
267 | " max_length = min([len(val_acc1), len(train_acc1)])\n",
268 | " val_acc1 = reduce(val_acc1, max_length)\n",
269 | " train_acc1= reduce(train_acc1, max_length)\n",
270 | "\n",
271 | " accs_val.append(val_acc1)\n",
272 | " accs_train.append(train_acc1)\n",
273 | " print(name, max(val_acc1))\n",
274 | " if max(val_acc1) > acc_max[1]:\n",
275 | " acc_max = (name, max(val_acc1))\n",
276 | " names.append(name)\n",
277 | "fig = plot_accs(accs_train, accs_val, names,\n",
278 | " title=acc_max[0]+' '+str(acc_max[1]), fig=fig)\n",
279 | "path_visu = dir_logs+'/test.html'\n",
280 | "with open(path_visu, 'w') as f:\n",
281 | " mpld3.save_html(fig, f)"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "metadata": {
288 | "collapsed": false
289 | },
290 | "outputs": [],
291 | "source": [
292 | "# Display accuracy & loss of multiple exp\n",
293 | "accs_train = []\n",
294 | "accs_val = []\n",
295 | "names = []\n",
296 | "acc_max = ('xp name', -1)\n",
297 | "fig = plt.figure(figsize=(20,13))\n",
298 | "#for name in os.listdir(dir_logs):#['bgu_normal', 'mutan_noatt', 'mutan_noatt_2', 'mlb_att_2', 'mlb_att_3']:\n",
299 | "#for name in ['seq2vec/bgru_dropout,25%', 'mutan_noatt', 'blocmutan_noatt', 'mlb_noatt_2lstm', 'mlb_noatt_2lstm_emb']:\n",
300 | "#for name in ['blocmutan_noatt', 'blocmutan_noatt_samplingans,True']:\n",
301 | "\n",
302 | "dir_logs = 'logs/vqa'\n",
303 | "list_dir = []\n",
304 | "for name in os.listdir(dir_logs):\n",
305 | " if not name.startswith('.') and not name.endswith('.html'):\n",
306 | " list_dir.append(name)\n",
307 | " \n",
308 | "for name in list_dir:\n",
309 | " path_logger = os.path.join(dir_logs, name, 'logger.json')\n",
310 | " #xp = Experiment.from_json(path_logger)\n",
311 | " val_acc1 = load_accs_oe(path_logger)\n",
312 | " #train_acc1 = sort(xp.logged['train']['acc1'])\n",
313 | " val_acc1 = sort(val_acc1)\n",
314 | "\n",
315 | " max_length = len(val_acc1)\n",
316 | " val_acc1 = reduce(val_acc1, max_length)\n",
317 | " #train_acc1= reduce(train_acc1, max_length)\n",
318 | "\n",
319 | " accs_val.append(val_acc1)\n",
320 | " #accs_train.append(train_acc1)\n",
321 | " print(name, max(val_acc1))\n",
322 | " if max(val_acc1) > acc_max[1]:\n",
323 | " acc_max = (name, max(val_acc1))\n",
324 | " names.append(name)\n",
325 | "\n",
326 | "fig = plot_accs(accs_val, accs_val, names,\n",
327 | " title=acc_max[0]+' '+str(acc_max[1]), fig=fig)\n",
328 | "path_visu = dir_logs+'/test.html'\n",
329 | "with open(path_visu, 'w') as f:\n",
330 | " mpld3.save_html(fig, f)"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": null,
336 | "metadata": {
337 | "collapsed": true
338 | },
339 | "outputs": [],
340 | "source": []
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "metadata": {
346 | "collapsed": false
347 | },
348 | "outputs": [],
349 | "source": [
350 | "# Test displaying accuracy of one exp\n",
351 | "val_acc1 = sort(xp.logged['val']['acc1'])\n",
352 | "train_acc1 = sort(xp.logged['train']['acc1'])\n",
353 | "loss_val = sort(xp.logged['val']['loss'])\n",
354 | "loss_train = sort(xp.logged['train']['loss'])\n",
355 | "\n",
356 | "#plot_line(val_acc1, range(1,len(val_acc1)+1))\n",
357 | "for i in range(1,40):\n",
358 | " val_acc1_tmp = reduce(val_acc1, i)\n",
359 | " train_acc1_tmp = reduce(train_acc1, i)\n",
360 | " loss_val_tmp = reduce(loss_val, i)\n",
361 | " loss_train_tmp = reduce(loss_train, i)\n",
362 | " fig = plt.figure(figsize=(10,13))\n",
363 | " fig = plot_acc(train_acc1_tmp, val_acc1_tmp, fig=fig, \\\n",
364 | " title='max: '+str(max(val_acc1_tmp)))\n",
365 | " fig = plot_acc(loss_train_tmp, loss_val_tmp, fig=fig, \\\n",
366 | " ylabel='loss', title='min: '+str(min(loss_val_tmp)))\n",
367 | " fig.tight_layout()\n",
368 | " with open('test.html', 'w') as f:\n",
369 | " mpld3.save_html(fig, f)\n",
370 | " time.sleep(3)\n",
371 | " "
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "metadata": {
378 | "collapsed": true
379 | },
380 | "outputs": [],
381 | "source": []
382 | }
383 | ],
384 | "metadata": {
385 | "anaconda-cloud": {},
386 | "kernelspec": {
387 | "display_name": "Python [conda env:vqa]",
388 | "language": "python",
389 | "name": "conda-env-vqa-py"
390 | },
391 | "language_info": {
392 | "codemirror_mode": {
393 | "name": "ipython",
394 | "version": 3
395 | },
396 | "file_extension": ".py",
397 | "mimetype": "text/x-python",
398 | "name": "python",
399 | "nbconvert_exporter": "python",
400 | "pygments_lexer": "ipython3",
401 | "version": "3.6.0"
402 | }
403 | },
404 | "nbformat": 4,
405 | "nbformat_minor": 1
406 | }
407 |
--------------------------------------------------------------------------------
/visu.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import json
5 | import random
6 | import numpy as np
7 | import colorlover as cl
8 |
9 | import plotly.plotly as py
10 | import plotly.graph_objs as go
11 | from plotly.offline import download_plotlyjs, plot
12 | from plotly import tools
13 |
14 | from vqa.lib.logger import Experiment
15 |
16 | def load_accs_oe(path_logger):
17 | dir_xp = os.path.dirname(path_logger)
18 | epochs = []
19 | for name in os.listdir(dir_xp):
20 | if name.startswith('epoch'):
21 | epochs.append(name)
22 | epochs = sorted(epochs, key=lambda x: float(x.split('_')[1]))
23 | accs = {}
24 | for i, epoch in enumerate(epochs):
25 | epoch_id = i+1
26 | path_acc = os.path.join(dir_xp, epoch, 'OpenEnded_mscoco_val2014_model_accuracy.json')
27 | if os.path.exists(path_acc):
28 | with open(path_acc, 'r') as f:
29 | data = json.load(f)
30 | accs[epoch_id] = data['overall']
31 | return accs
32 |
33 | def sort(dict_):
34 | return [v for k,v in sorted(dict_.items(), \
35 | key=lambda x: float(x[0]))]
36 |
37 | def reduce(list_, num=15):
38 | tmp = []
39 | for i, val in enumerate(list_):
40 | if i < num:
41 | tmp.append(val)
42 | return tmp
43 |
44 | # Display accuracy & loss of one exp
45 | def visu_one_exp(path_logger, path_visu, auto_open=True):
46 | xp = Experiment.from_json(path_logger)
47 | xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger)
48 | train_acc1 = sort(xp.logged['train']['acc1'])
49 | val_acc1 = sort(xp.logged['val']['acc1_oe'])
50 | train_loss = sort(xp.logged['train']['loss'])
51 | val_loss = sort(xp.logged['val']['loss'])
52 | train_data_x = list(range(1, len(train_acc1)+1))
53 | val_data_x = list(range(1, len(val_acc1)+1))
54 |
55 | fig = tools.make_subplots(rows=1,
56 | cols=2,
57 | subplot_titles=('Accuracy top1', 'Loss'))
58 | # blue rgb(31, 119, 180)
59 | # orange rgb(255, 127, 14)
60 |
61 | train_acc1_trace = go.Scatter(
62 | x=train_data_x,
63 | y=train_acc1,
64 | name='train accuracy top1'
65 | )
66 | val_acc1_trace = go.Scatter(
67 | x=val_data_x,
68 | y=val_acc1,
69 | name='val accuracy top1',
70 | line = dict(
71 | color = ('rgb(255, 127, 14)'),
72 | )
73 | )
74 | best_val_acc1_trace = go.Scatter(
75 | x=[np.argmax(val_acc1)+1],
76 | y=[max(val_acc1)],
77 | mode='markers',
78 | name='best val accuracy top1',
79 | marker = dict(
80 | color = 'rgb(255, 127, 14)',
81 | size = 10
82 | )
83 | )
84 |
85 | val_loss_trace = go.Scatter(
86 | x=val_data_x,
87 | y=val_loss,
88 | name='val loss'
89 | )
90 | train_loss_trace = go.Scatter(
91 | x=train_data_x,
92 | y=train_loss,
93 | name='train loss'
94 | )
95 |
96 | fig.append_trace(train_acc1_trace, 1, 1)
97 | fig.append_trace(val_acc1_trace, 1, 1)
98 | fig.append_trace(best_val_acc1_trace, 1, 1)
99 |
100 | fig.append_trace(train_loss_trace, 1, 2)
101 | fig.append_trace(val_loss_trace, 1, 2)
102 |
103 | plot(fig, filename=path_visu, auto_open=auto_open)
104 |
105 | return train_acc1, val_acc1
106 |
107 | # Display accuracy & loss of one exp
108 | def visu_exps(list_path_logger, path_visu, auto_open=True):
109 | fig = tools.make_subplots(rows=2,
110 | cols=2,
111 | subplot_titles=('Val accuracy top1',
112 | 'Val loss',
113 | 'Train accuracy top1',
114 | 'Train loss'))
115 | num_xp = len(list_path_logger)
116 | if num_xp < 3: # cl.scales not accept
117 | num_xp = 3
118 | list_color = cl.scales[str(num_xp)]['qual']['Paired']
119 |
120 | for i, path_logger in enumerate(list_path_logger):
121 | name = path_logger.split('/')[-2]
122 |
123 | xp = Experiment.from_json(path_logger)
124 | xp.logged['val']['acc1_oe'] = load_accs_oe(path_logger)
125 | train_acc1 = sort(xp.logged['train']['acc1'])
126 | val_acc1 = sort(xp.logged['val']['acc1_oe'])
127 | train_loss = sort(xp.logged['train']['loss'])
128 | val_loss = sort(xp.logged['val']['loss'])
129 | train_data_x = list(range(1, len(train_acc1)+1))
130 | val_data_x = list(range(1, len(val_acc1)+1))
131 |
132 | train_acc1_trace = go.Scatter(
133 | x=train_data_x,
134 | y=train_acc1,
135 | name='train acc: '+name,
136 | line=dict(
137 | color=list_color[i]
138 | )
139 | )
140 | val_acc1_trace = go.Scatter(
141 | x=val_data_x,
142 | y=val_acc1,
143 | name='val acc: '+name,
144 | line=dict(
145 | color=list_color[i]
146 | )
147 | )
148 | best_val_acc1_trace = go.Scatter(
149 | x=[np.argmax(val_acc1)+1],
150 | y=[max(val_acc1)],
151 | mode='markers',
152 | name='best val acc: '+name,
153 | marker = dict(
154 | color = list_color[i],
155 | size = 10
156 | )
157 | )
158 |
159 | val_loss_trace = go.Scatter(
160 | x=val_data_x,
161 | y=val_loss,
162 | name='val loss: '+name,
163 | line=dict(
164 | color=list_color[i]
165 | )
166 | )
167 | train_loss_trace = go.Scatter(
168 | x=train_data_x,
169 | y=train_loss,
170 | name='train loss: '+name,
171 | line=dict(
172 | color=list_color[i]
173 | )
174 | )
175 |
176 | fig.append_trace(val_acc1_trace, 1, 1)
177 | fig.append_trace(best_val_acc1_trace, 1, 1)
178 | fig.append_trace(train_acc1_trace, 2, 1)
179 |
180 | fig.append_trace(val_loss_trace, 1, 2)
181 | fig.append_trace(train_loss_trace, 2, 2)
182 |
183 | plot(fig, filename=path_visu, auto_open=auto_open)
184 |
185 | def main_one_exp(dir_logs, path_visu=None, refresh_freq=60):
186 | if path_visu is None:
187 | path_visu = os.path.join(dir_logs, 'visu.html')
188 |
189 | path_logger = os.path.join(dir_logs, 'logger.json')
190 |
191 | i = 1
192 | print('Create visu to ' + path_visu)
193 | while True:
194 | train_acc1, val_acc1 = visu_one_exp(path_logger, path_visu, auto_open=(i==1))
195 | print('# Visu iteration (refresh every {} sec): {}'.format(refresh_freq, i))
196 | print('Max Val OpenEnded-Accuracy Top1: {}'.format(max(val_acc1)))
197 | print('Max Train Accuracy Top1: {}'.format(max(train_acc1)))
198 | i += 1
199 | time.sleep(refresh_freq)
200 |
201 | def main_exps(list_dir_logs, path_visu=None, refresh_freq=60):
202 | if path_visu is None:
203 | path_visu = os.path.join(os.path.dirname(list_dir_logs[0]), 'visu.html')
204 |
205 | list_path_logger = []
206 | for dir_logs in list_dir_logs:
207 | list_path_logger.append(os.path.join(dir_logs, 'logger.json'))
208 |
209 | i = 1
210 | print('Create visu to ' + path_visu)
211 | while True:
212 | visu_exps(list_path_logger, path_visu, auto_open=(i==1))
213 | print('# Visu iteration (refresh every {} sec): {}'.format(refresh_freq, i))
214 | i += 1
215 | time.sleep(refresh_freq)
216 |
217 |
218 | ##########################################################################
219 | # Main
220 | ##########################################################################
221 |
222 | parser = argparse.ArgumentParser(
223 | description='Create html visu files',
224 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
225 |
226 | parser.add_argument('--dir_logs', type=str,
227 | help='''First mode: dir to logs of an experiment (ex: logs/vqa/mutan)'''
228 | '''Second mode: add several dirs to create a comparativ visualisation (ex: logs/vqa/mutan,logs/vqa/mlb)''')
229 | parser.add_argument('--refresh_freq', '-f', default=60, type=int,
230 | help='refresh frequency in seconds')
231 | parser.add_argument('--path_visu', default=None,
232 | help='path to the html file (default: visu.html in dir_logs)')
233 |
234 | def main():
235 | global args
236 | args = parser.parse_args()
237 |
238 | list_dir_logs = args.dir_logs.split(',')
239 |
240 | if len(list_dir_logs) == 1:
241 | main_one_exp(args.dir_logs,
242 | path_visu=args.path_visu,
243 | refresh_freq=args.refresh_freq)
244 | else:
245 | main_exps(list_dir_logs,
246 | path_visu=args.path_visu,
247 | refresh_freq=args.refresh_freq)
248 |
249 | if __name__ == '__main__':
250 | main()
251 |
--------------------------------------------------------------------------------
/vqa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/vqa/__init__.py
--------------------------------------------------------------------------------
/vqa/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .vqa import factory as factory_VQA
2 | from .coco import COCOImages
3 | from .vgenome import VisualGenomeImages
--------------------------------------------------------------------------------
/vqa/datasets/coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | # from mpi4py import MPI
3 | import numpy as np
4 | import h5py
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 |
8 | from ..lib import utils
9 | from .images import ImagesFolder, AbstractImagesDataset, default_loader
10 | from .features import FeaturesDataset
11 |
12 | def split_name(data_split):
13 | if data_split in ['train', 'val']:
14 | return data_split + '2014'
15 | elif data_split == 'test':
16 | return data_split + '2015'
17 | else:
18 | assert False, 'data_split {} not exists'.format(data_split)
19 |
20 |
21 | class COCOImages(AbstractImagesDataset):
22 |
23 | def __init__(self, data_split, opt, transform=None, loader=default_loader):
24 | self.split_name = split_name(data_split)
25 | super(COCOImages, self).__init__(data_split, opt, transform, loader)
26 | self.dir_split = self.get_dir_data()
27 | self.dataset = ImagesFolder(self.dir_split, transform=self.transform, loader=self.loader)
28 | self.name_to_index = self._load_name_to_index()
29 |
30 | def get_dir_data(self):
31 | return os.path.join(self.dir_raw, self.split_name)
32 |
33 | def _raw(self):
34 | if self.data_split in ['train', 'val']:
35 | os.system('wget http://msvocds.blob.core.windows.net/coco2014/{}.zip -P {}'.format(self.split_name, self.dir_raw))
36 | elif self.data_split == 'test':
37 | os.system('wget http://msvocds.blob.core.windows.net/coco2015/test2015.zip -P '+self.dir_raw)
38 | else:
39 | assert False, 'data_split {} not exists'.format(self.data_split)
40 | os.system('unzip '+os.path.join(self.dir_raw, self.split_name+'.zip')+' -d '+self.dir_raw)
41 |
42 | def _load_name_to_index(self):
43 | self.name_to_index = {name:index for index, name in enumerate(self.dataset.imgs)}
44 | return self.name_to_index
45 |
46 | def __getitem__(self, index):
47 | item = self.dataset[index]
48 | item['name'] = os.path.join(self.split_name, item['name'])
49 | return item
50 |
51 | def __len__(self):
52 | return len(self.dataset)
53 |
54 |
55 | class COCOTrainval(data.Dataset):
56 |
57 | def __init__(self, trainset, valset):
58 | self.trainset = trainset
59 | self.valset = valset
60 |
61 | def __getitem__(self, index):
62 | if index < len(self.trainset):
63 | item = self.trainset[index]
64 | else:
65 | item = self.valset[index - len(self.trainset)]
66 | return item
67 |
68 | def get_by_name(self, image_name):
69 | if image_name in self.trainset.name_to_index:
70 | index = self.trainset.name_to_index[image_name]
71 | item = self.trainset[index]
72 | return item
73 | elif image_name in self.valset.name_to_index:
74 | index = self.valset.name_to_index[image_name]
75 | item = self.valset[index]
76 | return item
77 | else:
78 | raise ValueError
79 |
80 | def __len__(self):
81 | return len(self.trainset) + len(self.valset)
82 |
83 |
84 | def default_transform(size):
85 | transform = transforms.Compose([
86 | transforms.Scale(size),
87 | transforms.CenterCrop(size),
88 | transforms.ToTensor(),
89 | transforms.Normalize(mean=[0.485, 0.456, 0.406], # resnet imagnet
90 | std=[0.229, 0.224, 0.225])
91 | ])
92 | return transform
93 |
94 | def factory(data_split, opt, transform=None):
95 | if data_split == 'trainval':
96 | trainset = factory('train', opt, transform)
97 | valset = factory('val', opt, transform)
98 | return COCOTrainval(trainset, valset)
99 | elif data_split in ['train', 'val', 'test']:
100 | if opt['mode'] == 'img':
101 | if transform is None:
102 | transform = default_transform(opt['size'])
103 | return COCOImages(data_split, opt, transform)
104 | elif opt['mode'] in ['noatt', 'att']:
105 | return FeaturesDataset(data_split, opt)
106 | else:
107 | raise ValueError
108 | else:
109 | raise ValueError
110 |
--------------------------------------------------------------------------------
/vqa/datasets/features.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import h5py
4 | import torch
5 | import torch.utils.data as data
6 |
7 | class FeaturesDataset(data.Dataset):
8 |
9 | def __init__(self, data_split, opt):
10 | self.data_split = data_split
11 | self.opt = opt
12 | self.dir_extract = os.path.join(self.opt['dir'],
13 | 'extract',
14 | 'arch,' + self.opt['arch'])
15 | if 'size' in opt:
16 | self.dir_extract += '_size,' + str(opt['size'])
17 | self.path_hdf5 = os.path.join(self.dir_extract,
18 | data_split + 'set.hdf5')
19 | assert os.path.isfile(self.path_hdf5), \
20 | 'File not found in {}, you must extract the features first with extract.py'.format(self.path_hdf5)
21 | self.hdf5_file = h5py.File(self.path_hdf5, 'r')#, driver='mpio', comm=MPI.COMM_WORLD)
22 | self.dataset_features = self.hdf5_file[self.opt['mode']]
23 | self.index_to_name, self.name_to_index = self._load_dicts()
24 |
25 | def _load_dicts(self):
26 | self.path_fname = os.path.join(self.dir_extract,
27 | self.data_split + 'set.txt')
28 | with open(self.path_fname, 'r') as handle:
29 | self.index_to_name = handle.readlines()
30 | self.index_to_name = [name[:-1] for name in self.index_to_name] # remove char '\n'
31 | self.name_to_index = {name:index for index,name in enumerate(self.index_to_name)}
32 | return self.index_to_name, self.name_to_index
33 |
34 | def __getitem__(self, index):
35 | item = {}
36 | item['name'] = self.index_to_name[index]
37 | item['visual'] = self.get_features(index)
38 | #item = torch.Tensor(self.get_features(index))
39 | return item
40 |
41 | def get_features(self, index):
42 | return torch.Tensor(self.dataset_features[index])
43 |
44 | def get_features_old(self, index):
45 | try:
46 | self.features_array
47 | except AttributeError:
48 | if self.opt['mode'] == 'att':
49 | self.features_array = np.zeros((2048,14,14), dtype='f')
50 | elif self.opt['mode'] == 'noatt':
51 | self.features_array = np.zeros((2048), dtype='f')
52 |
53 | if self.opt['mode'] == 'att':
54 | self.dataset_features.read_direct(self.features_array,
55 | np.s_[index,:2048,:14,:14],
56 | np.s_[:2048,:14,:14])
57 | elif self.opt['mode'] == 'noatt':
58 | self.dataset_features.read_direct(self.features_array,
59 | np.s_[index,:2048],
60 | np.s_[:2048])
61 | return self.features_array
62 |
63 |
64 | def get_by_name(self, image_name):
65 | index = self.name_to_index[image_name]
66 | return self[index]
67 |
68 | def __len__(self):
69 | return self.dataset_features.shape[0]
70 |
--------------------------------------------------------------------------------
/vqa/datasets/images.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 |
7 | IMG_EXTENSIONS = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10 | ]
11 |
12 | def is_image_file(filename):
13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
14 |
15 | def make_dataset(dir):
16 | images = []
17 | for fname in os.listdir(dir):
18 | if is_image_file(fname):
19 | images.append(fname)
20 | return images
21 |
22 |
23 | def default_loader(path):
24 | return Image.open(path).convert('RGB')
25 |
26 |
27 | class ImagesFolder(data.Dataset):
28 |
29 | def __init__(self, root, transform=None, loader=default_loader):
30 | imgs = make_dataset(root)
31 | if len(imgs) == 0:
32 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
33 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
34 | self.root = root
35 | self.imgs = imgs
36 | self.transform = transform
37 | self.loader = loader
38 |
39 | def __getitem__(self, index):
40 | item = {}
41 | item['name'] = self.imgs[index]
42 | item['path'] = os.path.join(self.root, item['name'])
43 | if self.loader is not None:
44 | item['visual'] = self.loader(item['path'])
45 | if self.transform is not None:
46 | item['visual'] = self.transform(item['visual'])
47 | return item
48 |
49 | def __len__(self):
50 | return len(self.imgs)
51 |
52 |
53 | class AbstractImagesDataset(data.Dataset):
54 |
55 | def __init__(self, data_split, opt, transform=None, loader=default_loader):
56 | self.data_split = data_split
57 | self.opt = opt
58 | self.transform = transform
59 | self.loader = loader
60 | self.dir_raw = os.path.join(self.opt['dir'], 'raw')
61 |
62 | if not os.path.exists(self.get_dir_data()):
63 | self._raw()
64 |
65 | def get_dir_data(self):
66 | return self.dir_raw
67 |
68 | def get_by_name(self, image_name):
69 | index = self.name_to_index[image_name]
70 | return self[index]
71 |
72 | def _raw(self):
73 | raise NotImplementedError
74 |
75 | def __getitem__(self, index):
76 | raise NotImplementedError
77 |
78 | def __len__(self):
79 | raise NotImplementedError
80 |
--------------------------------------------------------------------------------
/vqa/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import torch
4 | import torch.utils.data as data
5 | import copy
6 |
7 | class AbstractVQADataset(data.Dataset):
8 |
9 | def __init__(self, data_split, opt, dataset_img=None):
10 | self.data_split = data_split
11 | self.opt = copy.copy(opt)
12 | self.dataset_img = dataset_img
13 |
14 | self.dir_raw = os.path.join(self.opt['dir'], 'raw')
15 | if not os.path.exists(self.dir_raw):
16 | self._raw()
17 |
18 | self.dir_interim = os.path.join(self.opt['dir'], 'interim')
19 | if not os.path.exists(self.dir_interim):
20 | self._interim()
21 |
22 | self.dir_processed = os.path.join(self.opt['dir'], 'processed')
23 | self.subdir_processed = self.subdir_processed()
24 | if not os.path.exists(self.subdir_processed):
25 | self._processed()
26 |
27 | path_wid_to_word = os.path.join(self.subdir_processed, 'wid_to_word.pickle')
28 | path_word_to_wid = os.path.join(self.subdir_processed, 'word_to_wid.pickle')
29 | path_aid_to_ans = os.path.join(self.subdir_processed, 'aid_to_ans.pickle')
30 | path_ans_to_aid = os.path.join(self.subdir_processed, 'ans_to_aid.pickle')
31 | path_dataset = os.path.join(self.subdir_processed, self.data_split+'set.pickle')
32 |
33 | with open(path_wid_to_word, 'rb') as handle:
34 | self.wid_to_word = pickle.load(handle)
35 |
36 | with open(path_word_to_wid, 'rb') as handle:
37 | self.word_to_wid = pickle.load(handle)
38 |
39 | with open(path_aid_to_ans, 'rb') as handle:
40 | self.aid_to_ans = pickle.load(handle)
41 |
42 | with open(path_ans_to_aid, 'rb') as handle:
43 | self.ans_to_aid = pickle.load(handle)
44 |
45 | with open(path_dataset, 'rb') as handle:
46 | self.dataset = pickle.load(handle)
47 |
48 | def _raw(self):
49 | raise NotImplementedError
50 |
51 | def _interim(self):
52 | raise NotImplementedError
53 |
54 | def _processed(self):
55 | raise NotImplementedError
56 |
57 | def __getitem__(self, index):
58 | raise NotImplementedError
59 |
60 | def subdir_processed(self):
61 | subdir = 'nans,' + str(self.opt['nans']) \
62 | + '_maxlength,' + str(self.opt['maxlength']) \
63 | + '_minwcount,' + str(self.opt['minwcount']) \
64 | + '_nlp,' + self.opt['nlp'] \
65 | + '_pad,' + self.opt['pad'] \
66 | + '_trainsplit,' + self.opt['trainsplit']
67 | subdir = os.path.join(self.dir_processed, subdir)
68 | return subdir
--------------------------------------------------------------------------------
/vqa/datasets/vgenome.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 | import copy
5 |
6 | from .images import ImagesFolder, AbstractImagesDataset, default_loader
7 | from .features import FeaturesDataset
8 | from .vgenome_interim import vgenome_interim
9 | from .vgenome_processed import vgenome_processed
10 | from .coco import default_transform
11 | from .utils import AbstractVQADataset
12 |
13 | def raw(dir_raw):
14 | dir_img = os.path.join(dir_raw, 'images')
15 | os.system('wget http://visualgenome.org/static/data/dataset/image_data.json.zip -P '+dir_raw)
16 | os.system('wget http://visualgenome.org/static/data/dataset/question_answers.json.zip -P '+dir_raw)
17 | os.system('wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -P '+dir_raw)
18 | os.system('wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -P '+dir_raw)
19 |
20 | os.system('unzip '+os.path.join(dir_raw, 'image_data.json.zip')+' -d '+dir_raw)
21 | os.system('unzip '+os.path.join(dir_raw, 'question_answers.json.zip')+' -d '+dir_raw)
22 | os.system('unzip '+os.path.join(dir_raw, 'images.zip')+' -d '+dir_raw)
23 | os.system('unzip '+os.path.join(dir_raw, 'images2.zip')+' -d '+dir_raw)
24 |
25 | os.system('mv '+os.path.join(dir_raw, 'VG_100K')+' '+dir_img)
26 |
27 | #os.system('mv '+os.path.join(dir_raw, 'VG_100K_2', '*.jpg')+' '+dir_img)
28 | os.system('find '+os.path.join(dir_raw, 'VG_100K_2')+' -type f -name \'*\' -exec mv {} '+dir_img+' \\;')
29 | os.system('rm -rf '+os.path.join(dir_raw, 'VG_100K_2'))
30 |
31 | # remove images with 0 octet in a ugly but efficient way :')
32 | #print('for f in $(ls -lh '+dir_img+' | grep " 0 " | cut -s -f14 --delimiter=" "); do rm '+dir_img+'/${f}; done;')
33 | os.system('for f in $(ls -lh '+dir_img+' | grep " 0 " | cut -s -f14 --delimiter=" "); do echo '+dir_img+'/${f}; done;')
34 | os.system('for f in $(ls -lh '+dir_img+' | grep " 0 " | cut -s -f14 --delimiter=" "); do rm '+dir_img+'/${f}; done;')
35 |
36 |
37 | class VisualGenome(AbstractVQADataset):
38 |
39 | def __init__(self, data_split, opt, dataset_img=None):
40 | super(VisualGenome, self).__init__(data_split, opt, dataset_img)
41 |
42 | def __getitem__(self, index):
43 | item_qa = self.dataset[index]
44 | item = {}
45 | if self.dataset_img is not None:
46 | item_img = self.dataset_img.get_by_name(item_qa['image_name'])
47 | item['visual'] = item_img['visual']
48 | # DEBUG
49 | #item['visual_debug'] = item_qa['image_name']
50 | item['question'] = torch.LongTensor(item_qa['question_wids'])
51 | # DEBUG
52 | #item['question_debug'] = item_qa['question']
53 | item['question_id'] = item_qa['question_id']
54 | item['answer'] = item_qa['answer_aid']
55 | # DEBUG
56 | #item['answer_debug'] = item_qa['answer']
57 | return item
58 |
59 | def _raw(self):
60 | raw(self.dir_raw)
61 |
62 | def _interim(self):
63 | vgenome_interim(self.opt)
64 |
65 | def _processed(self):
66 | vgenome_processed(self.opt)
67 |
68 | def __len__(self):
69 | return len(self.dataset)
70 |
71 |
72 | class VisualGenomeImages(AbstractImagesDataset):
73 |
74 | def __init__(self, data_split, opt, transform=None, loader=default_loader):
75 | super(VisualGenomeImages, self).__init__(data_split, opt, transform, loader)
76 | self.dir_img = os.path.join(self.dir_raw, 'images')
77 | self.dataset = ImagesFolder(self.dir_img, transform=self.transform, loader=self.loader)
78 | self.name_to_index = self._load_name_to_index()
79 |
80 | def _raw(self):
81 | raw(self.dir_raw)
82 |
83 | def _load_name_to_index(self):
84 | self.name_to_index = {name:index for index, name in enumerate(self.dataset.imgs)}
85 | return self.name_to_index
86 |
87 | def __getitem__(self, index):
88 | item = self.dataset[index]
89 | return item
90 |
91 | def __len__(self):
92 | return len(self.dataset)
93 |
94 |
95 | def factory(opt, vqa=False, transform=None):
96 |
97 | if vqa:
98 | dataset_img = factory(opt, vqa=False, transform=transform)
99 | return VisualGenome('train', opt, dataset_img)
100 |
101 | if opt['mode'] == 'img':
102 | if transform is None:
103 | transform = default_transform(opt['size'])
104 |
105 | elif opt['mode'] in ['noatt', 'att']:
106 | return FeaturesDataset('train', opt)
107 |
108 | else:
109 | raise ValueError
110 |
111 |
112 |
--------------------------------------------------------------------------------
/vqa/datasets/vgenome_interim.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import argparse
4 |
5 | # def get_image_path(subtype='train2014', image_id='1', format='%s/COCO_%s_%012d.jpg'):
6 | # return format%(subtype, subtype, image_id)
7 |
8 | def interim(questions_annotations):
9 | data = []
10 | for i in range(len(questions_annotations)):
11 | qa_img = questions_annotations[i]
12 | qa_img_id = qa_img['id']
13 | for j in range(len(qa_img['qas'])):
14 | qa = qa_img['qas'][j]
15 | row = {}
16 | row['question_id'] = qa['qa_id']
17 | row['image_id'] = qa_img_id
18 | row['image_name'] = str(qa_img_id) + '.jpg'
19 | row['question'] = qa['question']
20 | row['answer'] = qa['answer']
21 | data.append(row)
22 | return data
23 |
24 | def vgenome_interim(params):
25 | '''
26 | Put the VisualGenomme VQA data into single json file in data/interim
27 | or train, val, trainval : [[question_id, image_id, question, answer] ... ]
28 | '''
29 | path_qa = os.path.join(params['dir'], 'interim', 'questions_annotations.json')
30 | os.system('mkdir -p ' + os.path.join(params['dir'], 'interim'))
31 |
32 | print('Loading annotations and questions...')
33 | questions_annotations = json.load(open(os.path.join(params['dir'], 'raw', 'question_answers.json'), 'r'))
34 |
35 | data = interim(questions_annotations)
36 | print('Questions number %d'%len(data))
37 | print('Write', path_qa)
38 | json.dump(data, open(path_qa, 'w'))
39 |
40 | if __name__ == "__main__":
41 |
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--dir_vg', default='data/visualgenome', type=str, help='Path to visual genome data directory')
44 | args = parser.parse_args()
45 | params = vars(args)
46 | vgenome_interim(params)
47 |
--------------------------------------------------------------------------------
/vqa/datasets/vgenome_processed.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess an interim json data files
3 | into one preprocess hdf5/json data files.
4 | Caption: Use nltk, or mcb, or split function to get tokens.
5 | """
6 | from random import shuffle, seed
7 | import sys
8 | import os.path
9 | import argparse
10 | import numpy as np
11 | import scipy.io
12 | import pdb
13 | import h5py
14 | from nltk.tokenize import word_tokenize
15 | import json
16 | import csv
17 | import re
18 | import math
19 | import pickle
20 |
21 | from .vqa_processed import get_top_answers, remove_examples, tokenize, tokenize_mcb, \
22 | preprocess_questions, remove_long_tail_train, \
23 | encode_question, encode_answer
24 |
25 | def preprocess_answers(examples, nlp='nltk'):
26 | print('Example of modified answers after preprocessing:')
27 | for i, ex in enumerate(examples):
28 | s = ex['answer']
29 | if nlp == 'nltk':
30 | ex['answer'] = " ".join(word_tokenize(str(s).lower()))
31 | elif nlp == 'mcb':
32 | ex['answer'] = " ".join(tokenize_mcb(s))
33 | else:
34 | ex['answer'] = " ".join(tokenize(s))
35 | if i < 10: print(s, 'became', "->"+ex['answer']+"<-")
36 | if i>0 and i % 1000 == 0:
37 | sys.stdout.write("processing %d/%d (%.2f%% done) \r" % (i, len(examples), i*100.0/len(examples)) )
38 | sys.stdout.flush()
39 | return examples
40 |
41 | def build_csv(path, examples, split='train', delimiter_col='~', delimiter_number='|'):
42 | with open(path, 'wb') as f:
43 | writer = csv.writer(f, delimiter=delimiter_col)
44 | for ex in examples:
45 | import ipdb; ipdb.set_trace()
46 | row = []
47 | row.append(ex['question_id'])
48 | row.append(ex['question'])
49 | row.append(delimiter_number.join(ex['question_words_UNK']))
50 | row.append(delimiter_number.join(ex['question_wids']))
51 |
52 | row.append(ex['image_id'])
53 |
54 | if split in ['train','val','trainval']:
55 | row.append(ex['answer_aid'])
56 | row.append(ex['answer'])
57 | writer.writerow(row)
58 |
59 | def vgenome_processed(params):
60 |
61 | #####################################################
62 | ## Read input files
63 | #####################################################
64 |
65 | path_train = os.path.join(params['dir'], 'interim', 'questions_annotations.json')
66 |
67 | # An example is a tuple (question, image, answer)
68 | # /!\ test and test-dev have no answer
69 | trainset = json.load(open(path_train, 'r'))
70 |
71 | #####################################################
72 | ## Preprocess examples (questions and answers)
73 | #####################################################
74 |
75 | trainset = preprocess_answers(trainset, params['nlp'])
76 |
77 | top_answers = get_top_answers(trainset, params['nans'])
78 | aid_to_ans = {i+1:w for i,w in enumerate(top_answers)}
79 | ans_to_aid = {w:i+1 for i,w in enumerate(top_answers)}
80 |
81 | # Remove examples if answer is not in top answers
82 | #trainset = remove_examples(trainset, ans_to_aid)
83 |
84 | # Add 'question_words' to the initial tuple
85 | trainset = preprocess_questions(trainset, params['nlp'])
86 |
87 | # Also process top_words which contains a UNK char
88 | trainset, top_words = remove_long_tail_train(trainset, params['minwcount'])
89 | wid_to_word = {i+1:w for i,w in enumerate(top_words)}
90 | word_to_wid = {w:i+1 for i,w in enumerate(top_words)}
91 |
92 | #examples_test = remove_long_tail_test(examples_test, word_to_wid)
93 |
94 | trainset = encode_question(trainset, word_to_wid, params['maxlength'], params['pad'])
95 |
96 | trainset = encode_answer(trainset, ans_to_aid)
97 |
98 | #####################################################
99 | ## Write output files
100 | #####################################################
101 |
102 | # Paths to output files
103 | # Ex: data/vqa/preprocess/nans,3000_maxlength,15_..._trainsplit,train_testsplit,val/id_to_word.json
104 | subdirname = 'nans,'+str(params['nans'])
105 | for param in ['maxlength', 'minwcount', 'nlp', 'pad', 'trainsplit']:
106 | subdirname += '_' + param + ',' + str(params[param])
107 | os.system('mkdir -p ' + os.path.join(params['dir'], 'processed', subdirname))
108 |
109 | path_wid_to_word = os.path.join(params['dir'], 'processed', subdirname, 'wid_to_word.pickle')
110 | path_word_to_wid = os.path.join(params['dir'], 'processed', subdirname, 'word_to_wid.pickle')
111 | path_aid_to_ans = os.path.join(params['dir'], 'processed', subdirname, 'aid_to_ans.pickle')
112 | path_ans_to_aid = os.path.join(params['dir'], 'processed', subdirname, 'ans_to_aid.pickle')
113 | #path_csv_train = os.path.join(params['dir'], 'processed', subdirname, 'train.csv')
114 | path_trainset = os.path.join(params['dir'], 'processed', subdirname, 'trainset.pickle')
115 |
116 | print('Write wid_to_word to', path_wid_to_word)
117 | with open(path_wid_to_word, 'wb') as handle:
118 | pickle.dump(wid_to_word, handle)
119 |
120 | print('Write word_to_wid to', path_word_to_wid)
121 | with open(path_word_to_wid, 'wb') as handle:
122 | pickle.dump(word_to_wid, handle)
123 |
124 | print('Write aid_to_ans to', path_aid_to_ans)
125 | with open(path_aid_to_ans, 'wb') as handle:
126 | pickle.dump(aid_to_ans, handle)
127 |
128 | print('Write ans_to_aid to', path_ans_to_aid)
129 | with open(path_ans_to_aid, 'wb') as handle:
130 | pickle.dump(ans_to_aid, handle)
131 |
132 | print('Write trainset to', path_trainset)
133 | with open(path_trainset, 'wb') as handle:
134 | pickle.dump(trainset, handle)
135 |
136 |
137 | if __name__ == "__main__":
138 | parser = argparse.ArgumentParser()
139 | parser.add_argument('--dir_vg',
140 | default='data/visualgenome',
141 | type=str,
142 | help='Root directory containing raw, interim and processed directories'
143 | )
144 | parser.add_argument('--nans',
145 | default=10000,
146 | type=int,
147 | help='Number of top answers for the final classifications'
148 | )
149 | parser.add_argument('--maxlength',
150 | default=26,
151 | type=int,
152 | help='Max number of words in a caption. Captions longer get clipped'
153 | )
154 | parser.add_argument('--minwcount',
155 | default=0,
156 | type=int,
157 | help='Words that occur less than that are removed from vocab'
158 | )
159 | parser.add_argument('--nlp',
160 | default='mcb',
161 | type=str,
162 | help='Token method ; Options: nltk | mcb | naive'
163 | )
164 | parser.add_argument('--pad',
165 | default='left',
166 | type=str,
167 | help='Padding ; Options: right (finish by zeros) | left (begin by zeros)'
168 | )
169 | args = parser.parse_args()
170 | params = vars(args)
171 | vgenome_processed(params)
--------------------------------------------------------------------------------
/vqa/datasets/vqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import torch
4 | import torch.utils.data as data
5 | import copy
6 | import numpy as np
7 |
8 | from ..lib import utils
9 | from ..lib.dataloader import DataLoader
10 | from .utils import AbstractVQADataset
11 | from .vqa_interim import vqa_interim
12 | from .vqa2_interim import vqa_interim as vqa2_interim
13 | from .vqa_processed import vqa_processed
14 | from . import coco
15 | from . import vgenome
16 |
17 |
18 | class AbstractVQA(AbstractVQADataset):
19 |
20 | def __init__(self, data_split, opt, dataset_img=None):
21 | super(AbstractVQA, self).__init__(data_split, opt, dataset_img)
22 |
23 | if 'train' not in self.data_split: # means self.data_split is 'val' or 'test'
24 | self.opt['samplingans'] = False
25 | assert 'samplingans' in self.opt, \
26 | "opt['vqa'] does not have 'samplingans' "\
27 | "entry. Set it to True or False."
28 |
29 | if self.data_split == 'test':
30 | path_testdevset = os.path.join(self.subdir_processed, 'testdevset.pickle')
31 | with open(path_testdevset, 'rb') as handle:
32 | self.testdevset_vqa = pickle.load(handle)
33 | self.is_qid_testdev = {}
34 | for i in range(len(self.testdevset_vqa)):
35 | qid = self.testdevset_vqa[i]['question_id']
36 | self.is_qid_testdev[qid] = True
37 |
38 | def _raw(self):
39 | raise NotImplementedError
40 |
41 | def _interim(self):
42 | raise NotImplementedError
43 |
44 | def _processed(self):
45 | raise NotImplementedError
46 |
47 | def __getitem__(self, index):
48 | item = {}
49 | # TODO: better handle cascade of dict items
50 | item_vqa = self.dataset[index]
51 |
52 | # Process Visual (image or features)
53 | if self.dataset_img is not None:
54 | item_img = self.dataset_img.get_by_name(item_vqa['image_name'])
55 | item['visual'] = item_img['visual']
56 |
57 | # Process Question (word token)
58 | item['question_id'] = item_vqa['question_id']
59 | item['question'] = torch.LongTensor(item_vqa['question_wids'])
60 |
61 | if self.data_split == 'test':
62 | if item['question_id'] in self.is_qid_testdev:
63 | item['is_testdev'] = True
64 | else:
65 | item['is_testdev'] = False
66 | else:
67 | ## Process Answer if exists
68 | if self.opt['samplingans']:
69 | proba = item_vqa['answers_count']
70 | proba = proba / np.sum(proba)
71 | item['answer'] = int(np.random.choice(item_vqa['answers_aid'], p=proba))
72 | else:
73 | item['answer'] = item_vqa['answer_aid']
74 |
75 | return item
76 |
77 | def __len__(self):
78 | return len(self.dataset)
79 |
80 | def num_classes(self):
81 | return len(self.aid_to_ans)
82 |
83 | def vocab_words(self):
84 | return list(self.wid_to_word.values())
85 |
86 | def vocab_answers(self):
87 | return self.aid_to_ans
88 |
89 | def data_loader(self, batch_size=10, num_workers=4, shuffle=False):
90 | return DataLoader(self,
91 | batch_size=batch_size, shuffle=shuffle,
92 | num_workers=num_workers, pin_memory=True)
93 |
94 | def split_name(self, testdev=False):
95 | if testdev:
96 | return 'test-dev2015'
97 | if self.data_split in ['train', 'val']:
98 | return self.data_split+'2014'
99 | elif self.data_split == 'test':
100 | return self.data_split+'2015'
101 | elif self.data_split == 'testdev':
102 | return 'test-dev2015'
103 | else:
104 | assert False, 'Wrong data_split: {}'.format(self.data_split)
105 |
106 | def subdir_processed(self):
107 | subdir = 'nans,' + str(self.opt['nans']) \
108 | + '_maxlength,' + str(self.opt['maxlength']) \
109 | + '_minwcount,' + str(self.opt['minwcount']) \
110 | + '_nlp,' + self.opt['nlp'] \
111 | + '_pad,' + self.opt['pad'] \
112 | + '_trainsplit,' + self.opt['trainsplit']
113 | subdir = os.path.join(self.dir_processed, subdir)
114 | return subdir
115 |
116 |
117 | class VQA(AbstractVQA):
118 |
119 | def __init__(self, data_split, opt, dataset_img=None):
120 | super(VQA, self).__init__(data_split, opt, dataset_img)
121 |
122 | def _raw(self):
123 | dir_zip = os.path.join(self.dir_raw, 'zip')
124 | dir_ann = os.path.join(self.dir_raw, 'annotations')
125 | os.system('mkdir -p '+dir_zip)
126 | os.system('mkdir -p '+dir_ann)
127 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Questions_Train_mscoco.zip -P '+dir_zip)
128 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Questions_Val_mscoco.zip -P '+dir_zip)
129 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Questions_Test_mscoco.zip -P '+dir_zip)
130 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Annotations_Train_mscoco.zip -P '+dir_zip)
131 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Annotations_Val_mscoco.zip -P '+dir_zip)
132 | os.system('unzip '+os.path.join(dir_zip, 'Questions_Train_mscoco.zip')+' -d '+dir_ann)
133 | os.system('unzip '+os.path.join(dir_zip, 'Questions_Val_mscoco.zip')+' -d '+dir_ann)
134 | os.system('unzip '+os.path.join(dir_zip, 'Questions_Test_mscoco.zip')+' -d '+dir_ann)
135 | os.system('unzip '+os.path.join(dir_zip, 'Annotations_Train_mscoco.zip')+' -d '+dir_ann)
136 | os.system('unzip '+os.path.join(dir_zip, 'Annotations_Val_mscoco.zip')+' -d '+dir_ann)
137 |
138 | def _interim(self):
139 | vqa_interim(self.opt['dir'])
140 |
141 | def _processed(self):
142 | vqa_processed(self.opt)
143 |
144 |
145 | class VQA2(AbstractVQA):
146 |
147 | def __init__(self, data_split, opt, dataset_img=None):
148 | super(VQA2, self).__init__(data_split, opt, dataset_img)
149 |
150 | def _raw(self):
151 | dir_zip = os.path.join(self.dir_raw, 'zip')
152 | dir_ann = os.path.join(self.dir_raw, 'annotations')
153 | os.system('mkdir -p '+dir_zip)
154 | os.system('mkdir -p '+dir_ann)
155 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip -P '+dir_zip)
156 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip -P '+dir_zip)
157 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip -P '+dir_zip)
158 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip -P '+dir_zip)
159 | os.system('wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip -P '+dir_zip)
160 | os.system('unzip '+os.path.join(dir_zip, 'v2_Questions_Train_mscoco.zip')+' -d '+dir_ann)
161 | os.system('unzip '+os.path.join(dir_zip, 'v2_Questions_Val_mscoco.zip')+' -d '+dir_ann)
162 | os.system('unzip '+os.path.join(dir_zip, 'v2_Questions_Test_mscoco.zip')+' -d '+dir_ann)
163 | os.system('unzip '+os.path.join(dir_zip, 'v2_Annotations_Train_mscoco.zip')+' -d '+dir_ann)
164 | os.system('unzip '+os.path.join(dir_zip, 'v2_Annotations_Val_mscoco.zip')+' -d '+dir_ann)
165 | os.system('mv '+os.path.join(dir_ann, 'v2_mscoco_train2014_annotations.json')+' '
166 | +os.path.join(dir_ann, 'mscoco_train2014_annotations.json'))
167 | os.system('mv '+os.path.join(dir_ann, 'v2_mscoco_val2014_annotations.json')+' '
168 | +os.path.join(dir_ann, 'mscoco_val2014_annotations.json'))
169 | os.system('mv '+os.path.join(dir_ann, 'v2_OpenEnded_mscoco_train2014_questions.json')+' '
170 | +os.path.join(dir_ann, 'OpenEnded_mscoco_train2014_questions.json'))
171 | os.system('mv '+os.path.join(dir_ann, 'v2_OpenEnded_mscoco_val2014_questions.json')+' '
172 | +os.path.join(dir_ann, 'OpenEnded_mscoco_val2014_questions.json'))
173 | os.system('mv '+os.path.join(dir_ann, 'v2_OpenEnded_mscoco_test2015_questions.json')+' '
174 | +os.path.join(dir_ann, 'OpenEnded_mscoco_test2015_questions.json'))
175 | os.system('mv '+os.path.join(dir_ann, 'v2_OpenEnded_mscoco_test-dev2015_questions.json')+' '
176 | +os.path.join(dir_ann, 'OpenEnded_mscoco_test-dev2015_questions.json'))
177 |
178 | def _interim(self):
179 | vqa2_interim(self.opt['dir'])
180 |
181 | def _processed(self):
182 | vqa_processed(self.opt)
183 |
184 |
185 | class VQAVisualGenome(data.Dataset):
186 |
187 | def __init__(self, dataset_vqa, dataset_vgenome):
188 | self.dataset_vqa = dataset_vqa
189 | self.dataset_vgenome = dataset_vgenome
190 | self._filter_dataset_vgenome()
191 |
192 | def _filter_dataset_vgenome(self):
193 | print('-> Filtering dataset vgenome')
194 | data_vg = self.dataset_vgenome.dataset
195 | ans_to_aid = self.dataset_vqa.ans_to_aid
196 | word_to_wid = self.dataset_vqa.word_to_wid
197 | data_vg_new = []
198 | not_in = 0
199 | for i in range(len(data_vg)):
200 | if data_vg[i]['answer'] not in ans_to_aid:
201 | not_in += 1
202 | else:
203 | data_vg[i]['answer_aid'] = ans_to_aid[data_vg[i]['answer']]
204 | for j in range(data_vg[i]['seq_length']):
205 | word = data_vg[i]['question_words_UNK'][j]
206 | if word in word_to_wid:
207 | wid = word_to_wid[word]
208 | else:
209 | wid = word_to_wid['UNK']
210 | data_vg[i]['question_wids'][j] = wid
211 | data_vg_new.append(data_vg[i])
212 | print('-> {} / {} items removed'.format(not_in, len(data_vg)))
213 | self.dataset_vgenome.dataset = data_vg_new
214 | print('-> {} items left in visual genome'.format(len(self.dataset_vgenome)))
215 | print('-> {} items total in vqa+vg'.format(len(self)))
216 |
217 |
218 | def __getitem__(self, index):
219 | if index < len(self.dataset_vqa):
220 | item = self.dataset_vqa[index]
221 | #print('vqa')
222 | else:
223 | item = self.dataset_vgenome[index - len(self.dataset_vqa)]
224 | #print('vg')
225 | #import ipdb; ipdb.set_trace()
226 | return item
227 |
228 | def __len__(self):
229 | return len(self.dataset_vqa) + len(self.dataset_vgenome)
230 |
231 | def num_classes(self):
232 | return self.dataset_vqa.num_classes()
233 |
234 | def vocab_words(self):
235 | return self.dataset_vqa.vocab_words()
236 |
237 | def vocab_answers(self):
238 | return self.dataset_vqa.vocab_answers()
239 |
240 | def data_loader(self, batch_size=10, num_workers=4, shuffle=False):
241 | return DataLoader(self,
242 | batch_size=batch_size, shuffle=shuffle,
243 | num_workers=num_workers, pin_memory=True)
244 |
245 | def split_name(self, testdev=False):
246 | return self.dataset_vqa.split_name(testdev=testdev)
247 |
248 |
249 | def factory(data_split, opt, opt_coco=None, opt_vgenome=None):
250 | dataset_img = None
251 |
252 | if opt_coco is not None:
253 | dataset_img = coco.factory(data_split, opt_coco)
254 |
255 | if opt['dataset'] == 'VQA' and '2' not in opt['dir']: # sanity check
256 | dataset_vqa = VQA(data_split, opt, dataset_img)
257 | elif opt['dataset'] == 'VQA2' and '2' in opt['dir']: # sanity check
258 | dataset_vqa = VQA2(data_split, opt, dataset_img)
259 | else:
260 | raise ValueError
261 |
262 | if opt_vgenome is not None:
263 | dataset_vgenome = vgenome.factory(opt_vgenome, vqa=True)
264 | return VQAVisualGenome(dataset_vqa, dataset_vgenome)
265 | else:
266 | return dataset_vqa
267 |
268 |
--------------------------------------------------------------------------------
/vqa/datasets/vqa2_interim.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import argparse
4 | from collections import Counter
5 |
6 | def get_subtype(split='train'):
7 | if split in ['train', 'val']:
8 | return split + '2014'
9 | else:
10 | return 'test2015'
11 |
12 | def get_image_name_old(subtype='train2014', image_id='1', format='%s/COCO_%s_%012d.jpg'):
13 | return format%(subtype, subtype, image_id)
14 |
15 | def get_image_name(subtype='train2014', image_id='1', format='COCO_%s_%012d.jpg'):
16 | return format%(subtype, image_id)
17 |
18 | def interim(questions, split='train', annotations=[]):
19 | print('Interim', split)
20 | data = []
21 | for i in range(len(questions)):
22 | row = {}
23 | row['question_id'] = questions[i]['question_id']
24 | row['image_name'] = get_image_name(get_subtype(split), questions[i]['image_id'])
25 | row['question'] = questions[i]['question']
26 | #row['MC_answer'] = questions[i]['multiple_choices']
27 | if split in ['train', 'val', 'trainval']:
28 | row['answer'] = annotations[i]['multiple_choice_answer']
29 | answers = []
30 | for ans in annotations[i]['answers']:
31 | answers.append(ans['answer'])
32 | row['answers_occurence'] = Counter(answers).most_common()
33 | data.append(row)
34 | return data
35 |
36 | def vqa_interim(dir_vqa):
37 | '''
38 | Put the VQA data into single json file in data/interim
39 | or train, val, trainval : [[question_id, image_name, question, MC_answer, answer] ... ]
40 | or test, test-dev : [[question_id, image_name, question, MC_answer] ... ]
41 | '''
42 |
43 | path_train_qa = os.path.join(dir_vqa, 'interim', 'train_questions_annotations.json')
44 | path_val_qa = os.path.join(dir_vqa, 'interim', 'val_questions_annotations.json')
45 | path_trainval_qa = os.path.join(dir_vqa, 'interim', 'trainval_questions_annotations.json')
46 | path_test_q = os.path.join(dir_vqa, 'interim', 'test_questions.json')
47 | path_testdev_q = os.path.join(dir_vqa, 'interim', 'testdev_questions.json')
48 |
49 | os.system('mkdir -p ' + os.path.join(dir_vqa, 'interim'))
50 |
51 | print('Loading annotations and questions...')
52 | annotations_train = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'mscoco_train2014_annotations.json'), 'r'))
53 | annotations_val = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'mscoco_val2014_annotations.json'), 'r'))
54 | questions_train = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'OpenEnded_mscoco_train2014_questions.json'), 'r'))
55 | questions_val = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'OpenEnded_mscoco_val2014_questions.json'), 'r'))
56 | questions_test = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'OpenEnded_mscoco_test2015_questions.json'), 'r'))
57 | questions_testdev = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'OpenEnded_mscoco_test-dev2015_questions.json'), 'r'))
58 |
59 | data_train = interim(questions_train['questions'], 'train', annotations_train['annotations'])
60 | print('Train size %d'%len(data_train))
61 | print('Write', path_train_qa)
62 | json.dump(data_train, open(path_train_qa, 'w'))
63 |
64 | data_val = interim(questions_val['questions'], 'val', annotations_val['annotations'])
65 | print('Val size %d'%len(data_val))
66 | print('Write', path_val_qa)
67 | json.dump(data_val, open(path_val_qa, 'w'))
68 |
69 | print('Concat. train and val')
70 | data_trainval = data_train + data_val
71 | print('Trainval size %d'%len(data_trainval))
72 | print('Write', path_trainval_qa)
73 | json.dump(data_trainval, open(path_trainval_qa, 'w'))
74 |
75 | data_testdev = interim(questions_testdev['questions'], 'testdev')
76 | print('Testdev size %d'%len(data_testdev))
77 | print('Write', path_testdev_q)
78 | json.dump(data_testdev, open(path_testdev_q, 'w'))
79 |
80 | data_test = interim(questions_test['questions'], 'test')
81 | print('Test size %d'%len(data_test))
82 | print('Write', path_test_q)
83 | json.dump(data_test, open(path_test_q, 'w'))
84 |
85 | if __name__ == "__main__":
86 |
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument('--dir_vqa', default='data/vqa', type=str, help='Path to vqa data directory')
89 | args = parser.parse_args()
90 | vqa_interim(args.dir_vqa)
91 |
--------------------------------------------------------------------------------
/vqa/datasets/vqa_interim.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import argparse
4 | from collections import Counter
5 |
6 | def get_subtype(split='train'):
7 | if split in ['train', 'val']:
8 | return split + '2014'
9 | else:
10 | return 'test2015'
11 |
12 | def get_image_name_old(subtype='train2014', image_id='1', format='%s/COCO_%s_%012d.jpg'):
13 | return format%(subtype, subtype, image_id)
14 |
15 | def get_image_name(subtype='train2014', image_id='1', format='COCO_%s_%012d.jpg'):
16 | return format%(subtype, image_id)
17 |
18 | def interim(questions, split='train', annotations=[]):
19 | print('Interim', split)
20 | data = []
21 | for i in range(len(questions)):
22 | row = {}
23 | row['question_id'] = questions[i]['question_id']
24 | row['image_name'] = get_image_name(get_subtype(split), questions[i]['image_id'])
25 | row['question'] = questions[i]['question']
26 | row['MC_answer'] = questions[i]['multiple_choices']
27 | if split in ['train', 'val', 'trainval']:
28 | row['answer'] = annotations[i]['multiple_choice_answer']
29 | answers = []
30 | for ans in annotations[i]['answers']:
31 | answers.append(ans['answer'])
32 | row['answers_occurence'] = Counter(answers).most_common()
33 | data.append(row)
34 | return data
35 |
36 | def vqa_interim(dir_vqa):
37 | '''
38 | Put the VQA data into single json file in data/interim
39 | or train, val, trainval : [[question_id, image_name, question, MC_answer, answer] ... ]
40 | or test, test-dev : [[question_id, image_name, question, MC_answer] ... ]
41 | '''
42 |
43 | path_train_qa = os.path.join(dir_vqa, 'interim', 'train_questions_annotations.json')
44 | path_val_qa = os.path.join(dir_vqa, 'interim', 'val_questions_annotations.json')
45 | path_trainval_qa = os.path.join(dir_vqa, 'interim', 'trainval_questions_annotations.json')
46 | path_test_q = os.path.join(dir_vqa, 'interim', 'test_questions.json')
47 | path_testdev_q = os.path.join(dir_vqa, 'interim', 'testdev_questions.json')
48 |
49 | os.system('mkdir -p ' + os.path.join(dir_vqa, 'interim'))
50 |
51 | print('Loading annotations and questions...')
52 | annotations_train = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'mscoco_train2014_annotations.json'), 'r'))
53 | annotations_val = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'mscoco_val2014_annotations.json'), 'r'))
54 | questions_train = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'MultipleChoice_mscoco_train2014_questions.json'), 'r'))
55 | questions_val = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'MultipleChoice_mscoco_val2014_questions.json'), 'r'))
56 | questions_test = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'MultipleChoice_mscoco_test2015_questions.json'), 'r'))
57 | questions_testdev = json.load(open(os.path.join(dir_vqa, 'raw', 'annotations', 'MultipleChoice_mscoco_test-dev2015_questions.json'), 'r'))
58 |
59 | data_train = interim(questions_train['questions'], 'train', annotations_train['annotations'])
60 | print('Train size %d'%len(data_train))
61 | print('Write', path_train_qa)
62 | json.dump(data_train, open(path_train_qa, 'w'))
63 |
64 | data_val = interim(questions_val['questions'], 'val', annotations_val['annotations'])
65 | print('Val size %d'%len(data_val))
66 | print('Write', path_val_qa)
67 | json.dump(data_val, open(path_val_qa, 'w'))
68 |
69 | print('Concat. train and val')
70 | data_trainval = data_train + data_val
71 | print('Trainval size %d'%len(data_trainval))
72 | print('Write', path_trainval_qa)
73 | json.dump(data_trainval, open(path_trainval_qa, 'w'))
74 |
75 | data_testdev = interim(questions_testdev['questions'], 'testdev')
76 | print('Testdev size %d'%len(data_testdev))
77 | print('Write', path_testdev_q)
78 | json.dump(data_testdev, open(path_testdev_q, 'w'))
79 |
80 | data_test = interim(questions_test['questions'], 'test')
81 | print('Test size %d'%len(data_test))
82 | print('Write', path_test_q)
83 | json.dump(data_test, open(path_test_q, 'w'))
84 |
85 | if __name__ == "__main__":
86 |
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument('--dir_vqa', default='data/vqa', type=str, help='Path to vqa data directory')
89 | args = parser.parse_args()
90 | vqa_interim(args.dir_vqa)
91 |
--------------------------------------------------------------------------------
/vqa/datasets/vqa_processed.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a train/test pair of interim json data files.
3 | Caption: Use NLTK or split function to get tokens.
4 | """
5 | from random import shuffle, seed
6 | import sys
7 | import os.path
8 | import argparse
9 | import numpy as np
10 | import scipy.io
11 | import pdb
12 | import json
13 | import csv
14 | import re
15 | import math
16 | import pickle
17 | #import pprint
18 |
19 | def get_top_answers(examples, nans=3000):
20 | counts = {}
21 | for ex in examples:
22 | ans = ex['answer']
23 | counts[ans] = counts.get(ans, 0) + 1
24 |
25 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True)
26 | print('Top answer and their counts:' )
27 | print('\n'.join(map(str,cw[:20])))
28 |
29 | vocab = []
30 | for i in range(nans):
31 | vocab.append(cw[i][1])
32 | return vocab[:nans]
33 |
34 | def remove_examples(examples, ans_to_aid):
35 | new_examples = []
36 | for i, ex in enumerate(examples):
37 | if ex['answer'] in ans_to_aid:
38 | new_examples.append(ex)
39 | print('Number of examples reduced from %d to %d '%(len(examples), len(new_examples)))
40 | return new_examples
41 |
42 | def tokenize(sentence):
43 | return [i for i in re.split(r"([-.\"',:? !\$#@~()*&\^%;\[\]/\\\+<>\n=])", sentence) if i!='' and i!=' ' and i!='\n'];
44 |
45 | def tokenize_mcb(s):
46 | t_str = s.lower()
47 | for i in [r'\?',r'\!',r'\'',r'\"',r'\$',r'\:',r'\@',r'\(',r'\)',r'\,',r'\.',r'\;']:
48 | t_str = re.sub( i, '', t_str)
49 | for i in [r'\-',r'\/']:
50 | t_str = re.sub( i, ' ', t_str)
51 | q_list = re.sub(r'\?','',t_str.lower()).split(' ')
52 | q_list = list(filter(lambda x: len(x) > 0, q_list))
53 | return q_list
54 |
55 | def preprocess_questions(examples, nlp='nltk'):
56 | if nlp == 'nltk':
57 | from nltk.tokenize import word_tokenize
58 | print('Example of generated tokens after preprocessing some questions:')
59 | for i, ex in enumerate(examples):
60 | s = ex['question']
61 | if nlp == 'nltk':
62 | ex['question_words'] = word_tokenize(str(s).lower())
63 | elif nlp == 'mcb':
64 | ex['question_words'] = tokenize_mcb(s)
65 | else:
66 | ex['question_words'] = tokenize(s)
67 | if i < 10:
68 | print(ex['question_words'])
69 | if i % 1000 == 0:
70 | sys.stdout.write("processing %d/%d (%.2f%% done) \r" % (i, len(examples), i*100.0/len(examples)) )
71 | sys.stdout.flush()
72 | return examples
73 |
74 | def remove_long_tail_train(examples, minwcount=0):
75 | # Replace words which are in the long tail (counted less than 'minwcount' times) by the UNK token.
76 | # Also create vocab, a list of the final words.
77 |
78 | # count up the number of words
79 | counts = {}
80 | for ex in examples:
81 | for w in ex['question_words']:
82 | counts[w] = counts.get(w, 0) + 1
83 | cw = sorted([(count,w) for w, count in counts.items()], reverse=True)
84 | print('Top words and their counts:')
85 | print('\n'.join(map(str,cw[:20])))
86 |
87 | total_words = sum(counts.values())
88 | print('Total words:', total_words)
89 | bad_words = [w for w,n in counts.items() if n <= minwcount]
90 | vocab = [w for w,n in counts.items() if n > minwcount]
91 | bad_count = sum(counts[w] for w in bad_words)
92 | print('Number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)))
93 | print('Number of words in vocab would be %d' % (len(vocab), ))
94 | print('Number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words))
95 |
96 | print('Insert the special UNK token')
97 | vocab.append('UNK')
98 | for ex in examples:
99 | words = ex['question_words']
100 | question = [w if counts.get(w,0) > minwcount else 'UNK' for w in words]
101 | ex['question_words_UNK'] = question
102 |
103 | return examples, vocab
104 |
105 | def remove_long_tail_test(examples, word_to_wid):
106 | for ex in examples:
107 | ex['question_words_UNK'] = [w if w in word_to_wid else 'UNK' for w in ex['question_words']]
108 | return examples
109 |
110 | def encode_question(examples, word_to_wid, maxlength=15, pad='left'):
111 | # Add to tuple question_wids and question_length
112 | for i, ex in enumerate(examples):
113 | ex['question_length'] = min(maxlength, len(ex['question_words_UNK'])) # record the length of this sequence
114 | ex['question_wids'] = [0]*maxlength
115 | for k, w in enumerate(ex['question_words_UNK']):
116 | if k < maxlength:
117 | if pad == 'right':
118 | ex['question_wids'][k] = word_to_wid[w]
119 | else: #['pad'] == 'left'
120 | new_k = k + maxlength - len(ex['question_words_UNK'])
121 | ex['question_wids'][new_k] = word_to_wid[w]
122 | ex['seq_length'] = len(ex['question_words_UNK'])
123 | return examples
124 |
125 | def encode_answer(examples, ans_to_aid):
126 | print('Warning: aid of answer not in vocab is 1999')
127 | for i, ex in enumerate(examples):
128 | ex['answer_aid'] = ans_to_aid.get(ex['answer'], 1999) # -1 means answer not in vocab
129 | return examples
130 |
131 | def encode_answers_occurence(examples, ans_to_aid):
132 | for i, ex in enumerate(examples):
133 | answers = []
134 | answers_aid = []
135 | answers_count = []
136 | for ans in ex['answers_occurence']:
137 | aid = ans_to_aid.get(ans[0], -1) # -1 means answer not in vocab
138 | if aid != -1:
139 | answers.append(ans[0])
140 | answers_aid.append(aid)
141 | answers_count.append(ans[1])
142 | ex['answers'] = answers
143 | ex['answers_aid'] = answers_aid
144 | ex['answers_count'] = answers_count
145 | return examples
146 |
147 | def vqa_processed(params):
148 |
149 | #####################################################
150 | ## Read input files
151 | #####################################################
152 |
153 | path_train = os.path.join(params['dir'], 'interim', params['trainsplit']+'_questions_annotations.json')
154 | if params['trainsplit'] == 'train':
155 | path_val = os.path.join(params['dir'], 'interim', 'val_questions_annotations.json')
156 | path_test = os.path.join(params['dir'], 'interim', 'test_questions.json')
157 | path_testdev = os.path.join(params['dir'], 'interim', 'testdev_questions.json')
158 |
159 | # An example is a tuple (question, image, answer)
160 | # /!\ test and test-dev have no answer
161 | trainset = json.load(open(path_train, 'r'))
162 | if params['trainsplit'] == 'train':
163 | valset = json.load(open(path_val, 'r'))
164 | testset = json.load(open(path_test, 'r'))
165 | testdevset = json.load(open(path_testdev, 'r'))
166 |
167 | #####################################################
168 | ## Preprocess examples (questions and answers)
169 | #####################################################
170 |
171 | top_answers = get_top_answers(trainset, params['nans'])
172 | aid_to_ans = [a for i,a in enumerate(top_answers)]
173 | ans_to_aid = {a:i for i,a in enumerate(top_answers)}
174 | # Remove examples if answer is not in top answers
175 | trainset = remove_examples(trainset, ans_to_aid)
176 |
177 | # Add 'question_words' to the initial tuple
178 | trainset = preprocess_questions(trainset, params['nlp'])
179 | if params['trainsplit'] == 'train':
180 | valset = preprocess_questions(valset, params['nlp'])
181 | testset = preprocess_questions(testset, params['nlp'])
182 | testdevset = preprocess_questions(testdevset, params['nlp'])
183 |
184 | # Also process top_words which contains a UNK char
185 | trainset, top_words = remove_long_tail_train(trainset, params['minwcount'])
186 | wid_to_word = {i+1:w for i,w in enumerate(top_words)}
187 | word_to_wid = {w:i+1 for i,w in enumerate(top_words)}
188 |
189 | if params['trainsplit'] == 'train':
190 | valset = remove_long_tail_test(valset, word_to_wid)
191 | testset = remove_long_tail_test(testset, word_to_wid)
192 | testdevset = remove_long_tail_test(testdevset, word_to_wid)
193 |
194 | trainset = encode_question(trainset, word_to_wid, params['maxlength'], params['pad'])
195 | if params['trainsplit'] == 'train':
196 | valset = encode_question(valset, word_to_wid, params['maxlength'], params['pad'])
197 | testset = encode_question(testset, word_to_wid, params['maxlength'], params['pad'])
198 | testdevset = encode_question(testdevset, word_to_wid, params['maxlength'], params['pad'])
199 |
200 | trainset = encode_answer(trainset, ans_to_aid)
201 | trainset = encode_answers_occurence(trainset, ans_to_aid)
202 | if params['trainsplit'] == 'train':
203 | valset = encode_answer(valset, ans_to_aid)
204 | valset = encode_answers_occurence(valset, ans_to_aid)
205 |
206 | #####################################################
207 | ## Write output files
208 | #####################################################
209 |
210 | # Paths to output files
211 | # Ex: data/vqa/processed/nans,3000_maxlength,15_..._trainsplit,train_testsplit,val/id_to_word.json
212 | subdirname = 'nans,'+str(params['nans'])
213 | for param in ['maxlength', 'minwcount', 'nlp', 'pad', 'trainsplit']:
214 | subdirname += '_' + param + ',' + str(params[param])
215 | os.system('mkdir -p ' + os.path.join(params['dir'], 'processed', subdirname))
216 |
217 | path_wid_to_word = os.path.join(params['dir'], 'processed', subdirname, 'wid_to_word.pickle')
218 | path_word_to_wid = os.path.join(params['dir'], 'processed', subdirname, 'word_to_wid.pickle')
219 | path_aid_to_ans = os.path.join(params['dir'], 'processed', subdirname, 'aid_to_ans.pickle')
220 | path_ans_to_aid = os.path.join(params['dir'], 'processed', subdirname, 'ans_to_aid.pickle')
221 | if params['trainsplit'] == 'train':
222 | path_trainset = os.path.join(params['dir'], 'processed', subdirname, 'trainset.pickle')
223 | path_valset = os.path.join(params['dir'], 'processed', subdirname, 'valset.pickle')
224 | elif params['trainsplit'] == 'trainval':
225 | path_trainset = os.path.join(params['dir'], 'processed', subdirname, 'trainvalset.pickle')
226 | path_testset = os.path.join(params['dir'], 'processed', subdirname, 'testset.pickle')
227 | path_testdevset = os.path.join(params['dir'], 'processed', subdirname, 'testdevset.pickle')
228 |
229 | print('Write wid_to_word to', path_wid_to_word)
230 | with open(path_wid_to_word, 'wb') as handle:
231 | pickle.dump(wid_to_word, handle)
232 |
233 | print('Write word_to_wid to', path_word_to_wid)
234 | with open(path_word_to_wid, 'wb') as handle:
235 | pickle.dump(word_to_wid, handle)
236 |
237 | print('Write aid_to_ans to', path_aid_to_ans)
238 | with open(path_aid_to_ans, 'wb') as handle:
239 | pickle.dump(aid_to_ans, handle)
240 |
241 | print('Write ans_to_aid to', path_ans_to_aid)
242 | with open(path_ans_to_aid, 'wb') as handle:
243 | pickle.dump(ans_to_aid, handle)
244 |
245 | print('Write trainset to', path_trainset)
246 | with open(path_trainset, 'wb') as handle:
247 | pickle.dump(trainset, handle)
248 |
249 | if params['trainsplit'] == 'train':
250 | print('Write valset to', path_valset)
251 | with open(path_valset, 'wb') as handle:
252 | pickle.dump(valset, handle)
253 |
254 | print('Write testset to', path_testset)
255 | with open(path_testset, 'wb') as handle:
256 | pickle.dump(testset, handle)
257 |
258 | print('Write testdevset to', path_testdevset)
259 | with open(path_testdevset, 'wb') as handle:
260 | pickle.dump(testdevset, handle)
261 |
262 |
263 | if __name__ == "__main__":
264 | parser = argparse.ArgumentParser()
265 | parser.add_argument('--dirname',
266 | default='data/vqa',
267 | type=str,
268 | help='Root directory containing raw, interim and processed directories'
269 | )
270 | parser.add_argument('--trainsplit',
271 | default='train',
272 | type=str,
273 | help='Options: train | trainval'
274 | )
275 | parser.add_argument('--nans',
276 | default=2000,
277 | type=int,
278 | help='Number of top answers for the final classifications'
279 | )
280 | parser.add_argument('--maxlength',
281 | default=26,
282 | type=int,
283 | help='Max number of words in a caption. Captions longer get clipped'
284 | )
285 | parser.add_argument('--minwcount',
286 | default=0,
287 | type=int,
288 | help='Words that occur less than that are removed from vocab'
289 | )
290 | parser.add_argument('--nlp',
291 | default='mcb',
292 | type=str,
293 | help='Token method ; Options: nltk | mcb | naive'
294 | )
295 | parser.add_argument('--pad',
296 | default='left',
297 | type=str,
298 | help='Padding ; Options: right (finish by zeros) | left (begin by zeros)'
299 | )
300 | args = parser.parse_args()
301 | opt_vqa = vars(args)
302 | vqa_processed(opt_vqa)
--------------------------------------------------------------------------------
/vqa/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cadene/vqa.pytorch/69d7a43fb02c0332176915a4a23fc47cab08b1d2/vqa/lib/__init__.py
--------------------------------------------------------------------------------
/vqa/lib/criterions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def factory(opt, cuda=True):
5 | criterion = nn.CrossEntropyLoss()
6 | if cuda:
7 | criterion = criterion.cuda()
8 | return criterion
--------------------------------------------------------------------------------
/vqa/lib/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.multiprocessing as multiprocessing
3 | from .sampler import SequentialSampler, RandomSampler
4 | import collections
5 | import math
6 | import sys
7 | import traceback
8 | import threading
9 | if sys.version_info[0] == 2:
10 | import Queue as queue
11 | else:
12 | import queue
13 |
14 |
15 | class ExceptionWrapper(object):
16 | "Wraps an exception plus traceback to communicate across threads"
17 |
18 | def __init__(self, exc_info):
19 | self.exc_type = exc_info[0]
20 | self.exc_msg = "".join(traceback.format_exception(*exc_info))
21 |
22 |
23 | def _worker_loop(dataset, index_queue, data_queue, collate_fn):
24 | torch.set_num_threads(1)
25 | while True:
26 | r = index_queue.get()
27 | if r is None:
28 | data_queue.put(None)
29 | break
30 | idx, batch_indices = r
31 | try:
32 | samples = collate_fn([dataset[i] for i in batch_indices])
33 | except Exception:
34 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
35 | else:
36 | data_queue.put((idx, samples))
37 |
38 |
39 | def _pin_memory_loop(in_queue, out_queue, done_event):
40 | while True:
41 | try:
42 | r = in_queue.get()
43 | except:
44 | if done_event.is_set():
45 | return
46 | raise
47 | if r is None:
48 | break
49 | if isinstance(r[1], ExceptionWrapper):
50 | out_queue.put(r)
51 | continue
52 | idx, batch = r
53 | try:
54 | batch = pin_memory_batch(batch)
55 | except Exception:
56 | out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
57 | else:
58 | out_queue.put((idx, batch))
59 |
60 |
61 | def default_collate(batch):
62 | string_classes = (str, bytes)
63 | "Puts each data field into a tensor with outer dimension batch size"
64 | if torch.is_tensor(batch[0]):
65 | return torch.stack(batch, 0)
66 | elif type(batch[0]).__module__ == 'numpy' and type(batch[0]).__name__ == 'ndarray':
67 | return torch.stack([torch.from_numpy(b) for b in batch], 0)
68 | elif isinstance(batch[0], int):
69 | return torch.LongTensor(batch)
70 | elif isinstance(batch[0], float):
71 | return torch.DoubleTensor(batch)
72 | elif isinstance(batch[0], string_classes):
73 | return batch
74 | elif isinstance(batch[0], dict):
75 | # ~ added by Cadene ~
76 | # if each batch element is a dict with same keys,
77 | # then it should be a dict of collated elements
78 | keys = batch[0].keys()
79 | new_dict = {}
80 | for key in keys:
81 | new_dict[key] = []
82 | for sample in batch:
83 | for key in keys:
84 | new_dict[key].append(sample[key])
85 | return {key:default_collate(samples) for key, samples in new_dict.items()}
86 | elif isinstance(batch[0], collections.Iterable):
87 | # if each batch element is not a tensor, then it should be a tuple
88 | # of tensors; in that case we collate each element in the tuple
89 | transposed = zip(*batch)
90 | return [default_collate(samples) for samples in transposed]
91 | raise TypeError(("batch must contain tensors, numbers, or lists; found {}"
92 | .format(type(batch[0]))))
93 |
94 |
95 | def pin_memory_batch(batch):
96 | if torch.is_tensor(batch):
97 | return batch.pin_memory()
98 | elif isinstance(batch, dict):
99 | # ~ added by Cadene ~
100 | return {key:pin_memory_batch(sample) for key,sample in batch.items()}
101 | elif isinstance(batch[0], str):
102 | # ~ added by Cadene ~
103 | return batch
104 | elif isinstance(batch, collections.Iterable):
105 | return [pin_memory_batch(sample) for sample in batch]
106 | else:
107 | return batch
108 |
109 |
110 | class DataLoaderIter(object):
111 | "Iterates once over the DataLoader's dataset, as specified by the sampler"
112 |
113 | def __init__(self, loader):
114 | self.dataset = loader.dataset
115 | self.batch_size = loader.batch_size
116 | self.collate_fn = loader.collate_fn
117 | self.sampler = loader.sampler
118 | self.num_workers = loader.num_workers
119 | self.pin_memory = loader.pin_memory
120 | self.done_event = threading.Event()
121 |
122 | self.samples_remaining = len(self.sampler)
123 | self.sample_iter = iter(self.sampler)
124 |
125 | if self.num_workers > 0:
126 | self.index_queue = multiprocessing.SimpleQueue()
127 | self.data_queue = multiprocessing.SimpleQueue()
128 | self.batches_outstanding = 0
129 | self.shutdown = False
130 | self.send_idx = 0
131 | self.rcvd_idx = 0
132 | self.reorder_dict = {}
133 |
134 | self.workers = [
135 | multiprocessing.Process(
136 | target=_worker_loop,
137 | args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
138 | for _ in range(self.num_workers)]
139 |
140 | for w in self.workers:
141 | w.daemon = True # ensure that the worker exits on process exit
142 | w.start()
143 |
144 | if self.pin_memory:
145 | in_data = self.data_queue
146 | self.data_queue = queue.Queue()
147 | self.pin_thread = threading.Thread(
148 | target=_pin_memory_loop,
149 | args=(in_data, self.data_queue, self.done_event))
150 | self.pin_thread.daemon = True
151 | self.pin_thread.start()
152 |
153 | # prime the prefetch loop
154 | for _ in range(2 * self.num_workers):
155 | self._put_indices()
156 |
157 | def __len__(self):
158 | return int(math.ceil(len(self.sampler) / float(self.batch_size)))
159 |
160 | def __next__(self):
161 | if self.num_workers == 0:
162 | # same-process loading
163 | if self.samples_remaining == 0:
164 | raise StopIteration
165 | indices = self._next_indices()
166 | batch = self.collate_fn([self.dataset[i] for i in indices])
167 | if self.pin_memory:
168 | batch = pin_memory_batch(batch)
169 | return batch
170 |
171 | # check if the next sample has already been generated
172 | if self.rcvd_idx in self.reorder_dict:
173 | batch = self.reorder_dict.pop(self.rcvd_idx)
174 | return self._process_next_batch(batch)
175 |
176 | if self.batches_outstanding == 0:
177 | self._shutdown_workers()
178 | raise StopIteration
179 |
180 | while True:
181 | assert (not self.shutdown and self.batches_outstanding > 0)
182 | idx, batch = self.data_queue.get()
183 | self.batches_outstanding -= 1
184 | if idx != self.rcvd_idx:
185 | # store out-of-order samples
186 | self.reorder_dict[idx] = batch
187 | continue
188 | return self._process_next_batch(batch)
189 |
190 | next = __next__ # Python 2 compatibility
191 |
192 | def __iter__(self):
193 | return self
194 |
195 | def _next_indices(self):
196 | batch_size = min(self.samples_remaining, self.batch_size)
197 | batch = [next(self.sample_iter) for _ in range(batch_size)]
198 | self.samples_remaining -= len(batch)
199 | return batch
200 |
201 | def _put_indices(self):
202 | assert self.batches_outstanding < 2 * self.num_workers
203 | if self.samples_remaining > 0:
204 | self.index_queue.put((self.send_idx, self._next_indices()))
205 | self.batches_outstanding += 1
206 | self.send_idx += 1
207 |
208 | def _process_next_batch(self, batch):
209 | self.rcvd_idx += 1
210 | self._put_indices()
211 | if isinstance(batch, ExceptionWrapper):
212 | raise batch.exc_type(batch.exc_msg)
213 | return batch
214 |
215 | def __getstate__(self):
216 | # TODO: add limited pickling support for sharing an iterator
217 | # across multiple threads for HOGWILD.
218 | # Probably the best way to do this is by moving the sample pushing
219 | # to a separate thread and then just sharing the data queue
220 | # but signalling the end is tricky without a non-blocking API
221 | raise NotImplementedError("DataLoaderIterator cannot be pickled")
222 |
223 | def _shutdown_workers(self):
224 | if not self.shutdown:
225 | self.shutdown = True
226 | self.done_event.set()
227 | for _ in self.workers:
228 | self.index_queue.put(None)
229 |
230 | def __del__(self):
231 | if self.num_workers > 0:
232 | self._shutdown_workers()
233 |
234 |
235 | class DataLoader(object):
236 | """
237 | Data loader. Combines a dataset and a sampler, and provides
238 | single- or multi-process iterators over the dataset.
239 |
240 | Arguments:
241 | dataset (Dataset): dataset from which to load the data.
242 | batch_size (int, optional): how many samples per batch to load
243 | (default: 1).
244 | shuffle (bool, optional): set to ``True`` to have the data reshuffled
245 | at every epoch (default: False).
246 | sampler (Sampler, optional): defines the strategy to draw samples from
247 | the dataset. If specified, the ``shuffle`` argument is ignored.
248 | num_workers (int, optional): how many subprocesses to use for data
249 | loading. 0 means that the data will be loaded in the main process
250 | (default: 0)
251 | collate_fn (callable, optional)
252 | pin_memory (bool, optional)
253 | """
254 |
255 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
256 | num_workers=0, collate_fn=default_collate, pin_memory=False):
257 | self.dataset = dataset
258 | self.batch_size = batch_size
259 | self.num_workers = num_workers
260 | self.collate_fn = collate_fn
261 | self.pin_memory = pin_memory
262 |
263 | if sampler is not None:
264 | self.sampler = sampler
265 | elif shuffle:
266 | self.sampler = RandomSampler(dataset)
267 | elif not shuffle:
268 | self.sampler = SequentialSampler(dataset)
269 |
270 | def __iter__(self):
271 | return DataLoaderIter(self)
272 |
273 | def __len__(self):
274 | return int(math.ceil(len(self.sampler) / float(self.batch_size)))
275 |
--------------------------------------------------------------------------------
/vqa/lib/engine.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from torch.autograd import Variable
4 | import vqa.lib.utils as utils
5 |
6 | def train(loader, model, criterion, optimizer, logger, epoch, print_freq=10):
7 | # switch to train mode
8 | model.train()
9 | meters = logger.reset_meters('train')
10 |
11 | end = time.time()
12 | for i, sample in enumerate(loader):
13 | batch_size = sample['visual'].size(0)
14 |
15 | # measure data loading time
16 | meters['data_time'].update(time.time() - end, n=batch_size)
17 |
18 | input_visual = Variable(sample['visual'])
19 | input_question = Variable(sample['question'])
20 | target_answer = Variable(sample['answer'].cuda(async=True))
21 |
22 | # compute output
23 | output = model(input_visual, input_question)
24 | torch.cuda.synchronize()
25 | loss = criterion(output, target_answer)
26 | meters['loss'].update(loss.data[0], n=batch_size)
27 |
28 | # measure accuracy
29 | acc1, acc5 = utils.accuracy(output.data, target_answer.data, topk=(1, 5))
30 | meters['acc1'].update(acc1[0], n=batch_size)
31 | meters['acc5'].update(acc5[0], n=batch_size)
32 |
33 | # compute gradient and do SGD step
34 | optimizer.zero_grad()
35 | loss.backward()
36 | torch.cuda.synchronize()
37 | optimizer.step()
38 | torch.cuda.synchronize()
39 |
40 | # measure elapsed time
41 | meters['batch_time'].update(time.time() - end, n=batch_size)
42 | end = time.time()
43 |
44 | if i % print_freq == 0:
45 | print('Epoch: [{0}][{1}/{2}]\t'
46 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
47 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
48 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
49 | 'Acc@1 {acc1.val:.3f} ({acc1.avg:.3f})\t'
50 | 'Acc@5 {acc5.val:.3f} ({acc5.avg:.3f})'.format(
51 | epoch, i, len(loader),
52 | batch_time=meters['batch_time'], data_time=meters['data_time'],
53 | loss=meters['loss'], acc1=meters['acc1'], acc5=meters['acc5']))
54 |
55 | logger.log_meters('train', n=epoch)
56 |
57 | # def adjust_learning_rate(optimizer, epoch):
58 | # """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
59 | # lr = args.lr * (0.1 ** (epoch // 30))
60 | # for param_group in optimizer.param_groups:
61 | # param_group['lr'] = lr
62 |
63 |
64 | def validate(loader, model, criterion, logger, epoch=0, print_freq=10):
65 | results = []
66 |
67 | # switch to evaluate mode
68 | model.eval()
69 | meters = logger.reset_meters('val')
70 |
71 | end = time.time()
72 | for i, sample in enumerate(loader):
73 | batch_size = sample['visual'].size(0)
74 | input_visual = Variable(sample['visual'].cuda(async=True), volatile=True)
75 | input_question = Variable(sample['question'].cuda(async=True), volatile=True)
76 | target_answer = Variable(sample['answer'].cuda(async=True), volatile=True)
77 |
78 | # compute output
79 | output = model(input_visual, input_question)
80 | loss = criterion(output, target_answer)
81 | meters['loss'].update(loss.data[0], n=batch_size)
82 |
83 | # measure accuracy and record loss
84 | acc1, acc5 = utils.accuracy(output.data, target_answer.data, topk=(1, 5))
85 | meters['acc1'].update(acc1[0], n=batch_size)
86 | meters['acc5'].update(acc5[0], n=batch_size)
87 |
88 | # compute predictions for OpenEnded accuracy
89 | _, pred = output.data.cpu().max(1)
90 | pred.squeeze_()
91 | for j in range(batch_size):
92 | results.append({'question_id': sample['question_id'][j],
93 | 'answer': loader.dataset.aid_to_ans[pred[j]]})
94 |
95 | # measure elapsed time
96 | meters['batch_time'].update(time.time() - end, n=batch_size)
97 | end = time.time()
98 |
99 | if i % print_freq == 0:
100 | print('Val: [{0}/{1}]\t'
101 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
102 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
103 | 'Acc@1 {acc1.val:.3f} ({acc1.avg:.3f})\t'
104 | 'Acc@5 {acc5.val:.3f} ({acc5.avg:.3f})'.format(
105 | i, len(loader), batch_time=meters['batch_time'],
106 | data_time=meters['data_time'], loss=meters['loss'],
107 | acc1=meters['acc1'], acc5=meters['acc5']))
108 |
109 | print(' * Acc@1 {acc1.avg:.3f} Acc@5 {acc5.avg:.3f}'
110 | .format(acc1=meters['acc1'], acc5=meters['acc1']))
111 |
112 | logger.log_meters('val', n=epoch)
113 | return meters['acc1'].avg, results
114 |
115 |
116 | def test(loader, model, logger, epoch=0, print_freq=10):
117 | results = []
118 | testdev_results = []
119 |
120 | model.eval()
121 | meters = logger.reset_meters('test')
122 |
123 | end = time.time()
124 | for i, sample in enumerate(loader):
125 | batch_size = sample['visual'].size(0)
126 | input_visual = Variable(sample['visual'].cuda(async=True), volatile=True)
127 | input_question = Variable(sample['question'].cuda(async=True), volatile=True)
128 |
129 | # compute output
130 | output = model(input_visual, input_question)
131 |
132 | # compute predictions for OpenEnded accuracy
133 | _, pred = output.data.cpu().max(1)
134 | pred.squeeze_()
135 | for j in range(batch_size):
136 | item = {'question_id': sample['question_id'][j],
137 | 'answer': loader.dataset.aid_to_ans[pred[j]]}
138 | results.append(item)
139 | if sample['is_testdev'][j]:
140 | testdev_results.append(item)
141 |
142 | # measure elapsed time
143 | meters['batch_time'].update(time.time() - end, n=batch_size)
144 | end = time.time()
145 |
146 | if i % print_freq == 0:
147 | print('Test: [{0}/{1}]\t'
148 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
149 | i, len(loader), batch_time=meters['batch_time']))
150 |
151 | logger.log_meters('test', n=epoch)
152 | return results, testdev_results
153 |
--------------------------------------------------------------------------------
/vqa/lib/logger.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import time
3 | import json
4 | import numpy as np
5 | import os
6 | from collections import defaultdict
7 |
8 | class Experiment(object):
9 |
10 | def __init__(self, name, options=dict()):
11 | """ Create an experiment
12 | """
13 | super(Experiment, self).__init__()
14 |
15 | self.name = name
16 | self.options = options
17 | self.date_and_time = time.strftime('%d-%m-%Y--%H-%M-%S')
18 |
19 | self.info = defaultdict(dict)
20 | self.logged = defaultdict(dict)
21 | self.meters = defaultdict(dict)
22 |
23 | def add_meters(self, tag, meters_dict):
24 | assert tag not in (self.meters.keys())
25 | for name, meter in meters_dict.items():
26 | self.add_meter(tag, name, meter)
27 |
28 | def add_meter(self, tag, name, meter):
29 | assert name not in list(self.meters[tag].keys()), \
30 | "meter with tag {} and name {} already exists".format(tag, name)
31 | self.meters[tag][name] = meter
32 |
33 | def update_options(self, options_dict):
34 | self.options.update(options_dict)
35 |
36 | def log_meter(self, tag, name, n=1):
37 | meter = self.get_meter(tag, name)
38 | if name not in self.logged[tag]:
39 | self.logged[tag][name] = {}
40 | self.logged[tag][name][n] = meter.value()
41 |
42 | def log_meters(self, tag, n=1):
43 | for name, meter in self.get_meters(tag).items():
44 | self.log_meter(tag, name, n=n)
45 |
46 | def reset_meters(self, tag):
47 | meters = self.get_meters(tag)
48 | for name, meter in meters.items():
49 | meter.reset()
50 | return meters
51 |
52 | def get_meters(self, tag):
53 | assert tag in list(self.meters.keys())
54 | return self.meters[tag]
55 |
56 | def get_meter(self, tag, name):
57 | assert tag in list(self.meters.keys())
58 | assert name in list(self.meters[tag].keys())
59 | return self.meters[tag][name]
60 |
61 | def to_json(self, filename):
62 | os.system('mkdir -p ' + os.path.dirname(filename))
63 | var_dict = copy.copy(vars(self))
64 | var_dict.pop('meters')
65 | for key in ('viz', 'viz_dict'):
66 | if key in list(var_dict.keys()):
67 | var_dict.pop(key)
68 | with open(filename, 'w') as f:
69 | json.dump(var_dict, f)
70 |
71 | def from_json(filename):
72 | with open(filename, 'r') as f:
73 | var_dict = json.load(f)
74 | xp = Experiment('')
75 | xp.date_and_time = var_dict['date_and_time']
76 | xp.logged = var_dict['logged']
77 | # TODO: Remove
78 | if 'info' in var_dict:
79 | xp.info = var_dict['info']
80 | xp.options = var_dict['options']
81 | xp.name = var_dict['name']
82 | return xp
83 |
84 |
85 | class AvgMeter(object):
86 | """Computes and stores the average and current value"""
87 | def __init__(self):
88 | self.reset()
89 |
90 | def reset(self):
91 | self.val = 0
92 | self.avg = 0
93 | self.sum = 0
94 | self.count = 0
95 |
96 | def update(self, val, n=1):
97 | self.val = val
98 | self.sum += val * n
99 | self.count += n
100 | self.avg = self.sum / self.count
101 |
102 | def value(self):
103 | return self.avg
104 |
105 |
106 | class SumMeter(object):
107 | """Computes and stores the sum and current value"""
108 | def __init__(self):
109 | self.reset()
110 |
111 | def reset(self):
112 | self.val = 0
113 | self.sum = 0
114 | self.count = 0
115 |
116 | def update(self, val, n=1):
117 | self.val = val
118 | self.sum += val * n
119 | self.count += n
120 |
121 | def value(self):
122 | return self.sum
123 |
124 |
125 | class ValueMeter(object):
126 | """Computes and stores the average and current value"""
127 | def __init__(self):
128 | self.reset()
129 |
130 | def reset(self):
131 | self.val = 0
132 |
133 | def update(self, val):
134 | self.val = val
135 |
136 | def value(self):
137 | return self.val
--------------------------------------------------------------------------------
/vqa/lib/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Sampler(object):
5 | """Base class for all Samplers.
6 | Every Sampler subclass has to provide an __iter__ method, providing a way
7 | to iterate over indices of dataset elements, and a __len__ method that
8 | returns the length of the returned iterators.
9 | """
10 |
11 | def __init__(self, data_source):
12 | pass
13 |
14 | def __iter__(self):
15 | raise NotImplementedError
16 |
17 | def __len__(self):
18 | raise NotImplementedError
19 |
20 |
21 | class SequentialSampler(Sampler):
22 | """Samples elements sequentially, always in the same order.
23 | Arguments:
24 | data_source (Dataset): dataset to sample from
25 | """
26 |
27 | def __init__(self, data_source):
28 | self.num_samples = len(data_source)
29 |
30 | def __iter__(self):
31 | return iter(range(self.num_samples))
32 |
33 | def __len__(self):
34 | return self.num_samples
35 |
36 |
37 | class RandomSampler(Sampler):
38 | """Samples elements randomly, without replacement.
39 | Arguments:
40 | data_source (Dataset): dataset to sample from
41 | """
42 |
43 | def __init__(self, data_source):
44 | self.num_samples = len(data_source)
45 |
46 | def __iter__(self):
47 | return iter(torch.randperm(self.num_samples).long())
48 |
49 | def __len__(self):
50 | return self.num_samples
--------------------------------------------------------------------------------
/vqa/lib/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import collections
3 | import torch
4 | import numpy as np
5 |
6 | def update_values(dict_from, dict_to):
7 | for key, value in dict_from.items():
8 | if isinstance(value, dict):
9 | update_values(dict_from[key], dict_to[key])
10 | elif value is not None:
11 | dict_to[key] = dict_from[key]
12 | return dict_to
13 |
14 | def merge_dict(a, b):
15 | if isinstance(a, dict) and isinstance(b, dict):
16 | d = dict(a)
17 | d.update({k: merge_dict(a.get(k, None), b[k]) for k in b})
18 | if isinstance(a, list) and isinstance(b, list):
19 | return b
20 | #return [merge_dict(x, y) for x, y in itertools.zip_longest(a, b)]
21 | return a if b is None else b
22 |
23 | def accuracy(output, target, topk=(1,)):
24 | """Computes the precision@k for the specified values of k"""
25 | maxk = max(topk)
26 | batch_size = target.size(0)
27 |
28 | _, pred = output.topk(maxk, 1, True, True)
29 | pred = pred.t()
30 | if target.dim() == 2: # multians option
31 | _, target = torch.max(target, 1)
32 | correct = pred.eq(target.view(1, -1).expand_as(pred))
33 |
34 | res = []
35 | for k in topk:
36 | correct_k = correct[:k].view(-1).float().sum(0)
37 | res.append(correct_k.mul_(100.0 / batch_size))
38 | return res
39 |
40 | def params_count(model):
41 | count = 0
42 | for p in model.parameters():
43 | c = 1
44 | for i in range(p.dim()):
45 | c *= p.size(i)
46 | count += c
47 | return count
48 |
49 | def str2bool(v):
50 | if v is None:
51 | return v
52 | elif type(v) == bool:
53 | return v
54 | elif type(v) == str:
55 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
56 | return True
57 | if v.lower() in ('no', 'false', 'f', 'n', '0'):
58 | return False
59 | raise argparse.ArgumentTypeError('Boolean value expected.')
60 |
61 | def create_n_hot(idxs, N):
62 | out = np.zeros(N)
63 | for i in idxs:
64 | out[i] += 1
65 | return torch.Tensor(out/out.sum())
66 |
67 |
--------------------------------------------------------------------------------
/vqa/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .noatt import MLBNoAtt, MutanNoAtt
2 | from .att import MLBAtt, MutanAtt
3 | from .utils import factory
4 | from .utils import model_names
--------------------------------------------------------------------------------
/vqa/models/att.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import copy
6 |
7 | from vqa.lib import utils
8 | from vqa.models import seq2vec
9 | from vqa.models import fusion
10 |
11 | class AbstractAtt(nn.Module):
12 |
13 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]):
14 | super(AbstractAtt, self).__init__()
15 | self.opt = opt
16 | self.vocab_words = vocab_words
17 | self.vocab_answers = vocab_answers
18 | self.num_classes = len(self.vocab_answers)
19 | # Modules
20 | self.seq2vec = seq2vec.factory(self.vocab_words, self.opt['seq2vec'])
21 | # Modules for attention
22 | self.conv_v_att = nn.Conv2d(self.opt['dim_v'],
23 | self.opt['attention']['dim_v'], 1, 1)
24 | self.linear_q_att = nn.Linear(self.opt['dim_q'],
25 | self.opt['attention']['dim_q'])
26 | self.conv_att = nn.Conv2d(self.opt['attention']['dim_mm'],
27 | self.opt['attention']['nb_glimpses'], 1, 1)
28 | # Modules for classification
29 | self.list_linear_v_fusion = None
30 | self.linear_q_fusion = None
31 | self.linear_classif = None
32 |
33 | def _fusion_att(self, x_v, x_q):
34 | raise NotImplementedError
35 |
36 | def _fusion_classif(self, x_v, x_q):
37 | raise NotImplementedError
38 |
39 | def _attention(self, input_v, x_q_vec):
40 | batch_size = input_v.size(0)
41 | width = input_v.size(2)
42 | height = input_v.size(3)
43 |
44 | # Process visual before fusion
45 | #x_v = input_v.view(batch_size*width*height, dim_features)
46 | x_v = input_v
47 | x_v = F.dropout(x_v,
48 | p=self.opt['attention']['dropout_v'],
49 | training=self.training)
50 | x_v = self.conv_v_att(x_v)
51 | if 'activation_v' in self.opt['attention']:
52 | x_v = getattr(F, self.opt['attention']['activation_v'])(x_v)
53 | x_v = x_v.view(batch_size,
54 | self.opt['attention']['dim_v'],
55 | width * height)
56 | x_v = x_v.transpose(1,2)
57 |
58 | # Process question before fusion
59 | x_q = F.dropout(x_q_vec, p=self.opt['attention']['dropout_q'],
60 | training=self.training)
61 | x_q = self.linear_q_att(x_q)
62 | if 'activation_q' in self.opt['attention']:
63 | x_q = getattr(F, self.opt['attention']['activation_q'])(x_q)
64 | x_q = x_q.view(batch_size,
65 | 1,
66 | self.opt['attention']['dim_q'])
67 | x_q = x_q.expand(batch_size,
68 | width * height,
69 | self.opt['attention']['dim_q'])
70 |
71 | # First multimodal fusion
72 | x_att = self._fusion_att(x_v, x_q)
73 |
74 | if 'activation_mm' in self.opt['attention']:
75 | x_att = getattr(F, self.opt['attention']['activation_mm'])(x_att)
76 |
77 | # Process attention vectors
78 | x_att = F.dropout(x_att,
79 | p=self.opt['attention']['dropout_mm'],
80 | training=self.training)
81 | # can be optim to avoid two views and transposes
82 | x_att = x_att.view(batch_size,
83 | width,
84 | height,
85 | self.opt['attention']['dim_mm'])
86 | x_att = x_att.transpose(2,3).transpose(1,2)
87 | x_att = self.conv_att(x_att)
88 | x_att = x_att.view(batch_size,
89 | self.opt['attention']['nb_glimpses'],
90 | width * height)
91 | list_att_split = torch.split(x_att, 1, dim=1)
92 | list_att = []
93 | for x_att in list_att_split:
94 | x_att = x_att.contiguous()
95 | x_att = x_att.view(batch_size, width*height)
96 | x_att = F.softmax(x_att)
97 | list_att.append(x_att)
98 |
99 | self.list_att = [x_att.data for x_att in list_att]
100 |
101 | # Apply attention vectors to input_v
102 | x_v = input_v.view(batch_size, self.opt['dim_v'], width * height)
103 | x_v = x_v.transpose(1,2)
104 |
105 | list_v_att = []
106 | for i, x_att in enumerate(list_att):
107 | x_att = x_att.view(batch_size,
108 | width * height,
109 | 1)
110 | x_att = x_att.expand(batch_size,
111 | width * height,
112 | self.opt['dim_v'])
113 | x_v_att = torch.mul(x_att, x_v)
114 | x_v_att = x_v_att.sum(1)
115 | x_v_att = x_v_att.view(batch_size, self.opt['dim_v'])
116 | list_v_att.append(x_v_att)
117 |
118 | return list_v_att
119 |
120 | def _fusion_glimpses(self, list_v_att, x_q_vec):
121 | # Process visual for each glimpses
122 | list_v = []
123 | for glimpse_id, x_v_att in enumerate(list_v_att):
124 | x_v = F.dropout(x_v_att,
125 | p=self.opt['fusion']['dropout_v'],
126 | training=self.training)
127 | x_v = self.list_linear_v_fusion[glimpse_id](x_v)
128 | if 'activation_v' in self.opt['fusion']:
129 | x_v = getattr(F, self.opt['fusion']['activation_v'])(x_v)
130 | list_v.append(x_v)
131 | x_v = torch.cat(list_v, 1)
132 |
133 | # Process question
134 | x_q = F.dropout(x_q_vec,
135 | p=self.opt['fusion']['dropout_q'],
136 | training=self.training)
137 | x_q = self.linear_q_fusion(x_q)
138 | if 'activation_q' in self.opt['fusion']:
139 | x_q = getattr(F, self.opt['fusion']['activation_q'])(x_q)
140 |
141 | # Second multimodal fusion
142 | x = self._fusion_classif(x_v, x_q)
143 | return x
144 |
145 | def _classif(self, x):
146 |
147 | if 'activation' in self.opt['classif']:
148 | x = getattr(F, self.opt['classif']['activation'])(x)
149 | x = F.dropout(x,
150 | p=self.opt['classif']['dropout'],
151 | training=self.training)
152 | x = self.linear_classif(x)
153 | return x
154 |
155 | def forward(self, input_v, input_q):
156 | if input_v.dim() != 4 and input_q.dim() != 2:
157 | raise ValueError
158 |
159 | x_q_vec = self.seq2vec(input_q)
160 | list_v_att = self._attention(input_v, x_q_vec)
161 | x = self._fusion_glimpses(list_v_att, x_q_vec)
162 | x = self._classif(x)
163 | return x
164 |
165 |
166 | class MLBAtt(AbstractAtt):
167 |
168 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]):
169 | # TODO: deep copy ?
170 | opt['attention']['dim_v'] = opt['attention']['dim_h']
171 | opt['attention']['dim_q'] = opt['attention']['dim_h']
172 | opt['attention']['dim_mm'] = opt['attention']['dim_h']
173 | super(MLBAtt, self).__init__(opt, vocab_words, vocab_answers)
174 | # Modules for classification
175 | self.list_linear_v_fusion = nn.ModuleList([
176 | nn.Linear(self.opt['dim_v'],
177 | self.opt['fusion']['dim_h'])
178 | for i in range(self.opt['attention']['nb_glimpses'])])
179 | self.linear_q_fusion = nn.Linear(self.opt['dim_q'],
180 | self.opt['fusion']['dim_h']
181 | * self.opt['attention']['nb_glimpses'])
182 | self.linear_classif = nn.Linear(self.opt['fusion']['dim_h']
183 | * self.opt['attention']['nb_glimpses'],
184 | self.num_classes)
185 |
186 | def _fusion_att(self, x_v, x_q):
187 | x_att = torch.mul(x_v, x_q)
188 | return x_att
189 |
190 | def _fusion_classif(self, x_v, x_q):
191 | x_mm = torch.mul(x_v, x_q)
192 | return x_mm
193 |
194 |
195 | class MutanAtt(AbstractAtt):
196 |
197 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]):
198 | # TODO: deep copy ?
199 | opt['attention']['dim_v'] = opt['attention']['dim_hv']
200 | opt['attention']['dim_q'] = opt['attention']['dim_hq']
201 | super(MutanAtt, self).__init__(opt, vocab_words, vocab_answers)
202 | # Modules for classification
203 | self.fusion_att = fusion.MutanFusion2d(self.opt['attention'],
204 | visual_embedding=False,
205 | question_embedding=False)
206 | self.list_linear_v_fusion = nn.ModuleList([
207 | nn.Linear(self.opt['dim_v'],
208 | int(self.opt['fusion']['dim_hv']
209 | / opt['attention']['nb_glimpses']))
210 | for i in range(self.opt['attention']['nb_glimpses'])])
211 | self.linear_q_fusion = nn.Linear(self.opt['dim_q'],
212 | self.opt['fusion']['dim_hq'])
213 | self.linear_classif = nn.Linear(self.opt['fusion']['dim_mm'],
214 | self.num_classes)
215 | self.fusion_classif = fusion.MutanFusion(self.opt['fusion'],
216 | visual_embedding=False,
217 | question_embedding=False)
218 |
219 | def _fusion_att(self, x_v, x_q):
220 | return self.fusion_att(x_v, x_q)
221 |
222 | def _fusion_classif(self, x_v, x_q):
223 | return self.fusion_classif(x_v, x_q)
224 |
--------------------------------------------------------------------------------
/vqa/models/convnets.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.models as pytorch_models
5 | import sys
6 | sys.path.append('vqa/external/pretrained-models.pytorch')
7 | import pretrainedmodels as torch7_models
8 |
9 | pytorch_resnet_names = sorted(name for name in pytorch_models.__dict__
10 | if name.islower()
11 | and name.startswith("resnet")
12 | and callable(pytorch_models.__dict__[name]))
13 |
14 | torch7_resnet_names = sorted(name for name in torch7_models.__dict__
15 | if name.islower()
16 | and callable(torch7_models.__dict__[name]))
17 |
18 | model_names = pytorch_resnet_names + torch7_resnet_names
19 |
20 | def factory(opt, cuda=True, data_parallel=True):
21 | opt = copy.copy(opt)
22 |
23 | class WrapperModule(nn.Module):
24 | def __init__(self, net, forward_fn):
25 | super(WrapperModule, self).__init__()
26 | self.net = net
27 | self.forward_fn = forward_fn
28 |
29 | def forward(self, x):
30 | return self.forward_fn(self.net, x)
31 |
32 | def __getattr__(self, attr):
33 | try:
34 | return super(WrapperModule, self).__getattr__(attr)
35 | except AttributeError:
36 | return getattr(self.net, attr)
37 |
38 | def forward_resnet(self, x):
39 | x = self.conv1(x)
40 | x = self.bn1(x)
41 | x = self.relu(x)
42 | x = self.maxpool(x)
43 | x = self.layer1(x)
44 | x = self.layer2(x)
45 | x = self.layer3(x)
46 | x = self.layer4(x)
47 |
48 | if 'pooling' in opt and opt['pooling']:
49 | x = self.avgpool(x)
50 | div = x.size(3) + x.size(2)
51 | x = x.sum(3)
52 | x = x.sum(2)
53 | x = x.view(x.size(0), -1)
54 | x = x.div(div)
55 |
56 | return x
57 |
58 | def forward_resnext(self, x):
59 | x = self.features(x)
60 |
61 | if 'pooling' in opt and opt['pooling']:
62 | x = self.avgpool(x)
63 | div = x.size(3) + x.size(2)
64 | x = x.sum(3)
65 | x = x.sum(2)
66 | x = x.view(x.size(0), -1)
67 | x = x.div(div)
68 |
69 | return x
70 |
71 | if opt['arch'] in pytorch_resnet_names:
72 | model = pytorch_models.__dict__[opt['arch']](pretrained=True)
73 |
74 | model = WrapperModule(model, forward_resnet) # ugly hack in case of DataParallel wrapping
75 |
76 | elif opt['arch'] == 'fbresnet152':
77 | model = torch7_models.__dict__[opt['arch']](num_classes=1000,
78 | pretrained='imagenet')
79 |
80 | model = WrapperModule(model, forward_resnet) # ugly hack in case of DataParallel wrapping
81 |
82 | elif opt['arch'] in torch7_resnet_names:
83 | model = torch7_models.__dict__[opt['arch']](num_classes=1000,
84 | pretrained='imagenet')
85 |
86 | model = WrapperModule(model, forward_resnext) # ugly hack in case of DataParallel wrapping
87 |
88 | else:
89 | raise ValueError
90 |
91 | if data_parallel:
92 | model = nn.DataParallel(model).cuda()
93 | if not cuda:
94 | raise ValueError
95 |
96 | if cuda:
97 | model.cuda()
98 |
99 | return model
100 |
--------------------------------------------------------------------------------
/vqa/models/fusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | class AbstractFusion(nn.Module):
7 |
8 | def __init__(self, opt={}):
9 | super(AbstractFusion, self).__init__()
10 | self.opt = opt
11 |
12 | def forward(self, input_v, input_q):
13 | raise NotImplementedError
14 |
15 |
16 | class MLBFusion(AbstractFusion):
17 |
18 | def __init__(self, opt):
19 | super(MLBFusion, self).__init__(opt)
20 | # Modules
21 | if 'dim_v' in self.opt:
22 | self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_h'])
23 | else:
24 | print('Warning fusion.py: no visual embedding before fusion')
25 |
26 | if 'dim_q' in self.opt:
27 | self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_h'])
28 | else:
29 | print('Warning fusion.py: no question embedding before fusion')
30 |
31 | def forward(self, input_v, input_q):
32 | # visual (cnn features)
33 | if 'dim_v' in self.opt:
34 | x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training)
35 | x_v = self.linear_v(x_v)
36 | if 'activation_v' in self.opt:
37 | x_v = getattr(F, self.opt['activation_v'])(x_v)
38 | else:
39 | x_v = input_v
40 | # question (rnn features)
41 | if 'dim_q' in self.opt:
42 | x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training)
43 | x_q = self.linear_q(x_q)
44 | if 'activation_q' in self.opt:
45 | x_q = getattr(F, self.opt['activation_q'])(x_q)
46 | else:
47 | x_q = input_q
48 | # hadamard product
49 | x_mm = torch.mul(x_q, x_v)
50 | return x_mm
51 |
52 |
53 | class MutanFusion(AbstractFusion):
54 |
55 | def __init__(self, opt, visual_embedding=True, question_embedding=True):
56 | super(MutanFusion, self).__init__(opt)
57 | self.visual_embedding = visual_embedding
58 | self.question_embedding = question_embedding
59 | # Modules
60 | if self.visual_embedding:
61 | self.linear_v = nn.Linear(self.opt['dim_v'], self.opt['dim_hv'])
62 | else:
63 | print('Warning fusion.py: no visual embedding before fusion')
64 |
65 | if self.question_embedding:
66 | self.linear_q = nn.Linear(self.opt['dim_q'], self.opt['dim_hq'])
67 | else:
68 | print('Warning fusion.py: no question embedding before fusion')
69 |
70 | self.list_linear_hv = nn.ModuleList([
71 | nn.Linear(self.opt['dim_hv'], self.opt['dim_mm'])
72 | for i in range(self.opt['R'])])
73 |
74 | self.list_linear_hq = nn.ModuleList([
75 | nn.Linear(self.opt['dim_hq'], self.opt['dim_mm'])
76 | for i in range(self.opt['R'])])
77 |
78 | def forward(self, input_v, input_q):
79 | if input_v.dim() != input_q.dim() and input_v.dim() != 2:
80 | raise ValueError
81 | batch_size = input_v.size(0)
82 |
83 | if self.visual_embedding:
84 | x_v = F.dropout(input_v, p=self.opt['dropout_v'], training=self.training)
85 | x_v = self.linear_v(x_v)
86 | if 'activation_v' in self.opt:
87 | x_v = getattr(F, self.opt['activation_v'])(x_v)
88 | else:
89 | x_v = input_v
90 |
91 | if self.question_embedding:
92 | x_q = F.dropout(input_q, p=self.opt['dropout_q'], training=self.training)
93 | x_q = self.linear_q(x_q)
94 | if 'activation_q' in self.opt:
95 | x_q = getattr(F, self.opt['activation_q'])(x_q)
96 | else:
97 | x_q = input_q
98 |
99 | x_mm = []
100 | for i in range(self.opt['R']):
101 |
102 | x_hv = F.dropout(x_v, p=self.opt['dropout_hv'], training=self.training)
103 | x_hv = self.list_linear_hv[i](x_hv)
104 | if 'activation_hv' in self.opt:
105 | x_hv = getattr(F, self.opt['activation_hv'])(x_hv)
106 |
107 | x_hq = F.dropout(x_q, p=self.opt['dropout_hq'], training=self.training)
108 | x_hq = self.list_linear_hq[i](x_hq)
109 | if 'activation_hq' in self.opt:
110 | x_hq = getattr(F, self.opt['activation_hq'])(x_hq)
111 |
112 | x_mm.append(torch.mul(x_hq, x_hv))
113 |
114 | x_mm = torch.stack(x_mm, dim=1)
115 | x_mm = x_mm.sum(1).view(batch_size, self.opt['dim_mm'])
116 |
117 | if 'activation_mm' in self.opt:
118 | x_mm = getattr(F, self.opt['activation_mm'])(x_mm)
119 |
120 | return x_mm
121 |
122 |
123 | class MutanFusion2d(MutanFusion):
124 |
125 | def __init__(self, opt, visual_embedding=True, question_embedding=True):
126 | super(MutanFusion2d, self).__init__(opt,
127 | visual_embedding,
128 | question_embedding)
129 |
130 | def forward(self, input_v, input_q):
131 | if input_v.dim() != input_q.dim() and input_v.dim() != 3:
132 | raise ValueError
133 | batch_size = input_v.size(0)
134 | weight_height = input_v.size(1)
135 | dim_hv = input_v.size(2)
136 | dim_hq = input_q.size(2)
137 | if not input_v.is_contiguous():
138 | input_v = input_v.contiguous()
139 | if not input_q.is_contiguous():
140 | input_q = input_q.contiguous()
141 | x_v = input_v.view(batch_size * weight_height, self.opt['dim_hv'])
142 | x_q = input_q.view(batch_size * weight_height, self.opt['dim_hq'])
143 | x_mm = super().forward(x_v, x_q)
144 | x_mm = x_mm.view(batch_size, weight_height, self.opt['dim_mm'])
145 | return x_mm
146 |
147 |
--------------------------------------------------------------------------------
/vqa/models/noatt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from vqa.lib import utils
6 | from vqa.models import fusion
7 | from vqa.models import seq2vec
8 |
9 | class AbstractNoAtt(nn.Module):
10 |
11 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]):
12 | super(AbstractNoAtt, self).__init__()
13 | self.opt = opt
14 | self.vocab_words = vocab_words
15 | self.vocab_answers = vocab_answers
16 | self.num_classes = len(self.vocab_answers)
17 | # Modules
18 | self.seq2vec = seq2vec.factory(self.vocab_words, self.opt['seq2vec'])
19 | self.linear_classif = nn.Linear(self.opt['fusion']['dim_h'], self.num_classes)
20 |
21 | def _fusion(self, input_v, input_q):
22 | raise NotImplementedError
23 |
24 | def _classif(self, x):
25 | if 'activation' in self.opt['classif']:
26 | x = getattr(F, self.opt['classif']['activation'])(x)
27 | x = F.dropout(x, p=self.opt['classif']['dropout'], training=self.training)
28 | x = self.linear_classif(x)
29 | return x
30 |
31 | def forward(self, input_v, input_q):
32 | x_q = self.seq2vec(input_q)
33 | x = self._fusion(input_v, x_q)
34 | x = self._classif(x)
35 | return x
36 |
37 |
38 | class MLBNoAtt(AbstractNoAtt):
39 |
40 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]):
41 | super(MLBNoAtt, self).__init__(opt, vocab_words, vocab_answers)
42 | self.fusion = fusion.MLBFusion(self.opt['fusion'])
43 |
44 | def _fusion(self, input_v, input_q):
45 | x = self.fusion(input_v, input_q)
46 | return x
47 |
48 |
49 | class MutanNoAtt(AbstractNoAtt):
50 |
51 | def __init__(self, opt={}, vocab_words=[], vocab_answers=[]):
52 | opt['fusion']['dim_h'] = opt['fusion']['dim_mm']
53 | super(MutanNoAtt, self).__init__(opt, vocab_words, vocab_answers)
54 | self.fusion = fusion.MutanFusion(self.opt['fusion'])
55 |
56 | def _fusion(self, input_v, input_q):
57 | x = self.fusion(input_v, input_q)
58 | return x
59 |
60 |
--------------------------------------------------------------------------------
/vqa/models/seq2vec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 | import sys
7 | sys.path.append('vqa/external/skip-thoughts.torch/pytorch')
8 | import skipthoughts
9 |
10 |
11 | def process_lengths(input):
12 | max_length = input.size(1)
13 | lengths = list(max_length - input.data.eq(0).sum(1).squeeze())
14 | return lengths
15 |
16 | def select_last(x, lengths):
17 | batch_size = x.size(0)
18 | seq_length = x.size(1)
19 | mask = x.data.new().resize_as_(x.data).fill_(0)
20 | for i in range(batch_size):
21 | mask[i][lengths[i]-1].fill_(1)
22 | mask = Variable(mask)
23 | x = x.mul(mask)
24 | x = x.sum(1).view(batch_size, x.size(2))
25 | return x
26 |
27 | class LSTM(nn.Module):
28 |
29 | def __init__(self, vocab, emb_size, hidden_size, num_layers):
30 | super(LSTM, self).__init__()
31 | self.vocab = vocab
32 | self.emb_size = emb_size
33 | self.hidden_size = hidden_size
34 | self.num_layers = num_layers
35 | self.embedding = nn.Embedding(num_embeddings=len(self.vocab)+1,
36 | embedding_dim=emb_size,
37 | padding_idx=0)
38 | self.rnn = nn.LSTM(input_size=emb_size, hidden_size=hidden_size, num_layers=num_layers)
39 |
40 | def forward(self, input):
41 | lengths = process_lengths(input)
42 | x = self.embedding(input) # seq2seq
43 | output, hn = self.rnn(x)
44 | output = select_last(output, lengths)
45 | return output
46 |
47 |
48 | class TwoLSTM(nn.Module):
49 |
50 | def __init__(self, vocab, emb_size, hidden_size):
51 | super(TwoLSTM, self).__init__()
52 | self.vocab = vocab
53 | self.emb_size = emb_size
54 | self.hidden_size = hidden_size
55 | self.embedding = nn.Embedding(num_embeddings=len(self.vocab)+1,
56 | embedding_dim=emb_size,
57 | padding_idx=0)
58 | self.rnn_0 = nn.LSTM(input_size=emb_size, hidden_size=hidden_size, num_layers=1)
59 | self.rnn_1 = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=1)
60 |
61 | def forward(self, input):
62 | lengths = process_lengths(input)
63 | x = self.embedding(input) # seq2seq
64 | x = getattr(F, 'tanh')(x)
65 | x_0, hn = self.rnn_0(x)
66 | vec_0 = select_last(x_0, lengths)
67 |
68 | # x_1 = F.dropout(x_0, p=0.3, training=self.training)
69 | # print(x_1.size())
70 | x_1, hn = self.rnn_1(x_0)
71 | vec_1 = select_last(x_1, lengths)
72 |
73 | vec_0 = F.dropout(vec_0, p=0.3, training=self.training)
74 | vec_1 = F.dropout(vec_1, p=0.3, training=self.training)
75 | output = torch.cat((vec_0, vec_1), 1)
76 | return output
77 |
78 |
79 | def factory(vocab_words, opt):
80 | if opt['arch'] == 'skipthoughts':
81 | st_class = getattr(skipthoughts, opt['type'])
82 | seq2vec = st_class(opt['dir_st'],
83 | vocab_words,
84 | dropout=opt['dropout'],
85 | fixed_emb=opt['fixed_emb'])
86 | elif opt['arch'] == '2-lstm':
87 | seq2vec = TwoLSTM(vocab_words,
88 | opt['emb_size'],
89 | opt['hidden_size'])
90 | elif opt['arch'] == 'lstm':
91 | seq2vec = TwoLSTM(vocab_words,
92 | opt['emb_size'],
93 | opt['hidden_size'],
94 | opt['num_layers'])
95 | else:
96 | raise NotImplementedError
97 | return seq2vec
98 |
99 |
100 | if __name__ == '__main__':
101 |
102 | vocab = ['robots', 'are', 'very', 'cool', '', 'BiDiBu']
103 | lstm = TwoLSTM(vocab, 300, 1024)
104 |
105 | input = Variable(torch.LongTensor([
106 | [1,2,3,4,5,0,0],
107 | [6,1,2,3,3,4,5],
108 | [6,1,2,3,3,4,5]
109 | ]))
110 | output = lstm(input)
111 |
--------------------------------------------------------------------------------
/vqa/models/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import copy
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.models as models
6 |
7 | from .noatt import MLBNoAtt, MutanNoAtt
8 | from .att import MLBAtt, MutanAtt
9 |
10 | model_names = sorted(name for name in sys.modules[__name__].__dict__
11 | if not name.startswith("__"))# and 'Att' in name)
12 |
13 | def factory(opt, vocab_words, vocab_answers, cuda=True, data_parallel=True):
14 | opt = copy.copy(opt)
15 |
16 | if opt['arch'] in model_names:
17 | model = getattr(sys.modules[__name__], opt['arch'])(opt, vocab_words, vocab_answers)
18 | else:
19 | raise ValueError
20 |
21 | if data_parallel:
22 | model = nn.DataParallel(model).cuda()
23 | if not cuda:
24 | raise ValueError
25 |
26 | if cuda:
27 | model.cuda()
28 |
29 | return model
--------------------------------------------------------------------------------