├── LICENSE.txt ├── README.md ├── assets ├── model_classic.png └── model_rubi.png ├── requirements.txt ├── rubi ├── __init__.py ├── __version__.py ├── compare_vqa2_rubi_val.py ├── compare_vqa2_val.py ├── compare_vqacp2_rubi.py ├── datasets │ ├── __init__.py │ ├── factory.py │ ├── scripts │ │ ├── download_vqa2.sh │ │ └── download_vqacp2.sh │ ├── vqa2.py │ └── vqacp2.py ├── models │ ├── criterions │ │ ├── __init__.py │ │ ├── factory.py │ │ └── rubi_criterion.py │ ├── metrics │ │ ├── __init__.py │ │ ├── factory.py │ │ └── vqa_rubi_metrics.py │ └── networks │ │ ├── __init__.py │ │ ├── baseline_net.py │ │ ├── factory.py │ │ ├── rubi.py │ │ └── utils.py ├── optimizers │ ├── __init__.py │ └── factory.py └── options │ ├── vqa2 │ ├── baseline.yaml │ └── rubi.yaml │ └── vqacp2 │ ├── baseline.yaml │ └── rubi.yaml ├── setup.cfg └── setup.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019+, Remi Cadene, Corentin Dancette 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RUBi : Reducing Unimodal Biases for Visual Question Answering 2 | 3 | This is the code for the NeurIPS 2019 article available here: https://arxiv.org/abs/1906.10169. 4 | 5 | This paper was written by [Rémi Cadene](http://www.remicadene.com/), [Corentin Dancette](https://cdancette.fr), Hedi Ben Younes, [Matthieu Cord](http://webia.lip6.fr/~cord/) and [Devi Parikh](https://www.cc.gatech.edu/~parikh/). 6 | 7 | 8 | **RUBi** is a learning strategy to reduce biases in VQA models. 9 | It relies on a question-only branch plugged at the end of a VQA model. 10 | 11 | 12 |

13 | 14 | 15 |

