├── .gitignore ├── LICENSE ├── README.md ├── cfvqa ├── .gitignore ├── README.md ├── cfvqa │ ├── __init__.py │ ├── __version__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── scripts │ │ │ ├── download_vqa2.sh │ │ │ └── download_vqacp2.sh │ │ ├── vqa2.py │ │ ├── vqacp.py │ │ └── vqacp2.py │ ├── engines │ │ ├── __init__.py │ │ ├── engine.py │ │ ├── factory.py │ │ └── logger.py │ ├── models │ │ ├── criterions │ │ │ ├── __init__.py │ │ │ ├── cfvqa_criterion.py │ │ │ ├── cfvqaintrod_criterion.py │ │ │ ├── factory.py │ │ │ ├── rubi_criterion.py │ │ │ └── rubiintrod_criterion.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── factory.py │ │ │ ├── vqa_cfvqa_metrics.py │ │ │ ├── vqa_cfvqaintrod_metrics.py │ │ │ ├── vqa_cfvqasimple_metrics.py │ │ │ ├── vqa_rubi_metrics.py │ │ │ └── vqa_rubiintrod_metrics.py │ │ └── networks │ │ │ ├── __init__.py │ │ │ ├── cfvqa.py │ │ │ ├── cfvqaintrod.py │ │ │ ├── factory.py │ │ │ ├── rubi.py │ │ │ ├── rubiintrod.py │ │ │ ├── san_net.py │ │ │ ├── smrl_net.py │ │ │ ├── updn_net.py │ │ │ └── utils.py │ ├── optimizers │ │ ├── __init__.py │ │ └── factory.py │ ├── options │ │ ├── vqa2 │ │ │ ├── smrl_baseline.yaml │ │ │ ├── smrl_cfvqa_sum.yaml │ │ │ ├── smrl_cfvqaintrod_sum.yaml │ │ │ ├── smrl_cfvqasimple_rubi.yaml │ │ │ ├── smrl_cfvqasimpleintrod_rubi.yaml │ │ │ ├── smrl_rubi.yaml │ │ │ └── smrl_rubiintrod.yaml │ │ └── vqacp2 │ │ │ ├── smrl_baseline.yaml │ │ │ ├── smrl_cfvqa_sum.yaml │ │ │ ├── smrl_cfvqaintrod_sum.yaml │ │ │ ├── smrl_cfvqasimple_rubi.yaml │ │ │ ├── smrl_cfvqasimpleintrod_rubi.yaml │ │ │ ├── smrl_rubi.yaml │ │ │ └── smrl_rubiintrod.yaml │ └── run.py ├── engine.py ├── requirements.txt ├── run.py ├── run_vqa2_cfvqa_introd.sh └── scripts │ ├── run_vqa2_cfvqa_introd.sh │ ├── run_vqa2_rubi_introd.sh │ ├── run_vqa2_rubicf_introd.sh │ ├── run_vqacp2_rubi_introd.sh │ └── run_vqacp2_rubicf_introd.sh ├── css ├── .gitignore ├── README.md ├── attention.py ├── base_model.py ├── base_model_introd.py ├── classifier.py ├── dataset.py ├── eval.py ├── fc.py ├── language_model.py ├── main.py ├── main_introd.py ├── tools │ ├── compute_softscore.py │ ├── compute_softscore_val.py │ ├── create_dictionary.py │ ├── create_dictionary_v1.py │ ├── download.sh │ └── process.sh ├── train.py ├── train_introd.py ├── util │ ├── cpv2_notype_mask.json │ ├── cpv2_type_mask.json │ ├── qid2type_cpv1.json │ ├── qid2type_cpv2.json │ ├── qid2type_v2.json │ ├── v2_notype_mask.json │ └── v2_type_mask.json ├── utils.py └── vqa_debias_loss_functions.py └── images ├── architecture.png └── introd.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introspective Distillation for Robust Question Answering 2 | 3 | This repository is the Pytorch implementation of our paper [" Introspective Distillation for Robust Question Answering"](https://arxiv.org/abs/2111.01026) in NeurIPS 2021. 4 | 5 | IntroD is proposed to achieve both high in-distribution (ID) and out-of-distribution (OOD) performances for question answering tasks like VQA and extractive QA. The key technical contribution is to blend the inductive bias of OOD and ID by introspecting whether a training sample fits in the factual ID world or the counterfactual OOD one. 6 | 7 |

8 | 9 |

10 | 11 |

12 | 13 |

14 | 15 | 16 | If you find this paper and codes help your research, please kindly consider citing our papers in your publications. 17 | ``` 18 | @inproceedings{niu2021introspective, 19 | title={Introspective Distillation for Robust Question Answering}, 20 | author={Niu, Yulei and Zhang, Hanwang}, 21 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, 22 | year={2021} 23 | } 24 | ``` 25 | ``` 26 | @inproceedings{niu2020counterfactual, 27 | title={Counterfactual VQA: A Cause-Effect Look at Language Bias}, 28 | author={Niu, Yulei and Tang, Kaihua and Zhang, Hanwang and Lu, Zhiwu and Hua, Xian-Sheng and Wen, Ji-Rong}, 29 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 30 | year={2021} 31 | } 32 | ``` -------------------------------------------------------------------------------- /cfvqa/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/ 132 | data 133 | logs/ 134 | logs 135 | -------------------------------------------------------------------------------- /cfvqa/README.md: -------------------------------------------------------------------------------- 1 | # CFVQA+IntroD 2 | 3 | This code is implemented as a fork of [CF-VQA][1] and [RUBi][2]. 4 | 5 | CF-VQA is proposed to capture and mitigate language bias in VQA from the view of causality. CF-VQA (1) captures the language bias as the direct causal effect of questions on answers, and (2) reduces the language bias by subtracting the direct language effect from the total causal effect. 6 | 7 | ## Summary 8 | 9 | * [Installation](#installation) 10 | * [Setup and dependencies](#1-setup-and-dependencies) 11 | * [Download datasets](#2-download-datasets) 12 | * [Quick start](#quick-start) 13 | * [Train a model](#train-a-model) 14 | * [Evaluate a model](#evaluate-a-model) 15 | * [Useful commands](#useful-commands) 16 | * [Acknowledgment](#acknowledgment) 17 | 18 | ## Installation 19 | 20 | 21 | ### 1. Setup and dependencies 22 | 23 | Install Anaconda or Miniconda distribution based on Python3+ from their downloads' site. 24 | 25 | ```bash 26 | conda create --name cfvqa python=3.7 27 | source activate cfvqa 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | For the error that `ModuleNotFoundError: No module named 'block.external'`, please follow [this issue](https://github.com/yuleiniu/cfvqa/issues/7) for the solution. 32 | 33 | 34 | ### 2. Download datasets 35 | 36 | Download annotations, images and features for VQA experiments: 37 | ```bash 38 | bash cfvqa/datasets/scripts/download_vqa2.sh 39 | bash cfvqa/datasets/scripts/download_vqacp2.sh 40 | ``` 41 | 42 | 43 | ## Overview of commands and result files 44 | 45 | ### Train a model 46 | 47 | The [bootstrap/run.py](https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/run.py) file load the options contained in a yaml file, create the corresponding experiment directory and start the training procedure. For instance, you can train our best model on VQA-CP v2 (CFVQA+SUM+SMRL) by running: 48 | ```bash? 49 | python -m bootstrap.run -o cfvqa/options/vqacp2/smrl_cfvqa_sum.yaml 50 | ``` 51 | Then, several files are going to be created in `logs/vqacp2/smrl_cfvqa_sum/`: 52 | - [options.yaml] (copy of options) 53 | - [logs.txt] (history of print) 54 | - [logs.json] (batchs and epochs statistics) 55 | - **[\_vq\_val\_oe.json] (statistics for the language-prior based strategy, e.g., RUBi)** 56 | - **[\_cfvqa\_val\_oe.json] (statistics for CF-VQA)** 57 | - [\_q\_val\_oe.json] (statistics for language-only branch) 58 | - [\_v\_val\_oe.json] (statistics for vision-only branch) 59 | - [\_all\_val\_oe.json] (statistics for the ensembled branch) 60 | - ckpt_last_engine.pth.tar (checkpoints of last epoch) 61 | - ckpt_last_model.pth.tar 62 | - ckpt_last_optimizer.pth.tar 63 | 64 | Many options are available in the options directory. CFVQA represents the complete causal graph while cfvqas represents the simplified causal graph. 65 | 66 | ### Evaluate a model 67 | 68 | There is no test set on VQA-CP v2, our main dataset. The evaluation is done on the validation set. For a model trained on VQA v2, you can evaluate your model on the test set. In this example, [boostrap/run.py](https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/run.py) load the options from your experiment directory, resume the best checkpoint on the validation set and start an evaluation on the testing set instead of the validation set while skipping the training set (train_split is empty). Thanks to `--misc.logs_name`, the logs will be written in the new `logs_predicate.txt` and `logs_predicate.json` files, instead of being appended to the `logs.txt` and `logs.json` files. 69 | ```bash 70 | python -m bootstrap.run \ 71 | -o ./logs/vqacp2/smrl_cfvqa_sum/options.yaml \ 72 | --exp.resume last \ 73 | --dataset.train_split ''\ 74 | --dataset.eval_split val \ 75 | --misc.logs_name test 76 | ``` 77 | 78 | ## Run IntroD 79 | 80 | We take CF-VQA+IntroD on VQA-CP v2 as an example. Simply run 81 | ```bash? 82 | bash scripts/run_vqacp2_cfvqa_introd.sh 83 | ``` 84 | The statistics for the final student model are included in `./logs/vqacp2/smrl_cfvqaintrod_sum/_stu_val_oe.json` 85 | 86 | ## Useful commands 87 | 88 | More useful commands can be founded [here](https://github.com/yuleiniu/cfvqa#useful-commands). 89 | 90 | ## Acknowledgment 91 | 92 | Special thanks to the authors of [RUBi][2], [BLOCK][3], and [bootstrap.pytorch][4], and the datasets used in this research project. 93 | 94 | [1]: https://github.com/yuleiniu/cfvqa 95 | [2]: https://github.com/cdancette/rubi.bootstrap.pytorch 96 | [3]: https://github.com/Cadene/block.bootstrap.pytorch 97 | [4]: https://github.com/Cadene/bootstrap.pytorch 98 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/__init__.py -------------------------------------------------------------------------------- /cfvqa/cfvqa/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.0' 2 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/datasets/__init__.py -------------------------------------------------------------------------------- /cfvqa/cfvqa/datasets/factory.py: -------------------------------------------------------------------------------- 1 | from bootstrap.lib.options import Options 2 | from block.datasets.tdiuc import TDIUC 3 | from block.datasets.vrd import VRD 4 | from block.datasets.vg import VG 5 | from block.datasets.vqa_utils import ListVQADatasets 6 | from .vqa2 import VQA2 7 | from .vqacp2 import VQACP2 8 | from .vqacp import VQACP 9 | 10 | def factory(engine=None): 11 | opt = Options()['dataset'] 12 | 13 | dataset = {} 14 | if opt.get('train_split', None): 15 | dataset['train'] = factory_split(opt['train_split']) 16 | if opt.get('eval_split', None): 17 | dataset['eval'] = factory_split(opt['eval_split']) 18 | 19 | return dataset 20 | 21 | def factory_split(split): 22 | opt = Options()['dataset'] 23 | shuffle = ('train' in split) 24 | 25 | if opt['name'] == 'vqacp2': 26 | assert(split in ['train', 'val', 'test']) 27 | samplingans = (opt['samplingans'] and split == 'train') 28 | 29 | dataset = VQACP2( 30 | dir_data=opt['dir'], 31 | split=split, 32 | batch_size=opt['batch_size'], 33 | nb_threads=opt['nb_threads'], 34 | pin_memory=Options()['misc']['cuda'], 35 | shuffle=shuffle, 36 | nans=opt['nans'], 37 | minwcount=opt['minwcount'], 38 | nlp=opt['nlp'], 39 | proc_split=opt['proc_split'], 40 | samplingans=samplingans, 41 | dir_rcnn=opt['dir_rcnn'], 42 | dir_cnn=opt.get('dir_cnn', None), 43 | dir_vgg16=opt.get('dir_vgg16', None), 44 | ) 45 | 46 | elif opt['name'] == 'vqacp': 47 | assert(split in ['train', 'val', 'test']) 48 | samplingans = (opt['samplingans'] and split == 'train') 49 | 50 | dataset = VQACP( 51 | dir_data=opt['dir'], 52 | split=split, 53 | batch_size=opt['batch_size'], 54 | nb_threads=opt['nb_threads'], 55 | pin_memory=Options()['misc']['cuda'], 56 | shuffle=shuffle, 57 | nans=opt['nans'], 58 | minwcount=opt['minwcount'], 59 | nlp=opt['nlp'], 60 | proc_split=opt['proc_split'], 61 | samplingans=samplingans, 62 | dir_rcnn=opt['dir_rcnn'], 63 | dir_cnn=opt.get('dir_cnn', None), 64 | dir_vgg16=opt.get('dir_vgg16', None), 65 | ) 66 | 67 | elif opt['name'] == 'vqacpv2-with-testdev': 68 | assert(split in ['train', 'val', 'test']) 69 | samplingans = (opt['samplingans'] and split == 'train') 70 | dataset = VQACP2( 71 | dir_data=opt['dir'], 72 | split=split, 73 | batch_size=opt['batch_size'], 74 | nb_threads=opt['nb_threads'], 75 | pin_memory=Options()['misc']['cuda'], 76 | shuffle=shuffle, 77 | nans=opt['nans'], 78 | minwcount=opt['minwcount'], 79 | nlp=opt['nlp'], 80 | proc_split=opt['proc_split'], 81 | samplingans=samplingans, 82 | dir_rcnn=opt['dir_rcnn'], 83 | dir_cnn=opt.get('dir_cnn', None), 84 | dir_vgg16=opt.get('dir_vgg16', None), 85 | has_testdevset=True, 86 | ) 87 | 88 | elif opt['name'] == 'vqa2': 89 | assert(split in ['train', 'val', 'test']) 90 | samplingans = (opt['samplingans'] and split == 'train') 91 | 92 | if opt['vg']: 93 | assert(opt['proc_split'] == 'trainval') 94 | 95 | # trainvalset 96 | vqa2 = VQA2( 97 | dir_data=opt['dir'], 98 | split='train', 99 | nans=opt['nans'], 100 | minwcount=opt['minwcount'], 101 | nlp=opt['nlp'], 102 | proc_split=opt['proc_split'], 103 | samplingans=samplingans, 104 | dir_rcnn=opt['dir_rcnn']) 105 | 106 | vg = VG( 107 | dir_data=opt['dir_vg'], 108 | split='train', 109 | nans=10000, 110 | minwcount=0, 111 | nlp=opt['nlp'], 112 | dir_rcnn=opt['dir_rcnn_vg']) 113 | 114 | vqa2vg = ListVQADatasets( 115 | [vqa2,vg], 116 | split='train', 117 | batch_size=opt['batch_size'], 118 | nb_threads=opt['nb_threads'], 119 | pin_memory=Options()['misc.cuda'], 120 | shuffle=shuffle) 121 | 122 | if split == 'train': 123 | dataset = vqa2vg 124 | else: 125 | dataset = VQA2( 126 | dir_data=opt['dir'], 127 | split=split, 128 | batch_size=opt['batch_size'], 129 | nb_threads=opt['nb_threads'], 130 | pin_memory=Options()['misc.cuda'], 131 | shuffle=False, 132 | nans=opt['nans'], 133 | minwcount=opt['minwcount'], 134 | nlp=opt['nlp'], 135 | proc_split=opt['proc_split'], 136 | samplingans=samplingans, 137 | dir_rcnn=opt['dir_rcnn']) 138 | dataset.sync_from(vqa2vg) 139 | 140 | else: 141 | dataset = VQA2( 142 | dir_data=opt['dir'], 143 | split=split, 144 | batch_size=opt['batch_size'], 145 | nb_threads=opt['nb_threads'], 146 | pin_memory=Options()['misc.cuda'], 147 | shuffle=shuffle, 148 | nans=opt['nans'], 149 | minwcount=opt['minwcount'], 150 | nlp=opt['nlp'], 151 | proc_split=opt['proc_split'], 152 | samplingans=samplingans, 153 | dir_rcnn=opt['dir_rcnn'], 154 | dir_cnn=opt.get('dir_cnn', None), 155 | ) 156 | 157 | return dataset 158 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/datasets/scripts/download_vqa2.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/vqa 2 | cd data/vqa 3 | wget http://data.lip6.fr/cadene/block/vqa2.tar.gz 4 | wget http://data.lip6.fr/cadene/block/coco.tar.gz 5 | tar -xzvf vqa2.tar.gz 6 | tar -xzvf coco.tar.gz 7 | 8 | mkdir -p data/vqa/coco/extract_rcnn 9 | cd data/vqa/coco/extract_rcnn 10 | wget http://data.lip6.fr/cadene/block/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36.tar 11 | tar -xvf 2018-04-27_bottom-up-attention_fixed_36.tar 12 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/datasets/scripts/download_vqacp2.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/vqa 2 | cd data/vqa 3 | wget http://data.lip6.fr/cadene/murel/vqacp2.tar.gz 4 | tar -xzvf vqacp2.tar.gz 5 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/datasets/vqa2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import copy 4 | import json 5 | import torch 6 | import numpy as np 7 | from os import path as osp 8 | from bootstrap.lib.logger import Logger 9 | from bootstrap.lib.options import Options 10 | from block.datasets.vqa_utils import AbstractVQA 11 | from copy import deepcopy 12 | import random 13 | import tqdm 14 | import h5py 15 | 16 | class VQA2(AbstractVQA): 17 | 18 | def __init__(self, 19 | dir_data='data/vqa2', 20 | split='train', 21 | batch_size=10, 22 | nb_threads=4, 23 | pin_memory=False, 24 | shuffle=False, 25 | nans=1000, 26 | minwcount=10, 27 | nlp='mcb', 28 | proc_split='train', 29 | samplingans=False, 30 | dir_rcnn='data/coco/extract_rcnn', 31 | adversarial=False, 32 | dir_cnn=None 33 | ): 34 | 35 | super(VQA2, self).__init__( 36 | dir_data=dir_data, 37 | split=split, 38 | batch_size=batch_size, 39 | nb_threads=nb_threads, 40 | pin_memory=pin_memory, 41 | shuffle=shuffle, 42 | nans=nans, 43 | minwcount=minwcount, 44 | nlp=nlp, 45 | proc_split=proc_split, 46 | samplingans=samplingans, 47 | has_valset=True, 48 | has_testset=True, 49 | has_answers_occurence=True, 50 | do_tokenize_answers=False) 51 | 52 | self.dir_rcnn = dir_rcnn 53 | self.dir_cnn = dir_cnn 54 | self.load_image_features() 55 | # to activate manually in visualization context (notebo# to activate manually in visualization context (notebook) 56 | self.load_original_annotation = False 57 | 58 | def add_rcnn_to_item(self, item): 59 | path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name'])) 60 | item_rcnn = torch.load(path_rcnn) 61 | item['visual'] = item_rcnn['pooled_feat'] 62 | item['coord'] = item_rcnn['rois'] 63 | item['norm_coord'] = item_rcnn.get('norm_rois', None) 64 | item['nb_regions'] = item['visual'].size(0) 65 | return item 66 | 67 | def add_cnn_to_item(self, item): 68 | image_name = item['image_name'] 69 | if image_name in self.image_names_to_index_train: 70 | index = self.image_names_to_index_train[image_name] 71 | image = torch.tensor(self.image_features_train['att'][index]) 72 | elif image_name in self.image_names_to_index_val: 73 | index = self.image_names_to_index_val[image_name] 74 | image = torch.tensor(self.image_features_val['att'][index]) 75 | image = image.permute(1, 2, 0).view(196, 2048) 76 | item['visual'] = image 77 | return item 78 | 79 | def load_image_features(self): 80 | if self.dir_cnn: 81 | filename_train = os.path.join(self.dir_cnn, 'trainset.hdf5') 82 | filename_val = os.path.join(self.dir_cnn, 'valset.hdf5') 83 | Logger()(f"Opening file {filename_train}, {filename_val}") 84 | self.image_features_train = h5py.File(filename_train, 'r', swmr=True) 85 | self.image_features_val = h5py.File(filename_val, 'r', swmr=True) 86 | # load txt 87 | with open(os.path.join(self.dir_cnn, 'trainset.txt'.format(self.split)), 'r') as f: 88 | self.image_names_to_index_train = {} 89 | for i, line in enumerate(f): 90 | self.image_names_to_index_train[line.strip()] = i 91 | with open(os.path.join(self.dir_cnn, 'valset.txt'.format(self.split)), 'r') as f: 92 | self.image_names_to_index_val = {} 93 | for i, line in enumerate(f): 94 | self.image_names_to_index_val[line.strip()] = i 95 | 96 | def __getitem__(self, index): 97 | item = {} 98 | item['index'] = index 99 | 100 | # Process Question (word token) 101 | question = self.dataset['questions'][index] 102 | if self.load_original_annotation: 103 | item['original_question'] = question 104 | 105 | item['question_id'] = question['question_id'] 106 | 107 | item['question'] = torch.tensor(question['question_wids'], dtype=torch.long) 108 | item['lengths'] = torch.tensor([len(question['question_wids'])], dtype=torch.long) 109 | item['image_name'] = question['image_name'] 110 | 111 | # Process Object, Attribut and Relational features 112 | # Process Object, Attribut and Relational features 113 | if self.dir_rcnn: 114 | item = self.add_rcnn_to_item(item) 115 | elif self.dir_cnn: 116 | item = self.add_cnn_to_item(item) 117 | 118 | # Process Answer if exists 119 | if 'annotations' in self.dataset: 120 | annotation = self.dataset['annotations'][index] 121 | if self.load_original_annotation: 122 | item['original_annotation'] = annotation 123 | if 'train' in self.split and self.samplingans: 124 | proba = annotation['answers_count'] 125 | proba = proba / np.sum(proba) 126 | item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba)) 127 | else: 128 | item['answer_id'] = annotation['answer_id'] 129 | item['class_id'] = torch.tensor([item['answer_id']], dtype=torch.long) 130 | item['answer'] = annotation['answer'] 131 | item['question_type'] = annotation['question_type'] 132 | else: 133 | if item['question_id'] in self.is_qid_testdev: 134 | item['is_testdev'] = True 135 | else: 136 | item['is_testdev'] = False 137 | 138 | # if Options()['model.network.name'] == 'xmn_net': 139 | # num_feat = 36 140 | # relation_mask = np.zeros((num_feat, num_feat)) 141 | # boxes = item['coord'] 142 | # for i in range(num_feat): 143 | # for j in range(i+1, num_feat): 144 | # # if there is no overlap between two bounding box 145 | # if boxes[0,i]>boxes[2,j] or boxes[0,j]>boxes[2,i] or boxes[1,i]>boxes[3,j] or boxes[1,j]>boxes[3,i]: 146 | # pass 147 | # else: 148 | # relation_mask[i,j] = relation_mask[j,i] = 1 149 | # relation_mask = torch.from_numpy(relation_mask).byte() 150 | # item['relation_mask'] = relation_mask 151 | 152 | return item 153 | 154 | def download(self): 155 | dir_zip = osp.join(self.dir_raw, 'zip') 156 | os.system('mkdir -p '+dir_zip) 157 | dir_ann = osp.join(self.dir_raw, 'annotations') 158 | os.system('mkdir -p '+dir_ann) 159 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Train_mscoco.zip -P '+dir_zip) 160 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Val_mscoco.zip -P '+dir_zip) 161 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Test_mscoco.zip -P '+dir_zip) 162 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Train_mscoco.zip -P '+dir_zip) 163 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Val_mscoco.zip -P '+dir_zip) 164 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Train_mscoco.zip')+' -d '+dir_ann) 165 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Val_mscoco.zip')+' -d '+dir_ann) 166 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Test_mscoco.zip')+' -d '+dir_ann) 167 | os.system('unzip '+osp.join(dir_zip, 'v2_Annotations_Train_mscoco.zip')+' -d '+dir_ann) 168 | os.system('unzip '+osp.join(dir_zip, 'v2_Annotations_Val_mscoco.zip')+' -d '+dir_ann) 169 | os.system('mv '+osp.join(dir_ann, 'v2_mscoco_train2014_annotations.json')+' ' 170 | +osp.join(dir_ann, 'mscoco_train2014_annotations.json')) 171 | os.system('mv '+osp.join(dir_ann, 'v2_mscoco_val2014_annotations.json')+' ' 172 | +osp.join(dir_ann, 'mscoco_val2014_annotations.json')) 173 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_train2014_questions.json')+' ' 174 | +osp.join(dir_ann, 'OpenEnded_mscoco_train2014_questions.json')) 175 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_val2014_questions.json')+' ' 176 | +osp.join(dir_ann, 'OpenEnded_mscoco_val2014_questions.json')) 177 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_test2015_questions.json')+' ' 178 | +osp.join(dir_ann, 'OpenEnded_mscoco_test2015_questions.json')) 179 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_test-dev2015_questions.json')+' ' 180 | +osp.join(dir_ann, 'OpenEnded_mscoco_test-dev2015_questions.json')) 181 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/datasets/vqacp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import copy 4 | import json 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | from os import path as osp 9 | from bootstrap.lib.logger import Logger 10 | from block.datasets.vqa_utils import AbstractVQA 11 | from copy import deepcopy 12 | import random 13 | import h5py 14 | 15 | class VQACP(AbstractVQA): 16 | 17 | def __init__(self, 18 | dir_data='data/vqa/vqacp2', 19 | split='train', 20 | batch_size=80, 21 | nb_threads=4, 22 | pin_memory=False, 23 | shuffle=False, 24 | nans=1000, 25 | minwcount=10, 26 | nlp='mcb', 27 | proc_split='train', 28 | samplingans=False, 29 | dir_rcnn='data/coco/extract_rcnn', 30 | dir_cnn=None, 31 | dir_vgg16=None, 32 | has_testdevset=False, 33 | ): 34 | super(VQACP, self).__init__( 35 | dir_data=dir_data, 36 | split=split, 37 | batch_size=batch_size, 38 | nb_threads=nb_threads, 39 | pin_memory=pin_memory, 40 | shuffle=shuffle, 41 | nans=nans, 42 | minwcount=minwcount, 43 | nlp=nlp, 44 | proc_split=proc_split, 45 | samplingans=samplingans, 46 | has_valset=True, 47 | has_testset=False, 48 | has_testdevset=has_testdevset, 49 | has_testset_anno=False, 50 | has_answers_occurence=True, 51 | do_tokenize_answers=False) 52 | self.dir_rcnn = dir_rcnn 53 | self.dir_cnn = dir_cnn 54 | self.dir_vgg16 = dir_vgg16 55 | self.load_image_features() 56 | self.load_original_annotation = False 57 | 58 | def add_rcnn_to_item(self, item): 59 | path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name'])) 60 | item_rcnn = torch.load(path_rcnn) 61 | item['visual'] = item_rcnn['pooled_feat'] 62 | item['coord'] = item_rcnn['rois'] 63 | item['norm_coord'] = item_rcnn['norm_rois'] 64 | item['nb_regions'] = item['visual'].size(0) 65 | return item 66 | 67 | def load_image_features(self): 68 | if self.dir_cnn: 69 | filename_train = os.path.join(self.dir_cnn, 'trainset.hdf5') 70 | filename_val = os.path.join(self.dir_cnn, 'valset.hdf5') 71 | Logger()(f"Opening file {filename_train}, {filename_val}") 72 | self.image_features_train = h5py.File(filename_train, 'r', swmr=True) 73 | self.image_features_val = h5py.File(filename_val, 'r', swmr=True) 74 | # load txt 75 | with open(os.path.join(self.dir_cnn, 'trainset.txt'.format(self.split)), 'r') as f: 76 | self.image_names_to_index_train = {} 77 | for i, line in enumerate(f): 78 | self.image_names_to_index_train[line.strip()] = i 79 | with open(os.path.join(self.dir_cnn, 'valset.txt'.format(self.split)), 'r') as f: 80 | self.image_names_to_index_val = {} 81 | for i, line in enumerate(f): 82 | self.image_names_to_index_val[line.strip()] = i 83 | elif self.dir_vgg16: 84 | # list filenames 85 | self.filenames_train = os.listdir(os.path.join(self.dir_vgg16, 'train')) 86 | self.filenames_val = os.listdir(os.path.join(self.dir_vgg16, 'val')) 87 | 88 | 89 | def add_vgg_to_item(self, item): 90 | image_name = item['image_name'] 91 | filename = image_name + '.pth' 92 | if filename in self.filenames_train: 93 | path = os.path.join(self.dir_vgg16, 'train', filename) 94 | elif filename in self.filenames_val: 95 | path = os.path.join(self.dir_vgg16, 'val', filename) 96 | visual = torch.load(path) 97 | visual = visual.permute(1, 2, 0).view(14*14, 512) 98 | item['visual'] = visual 99 | return item 100 | 101 | def add_cnn_to_item(self, item): 102 | image_name = item['image_name'] 103 | if image_name in self.image_names_to_index_train: 104 | index = self.image_names_to_index_train[image_name] 105 | image = torch.tensor(self.image_features_train['att'][index]) 106 | elif image_name in self.image_names_to_index_val: 107 | index = self.image_names_to_index_val[image_name] 108 | image = torch.tensor(self.image_features_val['att'][index]) 109 | image = image.permute(1, 2, 0).view(196, 2048) 110 | item['visual'] = image 111 | return item 112 | 113 | def __getitem__(self, index): 114 | item = {} 115 | item['index'] = index 116 | 117 | # Process Question (word token) 118 | question = self.dataset['questions'][index] 119 | if self.load_original_annotation: 120 | item['original_question'] = question 121 | item['question_id'] = question['question_id'] 122 | item['question'] = torch.LongTensor(question['question_wids']) 123 | item['lengths'] = torch.LongTensor([len(question['question_wids'])]) 124 | item['image_name'] = question['image_name'] 125 | 126 | # Process Object, Attribut and Relational features 127 | if self.dir_rcnn: 128 | item = self.add_rcnn_to_item(item) 129 | elif self.dir_cnn: 130 | item = self.add_cnn_to_item(item) 131 | elif self.dir_vgg16: 132 | item = self.add_vgg_to_item(item) 133 | 134 | # Process Answer if exists 135 | if 'annotations' in self.dataset: 136 | annotation = self.dataset['annotations'][index] 137 | if self.load_original_annotation: 138 | item['original_annotation'] = annotation 139 | if 'train' in self.split and self.samplingans: 140 | proba = annotation['answers_count'] 141 | proba = proba / np.sum(proba) 142 | item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba)) 143 | else: 144 | item['answer_id'] = annotation['answer_id'] 145 | item['class_id'] = torch.LongTensor([item['answer_id']]) 146 | item['answer'] = annotation['answer'] 147 | item['question_type'] = annotation['question_type'] 148 | 149 | return item 150 | 151 | def download(self): 152 | dir_ann = osp.join(self.dir_raw, 'annotations') 153 | os.system('mkdir -p '+dir_ann) 154 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json -P' + dir_ann) 155 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json -P' + dir_ann) 156 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json -P' + dir_ann) 157 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json -P' + dir_ann) 158 | train_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v1_train_questions.json")))} 159 | val_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v1_test_questions.json")))} 160 | train_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v1_train_annotations.json")))} 161 | val_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v1_test_annotations.json")))} 162 | train_q['info'] = {} 163 | train_q['data_type'] = 'mscoco' 164 | train_q['data_subtype'] = "train2014cp" 165 | train_q['task_type'] = "Open-Ended" 166 | train_q['license'] = {} 167 | val_q['info'] = {} 168 | val_q['data_type'] = 'mscoco' 169 | val_q['data_subtype'] = "val2014cp" 170 | val_q['task_type'] = "Open-Ended" 171 | val_q['license'] = {} 172 | for k in ["info", 'data_type','data_subtype', 'license']: 173 | train_ann[k] = train_q[k] 174 | val_ann[k] = val_q[k] 175 | with open(osp.join(dir_ann, "OpenEnded_mscoco_train2014_questions.json"), 'w') as F: 176 | F.write(json.dumps(train_q)) 177 | with open(osp.join(dir_ann, "OpenEnded_mscoco_val2014_questions.json"), 'w') as F: 178 | F.write(json.dumps(val_q)) 179 | with open(osp.join(dir_ann, "mscoco_train2014_annotations.json"), 'w') as F: 180 | F.write(json.dumps(train_ann)) 181 | with open(osp.join(dir_ann, "mscoco_val2014_annotations.json"), 'w') as F: 182 | F.write(json.dumps(val_ann)) 183 | 184 | def add_image_names(self, dataset): 185 | for q in dataset['questions']: 186 | q['image_name'] = 'COCO_%s_%012d.jpg'%(q['coco_split'],q['image_id']) 187 | return dataset 188 | 189 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/datasets/vqacp2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import copy 4 | import json 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | from os import path as osp 9 | from bootstrap.lib.logger import Logger 10 | from block.datasets.vqa_utils import AbstractVQA 11 | from copy import deepcopy 12 | import random 13 | import h5py 14 | 15 | class VQACP2(AbstractVQA): 16 | 17 | def __init__(self, 18 | dir_data='data/vqa/vqacp2', 19 | split='train', 20 | batch_size=80, 21 | nb_threads=4, 22 | pin_memory=False, 23 | shuffle=False, 24 | nans=1000, 25 | minwcount=10, 26 | nlp='mcb', 27 | proc_split='train', 28 | samplingans=False, 29 | dir_rcnn='data/coco/extract_rcnn', 30 | dir_cnn=None, 31 | dir_vgg16=None, 32 | has_testdevset=False, 33 | ): 34 | super(VQACP2, self).__init__( 35 | dir_data=dir_data, 36 | split=split, 37 | batch_size=batch_size, 38 | nb_threads=nb_threads, 39 | pin_memory=pin_memory, 40 | shuffle=shuffle, 41 | nans=nans, 42 | minwcount=minwcount, 43 | nlp=nlp, 44 | proc_split=proc_split, 45 | samplingans=samplingans, 46 | has_valset=True, 47 | has_testset=False, 48 | has_testdevset=has_testdevset, 49 | has_testset_anno=False, 50 | has_answers_occurence=True, 51 | do_tokenize_answers=False) 52 | self.dir_rcnn = dir_rcnn 53 | self.dir_cnn = dir_cnn 54 | self.dir_vgg16 = dir_vgg16 55 | self.load_image_features() 56 | self.load_original_annotation = False 57 | 58 | def add_rcnn_to_item(self, item): 59 | path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name'])) 60 | item_rcnn = torch.load(path_rcnn) 61 | item['visual'] = item_rcnn['pooled_feat'] 62 | item['coord'] = item_rcnn['rois'] 63 | item['norm_coord'] = item_rcnn['norm_rois'] 64 | item['nb_regions'] = item['visual'].size(0) 65 | return item 66 | 67 | def load_image_features(self): 68 | if self.dir_cnn: 69 | filename_train = os.path.join(self.dir_cnn, 'trainset.hdf5') 70 | filename_val = os.path.join(self.dir_cnn, 'valset.hdf5') 71 | Logger()(f"Opening file {filename_train}, {filename_val}") 72 | self.image_features_train = h5py.File(filename_train, 'r', swmr=True) 73 | self.image_features_val = h5py.File(filename_val, 'r', swmr=True) 74 | # load txt 75 | with open(os.path.join(self.dir_cnn, 'trainset.txt'.format(self.split)), 'r') as f: 76 | self.image_names_to_index_train = {} 77 | for i, line in enumerate(f): 78 | self.image_names_to_index_train[line.strip()] = i 79 | with open(os.path.join(self.dir_cnn, 'valset.txt'.format(self.split)), 'r') as f: 80 | self.image_names_to_index_val = {} 81 | for i, line in enumerate(f): 82 | self.image_names_to_index_val[line.strip()] = i 83 | elif self.dir_vgg16: 84 | # list filenames 85 | self.filenames_train = os.listdir(os.path.join(self.dir_vgg16, 'train')) 86 | self.filenames_val = os.listdir(os.path.join(self.dir_vgg16, 'val')) 87 | 88 | 89 | def add_vgg_to_item(self, item): 90 | image_name = item['image_name'] 91 | filename = image_name + '.pth' 92 | if filename in self.filenames_train: 93 | path = os.path.join(self.dir_vgg16, 'train', filename) 94 | elif filename in self.filenames_val: 95 | path = os.path.join(self.dir_vgg16, 'val', filename) 96 | visual = torch.load(path) 97 | visual = visual.permute(1, 2, 0).view(14*14, 512) 98 | item['visual'] = visual 99 | return item 100 | 101 | def add_cnn_to_item(self, item): 102 | image_name = item['image_name'] 103 | if image_name in self.image_names_to_index_train: 104 | index = self.image_names_to_index_train[image_name] 105 | image = torch.tensor(self.image_features_train['att'][index]) 106 | elif image_name in self.image_names_to_index_val: 107 | index = self.image_names_to_index_val[image_name] 108 | image = torch.tensor(self.image_features_val['att'][index]) 109 | image = image.permute(1, 2, 0).view(196, 2048) 110 | item['visual'] = image 111 | return item 112 | 113 | def __getitem__(self, index): 114 | item = {} 115 | item['index'] = index 116 | 117 | # Process Question (word token) 118 | question = self.dataset['questions'][index] 119 | if self.load_original_annotation: 120 | item['original_question'] = question 121 | item['question_id'] = question['question_id'] 122 | item['question'] = torch.LongTensor(question['question_wids']) 123 | item['lengths'] = torch.LongTensor([len(question['question_wids'])]) 124 | item['image_name'] = question['image_name'] 125 | 126 | # Process Object, Attribut and Relational features 127 | if self.dir_rcnn: 128 | item = self.add_rcnn_to_item(item) 129 | elif self.dir_cnn: 130 | item = self.add_cnn_to_item(item) 131 | elif self.dir_vgg16: 132 | item = self.add_vgg_to_item(item) 133 | 134 | # Process Answer if exists 135 | if 'annotations' in self.dataset: 136 | annotation = self.dataset['annotations'][index] 137 | if self.load_original_annotation: 138 | item['original_annotation'] = annotation 139 | if 'train' in self.split and self.samplingans: 140 | proba = annotation['answers_count'] 141 | proba = proba / np.sum(proba) 142 | item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba)) 143 | else: 144 | item['answer_id'] = annotation['answer_id'] 145 | item['class_id'] = torch.LongTensor([item['answer_id']]) 146 | item['answer'] = annotation['answer'] 147 | item['question_type'] = annotation['question_type'] 148 | 149 | return item 150 | 151 | def download(self): 152 | dir_ann = osp.join(self.dir_raw, 'annotations') 153 | os.system('mkdir -p '+dir_ann) 154 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json -P' + dir_ann) 155 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json -P' + dir_ann) 156 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json -P' + dir_ann) 157 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json -P' + dir_ann) 158 | train_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v2_train_questions.json")))} 159 | val_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v2_test_questions.json")))} 160 | train_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v2_train_annotations.json")))} 161 | val_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v2_test_annotations.json")))} 162 | train_q['info'] = {} 163 | train_q['data_type'] = 'mscoco' 164 | train_q['data_subtype'] = "train2014cp" 165 | train_q['task_type'] = "Open-Ended" 166 | train_q['license'] = {} 167 | val_q['info'] = {} 168 | val_q['data_type'] = 'mscoco' 169 | val_q['data_subtype'] = "val2014cp" 170 | val_q['task_type'] = "Open-Ended" 171 | val_q['license'] = {} 172 | for k in ["info", 'data_type','data_subtype', 'license']: 173 | train_ann[k] = train_q[k] 174 | val_ann[k] = val_q[k] 175 | with open(osp.join(dir_ann, "OpenEnded_mscoco_train2014_questions.json"), 'w') as F: 176 | F.write(json.dumps(train_q)) 177 | with open(osp.join(dir_ann, "OpenEnded_mscoco_val2014_questions.json"), 'w') as F: 178 | F.write(json.dumps(val_q)) 179 | with open(osp.join(dir_ann, "mscoco_train2014_annotations.json"), 'w') as F: 180 | F.write(json.dumps(train_ann)) 181 | with open(osp.join(dir_ann, "mscoco_val2014_annotations.json"), 'w') as F: 182 | F.write(json.dumps(val_ann)) 183 | 184 | def add_image_names(self, dataset): 185 | for q in dataset['questions']: 186 | q['image_name'] = 'COCO_%s_%012d.jpg'%(q['coco_split'],q['image_id']) 187 | return dataset 188 | 189 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import factory -------------------------------------------------------------------------------- /cfvqa/cfvqa/engines/factory.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from bootstrap.lib.options import Options 3 | from bootstrap.lib.logger import Logger 4 | from .engine import Engine 5 | from .logger import LoggerEngine 6 | 7 | 8 | def factory(): 9 | Logger()('Creating engine...') 10 | 11 | if Options()['engine'].get('import', False): 12 | # import usually is "yourmodule.engine.factory" 13 | module = importlib.import_module(Options()['engine']['import']) 14 | engine = module.factory() 15 | 16 | elif Options()['engine']['name'] == 'default': 17 | engine = Engine() 18 | 19 | elif Options()['engine']['name'] == 'logger': 20 | engine = LoggerEngine() 21 | 22 | else: 23 | raise ValueError 24 | 25 | return engine 26 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/engines/logger.py: -------------------------------------------------------------------------------- 1 | from bootstrap.lib.logger import Logger 2 | from .engine import Engine 3 | 4 | 5 | class LoggerEngine(Engine): 6 | """ LoggerEngine is similar to Engine. The only difference is a more powerful is_best method. 7 | It is able to look into the logger dictionary that contains the list of all the logged variables 8 | indexed by name. 9 | 10 | Example usage: 11 | 12 | .. code-block:: python 13 | 14 | out = { 15 | 'loss': 0.2, 16 | 'acctop1': 87.02 17 | } 18 | engine.is_best(out, 'loss:min') 19 | 20 | # Logger().values['eval_epoch.recall_at_1'] contains a list 21 | # of all the recall at 1 values for each evaluation epoch 22 | engine.is_best(out, 'eval_epoch.recall_at_1') 23 | """ 24 | 25 | def __init__(self): 26 | super(LoggerEngine, self).__init__() 27 | 28 | def is_best(self, out, saving_criteria): 29 | if ':min' in saving_criteria: 30 | name = saving_criteria.replace(':min', '') 31 | order = '<' 32 | elif ':max' in saving_criteria: 33 | name = saving_criteria.replace(':max', '') 34 | order = '>' 35 | else: 36 | error_msg = """'--engine.saving_criteria' named '{}' does not specify order, 37 | you need to chose between '{}' or '{}' to specify if the criteria needs to be minimize or maximize""".format( 38 | saving_criteria, saving_criteria + ':min', saving_criteria + ':max') 39 | raise ValueError(error_msg) 40 | 41 | if name in out: 42 | new_value = out[name] 43 | elif name in Logger().values: 44 | new_value = Logger().values[name][-1] 45 | else: 46 | raise ValueError("name '{}' not in outputs '{}' and not in logger '{}'".format( 47 | name, list(out.keys()), list(Logger().values.keys()))) 48 | 49 | if name not in self.best_out: 50 | self.best_out[name] = new_value 51 | else: 52 | if eval('{} {} {}'.format(new_value, order, self.best_out[name])): 53 | self.best_out[name] = new_value 54 | return True 55 | 56 | return False 57 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/criterions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/models/criterions/__init__.py -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/criterions/cfvqa_criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from bootstrap.lib.logger import Logger 5 | from bootstrap.lib.options import Options 6 | 7 | class CFVQACriterion(nn.Module): 8 | 9 | def __init__(self, question_loss_weight=1.0, vision_loss_weight=1.0, is_va=True): 10 | super().__init__() 11 | self.is_va = is_va 12 | 13 | Logger()(f'CFVQACriterion, with question_loss_weight = ({question_loss_weight})') 14 | if self.is_va: 15 | Logger()(f'CFVQACriterion, with vision_loss_weight = ({vision_loss_weight})') 16 | 17 | self.fusion_loss = nn.CrossEntropyLoss() 18 | self.question_loss = nn.CrossEntropyLoss() 19 | self.question_loss_weight = question_loss_weight 20 | if self.is_va: 21 | self.vision_loss = nn.CrossEntropyLoss() 22 | self.vision_loss_weight = vision_loss_weight 23 | 24 | def forward(self, net_out, batch): 25 | out = {} 26 | class_id = batch['class_id'].squeeze(1) 27 | 28 | logits_rubi = net_out['logits_all'] 29 | fusion_loss = self.fusion_loss(logits_rubi, class_id) 30 | 31 | logits_q = net_out['logits_q'] 32 | question_loss = self.question_loss(logits_q, class_id) 33 | 34 | if self.is_va: 35 | logits_v = net_out['logits_v'] 36 | vision_loss = self.vision_loss(logits_v, class_id) 37 | 38 | nde = net_out['z_nde'] 39 | p_te = torch.nn.functional.softmax(logits_rubi, -1).clone().detach() 40 | p_nde = torch.nn.functional.softmax(nde, -1) 41 | kl_loss = - p_te*p_nde.log() 42 | kl_loss = kl_loss.sum(1).mean() 43 | 44 | loss = fusion_loss \ 45 | + self.question_loss_weight * question_loss \ 46 | + kl_loss 47 | if self.is_va: 48 | loss += self.vision_loss_weight * vision_loss 49 | 50 | out['loss'] = loss 51 | out['loss_mm_q'] = fusion_loss 52 | out['loss_q'] = question_loss 53 | if self.is_va: 54 | out['loss_v'] = vision_loss 55 | return out 56 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/criterions/cfvqaintrod_criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from bootstrap.lib.logger import Logger 5 | from bootstrap.lib.options import Options 6 | 7 | class CFVQAIntroDCriterion(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.cls_loss = nn.CrossEntropyLoss(reduction='none') 13 | 14 | def forward(self, net_out, batch): 15 | out = {} 16 | logits_all = net_out['logits_all'] 17 | class_id = batch['class_id'].squeeze(1) 18 | 19 | # KD 20 | logits_t = net_out['logits_cfvqa'] 21 | logits_s = net_out['logits_stu'] 22 | p_t = torch.nn.functional.softmax(logits_t, -1).clone().detach() 23 | kd_loss = - p_t*F.log_softmax(logits_s, -1) 24 | kd_loss = kd_loss.sum(1) 25 | 26 | cls_loss = self.cls_loss(logits_s, class_id) 27 | 28 | # weight estimation 29 | cls_loss_ood = self.cls_loss(logits_t, class_id) 30 | cls_loss_id = self.cls_loss(logits_all, class_id) 31 | weight = cls_loss_ood/(cls_loss_ood+cls_loss_id) 32 | weight = weight.detach() 33 | 34 | loss = (weight*kd_loss).mean() + ((1-weight)*cls_loss).mean() 35 | 36 | out['loss'] = loss 37 | return out 38 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/criterions/factory.py: -------------------------------------------------------------------------------- 1 | from bootstrap.lib.options import Options 2 | from block.models.criterions.vqa_cross_entropy import VQACrossEntropyLoss 3 | from .rubi_criterion import RUBiCriterion 4 | from .cfvqa_criterion import CFVQACriterion 5 | from .cfvqaintrod_criterion import CFVQAIntroDCriterion 6 | from .rubiintrod_criterion import RUBiIntroDCriterion 7 | 8 | def factory(engine, mode): 9 | name = Options()['model.criterion.name'] 10 | split = engine.dataset[mode].split 11 | eval_only = 'train' not in engine.dataset 12 | 13 | opt = Options()['model.criterion'] 14 | if split == "test" and 'tdiuc' not in Options()['dataset.name']: 15 | return None 16 | if name == 'vqa_cross_entropy': 17 | criterion = VQACrossEntropyLoss() 18 | elif name == "rubi_criterion": 19 | criterion = RUBiCriterion( 20 | question_loss_weight=opt['question_loss_weight'] 21 | ) 22 | elif name == "cfvqa_criterion": 23 | criterion = CFVQACriterion( 24 | question_loss_weight=opt['question_loss_weight'], 25 | vision_loss_weight=opt['vision_loss_weight'], 26 | is_va=True, 27 | ) 28 | elif name == "cfvqasimple_criterion": 29 | criterion = CFVQACriterion( 30 | question_loss_weight=opt['question_loss_weight'], 31 | is_va=False, 32 | ) 33 | elif name == "cfvqaintrod_criterion": 34 | criterion = CFVQAIntroDCriterion() 35 | elif name == "rubiintrod_criterion": 36 | criterion = RUBiIntroDCriterion() 37 | else: 38 | raise ValueError(name) 39 | return criterion 40 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/criterions/rubi_criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from bootstrap.lib.logger import Logger 5 | from bootstrap.lib.options import Options 6 | 7 | class RUBiCriterion(nn.Module): 8 | 9 | def __init__(self, question_loss_weight=1.0): 10 | super().__init__() 11 | 12 | Logger()(f'RUBiCriterion, with question_loss_weight = ({question_loss_weight})') 13 | 14 | self.question_loss_weight = question_loss_weight 15 | self.fusion_loss = nn.CrossEntropyLoss() 16 | self.question_loss = nn.CrossEntropyLoss() 17 | 18 | def forward(self, net_out, batch): 19 | out = {} 20 | # logits = net_out['logits'] 21 | logits_q = net_out['logits_q'] 22 | logits_rubi = net_out['logits_all'] 23 | class_id = batch['class_id'].squeeze(1) 24 | fusion_loss = self.fusion_loss(logits_rubi, class_id) 25 | question_loss = self.question_loss(logits_q, class_id) 26 | loss = fusion_loss + self.question_loss_weight * question_loss 27 | 28 | out['loss'] = loss 29 | out['loss_mm_q'] = fusion_loss 30 | out['loss_q'] = question_loss 31 | return out 32 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/criterions/rubiintrod_criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from bootstrap.lib.logger import Logger 5 | from bootstrap.lib.options import Options 6 | 7 | class RUBiIntroDCriterion(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.cls_loss = nn.CrossEntropyLoss(reduction='none') 13 | 14 | def forward(self, net_out, batch): 15 | out = {} 16 | logits_all = net_out['logits_all'] 17 | class_id = batch['class_id'].squeeze(1) 18 | 19 | # KD 20 | logits_t = net_out['logits'] 21 | logits_s = net_out['logits_stu'] 22 | p_t = torch.nn.functional.softmax(logits_t, -1).clone().detach() 23 | kd_loss = - p_t*F.log_softmax(logits_s, -1) 24 | kd_loss = kd_loss.sum(1) 25 | 26 | cls_loss = self.cls_loss(logits_s, class_id) 27 | 28 | # weight estimation 29 | cls_loss_ood = self.cls_loss(logits_t, class_id) 30 | cls_loss_id = self.cls_loss(logits_all, class_id) 31 | weight = cls_loss_ood/(cls_loss_ood+cls_loss_id) 32 | weight = torch.round(weight) 33 | weight = weight.detach() 34 | 35 | loss = (weight*kd_loss).mean() + ((1-weight)*cls_loss).mean() 36 | 37 | out['loss'] = loss 38 | return out 39 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/models/metrics/__init__.py -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/metrics/factory.py: -------------------------------------------------------------------------------- 1 | from bootstrap.lib.options import Options 2 | from block.models.metrics.vqa_accuracies import VQAAccuracies 3 | from .vqa_rubi_metrics import VQARUBiMetrics 4 | from .vqa_cfvqa_metrics import VQACFVQAMetrics 5 | from .vqa_cfvqasimple_metrics import VQACFVQASimpleMetrics 6 | from .vqa_cfvqaintrod_metrics import VQACFVQAIntroDMetrics 7 | from .vqa_rubiintrod_metrics import VQARUBiIntroDMetrics 8 | 9 | def factory(engine, mode): 10 | name = Options()['model.metric.name'] 11 | metric = None 12 | 13 | if name == 'vqa_accuracies': 14 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name']) 15 | if mode == 'train': 16 | split = engine.dataset['train'].split 17 | if split == 'train': 18 | metric = VQAAccuracies(engine, 19 | mode='train', 20 | open_ended=open_ended, 21 | tdiuc=True, 22 | dir_exp=Options()['exp.dir'], 23 | dir_vqa=Options()['dataset.dir']) 24 | elif split == 'trainval': 25 | metric = None 26 | else: 27 | raise ValueError(split) 28 | elif mode == 'eval': 29 | metric = VQAAccuracies(engine, 30 | mode='eval', 31 | open_ended=open_ended, 32 | tdiuc=('tdiuc' in Options()['dataset.name'] or Options()['dataset.eval_split'] != 'test'), 33 | dir_exp=Options()['exp.dir'], 34 | dir_vqa=Options()['dataset.dir']) 35 | else: 36 | metric = None 37 | 38 | elif name == "vqa_rubi_metrics": 39 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name']) 40 | metric = VQARUBiMetrics(engine, 41 | mode=mode, 42 | open_ended=open_ended, 43 | tdiuc=True, 44 | dir_exp=Options()['exp.dir'], 45 | dir_vqa=Options()['dataset.dir'] 46 | ) 47 | 48 | elif name == "vqa_cfvqa_metrics": 49 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name']) 50 | metric = VQACFVQAMetrics(engine, 51 | mode=mode, 52 | open_ended=open_ended, 53 | tdiuc=True, 54 | dir_exp=Options()['exp.dir'], 55 | dir_vqa=Options()['dataset.dir'] 56 | ) 57 | 58 | elif name == "vqa_cfvqasimple_metrics": 59 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name']) 60 | metric = VQACFVQASimpleMetrics(engine, 61 | mode=mode, 62 | open_ended=open_ended, 63 | tdiuc=True, 64 | dir_exp=Options()['exp.dir'], 65 | dir_vqa=Options()['dataset.dir'] 66 | ) 67 | 68 | elif name == "vqa_cfvqaintrod_metrics": 69 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name']) 70 | metric = VQACFVQAIntroDMetrics(engine, 71 | mode=mode, 72 | open_ended=open_ended, 73 | tdiuc=True, 74 | dir_exp=Options()['exp.dir'], 75 | dir_vqa=Options()['dataset.dir'] 76 | ) 77 | 78 | elif name == "vqa_rubiintrod_metrics": 79 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name']) 80 | metric = VQARUBiIntroDMetrics(engine, 81 | mode=mode, 82 | open_ended=open_ended, 83 | tdiuc=True, 84 | dir_exp=Options()['exp.dir'], 85 | dir_vqa=Options()['dataset.dir'] 86 | ) 87 | 88 | else: 89 | raise ValueError(name) 90 | return metric 91 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/models/networks/__init__.py -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/cfvqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from block.models.networks.mlp import MLP 4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask, 5 | 6 | eps = 1e-12 7 | 8 | class CFVQA(nn.Module): 9 | """ 10 | Wraps another model 11 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax) 12 | Returns: 13 | - logits_vq: the original predictions of the model, i.e., NIE 14 | - logits_q: the predictions from the question-only branch 15 | - logits_v: the predictions from the vision-only branch 16 | - logits_all: the predictions from the ensemble model 17 | - logits_cfvqa: the predictions based on CF-VQA, i.e., TIE 18 | => Use `logits_all`, `logits_q` and `logits_v` for the loss 19 | """ 20 | def __init__(self, model, output_size, classif_q, classif_v, fusion_mode, end_classif=True, is_va=True): 21 | super().__init__() 22 | self.net = model 23 | self.end_classif = end_classif 24 | 25 | assert fusion_mode in ['rubi', 'hm', 'sum'], "Fusion mode should be rubi/hm/sum." 26 | self.fusion_mode = fusion_mode 27 | self.is_va = is_va and (not fusion_mode=='rubi') # RUBi does not consider V->A 28 | 29 | # Q->A branch 30 | self.q_1 = MLP(**classif_q) 31 | if self.end_classif: # default: True (following RUBi) 32 | self.q_2 = nn.Linear(output_size, output_size) 33 | 34 | # V->A branch 35 | if self.is_va: # default: True (containing V->A) 36 | self.v_1 = MLP(**classif_v) 37 | if self.end_classif: # default: True (following RUBi) 38 | self.v_2 = nn.Linear(output_size, output_size) 39 | 40 | self.constant = nn.Parameter(torch.tensor(0.0)) 41 | 42 | def forward(self, batch): 43 | out = {} 44 | # model prediction 45 | net_out = self.net(batch) 46 | logits = net_out['logits'] 47 | 48 | # Q->A branch 49 | q_embedding = net_out['q_emb'] # N * q_emb 50 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate 51 | q_pred = self.q_1(q_embedding) 52 | 53 | # V->A branch 54 | if self.is_va: 55 | v_embedding = net_out['v_emb'] # N * v_emb 56 | v_embedding = grad_mul_const(v_embedding, 0.0) # don't backpropagate 57 | v_pred = self.v_1(v_embedding) 58 | else: 59 | v_pred = None 60 | 61 | # both q, k and v are the facts 62 | z_qkv = self.fusion(logits, q_pred, v_pred, q_fact=True, k_fact=True, v_fact=True) # te 63 | # q is the fact while k and v are the counterfactuals 64 | z_q = self.fusion(logits, q_pred, v_pred, q_fact=True, k_fact=False, v_fact=False) # nie 65 | 66 | logits_cfvqa = z_qkv - z_q 67 | 68 | if self.end_classif: 69 | q_out = self.q_2(q_pred) 70 | if self.is_va: 71 | v_out = self.v_2(v_pred) 72 | else: 73 | q_out = q_pred 74 | if self.is_va: 75 | v_out = v_pred 76 | 77 | out['logits_all'] = z_qkv # for optimization 78 | out['logits_vq'] = logits # predictions of the original VQ branch, i.e., NIE 79 | out['logits_cfvqa'] = logits_cfvqa # predictions of CFVQA, i.e., TIE 80 | out['logits_q'] = q_out # for optimization 81 | if self.is_va: 82 | out['logits_v'] = v_out # for optimization 83 | 84 | if self.is_va: 85 | out['z_nde'] = self.fusion(logits.clone().detach(), q_pred.clone().detach(), v_pred.clone().detach(), q_fact=True, k_fact=False, v_fact=False) # tie 86 | else: 87 | out['z_nde'] = self.fusion(logits.clone().detach(), q_pred.clone().detach(), None, q_fact=True, k_fact=False, v_fact=False) # tie 88 | 89 | return out 90 | 91 | def process_answers(self, out, key=''): 92 | out = self.net.process_answers(out, key='_all') 93 | out = self.net.process_answers(out, key='_vq') 94 | out = self.net.process_answers(out, key='_cfvqa') 95 | out = self.net.process_answers(out, key='_q') 96 | if self.is_va: 97 | out = self.net.process_answers(out, key='_v') 98 | return out 99 | 100 | def fusion(self, z_k, z_q, z_v, q_fact=False, k_fact=False, v_fact=False): 101 | 102 | z_k, z_q, z_v = self.transform(z_k, z_q, z_v, q_fact, k_fact, v_fact) 103 | 104 | if self.fusion_mode == 'rubi': 105 | z = z_k * torch.sigmoid(z_q) 106 | 107 | elif self.fusion_mode == 'hm': 108 | if self.is_va: 109 | z = z_k * z_q * z_v 110 | else: 111 | z = z_k * z_q 112 | z = torch.log(z + eps) - torch.log1p(z) 113 | 114 | elif self.fusion_mode == 'sum': 115 | if self.is_va: 116 | z = z_k + z_q + z_v 117 | else: 118 | z = z_k + z_q 119 | z = torch.log(torch.sigmoid(z) + eps) 120 | 121 | return z 122 | 123 | def transform(self, z_k, z_q, z_v, q_fact=False, k_fact=False, v_fact=False): 124 | 125 | if not k_fact: 126 | z_k = self.constant * torch.ones_like(z_k).cuda() 127 | 128 | if not q_fact: 129 | z_q = self.constant * torch.ones_like(z_q).cuda() 130 | 131 | if self.is_va: 132 | if not v_fact: 133 | z_v = self.constant * torch.ones_like(z_v).cuda() 134 | 135 | if self.fusion_mode == 'hm': 136 | z_k = torch.sigmoid(z_k) 137 | z_q = torch.sigmoid(z_q) 138 | if self.is_va: 139 | z_v = torch.sigmoid(z_v) 140 | 141 | return z_k, z_q, z_v -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/cfvqaintrod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from block.models.networks.mlp import MLP 4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask, 5 | 6 | eps = 1e-12 7 | 8 | class CFVQAIntroD(nn.Module): 9 | """ 10 | Wraps another model 11 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax) 12 | Returns: 13 | - logits_vq: the original predictions of the model, i.e., NIE 14 | - logits_q: the predictions from the question-only branch 15 | - logits_v: the predictions from the vision-only branch 16 | - logits_all: the predictions from the ensemble model 17 | - logits_cfvqa: the predictions based on CF-VQA, i.e., TIE 18 | => Use `logits_all`, `logits_q` and `logits_v` for the loss 19 | """ 20 | def __init__(self, model, model_teacher, output_size, classif_q, classif_v, fusion_mode, end_classif=True, is_va=True): 21 | super().__init__() 22 | self.net_student = model 23 | self.net = model_teacher 24 | self.end_classif = end_classif 25 | 26 | assert fusion_mode in ['rubi', 'hm', 'sum'], "Fusion mode should be rubi/hm/sum." 27 | self.fusion_mode = fusion_mode 28 | self.is_va = is_va and (not fusion_mode=='rubi') # RUBi does not consider V->A 29 | 30 | # Q->A branch 31 | self.q_1 = MLP(**classif_q) 32 | if self.end_classif: # default: True (following RUBi) 33 | self.q_2 = nn.Linear(output_size, output_size) 34 | 35 | # V->A branch 36 | if self.is_va: # default: True (containing V->A) 37 | self.v_1 = MLP(**classif_v) 38 | if self.end_classif: # default: True (following RUBi) 39 | self.v_2 = nn.Linear(output_size, output_size) 40 | 41 | self.constant = nn.Parameter(torch.tensor(0.0)) 42 | self.constant.requires_grad = True 43 | 44 | self.net.eval() 45 | self.q_1.eval() 46 | if self.end_classif: 47 | self.q_2.eval() 48 | if self.is_va: 49 | self.v_1.eval() 50 | if self.end_classif: 51 | self.v_2.eval() 52 | 53 | def forward(self, batch): 54 | out = {} 55 | # model prediction 56 | net_out = self.net(batch) 57 | logits = net_out['logits'] 58 | 59 | # Q->A branch 60 | q_embedding = net_out['q_emb'] # N * q_emb 61 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate 62 | q_pred = self.q_1(q_embedding) 63 | 64 | # V->A branch 65 | if self.is_va: 66 | v_embedding = net_out['v_emb'] # N * v_emb 67 | v_embedding = grad_mul_const(v_embedding, 0.0) # don't backpropagate 68 | v_pred = self.v_1(v_embedding) 69 | else: 70 | v_pred = None 71 | 72 | # both q, k and v are the facts 73 | z_qkv = self.fusion(logits, q_pred, v_pred, q_fact=True, k_fact=True, v_fact=True) # te 74 | # q is the fact while k and v are the counterfactuals 75 | z_q = self.fusion(logits, q_pred, v_pred, q_fact=True, k_fact=False, v_fact=False) # nie 76 | 77 | logits_cfvqa = z_qkv - z_q 78 | 79 | if self.end_classif: 80 | q_out = self.q_2(q_pred) 81 | if self.is_va: 82 | v_out = self.v_2(v_pred) 83 | else: 84 | q_out = q_pred 85 | if self.is_va: 86 | v_out = v_pred 87 | 88 | out['logits_all'] = z_qkv # for optimization 89 | out['logits_vq'] = logits # predictions of the original VQ branch, i.e., NIE 90 | out['logits_cfvqa'] = logits_cfvqa # predictions of CFVQA, i.e., TIE 91 | out['logits_q'] = q_out # for optimization 92 | if self.is_va: 93 | out['logits_v'] = v_out # for optimization 94 | 95 | # student model 96 | logits_stu = self.net_student(batch) 97 | out['logits_stu'] = logits_stu['logits'] 98 | 99 | return out 100 | 101 | def process_answers(self, out, key=''): 102 | out = self.net.process_answers(out, key='_all') 103 | out = self.net.process_answers(out, key='_vq') 104 | out = self.net.process_answers(out, key='_cfvqa') 105 | out = self.net.process_answers(out, key='_q') 106 | if self.is_va: 107 | out = self.net.process_answers(out, key='_v') 108 | 109 | # student model 110 | out = self.net.process_answers(out, key='_stu') 111 | 112 | return out 113 | 114 | def fusion(self, z_k, z_q, z_v, q_fact=False, k_fact=False, v_fact=False): 115 | 116 | z_k, z_q, z_v = self.transform(z_k, z_q, z_v, q_fact, k_fact, v_fact) 117 | 118 | if self.fusion_mode == 'rubi': 119 | z = z_k * torch.sigmoid(z_q) 120 | 121 | elif self.fusion_mode == 'hm': 122 | if self.is_va: 123 | z = z_k * z_q * z_v 124 | else: 125 | z = z_k * z_q 126 | z = torch.log(z + eps) - torch.log1p(z) 127 | 128 | elif self.fusion_mode == 'sum': 129 | if self.is_va: 130 | z = z_k + z_q + z_v 131 | else: 132 | z = z_k + z_q 133 | z = torch.log(torch.sigmoid(z) + eps) 134 | 135 | return z 136 | 137 | def transform(self, z_k, z_q, z_v, q_fact=False, k_fact=False, v_fact=False): 138 | 139 | if not k_fact: 140 | z_k = self.constant * torch.ones_like(z_k).cuda() 141 | 142 | if not q_fact: 143 | z_q = self.constant * torch.ones_like(z_q).cuda() 144 | 145 | if self.is_va: 146 | if not v_fact: 147 | z_v = self.constant * torch.ones_like(z_v).cuda() 148 | 149 | if self.fusion_mode == 'hm': 150 | z_k = torch.sigmoid(z_k) 151 | z_q = torch.sigmoid(z_q) 152 | if self.is_va: 153 | z_v = torch.sigmoid(z_v) 154 | 155 | return z_k, z_q, z_v -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/factory.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import json 7 | from bootstrap.lib.options import Options 8 | from bootstrap.models.networks.data_parallel import DataParallel 9 | from block.models.networks.vqa_net import VQANet as AttentionNet 10 | from bootstrap.lib.logger import Logger 11 | 12 | from .rubi import RUBiNet 13 | from .cfvqa import CFVQA 14 | from .cfvqaintrod import CFVQAIntroD 15 | from .rubiintrod import RUBiIntroD 16 | 17 | def factory(engine): 18 | mode = list(engine.dataset.keys())[0] 19 | dataset = engine.dataset[mode] 20 | opt = Options()['model.network'] 21 | 22 | if opt['base'] == 'smrl': 23 | from .smrl_net import SMRLNet as BaselineNet 24 | elif opt['base'] == 'updn': 25 | from .updn_net import UpDnNet as BaselineNet 26 | elif opt['base'] == 'san': 27 | from .san_net import SANNet as BaselineNet 28 | else: 29 | raise ValueError(opt['base']) 30 | 31 | orig_net = BaselineNet( 32 | txt_enc=opt['txt_enc'], 33 | self_q_att=opt['self_q_att'], 34 | agg=opt['agg'], 35 | classif=opt['classif'], 36 | wid_to_word=dataset.wid_to_word, 37 | word_to_wid=dataset.word_to_wid, 38 | aid_to_ans=dataset.aid_to_ans, 39 | ans_to_aid=dataset.ans_to_aid, 40 | fusion=opt['fusion'], 41 | residual=opt['residual'], 42 | q_single=opt['q_single'], 43 | ) 44 | 45 | if opt['name'] == 'baseline': 46 | net = orig_net 47 | 48 | elif opt['name'] == 'rubi': 49 | net = RUBiNet( 50 | model=orig_net, 51 | output_size=len(dataset.aid_to_ans), 52 | classif=opt['rubi_params']['mlp_q'] 53 | ) 54 | 55 | elif opt['name'] == 'cfvqa': 56 | net = CFVQA( 57 | model=orig_net, 58 | output_size=len(dataset.aid_to_ans), 59 | classif_q=opt['cfvqa_params']['mlp_q'], 60 | classif_v=opt['cfvqa_params']['mlp_v'], 61 | fusion_mode=opt['fusion_mode'], 62 | is_va=True 63 | ) 64 | 65 | elif opt['name'] == 'cfvqasimple': 66 | net = CFVQA( 67 | model=orig_net, 68 | output_size=len(dataset.aid_to_ans), 69 | classif_q=opt['cfvqa_params']['mlp_q'], 70 | classif_v=None, 71 | fusion_mode=opt['fusion_mode'], 72 | is_va=False 73 | ) 74 | 75 | elif opt['name'] == 'cfvqaintrod': 76 | orig_net_teacher = BaselineNet( 77 | txt_enc=opt['txt_enc'], 78 | self_q_att=opt['self_q_att'], 79 | agg=opt['agg'], 80 | classif=opt['classif'], 81 | wid_to_word=dataset.wid_to_word, 82 | word_to_wid=dataset.word_to_wid, 83 | aid_to_ans=dataset.aid_to_ans, 84 | ans_to_aid=dataset.ans_to_aid, 85 | fusion=opt['fusion'], 86 | residual=opt['residual'], 87 | q_single=opt['q_single'], 88 | ) 89 | net = CFVQAIntroD( 90 | model=orig_net, 91 | model_teacher=orig_net_teacher, 92 | output_size=len(dataset.aid_to_ans), 93 | classif_q=opt['cfvqa_params']['mlp_q'], 94 | classif_v=opt['cfvqa_params']['mlp_v'], 95 | fusion_mode=opt['fusion_mode'] 96 | ) 97 | 98 | elif opt['name'] == 'cfvqasimpleintrod': 99 | orig_net_teacher = BaselineNet( 100 | txt_enc=opt['txt_enc'], 101 | self_q_att=opt['self_q_att'], 102 | agg=opt['agg'], 103 | classif=opt['classif'], 104 | wid_to_word=dataset.wid_to_word, 105 | word_to_wid=dataset.word_to_wid, 106 | aid_to_ans=dataset.aid_to_ans, 107 | ans_to_aid=dataset.ans_to_aid, 108 | fusion=opt['fusion'], 109 | residual=opt['residual'], 110 | q_single=opt['q_single'], 111 | ) 112 | net = CFVQAIntroD( 113 | model=orig_net, 114 | model_teacher=orig_net_teacher, 115 | output_size=len(dataset.aid_to_ans), 116 | classif_q=opt['cfvqa_params']['mlp_q'], 117 | classif_v=None, 118 | fusion_mode=opt['fusion_mode'], 119 | is_va=False 120 | ) 121 | 122 | elif opt['name'] == 'rubiintrod': 123 | orig_net_teacher = BaselineNet( 124 | txt_enc=opt['txt_enc'], 125 | self_q_att=opt['self_q_att'], 126 | agg=opt['agg'], 127 | classif=opt['classif'], 128 | wid_to_word=dataset.wid_to_word, 129 | word_to_wid=dataset.word_to_wid, 130 | aid_to_ans=dataset.aid_to_ans, 131 | ans_to_aid=dataset.ans_to_aid, 132 | fusion=opt['fusion'], 133 | residual=opt['residual'], 134 | q_single=opt['q_single'], 135 | ) 136 | net = RUBiIntroD( 137 | model=orig_net, 138 | model_teacher=orig_net_teacher, 139 | output_size=len(dataset.aid_to_ans), 140 | classif=opt['rubi_params']['mlp_q'] 141 | ) 142 | 143 | else: 144 | raise ValueError(opt['name']) 145 | 146 | if Options()['misc.cuda'] and torch.cuda.device_count() > 1: 147 | net = DataParallel(net) 148 | 149 | return net 150 | 151 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/rubi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from block.models.networks.mlp import MLP 4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask, 5 | 6 | 7 | class RUBiNet(nn.Module): 8 | """ 9 | Wraps another model 10 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax) 11 | Returns: 12 | - logits: the original predictions of the model 13 | - logits_q: the predictions from the question-only branch 14 | - logits_rubi: the updated predictions from the model by the mask. 15 | => Use `logits_rubi` and `logits_q` for the loss 16 | """ 17 | def __init__(self, model, output_size, classif, end_classif=True): 18 | super().__init__() 19 | self.net = model 20 | self.c_1 = MLP(**classif) 21 | self.end_classif = end_classif 22 | if self.end_classif: 23 | self.c_2 = nn.Linear(output_size, output_size) 24 | 25 | def forward(self, batch): 26 | out = {} 27 | # model prediction 28 | net_out = self.net(batch) 29 | logits = net_out['logits'] 30 | 31 | q_embedding = net_out['q_emb'] # N * q_emb 32 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate through question encoder 33 | q_pred = self.c_1(q_embedding) 34 | fusion_pred = logits * torch.sigmoid(q_pred) 35 | 36 | if self.end_classif: 37 | q_out = self.c_2(q_pred) 38 | else: 39 | q_out = q_pred 40 | 41 | out['logits'] = net_out['logits'] 42 | out['logits_all'] = fusion_pred 43 | out['logits_q'] = q_out 44 | return out 45 | 46 | def process_answers(self, out, key=''): 47 | out = self.net.process_answers(out) 48 | out = self.net.process_answers(out, key='_all') 49 | out = self.net.process_answers(out, key='_q') 50 | return out 51 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/rubiintrod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from block.models.networks.mlp import MLP 4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask, 5 | 6 | 7 | class RUBiIntroD(nn.Module): 8 | """ 9 | Wraps another model 10 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax) 11 | Returns: 12 | - logits: the original predictions of the model 13 | - logits_q: the predictions from the question-only branch 14 | - logits_rubi: the updated predictions from the model by the mask. 15 | => Use `logits_rubi` and `logits_q` for the loss 16 | """ 17 | def __init__(self, model, model_teacher, output_size, classif, end_classif=True): 18 | super().__init__() 19 | self.net_student = model 20 | self.net = model_teacher 21 | self.c_1 = MLP(**classif) 22 | self.end_classif = end_classif 23 | if self.end_classif: 24 | self.c_2 = nn.Linear(output_size, output_size) 25 | 26 | self.net.eval() 27 | self.c_1.eval() 28 | self.c_2.eval() 29 | 30 | def forward(self, batch): 31 | out = {} 32 | # model prediction 33 | net_out = self.net(batch) 34 | logits = net_out['logits'] 35 | 36 | q_embedding = net_out['q_emb'] # N * q_emb 37 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate through question encoder 38 | q_pred = self.c_1(q_embedding) 39 | fusion_pred = logits * torch.sigmoid(q_pred) 40 | 41 | if self.end_classif: 42 | q_out = self.c_2(q_pred) 43 | else: 44 | q_out = q_pred 45 | 46 | out['logits'] = net_out['logits'] 47 | out['logits_all'] = fusion_pred 48 | out['logits_q'] = q_out 49 | 50 | # student model 51 | logits_stu = self.net_student(batch) 52 | out['logits_stu'] = logits_stu['logits'] 53 | 54 | return out 55 | 56 | def process_answers(self, out, key=''): 57 | out = self.net.process_answers(out) 58 | out = self.net.process_answers(out, key='_all') 59 | out = self.net.process_answers(out, key='_q') 60 | out = self.net.process_answers(out, key='_stu') 61 | return out 62 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/smrl_net.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import itertools 3 | import os 4 | import numpy as np 5 | import scipy 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from bootstrap.lib.options import Options 10 | from bootstrap.lib.logger import Logger 11 | import block 12 | from block.models.networks.vqa_net import factory_text_enc 13 | from block.models.networks.mlp import MLP 14 | 15 | from .utils import mask_softmax 16 | 17 | class SMRLNet(nn.Module): 18 | 19 | def __init__(self, 20 | txt_enc={}, 21 | self_q_att=False, 22 | agg={}, 23 | classif={}, 24 | wid_to_word={}, 25 | word_to_wid={}, 26 | aid_to_ans=[], 27 | ans_to_aid={}, 28 | fusion={}, 29 | residual=False, 30 | q_single=False, 31 | ): 32 | super().__init__() 33 | self.self_q_att = self_q_att 34 | self.agg = agg 35 | assert self.agg['type'] in ['max', 'mean'] 36 | self.classif = classif 37 | self.wid_to_word = wid_to_word 38 | self.word_to_wid = word_to_wid 39 | self.aid_to_ans = aid_to_ans 40 | self.ans_to_aid = ans_to_aid 41 | self.fusion = fusion 42 | self.residual = residual 43 | 44 | # Modules 45 | self.txt_enc = self.get_text_enc(self.wid_to_word, txt_enc) 46 | if self.self_q_att: 47 | self.q_att_linear0 = nn.Linear(2400, 512) 48 | self.q_att_linear1 = nn.Linear(512, 2) 49 | 50 | if q_single: 51 | self.txt_enc_single = self.get_text_enc(self.wid_to_word, txt_enc) 52 | if self.self_q_att: 53 | self.q_att_linear0_single = nn.Linear(2400, 512) 54 | self.q_att_linear1_single = nn.Linear(512, 2) 55 | else: 56 | self.txt_enc_single = None 57 | 58 | self.fusion_module = block.factory_fusion(self.fusion) 59 | 60 | if self.classif['mlp']['dimensions'][-1] != len(self.aid_to_ans): 61 | Logger()(f"Warning, the classif_mm output dimension ({self.classif['mlp']['dimensions'][-1]})" 62 | f"doesn't match the number of answers ({len(self.aid_to_ans)}). Modifying the output dimension.") 63 | self.classif['mlp']['dimensions'][-1] = len(self.aid_to_ans) 64 | 65 | self.classif_module = MLP(**self.classif['mlp']) 66 | 67 | Logger().log_value('nparams', 68 | sum(p.numel() for p in self.parameters() if p.requires_grad), 69 | should_print=True) 70 | 71 | Logger().log_value('nparams_txt_enc', 72 | self.get_nparams_txt_enc(), 73 | should_print=True) 74 | 75 | 76 | def get_text_enc(self, vocab_words, options): 77 | """ 78 | returns the text encoding network. 79 | """ 80 | return factory_text_enc(self.wid_to_word, options) 81 | 82 | def get_nparams_txt_enc(self): 83 | params = [p.numel() for p in self.txt_enc.parameters() if p.requires_grad] 84 | if self.self_q_att: 85 | params += [p.numel() for p in self.q_att_linear0.parameters() if p.requires_grad] 86 | params += [p.numel() for p in self.q_att_linear1.parameters() if p.requires_grad] 87 | return sum(params) 88 | 89 | def process_fusion(self, q, mm): 90 | bsize = mm.shape[0] 91 | n_regions = mm.shape[1] 92 | 93 | mm = mm.contiguous().view(bsize*n_regions, -1) 94 | mm = self.fusion_module([q, mm]) 95 | mm = mm.view(bsize, n_regions, -1) 96 | return mm 97 | 98 | def forward(self, batch): 99 | v = batch['visual'] 100 | q = batch['question'] 101 | l = batch['lengths'].data 102 | c = batch['norm_coord'] 103 | nb_regions = batch.get('nb_regions') 104 | bsize = v.shape[0] 105 | n_regions = v.shape[1] 106 | 107 | out = {} 108 | 109 | q = self.process_question(q, l,) 110 | out['q_emb'] = q 111 | q_expand = q[:,None,:].expand(bsize, n_regions, q.shape[1]) 112 | q_expand = q_expand.contiguous().view(bsize*n_regions, -1) 113 | 114 | # single txt encoder 115 | if self.txt_enc_single is not None: 116 | out['q_emb'] = self.process_question(q, l, self.txt_enc_single, self.q_att_linear0_single, self.q_att_linear1_single) 117 | 118 | mm = self.process_fusion(q_expand, v,) 119 | 120 | if self.residual: 121 | mm = v + mm 122 | 123 | if self.agg['type'] == 'max': 124 | mm, mm_argmax = torch.max(mm, 1) 125 | elif self.agg['type'] == 'mean': 126 | mm = mm.mean(1) 127 | 128 | out['v_emb'] = v.mean(1) 129 | out['mm'] = mm 130 | out['mm_argmax'] = mm_argmax 131 | 132 | logits = self.classif_module(mm) 133 | out['logits'] = logits 134 | return out 135 | 136 | def process_question(self, q, l, txt_enc=None, q_att_linear0=None, q_att_linear1=None): 137 | if txt_enc is None: 138 | txt_enc = self.txt_enc 139 | if q_att_linear0 is None: 140 | q_att_linear0 = self.q_att_linear0 141 | if q_att_linear1 is None: 142 | q_att_linear1 = self.q_att_linear1 143 | q_emb = txt_enc.embedding(q) 144 | 145 | q, _ = txt_enc.rnn(q_emb) 146 | 147 | if self.self_q_att: 148 | q_att = q_att_linear0(q) 149 | q_att = F.relu(q_att) 150 | q_att = q_att_linear1(q_att) 151 | q_att = mask_softmax(q_att, l) 152 | #self.q_att_coeffs = q_att 153 | if q_att.size(2) > 1: 154 | q_atts = torch.unbind(q_att, dim=2) 155 | q_outs = [] 156 | for q_att in q_atts: 157 | q_att = q_att.unsqueeze(2) 158 | q_att = q_att.expand_as(q) 159 | q_out = q_att*q 160 | q_out = q_out.sum(1) 161 | q_outs.append(q_out) 162 | q = torch.cat(q_outs, dim=1) 163 | else: 164 | q_att = q_att.expand_as(q) 165 | q = q_att * q 166 | q = q.sum(1) 167 | else: 168 | # l contains the number of words for each question 169 | # in case of multi-gpus it must be a Tensor 170 | # thus we convert it into a list during the forward pass 171 | l = list(l.data[:,0]) 172 | q = txt_enc._select_last(q, l) 173 | 174 | return q 175 | 176 | def process_answers(self, out, key=''): 177 | batch_size = out[f'logits{key}'].shape[0] 178 | _, pred = out[f'logits{key}'].data.max(1) 179 | pred.squeeze_() 180 | if batch_size != 1: 181 | out[f'answers{key}'] = [self.aid_to_ans[pred[i].item()] for i in range(batch_size)] 182 | out[f'answer_ids{key}'] = [pred[i].item() for i in range(batch_size)] 183 | else: 184 | out[f'answers{key}'] = [self.aid_to_ans[pred.item()]] 185 | out[f'answer_ids{key}'] = [pred.item()] 186 | return out 187 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/updn_net.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import itertools 3 | import os 4 | import numpy as np 5 | import scipy 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from bootstrap.lib.options import Options 10 | from bootstrap.lib.logger import Logger 11 | import block 12 | from block.models.networks.vqa_net import factory_text_enc 13 | from block.models.networks.mlp import MLP 14 | 15 | from .utils import mask_softmax 16 | 17 | from torch.nn.utils.weight_norm import weight_norm 18 | 19 | class UpDnNet(nn.Module): 20 | 21 | def __init__(self, 22 | txt_enc={}, 23 | self_q_att=False, 24 | agg={}, 25 | classif={}, 26 | wid_to_word={}, 27 | word_to_wid={}, 28 | aid_to_ans=[], 29 | ans_to_aid={}, 30 | fusion={}, 31 | residual=False, 32 | q_single=False, 33 | ): 34 | super().__init__() 35 | self.self_q_att = self_q_att 36 | self.agg = agg 37 | assert self.agg['type'] in ['max', 'mean'] 38 | self.classif = classif 39 | self.wid_to_word = wid_to_word 40 | self.word_to_wid = word_to_wid 41 | self.aid_to_ans = aid_to_ans 42 | self.ans_to_aid = ans_to_aid 43 | self.fusion = fusion 44 | self.residual = residual 45 | 46 | # Modules 47 | self.txt_enc = self.get_text_enc(self.wid_to_word, txt_enc) 48 | if self.self_q_att: 49 | self.q_att_linear0 = nn.Linear(2400, 512) 50 | self.q_att_linear1 = nn.Linear(512, 2) 51 | 52 | if q_single: 53 | self.txt_enc_single = self.get_text_enc(self.wid_to_word, txt_enc) 54 | if self.self_q_att: 55 | self.q_att_linear0_single = nn.Linear(2400, 512) 56 | self.q_att_linear1_single = nn.Linear(512, 2) 57 | else: 58 | self.txt_enc_single = None 59 | 60 | if self.classif['mlp']['dimensions'][-1] != len(self.aid_to_ans): 61 | Logger()(f"Warning, the classif_mm output dimension ({self.classif['mlp']['dimensions'][-1]})" 62 | f"doesn't match the number of answers ({len(self.aid_to_ans)}). Modifying the output dimension.") 63 | self.classif['mlp']['dimensions'][-1] = len(self.aid_to_ans) 64 | 65 | self.classif_module = MLP(**self.classif['mlp']) 66 | 67 | # UpDn 68 | q_dim = self.fusion['input_dims'][0] 69 | v_dim = self.fusion['input_dims'][1] 70 | output_dim = self.fusion['output_dim'] 71 | self.v_att = Attention(v_dim, q_dim, output_dim) 72 | self.q_net = FCNet([q_dim, output_dim]) 73 | self.v_net = FCNet([v_dim, output_dim]) 74 | 75 | Logger().log_value('nparams', 76 | sum(p.numel() for p in self.parameters() if p.requires_grad), 77 | should_print=True) 78 | 79 | Logger().log_value('nparams_txt_enc', 80 | self.get_nparams_txt_enc(), 81 | should_print=True) 82 | 83 | 84 | def get_text_enc(self, vocab_words, options): 85 | """ 86 | returns the text encoding network. 87 | """ 88 | return factory_text_enc(self.wid_to_word, options) 89 | 90 | def get_nparams_txt_enc(self): 91 | params = [p.numel() for p in self.txt_enc.parameters() if p.requires_grad] 92 | if self.self_q_att: 93 | params += [p.numel() for p in self.q_att_linear0.parameters() if p.requires_grad] 94 | params += [p.numel() for p in self.q_att_linear1.parameters() if p.requires_grad] 95 | return sum(params) 96 | 97 | def forward(self, batch): 98 | v = batch['visual'] 99 | q = batch['question'] 100 | l = batch['lengths'].data 101 | c = batch['norm_coord'] 102 | nb_regions = batch.get('nb_regions') 103 | 104 | out = {} 105 | 106 | q_emb = self.process_question(q, l,) 107 | out['v_emb'] = v.mean(1) 108 | out['q_emb'] = q_emb 109 | 110 | # single txt encoder 111 | if self.txt_enc_single is not None: 112 | out['q_emb'] = self.process_question(q, l, self.txt_enc_single, self.q_att_linear0_single, self.q_att_linear1_single) 113 | 114 | # New 115 | att = self.v_att(v, q_emb) 116 | v_emb = (att * v).sum(1) 117 | q_repr = self.q_net(q_emb) 118 | v_repr = self.v_net(v_emb) 119 | joint_repr = q_repr * v_repr 120 | 121 | logits = self.classif_module(joint_repr) 122 | out['logits'] = logits 123 | 124 | return out 125 | 126 | def process_question(self, q, l, txt_enc=None, q_att_linear0=None, q_att_linear1=None): 127 | if txt_enc is None: 128 | txt_enc = self.txt_enc 129 | if q_att_linear0 is None: 130 | q_att_linear0 = self.q_att_linear0 131 | if q_att_linear1 is None: 132 | q_att_linear1 = self.q_att_linear1 133 | q_emb = txt_enc.embedding(q) 134 | 135 | q, _ = txt_enc.rnn(q_emb) 136 | 137 | if self.self_q_att: 138 | q_att = q_att_linear0(q) 139 | q_att = F.relu(q_att) 140 | q_att = q_att_linear1(q_att) 141 | q_att = mask_softmax(q_att, l) 142 | #self.q_att_coeffs = q_att 143 | if q_att.size(2) > 1: 144 | q_atts = torch.unbind(q_att, dim=2) 145 | q_outs = [] 146 | for q_att in q_atts: 147 | q_att = q_att.unsqueeze(2) 148 | q_att = q_att.expand_as(q) 149 | q_out = q_att*q 150 | q_out = q_out.sum(1) 151 | q_outs.append(q_out) 152 | q = torch.cat(q_outs, dim=1) 153 | else: 154 | q_att = q_att.expand_as(q) 155 | q = q_att * q 156 | q = q.sum(1) 157 | else: 158 | # l contains the number of words for each question 159 | # in case of multi-gpus it must be a Tensor 160 | # thus we convert it into a list during the forward pass 161 | l = list(l.data[:,0]) 162 | q = txt_enc._select_last(q, l) 163 | 164 | return q 165 | 166 | def process_answers(self, out, key=''): 167 | batch_size = out[f'logits{key}'].shape[0] 168 | _, pred = out[f'logits{key}'].data.max(1) 169 | pred.squeeze_() 170 | if batch_size != 1: 171 | out[f'answers{key}'] = [self.aid_to_ans[pred[i].item()] for i in range(batch_size)] 172 | out[f'answer_ids{key}'] = [pred[i].item() for i in range(batch_size)] 173 | else: 174 | out[f'answers{key}'] = [self.aid_to_ans[pred.item()]] 175 | out[f'answer_ids{key}'] = [pred.item()] 176 | return out 177 | 178 | class Attention(nn.Module): 179 | def __init__(self, v_dim, q_dim, num_hid, dropout=0.2): 180 | super(Attention, self).__init__() 181 | 182 | self.v_proj = FCNet([v_dim, num_hid]) 183 | self.q_proj = FCNet([q_dim, num_hid]) 184 | self.dropout = nn.Dropout(dropout) 185 | self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None) 186 | 187 | def forward(self, v, q): 188 | """ 189 | v: [batch, k, vdim] 190 | q: [batch, qdim] 191 | """ 192 | logits = self.logits(v, q) 193 | w = nn.functional.softmax(logits, 1) 194 | return w 195 | 196 | def logits(self, v, q): 197 | batch, k, _ = v.size() 198 | v_proj = self.v_proj(v) # [batch, k, qdim] 199 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) 200 | joint_repr = v_proj * q_proj 201 | joint_repr = self.dropout(joint_repr) 202 | logits = self.linear(joint_repr) 203 | return logits 204 | 205 | class FCNet(nn.Module): 206 | """Simple class for non-linear fully connect network 207 | """ 208 | def __init__(self, dims): 209 | super(FCNet, self).__init__() 210 | 211 | layers = [] 212 | for i in range(len(dims)-2): 213 | in_dim = dims[i] 214 | out_dim = dims[i+1] 215 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 216 | layers.append(nn.ReLU()) 217 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) 218 | layers.append(nn.ReLU()) 219 | 220 | self.main = nn.Sequential(*layers) 221 | 222 | def forward(self, x): 223 | return self.main(x) -------------------------------------------------------------------------------- /cfvqa/cfvqa/models/networks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mask_softmax(x, lengths):#, dim=1) 4 | mask = torch.zeros_like(x).to(device=x.device, non_blocking=True) 5 | t_lengths = lengths[:,:,None].expand_as(mask) 6 | arange_id = torch.arange(mask.size(1)).to(device=x.device, non_blocking=True) 7 | arange_id = arange_id[None,:,None].expand_as(mask) 8 | 9 | mask[arange_id=2: 26 | nn.init.xavier_uniform_(p.data) 27 | else: 28 | raise ValueError(p.dim()) 29 | 30 | return optimizer 31 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqa2/smrl_baseline.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqa2/smrl_baseline 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | vg: false 19 | model: 20 | name: default 21 | network: 22 | import: cfvqa.models.networks.factory 23 | base: smrl 24 | name: baseline 25 | cfvqa_params: 26 | mlp_q: 27 | input_dim: 4800 28 | dimensions: [1024,1024,3000] 29 | mlp_v: 30 | input_dim: 2048 31 | dimensions: [1024,1024,3000] 32 | txt_enc: 33 | name: skipthoughts 34 | type: BayesianUniSkip 35 | dropout: 0.25 36 | fixed_emb: False 37 | dir_st: data/skip-thoughts 38 | self_q_att: True 39 | residual: False 40 | q_single: False 41 | fusion: 42 | type: block 43 | input_dims: [4800, 2048] 44 | output_dim: 2048 45 | mm_dim: 1000 46 | chunks: 20 47 | rank: 15 48 | dropout_input: 0. 49 | dropout_pre_lin: 0. 50 | agg: 51 | type: max 52 | classif: 53 | mlp: 54 | input_dim: 2048 55 | dimensions: [1024,1024,3000] 56 | criterion: 57 | import: cfvqa.models.criterions.factory 58 | name: vqa_cross_entropy 59 | question_loss_weight: 1.0 60 | vision_loss_weight: 1.0 61 | loss_temp: 100.0 62 | metric: 63 | import: cfvqa.models.metrics.factory 64 | name: vqa_accuracies 65 | optimizer: 66 | import: cfvqa.optimizers.factory 67 | name: Adam 68 | lr: 0.0003 69 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 70 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 71 | lr_decay_epochs: [14, 24, 2] #range 72 | lr_decay_rate: .25 73 | engine: 74 | name: logger 75 | debug: False 76 | print_freq: 10 77 | nb_epochs: 22 78 | saving_criteria: 79 | - eval_epoch.accuracy_top1:max 80 | misc: 81 | logs_name: 82 | cuda: True 83 | seed: 1337 84 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqa2/smrl_cfvqa_sum.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqa2/smrl_cfvqa_sum 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | vg: False 19 | model: 20 | name: default 21 | network: 22 | import: cfvqa.models.networks.factory 23 | base: smrl 24 | name: cfvqa 25 | fusion_mode: sum 26 | cfvqa_params: 27 | mlp_q: 28 | input_dim: 4800 29 | dimensions: [1024,1024,3000] 30 | mlp_v: 31 | input_dim: 2048 32 | dimensions: [1024,1024,3000] 33 | txt_enc: 34 | name: skipthoughts 35 | type: BayesianUniSkip 36 | dropout: 0.25 37 | fixed_emb: False 38 | dir_st: data/skip-thoughts 39 | self_q_att: True 40 | residual: False 41 | q_single: False 42 | fusion: 43 | type: block 44 | input_dims: [4800, 2048] 45 | output_dim: 2048 46 | mm_dim: 1000 47 | chunks: 20 48 | rank: 15 49 | dropout_input: 0. 50 | dropout_pre_lin: 0. 51 | agg: 52 | type: max 53 | classif: 54 | mlp: 55 | input_dim: 2048 56 | dimensions: [1024,1024,3000] 57 | criterion: 58 | import: cfvqa.models.criterions.factory 59 | name: cfvqa_criterion 60 | question_loss_weight: 1.0 61 | vision_loss_weight: 1.0 62 | metric: 63 | import: cfvqa.models.metrics.factory 64 | name: vqa_cfvqa_metrics 65 | optimizer: 66 | import: cfvqa.optimizers.factory 67 | name: Adam 68 | lr: 0.0003 69 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 70 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 71 | lr_decay_epochs: [14, 24, 2] #range 72 | lr_decay_rate: .25 73 | engine: 74 | name: logger 75 | debug: False 76 | print_freq: 10 77 | nb_epochs: 22 78 | saving_criteria: 79 | - eval_epoch.accuracy_all_top1:max 80 | - eval_epoch.accuracy_vq_top1:max 81 | - eval_epoch.accuracy_cfvqa_top1:max 82 | misc: 83 | logs_name: 84 | cuda: True 85 | seed: 1337 86 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqa2/smrl_cfvqaintrod_sum.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqa2/smrl_cfvqaintrod_sum 3 | resume: last # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | vg: false 19 | model: 20 | name: default 21 | network: 22 | import: cfvqa.models.networks.factory 23 | base: smrl 24 | name: cfvqaintrod 25 | fusion_mode: sum 26 | cfvqa_params: 27 | mlp_q: 28 | input_dim: 4800 29 | dimensions: [1024,1024,3000] 30 | mlp_v: 31 | input_dim: 2048 32 | dimensions: [1024,1024,3000] 33 | txt_enc: 34 | name: skipthoughts 35 | type: BayesianUniSkip 36 | dropout: 0.25 37 | fixed_emb: False 38 | dir_st: data/skip-thoughts 39 | self_q_att: True 40 | residual: False 41 | q_single: False 42 | fusion: 43 | type: block 44 | input_dims: [4800, 2048] 45 | output_dim: 2048 46 | mm_dim: 1000 47 | chunks: 20 48 | rank: 15 49 | dropout_input: 0. 50 | dropout_pre_lin: 0. 51 | agg: 52 | type: max 53 | classif: 54 | mlp: 55 | input_dim: 2048 56 | dimensions: [1024,1024,3000] 57 | criterion: 58 | import: cfvqa.models.criterions.factory 59 | name: cfvqaintrod_criterion 60 | metric: 61 | import: cfvqa.models.metrics.factory 62 | name: vqa_cfvqaintrod_metrics 63 | optimizer: 64 | import: cfvqa.optimizers.factory 65 | name: Adam 66 | lr: 0.0003 67 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 68 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 69 | lr_decay_epochs: [14, 24, 2] #range 70 | lr_decay_rate: .25 71 | engine: 72 | name: logger 73 | debug: False 74 | print_freq: 10 75 | nb_epochs: 22 76 | saving_criteria: 77 | - eval_epoch.accuracy_stu_top1:max 78 | misc: 79 | logs_name: 80 | cuda: True 81 | seed: 1337 82 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqa2/smrl_cfvqasimple_rubi.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqa2/smrl_cfvqasimple_rubi 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | vg: False 19 | model: 20 | name: default 21 | network: 22 | import: cfvqa.models.networks.factory 23 | base: smrl 24 | name: cfvqasimple 25 | fusion_mode: rubi 26 | is_vq: False 27 | cfvqa_params: 28 | mlp_q: 29 | input_dim: 4800 30 | dimensions: [1024,1024,3000] 31 | txt_enc: 32 | name: skipthoughts 33 | type: BayesianUniSkip 34 | dropout: 0.25 35 | fixed_emb: False 36 | dir_st: data/skip-thoughts 37 | self_q_att: True 38 | residual: False 39 | q_single: False 40 | fusion: 41 | type: block 42 | input_dims: [4800, 2048] 43 | output_dim: 2048 44 | mm_dim: 1000 45 | chunks: 20 46 | rank: 15 47 | dropout_input: 0. 48 | dropout_pre_lin: 0. 49 | agg: 50 | type: max 51 | classif: 52 | mlp: 53 | input_dim: 2048 54 | dimensions: [1024,1024,3000] 55 | criterion: 56 | import: cfvqa.models.criterions.factory 57 | name: cfvqasimple_criterion 58 | question_loss_weight: 1.0 59 | metric: 60 | import: cfvqa.models.metrics.factory 61 | name: vqa_cfvqasimple_metrics 62 | optimizer: 63 | import: cfvqa.optimizers.factory 64 | name: Adam 65 | lr: 0.0003 66 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 67 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 68 | lr_decay_epochs: [14, 24, 2] #range 69 | lr_decay_rate: .25 70 | engine: 71 | name: logger 72 | debug: False 73 | print_freq: 10 74 | nb_epochs: 22 75 | saving_criteria: 76 | - eval_epoch.accuracy_all_top1:max 77 | - eval_epoch.accuracy_vq_top1:max 78 | - eval_epoch.accuracy_cfvqa_top1:max 79 | misc: 80 | logs_name: 81 | cuda: True 82 | seed: 1337 83 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqa2/smrl_cfvqasimpleintrod_rubi.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqa2/smrl_cfvqasimpleintrod_sum 3 | resume: last # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | vg: false 19 | model: 20 | name: default 21 | network: 22 | import: cfvqa.models.networks.factory 23 | base: smrl 24 | name: cfvqasimpleintrod 25 | fusion_mode: sum 26 | cfvqa_params: 27 | mlp_q: 28 | input_dim: 4800 29 | dimensions: [1024,1024,3000] 30 | mlp_v: 31 | input_dim: 2048 32 | dimensions: [1024,1024,3000] 33 | txt_enc: 34 | name: skipthoughts 35 | type: BayesianUniSkip 36 | dropout: 0.25 37 | fixed_emb: False 38 | dir_st: data/skip-thoughts 39 | self_q_att: True 40 | residual: False 41 | q_single: False 42 | fusion: 43 | type: block 44 | input_dims: [4800, 2048] 45 | output_dim: 2048 46 | mm_dim: 1000 47 | chunks: 20 48 | rank: 15 49 | dropout_input: 0. 50 | dropout_pre_lin: 0. 51 | agg: 52 | type: max 53 | classif: 54 | mlp: 55 | input_dim: 2048 56 | dimensions: [1024,1024,3000] 57 | criterion: 58 | import: cfvqa.models.criterions.factory 59 | name: cfvqaintrod_criterion 60 | metric: 61 | import: cfvqa.models.metrics.factory 62 | name: vqa_cfvqaintrod_metrics 63 | optimizer: 64 | import: cfvqa.optimizers.factory 65 | name: Adam 66 | lr: 0.0003 67 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 68 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 69 | lr_decay_epochs: [14, 24, 2] #range 70 | lr_decay_rate: .25 71 | engine: 72 | name: logger 73 | debug: False 74 | print_freq: 10 75 | nb_epochs: 22 76 | saving_criteria: 77 | - eval_epoch.accuracy_stu_top1:max 78 | misc: 79 | logs_name: 80 | cuda: True 81 | seed: 1337 82 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqa2/smrl_rubi.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqa2/smrl_rubi 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | vg: false 19 | model: 20 | name: default 21 | network: 22 | import: cfvqa.models.networks.factory 23 | base: smrl 24 | name: rubi 25 | rubi_params: 26 | mlp_q: 27 | input_dim: 4800 28 | dimensions: [1024,1024,3000] 29 | txt_enc: 30 | name: skipthoughts 31 | type: BayesianUniSkip 32 | dropout: 0.25 33 | fixed_emb: False 34 | dir_st: data/skip-thoughts 35 | self_q_att: True 36 | residual: False 37 | q_single: False 38 | fusion: 39 | type: block 40 | input_dims: [4800, 2048] 41 | output_dim: 2048 42 | mm_dim: 1000 43 | chunks: 20 44 | rank: 15 45 | dropout_input: 0. 46 | dropout_pre_lin: 0. 47 | agg: 48 | type: max 49 | classif: 50 | mlp: 51 | input_dim: 2048 52 | dimensions: [1024,1024,3000] 53 | criterion: 54 | import: cfvqa.models.criterions.factory 55 | name: rubi_criterion 56 | question_loss_weight: 1.0 57 | metric: 58 | import: cfvqa.models.metrics.factory 59 | name: vqa_rubi_metrics 60 | optimizer: 61 | import: cfvqa.optimizers.factory 62 | name: Adam 63 | lr: 0.0003 64 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 65 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 66 | lr_decay_epochs: [14, 24, 2] #range 67 | lr_decay_rate: .25 68 | engine: 69 | name: logger 70 | debug: False 71 | print_freq: 10 72 | nb_epochs: 22 73 | saving_criteria: 74 | - eval_epoch.accuracy_top1:max 75 | - eval_epoch.accuracy_all_top1:max 76 | misc: 77 | logs_name: 78 | cuda: True 79 | seed: 1337 80 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqa2/smrl_rubiintrod.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqa2/smrl_rubiintrod 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqa2 # or vqa2vg 7 | dir: data/vqa/vqa2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | vg: false 19 | model: 20 | name: default 21 | network: 22 | import: cfvqa.models.networks.factory 23 | base: smrl 24 | name: rubiintrod 25 | rubi_params: 26 | mlp_q: 27 | input_dim: 4800 28 | dimensions: [1024,1024,3000] 29 | txt_enc: 30 | name: skipthoughts 31 | type: BayesianUniSkip 32 | dropout: 0.25 33 | fixed_emb: False 34 | dir_st: data/skip-thoughts 35 | self_q_att: True 36 | residual: False 37 | q_single: False 38 | fusion: 39 | type: block 40 | input_dims: [4800, 2048] 41 | output_dim: 2048 42 | mm_dim: 1000 43 | chunks: 20 44 | rank: 15 45 | dropout_input: 0. 46 | dropout_pre_lin: 0. 47 | agg: 48 | type: max 49 | classif: 50 | mlp: 51 | input_dim: 2048 52 | dimensions: [1024,1024,3000] 53 | criterion: 54 | import: cfvqa.models.criterions.factory 55 | name: rubiintrod_criterion 56 | metric: 57 | import: cfvqa.models.metrics.factory 58 | name: vqa_rubiintrod_metrics 59 | optimizer: 60 | import: cfvqa.optimizers.factory 61 | name: Adam 62 | lr: 0.0003 63 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 64 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 65 | lr_decay_epochs: [14, 24, 2] #range 66 | lr_decay_rate: .25 67 | engine: 68 | name: logger 69 | debug: False 70 | print_freq: 10 71 | nb_epochs: 22 72 | saving_criteria: 73 | - eval_epoch.accuracy_stu_top1:max 74 | misc: 75 | logs_name: 76 | cuda: True 77 | seed: 1337 78 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqacp2/smrl_baseline.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/smrl_baseline 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: cfvqa.models.networks.factory 22 | base: smrl 23 | name: baseline 24 | cfvqa_params: 25 | mlp_q: 26 | input_dim: 4800 27 | dimensions: [1024,1024,3000] 28 | mlp_v: 29 | input_dim: 2048 30 | dimensions: [1024,1024,3000] 31 | txt_enc: 32 | name: skipthoughts 33 | type: BayesianUniSkip 34 | dropout: 0.25 35 | fixed_emb: False 36 | dir_st: data/skip-thoughts 37 | self_q_att: True 38 | residual: False 39 | q_single: False 40 | fusion: 41 | type: block 42 | input_dims: [4800, 2048] 43 | output_dim: 2048 44 | mm_dim: 1000 45 | chunks: 20 46 | rank: 15 47 | dropout_input: 0. 48 | dropout_pre_lin: 0. 49 | agg: 50 | type: max 51 | classif: 52 | mlp: 53 | input_dim: 2048 54 | dimensions: [1024,1024,3000] 55 | criterion: 56 | import: cfvqa.models.criterions.factory 57 | name: vqa_cross_entropy 58 | question_loss_weight: 1.0 59 | vision_loss_weight: 1.0 60 | loss_temp: 100.0 61 | metric: 62 | import: cfvqa.models.metrics.factory 63 | name: vqa_accuracies 64 | optimizer: 65 | import: cfvqa.optimizers.factory 66 | name: Adam 67 | lr: 0.0003 68 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 69 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 70 | lr_decay_epochs: [14, 24, 2] #range 71 | lr_decay_rate: .25 72 | engine: 73 | name: logger 74 | debug: False 75 | print_freq: 10 76 | nb_epochs: 22 77 | saving_criteria: 78 | - eval_epoch.accuracy_top1:max 79 | misc: 80 | logs_name: 81 | cuda: True 82 | seed: 1337 83 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqacp2/smrl_cfvqa_sum.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/smrl_cfvqa_sum 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: cfvqa.models.networks.factory 22 | base: smrl 23 | name: cfvqa 24 | fusion_mode: sum 25 | cfvqa_params: 26 | mlp_q: 27 | input_dim: 4800 28 | dimensions: [1024,1024,3000] 29 | mlp_v: 30 | input_dim: 2048 31 | dimensions: [1024,1024,3000] 32 | txt_enc: 33 | name: skipthoughts 34 | type: BayesianUniSkip 35 | dropout: 0.25 36 | fixed_emb: False 37 | dir_st: data/skip-thoughts 38 | self_q_att: True 39 | residual: False 40 | q_single: False 41 | fusion: 42 | type: block 43 | input_dims: [4800, 2048] 44 | output_dim: 2048 45 | mm_dim: 1000 46 | chunks: 20 47 | rank: 15 48 | dropout_input: 0. 49 | dropout_pre_lin: 0. 50 | agg: 51 | type: max 52 | classif: 53 | mlp: 54 | input_dim: 2048 55 | dimensions: [1024,1024,3000] 56 | criterion: 57 | import: cfvqa.models.criterions.factory 58 | name: cfvqa_criterion 59 | question_loss_weight: 1.0 60 | vision_loss_weight: 1.0 61 | metric: 62 | import: cfvqa.models.metrics.factory 63 | name: vqa_cfvqa_metrics 64 | optimizer: 65 | import: cfvqa.optimizers.factory 66 | name: Adam 67 | lr: 0.0003 68 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 69 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 70 | lr_decay_epochs: [14, 24, 2] #range 71 | lr_decay_rate: .25 72 | engine: 73 | name: logger 74 | debug: False 75 | print_freq: 10 76 | nb_epochs: 22 77 | saving_criteria: 78 | - eval_epoch.accuracy_all_top1:max 79 | - eval_epoch.accuracy_vq_top1:max 80 | - eval_epoch.accuracy_cfvqa_top1:max 81 | misc: 82 | logs_name: 83 | cuda: True 84 | seed: 1337 85 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqacp2/smrl_cfvqaintrod_sum.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/smrl_cfvqaintrod_sum 3 | resume: last # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: cfvqa.models.networks.factory 22 | base: smrl 23 | name: cfvqaintrod 24 | fusion_mode: sum 25 | cfvqa_params: 26 | mlp_q: 27 | input_dim: 4800 28 | dimensions: [1024,1024,3000] 29 | mlp_v: 30 | input_dim: 2048 31 | dimensions: [1024,1024,3000] 32 | txt_enc: 33 | name: skipthoughts 34 | type: BayesianUniSkip 35 | dropout: 0.25 36 | fixed_emb: False 37 | dir_st: data/skip-thoughts 38 | self_q_att: True 39 | residual: False 40 | q_single: False 41 | fusion: 42 | type: block 43 | input_dims: [4800, 2048] 44 | output_dim: 2048 45 | mm_dim: 1000 46 | chunks: 20 47 | rank: 15 48 | dropout_input: 0. 49 | dropout_pre_lin: 0. 50 | agg: 51 | type: max 52 | classif: 53 | mlp: 54 | input_dim: 2048 55 | dimensions: [1024,1024,3000] 56 | criterion: 57 | import: cfvqa.models.criterions.factory 58 | name: cfvqaintrod_criterion 59 | metric: 60 | import: cfvqa.models.metrics.factory 61 | name: vqa_cfvqaintrod_metrics 62 | optimizer: 63 | import: cfvqa.optimizers.factory 64 | name: Adam 65 | lr: 0.0003 66 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 67 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 68 | lr_decay_epochs: [14, 24, 2] #range 69 | lr_decay_rate: .25 70 | engine: 71 | name: logger 72 | debug: False 73 | print_freq: 10 74 | nb_epochs: 22 75 | saving_criteria: 76 | - eval_epoch.accuracy_stu_top1:max 77 | misc: 78 | logs_name: 79 | cuda: True 80 | seed: 1337 81 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqacp2/smrl_cfvqasimple_rubi.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/smrl_cfvqasimple_rubi 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: cfvqa.models.networks.factory 22 | base: smrl 23 | name: cfvqasimple 24 | fusion_mode: rubi 25 | cfvqa_params: 26 | mlp_q: 27 | input_dim: 4800 28 | dimensions: [1024,1024,3000] 29 | mlp_v: 30 | input_dim: 2048 31 | dimensions: [1024,1024,3000] 32 | txt_enc: 33 | name: skipthoughts 34 | type: BayesianUniSkip 35 | dropout: 0.25 36 | fixed_emb: False 37 | dir_st: data/skip-thoughts 38 | self_q_att: True 39 | residual: False 40 | q_single: False 41 | fusion: 42 | type: block 43 | input_dims: [4800, 2048] 44 | output_dim: 2048 45 | mm_dim: 1000 46 | chunks: 20 47 | rank: 15 48 | dropout_input: 0. 49 | dropout_pre_lin: 0. 50 | agg: 51 | type: max 52 | classif: 53 | mlp: 54 | input_dim: 2048 55 | dimensions: [1024,1024,3000] 56 | criterion: 57 | import: cfvqa.models.criterions.factory 58 | name: cfvqasimple_criterion 59 | question_loss_weight: 1.0 60 | metric: 61 | import: cfvqa.models.metrics.factory 62 | name: vqa_cfvqasimple_metrics 63 | optimizer: 64 | import: cfvqa.optimizers.factory 65 | name: Adam 66 | lr: 0.0003 67 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 68 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 69 | lr_decay_epochs: [14, 24, 2] #range 70 | lr_decay_rate: .25 71 | engine: 72 | name: logger 73 | debug: False 74 | print_freq: 10 75 | nb_epochs: 22 76 | saving_criteria: 77 | - eval_epoch.accuracy_all_top1:max 78 | - eval_epoch.accuracy_vq_top1:max 79 | - eval_epoch.accuracy_cfvqa_top1:max 80 | misc: 81 | logs_name: 82 | cuda: True 83 | seed: 1337 84 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqacp2/smrl_cfvqasimpleintrod_rubi.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/smrl_cfvqasimpleintrod_rubi 3 | resume: last # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: cfvqa.models.networks.factory 22 | base: smrl 23 | name: cfvqasimpleintrod 24 | fusion_mode: rubi 25 | is_vq: False 26 | cfvqa_params: 27 | mlp_q: 28 | input_dim: 4800 29 | dimensions: [1024,1024,3000] 30 | txt_enc: 31 | name: skipthoughts 32 | type: BayesianUniSkip 33 | dropout: 0.25 34 | fixed_emb: False 35 | dir_st: data/skip-thoughts 36 | self_q_att: True 37 | residual: False 38 | q_single: False 39 | fusion: 40 | type: block 41 | input_dims: [4800, 2048] 42 | output_dim: 2048 43 | mm_dim: 1000 44 | chunks: 20 45 | rank: 15 46 | dropout_input: 0. 47 | dropout_pre_lin: 0. 48 | agg: 49 | type: max 50 | classif: 51 | mlp: 52 | input_dim: 2048 53 | dimensions: [1024,1024,3000] 54 | criterion: 55 | import: cfvqa.models.criterions.factory 56 | name: cfvqaintrod_criterion 57 | metric: 58 | import: cfvqa.models.metrics.factory 59 | name: vqa_cfvqaintrod_metrics 60 | optimizer: 61 | import: cfvqa.optimizers.factory 62 | name: Adam 63 | lr: 0.0003 64 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 65 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 66 | lr_decay_epochs: [14, 24, 2] #range 67 | lr_decay_rate: .25 68 | engine: 69 | name: logger 70 | debug: False 71 | print_freq: 10 72 | nb_epochs: 22 73 | saving_criteria: 74 | - eval_epoch.accuracy_stu_top1:max 75 | misc: 76 | logs_name: 77 | cuda: True 78 | seed: 1337 79 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqacp2/smrl_rubi.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/smrl_rubi 3 | resume: # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: cfvqa.models.networks.factory 22 | base: smrl 23 | name: rubi 24 | rubi_params: 25 | mlp_q: 26 | input_dim: 4800 27 | dimensions: [1024,1024,3000] 28 | txt_enc: 29 | name: skipthoughts 30 | type: BayesianUniSkip 31 | dropout: 0.25 32 | fixed_emb: False 33 | dir_st: data/skip-thoughts 34 | self_q_att: True 35 | residual: False 36 | q_single: False 37 | fusion: 38 | type: block 39 | input_dims: [4800, 2048] 40 | output_dim: 2048 41 | mm_dim: 1000 42 | chunks: 20 43 | rank: 15 44 | dropout_input: 0. 45 | dropout_pre_lin: 0. 46 | agg: 47 | type: max 48 | classif: 49 | mlp: 50 | input_dim: 2048 51 | dimensions: [1024,1024,3000] 52 | criterion: 53 | import: cfvqa.models.criterions.factory 54 | name: rubi_criterion 55 | question_loss_weight: 1.0 56 | metric: 57 | import: cfvqa.models.metrics.factory 58 | name: vqa_rubi_metrics 59 | optimizer: 60 | import: cfvqa.optimizers.factory 61 | name: Adam 62 | lr: 0.0003 63 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 64 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 65 | lr_decay_epochs: [14, 24, 2] #range 66 | lr_decay_rate: .25 67 | engine: 68 | name: logger 69 | debug: False 70 | print_freq: 10 71 | nb_epochs: 22 72 | saving_criteria: 73 | - eval_epoch.accuracy_top1:max 74 | - eval_epoch.accuracy_all_top1:max 75 | misc: 76 | logs_name: 77 | cuda: True 78 | seed: 1337 79 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/options/vqacp2/smrl_rubiintrod.yaml: -------------------------------------------------------------------------------- 1 | exp: 2 | dir: logs/vqacp2/smrl_rubiintrod 3 | resume: last # last, best_[...], or empty (from scratch) 4 | dataset: 5 | import: cfvqa.datasets.factory 6 | name: vqacp2 # or vqa2vg 7 | dir: data/vqa/vqacp2 8 | train_split: train 9 | eval_split: val # or test 10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split) 11 | nb_threads: 4 12 | batch_size: 256 13 | nans: 3000 14 | minwcount: 0 15 | nlp: mcb 16 | samplingans: True 17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36 18 | model: 19 | name: default 20 | network: 21 | import: cfvqa.models.networks.factory 22 | base: smrl 23 | name: rubiintrod 24 | rubi_params: 25 | mlp_q: 26 | input_dim: 4800 27 | dimensions: [1024,1024,3000] 28 | txt_enc: 29 | name: skipthoughts 30 | type: BayesianUniSkip 31 | dropout: 0.25 32 | fixed_emb: False 33 | dir_st: data/skip-thoughts 34 | self_q_att: True 35 | residual: False 36 | q_single: False 37 | fusion: 38 | type: block 39 | input_dims: [4800, 2048] 40 | output_dim: 2048 41 | mm_dim: 1000 42 | chunks: 20 43 | rank: 15 44 | dropout_input: 0. 45 | dropout_pre_lin: 0. 46 | agg: 47 | type: max 48 | classif: 49 | mlp: 50 | input_dim: 2048 51 | dimensions: [1024,1024,3000] 52 | criterion: 53 | import: cfvqa.models.criterions.factory 54 | name: rubiintrod_criterion 55 | question_loss_weight: 1.0 56 | metric: 57 | import: cfvqa.models.metrics.factory 58 | name: vqa_rubiintrod_metrics 59 | optimizer: 60 | import: cfvqa.optimizers.factory 61 | name: Adam 62 | lr: 0.0003 63 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace 64 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace 65 | lr_decay_epochs: [14, 24, 2] #range 66 | lr_decay_rate: .25 67 | engine: 68 | name: logger 69 | debug: False 70 | print_freq: 10 71 | nb_epochs: 22 72 | saving_criteria: 73 | - eval_epoch.accuracy_top1:max 74 | - eval_epoch.accuracy_all_top1:max 75 | misc: 76 | logs_name: 77 | cuda: True 78 | seed: 1337 79 | -------------------------------------------------------------------------------- /cfvqa/cfvqa/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import click 3 | import traceback 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | 7 | from bootstrap.lib import utils 8 | from bootstrap.lib.logger import Logger 9 | from bootstrap.lib.options import Options 10 | from cfvqa import engines 11 | from bootstrap import datasets 12 | from bootstrap import models 13 | from bootstrap import optimizers 14 | from bootstrap import views 15 | 16 | 17 | def init_experiment_directory(exp_dir, resume=None): 18 | # create the experiment directory 19 | if not os.path.isdir(exp_dir): 20 | os.system('mkdir -p ' + exp_dir) 21 | else: 22 | if resume is None: 23 | if click.confirm('Exp directory already exists in {}. Erase?' 24 | .format(exp_dir, default=False)): 25 | os.system('rm -r ' + exp_dir) 26 | os.system('mkdir -p ' + exp_dir) 27 | else: 28 | os._exit(1) 29 | 30 | 31 | def init_logs_options_files(exp_dir, resume=None): 32 | # get the logs name which is used for the txt, json and yaml files 33 | # default is `logs.txt`, `logs.json` and `options.yaml` 34 | if 'logs_name' in Options()['misc'] and Options()['misc']['logs_name'] is not None: 35 | logs_name = 'logs_{}'.format(Options()['misc']['logs_name']) 36 | path_yaml = os.path.join(exp_dir, 'options_{}.yaml'.format(logs_name)) 37 | elif resume and Options()['dataset']['train_split'] is None: 38 | eval_split = Options()['dataset']['eval_split'] 39 | path_yaml = os.path.join(exp_dir, 'options_eval_{}.yaml'.format(eval_split)) 40 | logs_name = 'logs_eval_{}'.format(eval_split) 41 | else: 42 | path_yaml = os.path.join(exp_dir, 'options.yaml') 43 | logs_name = 'logs' 44 | 45 | # create the options.yaml file 46 | if not os.path.isfile(path_yaml): 47 | Options().save(path_yaml) 48 | 49 | # create the logs.txt and logs.json files 50 | Logger(exp_dir, name=logs_name) 51 | 52 | 53 | def run(path_opts=None): 54 | # first call to Options() load the options yaml file from --path_opts command line argument if path_opts=None 55 | Options(path_opts) 56 | # initialiaze seeds to be able to reproduce experiment on reload 57 | utils.set_random_seed(Options()['misc']['seed']) 58 | 59 | init_experiment_directory(Options()['exp']['dir'], Options()['exp']['resume']) 60 | init_logs_options_files(Options()['exp']['dir'], Options()['exp']['resume']) 61 | 62 | Logger().log_dict('options', Options(), should_print=True) # display options 63 | Logger()(os.uname()) # display server name 64 | 65 | if torch.cuda.is_available(): 66 | cudnn.benchmark = True 67 | Logger()('Available GPUs: {}'.format(utils.available_gpu_ids())) 68 | 69 | # engine can train, eval, optimize the model 70 | # engine can save and load the model and optimizer 71 | engine = engines.factory() 72 | 73 | # dataset is a dictionary that contains all the needed datasets indexed by modes 74 | # (example: dataset.keys() -> ['train','eval']) 75 | engine.dataset = datasets.factory(engine) 76 | 77 | # model includes a network, a criterion and a metric 78 | # model can register engine hooks (begin epoch, end batch, end batch, etc.) 79 | # (example: "calculate mAP at the end of the evaluation epoch") 80 | # note: model can access to datasets using engine.dataset 81 | engine.model = models.factory(engine) 82 | 83 | # optimizer can register engine hooks 84 | engine.optimizer = optimizers.factory(engine.model, engine) 85 | 86 | # view will save a view.html in the experiment directory 87 | # with some nice plots and curves to monitor training 88 | engine.view = views.factory(engine) 89 | 90 | # load the model and optimizer from a checkpoint 91 | if Options()['exp']['resume']: 92 | engine.resume() 93 | 94 | # if no training split, evaluate the model on the evaluation split 95 | # (example: $ python main.py --dataset.train_split --dataset.eval_split test) 96 | if not Options()['dataset']['train_split']: 97 | engine.eval() 98 | 99 | # optimize the model on the training split for several epochs 100 | # (example: $ python main.py --dataset.train_split train) 101 | # if evaluation split, evaluate the model after each epochs 102 | # (example: $ python main.py --dataset.train_split train --dataset.eval_split val) 103 | if Options()['dataset']['train_split']: 104 | engine.train() 105 | # with torch.autograd.profiler.profile(use_cuda=Options()['misc.cuda']) as prof: 106 | # engine.train() 107 | # path_tracing = 'tracing_1.0_cuda,{}_all.html'.format(Options()['misc.cuda']) 108 | # prof.export_chrome_trace(path_tracing) 109 | 110 | 111 | def main(path_opts=None, run=None): 112 | try: 113 | run(path_opts=path_opts) 114 | # to avoid traceback for -h flag in arguments line 115 | except SystemExit: 116 | pass 117 | except: 118 | # to be able to write the error trace to exp_dir/logs.txt 119 | try: 120 | Logger()(traceback.format_exc(), Logger.ERROR) 121 | except: 122 | pass 123 | 124 | 125 | if __name__ == '__main__': 126 | main(run=run) 127 | 128 | -------------------------------------------------------------------------------- /cfvqa/requirements.txt: -------------------------------------------------------------------------------- 1 | block.bootstrap.pytorch 2 | h5py 3 | plotly==3.10.0 -------------------------------------------------------------------------------- /cfvqa/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import click 3 | import traceback 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | 7 | from bootstrap.lib import utils 8 | from bootstrap.lib.logger import Logger 9 | from bootstrap.lib.options import Options 10 | from cfvqa import engines 11 | from bootstrap import datasets 12 | from bootstrap import models 13 | from bootstrap import optimizers 14 | from bootstrap import views 15 | 16 | 17 | def init_experiment_directory(exp_dir, resume=None): 18 | # create the experiment directory 19 | if not os.path.isdir(exp_dir): 20 | os.system('mkdir -p ' + exp_dir) 21 | else: 22 | if resume is None: 23 | if click.confirm('Exp directory already exists in {}. Erase?' 24 | .format(exp_dir, default=False)): 25 | os.system('rm -r ' + exp_dir) 26 | os.system('mkdir -p ' + exp_dir) 27 | else: 28 | os._exit(1) 29 | 30 | 31 | def init_logs_options_files(exp_dir, resume=None): 32 | # get the logs name which is used for the txt, json and yaml files 33 | # default is `logs.txt`, `logs.json` and `options.yaml` 34 | if 'logs_name' in Options()['misc'] and Options()['misc']['logs_name'] is not None: 35 | logs_name = 'logs_{}'.format(Options()['misc']['logs_name']) 36 | path_yaml = os.path.join(exp_dir, 'options_{}.yaml'.format(logs_name)) 37 | elif resume and Options()['dataset']['train_split'] is None: 38 | eval_split = Options()['dataset']['eval_split'] 39 | path_yaml = os.path.join(exp_dir, 'options_eval_{}.yaml'.format(eval_split)) 40 | logs_name = 'logs_eval_{}'.format(eval_split) 41 | else: 42 | path_yaml = os.path.join(exp_dir, 'options.yaml') 43 | logs_name = 'logs' 44 | 45 | # create the options.yaml file 46 | if not os.path.isfile(path_yaml): 47 | Options().save(path_yaml) 48 | 49 | # create the logs.txt and logs.json files 50 | Logger(exp_dir, name=logs_name) 51 | 52 | 53 | def run(path_opts=None): 54 | # first call to Options() load the options yaml file from --path_opts command line argument if path_opts=None 55 | Options(path_opts) 56 | # initialiaze seeds to be able to reproduce experiment on reload 57 | utils.set_random_seed(Options()['misc']['seed']) 58 | 59 | init_experiment_directory(Options()['exp']['dir'], Options()['exp']['resume']) 60 | init_logs_options_files(Options()['exp']['dir'], Options()['exp']['resume']) 61 | 62 | Logger().log_dict('options', Options(), should_print=True) # display options 63 | Logger()(os.uname()) # display server name 64 | 65 | if torch.cuda.is_available(): 66 | cudnn.benchmark = True 67 | Logger()('Available GPUs: {}'.format(utils.available_gpu_ids())) 68 | 69 | # engine can train, eval, optimize the model 70 | # engine can save and load the model and optimizer 71 | engine = engines.factory() 72 | 73 | # dataset is a dictionary that contains all the needed datasets indexed by modes 74 | # (example: dataset.keys() -> ['train','eval']) 75 | engine.dataset = datasets.factory(engine) 76 | 77 | # model includes a network, a criterion and a metric 78 | # model can register engine hooks (begin epoch, end batch, end batch, etc.) 79 | # (example: "calculate mAP at the end of the evaluation epoch") 80 | # note: model can access to datasets using engine.dataset 81 | engine.model = models.factory(engine) 82 | 83 | # optimizer can register engine hooks 84 | engine.optimizer = optimizers.factory(engine.model, engine) 85 | 86 | # view will save a view.html in the experiment directory 87 | # with some nice plots and curves to monitor training 88 | engine.view = views.factory(engine) 89 | 90 | # load the model and optimizer from a checkpoint 91 | if Options()['exp']['resume']: 92 | engine.resume() 93 | 94 | # if no training split, evaluate the model on the evaluation split 95 | # (example: $ python main.py --dataset.train_split --dataset.eval_split test) 96 | if not Options()['dataset']['train_split']: 97 | engine.eval() 98 | 99 | # optimize the model on the training split for several epochs 100 | # (example: $ python main.py --dataset.train_split train) 101 | # if evaluation split, evaluate the model after each epochs 102 | # (example: $ python main.py --dataset.train_split train --dataset.eval_split val) 103 | if Options()['dataset']['train_split']: 104 | engine.train() 105 | # with torch.autograd.profiler.profile(use_cuda=Options()['misc.cuda']) as prof: 106 | # engine.train() 107 | # path_tracing = 'tracing_1.0_cuda,{}_all.html'.format(Options()['misc.cuda']) 108 | # prof.export_chrome_trace(path_tracing) 109 | 110 | 111 | def main(path_opts=None, run=None): 112 | try: 113 | run(path_opts=path_opts) 114 | # to avoid traceback for -h flag in arguments line 115 | except SystemExit: 116 | pass 117 | except: 118 | # to be able to write the error trace to exp_dir/logs.txt 119 | try: 120 | Logger()(traceback.format_exc(), Logger.ERROR) 121 | except: 122 | pass 123 | 124 | 125 | if __name__ == '__main__': 126 | main(run=run) 127 | 128 | -------------------------------------------------------------------------------- /cfvqa/run_vqa2_cfvqa_introd.sh: -------------------------------------------------------------------------------- 1 | python -m bootstrap.run -o cfvqa/options/vqa2/smrl_cfvqa_sum.yaml 2 | mkdir ./logs/vqa2/smrl_cfvqaintrod_sum/ 3 | cp ./logs/vqa2/smrl_cfvqa_sum/ ./logs/vqa2/smrl_cfvqaintrod_sum/ 4 | python -m run -o ./cfvqa/options/vqa2/smrl_cfvqaintrod_sum.yaml -------------------------------------------------------------------------------- /cfvqa/scripts/run_vqa2_cfvqa_introd.sh: -------------------------------------------------------------------------------- 1 | python -m bootstrap.run -o cfvqa/options/vqa2/smrl_cfvqa_sum.yaml 2 | mkdir ./logs/vqa2/smrl_cfvqaintrod_sum/ 3 | cp ./logs/vqa2/smrl_cfvqa_sum/ ./logs/vqa2/smrl_cfvqaintrod_sum/ 4 | python -m run -o ./cfvqa/options/vqa2/smrl_cfvqaintrod_sum.yaml -------------------------------------------------------------------------------- /cfvqa/scripts/run_vqa2_rubi_introd.sh: -------------------------------------------------------------------------------- 1 | python -m bootstrap.run -o cfvqa/options/vqa2/smrl_rubi.yaml 2 | mkdir ./logs/vqa2/smrl_rubiintrod/ 3 | cp ./logs/vqa2/smrl_rubi/ ./logs/vqa2/smrl_rubiintrod/ 4 | python -m run -o ./cfvqa/options/vqa2/smrl_rubiintrod.yaml -------------------------------------------------------------------------------- /cfvqa/scripts/run_vqa2_rubicf_introd.sh: -------------------------------------------------------------------------------- 1 | python -m bootstrap.run -o cfvqa/options/vqa2/smrl_cfvqasimple_rubi.yaml 2 | mkdir ./logs/vqa2/smrl_cfvqasimpleintrod_rubi/ 3 | cp ./logs/vqa2/smrl_cfvqasimple_rubi/ ./logs/vqa2/smrl_cfvqasimpleintrod_rubi/ 4 | python -m run -o ./cfvqa/options/vqa2/smrl_cfvqasimpleintrod_rubi.yaml -------------------------------------------------------------------------------- /cfvqa/scripts/run_vqacp2_rubi_introd.sh: -------------------------------------------------------------------------------- 1 | python -m bootstrap.run -o cfvqa/options/vqacp2/smrl_rubi.yaml 2 | mkdir ./logs/vqacp2/smrl_rubiintrod/ 3 | cp ./logs/vqacp2/smrl_rubi/ ./logs/vqacp2/smrl_rubiintrod/ 4 | python -m run -o ./cfvqa/options/vqacp2/smrl_rubiintrod.yaml -------------------------------------------------------------------------------- /cfvqa/scripts/run_vqacp2_rubicf_introd.sh: -------------------------------------------------------------------------------- 1 | python -m bootstrap.run -o cfvqa/options/vqacp2/smrl_cfvqasimple_rubi.yaml 2 | mkdir ./logs/vqacp2/smrl_cfvqasimpleintrod_rubi/ 3 | cp ./logs/vqacp2/smrl_cfvqasimple_rubi/ ./logs/vqacp2/smrl_cfvqasimpleintrod_rubi/ 4 | python -m run -o ./cfvqa/options/vqacp2/smrl_cfvqasimpleintrod_rubi.yaml -------------------------------------------------------------------------------- /css/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/ 132 | data 133 | logs/ 134 | logs 135 | -------------------------------------------------------------------------------- /css/README.md: -------------------------------------------------------------------------------- 1 | # CSS+IntroD 2 | 3 | This code provides the implementation of IntroD on top of LMH and CSS on the VQA task. This code is modified from [css](https://github.com/yanxinzju/CSS-VQA) and [lmh](https://github.com/chrisc36/bottom-up-attention-vqa). The data is provided by the authors of "Counterfactual Samples Synthesizing for Robust VQA". 4 | 5 | ## Prerequisites 6 | 7 | (Follow [css](https://github.com/yanxinzju/CSS-VQA)) 8 | 9 | Make sure you are on a machine with a NVIDIA GPU and Python 2.7 with about 100 GB disk space.
10 | h5py==2.10.0
11 | pytorch==1.1.0
12 | Click==7.0
13 | numpy==1.16.5
14 | tqdm==4.35.0
15 | 16 | ## Data Setup 17 | 18 | (The data is provided by [css](https://github.com/yanxinzju/CSS-VQA).) 19 | 20 | You can use 21 | ``` 22 | bash tools/download.sh 23 | ``` 24 | to download the data
25 | and the rest of the data and trained model can be obtained from [BaiduYun](https://pan.baidu.com/s/1oHdwYDSJXC1mlmvu8cQhKw)(passwd:3jot) or [GoogleDrive](https://drive.google.com/drive/folders/13e-b76otJukupbjfC-n1s05L202PaFKQ?usp=sharing) 26 | unzip feature1.zip and feature2.zip and merge them into data/rcnn_feature/
27 | use 28 | ``` 29 | bash tools/process.sh 30 | ``` 31 | to process the data
32 | 33 | ## CSS+IntroD 34 | 35 | ### Step 1: train the teacher model 36 | 37 | (Follow [css](https://github.com/yanxinzju/CSS-VQA)) 38 | 39 | Run 40 | ``` 41 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset cpv2 --mode q_v_debias --debias learned_mixin --topq 1 --topv -1 --qvp 5 --output ./logs/vqacp2/css/ 42 | ``` 43 | to train a teacher model on VQA-CP v2. 44 | 45 | Or 46 | 47 | Run 48 | ``` 49 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset v2 --mode q_v_debias --debias learned_mixin --topq 1 --topv -1 --qvp 5 --output ./logs/vqacp2/css/ 50 | ``` 51 | to train a teacher model on VQA v2. 52 | 53 | ### Step 2: train the student model 54 | 55 | Run 56 | ``` 57 | CUDA_VISIBLE_DEVICES=0 python main_introd.py --dataset cpv2 --output ./logs/vqacp2/css_introd/ --source ./logs/vqacp2/css/ 58 | ``` 59 | to train a student model on VQA-CP v2. 60 | 61 | Or 62 | 63 | Run 64 | ``` 65 | CUDA_VISIBLE_DEVICES=0 python main_introd.py --dataset v2 --output ./logs/vqa2/css_introd/ --source ./logs/vqa2/css/ 66 | ``` 67 | to train a student model on VQA v2. 68 | 69 | ## LMH+IntroD 70 | 71 | ### Step 1: train the teacher model 72 | 73 | (Follow [css](https://github.com/yanxinzju/CSS-VQA)) 74 | 75 | Run 76 | ``` 77 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset cpv2 --mode updn --debias learned_mixin --output ./logs/vqacp2/lmh/ 78 | ``` 79 | to train a teacher model on VQA-CP v2. 80 | 81 | Or 82 | 83 | Run 84 | ``` 85 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset v2 --mode updn --debias learned_mixin --output ./logs/vqa2/lmh/ 86 | ``` 87 | to train a teacher model on VQA v2. 88 | 89 | 90 | ### Step 2: train the student model 91 | 92 | Run 93 | ``` 94 | CUDA_VISIBLE_DEVICES=0 python main_introd.py --dataset cpv2 --output ./logs/vqacp2/lmh_introd/ --source ./logs/vqacp2/lmh/ 95 | ``` 96 | to train a student model on VQA-CP v2. 97 | 98 | Or 99 | 100 | Run 101 | ``` 102 | CUDA_VISIBLE_DEVICES=0 python main_introd.py --dataset v2 --output ./logs/vqa2/lmh_introd/ --source ./logs/vqa2/lmh/ 103 | ``` 104 | to train a student model on VQA v2. -------------------------------------------------------------------------------- /css/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | from fc import FCNet 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, v_dim, q_dim, num_hid): 9 | super(Attention, self).__init__() 10 | self.nonlinear = FCNet([v_dim + q_dim, num_hid]) 11 | self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None) 12 | 13 | def forward(self, v, q): 14 | """ 15 | v: [batch, k, vdim] 16 | q: [batch, qdim] 17 | """ 18 | logits = self.logits(v, q) 19 | w = nn.functional.softmax(logits, 1) 20 | return w 21 | 22 | def logits(self, v, q): 23 | num_objs = v.size(1) 24 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 25 | vq = torch.cat((v, q), 2) 26 | joint_repr = self.nonlinear(vq) 27 | logits = self.linear(joint_repr) 28 | return logits 29 | 30 | 31 | class NewAttention(nn.Module): 32 | def __init__(self, v_dim, q_dim, num_hid, dropout=0.2): 33 | super(NewAttention, self).__init__() 34 | 35 | self.v_proj = FCNet([v_dim, num_hid]) 36 | self.q_proj = FCNet([q_dim, num_hid]) 37 | self.dropout = nn.Dropout(dropout) 38 | self.linear = weight_norm(nn.Linear(q_dim, 1), dim=None) 39 | 40 | def forward(self, v, q): 41 | """ 42 | v: [batch, k, vdim] 43 | q: [batch, qdim] 44 | """ 45 | logits = self.logits(v, q) 46 | # w = nn.functional.softmax(logits, 1) 47 | # return w 48 | return logits 49 | 50 | def logits(self, v, q): 51 | batch, k, _ = v.size() 52 | v_proj = self.v_proj(v) # [batch, k, qdim] 53 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) 54 | joint_repr = v_proj * q_proj 55 | joint_repr = self.dropout(joint_repr) 56 | logits = self.linear(joint_repr) 57 | return logits 58 | -------------------------------------------------------------------------------- /css/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attention import Attention, NewAttention 4 | from language_model import WordEmbedding, QuestionEmbedding 5 | from classifier import SimpleClassifier 6 | from fc import FCNet 7 | import numpy as np 8 | 9 | def mask_softmax(x,mask): 10 | mask=mask.unsqueeze(2).float() 11 | x2=torch.exp(x-torch.max(x)) 12 | x3=x2*mask 13 | epsilon=1e-5 14 | x3_sum=torch.sum(x3,dim=1,keepdim=True)+epsilon 15 | x4=x3/x3_sum.expand_as(x3) 16 | return x4 17 | 18 | 19 | class BaseModel(nn.Module): 20 | def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier): 21 | super(BaseModel, self).__init__() 22 | self.w_emb = w_emb 23 | self.q_emb = q_emb 24 | self.v_att = v_att 25 | self.q_net = q_net 26 | self.v_net = v_net 27 | self.classifier = classifier 28 | self.debias_loss_fn = None 29 | # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2)) 30 | self.bias_lin = torch.nn.Linear(1024, 1) 31 | 32 | def forward(self, v, q, labels, bias,v_mask): 33 | """Forward 34 | 35 | v: [batch, num_objs, obj_dim] 36 | b: [batch, num_objs, b_dim] 37 | q: [batch_size, seq_length] 38 | 39 | return: logits, not probs 40 | """ 41 | w_emb = self.w_emb(q) 42 | q_emb = self.q_emb(w_emb) # [batch, q_dim] 43 | 44 | att = self.v_att(v, q_emb) 45 | if v_mask is None: 46 | att = nn.functional.softmax(att, 1) 47 | else: 48 | att= mask_softmax(att,v_mask) 49 | 50 | v_emb = (att * v).sum(1) # [batch, v_dim] 51 | 52 | q_repr = self.q_net(q_emb) 53 | v_repr = self.v_net(v_emb) 54 | joint_repr = q_repr * v_repr 55 | 56 | logits = self.classifier(joint_repr) 57 | 58 | if labels is not None: 59 | loss = self.debias_loss_fn(joint_repr, logits, bias, labels) 60 | else: 61 | loss = None 62 | return logits, loss,w_emb 63 | 64 | def build_baseline0(dataset, num_hid): 65 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 66 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 67 | v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid) 68 | q_net = FCNet([num_hid, num_hid]) 69 | v_net = FCNet([dataset.v_dim, num_hid]) 70 | classifier = SimpleClassifier( 71 | num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5) 72 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) 73 | 74 | 75 | def build_baseline0_newatt(dataset, num_hid): 76 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 77 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 78 | v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid) 79 | q_net = FCNet([q_emb.num_hid, num_hid]) 80 | v_net = FCNet([dataset.v_dim, num_hid]) 81 | classifier = SimpleClassifier( 82 | num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5) 83 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) -------------------------------------------------------------------------------- /css/base_model_introd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attention import Attention, NewAttention 4 | from language_model import WordEmbedding, QuestionEmbedding 5 | from classifier import SimpleClassifier 6 | from fc import FCNet 7 | import numpy as np 8 | 9 | def mask_softmax(x,mask): 10 | mask=mask.unsqueeze(2).float() 11 | x2=torch.exp(x-torch.max(x)) 12 | x3=x2*mask 13 | epsilon=1e-5 14 | x3_sum=torch.sum(x3,dim=1,keepdim=True)+epsilon 15 | x4=x3/x3_sum.expand_as(x3) 16 | return x4 17 | 18 | 19 | class BaseModel(nn.Module): 20 | def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier): 21 | super(BaseModel, self).__init__() 22 | self.w_emb = w_emb 23 | self.q_emb = q_emb 24 | self.v_att = v_att 25 | self.q_net = q_net 26 | self.v_net = v_net 27 | self.classifier = classifier 28 | self.debias_loss_fn = None 29 | # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2)) 30 | self.bias_lin = torch.nn.Linear(1024, 1) 31 | 32 | def forward(self, v, q, labels, bias,v_mask): 33 | """Forward 34 | 35 | v: [batch, num_objs, obj_dim] 36 | b: [batch, num_objs, b_dim] 37 | q: [batch_size, seq_length] 38 | 39 | return: logits, not probs 40 | """ 41 | w_emb = self.w_emb(q) 42 | q_emb = self.q_emb(w_emb) # [batch, q_dim] 43 | 44 | att = self.v_att(v, q_emb) 45 | if v_mask is None: 46 | att = nn.functional.softmax(att, 1) 47 | else: 48 | att= mask_softmax(att,v_mask) 49 | 50 | v_emb = (att * v).sum(1) # [batch, v_dim] 51 | 52 | q_repr = self.q_net(q_emb) 53 | v_repr = self.v_net(v_emb) 54 | joint_repr = q_repr * v_repr 55 | 56 | logits = self.classifier(joint_repr) 57 | 58 | if labels is not None: 59 | logits_all, loss = self.debias_loss_fn(joint_repr, logits, bias, labels) 60 | else: 61 | logits_all = None 62 | loss = None 63 | return logits, logits_all, loss, w_emb 64 | 65 | def build_baseline0(dataset, num_hid): 66 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 67 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 68 | v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid) 69 | q_net = FCNet([num_hid, num_hid]) 70 | v_net = FCNet([dataset.v_dim, num_hid]) 71 | classifier = SimpleClassifier( 72 | num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5) 73 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) 74 | 75 | 76 | def build_baseline0_newatt(dataset, num_hid): 77 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 78 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 79 | v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid) 80 | q_net = FCNet([q_emb.num_hid, num_hid]) 81 | v_net = FCNet([dataset.v_dim, num_hid]) 82 | classifier = SimpleClassifier( 83 | num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5) 84 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) -------------------------------------------------------------------------------- /css/classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.weight_norm import weight_norm 3 | 4 | 5 | class SimpleClassifier(nn.Module): 6 | def __init__(self, in_dim, hid_dim, out_dim, dropout): 7 | super(SimpleClassifier, self).__init__() 8 | layers = [ 9 | weight_norm(nn.Linear(in_dim, hid_dim), dim=None), 10 | nn.ReLU(), 11 | nn.Dropout(dropout, inplace=True), 12 | weight_norm(nn.Linear(hid_dim, out_dim), dim=None) 13 | ] 14 | self.main = nn.Sequential(*layers) 15 | 16 | def forward(self, x): 17 | logits = self.main(x) 18 | return logits 19 | -------------------------------------------------------------------------------- /css/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import cPickle 4 | from collections import defaultdict, Counter 5 | from os.path import dirname, join 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | import numpy as np 11 | import os 12 | 13 | # from new_dataset import Dictionary, VQAFeatureDataset 14 | from dataset import Dictionary, VQAFeatureDataset 15 | import base_model 16 | from train import train 17 | import utils 18 | 19 | from vqa_debias_loss_functions import * 20 | from tqdm import tqdm 21 | from torch.autograd import Variable 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method") 26 | 27 | # Arguments we added 28 | parser.add_argument( 29 | '--cache_features', default=True, 30 | help="Cache image features in RAM. Makes things much faster, " 31 | "especially if the filesystem is slow, but requires at least 48gb of RAM") 32 | parser.add_argument( 33 | '--dataset', default='cpv2', help="Run on VQA-2.0 instead of VQA-CP 2.0") 34 | parser.add_argument( 35 | '-p', "--entropy_penalty", default=0.36, type=float, 36 | help="Entropy regularizer weight for the learned_mixin model") 37 | parser.add_argument( 38 | '--debias', default="learned_mixin", 39 | choices=["learned_mixin", "reweight", "bias_product", "none"], 40 | help="Kind of ensemble loss to use") 41 | # Arguments from the original model, we leave this default, except we 42 | # set --epochs to 15 since the model maxes out its performance on VQA 2.0 well before then 43 | parser.add_argument('--num_hid', type=int, default=1024) 44 | parser.add_argument('--model', type=str, default='baseline0_newatt') 45 | parser.add_argument('--batch_size', type=int, default=512) 46 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 47 | parser.add_argument('--model_state', type=str, default='logs/exp0/model.pth') 48 | args = parser.parse_args() 49 | return args 50 | 51 | def compute_score_with_logits(logits, labels): 52 | # logits = torch.max(logits, 1)[1].data # argmax 53 | logits = torch.argmax(logits,1) 54 | one_hots = torch.zeros(*labels.size()).cuda() 55 | one_hots.scatter_(1, logits.view(-1, 1), 1) 56 | scores = (one_hots * labels) 57 | return scores 58 | 59 | 60 | def evaluate(model,dataloader,qid2type): 61 | score = 0 62 | upper_bound = 0 63 | score_yesno = 0 64 | score_number = 0 65 | score_other = 0 66 | total_yesno = 0 67 | total_number = 0 68 | total_other = 0 69 | model.train(False) 70 | # import pdb;pdb.set_trace() 71 | 72 | 73 | for v, q, a, b,qids,hintscore in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 74 | v = Variable(v, requires_grad=False).cuda() 75 | q = Variable(q, requires_grad=False).cuda() 76 | pred, _ ,_= model(v, q, None, None,None) 77 | batch_score= compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1) 78 | score += batch_score.sum() 79 | upper_bound += (a.max(1)[0]).sum() 80 | qids = qids.detach().cpu().int().numpy() 81 | for j in range(len(qids)): 82 | qid=qids[j] 83 | typ = qid2type[str(qid)] 84 | if typ == 'yes/no': 85 | score_yesno += batch_score[j] 86 | total_yesno += 1 87 | elif typ == 'other': 88 | score_other += batch_score[j] 89 | total_other += 1 90 | elif typ == 'number': 91 | score_number += batch_score[j] 92 | total_number += 1 93 | else: 94 | print('Hahahahahahahahahahaha') 95 | score = score / len(dataloader.dataset) 96 | upper_bound = upper_bound / len(dataloader.dataset) 97 | score_yesno /= total_yesno 98 | score_other /= total_other 99 | score_number /= total_number 100 | print('\teval overall score: %.2f' % (100 * score)) 101 | print('\teval up_bound score: %.2f' % (100 * upper_bound)) 102 | print('\teval y/n score: %.2f' % (100 * score_yesno)) 103 | print('\teval other score: %.2f' % (100 * score_other)) 104 | print('\teval number score: %.2f' % (100 * score_number)) 105 | 106 | def evaluate_ai(model,dataloader,qid2type,label2ans): 107 | score=0 108 | upper_bound=0 109 | 110 | ai_top1=0 111 | ai_top2=0 112 | ai_top3=0 113 | 114 | for v, q, a, b, qids, hintscore in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 115 | v = Variable(v, requires_grad=False).cuda().float().requires_grad_() 116 | q = Variable(q, requires_grad=False).cuda() 117 | a=a.cuda() 118 | hintscore=hintscore.cuda().float() 119 | pred, _, _ = model(v, q, None, None, None) 120 | vqa_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] # [b , 36, 2048] 121 | 122 | vqa_grad_cam=vqa_grad.sum(2) 123 | sv_ind=torch.argmax(vqa_grad_cam,1) 124 | 125 | x_ind_top1=torch.topk(vqa_grad_cam,k=1)[1] 126 | x_ind_top2=torch.topk(vqa_grad_cam,k=2)[1] 127 | x_ind_top3=torch.topk(vqa_grad_cam,k=3)[1] 128 | 129 | y_score_top1 = hintscore.gather(1,x_ind_top1).sum(1)/1 130 | y_score_top2 = hintscore.gather(1,x_ind_top2).sum(1)/2 131 | y_score_top3 = hintscore.gather(1,x_ind_top3).sum(1)/3 132 | 133 | 134 | batch_score=compute_score_with_logits(pred,a.cuda()).cpu().numpy().sum(1) 135 | score+=batch_score.sum() 136 | upper_bound+=(a.max(1)[0]).sum() 137 | qids=qids.detach().cpu().int().numpy() 138 | for j in range(len(qids)): 139 | if batch_score[j]>0: 140 | ai_top1 += y_score_top1[j] 141 | ai_top2 += y_score_top2[j] 142 | ai_top3 += y_score_top3[j] 143 | 144 | 145 | 146 | score = score / len(dataloader.dataset) 147 | upper_bound = upper_bound / len(dataloader.dataset) 148 | ai_top1=(ai_top1.item() * 1.0) / len(dataloader.dataset) 149 | ai_top2=(ai_top2.item() * 1.0) / len(dataloader.dataset) 150 | ai_top3=(ai_top3.item() * 1.0) / len(dataloader.dataset) 151 | 152 | print('\teval overall score: %.2f' % (100 * score)) 153 | print('\teval up_bound score: %.2f' % (100 * upper_bound)) 154 | print('\ttop1_ai_score: %.2f' % (100 * ai_top1)) 155 | print('\ttop2_ai_score: %.2f' % (100 * ai_top2)) 156 | print('\ttop3_ai_score: %.2f' % (100 * ai_top3)) 157 | 158 | def main(): 159 | args = parse_args() 160 | dataset = args.dataset 161 | 162 | 163 | with open('util/qid2type_%s.json'%args.dataset,'r') as f: 164 | qid2type=json.load(f) 165 | 166 | if dataset=='cpv1': 167 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl') 168 | elif dataset=='cpv2' or dataset=='v2': 169 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 170 | 171 | print("Building test dataset...") 172 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset, 173 | cache_image_features=args.cache_features) 174 | 175 | # Build the model using the original constructor 176 | constructor = 'build_%s' % args.model 177 | model = getattr(base_model, constructor)(eval_dset, args.num_hid).cuda() 178 | 179 | if args.debias == "bias_product": 180 | model.debias_loss_fn = BiasProduct() 181 | elif args.debias == "none": 182 | model.debias_loss_fn = Plain() 183 | elif args.debias == "reweight": 184 | model.debias_loss_fn = ReweightByInvBias() 185 | elif args.debias == "learned_mixin": 186 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty) 187 | else: 188 | raise RuntimeError(args.mode) 189 | 190 | 191 | model_state = torch.load(args.model_state) 192 | model.load_state_dict(model_state) 193 | 194 | 195 | model = model.cuda() 196 | batch_size = args.batch_size 197 | 198 | torch.manual_seed(args.seed) 199 | torch.cuda.manual_seed(args.seed) 200 | torch.backends.cudnn.benchmark = True 201 | 202 | # The original version uses multiple workers, but that just seems slower on my setup 203 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0) 204 | 205 | 206 | 207 | print("Starting eval...") 208 | 209 | evaluate(model,eval_loader,qid2type) 210 | 211 | 212 | 213 | if __name__ == '__main__': 214 | main() 215 | -------------------------------------------------------------------------------- /css/fc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | 5 | 6 | class FCNet(nn.Module): 7 | """Simple class for non-linear fully connect network 8 | """ 9 | def __init__(self, dims): 10 | super(FCNet, self).__init__() 11 | 12 | layers = [] 13 | for i in range(len(dims)-2): 14 | in_dim = dims[i] 15 | out_dim = dims[i+1] 16 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 17 | layers.append(nn.ReLU()) 18 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) 19 | layers.append(nn.ReLU()) 20 | 21 | self.main = nn.Sequential(*layers) 22 | 23 | def forward(self, x): 24 | return self.main(x) 25 | 26 | 27 | if __name__ == '__main__': 28 | fc1 = FCNet([10, 20, 10]) 29 | print(fc1) 30 | 31 | print('============') 32 | fc2 = FCNet([10, 20]) 33 | print(fc2) 34 | -------------------------------------------------------------------------------- /css/language_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | class WordEmbedding(nn.Module): 8 | """Word Embedding 9 | 10 | The ntoken-th dim is used for padding_idx, which agrees *implicitly* 11 | with the definition in Dictionary. 12 | """ 13 | def __init__(self, ntoken, emb_dim, dropout): 14 | super(WordEmbedding, self).__init__() 15 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken) 16 | self.dropout = nn.Dropout(dropout) 17 | self.ntoken = ntoken 18 | self.emb_dim = emb_dim 19 | 20 | def init_embedding(self, np_file): 21 | weight_init = torch.from_numpy(np.load(np_file)) 22 | assert weight_init.shape == (self.ntoken, self.emb_dim) 23 | self.emb.weight.data[:self.ntoken] = weight_init 24 | 25 | def forward(self, x): 26 | emb = self.emb(x) 27 | emb = self.dropout(emb) 28 | return emb 29 | 30 | 31 | class QuestionEmbedding(nn.Module): 32 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'): 33 | """Module for question embedding 34 | """ 35 | super(QuestionEmbedding, self).__init__() 36 | assert rnn_type == 'LSTM' or rnn_type == 'GRU' 37 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU 38 | 39 | self.rnn = rnn_cls( 40 | in_dim, num_hid, nlayers, 41 | bidirectional=bidirect, 42 | dropout=dropout, 43 | batch_first=True) 44 | 45 | self.in_dim = in_dim 46 | self.num_hid = num_hid 47 | self.nlayers = nlayers 48 | self.rnn_type = rnn_type 49 | self.ndirections = 1 + int(bidirect) 50 | 51 | def init_hidden(self, batch): 52 | # just to get the type of tensor 53 | weight = next(self.parameters()).data 54 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid) 55 | if self.rnn_type == 'LSTM': 56 | return (Variable(weight.new(*hid_shape).zero_()), 57 | Variable(weight.new(*hid_shape).zero_())) 58 | else: 59 | return Variable(weight.new(*hid_shape).zero_()) 60 | 61 | def forward(self, x): 62 | # x: [batch, sequence, in_dim] 63 | batch = x.size(0) 64 | hidden = self.init_hidden(batch) 65 | self.rnn.flatten_parameters() 66 | output, hidden = self.rnn(x, hidden) 67 | 68 | if self.ndirections == 1: 69 | return output[:, -1] 70 | 71 | forward_ = output[:, -1, :self.num_hid] 72 | backward = output[:, 0, self.num_hid:] 73 | return torch.cat((forward_, backward), dim=1) 74 | 75 | def forward_all(self, x): 76 | # x: [batch, sequence, in_dim] 77 | batch = x.size(0) 78 | hidden = self.init_hidden(batch) 79 | self.rnn.flatten_parameters() 80 | output, hidden = self.rnn(x, hidden) 81 | return output 82 | -------------------------------------------------------------------------------- /css/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import cPickle as pickle 4 | from collections import defaultdict, Counter 5 | from os.path import dirname, join 6 | import os 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | 13 | from dataset import Dictionary, VQAFeatureDataset 14 | import base_model 15 | from train import train 16 | import utils 17 | import click 18 | 19 | from vqa_debias_loss_functions import * 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method") 24 | 25 | # Arguments we added 26 | parser.add_argument( 27 | '--cache_features', default=True, 28 | help="Cache image features in RAM. Makes things much faster, " 29 | "especially if the filesystem is slow, but requires at least 48gb of RAM") 30 | parser.add_argument( 31 | '--dataset', default='cpv2', 32 | choices=["v2", "cpv2", "cpv1", "cpv2val"], 33 | help="Run on VQA-2.0 instead of VQA-CP 2.0" 34 | ) 35 | parser.add_argument( 36 | '-p', "--entropy_penalty", default=0.36, type=float, 37 | help="Entropy regularizer weight for the learned_mixin model") 38 | parser.add_argument( 39 | '--mode', default="updn", 40 | choices=["updn", "q_debias","v_debias","q_v_debias"], 41 | help="Kind of ensemble loss to use") 42 | parser.add_argument( 43 | '--debias', default="learned_mixin", 44 | choices=["learned_mixin_rw2", "learned_mixin_rw", "learned_mixin", "reweight", "bias_product", "none",'focal'], 45 | help="Kind of ensemble loss to use") 46 | parser.add_argument( 47 | '--topq', type=int,default=1, 48 | choices=[1,2,3], 49 | help="num of words to be masked in questio") 50 | parser.add_argument( 51 | '--keep_qtype', default=True, 52 | help="keep qtype or not") 53 | parser.add_argument( 54 | '--topv', type=int,default=1, 55 | choices=[1,3,5,-1], 56 | help="num of object bbox to be masked in image") 57 | parser.add_argument( 58 | '--top_hint',type=int, default=9, 59 | choices=[9,18,27,36], 60 | help="num of hint") 61 | parser.add_argument( 62 | '--qvp', type=int,default=0, 63 | choices=[0,1,2,3,4,5,6,7,8,9,10], 64 | help="ratio of q_bias and v_bias") 65 | parser.add_argument( 66 | '--eval_each_epoch', default=True, 67 | help="Evaluate every epoch, instead of at the end") 68 | 69 | # Arguments from the original model, we leave this default, except we 70 | # set --epochs to 30 since the model maxes out its performance on VQA 2.0 well before then 71 | # parser.add_argument('--epochs', type=int, default=30) 72 | parser.add_argument('--epochs', type=int, default=15) 73 | parser.add_argument('--num_hid', type=int, default=1024) 74 | parser.add_argument('--model', type=str, default='baseline0_newatt') 75 | parser.add_argument('--output', type=str, default='logs/exp0') 76 | parser.add_argument('--batch_size', type=int, default=512) 77 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 78 | parser.add_argument('--feature', type=str, default='css') 79 | args = parser.parse_args() 80 | return args 81 | 82 | def get_bias(train_dset,eval_dset): 83 | # Compute the bias: 84 | # The bias here is just the expected score for each answer/question type 85 | answer_voc_size = train_dset.num_ans_candidates 86 | 87 | # question_type -> answer -> total score 88 | question_type_to_probs = defaultdict(Counter) 89 | 90 | # question_type -> num_occurances 91 | question_type_to_count = Counter() 92 | for ex in train_dset.entries: 93 | ans = ex["answer"] 94 | q_type = ans["question_type"] 95 | question_type_to_count[q_type] += 1 96 | if ans["labels"] is not None: 97 | for label, score in zip(ans["labels"], ans["scores"]): 98 | question_type_to_probs[q_type][label] += score 99 | question_type_to_prob_array = {} 100 | 101 | for q_type, count in question_type_to_count.items(): 102 | prob_array = np.zeros(answer_voc_size, np.float32) 103 | for label, total_score in question_type_to_probs[q_type].items(): 104 | prob_array[label] += total_score 105 | prob_array /= count 106 | question_type_to_prob_array[q_type] = prob_array 107 | 108 | for ds in [train_dset,eval_dset]: 109 | for ex in ds.entries: 110 | q_type = ex["answer"]["question_type"] 111 | ex["bias"] = question_type_to_prob_array[q_type] 112 | 113 | 114 | def main(): 115 | args = parse_args() 116 | dataset=args.dataset 117 | # args.output=os.path.join('logs',args.output) 118 | if not os.path.isdir(args.output): 119 | utils.create_dir(args.output) 120 | else: 121 | if click.confirm('Exp directory already exists in {}. Erase?' 122 | .format(args.output, default=False)): 123 | os.system('rm -r ' + args.output) 124 | utils.create_dir(args.output) 125 | 126 | else: 127 | os._exit(1) 128 | 129 | if dataset=='cpv1': 130 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl') 131 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val': 132 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 133 | 134 | print("Building train dataset...") 135 | train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset, 136 | cache_image_features=args.cache_features) 137 | 138 | print("Building test dataset...") 139 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset, 140 | cache_image_features=args.cache_features) 141 | 142 | get_bias(train_dset,eval_dset) 143 | 144 | 145 | # Build the model using the original constructor 146 | constructor = 'build_%s' % args.model 147 | model = getattr(base_model, constructor)(train_dset, args.num_hid).cuda() 148 | if dataset=='cpv1': 149 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.npy') 150 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val': 151 | model.w_emb.init_embedding('data/glove6b_init_300d.npy') 152 | 153 | # Add the loss_fn based our arguments 154 | if args.debias == "bias_product": 155 | model.debias_loss_fn = BiasProduct() 156 | elif args.debias == "none": 157 | model.debias_loss_fn = Plain() 158 | elif args.debias == "reweight": 159 | model.debias_loss_fn = ReweightByInvBias() 160 | elif args.debias == "learned_mixin": 161 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty) 162 | elif args.debias == "focal": 163 | model.debias_loss_fn = Focal() 164 | else: 165 | raise RuntimeError(args.mode) 166 | 167 | 168 | with open('util/qid2type_%s.json'%args.dataset,'r') as f: 169 | qid2type=json.load(f) 170 | model=model.cuda() 171 | batch_size = args.batch_size 172 | 173 | torch.manual_seed(args.seed) 174 | torch.cuda.manual_seed(args.seed) 175 | torch.backends.cudnn.benchmark = True 176 | 177 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0) 178 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0) 179 | 180 | print("Starting training...") 181 | train(model, train_loader, eval_loader, args,qid2type) 182 | 183 | if __name__ == '__main__': 184 | main() 185 | 186 | 187 | 188 | 189 | 190 | 191 | -------------------------------------------------------------------------------- /css/main_introd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import cPickle as pickle 4 | from collections import defaultdict, Counter 5 | from os.path import dirname, join 6 | import os 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | 13 | from dataset import Dictionary, VQAFeatureDataset 14 | import base_model_introd as base_model 15 | from train_introd import train 16 | import utils 17 | import click 18 | 19 | from vqa_debias_loss_functions import * 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method") 24 | 25 | # Arguments we added 26 | parser.add_argument( 27 | '--cache_features', default=True, 28 | help="Cache image features in RAM. Makes things much faster, " 29 | "especially if the filesystem is slow, but requires at least 48gb of RAM") 30 | parser.add_argument( 31 | '--dataset', default='cpv2', 32 | choices=["v2", "cpv2", "cpv1", "cpv2val"], 33 | help="Run on VQA-2.0 instead of VQA-CP 2.0" 34 | ) 35 | parser.add_argument( 36 | '-p', "--entropy_penalty", default=0.36, type=float, 37 | help="Entropy regularizer weight for the learned_mixin model") 38 | parser.add_argument( 39 | '--mode', default="updn", 40 | choices=["updn", "q_debias","v_debias","q_v_debias"], 41 | help="Kind of ensemble loss to use") 42 | parser.add_argument( 43 | '--debias', default="learned_mixin", 44 | choices=["learned_mixin_rw2", "learned_mixin_rw", "learned_mixin", "reweight", "bias_product", "none",'focal'], 45 | help="Kind of ensemble loss to use") 46 | parser.add_argument( 47 | '--topq', type=int,default=1, 48 | choices=[1,2,3], 49 | help="num of words to be masked in questio") 50 | parser.add_argument( 51 | '--keep_qtype', default=True, 52 | help="keep qtype or not") 53 | parser.add_argument( 54 | '--topv', type=int,default=1, 55 | choices=[1,3,5,-1], 56 | help="num of object bbox to be masked in image") 57 | parser.add_argument( 58 | '--top_hint',type=int, default=9, 59 | choices=[9,18,27,36], 60 | help="num of hint") 61 | parser.add_argument( 62 | '--qvp', type=int,default=0, 63 | choices=[0,1,2,3,4,5,6,7,8,9,10], 64 | help="ratio of q_bias and v_bias") 65 | parser.add_argument( 66 | '--eval_each_epoch', default=True, 67 | help="Evaluate every epoch, instead of at the end") 68 | 69 | # Arguments from the original model, we leave this default, except we 70 | # set --epochs to 30 since the model maxes out its performance on VQA 2.0 well before then 71 | parser.add_argument('--epochs', type=int, default=30) 72 | parser.add_argument('--num_hid', type=int, default=1024) 73 | parser.add_argument('--model', type=str, default='baseline0_newatt') 74 | parser.add_argument('--source', type=str, default='./logs/vqacp2/css/') 75 | parser.add_argument('--output', type=str, default='logs/exp0') 76 | parser.add_argument('--batch_size', type=int, default=512) 77 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 78 | args = parser.parse_args() 79 | return args 80 | 81 | def get_bias(train_dset,eval_dset): 82 | # Compute the bias: 83 | # The bias here is just the expected score for each answer/question type 84 | answer_voc_size = train_dset.num_ans_candidates 85 | 86 | # question_type -> answer -> total score 87 | question_type_to_probs = defaultdict(Counter) 88 | 89 | # question_type -> num_occurances 90 | question_type_to_count = Counter() 91 | for ex in train_dset.entries: 92 | ans = ex["answer"] 93 | q_type = ans["question_type"] 94 | question_type_to_count[q_type] += 1 95 | if ans["labels"] is not None: 96 | for label, score in zip(ans["labels"], ans["scores"]): 97 | question_type_to_probs[q_type][label] += score 98 | question_type_to_prob_array = {} 99 | 100 | for q_type, count in question_type_to_count.items(): 101 | prob_array = np.zeros(answer_voc_size, np.float32) 102 | for label, total_score in question_type_to_probs[q_type].items(): 103 | prob_array[label] += total_score 104 | prob_array /= count 105 | question_type_to_prob_array[q_type] = prob_array 106 | 107 | for ds in [train_dset,eval_dset]: 108 | for ex in ds.entries: 109 | q_type = ex["answer"]["question_type"] 110 | ex["bias"] = question_type_to_prob_array[q_type] 111 | 112 | 113 | def main(): 114 | args = parse_args() 115 | dataset=args.dataset 116 | # args.output=os.path.join('logs',args.output) 117 | if not os.path.isdir(args.output): 118 | utils.create_dir(args.output) 119 | else: 120 | if click.confirm('Exp directory already exists in {}. Erase?' 121 | .format(args.output, default=False)): 122 | os.system('rm -r ' + args.output) 123 | utils.create_dir(args.output) 124 | 125 | else: 126 | os._exit(1) 127 | 128 | if dataset=='cpv1': 129 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl') 130 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val': 131 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 132 | 133 | print("Building train dataset...") 134 | train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset, 135 | cache_image_features=args.cache_features) 136 | 137 | print("Building test dataset...") 138 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset, 139 | cache_image_features=args.cache_features) 140 | 141 | get_bias(train_dset,eval_dset) 142 | 143 | # Build the model using the original constructor 144 | constructor = 'build_%s' % args.model 145 | model = getattr(base_model, constructor)(train_dset, args.num_hid).cuda() 146 | if dataset=='cpv1': 147 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.npy') 148 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val': 149 | model.w_emb.init_embedding('data/glove6b_init_300d.npy') 150 | 151 | model_student = getattr(base_model, constructor)(train_dset, args.num_hid).cuda() 152 | if dataset=='cpv1': 153 | model_student.w_emb.init_embedding('data/glove6b_init_300d_v1.npy') 154 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val': 155 | model_student.w_emb.init_embedding('data/glove6b_init_300d.npy') 156 | 157 | state_dict = torch.load(join(args.source, "model.pth")) 158 | model.debias_loss_fn = LearnedMixinKD() 159 | model.load_state_dict(state_dict, strict=False) 160 | 161 | model_student.debias_loss_fn = PlainKD() 162 | model.train(False) 163 | 164 | with open('util/qid2type_%s.json'%args.dataset,'r') as f: 165 | qid2type=json.load(f) 166 | model=model.cuda() 167 | model_student=model_student.cuda() 168 | batch_size = args.batch_size 169 | 170 | torch.manual_seed(args.seed) 171 | torch.cuda.manual_seed(args.seed) 172 | torch.backends.cudnn.benchmark = True 173 | 174 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0) 175 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0) 176 | 177 | print("Starting training...") 178 | train(model, model_student, train_loader, eval_loader, args,qid2type) 179 | 180 | if __name__ == '__main__': 181 | main() 182 | 183 | 184 | 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /css/tools/compute_softscore_val.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import json 7 | import numpy as np 8 | import re 9 | import cPickle 10 | 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | from dataset import Dictionary 13 | import utils 14 | 15 | 16 | contractions = { 17 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 18 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 19 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 20 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 21 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 22 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 23 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 24 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 25 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 26 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 27 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 28 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 29 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 30 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 31 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 32 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 33 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 34 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 35 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 36 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 37 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 38 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 39 | "someonell": "someone'll", "someones": "someone's", "somethingd": 40 | "something'd", "somethingd've": "something'd've", "something'dve": 41 | "something'd've", "somethingll": "something'll", "thats": 42 | "that's", "thered": "there'd", "thered've": "there'd've", 43 | "there'dve": "there'd've", "therere": "there're", "theres": 44 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 45 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 46 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 47 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 48 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 49 | "what's", "whatve": "what've", "whens": "when's", "whered": 50 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 51 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 52 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 53 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 54 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 55 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 56 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 57 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 58 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 59 | "you'll", "youre": "you're", "youve": "you've" 60 | } 61 | 62 | manual_map = { 'none': '0', 63 | 'zero': '0', 64 | 'one': '1', 65 | 'two': '2', 66 | 'three': '3', 67 | 'four': '4', 68 | 'five': '5', 69 | 'six': '6', 70 | 'seven': '7', 71 | 'eight': '8', 72 | 'nine': '9', 73 | 'ten': '10'} 74 | articles = ['a', 'an', 'the'] 75 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 76 | comma_strip = re.compile("(\d)(\,)(\d)") 77 | punct = [';', r"/", '[', ']', '"', '{', '}', 78 | '(', ')', '=', '+', '\\', '_', '-', 79 | '>', '<', '@', '`', ',', '?', '!'] 80 | 81 | 82 | def get_score(occurences): 83 | if occurences == 0: 84 | return 0 85 | elif occurences == 1: 86 | return 0.3 87 | elif occurences == 2: 88 | return 0.6 89 | elif occurences == 3: 90 | return 0.9 91 | else: 92 | return 1 93 | 94 | 95 | def process_punctuation(inText): 96 | outText = inText 97 | for p in punct: 98 | if (p + ' ' in inText or ' ' + p in inText) \ 99 | or (re.search(comma_strip, inText) != None): 100 | outText = outText.replace(p, '') 101 | else: 102 | outText = outText.replace(p, ' ') 103 | outText = period_strip.sub("", outText, re.UNICODE) 104 | return outText 105 | 106 | 107 | def process_digit_article(inText): 108 | outText = [] 109 | tempText = inText.lower().split() 110 | for word in tempText: 111 | word = manual_map.setdefault(word, word) 112 | if word not in articles: 113 | outText.append(word) 114 | else: 115 | pass 116 | for wordId, word in enumerate(outText): 117 | if word in contractions: 118 | outText[wordId] = contractions[word] 119 | outText = ' '.join(outText) 120 | return outText 121 | 122 | 123 | def multiple_replace(text, wordDict): 124 | for key in wordDict: 125 | text = text.replace(key, wordDict[key]) 126 | return text 127 | 128 | 129 | def preprocess_answer(answer): 130 | answer = process_digit_article(process_punctuation(answer)) 131 | answer = answer.replace(',', '') 132 | return answer 133 | 134 | 135 | def filter_answers(answers_dset, min_occurence): 136 | """This will change the answer to preprocessed version 137 | """ 138 | occurence = {} 139 | for ans_entry in answers_dset: 140 | gtruth = ans_entry['multiple_choice_answer'] 141 | gtruth = preprocess_answer(gtruth) 142 | if gtruth not in occurence: 143 | occurence[gtruth] = set() 144 | occurence[gtruth].add(ans_entry['question_id']) 145 | for answer in occurence.keys(): 146 | if len(occurence[answer]) < min_occurence: 147 | occurence.pop(answer) 148 | 149 | print('Num of answers that appear >= %d times: %d' % ( 150 | min_occurence, len(occurence))) 151 | return occurence 152 | 153 | 154 | 155 | 156 | 157 | 158 | def create_ans2label(occurence, name, cache_root): 159 | """Note that this will also create label2ans.pkl at the same time 160 | 161 | occurence: dict {answer -> whatever} 162 | name: prefix of the output file 163 | cache_root: str 164 | """ 165 | ans2label = {} 166 | label2ans = [] 167 | label = 0 168 | for answer in occurence: 169 | label2ans.append(answer) 170 | ans2label[answer] = label 171 | label += 1 172 | 173 | utils.create_dir(cache_root) 174 | 175 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl') 176 | cPickle.dump(ans2label, open(cache_file, 'wb')) 177 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl') 178 | cPickle.dump(label2ans, open(cache_file, 'wb')) 179 | return ans2label 180 | 181 | 182 | def compute_target(answers_dset, ans2label, name, cache_root): 183 | """Augment answers_dset with soft score as label 184 | 185 | ***answers_dset should be preprocessed*** 186 | 187 | Write result into a cache file 188 | """ 189 | target = [] 190 | for ans_entry in answers_dset: 191 | answers = ans_entry['answers'] 192 | answer_count = {} 193 | for answer in answers: 194 | answer_ = answer['answer'] 195 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 196 | 197 | labels = [] 198 | scores = [] 199 | for answer in answer_count: 200 | if answer not in ans2label: 201 | continue 202 | labels.append(ans2label[answer]) 203 | score = get_score(answer_count[answer]) 204 | scores.append(score) 205 | 206 | label_counts = {} 207 | for k, v in answer_count.items(): 208 | if k in ans2label: 209 | label_counts[ans2label[k]] = v 210 | 211 | target.append({ 212 | 'question_id': ans_entry['question_id'], 213 | 'question_type': ans_entry['question_type'], 214 | 'image_id': ans_entry['image_id'], 215 | 'label_counts': label_counts, 216 | 'labels': labels, 217 | 'scores': scores 218 | }) 219 | 220 | print(cache_root) 221 | utils.create_dir(cache_root) 222 | cache_file = os.path.join(cache_root, name+'_target.pkl') 223 | print(cache_file) 224 | with open(cache_file, 'wb') as f: 225 | cPickle.dump(target, f) 226 | return target 227 | 228 | 229 | 230 | def get_answer(qid, answers): 231 | for ans in answers: 232 | if ans['question_id'] == qid: 233 | return ans 234 | 235 | 236 | def get_question(qid, questions): 237 | for question in questions: 238 | if question['question_id'] == qid: 239 | return question 240 | 241 | def load_cp(): 242 | train_answer_file = "data/vqacp2val/vqacp_v2_train_annotations.json" 243 | with open(train_answer_file) as f: 244 | train_answers = json.load(f) # ['annotations'] 245 | 246 | val_answer_file = "data/vqacp2val/vqacp_v2_test_annotations.json" 247 | with open(val_answer_file) as f: 248 | val_answers = json.load(f) # ['annotations'] 249 | 250 | occurence = filter_answers(train_answers, 9) 251 | ans2label = create_ans2label(occurence, 'trainval', "data/cpval-cache") 252 | compute_target(train_answers, ans2label, 'train', "data/cpval-cache") 253 | compute_target(val_answers, ans2label, 'val', "data/cpval-cache") 254 | 255 | def main(): 256 | # parser = argparse.ArgumentParser("Dataset preprocessing") 257 | # parser.add_argument("dataset", choices=["cp_v2", "v2",'cp_v1']) 258 | # args = parser.parse_args() 259 | # if args.dataset == "v2": 260 | # load_v2() 261 | # elif args.dataset == "cp_v1": 262 | # load_cp_v1() 263 | # elif args.dataset=='cp_v2': 264 | # load_cp() 265 | load_cp() 266 | 267 | 268 | 269 | if __name__ == '__main__': 270 | main() 271 | -------------------------------------------------------------------------------- /css/tools/create_dictionary.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | from dataset import Dictionary 8 | 9 | 10 | 11 | 12 | def create_dictionary(dataroot): 13 | dictionary = Dictionary() 14 | questions = [] 15 | files = [ 16 | 'v2_OpenEnded_mscoco_train2014_questions.json', 17 | 'v2_OpenEnded_mscoco_val2014_questions.json', 18 | 'v2_OpenEnded_mscoco_test2015_questions.json', 19 | 'v2_OpenEnded_mscoco_test-dev2015_questions.json' 20 | ] 21 | for path in files: 22 | question_path = os.path.join(dataroot, path) 23 | qs = json.load(open(question_path))['questions'] 24 | for q in qs: 25 | dictionary.tokenize(q['question'], True) 26 | dictionary.tokenize('wordmask',True) 27 | return dictionary 28 | 29 | 30 | def create_glove_embedding_init(idx2word, glove_file): 31 | word2emb = {} 32 | with open(glove_file, 'r') as f: 33 | entries = f.readlines() 34 | emb_dim = len(entries[0].split(' ')) - 1 35 | print('embedding dim is %d' % emb_dim) 36 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 37 | 38 | for entry in entries: 39 | vals = entry.split(' ') 40 | word = vals[0] 41 | vals = map(float, vals[1:]) 42 | word2emb[word] = np.array(vals) 43 | for idx, word in enumerate(idx2word): 44 | if word not in word2emb: 45 | continue 46 | weights[idx] = word2emb[word] 47 | return weights, word2emb 48 | 49 | 50 | if __name__ == '__main__': 51 | d = create_dictionary('data') 52 | d.dump_to_file('data/dictionary.pkl') 53 | 54 | d = Dictionary.load_from_file('data/dictionary.pkl') 55 | emb_dim = 300 56 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 57 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 58 | np.save('data/glove6b_init_%dd.npy' % emb_dim, weights) 59 | -------------------------------------------------------------------------------- /css/tools/create_dictionary_v1.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | from dataset import Dictionary 8 | 9 | 10 | def create_dictionary(dataroot): 11 | dictionary = Dictionary() 12 | questions = [] 13 | files = [ 14 | 'OpenEnded_mscoco_train2014_questions.json', 15 | 'OpenEnded_mscoco_val2014_questions.json', 16 | 'OpenEnded_mscoco_test2015_questions.json', 17 | 'OpenEnded_mscoco_test-dev2015_questions.json' 18 | ] 19 | for path in files: 20 | question_path = os.path.join(dataroot, path) 21 | qs = json.load(open(question_path))['questions'] 22 | for q in qs: 23 | dictionary.tokenize(q['question'], True) 24 | return dictionary 25 | 26 | 27 | def create_glove_embedding_init(idx2word, glove_file): 28 | word2emb = {} 29 | with open(glove_file, 'r') as f: 30 | entries = f.readlines() 31 | emb_dim = len(entries[0].split(' ')) - 1 32 | print('embedding dim is %d' % emb_dim) 33 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 34 | 35 | for entry in entries: 36 | vals = entry.split(' ') 37 | word = vals[0] 38 | vals = map(float, vals[1:]) 39 | word2emb[word] = np.array(vals) 40 | for idx, word in enumerate(idx2word): 41 | if word not in word2emb: 42 | continue 43 | weights[idx] = word2emb[word] 44 | return weights, word2emb 45 | 46 | 47 | if __name__ == '__main__': 48 | d = create_dictionary('data') 49 | d.dump_to_file('data/dictionary_v1.pkl') 50 | 51 | d = Dictionary.load_from_file('data/dictionary_v1.pkl') 52 | emb_dim = 300 53 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 54 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 55 | np.save('data/glove6b_init_%dd_v1.npy' % emb_dim, weights) 56 | -------------------------------------------------------------------------------- /css/tools/download.sh: -------------------------------------------------------------------------------- 1 | ## Script for downloading data 2 | 3 | # GloVe Vectors 4 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip 5 | unzip data/glove.6B.zip -d data/glove 6 | rm data/glove.6B.zip 7 | 8 | # VQA-CP2 9 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json 10 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json 11 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json 12 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json 13 | 14 | # VQA-CP1 15 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json 16 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json 17 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json 18 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json 19 | 20 | # VQA-V2 21 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip 22 | unzip data/v2_Questions_Train_mscoco.zip -d data 23 | rm data/v2_Questions_Train_mscoco.zip 24 | 25 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip 26 | unzip data/v2_Questions_Val_mscoco.zip -d data 27 | rm data/v2_Questions_Val_mscoco.zip 28 | 29 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip 30 | unzip data/v2_Questions_Test_mscoco.zip -d data 31 | rm data/v2_Questions_Test_mscoco.zip 32 | 33 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip 34 | unzip data/v2_Annotations_Train_mscoco.zip -d data 35 | rm data/v2_Annotations_Train_mscoco.zip 36 | 37 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip 38 | unzip data/v2_Annotations_Val_mscoco.zip -d data 39 | rm data/v2_Annotations_Val_mscoco.zip 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /css/tools/process.sh: -------------------------------------------------------------------------------- 1 | # Process data 2 | 3 | python tools/create_dictionary.py 4 | python tools/create_dictionary_v1.py 5 | python tools/compute_softscore.py v2 6 | python tools/compute_softscore.py cp_v1 7 | python tools/compute_softscore.py cp_v2 8 | 9 | -------------------------------------------------------------------------------- /css/train_introd.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import time 5 | from os.path import join 6 | 7 | import torch 8 | import torch.nn as nn 9 | import utils 10 | from torch.autograd import Variable 11 | import numpy as np 12 | from tqdm import tqdm 13 | import random 14 | import copy 15 | from torch.nn import functional as F 16 | 17 | 18 | def compute_score_with_logits(logits, labels): 19 | logits = torch.argmax(logits, 1) 20 | one_hots = torch.zeros(*labels.size()).cuda() 21 | one_hots.scatter_(1, logits.view(-1, 1), 1) 22 | scores = (one_hots * labels) 23 | return scores 24 | 25 | def train(model_teacher, model, train_loader, eval_loader,args,qid2type): 26 | dataset=args.dataset 27 | num_epochs=args.epochs 28 | mode=args.mode 29 | run_eval=args.eval_each_epoch 30 | output=args.output 31 | optim = torch.optim.Adamax(model.parameters()) 32 | logger = utils.Logger(os.path.join(output, 'log.txt')) 33 | total_step = 0 34 | best_eval_score = 0 35 | 36 | logsigmoid = torch.nn.LogSigmoid() 37 | 38 | KLDivLoss = torch.nn.KLDivLoss(reduction='none') 39 | 40 | for epoch in range(num_epochs): 41 | total_loss = 0 42 | train_score = 0 43 | 44 | t = time.time() 45 | for i, (v, q, a, b, _, _, _, _) in tqdm(enumerate(train_loader), ncols=100, 46 | desc="Epoch %d" % (epoch + 1), total=len(train_loader)): 47 | 48 | total_step += 1 49 | 50 | 51 | ######################################### 52 | v = Variable(v).cuda().requires_grad_() 53 | q = Variable(q).cuda() 54 | # q_mask=Variable(q_mask).cuda() 55 | a = Variable(a).cuda() 56 | b = Variable(b).cuda() 57 | # hintscore = Variable(hintscore).cuda() 58 | # type_mask=Variable(type_mask).float().cuda() 59 | # notype_mask=Variable(notype_mask).float().cuda() 60 | ######################################### 61 | 62 | pred_nie, pred_te, _, _ = model_teacher(v, q, a, b, None) 63 | pred, _, loss_ce, _ = model(v, q, a, b, None) 64 | 65 | aa = a/torch.clamp(a.sum(1, keepdim=True), min=1e-24) 66 | loss_te = -(aa*logsigmoid(pred_te) + (1-aa)*logsigmoid(-pred_te)).sum(1) 67 | loss_nie = -(aa*logsigmoid(pred_nie) + (1-aa)*logsigmoid(-pred_nie)).sum(1) 68 | 69 | loss_te = torch.clamp(loss_te, min=1e-12) 70 | loss_nie = torch.clamp(loss_nie, min=1e-12) 71 | w = loss_nie/(loss_te+loss_nie) 72 | w = w.clone().detach() 73 | 74 | # KL 75 | prob_nie = F.softmax(pred_nie, -1).clone().detach() 76 | loss_kl = - prob_nie*F.log_softmax(pred, -1) 77 | loss_kl = loss_kl.sum(1) 78 | 79 | loss = (w*loss_kl + (1-w)*loss_ce).mean() 80 | 81 | if (loss != loss).any(): 82 | raise ValueError("NaN loss") 83 | loss.backward() 84 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 85 | optim.step() 86 | optim.zero_grad() 87 | 88 | total_loss += loss.item() * q.size(0) 89 | batch_score = compute_score_with_logits(pred, a.data).sum() 90 | train_score += batch_score 91 | 92 | total_loss /= len(train_loader.dataset) 93 | train_score = 100 * train_score / len(train_loader.dataset) 94 | 95 | if run_eval: 96 | model.train(False) 97 | results = evaluate(model, eval_loader, qid2type) 98 | results["epoch"] = epoch + 1 99 | results["step"] = total_step 100 | results["train_loss"] = total_loss 101 | results["train_score"] = train_score 102 | 103 | model.train(True) 104 | 105 | eval_score = results["score"] 106 | bound = results["upper_bound"] 107 | yn = results['score_yesno'] 108 | other = results['score_other'] 109 | num = results['score_number'] 110 | 111 | logger.write('epoch %d, time: %.2f' % (epoch + 1, time.time() - t)) 112 | logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score)) 113 | 114 | if run_eval: 115 | logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound)) 116 | logger.write('\tyn score: %.2f other score: %.2f num score: %.2f' % (100 * yn, 100 * other, 100 * num)) 117 | 118 | if eval_score > best_eval_score: 119 | model_path = os.path.join(output, 'model.pth') 120 | torch.save(model.state_dict(), model_path) 121 | best_eval_score = eval_score 122 | 123 | 124 | def evaluate(model, dataloader, qid2type): 125 | score = 0 126 | upper_bound = 0 127 | score_yesno = 0 128 | score_number = 0 129 | score_other = 0 130 | total_yesno = 0 131 | total_number = 0 132 | total_other = 0 133 | 134 | for v, q, a, b, qids, _ in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 135 | v = Variable(v, requires_grad=False).cuda() 136 | q = Variable(q, requires_grad=False).cuda() 137 | pred, _, _, _ = model(v, q, None, None, None) 138 | batch_score = compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1) 139 | score += batch_score.sum() 140 | upper_bound += (a.max(1)[0]).sum() 141 | qids = qids.detach().cpu().int().numpy() 142 | for j in range(len(qids)): 143 | qid = qids[j] 144 | typ = qid2type[str(qid)] 145 | if typ == 'yes/no': 146 | score_yesno += batch_score[j] 147 | total_yesno += 1 148 | elif typ == 'other': 149 | score_other += batch_score[j] 150 | total_other += 1 151 | elif typ == 'number': 152 | score_number += batch_score[j] 153 | total_number += 1 154 | else: 155 | print('Hahahahahahahahahahaha') 156 | 157 | 158 | score = score / len(dataloader.dataset) 159 | upper_bound = upper_bound / len(dataloader.dataset) 160 | score_yesno /= total_yesno 161 | score_other /= total_other 162 | score_number /= total_number 163 | 164 | results = dict( 165 | score=score, 166 | upper_bound=upper_bound, 167 | score_yesno=score_yesno, 168 | score_other=score_other, 169 | score_number=score_number, 170 | ) 171 | return results 172 | -------------------------------------------------------------------------------- /css/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import errno 4 | import os 5 | import numpy as np 6 | # from PIL import Image 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | EPS = 1e-7 12 | 13 | 14 | def assert_eq(real, expected): 15 | # assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 16 | assert real == real, '%s (true) vs %s (expected)' % (real, expected) 17 | 18 | 19 | def assert_array_eq(real, expected): 20 | assert (np.abs(real-expected) < EPS).all(), \ 21 | '%s (true) vs %s (expected)' % (real, expected) 22 | 23 | 24 | def load_folder(folder, suffix): 25 | imgs = [] 26 | for f in sorted(os.listdir(folder)): 27 | if f.endswith(suffix): 28 | imgs.append(os.path.join(folder, f)) 29 | return imgs 30 | 31 | 32 | # def load_imageid(folder): 33 | # images = load_folder(folder, 'jpg') 34 | # img_ids = set() 35 | # for img in images: 36 | # img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1]) 37 | # img_ids.add(img_id) 38 | # return img_ids 39 | 40 | 41 | # def pil_loader(path): 42 | # with open(path, 'rb') as f: 43 | # with Image.open(f) as img: 44 | # return img.convert('RGB') 45 | 46 | 47 | def weights_init(m): 48 | """custom weights initialization.""" 49 | cname = m.__class__ 50 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d: 51 | m.weight.data.normal_(0.0, 0.02) 52 | elif cname == nn.BatchNorm2d: 53 | m.weight.data.normal_(1.0, 0.02) 54 | m.bias.data.fill_(0) 55 | else: 56 | print('%s is not initialized.' % cname) 57 | 58 | 59 | def init_net(net, net_file): 60 | if net_file: 61 | net.load_state_dict(torch.load(net_file)) 62 | else: 63 | net.apply(weights_init) 64 | 65 | 66 | def create_dir(path): 67 | if not os.path.exists(path): 68 | try: 69 | os.makedirs(path) 70 | except OSError as exc: 71 | if exc.errno != errno.EEXIST: 72 | raise 73 | 74 | 75 | class Logger(object): 76 | def __init__(self, output_name): 77 | dirname = os.path.dirname(output_name) 78 | if not os.path.exists(dirname): 79 | os.mkdir(dirname) 80 | 81 | self.log_file = open(output_name, 'w') 82 | self.infos = {} 83 | 84 | def append(self, key, val): 85 | vals = self.infos.setdefault(key, []) 86 | vals.append(val) 87 | 88 | def log(self, extra_msg=''): 89 | msgs = [extra_msg] 90 | for key, vals in self.infos.iteritems(): 91 | msgs.append('%s %.6f' % (key, np.mean(vals))) 92 | msg = '\n'.join(msgs) 93 | self.log_file.write(msg + '\n') 94 | self.log_file.flush() 95 | self.infos = {} 96 | return msg 97 | 98 | def write(self, msg): 99 | self.log_file.write(msg + '\n') 100 | self.log_file.flush() 101 | print(msg) 102 | -------------------------------------------------------------------------------- /css/vqa_debias_loss_functions.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict, Counter 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import torch 7 | import inspect 8 | 9 | 10 | def convert_sigmoid_logits_to_binary_logprobs(logits): 11 | """computes log(sigmoid(logits)), log(1-sigmoid(logits))""" 12 | log_prob = -F.softplus(-logits) 13 | log_one_minus_prob = -logits + log_prob 14 | return log_prob, log_one_minus_prob 15 | 16 | 17 | def elementwise_logsumexp(a, b): 18 | """computes log(exp(x) + exp(b))""" 19 | return torch.max(a, b) + torch.log1p(torch.exp(-torch.abs(a - b))) 20 | 21 | 22 | def renormalize_binary_logits(a, b): 23 | """Normalize so exp(a) + exp(b) == 1""" 24 | norm = elementwise_logsumexp(a, b) 25 | return a - norm, b - norm 26 | 27 | 28 | class DebiasLossFn(nn.Module): 29 | """General API for our loss functions""" 30 | 31 | def forward(self, hidden, logits, bias, labels): 32 | """ 33 | :param hidden: [batch, n_hidden] hidden features from the last layer in the model 34 | :param logits: [batch, n_answers_options] sigmoid logits for each answer option 35 | :param bias: [batch, n_answers_options] 36 | bias probabilities for each answer option between 0 and 1 37 | :param labels: [batch, n_answers_options] 38 | scores for each answer option, between 0 and 1 39 | :return: Scalar loss 40 | """ 41 | raise NotImplementedError() 42 | 43 | def to_json(self): 44 | """Get a json representation of this loss function. 45 | 46 | We construct this by looking up the __init__ args 47 | """ 48 | cls = self.__class__ 49 | init = cls.__init__ 50 | if init is object.__init__: 51 | return [] # No init args 52 | 53 | init_signature = inspect.getargspec(init) 54 | if init_signature.varargs is not None: 55 | raise NotImplementedError("varags not supported") 56 | if init_signature.keywords is not None: 57 | raise NotImplementedError("keywords not supported") 58 | args = [x for x in init_signature.args if x != "self"] 59 | out = OrderedDict() 60 | out["name"] = cls.__name__ 61 | for key in args: 62 | out[key] = getattr(self, key) 63 | return out 64 | 65 | 66 | class Plain(DebiasLossFn): 67 | def forward(self, hidden, logits, bias, labels): 68 | loss = F.binary_cross_entropy_with_logits(logits, labels) 69 | 70 | loss *= labels.size(1) 71 | 72 | return loss 73 | 74 | class PlainKD(DebiasLossFn): 75 | def forward(self, hidden, logits, bias, labels): 76 | loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='none') 77 | 78 | # loss *= labels.size(1) 79 | loss = loss.sum(1) 80 | 81 | # return loss 82 | return None, loss 83 | 84 | 85 | class Focal(DebiasLossFn): 86 | def forward(self, hidden, logits, bias, labels): 87 | # import pdb;pdb.set_trace() 88 | focal_logits=torch.log(F.softmax(logits,dim=1)+1e-5) * ((1-F.softmax(bias,dim=1))*(1-F.softmax(bias,dim=1))) 89 | loss=F.binary_cross_entropy_with_logits(focal_logits,labels) 90 | loss*=labels.size(1) 91 | return loss 92 | 93 | class ReweightByInvBias(DebiasLossFn): 94 | def forward(self, hidden, logits, bias, labels): 95 | # Manually compute the binary cross entropy since the old version of torch always aggregates 96 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 97 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob) 98 | weights = (1 - bias) 99 | loss *= weights # Apply the weights 100 | return loss.sum() / weights.sum() 101 | 102 | 103 | class BiasProduct(DebiasLossFn): 104 | def __init__(self, smooth=True, smooth_init=-1, constant_smooth=0.0): 105 | """ 106 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 107 | :param smooth_init: How to initialize `a` 108 | :param constant_smooth: Constant to add to the bias to smooth it 109 | """ 110 | super(BiasProduct, self).__init__() 111 | self.constant_smooth = constant_smooth 112 | self.smooth_init = smooth_init 113 | self.smooth = smooth 114 | if smooth: 115 | self.smooth_param = torch.nn.Parameter( 116 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 117 | else: 118 | self.smooth_param = None 119 | 120 | def forward(self, hidden, logits, bias, labels): 121 | smooth = self.constant_smooth 122 | if self.smooth: 123 | smooth += F.sigmoid(self.smooth_param) 124 | 125 | # Convert the bias into log-space, with a factor for both the 126 | # binary outputs for each answer option 127 | bias_lp = torch.log(bias + smooth) 128 | bias_l_inv = torch.log1p(-bias + smooth) 129 | 130 | # Convert the the logits into log-space with the same format 131 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 132 | # import pdb;pdb.set_trace() 133 | 134 | # Add the bias 135 | log_prob += bias_lp 136 | log_one_minus_prob += bias_l_inv 137 | 138 | # Re-normalize the factors in logspace 139 | log_prob, log_one_minus_prob = renormalize_binary_logits(log_prob, log_one_minus_prob) 140 | 141 | # Compute the binary cross entropy 142 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 143 | return loss 144 | 145 | 146 | class LearnedMixin(DebiasLossFn): 147 | def __init__(self, w, smooth=True, smooth_init=-1, constant_smooth=0.0): 148 | """ 149 | :param w: Weight of the entropy penalty 150 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 151 | :param smooth_init: How to initialize `a` 152 | :param constant_smooth: Constant to add to the bias to smooth it 153 | """ 154 | super(LearnedMixin, self).__init__() 155 | self.w = w 156 | # self.w=0 157 | self.smooth_init = smooth_init 158 | self.constant_smooth = constant_smooth 159 | self.bias_lin = torch.nn.Linear(1024, 1) 160 | self.smooth = smooth 161 | if self.smooth: 162 | self.smooth_param = torch.nn.Parameter( 163 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 164 | else: 165 | self.smooth_param = None 166 | 167 | def forward(self, hidden, logits, bias, labels): 168 | factor = self.bias_lin.forward(hidden) # [batch, 1] 169 | factor = F.softplus(factor) 170 | 171 | bias = torch.stack([bias, 1 - bias], 2) # [batch, n_answers, 2] 172 | 173 | # Smooth 174 | bias += self.constant_smooth 175 | if self.smooth: 176 | soften_factor = F.sigmoid(self.smooth_param) 177 | bias = bias + soften_factor.unsqueeze(1) 178 | 179 | bias = torch.log(bias) # Convert to logspace 180 | 181 | # Scale by the factor 182 | # [batch, n_answers, 2] * [batch, 1, 1] -> [batch, n_answers, 2] 183 | bias = bias * factor.unsqueeze(1) 184 | 185 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 186 | log_probs = torch.stack([log_prob, log_one_minus_prob], 2) 187 | 188 | # Add the bias in 189 | logits = bias + log_probs 190 | 191 | # Renormalize to get log probabilities 192 | log_prob, log_one_minus_prob = renormalize_binary_logits(logits[:, :, 0], logits[:, :, 1]) 193 | 194 | # Compute loss 195 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 196 | 197 | # Re-normalized version of the bias 198 | bias_norm = elementwise_logsumexp(bias[:, :, 0], bias[:, :, 1]) 199 | bias_logprob = bias - bias_norm.unsqueeze(2) 200 | 201 | # Compute and add the entropy penalty 202 | entropy = -(torch.exp(bias_logprob) * bias_logprob).sum(2).mean() 203 | return loss + self.w * entropy 204 | 205 | 206 | class LearnedMixinKD(DebiasLossFn): 207 | def __init__(self, smooth=True, smooth_init=-1, constant_smooth=0.0): 208 | """ 209 | :param w: Weight of the entropy penalty 210 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 211 | :param smooth_init: How to initialize `a` 212 | :param constant_smooth: Constant to add to the bias to smooth it 213 | """ 214 | super(LearnedMixinKD, self).__init__() 215 | self.smooth_init = smooth_init 216 | self.constant_smooth = constant_smooth 217 | self.bias_lin = torch.nn.Linear(1024, 1) 218 | self.smooth = smooth 219 | if self.smooth: 220 | self.smooth_param = torch.nn.Parameter( 221 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 222 | else: 223 | self.smooth_param = None 224 | 225 | def forward(self, hidden, logits, bias, labels): 226 | factor = self.bias_lin.forward(hidden) # [batch, 1] 227 | factor = F.softplus(factor) 228 | 229 | bias = torch.stack([bias, 1 - bias], 2) # [batch, n_answers, 2] 230 | 231 | # Smooth 232 | bias += self.constant_smooth 233 | if self.smooth: 234 | soften_factor = F.sigmoid(self.smooth_param) 235 | bias = bias + soften_factor.unsqueeze(1) 236 | 237 | bias = torch.log(bias) # Convert to logspace 238 | 239 | # Scale by the factor 240 | # [batch, n_answers, 2] * [batch, 1, 1] -> [batch, n_answers, 2] 241 | bias = bias * factor.unsqueeze(1) 242 | 243 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 244 | log_probs = torch.stack([log_prob, log_one_minus_prob], 2) 245 | 246 | # Add the bias in 247 | logits = bias + log_probs 248 | 249 | # Renormalize to get log probabilities 250 | log_prob, log_one_minus_prob = renormalize_binary_logits(logits[:, :, 0], logits[:, :, 1]) 251 | 252 | # Compute loss 253 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 254 | 255 | # Re-normalized version of the bias 256 | bias_norm = elementwise_logsumexp(bias[:, :, 0], bias[:, :, 1]) 257 | bias_logprob = bias - bias_norm.unsqueeze(2) 258 | 259 | prob_all = torch.exp(log_prob) 260 | 261 | p = torch.clamp(1-prob_all, min=1e-12) 262 | p = torch.clamp(prob_all/p, min=1e-12) 263 | 264 | logits_all = torch.log(p) 265 | 266 | return logits_all, loss -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/images/architecture.png -------------------------------------------------------------------------------- /images/introd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/images/introd.png --------------------------------------------------------------------------------