16 | 17 | #### Summary 18 | 19 | * [Installation](#installation) 20 | * [As a standalone project](#1-as-standalone-project) 21 | * [As a python library](#1-as-a-python-library) 22 | * [Download datasets](#3-download-datasets) 23 | * [Quick start](#quick-start) 24 | * [Train a model](#train-a-model) 25 | * [Evaluate a model](#evaluate-a-model) 26 | * [Reproduce results](#reproduce-results) 27 | * [VQACP2](#vqa-CP-v2-dataset) 28 | * [VQA2](#vqa-v2-dataset) 29 | * [Useful commands](#useful-commands) 30 | * [Authors](#authors) 31 | * [Acknowledgment](#acknowledgment) 32 | 33 | 34 | ## Installation 35 | 36 | We don't provide support for python 2. We advise you to install python 3 with [Anaconda](https://www.continuum.io/downloads). Then, you can create an environment. 37 | 38 | ### 1. As standalone project 39 | 40 | ``` 41 | conda create --name rubi python=3.7 42 | source activate rubi 43 | git clone --recursive https://github.com/cdancette/rubi.bootstrap.pytorch.git 44 | cd rubi.bootstrap.pytorch 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | ### (1. As a python library) 49 | 50 | To install the library 51 | ``` 52 | git clone https://github.com/cdancette/rubi.bootstrap.pytorch.git 53 | python setup.py install 54 | ``` 55 | 56 | Then by importing the `rubi` python module, you can access datasets and models in a simple way. 57 | 58 | ```python 59 | from rubi.models.networks.rubi import RUBiNet 60 | ``` 61 | 62 | 63 | **Note:** This repo is built on top of [block.bootstrap.pytorch](https://github.com/Cadene/block.bootstrap.pytorch). We import VQA2, TDIUC, VGenome from this library. 64 | 65 | ### 2. Download datasets 66 | 67 | Download annotations, images and features for VQA experiments: 68 | ``` 69 | bash rubi/datasets/scripts/download_vqa2.sh 70 | bash rubi/datasets/scripts/download_vqacp2.sh 71 | ``` 72 | 73 | 74 | ## Quick start 75 | 76 | ### The RUBi model 77 | 78 | The main model is RUBi. 79 | 80 | ```python 81 | from rubi.models.networks.rubi import RUBiNet 82 | ``` 83 | 84 | RUBi takes as input another VQA model, adds a question branch around it. The question predictions are merged with the original predictions. 85 | RUBi returns the new predictions that are used to train the VQA model. 86 | 87 | For an example base model, you can check the [baseline model](https://github.com/cdancette/rubi.pytorch/blob/master/rubi/models/networks/baseline_net.py). The model must return the raw predictions (before softmax) in a dictionnary, with the key `logits`. 88 | 89 | 90 | ### Train a model 91 | 92 | The [boostrap/run.py](https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/run.py) file load the options contained in a yaml file, create the corresponding experiment directory and start the training procedure. For instance, you can train our best model on VQA2 by running: 93 | ``` 94 | python -m bootstrap.run -o rubi/options/vqacp2/rubi.yaml 95 | ``` 96 | Then, several files are going to be created in `logs/vqa2/rubi`: 97 | - [options.yaml](https://github.com/Cadene/block.bootstrap.pytorch/blob/master/assets/logs/vrd/block/options.yaml) (copy of options) 98 | - [logs.txt](https://github.com/Cadene/block.bootstrap.pytorch/blob/master/assets/logs/vrd/block/logs.txt) (history of print) 99 | - [logs.json](https://github.com/Cadene/block.bootstrap.pytorch/blob/master/assets/logs/vrd/block/logs.json) (batchs and epochs statistics) 100 | - [view.html](http://htmlpreview.github.io/?https://raw.githubusercontent.com/Cadene/block.bootstrap.pytorch/master/assets/logs/vrd/block/view.html?token=AEdvLlDSYaSn3Hsr7gO5sDBxeyuKNQhEks5cTF6-wA%3D%3D) (learning curves) 101 | - ckpt_last_engine.pth.tar (checkpoints of last epoch) 102 | - ckpt_last_model.pth.tar 103 | - ckpt_last_optimizer.pth.tar 104 | - ckpt_best_eval_epoch.accuracy_top1_engine.pth.tar (checkpoints of best epoch) 105 | - ckpt_best_eval_epoch.accuracy_top1_model.pth.tar 106 | - ckpt_best_eval_epoch.accuracy_top1_optimizer.pth.tar 107 | 108 | Many options are available in the [options directory](https://github.com/cdancette/rubi.bootstrap.pytorch/blob/master/rubi/options). 109 | 110 | ### Evaluate a model 111 | 112 | There is no testing set on VQA-CP v2, our main dataset. The evaluation is done on the validation set. 113 | 114 | For a model trained on VQA v2, you can evaluate your model on the testing set. In this example, [boostrap/run.py](https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/run.py) load the options from your experiment directory, resume the best checkpoint on the validation set and start an evaluation on the testing set instead of the validation set while skipping the training set (train_split is empty). Thanks to `--misc.logs_name`, the logs will be written in the new `logs_predicate.txt` and `logs_predicate.json` files, instead of being appended to the `logs.txt` and `logs.json` files. 115 | ``` 116 | python -m bootstrap.run \ 117 | -o logs/vqa2/rubi/baseline.yaml \ 118 | --exp.resume best_accuracy_top1 \ 119 | --dataset.train_split \ 120 | --dataset.eval_split test \ 121 | --misc.logs_name test 122 | ``` 123 | 124 | 125 | ## Reproduce results 126 | 127 | ### VQA-CP v2 dataset 128 | 129 | Use this simple setup to reproduce our results on the valset of VQA-CP v2. 130 | 131 | Baseline: 132 | 133 | ```bash 134 | python -m bootstrap.run \ 135 | -o rubi/options/vqacp2/baseline.yaml \ 136 | --exp.dir logs/vqacp2/baseline 137 | ``` 138 | 139 | RUBi : 140 | 141 | ```bash 142 | python -m bootstrap.run \ 143 | -o rubi/options/vqacp2/rubi.yaml \ 144 | --exp.dir logs/vqacp2/rubi 145 | ``` 146 | 147 | #### Compare experiments on valset 148 | 149 | You can compare experiments by displaying their best metrics on the valset. 150 | 151 | ``` 152 | python -m rubi.compare_vqacp2_rubi -d logs/vqacp2/rubi logs/vqacp2/baseline 153 | ``` 154 | 155 | ### VQA v2 dataset 156 | 157 | Baseline: 158 | 159 | ```bash 160 | python -m bootstrap.run \ 161 | -o rubi/options/vqa2/baseline.yaml \ 162 | --exp.dir logs/vqa2/baseline 163 | ``` 164 | 165 | RUBi : 166 | 167 | ```bash 168 | python -m bootstrap.run \ 169 | -o rubi/options/vqa2/rubi.yaml \ 170 | --exp.dir logs/vqa2/rubi 171 | ``` 172 | 173 | 174 | You can compare experiments by displaying their best metrics on the valset. 175 | 176 | ``` 177 | python -m rubi.compare_vqa2_rubi_val -d logs/vqa2/rubi logs/vqa2/baseline 178 | ``` 179 | 180 | #### Evaluation on test set 181 | 182 | ```bash 183 | python -m bootstrap.run \ 184 | -o logs/vqa2/rubi/options.yaml \ 185 | --exp.resume best_eval_epoch.accuracy_top1 \ 186 | --dataset.train_split '' \ 187 | --dataset.eval_split test \ 188 | --misc.logs_name test 189 | ``` 190 | 191 | ## Weights of best model 192 | 193 | 194 | The weights for the model trained on VQA-CP v2 can be downloaded here : http://webia.lip6.fr/~cadene/rubi/ckpt_last_model.pth.tar 195 | 196 | To use it : 197 | * Run this command once to create the experiment folder. Cancel it when the training starts 198 | 199 | ```bash 200 | python -m bootstrap.run \ 201 | -o rubi/options/vqacp2/rubi.yaml \ 202 | --exp.dir logs/vqacp2/rubi 203 | ``` 204 | 205 | * Move the downloaded file to the experiment folder, and use the flag `--exp.resume last` to use this checkpoint : 206 | 207 | ```bash 208 | python -m bootstrap.run \ 209 | -o logs/vqacp2/rubi/options.yaml \ 210 | --exp.resume last 211 | ``` 212 | 213 | 214 | ## Useful commands 215 | 216 | ### Use tensorboard instead of plotly 217 | 218 | Instead of creating a `view.html` file, a tensorboard file will be created: 219 | ``` 220 | python -m bootstrap.run -o rubi/options/vqacp2/rubi.yaml \ 221 | --view.name tensorboard 222 | ``` 223 | 224 | ``` 225 | tensorboard --logdir=logs/vqa2 226 | ``` 227 | 228 | You can use plotly and tensorboard at the same time by updating the yaml file like [this one](https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/options/mnist_plotly_tensorboard.yaml#L38). 229 | 230 | 231 | ### Use a specific GPU 232 | 233 | For a specific experiment: 234 | ``` 235 | CUDA_VISIBLE_DEVICES=0 python -m boostrap.run -o rubi/options/vqacp2/rubi.yaml 236 | ``` 237 | 238 | For the current terminal session: 239 | ``` 240 | export CUDA_VISIBLE_DEVICES=0 241 | ``` 242 | 243 | ### Overwrite an option 244 | 245 | The boostrap.pytorch framework makes it easy to overwrite a hyperparameter. In this example, we run an experiment with a non-default learning rate. Thus, I also overwrite the experiment directory path: 246 | ``` 247 | python -m bootstrap.run -o rubi/options/vqacp2/rubi.yaml \ 248 | --optimizer.lr 0.0003 \ 249 | --exp.dir logs/vqacp2/rubi_lr,0.0003 250 | ``` 251 | 252 | ### Resume training 253 | 254 | If a problem occurs, it is easy to resume the last epoch by specifying the options file from the experiment directory while overwritting the `exp.resume` option (default is None): 255 | ``` 256 | python -m bootstrap.run -o logs/vqacp2/rubi/options.yaml \ 257 | --exp.resume last 258 | ``` 259 | 260 | ## Cite 261 | 262 | ``` 263 | @article{cadene2019rubi, 264 | title={RUBi: Reducing Unimodal Biases for Visual Question Answering}, 265 | author={Cadene, Remi and Dancette, Corentin and Cord, Matthieu and Parikh, Devi and others}, 266 | journal={Advances in Neural Information Processing Systems}, 267 | volume={32}, 268 | pages={841--852}, 269 | year={2019} 270 | } 271 | ``` 272 | 273 | ## Authors 274 | 275 | This code was made available by [Corentin Dancette](https://cdancette.fr) and [Rémi Cadene](http://www.remicadene.com/) 276 | 277 | ## Acknowledgment 278 | 279 | Special thanks to the authors of [VQA2](TODO), [TDIUC](TODO), [VisualGenome](TODO) and [VQACP2](TODO), the datasets used in this research project. 280 | 281 | 282 | -------------------------------------------------------------------------------- /assets/model_classic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdancette/rubi.bootstrap.pytorch/36180e51e4692424d6f6ed284a40a01c6fb97104/assets/model_classic.png -------------------------------------------------------------------------------- /assets/model_rubi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdancette/rubi.bootstrap.pytorch/36180e51e4692424d6f6ed284a40a01c6fb97104/assets/model_rubi.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | block.bootstrap.pytorch 2 | h5py 3 | -------------------------------------------------------------------------------- /rubi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdancette/rubi.bootstrap.pytorch/36180e51e4692424d6f6ed284a40a01c6fb97104/rubi/__init__.py -------------------------------------------------------------------------------- /rubi/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.0' 2 | -------------------------------------------------------------------------------- /rubi/compare_vqa2_rubi_val.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from bootstrap.compare import main 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser(description='') 6 | parser.add_argument('-n', '--nb_epochs', default=-1, type=int) 7 | parser.add_argument('-d', '--dir_logs', default='', type=str, nargs='*') 8 | parser.add_argument('-m', '--metrics', type=str, action='append', nargs=3, 9 | metavar=('json', 'name', 'order'), 10 | default=[['logs', 'eval_epoch.accuracy_top1', 'max'], 11 | ['logs', 'eval_epoch.loss', 'min'], 12 | ['logs', 'eval_epoch.loss', 'min'], 13 | ['logs', 'eval_epoch.loss', 'min'], 14 | ['logs', 'eval_epoch.loss', 'min'], 15 | ['logs', 'eval_epoch.loss', 'min'], 16 | ['logs', 'eval_epoch.loss', 'min'], 17 | ['logs_val_oe', 'eval_epoch.overall', 'max'], 18 | ['logs_val_oe', 'eval_epoch.overall', 'max'], 19 | ['logs_q_val_oe', 'eval_epoch.overall', 'max'], 20 | ['logs_v_val_oe', 'eval_epoch.overall', 'max'], 21 | ['logs_mm_val_oe', 'eval_epoch.overall', 'max'], 22 | ['logs_mm_v_val_oe', 'eval_epoch.overall', 'max'], 23 | ['logs_mm_q_val_oe', 'eval_epoch.overall', 'max'], 24 | ['logs_mm_v_q_val_oe', 'eval_epoch.overall', 'max'], 25 | ]) 26 | parser.add_argument('-b', '--best', type=str, nargs=3, 27 | metavar=('json', 'name', 'order'), 28 | default=None) 29 | args = parser.parse_args() 30 | main(args) 31 | -------------------------------------------------------------------------------- /rubi/compare_vqa2_val.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from bootstrap.compare import main 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser(description='') 6 | parser.add_argument('-n', '--nb_epochs', default=-1, type=int) 7 | parser.add_argument('-d', '--dir_logs', default='', type=str, nargs='*') 8 | parser.add_argument('-m', '--metrics', type=str, action='append', nargs=3, 9 | metavar=('json', 'name', 'order'), 10 | default=[['logs', 'eval_epoch.accuracy_top1', 'max'], 11 | ['logs_val_oe', 'eval_epoch.overall', 'max'], 12 | ['logs', 'eval_epoch.loss', 'min']]) 13 | parser.add_argument('-b', '--best', type=str, nargs=3, 14 | metavar=('json', 'name', 'order'), 15 | default=['logs', 'eval_epoch.accuracy_top1', 'max']) 16 | args = parser.parse_args() 17 | main(args) 18 | -------------------------------------------------------------------------------- /rubi/compare_vqacp2_rubi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from bootstrap.compare import main 3 | 4 | if __name__ == '__main__': 5 | parser = argparse.ArgumentParser(description='') 6 | parser.add_argument('-n', '--nb_epochs', default=-1, type=int) 7 | parser.add_argument('-d', '--dir_logs', default='', type=str, nargs='*') 8 | parser.add_argument('-m', '--metrics', type=str, action='append', nargs=3, 9 | metavar=('json', 'name', 'order'), 10 | default=[ 11 | ['logs', 'eval_epoch.accuracy_top1', 'max'], 12 | # overall 13 | ['logs_val_oe', 'eval_epoch.overall', 'max'], 14 | ['logs_rubi_val_oe', 'eval_epoch.overall', 'max'], 15 | ['logs_q_val_oe', 'eval_epoch.overall', 'max'], 16 | # question type 17 | ['logs_val_oe', 'eval_epoch.perAnswerType.yes/no', 'max'], 18 | ['logs_val_oe', 'eval_epoch.perAnswerType.number', 'max'], 19 | ['logs_val_oe', 'eval_epoch.perAnswerType.other', 'max'], 20 | ['logs_rubi_val_oe', 'eval_epoch.perAnswerType.yes/no', 'max'], 21 | ['logs_rubi_val_oe', 'eval_epoch.perAnswerType.number', 'max'], 22 | ['logs_rubi_val_oe', 'eval_epoch.perAnswerType.other', 'max'], 23 | ['logs_q_val_oe', 'eval_epoch.perAnswerType.yes/no', 'max'], 24 | ['logs_q_val_oe', 'eval_epoch.perAnswerType.number', 'max'], 25 | ['logs_q_val_oe', 'eval_epoch.perAnswerType.other', 'max'], 26 | ]) 27 | parser.add_argument('-b', '--best', type=str, nargs=3, 28 | metavar=('json', 'name', 'order'), 29 | default=['logs_val_oe', 'eval_epoch.overall', 'max']) 30 | args = parser.parse_args() 31 | main(args) 32 | -------------------------------------------------------------------------------- /rubi/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdancette/rubi.bootstrap.pytorch/36180e51e4692424d6f6ed284a40a01c6fb97104/rubi/datasets/__init__.py -------------------------------------------------------------------------------- /rubi/datasets/factory.py: -------------------------------------------------------------------------------- 1 | from bootstrap.lib.options import Options 2 | from block.datasets.tdiuc import TDIUC 3 | from block.datasets.vrd import VRD 4 | from block.datasets.vg import VG 5 | from block.datasets.vqa_utils import ListVQADatasets 6 | from .vqa2 import VQA2 7 | from .vqacp2 import VQACP2 8 | 9 | def factory(engine=None): 10 | opt = Options()['dataset'] 11 | 12 | dataset = {} 13 | if opt.get('train_split', None): 14 | dataset['train'] = factory_split(opt['train_split']) 15 | if opt.get('eval_split', None): 16 | dataset['eval'] = factory_split(opt['eval_split']) 17 | 18 | return dataset 19 | 20 | def factory_split(split): 21 | opt = Options()['dataset'] 22 | shuffle = ('train' in split) 23 | 24 | if opt['name'] == 'vqacp2': 25 | assert(split in ['train', 'val', 'test']) 26 | samplingans = (opt['samplingans'] and split == 'train') 27 | 28 | dataset = VQACP2( 29 | dir_data=opt['dir'], 30 | split=split, 31 | batch_size=opt['batch_size'], 32 | nb_threads=opt['nb_threads'], 33 | pin_memory=Options()['misc']['cuda'], 34 | shuffle=shuffle, 35 | nans=opt['nans'], 36 | minwcount=opt['minwcount'], 37 | nlp=opt['nlp'], 38 | proc_split=opt['proc_split'], 39 | samplingans=samplingans, 40 | dir_rcnn=opt['dir_rcnn'], 41 | dir_cnn=opt.get('dir_cnn', None), 42 | dir_vgg16=opt.get('dir_vgg16', None), 43 | ) 44 | 45 | elif opt['name'] == 'vqacpv2-with-testdev': 46 | assert(split in ['train', 'val', 'test']) 47 | samplingans = (opt['samplingans'] and split == 'train') 48 | dataset = VQACP2( 49 | dir_data=opt['dir'], 50 | split=split, 51 | batch_size=opt['batch_size'], 52 | nb_threads=opt['nb_threads'], 53 | pin_memory=Options()['misc']['cuda'], 54 | shuffle=shuffle, 55 | nans=opt['nans'], 56 | minwcount=opt['minwcount'], 57 | nlp=opt['nlp'], 58 | proc_split=opt['proc_split'], 59 | samplingans=samplingans, 60 | dir_rcnn=opt['dir_rcnn'], 61 | dir_cnn=opt.get('dir_cnn', None), 62 | dir_vgg16=opt.get('dir_vgg16', None), 63 | has_testdevset=True, 64 | ) 65 | 66 | elif opt['name'] == 'vqa2': 67 | assert(split in ['train', 'val', 'test']) 68 | samplingans = (opt['samplingans'] and split == 'train') 69 | 70 | if opt['vg']: 71 | assert(opt['proc_split'] == 'trainval') 72 | 73 | # trainvalset 74 | vqa2 = VQA2( 75 | dir_data=opt['dir'], 76 | split='train', 77 | nans=opt['nans'], 78 | minwcount=opt['minwcount'], 79 | nlp=opt['nlp'], 80 | proc_split=opt['proc_split'], 81 | samplingans=samplingans, 82 | dir_rcnn=opt['dir_rcnn']) 83 | 84 | vg = VG( 85 | dir_data=opt['dir_vg'], 86 | split='train', 87 | nans=10000, 88 | minwcount=0, 89 | nlp=opt['nlp'], 90 | dir_rcnn=opt['dir_rcnn_vg']) 91 | 92 | vqa2vg = ListVQADatasets( 93 | [vqa2,vg], 94 | split='train', 95 | batch_size=opt['batch_size'], 96 | nb_threads=opt['nb_threads'], 97 | pin_memory=Options()['misc.cuda'], 98 | shuffle=shuffle) 99 | 100 | if split == 'train': 101 | dataset = vqa2vg 102 | else: 103 | dataset = VQA2( 104 | dir_data=opt['dir'], 105 | split=split, 106 | batch_size=opt['batch_size'], 107 | nb_threads=opt['nb_threads'], 108 | pin_memory=Options()['misc.cuda'], 109 | shuffle=False, 110 | nans=opt['nans'], 111 | minwcount=opt['minwcount'], 112 | nlp=opt['nlp'], 113 | proc_split=opt['proc_split'], 114 | samplingans=samplingans, 115 | dir_rcnn=opt['dir_rcnn']) 116 | dataset.sync_from(vqa2vg) 117 | 118 | else: 119 | dataset = VQA2( 120 | dir_data=opt['dir'], 121 | split=split, 122 | batch_size=opt['batch_size'], 123 | nb_threads=opt['nb_threads'], 124 | pin_memory=Options()['misc.cuda'], 125 | shuffle=shuffle, 126 | nans=opt['nans'], 127 | minwcount=opt['minwcount'], 128 | nlp=opt['nlp'], 129 | proc_split=opt['proc_split'], 130 | samplingans=samplingans, 131 | dir_rcnn=opt['dir_rcnn'], 132 | dir_cnn=opt.get('dir_cnn', None), 133 | ) 134 | 135 | return dataset 136 | -------------------------------------------------------------------------------- /rubi/datasets/scripts/download_vqa2.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/vqa 2 | cd data/vqa 3 | wget http://data.lip6.fr/cadene/block/vqa2.tar.gz 4 | wget http://data.lip6.fr/cadene/block/coco.tar.gz 5 | tar -xzvf vqa2.tar.gz 6 | tar -xzvf coco.tar.gz 7 | 8 | mkdir -p data/vqa/coco/extract_rcnn 9 | cd data/vqa/coco/extract_rcnn 10 | wget http://data.lip6.fr/cadene/block/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36.tar 11 | tar -xvf 2018-04-27_bottom-up-attention_fixed_36.tar 12 | -------------------------------------------------------------------------------- /rubi/datasets/scripts/download_vqacp2.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/vqa 2 | cd data/vqa 3 | wget http://data.lip6.fr/cadene/murel/vqacp2.tar.gz 4 | tar -xzvf vqacp2.tar.gz 5 | -------------------------------------------------------------------------------- /rubi/datasets/vqa2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import copy 4 | import json 5 | import torch 6 | import numpy as np 7 | from os import path as osp 8 | from bootstrap.lib.logger import Logger 9 | from bootstrap.lib.options import Options 10 | from block.datasets.vqa_utils import AbstractVQA 11 | from copy import deepcopy 12 | import random 13 | import tqdm 14 | import h5py 15 | 16 | class VQA2(AbstractVQA): 17 | 18 | def __init__(self, 19 | dir_data='data/vqa2', 20 | split='train', 21 | batch_size=10, 22 | nb_threads=4, 23 | pin_memory=False, 24 | shuffle=False, 25 | nans=1000, 26 | minwcount=10, 27 | nlp='mcb', 28 | proc_split='train', 29 | samplingans=False, 30 | dir_rcnn='data/coco/extract_rcnn', 31 | adversarial=False, 32 | dir_cnn=None 33 | ): 34 | 35 | super(VQA2, self).__init__( 36 | dir_data=dir_data, 37 | split=split, 38 | batch_size=batch_size, 39 | nb_threads=nb_threads, 40 | pin_memory=pin_memory, 41 | shuffle=shuffle, 42 | nans=nans, 43 | minwcount=minwcount, 44 | nlp=nlp, 45 | proc_split=proc_split, 46 | samplingans=samplingans, 47 | has_valset=True, 48 | has_testset=True, 49 | has_answers_occurence=True, 50 | do_tokenize_answers=False) 51 | 52 | self.dir_rcnn = dir_rcnn 53 | self.dir_cnn = dir_cnn 54 | self.load_image_features() 55 | # to activate manually in visualization context (notebo# to activate manually in visualization context (notebook) 56 | self.load_original_annotation = False 57 | 58 | def add_rcnn_to_item(self, item): 59 | path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name'])) 60 | item_rcnn = torch.load(path_rcnn) 61 | item['visual'] = item_rcnn['pooled_feat'] 62 | item['coord'] = item_rcnn['rois'] 63 | item['norm_coord'] = item_rcnn.get('norm_rois', None) 64 | item['nb_regions'] = item['visual'].size(0) 65 | return item 66 | 67 | def add_cnn_to_item(self, item): 68 | image_name = item['image_name'] 69 | if image_name in self.image_names_to_index_train: 70 | index = self.image_names_to_index_train[image_name] 71 | image = torch.tensor(self.image_features_train['att'][index]) 72 | elif image_name in self.image_names_to_index_val: 73 | index = self.image_names_to_index_val[image_name] 74 | image = torch.tensor(self.image_features_val['att'][index]) 75 | image = image.permute(1, 2, 0).view(196, 2048) 76 | item['visual'] = image 77 | return item 78 | 79 | def load_image_features(self): 80 | if self.dir_cnn: 81 | filename_train = os.path.join(self.dir_cnn, 'trainset.hdf5') 82 | filename_val = os.path.join(self.dir_cnn, 'valset.hdf5') 83 | Logger()(f"Opening file {filename_train}, {filename_val}") 84 | self.image_features_train = h5py.File(filename_train, 'r', swmr=True) 85 | self.image_features_val = h5py.File(filename_val, 'r', swmr=True) 86 | # load txt 87 | with open(os.path.join(self.dir_cnn, 'trainset.txt'.format(self.split)), 'r') as f: 88 | self.image_names_to_index_train = {} 89 | for i, line in enumerate(f): 90 | self.image_names_to_index_train[line.strip()] = i 91 | with open(os.path.join(self.dir_cnn, 'valset.txt'.format(self.split)), 'r') as f: 92 | self.image_names_to_index_val = {} 93 | for i, line in enumerate(f): 94 | self.image_names_to_index_val[line.strip()] = i 95 | 96 | def __getitem__(self, index): 97 | item = {} 98 | item['index'] = index 99 | 100 | # Process Question (word token) 101 | question = self.dataset['questions'][index] 102 | if self.load_original_annotation: 103 | item['original_question'] = question 104 | 105 | item['question_id'] = question['question_id'] 106 | 107 | item['question'] = torch.tensor(question['question_wids'], dtype=torch.long) 108 | item['lengths'] = torch.tensor([len(question['question_wids'])], dtype=torch.long) 109 | item['image_name'] = question['image_name'] 110 | 111 | # Process Object, Attribut and Relational features 112 | # Process Object, Attribut and Relational features 113 | if self.dir_rcnn: 114 | item = self.add_rcnn_to_item(item) 115 | elif self.dir_cnn: 116 | item = self.add_cnn_to_item(item) 117 | 118 | # Process Answer if exists 119 | if 'annotations' in self.dataset: 120 | annotation = self.dataset['annotations'][index] 121 | if self.load_original_annotation: 122 | item['original_annotation'] = annotation 123 | if 'train' in self.split and self.samplingans: 124 | proba = annotation['answers_count'] 125 | proba = proba / np.sum(proba) 126 | item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba)) 127 | else: 128 | item['answer_id'] = annotation['answer_id'] 129 | item['class_id'] = torch.tensor([item['answer_id']], dtype=torch.long) 130 | item['answer'] = annotation['answer'] 131 | item['question_type'] = annotation['question_type'] 132 | else: 133 | if item['question_id'] in self.is_qid_testdev: 134 | item['is_testdev'] = True 135 | else: 136 | item['is_testdev'] = False 137 | 138 | # if Options()['model.network.name'] == 'xmn_net': 139 | # num_feat = 36 140 | # relation_mask = np.zeros((num_feat, num_feat)) 141 | # boxes = item['coord'] 142 | # for i in range(num_feat): 143 | # for j in range(i+1, num_feat): 144 | # # if there is no overlap between two bounding box 145 | # if boxes[0,i]>boxes[2,j] or boxes[0,j]>boxes[2,i] or boxes[1,i]>boxes[3,j] or boxes[1,j]>boxes[3,i]: 146 | # pass 147 | # else: 148 | # relation_mask[i,j] = relation_mask[j,i] = 1 149 | # relation_mask = torch.from_numpy(relation_mask).byte() 150 | # item['relation_mask'] = relation_mask 151 | 152 | return item 153 | 154 | def download(self): 155 | dir_zip = osp.join(self.dir_raw, 'zip') 156 | os.system('mkdir -p '+dir_zip) 157 | dir_ann = osp.join(self.dir_raw, 'annotations') 158 | os.system('mkdir -p '+dir_ann) 159 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Train_mscoco.zip -P '+dir_zip) 160 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Val_mscoco.zip -P '+dir_zip) 161 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Test_mscoco.zip -P '+dir_zip) 162 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Train_mscoco.zip -P '+dir_zip) 163 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Val_mscoco.zip -P '+dir_zip) 164 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Train_mscoco.zip')+' -d '+dir_ann) 165 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Val_mscoco.zip')+' -d '+dir_ann) 166 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Test_mscoco.zip')+' -d '+dir_ann) 167 | os.system('unzip '+osp.join(dir_zip, 'v2_Annotations_Train_mscoco.zip')+' -d '+dir_ann) 168 | os.system('unzip '+osp.join(dir_zip, 'v2_Annotations_Val_mscoco.zip')+' -d '+dir_ann) 169 | os.system('mv '+osp.join(dir_ann, 'v2_mscoco_train2014_annotations.json')+' ' 170 | +osp.join(dir_ann, 'mscoco_train2014_annotations.json')) 171 | os.system('mv '+osp.join(dir_ann, 'v2_mscoco_val2014_annotations.json')+' ' 172 | +osp.join(dir_ann, 'mscoco_val2014_annotations.json')) 173 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_train2014_questions.json')+' ' 174 | +osp.join(dir_ann, 'OpenEnded_mscoco_train2014_questions.json')) 175 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_val2014_questions.json')+' ' 176 | +osp.join(dir_ann, 'OpenEnded_mscoco_val2014_questions.json')) 177 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_test2015_questions.json')+' ' 178 | +osp.join(dir_ann, 'OpenEnded_mscoco_test2015_questions.json')) 179 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_test-dev2015_questions.json')+' ' 180 | +osp.join(dir_ann, 'OpenEnded_mscoco_test-dev2015_questions.json')) 181 | -------------------------------------------------------------------------------- /rubi/datasets/vqacp2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import copy 4 | import json 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | from os import path as osp 9 | from bootstrap.lib.logger import Logger 10 | from block.datasets.vqa_utils import AbstractVQA 11 | from copy import deepcopy 12 | import random 13 | import h5py 14 | 15 | class VQACP2(AbstractVQA): 16 | 17 | def __init__(self, 18 | dir_data='data/vqa/vqacp2', 19 | split='train', 20 | batch_size=80, 21 | nb_threads=4, 22 | pin_memory=False, 23 | shuffle=False, 24 | nans=1000, 25 | minwcount=10, 26 | nlp='mcb', 27 | proc_split='train', 28 | samplingans=False, 29 | dir_rcnn='data/coco/extract_rcnn', 30 | dir_cnn=None, 31 | dir_vgg16=None, 32 | has_testdevset=False, 33 | ): 34 | super(VQACP2, self).__init__( 35 | dir_data=dir_data, 36 | split=split, 37 | batch_size=batch_size, 38 | nb_threads=nb_threads, 39 | pin_memory=pin_memory, 40 | shuffle=shuffle, 41 | nans=nans, 42 | minwcount=minwcount, 43 | nlp=nlp, 44 | proc_split=proc_split, 45 | samplingans=samplingans, 46 | has_valset=True, 47 | has_testset=False, 48 | has_testdevset=has_testdevset, 49 | has_testset_anno=False, 50 | has_answers_occurence=True, 51 | do_tokenize_answers=False) 52 | self.dir_rcnn = dir_rcnn 53 | self.dir_cnn = dir_cnn 54 | self.dir_vgg16 = dir_vgg16 55 | self.load_image_features() 56 | self.load_original_annotation = False 57 | 58 | def add_rcnn_to_item(self, item): 59 | path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name'])) 60 | item_rcnn = torch.load(path_rcnn) 61 | item['visual'] = item_rcnn['pooled_feat'] 62 | item['coord'] = item_rcnn['rois'] 63 | item['norm_coord'] = item_rcnn['norm_rois'] 64 | item['nb_regions'] = item['visual'].size(0) 65 | return item 66 | 67 | def load_image_features(self): 68 | if self.dir_cnn: 69 | filename_train = os.path.join(self.dir_cnn, 'trainset.hdf5') 70 | filename_val = os.path.join(self.dir_cnn, 'valset.hdf5') 71 | Logger()(f"Opening file {filename_train}, {filename_val}") 72 | self.image_features_train = h5py.File(filename_train, 'r', swmr=True) 73 | self.image_features_val = h5py.File(filename_val, 'r', swmr=True) 74 | # load txt 75 | with open(os.path.join(self.dir_cnn, 'trainset.txt'.format(self.split)), 'r') as f: 76 | self.image_names_to_index_train = {} 77 | for i, line in enumerate(f): 78 | self.image_names_to_index_train[line.strip()] = i 79 | with open(os.path.join(self.dir_cnn, 'valset.txt'.format(self.split)), 'r') as f: 80 | self.image_names_to_index_val = {} 81 | for i, line in enumerate(f): 82 | self.image_names_to_index_val[line.strip()] = i 83 | elif self.dir_vgg16: 84 | # list filenames 85 | self.filenames_train = os.listdir(os.path.join(self.dir_vgg16, 'train')) 86 | self.filenames_val = os.listdir(os.path.join(self.dir_vgg16, 'val')) 87 | 88 | 89 | def add_vgg_to_item(self, item): 90 | image_name = item['image_name'] 91 | filename = image_name + '.pth' 92 | if filename in self.filenames_train: 93 | path = os.path.join(self.dir_vgg16, 'train', filename) 94 | elif filename in self.filenames_val: 95 | path = os.path.join(self.dir_vgg16, 'val', filename) 96 | visual = torch.load(path) 97 | visual = visual.permute(1, 2, 0).view(14*14, 512) 98 | item['visual'] = visual 99 | return item 100 | 101 | def add_cnn_to_item(self, item): 102 | image_name = item['image_name'] 103 | if image_name in self.image_names_to_index_train: 104 | index = self.image_names_to_index_train[image_name] 105 | image = torch.tensor(self.image_features_train['att'][index]) 106 | elif image_name in self.image_names_to_index_val: 107 | index = self.image_names_to_index_val[image_name] 108 | image = torch.tensor(self.image_features_val['att'][index]) 109 | image = image.permute(1, 2, 0).view(196, 2048) 110 | item['visual'] = image 111 | return item 112 | 113 | def __getitem__(self, index): 114 | item = {} 115 | item['index'] = index 116 | 117 | # Process Question (word token) 118 | question = self.dataset['questions'][index] 119 | if self.load_original_annotation: 120 | item['original_question'] = question 121 | item['question_id'] = question['question_id'] 122 | item['question'] = torch.LongTensor(question['question_wids']) 123 | item['lengths'] = torch.LongTensor([len(question['question_wids'])]) 124 | item['image_name'] = question['image_name'] 125 | 126 | # Process Object, Attribut and Relational features 127 | if self.dir_rcnn: 128 | item = self.add_rcnn_to_item(item) 129 | elif self.dir_cnn: 130 | item = self.add_cnn_to_item(item) 131 | elif self.dir_vgg16: 132 | item = self.add_vgg_to_item(item) 133 | 134 | # Process Answer if exists 135 | if 'annotations' in self.dataset: 136 | annotation = self.dataset['annotations'][index] 137 | if self.load_original_annotation: 138 | item['original_annotation'] = annotation 139 | if 'train' in self.split and self.samplingans: 140 | proba = annotation['answers_count'] 141 | proba = proba / np.sum(proba) 142 | item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba)) 143 | else: 144 | item['answer_id'] = annotation['answer_id'] 145 | item['class_id'] = torch.LongTensor([item['answer_id']]) 146 | item['answer'] = annotation['answer'] 147 | item['question_type'] = annotation['question_type'] 148 | 149 | return item 150 | 151 | def download(self): 152 | dir_ann = osp.join(self.dir_raw, 'annotations') 153 | os.system('mkdir -p '+dir_ann) 154 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json -P' + dir_ann) 155 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json -P' + dir_ann) 156 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json -P' + dir_ann) 157 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json -P' + dir_ann) 158 | train_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v2_train_questions.json")))} 159 | val_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v2_test_questions.json")))} 160 | train_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v2_train_annotations.json")))} 161 | val_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v2_test_annotations.json")))} 162 | train_q['info'] = {} 163 | train_q['data_type'] = 'mscoco' 164 | train_q['data_subtype'] = "train2014cp" 165 | train_q['task_type'] = "Open-Ended" 166 | train_q['license'] = {} 167 | val_q['info'] = {} 168 | val_q['data_type'] = 'mscoco' 169 | val_q['data_subtype'] = "val2014cp" 170 | val_q['task_type'] = "Open-Ended" 171 | val_q['license'] = {} 172 | for k in ["info", 'data_type','data_subtype', 'license']: 173 | train_ann[k] = train_q[k] 174 | val_ann[k] = val_q[k] 175 | with open(osp.join(dir_ann, "OpenEnded_mscoco_train2014_questions.json"), 'w') as F: 176 | F.write(json.dumps(train_q)) 177 | with open(osp.join(dir_ann, "OpenEnded_mscoco_val2014_questions.json"), 'w') as F: 178 | F.write(json.dumps(val_q)) 179 | with open(osp.join(dir_ann, "mscoco_train2014_annotations.json"), 'w') as F: 180 | F.write(json.dumps(train_ann)) 181 | with open(osp.join(dir_ann, "mscoco_val2014_annotations.json"), 'w') as F: 182 | F.write(json.dumps(val_ann)) 183 | 184 | def add_image_names(self, dataset): 185 | for q in dataset['questions']: 186 | q['image_name'] = 'COCO_%s_%012d.jpg'%(q['coco_split'],q['image_id']) 187 | return dataset 188 | 189 | -------------------------------------------------------------------------------- /rubi/models/criterions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdancette/rubi.bootstrap.pytorch/36180e51e4692424d6f6ed284a40a01c6fb97104/rubi/models/criterions/__init__.py -------------------------------------------------------------------------------- /rubi/models/criterions/factory.py: -------------------------------------------------------------------------------- 1 | from bootstrap.lib.options import Options 2 | from block.models.criterions.vqa_cross_entropy import VQACrossEntropyLoss 3 | from .rubi_criterion import RUBiCriterion 4 | 5 | def factory(engine, mode): 6 | name = Options()['model.criterion.name'] 7 | split = engine.dataset[mode].split 8 | eval_only = 'train' not in engine.dataset 9 | 10 | opt = Options()['model.criterion'] 11 | if split == "test" and 'tdiuc' not in Options()['dataset.name']: 12 | return None 13 | if name == 'vqa_cross_entropy': 14 | criterion = VQACrossEntropyLoss() 15 | elif name == "rubi_criterion": 16 | criterion = RUBiCriterion( 17 | question_loss_weight=opt['question_loss_weight'] 18 | ) 19 | else: 20 | raise ValueError(name) 21 | return criterion 22 | -------------------------------------------------------------------------------- /rubi/models/criterions/rubi_criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from bootstrap.lib.logger import Logger 5 | from bootstrap.lib.options import Options 6 | 7 | class RUBiCriterion(nn.Module): 8 | 9 | def __init__(self, question_loss_weight=1.0): 10 | super().__init__() 11 | 12 | Logger()(f'RUBiCriterion, with question_loss_weight = ({question_loss_weight})') 13 | 14 | self.question_loss_weight = question_loss_weight 15 | self.fusion_loss = nn.CrossEntropyLoss() 16 | self.question_loss = nn.CrossEntropyLoss() 17 | 18 | def forward(self, net_out, batch): 19 | out = {} 20 | # logits = net_out['logits'] 21 | logits_q = net_out['logits_q'] 22 | logits_rubi = net_out['logits_rubi'] 23 | class_id = batch['class_id'].squeeze(1) 24 | fusion_loss = self.fusion_loss(logits_rubi, class_id) 25 | question_loss = self.question_loss(logits_q, class_id) 26 | loss = fusion_loss + self.question_loss_weight * question_loss 27 | 28 | out['loss'] = loss 29 | out['loss_mm_q'] = fusion_loss 30 | out['loss_q'] = question_loss 31 | return out 32 | -------------------------------------------------------------------------------- /rubi/models/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdancette/rubi.bootstrap.pytorch/36180e51e4692424d6f6ed284a40a01c6fb97104/rubi/models/metrics/__init__.py -------------------------------------------------------------------------------- /rubi/models/metrics/factory.py: -------------------------------------------------------------------------------- 1 | from bootstrap.lib.options import Options 2 | from block.models.metrics.vqa_accuracies import VQAAccuracies 3 | from .vqa_rubi_metrics import VQARUBiMetrics 4 | 5 | def factory(engine, mode): 6 | name = Options()['model.metric.name'] 7 | metric = None 8 | 9 | if name == 'vqa_accuracies': 10 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name']) 11 | if mode == 'train': 12 | split = engine.dataset['train'].split 13 | if split == 'train': 14 | metric = VQAAccuracies(engine, 15 | mode='train', 16 | open_ended=open_ended, 17 | tdiuc=True, 18 | dir_exp=Options()['exp.dir'], 19 | dir_vqa=Options()['dataset.dir']) 20 | elif split == 'trainval': 21 | metric = None 22 | else: 23 | raise ValueError(split) 24 | elif mode == 'eval': 25 | metric = VQAAccuracies(engine, 26 | mode='eval', 27 | open_ended=open_ended, 28 | tdiuc=('tdiuc' in Options()['dataset.name'] or Options()['dataset.eval_split'] != 'test'), 29 | dir_exp=Options()['exp.dir'], 30 | dir_vqa=Options()['dataset.dir']) 31 | else: 32 | metric = None 33 | 34 | elif name == "vqa_rubi_metrics": 35 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name']) 36 | metric = VQARUBiMetrics(engine, 37 | mode=mode, 38 | open_ended=open_ended, 39 | tdiuc=True, 40 | dir_exp=Options()['exp.dir'], 41 | dir_vqa=Options()['dataset.dir'] 42 | ) 43 | 44 | else: 45 | raise ValueError(name) 46 | return metric 47 | -------------------------------------------------------------------------------- /rubi/models/metrics/vqa_rubi_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import json 5 | from scipy import stats 6 | import numpy as np 7 | from collections import defaultdict 8 | 9 | from bootstrap.models.metrics.accuracy import accuracy 10 | from block.models.metrics.vqa_accuracies import VQAAccuracies 11 | from bootstrap.lib.logger import Logger 12 | from bootstrap.lib.options import Options 13 | from bootstrap.lib.logger import Logger 14 | 15 | class VQAAccuracy(nn.Module): 16 | 17 | def __init__(self, topk=[1,5]): 18 | super().__init__() 19 | self.topk = topk 20 | 21 | def forward(self, cri_out, net_out, batch): 22 | out = {} 23 | class_id = batch['class_id'].data.cpu() 24 | for key in ['', '_rubi', '_q']: 25 | logits = net_out[f'logits{key}'].data.cpu() 26 | acc_out = accuracy(logits, class_id, topk=self.topk) 27 | for i, k in enumerate(self.topk): 28 | out[f'accuracy{key}_top{k}'] = acc_out[i] 29 | return out 30 | 31 | 32 | class VQARUBiMetrics(VQAAccuracies): 33 | 34 | def __init__(self, *args, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.accuracy = VQAAccuracy() 37 | self.rm_dir_rslt = 1 if Options()['dataset.train_split'] is not None else 0 38 | 39 | def forward(self, cri_out, net_out, batch): 40 | out = {} 41 | if self.accuracy is not None: 42 | out = self.accuracy(cri_out, net_out, batch) 43 | 44 | # add answers and answer_ids keys to net_out 45 | net_out = self.engine.model.network.process_answers(net_out) 46 | 47 | batch_size = len(batch['index']) 48 | for i in range(batch_size): 49 | 50 | # Open Ended Accuracy (VQA-VQA2) 51 | if self.open_ended: 52 | for key in ['', '_rubi', '_q']: 53 | pred_item = { 54 | 'question_id': batch['question_id'][i], 55 | 'answer': net_out[f'answers{key}'][i] 56 | } 57 | self.results[key].append(pred_item) 58 | 59 | if self.dataset.split == 'test': 60 | pred_item = { 61 | 'question_id': batch['question_id'][i], 62 | 'answer': net_out[f'answers'][i] 63 | } 64 | if 'is_testdev' in batch and batch['is_testdev'][i]: 65 | self.results_testdev.append(pred_item) 66 | 67 | if self.logits['tensor'] is None: 68 | self.logits['tensor'] = torch.FloatTensor(len(self.dataset), logits.size(1)) 69 | 70 | self.logits['tensor'][self.idx] = logits[i] 71 | self.logits['qid_to_idx'][batch['question_id'][i]] = self.idx 72 | 73 | self.idx += 1 74 | 75 | # TDIUC metrics 76 | if self.tdiuc: 77 | gt_aid = batch['answer_id'][i] 78 | gt_ans = batch['answer'][i] 79 | gt_type = batch['question_type'][i] 80 | self.gt_types.append(gt_type) 81 | if gt_ans in self.ans_to_aid: 82 | self.gt_aids.append(gt_aid) 83 | else: 84 | self.gt_aids.append(-1) 85 | self.gt_aid_not_found += 1 86 | 87 | for key in ['', '_rubi', '_q']: 88 | qid = batch['question_id'][i] 89 | pred_aid = net_out[f'answer_ids{key}'][i] 90 | self.pred_aids[key].append(pred_aid) 91 | 92 | self.res_by_type[key][gt_type+'_pred'].append(pred_aid) 93 | 94 | if gt_ans in self.ans_to_aid: 95 | self.res_by_type[key][gt_type+'_gt'].append(gt_aid) 96 | if gt_aid == pred_aid: 97 | self.res_by_type[key][gt_type+'_t'].append(pred_aid) 98 | else: 99 | self.res_by_type[key][gt_type+'_f'].append(pred_aid) 100 | else: 101 | self.res_by_type[key][gt_type+'_gt'].append(-1) 102 | self.res_by_type[key][gt_type+'_f'].append(pred_aid) 103 | return out 104 | 105 | def reset_oe(self): 106 | self.results = dict() 107 | self.dir_rslt = dict() 108 | self.path_rslt = dict() 109 | for key in ['', '_q', '_rubi']: 110 | self.results[key] = [] 111 | self.dir_rslt[key] = os.path.join( 112 | self.dir_exp, 113 | f'results{key}', 114 | self.dataset.split, 115 | 'epoch,{}'.format(self.engine.epoch)) 116 | os.system('mkdir -p '+self.dir_rslt[key]) 117 | self.path_rslt[key] = os.path.join( 118 | self.dir_rslt[key], 119 | 'OpenEnded_mscoco_{}_model_results.json'.format( 120 | self.dataset.get_subtype())) 121 | 122 | if self.dataset.split == 'test': 123 | pass 124 | # self.results_testdev = [] 125 | # self.path_rslt_testdev = os.path.join( 126 | # self.dir_rslt, 127 | # 'OpenEnded_mscoco_{}_model_results.json'.format( 128 | # self.dataset.get_subtype(testdev=True))) 129 | 130 | # self.path_logits = os.path.join(self.dir_rslt, 'logits.pth') 131 | # os.system('mkdir -p '+os.path.dirname(self.path_logits)) 132 | 133 | # self.logits = {} 134 | # self.logits['aid_to_ans'] = self.engine.model.network.aid_to_ans 135 | # self.logits['qid_to_idx'] = {} 136 | # self.logits['tensor'] = None 137 | 138 | # self.idx = 0 139 | 140 | # path_aid_to_ans = os.path.join(self.dir_rslt, 'aid_to_ans.json') 141 | # with open(path_aid_to_ans, 'w') as f: 142 | # json.dump(self.engine.model.network.aid_to_ans, f) 143 | 144 | 145 | def reset_tdiuc(self): 146 | self.pred_aids = defaultdict(list) 147 | self.gt_aids = [] 148 | self.gt_types = [] 149 | self.gt_aid_not_found = 0 150 | self.res_by_type = {key: defaultdict(list) for key in ['', '_rubi', '_q']} 151 | 152 | 153 | def compute_oe_accuracy(self): 154 | logs_name_prefix = Options()['misc'].get('logs_name', '') or '' 155 | 156 | for key in ['', '_rubi', '_q']: 157 | logs_name = (logs_name_prefix + key) or "logs" 158 | with open(self.path_rslt[key], 'w') as f: 159 | json.dump(self.results[key], f) 160 | 161 | # if self.dataset.split == 'test': 162 | # with open(self.path_rslt_testdev, 'w') as f: 163 | # json.dump(self.results_testdev, f) 164 | 165 | if 'test' not in self.dataset.split: 166 | call_to_prog = 'python -m block.models.metrics.compute_oe_accuracy '\ 167 | + '--dir_vqa {} --dir_exp {} --dir_rslt {} --epoch {} --split {} --logs_name {} --rm {} &'\ 168 | .format(self.dir_vqa, self.dir_exp, self.dir_rslt[key], self.engine.epoch, self.dataset.split, logs_name, self.rm_dir_rslt) 169 | Logger()('`'+call_to_prog+'`') 170 | os.system(call_to_prog) 171 | 172 | 173 | def compute_tdiuc_metrics(self): 174 | Logger()('{} of validation answers were not found in ans_to_aid'.format(self.gt_aid_not_found)) 175 | 176 | for key in ['', '_rubi', '_q']: 177 | Logger()(f'Computing TDIUC metrics for logits{key}') 178 | accuracy = float(100*np.mean(np.array(self.pred_aids[key])==np.array(self.gt_aids))) 179 | Logger()('Overall Traditional Accuracy is {:.2f}'.format(accuracy)) 180 | Logger().log_value('{}_epoch.tdiuc.accuracy{}'.format(self.mode, key), accuracy, should_print=False) 181 | 182 | types = list(set(self.gt_types)) 183 | sum_acc = [] 184 | eps = 1e-10 185 | 186 | Logger()('---------------------------------------') 187 | Logger()('Not using per-answer normalization...') 188 | for tp in types: 189 | acc = 100*(len(self.res_by_type[key][tp+'_t'])/len(self.res_by_type[key][tp+'_t']+self.res_by_type[key][tp+'_f'])) 190 | sum_acc.append(acc+eps) 191 | Logger()(f"Accuracy {key} for class '{tp}' is {acc:.2f}") 192 | Logger().log_value('{}_epoch.tdiuc{}.perQuestionType.{}'.format(self.mode, key, tp), acc, should_print=False) 193 | 194 | acc_mpt_a = float(np.mean(np.array(sum_acc))) 195 | Logger()('Arithmetic MPT Accuracy {} is {:.2f}'.format(key, acc_mpt_a)) 196 | Logger().log_value('{}_epoch.tdiuc{}.acc_mpt_a'.format(self.mode, key), acc_mpt_a, should_print=False) 197 | 198 | acc_mpt_h = float(stats.hmean(sum_acc)) 199 | Logger()('Harmonic MPT Accuracy {} is {:.2f}'.format(key, acc_mpt_h)) 200 | Logger().log_value('{}_epoch.tdiuc{}.acc_mpt_h'.format(self.mode, key), acc_mpt_h, should_print=False) 201 | 202 | Logger()('---------------------------------------') 203 | Logger()('Using per-answer normalization...') 204 | for tp in types: 205 | per_ans_stat = defaultdict(int) 206 | for g,p in zip(self.res_by_type[key][tp+'_gt'],self.res_by_type[key][tp+'_pred']): 207 | per_ans_stat[str(g)+'_gt']+=1 208 | if g==p: 209 | per_ans_stat[str(g)]+=1 210 | unq_acc = 0 211 | for unq_ans in set(self.res_by_type[key][tp+'_gt']): 212 | acc_curr_ans = per_ans_stat[str(unq_ans)]/per_ans_stat[str(unq_ans)+'_gt'] 213 | unq_acc +=acc_curr_ans 214 | acc = 100*unq_acc/len(set(self.res_by_type[key][tp+'_gt'])) 215 | sum_acc.append(acc+eps) 216 | Logger()("Accuracy {} for class '{}' is {:.2f}".format(key, tp, acc)) 217 | Logger().log_value('{}_epoch.tdiuc{}.perQuestionType_norm.{}'.format(self.mode, key, tp), acc, should_print=False) 218 | 219 | acc_mpt_a = float(np.mean(np.array(sum_acc))) 220 | Logger()('Arithmetic MPT Accuracy is {:.2f}'.format(acc_mpt_a)) 221 | Logger().log_value('{}_epoch.tdiuc{}.acc_mpt_a_norm'.format(self.mode, key), acc_mpt_a, should_print=False) 222 | 223 | acc_mpt_h = float(stats.hmean(sum_acc)) 224 | Logger()('Harmonic MPT Accuracy is {:.2f}'.format(acc_mpt_h)) 225 | Logger().log_value('{}_epoch.tdiuc{}.acc_mpt_h_norm'.format(self.mode, key), acc_mpt_h, should_print=False) 226 | -------------------------------------------------------------------------------- /rubi/models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdancette/rubi.bootstrap.pytorch/36180e51e4692424d6f6ed284a40a01c6fb97104/rubi/models/networks/__init__.py -------------------------------------------------------------------------------- /rubi/models/networks/baseline_net.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import itertools 3 | import os 4 | import numpy as np 5 | import scipy 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from bootstrap.lib.options import Options 10 | from bootstrap.lib.logger import Logger 11 | import block 12 | from block.models.networks.vqa_net import factory_text_enc 13 | from block.models.networks.mlp import MLP 14 | 15 | from .utils import mask_softmax 16 | 17 | class BaselineNet(nn.Module): 18 | 19 | def __init__(self, 20 | txt_enc={}, 21 | self_q_att=False, 22 | agg={}, 23 | classif={}, 24 | wid_to_word={}, 25 | word_to_wid={}, 26 | aid_to_ans=[], 27 | ans_to_aid={}, 28 | fusion={}, 29 | residual=False, 30 | ): 31 | super().__init__() 32 | self.self_q_att = self_q_att 33 | self.agg = agg 34 | assert self.agg['type'] in ['max', 'mean'] 35 | self.classif = classif 36 | self.wid_to_word = wid_to_word 37 | self.word_to_wid = word_to_wid 38 | self.aid_to_ans = aid_to_ans 39 | self.ans_to_aid = ans_to_aid 40 | self.fusion = fusion 41 | self.residual = residual 42 | 43 | # Modules 44 | self.txt_enc = self.get_text_enc(self.wid_to_word, txt_enc) 45 | if self.self_q_att: 46 | self.q_att_linear0 = nn.Linear(2400, 512) 47 | self.q_att_linear1 = nn.Linear(512, 2) 48 | 49 | self.fusion_module = block.factory_fusion(self.fusion) 50 | 51 | if self.classif['mlp']['dimensions'][-1] != len(self.aid_to_ans): 52 | Logger()(f"Warning, the classif_mm output dimension ({self.classif['mlp']['dimensions'][-1]})" 53 | f"doesn't match the number of answers ({len(self.aid_to_ans)}). Modifying the output dimension.") 54 | self.classif['mlp']['dimensions'][-1] = len(self.aid_to_ans) 55 | 56 | self.classif_module = MLP(**self.classif['mlp']) 57 | 58 | Logger().log_value('nparams', 59 | sum(p.numel() for p in self.parameters() if p.requires_grad), 60 | should_print=True) 61 | 62 | Logger().log_value('nparams_txt_enc', 63 | self.get_nparams_txt_enc(), 64 | should_print=True) 65 | 66 | 67 | def get_text_enc(self, vocab_words, options): 68 | """ 69 | returns the text encoding network. 70 | """ 71 | return factory_text_enc(self.wid_to_word, options) 72 | 73 | def get_nparams_txt_enc(self): 74 | params = [p.numel() for p in self.txt_enc.parameters() if p.requires_grad] 75 | if self.self_q_att: 76 | params += [p.numel() for p in self.q_att_linear0.parameters() if p.requires_grad] 77 | params += [p.numel() for p in self.q_att_linear1.parameters() if p.requires_grad] 78 | return sum(params) 79 | 80 | def process_fusion(self, q, mm): 81 | bsize = mm.shape[0] 82 | n_regions = mm.shape[1] 83 | 84 | mm = mm.contiguous().view(bsize*n_regions, -1) 85 | mm = self.fusion_module([q, mm]) 86 | mm = mm.view(bsize, n_regions, -1) 87 | return mm 88 | 89 | def forward(self, batch): 90 | v = batch['visual'] 91 | q = batch['question'] 92 | l = batch['lengths'].data 93 | c = batch['norm_coord'] 94 | nb_regions = batch.get('nb_regions') 95 | bsize = v.shape[0] 96 | n_regions = v.shape[1] 97 | 98 | out = {} 99 | 100 | q = self.process_question(q, l,) 101 | out['q_emb'] = q 102 | q_expand = q[:,None,:].expand(bsize, n_regions, q.shape[1]) 103 | q_expand = q_expand.contiguous().view(bsize*n_regions, -1) 104 | 105 | mm = self.process_fusion(q_expand, v,) 106 | 107 | if self.residual: 108 | mm = v + mm 109 | 110 | if self.agg['type'] == 'max': 111 | mm, mm_argmax = torch.max(mm, 1) 112 | elif self.agg['type'] == 'mean': 113 | mm = mm.mean(1) 114 | 115 | out['mm'] = mm 116 | out['mm_argmax'] = mm_argmax 117 | 118 | logits = self.classif_module(mm) 119 | out['logits'] = logits 120 | return out 121 | 122 | def process_question(self, q, l, txt_enc=None, q_att_linear0=None, q_att_linear1=None): 123 | if txt_enc is None: 124 | txt_enc = self.txt_enc 125 | if q_att_linear0 is None: 126 | q_att_linear0 = self.q_att_linear0 127 | if q_att_linear1 is None: 128 | q_att_linear1 = self.q_att_linear1 129 | q_emb = txt_enc.embedding(q) 130 | 131 | q, _ = txt_enc.rnn(q_emb) 132 | 133 | if self.self_q_att: 134 | q_att = q_att_linear0(q) 135 | q_att = F.relu(q_att) 136 | q_att = q_att_linear1(q_att) 137 | q_att = mask_softmax(q_att, l) 138 | #self.q_att_coeffs = q_att 139 | if q_att.size(2) > 1: 140 | q_atts = torch.unbind(q_att, dim=2) 141 | q_outs = [] 142 | for q_att in q_atts: 143 | q_att = q_att.unsqueeze(2) 144 | q_att = q_att.expand_as(q) 145 | q_out = q_att*q 146 | q_out = q_out.sum(1) 147 | q_outs.append(q_out) 148 | q = torch.cat(q_outs, dim=1) 149 | else: 150 | q_att = q_att.expand_as(q) 151 | q = q_att * q 152 | q = q.sum(1) 153 | else: 154 | # l contains the number of words for each question 155 | # in case of multi-gpus it must be a Tensor 156 | # thus we convert it into a list during the forward pass 157 | l = list(l.data[:,0]) 158 | q = txt_enc._select_last(q, l) 159 | 160 | return q 161 | 162 | def process_answers(self, out, key=''): 163 | batch_size = out[f'logits{key}'].shape[0] 164 | _, pred = out[f'logits{key}'].data.max(1) 165 | pred.squeeze_() 166 | if batch_size != 1: 167 | out[f'answers{key}'] = [self.aid_to_ans[pred[i].item()] for i in range(batch_size)] 168 | out[f'answer_ids{key}'] = [pred[i].item() for i in range(batch_size)] 169 | else: 170 | out[f'answers{key}'] = [self.aid_to_ans[pred.item()]] 171 | out[f'answer_ids{key}'] = [pred.item()] 172 | return out 173 | -------------------------------------------------------------------------------- /rubi/models/networks/factory.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import json 7 | from bootstrap.lib.options import Options 8 | from bootstrap.models.networks.data_parallel import DataParallel 9 | from block.models.networks.vqa_net import VQANet as AttentionNet 10 | from bootstrap.lib.logger import Logger 11 | 12 | from .baseline_net import BaselineNet 13 | from .rubi import RUBiNet 14 | 15 | def factory(engine): 16 | mode = list(engine.dataset.keys())[0] 17 | dataset = engine.dataset[mode] 18 | opt = Options()['model.network'] 19 | 20 | if opt['name'] == 'baseline': 21 | net = BaselineNet( 22 | txt_enc=opt['txt_enc'], 23 | self_q_att=opt['self_q_att'], 24 | agg=opt['agg'], 25 | classif=opt['classif'], 26 | wid_to_word=dataset.wid_to_word, 27 | word_to_wid=dataset.word_to_wid, 28 | aid_to_ans=dataset.aid_to_ans, 29 | ans_to_aid=dataset.ans_to_aid, 30 | fusion=opt['fusion'], 31 | residual=opt['residual'], 32 | ) 33 | 34 | elif opt['name'] == 'rubi': 35 | orig_net = BaselineNet( 36 | txt_enc=opt['txt_enc'], 37 | self_q_att=opt['self_q_att'], 38 | agg=opt['agg'], 39 | classif=opt['classif'], 40 | wid_to_word=dataset.wid_to_word, 41 | word_to_wid=dataset.word_to_wid, 42 | aid_to_ans=dataset.aid_to_ans, 43 | ans_to_aid=dataset.ans_to_aid, 44 | fusion=opt['fusion'], 45 | residual=opt['residual'], 46 | ) 47 | net = RUBiNet( 48 | model=orig_net, 49 | output_size=len(dataset.aid_to_ans), 50 | classif=opt['rubi_params']['mlp_q'] 51 | ) 52 | else: 53 | raise ValueError(opt['name']) 54 | 55 | if Options()['misc.cuda'] and torch.cuda.device_count() > 1: 56 | net = DataParallel(net) 57 | 58 | return net 59 | 60 | -------------------------------------------------------------------------------- /rubi/models/networks/rubi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from block.models.networks.mlp import MLP 4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask, 5 | 6 | 7 | class RUBiNet(nn.Module): 8 | """ 9 | Wraps another model 10 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax) 11 | Returns: 12 | - logits: the original predictions of the model 13 | - logits_q: the predictions from the question-only branch 14 | - logits_rubi: the updated predictions from the model by the mask. 15 | => Use `logits_rubi` and `logits_q` for the loss 16 | """ 17 | def __init__(self, model, output_size, classif, end_classif=True): 18 | super().__init__() 19 | self.net = model 20 | self.c_1 = MLP(**classif) 21 | self.end_classif = end_classif 22 | if self.end_classif: 23 | self.c_2 = nn.Linear(output_size, output_size) 24 | 25 | def forward(self, batch): 26 | out = {} 27 | # model prediction 28 | net_out = self.net(batch) 29 | logits = net_out['logits'] 30 | 31 | q_embedding = net_out['q_emb'] # N * q_emb 32 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate through question encoder 33 | q_pred = self.c_1(q_embedding) 34 | fusion_pred = logits * torch.sigmoid(q_pred) 35 | 36 | if self.end_classif: 37 | q_out = self.c_2(q_pred) 38 | else: 39 | q_out = q_pred 40 | 41 | out['logits'] = net_out['logits'] 42 | out['logits_rubi'] = fusion_pred 43 | out['logits_q'] = q_out 44 | return out 45 | 46 | def process_answers(self, out, key=''): 47 | out = self.net.process_answers(out) 48 | out = self.net.process_answers(out, key='_rubi') 49 | out = self.net.process_answers(out, key='_q') 50 | return out 51 | -------------------------------------------------------------------------------- /rubi/models/networks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mask_softmax(x, lengths):#, dim=1) 4 | mask = torch.zeros_like(x).to(device=x.device, non_blocking=True) 5 | t_lengths = lengths[:,:,None].expand_as(mask) 6 | arange_id = torch.arange(mask.size(1)).to(device=x.device, non_blocking=True) 7 | arange_id = arange_id[None,:,None].expand_as(mask) 8 | 9 | mask[arange_id=2: 26 | nn.init.xavier_uniform_(p.data) 27 | else: 28 | raise ValueError(p.dim()) 29 | 30 | return optimizer 31 | -------------------------------------------------------------------------------- /rubi/options/vqa2/baseline.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqa2/baseline 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: rubi.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | vg: false 19 | model: 20 | name: default 21 | network: 22 | import: rubi.models.networks.factory 23 | name: baseline 24 | txt_enc: 25 | name: skipthoughts 26 | type: BayesianUniSkip 27 | dropout: 0.25 28 | fixed_emb: False 29 | dir_st: data/skip-thoughts 30 | self_q_att: True 31 | residual: False 32 | fusion: 33 | type: block 34 | input_dims: [4800, 2048] 35 | output_dim: 2048 36 | mm_dim: 1000 37 | chunks: 20 38 | rank: 15 39 | dropout_input: 0. 40 | dropout_pre_lin: 0. 41 | agg: 42 | type: max 43 | classif: 44 | mlp: 45 | input_dim: 2048 46 | dimensions: [1024,1024,3000] 47 | criterion: 48 | import: rubi.models.criterions.factory 49 | name: vqa_cross_entropy 50 | metric: 51 | import: rubi.models.metrics.factory 52 | name: vqa_accuracies 53 | optimizer: 54 | import: rubi.optimizers.factory 55 | name: Adam 56 | lr: 0.0003 57 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 58 | lr_decay_epochs: [14, 24, 2] #range 59 | lr_decay_rate: .25 60 | engine: 61 | name: logger 62 | debug: False 63 | print_freq: 10 64 | nb_epochs: 22 65 | saving_criteria: 66 | - eval_epoch.accuracy_top1:max 67 | misc: 68 | logs_name: 69 | cuda: True 70 | seed: 1337 71 | view: 72 | name: plotly 73 | items: 74 | - logs:train_epoch.loss+logs:eval_epoch.loss 75 | - logs:train_epoch.accuracy_top1+logs:eval_epoch.accuracy_top1 76 | - logs_train_oe:train_epoch.overall+logs_val_oe:eval_epoch.overall 77 | -------------------------------------------------------------------------------- /rubi/options/vqa2/rubi.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/rubi 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: rubi.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: rubi.models.networks.factory 22 | name: rubi 23 | rubi_params: 24 | mlp_q: 25 | input_dim: 4800 26 | dimensions: [1024,1024,3000] 27 | txt_enc: 28 | name: skipthoughts 29 | type: BayesianUniSkip 30 | dropout: 0.25 31 | fixed_emb: False 32 | dir_st: data/skip-thoughts 33 | self_q_att: True 34 | residual: False 35 | fusion: 36 | type: block 37 | input_dims: [4800, 2048] 38 | output_dim: 2048 39 | mm_dim: 1000 40 | chunks: 20 41 | rank: 15 42 | dropout_input: 0. 43 | dropout_pre_lin: 0. 44 | agg: 45 | type: max 46 | classif: 47 | mlp: 48 | input_dim: 2048 49 | dimensions: [1024,1024,3000] 50 | criterion: 51 | import: rubi.models.criterions.factory 52 | name: rubi_criterion 53 | question_loss_weight: 1.0 54 | metric: 55 | import: rubi.models.metrics.factory 56 | name: vqa_rubi_metrics 57 | optimizer: 58 | import: rubi.optimizers.factory 59 | name: Adam 60 | lr: 0.0003 61 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 62 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 63 | lr_decay_epochs: [14, 24, 2] #range 64 | lr_decay_rate: .25 65 | engine: 66 | name: logger 67 | debug: False 68 | print_freq: 10 69 | nb_epochs: 22 70 | saving_criteria: 71 | - eval_epoch.accuracy_top1:max 72 | - eval_epoch.accuracy_rubi_top1:max 73 | misc: 74 | logs_name: 75 | cuda: True 76 | seed: 1337 77 | view: 78 | name: plotly 79 | items: 80 | - logs:train_epoch.loss+logs:eval_epoch.loss 81 | - logs:train_epoch.loss_mm_q+logs:eval_epoch.loss_mm_q 82 | - logs:train_epoch.loss_q+logs:eval_epoch.loss_q 83 | - logs:train_epoch.reconstruction_loss+logs:eval_epoch.reconstruction_loss 84 | ###### 85 | - logs:train_epoch.rubi_loss+logs:eval_epoch.rubi_loss 86 | - logs:train_epoch.accuracy_top1+logs:eval_epoch.accuracy_top1 87 | - logs:train_epoch.accuracy_rubi_top1+logs:eval_epoch.accuracy_rubi_top1 88 | - logs:train_epoch.accuracy_q_top1+logs:eval_epoch.accuracy_q_top1 89 | - logs_train_oe:train_epoch.overall+logs_val_oe:eval_epoch.overall 90 | - logs_q_train_oe:train_epoch.overall+logs_q_val_oe:eval_epoch.overall 91 | - logs_rubi_train_oe:train_epoch.overall+logs_rubi_val_oe:eval_epoch.overall 92 | -------------------------------------------------------------------------------- /rubi/options/vqacp2/baseline.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/baseline 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: rubi.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | adversarial_method: 19 | model: 20 | name: default 21 | network: 22 | import: rubi.models.networks.factory 23 | name: baseline 24 | txt_enc: 25 | name: skipthoughts 26 | type: BayesianUniSkip 27 | dropout: 0.25 28 | fixed_emb: False 29 | dir_st: data/skip-thoughts 30 | self_q_att: True 31 | residual: False 32 | fusion: 33 | type: block 34 | input_dims: [4800, 2048] 35 | output_dim: 2048 36 | mm_dim: 1000 37 | chunks: 20 38 | rank: 15 39 | dropout_input: 0. 40 | dropout_pre_lin: 0. 41 | agg: 42 | type: max 43 | classif: 44 | mlp: 45 | input_dim: 2048 46 | dimensions: [1024,1024,3000] 47 | criterion: 48 | import: rubi.models.criterions.factory 49 | name: vqa_cross_entropy 50 | metric: 51 | import: rubi.models.metrics.factory 52 | name: vqa_accuracies 53 | optimizer: 54 | import: rubi.optimizers.factory 55 | name: Adam 56 | lr: 0.0003 57 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 58 | lr_decay_epochs: [14, 24, 2] #range 59 | lr_decay_rate: .25 60 | engine: 61 | name: logger 62 | debug: False 63 | print_freq: 10 64 | nb_epochs: 22 65 | saving_criteria: 66 | - eval_epoch.accuracy_top1:max 67 | misc: 68 | logs_name: 69 | cuda: True 70 | seed: 1337 71 | view: 72 | name: plotly 73 | items: 74 | - logs:train_epoch.loss+logs:eval_epoch.loss 75 | - logs:train_epoch.accuracy_top1+logs:eval_epoch.accuracy_top1 76 | - logs_train_oe:train_epoch.overall+logs_val_oe:eval_epoch.overall 77 | -------------------------------------------------------------------------------- /rubi/options/vqacp2/rubi.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/rubi 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: rubi.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: rubi.models.networks.factory 22 | name: rubi 23 | rubi_params: 24 | mlp_q: 25 | input_dim: 4800 26 | dimensions: [1024,1024,3000] 27 | txt_enc: 28 | name: skipthoughts 29 | type: BayesianUniSkip 30 | dropout: 0.25 31 | fixed_emb: False 32 | dir_st: data/skip-thoughts 33 | self_q_att: True 34 | residual: False 35 | fusion: 36 | type: block 37 | input_dims: [4800, 2048] 38 | output_dim: 2048 39 | mm_dim: 1000 40 | chunks: 20 41 | rank: 15 42 | dropout_input: 0. 43 | dropout_pre_lin: 0. 44 | agg: 45 | type: max 46 | classif: 47 | mlp: 48 | input_dim: 2048 49 | dimensions: [1024,1024,3000] 50 | criterion: 51 | import: rubi.models.criterions.factory 52 | name: rubi_criterion 53 | question_loss_weight: 1.0 54 | metric: 55 | import: rubi.models.metrics.factory 56 | name: vqa_rubi_metrics 57 | optimizer: 58 | import: rubi.optimizers.factory 59 | name: Adam 60 | lr: 0.0003 61 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 62 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 63 | lr_decay_epochs: [14, 24, 2] #range 64 | lr_decay_rate: .25 65 | engine: 66 | name: logger 67 | debug: False 68 | print_freq: 10 69 | nb_epochs: 22 70 | saving_criteria: 71 | - eval_epoch.accuracy_top1:max 72 | - eval_epoch.accuracy_rubi_top1:max 73 | misc: 74 | logs_name: 75 | cuda: True 76 | seed: 1337 77 | view: 78 | name: plotly 79 | items: 80 | - logs:train_epoch.loss+logs:eval_epoch.loss 81 | - logs:train_epoch.loss_mm_q+logs:eval_epoch.loss_mm_q 82 | - logs:train_epoch.loss_q+logs:eval_epoch.loss_q 83 | ###### 84 | - logs:train_epoch.accuracy_top1+logs:eval_epoch.accuracy_top1 85 | - logs:train_epoch.accuracy_rubi_top1+logs:eval_epoch.accuracy_rubi_top1 86 | - logs:train_epoch.accuracy_q_top1+logs:eval_epoch.accuracy_q_top1 87 | - logs_train_oe:train_epoch.overall+logs_val_oe:eval_epoch.overall 88 | - logs_q_train_oe:train_epoch.overall+logs_q_val_oe:eval_epoch.overall 89 | - logs_rubi_train_oe:train_epoch.overall+logs_rubi_val_oe:eval_epoch.overall 90 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """A setuptools based setup module. 2 | 3 | See: 4 | https://packaging.python.org/en/latest/distributing.html 5 | https://github.com/pypa/sampleproject 6 | """ 7 | 8 | # Always prefer setuptools over distutils 9 | from setuptools import setup, find_packages 10 | # To use a consistent encoding 11 | from codecs import open 12 | from os import path 13 | 14 | here = path.abspath(path.dirname(__file__)) 15 | 16 | # Get the long description from the README file 17 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 18 | long_description = f.read() 19 | 20 | # Arguments marked as "Required" below must be included for upload to PyPI. 21 | # Fields marked as "Optional" may be commented out. 22 | 23 | # https://stackoverflow.com/questions/458550/standard-way-to-embed-version-into-python-package/16084844#16084844 24 | exec(open(path.join(here, 'rubi', '__version__.py')).read()) 25 | setup( 26 | # This is the name of your project. The first time you publish this 27 | # package, this name will be registered for you. It will determine how 28 | # users can install this project, e.g.: 29 | # 30 | # $ pip install sampleproject 31 | # 32 | # And where it will live on PyPI: https://pypi.org/project/sampleproject/ 33 | # 34 | # There are some restrictions on what makes a valid project name 35 | # specification here: 36 | # https://packaging.python.org/specifications/core-metadata/#name 37 | name='rubi.bootstrap.pytorch', # Required 38 | 39 | # Versions should comply with PEP 440: 40 | # https://www.python.org/dev/peps/pep-0440/ 41 | # 42 | # For a discussion on single-sourcing the version across setup.py and the 43 | # project code, see 44 | # https://packaging.python.org/en/latest/single_source_version.html 45 | version=__version__, # Required 46 | 47 | # This is a one-line description or tagline of what your project does. This 48 | # corresponds to the "Summary" metadata field: 49 | # https://packaging.python.org/specifications/core-metadata/#summary 50 | description='RUBi: Reducing Unimodal Biases for Visual Question Answering', # Required 51 | 52 | # This is an optional longer description of your project that represents 53 | # the body of text which users will see when they visit PyPI. 54 | # 55 | # Often, this is the same as your README, so you can just read it in from 56 | # that file directly (as we have already done above) 57 | # 58 | # This field corresponds to the "Description" metadata field: 59 | # https://packaging.python.org/specifications/core-metadata/#description-optional 60 | long_description=long_description, # Optional 61 | 62 | # This should be a valid link to your project's main homepage. 63 | # 64 | # This field corresponds to the "Home-Page" metadata field: 65 | # https://packaging.python.org/specifications/core-metadata/#home-page-optional 66 | url='https://github.com/', # Optional 67 | 68 | # This should be your name or the name of the organization which owns the 69 | # project. 70 | author='', # Optional 71 | 72 | # This should be a valid email address corresponding to the author listed 73 | # above. 74 | author_email='', # Optional 75 | 76 | # Classifiers help users find your project by categorizing it. 77 | # 78 | # For a list of valid classifiers, see 79 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers 80 | classifiers=[ # Optional 81 | # How mature is this project? Common values are 82 | # 3 - Alpha 83 | # 4 - Beta 84 | # 5 - Production/Stable 85 | 'Development Status :: 3 - Alpha', 86 | 87 | # Indicate who your project is intended for 88 | 'Intended Audience :: Developers', 89 | 'Topic :: Software Development :: Build Tools', 90 | 91 | # Pick your license as you wish 92 | 'License :: OSI Approved :: MIT License', 93 | 94 | # Specify the Python versions you support here. In particular, ensure 95 | # that you indicate whether you support Python 2, Python 3 or both. 96 | 'Programming Language :: Python :: 3.7', 97 | ], 98 | 99 | # This field adds keywords for your project which will appear on the 100 | # project page. What does your project relate to? 101 | # 102 | # Note that this is a string of words separated by whitespace, not a list. 103 | keywords='pytorch rubi vqa bias block murel tdiuc visual question answering visual relationship detection relation bootstrap deep learning neurips nips', # Optional 104 | 105 | # You can just specify package directories manually here if your project is 106 | # simple. Or you can use find_packages(). 107 | # 108 | # Alternatively, if you just want to distribute a single Python file, use 109 | # the `py_modules` argument instead as follows, which will expect a file 110 | # called `my_module.py` to exist: 111 | # 112 | # py_modules=["my_module"], 113 | # 114 | packages=find_packages(exclude=['data', 'logs', 'tests', 'bootstrap', 'block']), # Required 115 | 116 | # This field lists other packages that your project depends on to run. 117 | # Any package you put here will be installed by pip when your project is 118 | # installed, so they must be valid existing projects. 119 | # 120 | # For an analysis of "install_requires" vs pip's requirements files see: 121 | # https://packaging.python.org/en/latest/requirements.html 122 | install_requires=[ 123 | 'bootstrap.pytorch', 124 | 'skipthoughts', 125 | 'pretrainedmodels', 126 | 'opencv-python', 127 | 'block.bootstrap.pytorch', 128 | #'cfft' 129 | ], 130 | 131 | # pip install pytorch-fft does not install the last version 132 | # we need to add a dependency link 133 | # https://python-packaging.readthedocs.io/en/latest/dependencies.html 134 | # dependency_links=[ 135 | # 'https://github.com/locuslab/pytorch_fft' 136 | # ], 137 | 138 | # List additional groups of dependencies here (e.g. development 139 | # dependencies). Users will be able to install these using the "extras" 140 | # syntax, for example: 141 | # 142 | # $ pip install sampleproject[dev] 143 | # 144 | # Similar to `install_requires` above, these must be valid existing 145 | # projects. 146 | extras_require={ # Optional 147 | # 'dev': ['check-manifest'], 148 | 'test': ['pytest'], 149 | }, 150 | 151 | # If there are data files included in your packages that need to be 152 | # installed, specify them here. 153 | # 154 | # If using Python 2.6 or earlier, then these have to be included in 155 | # MANIFEST.in as well. 156 | # package_data={ # Optional 157 | # 'sample': ['package_data.dat'], 158 | # }, 159 | 160 | # Although 'package_data' is the preferred approach, in some case you may 161 | # need to place data files outside of your packages. See: 162 | # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files 163 | # 164 | # In this case, 'data_file' will be installed into '/my_data' 165 | # data_files=[('my_data', ['data/data_file'])], # Optional 166 | 167 | # To provide executable scripts, use entry points in preference to the 168 | # "scripts" keyword. Entry points provide cross-platform support and allow 169 | # `pip` to create the appropriate form of executable for the target 170 | # platform. 171 | # 172 | # For example, the following would provide a command called `sample` which 173 | # executes the function `main` from this package when invoked: 174 | # entry_points={ # Optional 175 | # 'console_scripts': [ 176 | # 'sample=sample:main', 177 | # ], 178 | # }, 179 | ) 180 | --------------------------------------------------------------------------------