├── .gitignore
├── LICENSE
├── README.md
├── cfvqa
├── .gitignore
├── README.md
├── cfvqa
│ ├── __init__.py
│ ├── __version__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── factory.py
│ │ ├── scripts
│ │ │ ├── download_vqa2.sh
│ │ │ └── download_vqacp2.sh
│ │ ├── vqa2.py
│ │ ├── vqacp.py
│ │ └── vqacp2.py
│ ├── engines
│ │ ├── __init__.py
│ │ ├── engine.py
│ │ ├── factory.py
│ │ └── logger.py
│ ├── models
│ │ ├── criterions
│ │ │ ├── __init__.py
│ │ │ ├── cfvqa_criterion.py
│ │ │ ├── cfvqaintrod_criterion.py
│ │ │ ├── factory.py
│ │ │ ├── rubi_criterion.py
│ │ │ └── rubiintrod_criterion.py
│ │ ├── metrics
│ │ │ ├── __init__.py
│ │ │ ├── factory.py
│ │ │ ├── vqa_cfvqa_metrics.py
│ │ │ ├── vqa_cfvqaintrod_metrics.py
│ │ │ ├── vqa_cfvqasimple_metrics.py
│ │ │ ├── vqa_rubi_metrics.py
│ │ │ └── vqa_rubiintrod_metrics.py
│ │ └── networks
│ │ │ ├── __init__.py
│ │ │ ├── cfvqa.py
│ │ │ ├── cfvqaintrod.py
│ │ │ ├── factory.py
│ │ │ ├── rubi.py
│ │ │ ├── rubiintrod.py
│ │ │ ├── san_net.py
│ │ │ ├── smrl_net.py
│ │ │ ├── updn_net.py
│ │ │ └── utils.py
│ ├── optimizers
│ │ ├── __init__.py
│ │ └── factory.py
│ ├── options
│ │ ├── vqa2
│ │ │ ├── smrl_baseline.yaml
│ │ │ ├── smrl_cfvqa_sum.yaml
│ │ │ ├── smrl_cfvqaintrod_sum.yaml
│ │ │ ├── smrl_cfvqasimple_rubi.yaml
│ │ │ ├── smrl_cfvqasimpleintrod_rubi.yaml
│ │ │ ├── smrl_rubi.yaml
│ │ │ └── smrl_rubiintrod.yaml
│ │ └── vqacp2
│ │ │ ├── smrl_baseline.yaml
│ │ │ ├── smrl_cfvqa_sum.yaml
│ │ │ ├── smrl_cfvqaintrod_sum.yaml
│ │ │ ├── smrl_cfvqasimple_rubi.yaml
│ │ │ ├── smrl_cfvqasimpleintrod_rubi.yaml
│ │ │ ├── smrl_rubi.yaml
│ │ │ └── smrl_rubiintrod.yaml
│ └── run.py
├── engine.py
├── requirements.txt
├── run.py
├── run_vqa2_cfvqa_introd.sh
└── scripts
│ ├── run_vqa2_cfvqa_introd.sh
│ ├── run_vqa2_rubi_introd.sh
│ ├── run_vqa2_rubicf_introd.sh
│ ├── run_vqacp2_rubi_introd.sh
│ └── run_vqacp2_rubicf_introd.sh
├── css
├── .gitignore
├── README.md
├── attention.py
├── base_model.py
├── base_model_introd.py
├── classifier.py
├── dataset.py
├── eval.py
├── fc.py
├── language_model.py
├── main.py
├── main_introd.py
├── tools
│ ├── compute_softscore.py
│ ├── compute_softscore_val.py
│ ├── create_dictionary.py
│ ├── create_dictionary_v1.py
│ ├── download.sh
│ └── process.sh
├── train.py
├── train_introd.py
├── util
│ ├── cpv2_notype_mask.json
│ ├── cpv2_type_mask.json
│ ├── qid2type_cpv1.json
│ ├── qid2type_cpv2.json
│ ├── qid2type_v2.json
│ ├── v2_notype_mask.json
│ └── v2_type_mask.json
├── utils.py
└── vqa_debias_loss_functions.py
└── images
├── architecture.png
└── introd.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Introspective Distillation for Robust Question Answering
2 |
3 | This repository is the Pytorch implementation of our paper [" Introspective Distillation for Robust Question Answering"](https://arxiv.org/abs/2111.01026) in NeurIPS 2021.
4 |
5 | IntroD is proposed to achieve both high in-distribution (ID) and out-of-distribution (OOD) performances for question answering tasks like VQA and extractive QA. The key technical contribution is to blend the inductive bias of OOD and ID by introspecting whether a training sample fits in the factual ID world or the counterfactual OOD one.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | If you find this paper and codes help your research, please kindly consider citing our papers in your publications.
17 | ```
18 | @inproceedings{niu2021introspective,
19 | title={Introspective Distillation for Robust Question Answering},
20 | author={Niu, Yulei and Zhang, Hanwang},
21 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
22 | year={2021}
23 | }
24 | ```
25 | ```
26 | @inproceedings{niu2020counterfactual,
27 | title={Counterfactual VQA: A Cause-Effect Look at Language Bias},
28 | author={Niu, Yulei and Tang, Kaihua and Zhang, Hanwang and Lu, Zhiwu and Hua, Xian-Sheng and Wen, Ji-Rong},
29 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
30 | year={2021}
31 | }
32 | ```
--------------------------------------------------------------------------------
/cfvqa/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | data/
132 | data
133 | logs/
134 | logs
135 |
--------------------------------------------------------------------------------
/cfvqa/README.md:
--------------------------------------------------------------------------------
1 | # CFVQA+IntroD
2 |
3 | This code is implemented as a fork of [CF-VQA][1] and [RUBi][2].
4 |
5 | CF-VQA is proposed to capture and mitigate language bias in VQA from the view of causality. CF-VQA (1) captures the language bias as the direct causal effect of questions on answers, and (2) reduces the language bias by subtracting the direct language effect from the total causal effect.
6 |
7 | ## Summary
8 |
9 | * [Installation](#installation)
10 | * [Setup and dependencies](#1-setup-and-dependencies)
11 | * [Download datasets](#2-download-datasets)
12 | * [Quick start](#quick-start)
13 | * [Train a model](#train-a-model)
14 | * [Evaluate a model](#evaluate-a-model)
15 | * [Useful commands](#useful-commands)
16 | * [Acknowledgment](#acknowledgment)
17 |
18 | ## Installation
19 |
20 |
21 | ### 1. Setup and dependencies
22 |
23 | Install Anaconda or Miniconda distribution based on Python3+ from their downloads' site.
24 |
25 | ```bash
26 | conda create --name cfvqa python=3.7
27 | source activate cfvqa
28 | pip install -r requirements.txt
29 | ```
30 |
31 | For the error that `ModuleNotFoundError: No module named 'block.external'`, please follow [this issue](https://github.com/yuleiniu/cfvqa/issues/7) for the solution.
32 |
33 |
34 | ### 2. Download datasets
35 |
36 | Download annotations, images and features for VQA experiments:
37 | ```bash
38 | bash cfvqa/datasets/scripts/download_vqa2.sh
39 | bash cfvqa/datasets/scripts/download_vqacp2.sh
40 | ```
41 |
42 |
43 | ## Overview of commands and result files
44 |
45 | ### Train a model
46 |
47 | The [bootstrap/run.py](https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/run.py) file load the options contained in a yaml file, create the corresponding experiment directory and start the training procedure. For instance, you can train our best model on VQA-CP v2 (CFVQA+SUM+SMRL) by running:
48 | ```bash?
49 | python -m bootstrap.run -o cfvqa/options/vqacp2/smrl_cfvqa_sum.yaml
50 | ```
51 | Then, several files are going to be created in `logs/vqacp2/smrl_cfvqa_sum/`:
52 | - [options.yaml] (copy of options)
53 | - [logs.txt] (history of print)
54 | - [logs.json] (batchs and epochs statistics)
55 | - **[\_vq\_val\_oe.json] (statistics for the language-prior based strategy, e.g., RUBi)**
56 | - **[\_cfvqa\_val\_oe.json] (statistics for CF-VQA)**
57 | - [\_q\_val\_oe.json] (statistics for language-only branch)
58 | - [\_v\_val\_oe.json] (statistics for vision-only branch)
59 | - [\_all\_val\_oe.json] (statistics for the ensembled branch)
60 | - ckpt_last_engine.pth.tar (checkpoints of last epoch)
61 | - ckpt_last_model.pth.tar
62 | - ckpt_last_optimizer.pth.tar
63 |
64 | Many options are available in the options directory. CFVQA represents the complete causal graph while cfvqas represents the simplified causal graph.
65 |
66 | ### Evaluate a model
67 |
68 | There is no test set on VQA-CP v2, our main dataset. The evaluation is done on the validation set. For a model trained on VQA v2, you can evaluate your model on the test set. In this example, [boostrap/run.py](https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/run.py) load the options from your experiment directory, resume the best checkpoint on the validation set and start an evaluation on the testing set instead of the validation set while skipping the training set (train_split is empty). Thanks to `--misc.logs_name`, the logs will be written in the new `logs_predicate.txt` and `logs_predicate.json` files, instead of being appended to the `logs.txt` and `logs.json` files.
69 | ```bash
70 | python -m bootstrap.run \
71 | -o ./logs/vqacp2/smrl_cfvqa_sum/options.yaml \
72 | --exp.resume last \
73 | --dataset.train_split ''\
74 | --dataset.eval_split val \
75 | --misc.logs_name test
76 | ```
77 |
78 | ## Run IntroD
79 |
80 | We take CF-VQA+IntroD on VQA-CP v2 as an example. Simply run
81 | ```bash?
82 | bash scripts/run_vqacp2_cfvqa_introd.sh
83 | ```
84 | The statistics for the final student model are included in `./logs/vqacp2/smrl_cfvqaintrod_sum/_stu_val_oe.json`
85 |
86 | ## Useful commands
87 |
88 | More useful commands can be founded [here](https://github.com/yuleiniu/cfvqa#useful-commands).
89 |
90 | ## Acknowledgment
91 |
92 | Special thanks to the authors of [RUBi][2], [BLOCK][3], and [bootstrap.pytorch][4], and the datasets used in this research project.
93 |
94 | [1]: https://github.com/yuleiniu/cfvqa
95 | [2]: https://github.com/cdancette/rubi.bootstrap.pytorch
96 | [3]: https://github.com/Cadene/block.bootstrap.pytorch
97 | [4]: https://github.com/Cadene/bootstrap.pytorch
98 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/__init__.py
--------------------------------------------------------------------------------
/cfvqa/cfvqa/__version__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.0'
2 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/datasets/__init__.py
--------------------------------------------------------------------------------
/cfvqa/cfvqa/datasets/factory.py:
--------------------------------------------------------------------------------
1 | from bootstrap.lib.options import Options
2 | from block.datasets.tdiuc import TDIUC
3 | from block.datasets.vrd import VRD
4 | from block.datasets.vg import VG
5 | from block.datasets.vqa_utils import ListVQADatasets
6 | from .vqa2 import VQA2
7 | from .vqacp2 import VQACP2
8 | from .vqacp import VQACP
9 |
10 | def factory(engine=None):
11 | opt = Options()['dataset']
12 |
13 | dataset = {}
14 | if opt.get('train_split', None):
15 | dataset['train'] = factory_split(opt['train_split'])
16 | if opt.get('eval_split', None):
17 | dataset['eval'] = factory_split(opt['eval_split'])
18 |
19 | return dataset
20 |
21 | def factory_split(split):
22 | opt = Options()['dataset']
23 | shuffle = ('train' in split)
24 |
25 | if opt['name'] == 'vqacp2':
26 | assert(split in ['train', 'val', 'test'])
27 | samplingans = (opt['samplingans'] and split == 'train')
28 |
29 | dataset = VQACP2(
30 | dir_data=opt['dir'],
31 | split=split,
32 | batch_size=opt['batch_size'],
33 | nb_threads=opt['nb_threads'],
34 | pin_memory=Options()['misc']['cuda'],
35 | shuffle=shuffle,
36 | nans=opt['nans'],
37 | minwcount=opt['minwcount'],
38 | nlp=opt['nlp'],
39 | proc_split=opt['proc_split'],
40 | samplingans=samplingans,
41 | dir_rcnn=opt['dir_rcnn'],
42 | dir_cnn=opt.get('dir_cnn', None),
43 | dir_vgg16=opt.get('dir_vgg16', None),
44 | )
45 |
46 | elif opt['name'] == 'vqacp':
47 | assert(split in ['train', 'val', 'test'])
48 | samplingans = (opt['samplingans'] and split == 'train')
49 |
50 | dataset = VQACP(
51 | dir_data=opt['dir'],
52 | split=split,
53 | batch_size=opt['batch_size'],
54 | nb_threads=opt['nb_threads'],
55 | pin_memory=Options()['misc']['cuda'],
56 | shuffle=shuffle,
57 | nans=opt['nans'],
58 | minwcount=opt['minwcount'],
59 | nlp=opt['nlp'],
60 | proc_split=opt['proc_split'],
61 | samplingans=samplingans,
62 | dir_rcnn=opt['dir_rcnn'],
63 | dir_cnn=opt.get('dir_cnn', None),
64 | dir_vgg16=opt.get('dir_vgg16', None),
65 | )
66 |
67 | elif opt['name'] == 'vqacpv2-with-testdev':
68 | assert(split in ['train', 'val', 'test'])
69 | samplingans = (opt['samplingans'] and split == 'train')
70 | dataset = VQACP2(
71 | dir_data=opt['dir'],
72 | split=split,
73 | batch_size=opt['batch_size'],
74 | nb_threads=opt['nb_threads'],
75 | pin_memory=Options()['misc']['cuda'],
76 | shuffle=shuffle,
77 | nans=opt['nans'],
78 | minwcount=opt['minwcount'],
79 | nlp=opt['nlp'],
80 | proc_split=opt['proc_split'],
81 | samplingans=samplingans,
82 | dir_rcnn=opt['dir_rcnn'],
83 | dir_cnn=opt.get('dir_cnn', None),
84 | dir_vgg16=opt.get('dir_vgg16', None),
85 | has_testdevset=True,
86 | )
87 |
88 | elif opt['name'] == 'vqa2':
89 | assert(split in ['train', 'val', 'test'])
90 | samplingans = (opt['samplingans'] and split == 'train')
91 |
92 | if opt['vg']:
93 | assert(opt['proc_split'] == 'trainval')
94 |
95 | # trainvalset
96 | vqa2 = VQA2(
97 | dir_data=opt['dir'],
98 | split='train',
99 | nans=opt['nans'],
100 | minwcount=opt['minwcount'],
101 | nlp=opt['nlp'],
102 | proc_split=opt['proc_split'],
103 | samplingans=samplingans,
104 | dir_rcnn=opt['dir_rcnn'])
105 |
106 | vg = VG(
107 | dir_data=opt['dir_vg'],
108 | split='train',
109 | nans=10000,
110 | minwcount=0,
111 | nlp=opt['nlp'],
112 | dir_rcnn=opt['dir_rcnn_vg'])
113 |
114 | vqa2vg = ListVQADatasets(
115 | [vqa2,vg],
116 | split='train',
117 | batch_size=opt['batch_size'],
118 | nb_threads=opt['nb_threads'],
119 | pin_memory=Options()['misc.cuda'],
120 | shuffle=shuffle)
121 |
122 | if split == 'train':
123 | dataset = vqa2vg
124 | else:
125 | dataset = VQA2(
126 | dir_data=opt['dir'],
127 | split=split,
128 | batch_size=opt['batch_size'],
129 | nb_threads=opt['nb_threads'],
130 | pin_memory=Options()['misc.cuda'],
131 | shuffle=False,
132 | nans=opt['nans'],
133 | minwcount=opt['minwcount'],
134 | nlp=opt['nlp'],
135 | proc_split=opt['proc_split'],
136 | samplingans=samplingans,
137 | dir_rcnn=opt['dir_rcnn'])
138 | dataset.sync_from(vqa2vg)
139 |
140 | else:
141 | dataset = VQA2(
142 | dir_data=opt['dir'],
143 | split=split,
144 | batch_size=opt['batch_size'],
145 | nb_threads=opt['nb_threads'],
146 | pin_memory=Options()['misc.cuda'],
147 | shuffle=shuffle,
148 | nans=opt['nans'],
149 | minwcount=opt['minwcount'],
150 | nlp=opt['nlp'],
151 | proc_split=opt['proc_split'],
152 | samplingans=samplingans,
153 | dir_rcnn=opt['dir_rcnn'],
154 | dir_cnn=opt.get('dir_cnn', None),
155 | )
156 |
157 | return dataset
158 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/datasets/scripts/download_vqa2.sh:
--------------------------------------------------------------------------------
1 | mkdir -p data/vqa
2 | cd data/vqa
3 | wget http://data.lip6.fr/cadene/block/vqa2.tar.gz
4 | wget http://data.lip6.fr/cadene/block/coco.tar.gz
5 | tar -xzvf vqa2.tar.gz
6 | tar -xzvf coco.tar.gz
7 |
8 | mkdir -p data/vqa/coco/extract_rcnn
9 | cd data/vqa/coco/extract_rcnn
10 | wget http://data.lip6.fr/cadene/block/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36.tar
11 | tar -xvf 2018-04-27_bottom-up-attention_fixed_36.tar
12 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/datasets/scripts/download_vqacp2.sh:
--------------------------------------------------------------------------------
1 | mkdir -p data/vqa
2 | cd data/vqa
3 | wget http://data.lip6.fr/cadene/murel/vqacp2.tar.gz
4 | tar -xzvf vqacp2.tar.gz
5 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/datasets/vqa2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import copy
4 | import json
5 | import torch
6 | import numpy as np
7 | from os import path as osp
8 | from bootstrap.lib.logger import Logger
9 | from bootstrap.lib.options import Options
10 | from block.datasets.vqa_utils import AbstractVQA
11 | from copy import deepcopy
12 | import random
13 | import tqdm
14 | import h5py
15 |
16 | class VQA2(AbstractVQA):
17 |
18 | def __init__(self,
19 | dir_data='data/vqa2',
20 | split='train',
21 | batch_size=10,
22 | nb_threads=4,
23 | pin_memory=False,
24 | shuffle=False,
25 | nans=1000,
26 | minwcount=10,
27 | nlp='mcb',
28 | proc_split='train',
29 | samplingans=False,
30 | dir_rcnn='data/coco/extract_rcnn',
31 | adversarial=False,
32 | dir_cnn=None
33 | ):
34 |
35 | super(VQA2, self).__init__(
36 | dir_data=dir_data,
37 | split=split,
38 | batch_size=batch_size,
39 | nb_threads=nb_threads,
40 | pin_memory=pin_memory,
41 | shuffle=shuffle,
42 | nans=nans,
43 | minwcount=minwcount,
44 | nlp=nlp,
45 | proc_split=proc_split,
46 | samplingans=samplingans,
47 | has_valset=True,
48 | has_testset=True,
49 | has_answers_occurence=True,
50 | do_tokenize_answers=False)
51 |
52 | self.dir_rcnn = dir_rcnn
53 | self.dir_cnn = dir_cnn
54 | self.load_image_features()
55 | # to activate manually in visualization context (notebo# to activate manually in visualization context (notebook)
56 | self.load_original_annotation = False
57 |
58 | def add_rcnn_to_item(self, item):
59 | path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name']))
60 | item_rcnn = torch.load(path_rcnn)
61 | item['visual'] = item_rcnn['pooled_feat']
62 | item['coord'] = item_rcnn['rois']
63 | item['norm_coord'] = item_rcnn.get('norm_rois', None)
64 | item['nb_regions'] = item['visual'].size(0)
65 | return item
66 |
67 | def add_cnn_to_item(self, item):
68 | image_name = item['image_name']
69 | if image_name in self.image_names_to_index_train:
70 | index = self.image_names_to_index_train[image_name]
71 | image = torch.tensor(self.image_features_train['att'][index])
72 | elif image_name in self.image_names_to_index_val:
73 | index = self.image_names_to_index_val[image_name]
74 | image = torch.tensor(self.image_features_val['att'][index])
75 | image = image.permute(1, 2, 0).view(196, 2048)
76 | item['visual'] = image
77 | return item
78 |
79 | def load_image_features(self):
80 | if self.dir_cnn:
81 | filename_train = os.path.join(self.dir_cnn, 'trainset.hdf5')
82 | filename_val = os.path.join(self.dir_cnn, 'valset.hdf5')
83 | Logger()(f"Opening file {filename_train}, {filename_val}")
84 | self.image_features_train = h5py.File(filename_train, 'r', swmr=True)
85 | self.image_features_val = h5py.File(filename_val, 'r', swmr=True)
86 | # load txt
87 | with open(os.path.join(self.dir_cnn, 'trainset.txt'.format(self.split)), 'r') as f:
88 | self.image_names_to_index_train = {}
89 | for i, line in enumerate(f):
90 | self.image_names_to_index_train[line.strip()] = i
91 | with open(os.path.join(self.dir_cnn, 'valset.txt'.format(self.split)), 'r') as f:
92 | self.image_names_to_index_val = {}
93 | for i, line in enumerate(f):
94 | self.image_names_to_index_val[line.strip()] = i
95 |
96 | def __getitem__(self, index):
97 | item = {}
98 | item['index'] = index
99 |
100 | # Process Question (word token)
101 | question = self.dataset['questions'][index]
102 | if self.load_original_annotation:
103 | item['original_question'] = question
104 |
105 | item['question_id'] = question['question_id']
106 |
107 | item['question'] = torch.tensor(question['question_wids'], dtype=torch.long)
108 | item['lengths'] = torch.tensor([len(question['question_wids'])], dtype=torch.long)
109 | item['image_name'] = question['image_name']
110 |
111 | # Process Object, Attribut and Relational features
112 | # Process Object, Attribut and Relational features
113 | if self.dir_rcnn:
114 | item = self.add_rcnn_to_item(item)
115 | elif self.dir_cnn:
116 | item = self.add_cnn_to_item(item)
117 |
118 | # Process Answer if exists
119 | if 'annotations' in self.dataset:
120 | annotation = self.dataset['annotations'][index]
121 | if self.load_original_annotation:
122 | item['original_annotation'] = annotation
123 | if 'train' in self.split and self.samplingans:
124 | proba = annotation['answers_count']
125 | proba = proba / np.sum(proba)
126 | item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba))
127 | else:
128 | item['answer_id'] = annotation['answer_id']
129 | item['class_id'] = torch.tensor([item['answer_id']], dtype=torch.long)
130 | item['answer'] = annotation['answer']
131 | item['question_type'] = annotation['question_type']
132 | else:
133 | if item['question_id'] in self.is_qid_testdev:
134 | item['is_testdev'] = True
135 | else:
136 | item['is_testdev'] = False
137 |
138 | # if Options()['model.network.name'] == 'xmn_net':
139 | # num_feat = 36
140 | # relation_mask = np.zeros((num_feat, num_feat))
141 | # boxes = item['coord']
142 | # for i in range(num_feat):
143 | # for j in range(i+1, num_feat):
144 | # # if there is no overlap between two bounding box
145 | # if boxes[0,i]>boxes[2,j] or boxes[0,j]>boxes[2,i] or boxes[1,i]>boxes[3,j] or boxes[1,j]>boxes[3,i]:
146 | # pass
147 | # else:
148 | # relation_mask[i,j] = relation_mask[j,i] = 1
149 | # relation_mask = torch.from_numpy(relation_mask).byte()
150 | # item['relation_mask'] = relation_mask
151 |
152 | return item
153 |
154 | def download(self):
155 | dir_zip = osp.join(self.dir_raw, 'zip')
156 | os.system('mkdir -p '+dir_zip)
157 | dir_ann = osp.join(self.dir_raw, 'annotations')
158 | os.system('mkdir -p '+dir_ann)
159 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Train_mscoco.zip -P '+dir_zip)
160 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Val_mscoco.zip -P '+dir_zip)
161 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Test_mscoco.zip -P '+dir_zip)
162 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Train_mscoco.zip -P '+dir_zip)
163 | os.system('wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Val_mscoco.zip -P '+dir_zip)
164 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Train_mscoco.zip')+' -d '+dir_ann)
165 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Val_mscoco.zip')+' -d '+dir_ann)
166 | os.system('unzip '+osp.join(dir_zip, 'v2_Questions_Test_mscoco.zip')+' -d '+dir_ann)
167 | os.system('unzip '+osp.join(dir_zip, 'v2_Annotations_Train_mscoco.zip')+' -d '+dir_ann)
168 | os.system('unzip '+osp.join(dir_zip, 'v2_Annotations_Val_mscoco.zip')+' -d '+dir_ann)
169 | os.system('mv '+osp.join(dir_ann, 'v2_mscoco_train2014_annotations.json')+' '
170 | +osp.join(dir_ann, 'mscoco_train2014_annotations.json'))
171 | os.system('mv '+osp.join(dir_ann, 'v2_mscoco_val2014_annotations.json')+' '
172 | +osp.join(dir_ann, 'mscoco_val2014_annotations.json'))
173 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_train2014_questions.json')+' '
174 | +osp.join(dir_ann, 'OpenEnded_mscoco_train2014_questions.json'))
175 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_val2014_questions.json')+' '
176 | +osp.join(dir_ann, 'OpenEnded_mscoco_val2014_questions.json'))
177 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_test2015_questions.json')+' '
178 | +osp.join(dir_ann, 'OpenEnded_mscoco_test2015_questions.json'))
179 | os.system('mv '+osp.join(dir_ann, 'v2_OpenEnded_mscoco_test-dev2015_questions.json')+' '
180 | +osp.join(dir_ann, 'OpenEnded_mscoco_test-dev2015_questions.json'))
181 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/datasets/vqacp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import copy
4 | import json
5 | import torch
6 | import numpy as np
7 | from tqdm import tqdm
8 | from os import path as osp
9 | from bootstrap.lib.logger import Logger
10 | from block.datasets.vqa_utils import AbstractVQA
11 | from copy import deepcopy
12 | import random
13 | import h5py
14 |
15 | class VQACP(AbstractVQA):
16 |
17 | def __init__(self,
18 | dir_data='data/vqa/vqacp2',
19 | split='train',
20 | batch_size=80,
21 | nb_threads=4,
22 | pin_memory=False,
23 | shuffle=False,
24 | nans=1000,
25 | minwcount=10,
26 | nlp='mcb',
27 | proc_split='train',
28 | samplingans=False,
29 | dir_rcnn='data/coco/extract_rcnn',
30 | dir_cnn=None,
31 | dir_vgg16=None,
32 | has_testdevset=False,
33 | ):
34 | super(VQACP, self).__init__(
35 | dir_data=dir_data,
36 | split=split,
37 | batch_size=batch_size,
38 | nb_threads=nb_threads,
39 | pin_memory=pin_memory,
40 | shuffle=shuffle,
41 | nans=nans,
42 | minwcount=minwcount,
43 | nlp=nlp,
44 | proc_split=proc_split,
45 | samplingans=samplingans,
46 | has_valset=True,
47 | has_testset=False,
48 | has_testdevset=has_testdevset,
49 | has_testset_anno=False,
50 | has_answers_occurence=True,
51 | do_tokenize_answers=False)
52 | self.dir_rcnn = dir_rcnn
53 | self.dir_cnn = dir_cnn
54 | self.dir_vgg16 = dir_vgg16
55 | self.load_image_features()
56 | self.load_original_annotation = False
57 |
58 | def add_rcnn_to_item(self, item):
59 | path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name']))
60 | item_rcnn = torch.load(path_rcnn)
61 | item['visual'] = item_rcnn['pooled_feat']
62 | item['coord'] = item_rcnn['rois']
63 | item['norm_coord'] = item_rcnn['norm_rois']
64 | item['nb_regions'] = item['visual'].size(0)
65 | return item
66 |
67 | def load_image_features(self):
68 | if self.dir_cnn:
69 | filename_train = os.path.join(self.dir_cnn, 'trainset.hdf5')
70 | filename_val = os.path.join(self.dir_cnn, 'valset.hdf5')
71 | Logger()(f"Opening file {filename_train}, {filename_val}")
72 | self.image_features_train = h5py.File(filename_train, 'r', swmr=True)
73 | self.image_features_val = h5py.File(filename_val, 'r', swmr=True)
74 | # load txt
75 | with open(os.path.join(self.dir_cnn, 'trainset.txt'.format(self.split)), 'r') as f:
76 | self.image_names_to_index_train = {}
77 | for i, line in enumerate(f):
78 | self.image_names_to_index_train[line.strip()] = i
79 | with open(os.path.join(self.dir_cnn, 'valset.txt'.format(self.split)), 'r') as f:
80 | self.image_names_to_index_val = {}
81 | for i, line in enumerate(f):
82 | self.image_names_to_index_val[line.strip()] = i
83 | elif self.dir_vgg16:
84 | # list filenames
85 | self.filenames_train = os.listdir(os.path.join(self.dir_vgg16, 'train'))
86 | self.filenames_val = os.listdir(os.path.join(self.dir_vgg16, 'val'))
87 |
88 |
89 | def add_vgg_to_item(self, item):
90 | image_name = item['image_name']
91 | filename = image_name + '.pth'
92 | if filename in self.filenames_train:
93 | path = os.path.join(self.dir_vgg16, 'train', filename)
94 | elif filename in self.filenames_val:
95 | path = os.path.join(self.dir_vgg16, 'val', filename)
96 | visual = torch.load(path)
97 | visual = visual.permute(1, 2, 0).view(14*14, 512)
98 | item['visual'] = visual
99 | return item
100 |
101 | def add_cnn_to_item(self, item):
102 | image_name = item['image_name']
103 | if image_name in self.image_names_to_index_train:
104 | index = self.image_names_to_index_train[image_name]
105 | image = torch.tensor(self.image_features_train['att'][index])
106 | elif image_name in self.image_names_to_index_val:
107 | index = self.image_names_to_index_val[image_name]
108 | image = torch.tensor(self.image_features_val['att'][index])
109 | image = image.permute(1, 2, 0).view(196, 2048)
110 | item['visual'] = image
111 | return item
112 |
113 | def __getitem__(self, index):
114 | item = {}
115 | item['index'] = index
116 |
117 | # Process Question (word token)
118 | question = self.dataset['questions'][index]
119 | if self.load_original_annotation:
120 | item['original_question'] = question
121 | item['question_id'] = question['question_id']
122 | item['question'] = torch.LongTensor(question['question_wids'])
123 | item['lengths'] = torch.LongTensor([len(question['question_wids'])])
124 | item['image_name'] = question['image_name']
125 |
126 | # Process Object, Attribut and Relational features
127 | if self.dir_rcnn:
128 | item = self.add_rcnn_to_item(item)
129 | elif self.dir_cnn:
130 | item = self.add_cnn_to_item(item)
131 | elif self.dir_vgg16:
132 | item = self.add_vgg_to_item(item)
133 |
134 | # Process Answer if exists
135 | if 'annotations' in self.dataset:
136 | annotation = self.dataset['annotations'][index]
137 | if self.load_original_annotation:
138 | item['original_annotation'] = annotation
139 | if 'train' in self.split and self.samplingans:
140 | proba = annotation['answers_count']
141 | proba = proba / np.sum(proba)
142 | item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba))
143 | else:
144 | item['answer_id'] = annotation['answer_id']
145 | item['class_id'] = torch.LongTensor([item['answer_id']])
146 | item['answer'] = annotation['answer']
147 | item['question_type'] = annotation['question_type']
148 |
149 | return item
150 |
151 | def download(self):
152 | dir_ann = osp.join(self.dir_raw, 'annotations')
153 | os.system('mkdir -p '+dir_ann)
154 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json -P' + dir_ann)
155 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json -P' + dir_ann)
156 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json -P' + dir_ann)
157 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json -P' + dir_ann)
158 | train_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v1_train_questions.json")))}
159 | val_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v1_test_questions.json")))}
160 | train_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v1_train_annotations.json")))}
161 | val_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v1_test_annotations.json")))}
162 | train_q['info'] = {}
163 | train_q['data_type'] = 'mscoco'
164 | train_q['data_subtype'] = "train2014cp"
165 | train_q['task_type'] = "Open-Ended"
166 | train_q['license'] = {}
167 | val_q['info'] = {}
168 | val_q['data_type'] = 'mscoco'
169 | val_q['data_subtype'] = "val2014cp"
170 | val_q['task_type'] = "Open-Ended"
171 | val_q['license'] = {}
172 | for k in ["info", 'data_type','data_subtype', 'license']:
173 | train_ann[k] = train_q[k]
174 | val_ann[k] = val_q[k]
175 | with open(osp.join(dir_ann, "OpenEnded_mscoco_train2014_questions.json"), 'w') as F:
176 | F.write(json.dumps(train_q))
177 | with open(osp.join(dir_ann, "OpenEnded_mscoco_val2014_questions.json"), 'w') as F:
178 | F.write(json.dumps(val_q))
179 | with open(osp.join(dir_ann, "mscoco_train2014_annotations.json"), 'w') as F:
180 | F.write(json.dumps(train_ann))
181 | with open(osp.join(dir_ann, "mscoco_val2014_annotations.json"), 'w') as F:
182 | F.write(json.dumps(val_ann))
183 |
184 | def add_image_names(self, dataset):
185 | for q in dataset['questions']:
186 | q['image_name'] = 'COCO_%s_%012d.jpg'%(q['coco_split'],q['image_id'])
187 | return dataset
188 |
189 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/datasets/vqacp2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import copy
4 | import json
5 | import torch
6 | import numpy as np
7 | from tqdm import tqdm
8 | from os import path as osp
9 | from bootstrap.lib.logger import Logger
10 | from block.datasets.vqa_utils import AbstractVQA
11 | from copy import deepcopy
12 | import random
13 | import h5py
14 |
15 | class VQACP2(AbstractVQA):
16 |
17 | def __init__(self,
18 | dir_data='data/vqa/vqacp2',
19 | split='train',
20 | batch_size=80,
21 | nb_threads=4,
22 | pin_memory=False,
23 | shuffle=False,
24 | nans=1000,
25 | minwcount=10,
26 | nlp='mcb',
27 | proc_split='train',
28 | samplingans=False,
29 | dir_rcnn='data/coco/extract_rcnn',
30 | dir_cnn=None,
31 | dir_vgg16=None,
32 | has_testdevset=False,
33 | ):
34 | super(VQACP2, self).__init__(
35 | dir_data=dir_data,
36 | split=split,
37 | batch_size=batch_size,
38 | nb_threads=nb_threads,
39 | pin_memory=pin_memory,
40 | shuffle=shuffle,
41 | nans=nans,
42 | minwcount=minwcount,
43 | nlp=nlp,
44 | proc_split=proc_split,
45 | samplingans=samplingans,
46 | has_valset=True,
47 | has_testset=False,
48 | has_testdevset=has_testdevset,
49 | has_testset_anno=False,
50 | has_answers_occurence=True,
51 | do_tokenize_answers=False)
52 | self.dir_rcnn = dir_rcnn
53 | self.dir_cnn = dir_cnn
54 | self.dir_vgg16 = dir_vgg16
55 | self.load_image_features()
56 | self.load_original_annotation = False
57 |
58 | def add_rcnn_to_item(self, item):
59 | path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name']))
60 | item_rcnn = torch.load(path_rcnn)
61 | item['visual'] = item_rcnn['pooled_feat']
62 | item['coord'] = item_rcnn['rois']
63 | item['norm_coord'] = item_rcnn['norm_rois']
64 | item['nb_regions'] = item['visual'].size(0)
65 | return item
66 |
67 | def load_image_features(self):
68 | if self.dir_cnn:
69 | filename_train = os.path.join(self.dir_cnn, 'trainset.hdf5')
70 | filename_val = os.path.join(self.dir_cnn, 'valset.hdf5')
71 | Logger()(f"Opening file {filename_train}, {filename_val}")
72 | self.image_features_train = h5py.File(filename_train, 'r', swmr=True)
73 | self.image_features_val = h5py.File(filename_val, 'r', swmr=True)
74 | # load txt
75 | with open(os.path.join(self.dir_cnn, 'trainset.txt'.format(self.split)), 'r') as f:
76 | self.image_names_to_index_train = {}
77 | for i, line in enumerate(f):
78 | self.image_names_to_index_train[line.strip()] = i
79 | with open(os.path.join(self.dir_cnn, 'valset.txt'.format(self.split)), 'r') as f:
80 | self.image_names_to_index_val = {}
81 | for i, line in enumerate(f):
82 | self.image_names_to_index_val[line.strip()] = i
83 | elif self.dir_vgg16:
84 | # list filenames
85 | self.filenames_train = os.listdir(os.path.join(self.dir_vgg16, 'train'))
86 | self.filenames_val = os.listdir(os.path.join(self.dir_vgg16, 'val'))
87 |
88 |
89 | def add_vgg_to_item(self, item):
90 | image_name = item['image_name']
91 | filename = image_name + '.pth'
92 | if filename in self.filenames_train:
93 | path = os.path.join(self.dir_vgg16, 'train', filename)
94 | elif filename in self.filenames_val:
95 | path = os.path.join(self.dir_vgg16, 'val', filename)
96 | visual = torch.load(path)
97 | visual = visual.permute(1, 2, 0).view(14*14, 512)
98 | item['visual'] = visual
99 | return item
100 |
101 | def add_cnn_to_item(self, item):
102 | image_name = item['image_name']
103 | if image_name in self.image_names_to_index_train:
104 | index = self.image_names_to_index_train[image_name]
105 | image = torch.tensor(self.image_features_train['att'][index])
106 | elif image_name in self.image_names_to_index_val:
107 | index = self.image_names_to_index_val[image_name]
108 | image = torch.tensor(self.image_features_val['att'][index])
109 | image = image.permute(1, 2, 0).view(196, 2048)
110 | item['visual'] = image
111 | return item
112 |
113 | def __getitem__(self, index):
114 | item = {}
115 | item['index'] = index
116 |
117 | # Process Question (word token)
118 | question = self.dataset['questions'][index]
119 | if self.load_original_annotation:
120 | item['original_question'] = question
121 | item['question_id'] = question['question_id']
122 | item['question'] = torch.LongTensor(question['question_wids'])
123 | item['lengths'] = torch.LongTensor([len(question['question_wids'])])
124 | item['image_name'] = question['image_name']
125 |
126 | # Process Object, Attribut and Relational features
127 | if self.dir_rcnn:
128 | item = self.add_rcnn_to_item(item)
129 | elif self.dir_cnn:
130 | item = self.add_cnn_to_item(item)
131 | elif self.dir_vgg16:
132 | item = self.add_vgg_to_item(item)
133 |
134 | # Process Answer if exists
135 | if 'annotations' in self.dataset:
136 | annotation = self.dataset['annotations'][index]
137 | if self.load_original_annotation:
138 | item['original_annotation'] = annotation
139 | if 'train' in self.split and self.samplingans:
140 | proba = annotation['answers_count']
141 | proba = proba / np.sum(proba)
142 | item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba))
143 | else:
144 | item['answer_id'] = annotation['answer_id']
145 | item['class_id'] = torch.LongTensor([item['answer_id']])
146 | item['answer'] = annotation['answer']
147 | item['question_type'] = annotation['question_type']
148 |
149 | return item
150 |
151 | def download(self):
152 | dir_ann = osp.join(self.dir_raw, 'annotations')
153 | os.system('mkdir -p '+dir_ann)
154 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json -P' + dir_ann)
155 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json -P' + dir_ann)
156 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json -P' + dir_ann)
157 | os.system('wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json -P' + dir_ann)
158 | train_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v2_train_questions.json")))}
159 | val_q = {"questions":json.load(open(osp.join(dir_ann, "vqacp_v2_test_questions.json")))}
160 | train_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v2_train_annotations.json")))}
161 | val_ann = {"annotations":json.load(open(osp.join(dir_ann, "vqacp_v2_test_annotations.json")))}
162 | train_q['info'] = {}
163 | train_q['data_type'] = 'mscoco'
164 | train_q['data_subtype'] = "train2014cp"
165 | train_q['task_type'] = "Open-Ended"
166 | train_q['license'] = {}
167 | val_q['info'] = {}
168 | val_q['data_type'] = 'mscoco'
169 | val_q['data_subtype'] = "val2014cp"
170 | val_q['task_type'] = "Open-Ended"
171 | val_q['license'] = {}
172 | for k in ["info", 'data_type','data_subtype', 'license']:
173 | train_ann[k] = train_q[k]
174 | val_ann[k] = val_q[k]
175 | with open(osp.join(dir_ann, "OpenEnded_mscoco_train2014_questions.json"), 'w') as F:
176 | F.write(json.dumps(train_q))
177 | with open(osp.join(dir_ann, "OpenEnded_mscoco_val2014_questions.json"), 'w') as F:
178 | F.write(json.dumps(val_q))
179 | with open(osp.join(dir_ann, "mscoco_train2014_annotations.json"), 'w') as F:
180 | F.write(json.dumps(train_ann))
181 | with open(osp.join(dir_ann, "mscoco_val2014_annotations.json"), 'w') as F:
182 | F.write(json.dumps(val_ann))
183 |
184 | def add_image_names(self, dataset):
185 | for q in dataset['questions']:
186 | q['image_name'] = 'COCO_%s_%012d.jpg'%(q['coco_split'],q['image_id'])
187 | return dataset
188 |
189 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/engines/__init__.py:
--------------------------------------------------------------------------------
1 | from .factory import factory
--------------------------------------------------------------------------------
/cfvqa/cfvqa/engines/factory.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from bootstrap.lib.options import Options
3 | from bootstrap.lib.logger import Logger
4 | from .engine import Engine
5 | from .logger import LoggerEngine
6 |
7 |
8 | def factory():
9 | Logger()('Creating engine...')
10 |
11 | if Options()['engine'].get('import', False):
12 | # import usually is "yourmodule.engine.factory"
13 | module = importlib.import_module(Options()['engine']['import'])
14 | engine = module.factory()
15 |
16 | elif Options()['engine']['name'] == 'default':
17 | engine = Engine()
18 |
19 | elif Options()['engine']['name'] == 'logger':
20 | engine = LoggerEngine()
21 |
22 | else:
23 | raise ValueError
24 |
25 | return engine
26 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/engines/logger.py:
--------------------------------------------------------------------------------
1 | from bootstrap.lib.logger import Logger
2 | from .engine import Engine
3 |
4 |
5 | class LoggerEngine(Engine):
6 | """ LoggerEngine is similar to Engine. The only difference is a more powerful is_best method.
7 | It is able to look into the logger dictionary that contains the list of all the logged variables
8 | indexed by name.
9 |
10 | Example usage:
11 |
12 | .. code-block:: python
13 |
14 | out = {
15 | 'loss': 0.2,
16 | 'acctop1': 87.02
17 | }
18 | engine.is_best(out, 'loss:min')
19 |
20 | # Logger().values['eval_epoch.recall_at_1'] contains a list
21 | # of all the recall at 1 values for each evaluation epoch
22 | engine.is_best(out, 'eval_epoch.recall_at_1')
23 | """
24 |
25 | def __init__(self):
26 | super(LoggerEngine, self).__init__()
27 |
28 | def is_best(self, out, saving_criteria):
29 | if ':min' in saving_criteria:
30 | name = saving_criteria.replace(':min', '')
31 | order = '<'
32 | elif ':max' in saving_criteria:
33 | name = saving_criteria.replace(':max', '')
34 | order = '>'
35 | else:
36 | error_msg = """'--engine.saving_criteria' named '{}' does not specify order,
37 | you need to chose between '{}' or '{}' to specify if the criteria needs to be minimize or maximize""".format(
38 | saving_criteria, saving_criteria + ':min', saving_criteria + ':max')
39 | raise ValueError(error_msg)
40 |
41 | if name in out:
42 | new_value = out[name]
43 | elif name in Logger().values:
44 | new_value = Logger().values[name][-1]
45 | else:
46 | raise ValueError("name '{}' not in outputs '{}' and not in logger '{}'".format(
47 | name, list(out.keys()), list(Logger().values.keys())))
48 |
49 | if name not in self.best_out:
50 | self.best_out[name] = new_value
51 | else:
52 | if eval('{} {} {}'.format(new_value, order, self.best_out[name])):
53 | self.best_out[name] = new_value
54 | return True
55 |
56 | return False
57 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/criterions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/models/criterions/__init__.py
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/criterions/cfvqa_criterion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from bootstrap.lib.logger import Logger
5 | from bootstrap.lib.options import Options
6 |
7 | class CFVQACriterion(nn.Module):
8 |
9 | def __init__(self, question_loss_weight=1.0, vision_loss_weight=1.0, is_va=True):
10 | super().__init__()
11 | self.is_va = is_va
12 |
13 | Logger()(f'CFVQACriterion, with question_loss_weight = ({question_loss_weight})')
14 | if self.is_va:
15 | Logger()(f'CFVQACriterion, with vision_loss_weight = ({vision_loss_weight})')
16 |
17 | self.fusion_loss = nn.CrossEntropyLoss()
18 | self.question_loss = nn.CrossEntropyLoss()
19 | self.question_loss_weight = question_loss_weight
20 | if self.is_va:
21 | self.vision_loss = nn.CrossEntropyLoss()
22 | self.vision_loss_weight = vision_loss_weight
23 |
24 | def forward(self, net_out, batch):
25 | out = {}
26 | class_id = batch['class_id'].squeeze(1)
27 |
28 | logits_rubi = net_out['logits_all']
29 | fusion_loss = self.fusion_loss(logits_rubi, class_id)
30 |
31 | logits_q = net_out['logits_q']
32 | question_loss = self.question_loss(logits_q, class_id)
33 |
34 | if self.is_va:
35 | logits_v = net_out['logits_v']
36 | vision_loss = self.vision_loss(logits_v, class_id)
37 |
38 | nde = net_out['z_nde']
39 | p_te = torch.nn.functional.softmax(logits_rubi, -1).clone().detach()
40 | p_nde = torch.nn.functional.softmax(nde, -1)
41 | kl_loss = - p_te*p_nde.log()
42 | kl_loss = kl_loss.sum(1).mean()
43 |
44 | loss = fusion_loss \
45 | + self.question_loss_weight * question_loss \
46 | + kl_loss
47 | if self.is_va:
48 | loss += self.vision_loss_weight * vision_loss
49 |
50 | out['loss'] = loss
51 | out['loss_mm_q'] = fusion_loss
52 | out['loss_q'] = question_loss
53 | if self.is_va:
54 | out['loss_v'] = vision_loss
55 | return out
56 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/criterions/cfvqaintrod_criterion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from bootstrap.lib.logger import Logger
5 | from bootstrap.lib.options import Options
6 |
7 | class CFVQAIntroDCriterion(nn.Module):
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | self.cls_loss = nn.CrossEntropyLoss(reduction='none')
13 |
14 | def forward(self, net_out, batch):
15 | out = {}
16 | logits_all = net_out['logits_all']
17 | class_id = batch['class_id'].squeeze(1)
18 |
19 | # KD
20 | logits_t = net_out['logits_cfvqa']
21 | logits_s = net_out['logits_stu']
22 | p_t = torch.nn.functional.softmax(logits_t, -1).clone().detach()
23 | kd_loss = - p_t*F.log_softmax(logits_s, -1)
24 | kd_loss = kd_loss.sum(1)
25 |
26 | cls_loss = self.cls_loss(logits_s, class_id)
27 |
28 | # weight estimation
29 | cls_loss_ood = self.cls_loss(logits_t, class_id)
30 | cls_loss_id = self.cls_loss(logits_all, class_id)
31 | weight = cls_loss_ood/(cls_loss_ood+cls_loss_id)
32 | weight = weight.detach()
33 |
34 | loss = (weight*kd_loss).mean() + ((1-weight)*cls_loss).mean()
35 |
36 | out['loss'] = loss
37 | return out
38 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/criterions/factory.py:
--------------------------------------------------------------------------------
1 | from bootstrap.lib.options import Options
2 | from block.models.criterions.vqa_cross_entropy import VQACrossEntropyLoss
3 | from .rubi_criterion import RUBiCriterion
4 | from .cfvqa_criterion import CFVQACriterion
5 | from .cfvqaintrod_criterion import CFVQAIntroDCriterion
6 | from .rubiintrod_criterion import RUBiIntroDCriterion
7 |
8 | def factory(engine, mode):
9 | name = Options()['model.criterion.name']
10 | split = engine.dataset[mode].split
11 | eval_only = 'train' not in engine.dataset
12 |
13 | opt = Options()['model.criterion']
14 | if split == "test" and 'tdiuc' not in Options()['dataset.name']:
15 | return None
16 | if name == 'vqa_cross_entropy':
17 | criterion = VQACrossEntropyLoss()
18 | elif name == "rubi_criterion":
19 | criterion = RUBiCriterion(
20 | question_loss_weight=opt['question_loss_weight']
21 | )
22 | elif name == "cfvqa_criterion":
23 | criterion = CFVQACriterion(
24 | question_loss_weight=opt['question_loss_weight'],
25 | vision_loss_weight=opt['vision_loss_weight'],
26 | is_va=True,
27 | )
28 | elif name == "cfvqasimple_criterion":
29 | criterion = CFVQACriterion(
30 | question_loss_weight=opt['question_loss_weight'],
31 | is_va=False,
32 | )
33 | elif name == "cfvqaintrod_criterion":
34 | criterion = CFVQAIntroDCriterion()
35 | elif name == "rubiintrod_criterion":
36 | criterion = RUBiIntroDCriterion()
37 | else:
38 | raise ValueError(name)
39 | return criterion
40 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/criterions/rubi_criterion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from bootstrap.lib.logger import Logger
5 | from bootstrap.lib.options import Options
6 |
7 | class RUBiCriterion(nn.Module):
8 |
9 | def __init__(self, question_loss_weight=1.0):
10 | super().__init__()
11 |
12 | Logger()(f'RUBiCriterion, with question_loss_weight = ({question_loss_weight})')
13 |
14 | self.question_loss_weight = question_loss_weight
15 | self.fusion_loss = nn.CrossEntropyLoss()
16 | self.question_loss = nn.CrossEntropyLoss()
17 |
18 | def forward(self, net_out, batch):
19 | out = {}
20 | # logits = net_out['logits']
21 | logits_q = net_out['logits_q']
22 | logits_rubi = net_out['logits_all']
23 | class_id = batch['class_id'].squeeze(1)
24 | fusion_loss = self.fusion_loss(logits_rubi, class_id)
25 | question_loss = self.question_loss(logits_q, class_id)
26 | loss = fusion_loss + self.question_loss_weight * question_loss
27 |
28 | out['loss'] = loss
29 | out['loss_mm_q'] = fusion_loss
30 | out['loss_q'] = question_loss
31 | return out
32 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/criterions/rubiintrod_criterion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from bootstrap.lib.logger import Logger
5 | from bootstrap.lib.options import Options
6 |
7 | class RUBiIntroDCriterion(nn.Module):
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | self.cls_loss = nn.CrossEntropyLoss(reduction='none')
13 |
14 | def forward(self, net_out, batch):
15 | out = {}
16 | logits_all = net_out['logits_all']
17 | class_id = batch['class_id'].squeeze(1)
18 |
19 | # KD
20 | logits_t = net_out['logits']
21 | logits_s = net_out['logits_stu']
22 | p_t = torch.nn.functional.softmax(logits_t, -1).clone().detach()
23 | kd_loss = - p_t*F.log_softmax(logits_s, -1)
24 | kd_loss = kd_loss.sum(1)
25 |
26 | cls_loss = self.cls_loss(logits_s, class_id)
27 |
28 | # weight estimation
29 | cls_loss_ood = self.cls_loss(logits_t, class_id)
30 | cls_loss_id = self.cls_loss(logits_all, class_id)
31 | weight = cls_loss_ood/(cls_loss_ood+cls_loss_id)
32 | weight = torch.round(weight)
33 | weight = weight.detach()
34 |
35 | loss = (weight*kd_loss).mean() + ((1-weight)*cls_loss).mean()
36 |
37 | out['loss'] = loss
38 | return out
39 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/models/metrics/__init__.py
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/metrics/factory.py:
--------------------------------------------------------------------------------
1 | from bootstrap.lib.options import Options
2 | from block.models.metrics.vqa_accuracies import VQAAccuracies
3 | from .vqa_rubi_metrics import VQARUBiMetrics
4 | from .vqa_cfvqa_metrics import VQACFVQAMetrics
5 | from .vqa_cfvqasimple_metrics import VQACFVQASimpleMetrics
6 | from .vqa_cfvqaintrod_metrics import VQACFVQAIntroDMetrics
7 | from .vqa_rubiintrod_metrics import VQARUBiIntroDMetrics
8 |
9 | def factory(engine, mode):
10 | name = Options()['model.metric.name']
11 | metric = None
12 |
13 | if name == 'vqa_accuracies':
14 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name'])
15 | if mode == 'train':
16 | split = engine.dataset['train'].split
17 | if split == 'train':
18 | metric = VQAAccuracies(engine,
19 | mode='train',
20 | open_ended=open_ended,
21 | tdiuc=True,
22 | dir_exp=Options()['exp.dir'],
23 | dir_vqa=Options()['dataset.dir'])
24 | elif split == 'trainval':
25 | metric = None
26 | else:
27 | raise ValueError(split)
28 | elif mode == 'eval':
29 | metric = VQAAccuracies(engine,
30 | mode='eval',
31 | open_ended=open_ended,
32 | tdiuc=('tdiuc' in Options()['dataset.name'] or Options()['dataset.eval_split'] != 'test'),
33 | dir_exp=Options()['exp.dir'],
34 | dir_vqa=Options()['dataset.dir'])
35 | else:
36 | metric = None
37 |
38 | elif name == "vqa_rubi_metrics":
39 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name'])
40 | metric = VQARUBiMetrics(engine,
41 | mode=mode,
42 | open_ended=open_ended,
43 | tdiuc=True,
44 | dir_exp=Options()['exp.dir'],
45 | dir_vqa=Options()['dataset.dir']
46 | )
47 |
48 | elif name == "vqa_cfvqa_metrics":
49 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name'])
50 | metric = VQACFVQAMetrics(engine,
51 | mode=mode,
52 | open_ended=open_ended,
53 | tdiuc=True,
54 | dir_exp=Options()['exp.dir'],
55 | dir_vqa=Options()['dataset.dir']
56 | )
57 |
58 | elif name == "vqa_cfvqasimple_metrics":
59 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name'])
60 | metric = VQACFVQASimpleMetrics(engine,
61 | mode=mode,
62 | open_ended=open_ended,
63 | tdiuc=True,
64 | dir_exp=Options()['exp.dir'],
65 | dir_vqa=Options()['dataset.dir']
66 | )
67 |
68 | elif name == "vqa_cfvqaintrod_metrics":
69 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name'])
70 | metric = VQACFVQAIntroDMetrics(engine,
71 | mode=mode,
72 | open_ended=open_ended,
73 | tdiuc=True,
74 | dir_exp=Options()['exp.dir'],
75 | dir_vqa=Options()['dataset.dir']
76 | )
77 |
78 | elif name == "vqa_rubiintrod_metrics":
79 | open_ended = ('tdiuc' not in Options()['dataset.name'] and 'gqa' not in Options()['dataset.name'])
80 | metric = VQARUBiIntroDMetrics(engine,
81 | mode=mode,
82 | open_ended=open_ended,
83 | tdiuc=True,
84 | dir_exp=Options()['exp.dir'],
85 | dir_vqa=Options()['dataset.dir']
86 | )
87 |
88 | else:
89 | raise ValueError(name)
90 | return metric
91 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/cfvqa/cfvqa/models/networks/__init__.py
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/cfvqa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from block.models.networks.mlp import MLP
4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask,
5 |
6 | eps = 1e-12
7 |
8 | class CFVQA(nn.Module):
9 | """
10 | Wraps another model
11 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax)
12 | Returns:
13 | - logits_vq: the original predictions of the model, i.e., NIE
14 | - logits_q: the predictions from the question-only branch
15 | - logits_v: the predictions from the vision-only branch
16 | - logits_all: the predictions from the ensemble model
17 | - logits_cfvqa: the predictions based on CF-VQA, i.e., TIE
18 | => Use `logits_all`, `logits_q` and `logits_v` for the loss
19 | """
20 | def __init__(self, model, output_size, classif_q, classif_v, fusion_mode, end_classif=True, is_va=True):
21 | super().__init__()
22 | self.net = model
23 | self.end_classif = end_classif
24 |
25 | assert fusion_mode in ['rubi', 'hm', 'sum'], "Fusion mode should be rubi/hm/sum."
26 | self.fusion_mode = fusion_mode
27 | self.is_va = is_va and (not fusion_mode=='rubi') # RUBi does not consider V->A
28 |
29 | # Q->A branch
30 | self.q_1 = MLP(**classif_q)
31 | if self.end_classif: # default: True (following RUBi)
32 | self.q_2 = nn.Linear(output_size, output_size)
33 |
34 | # V->A branch
35 | if self.is_va: # default: True (containing V->A)
36 | self.v_1 = MLP(**classif_v)
37 | if self.end_classif: # default: True (following RUBi)
38 | self.v_2 = nn.Linear(output_size, output_size)
39 |
40 | self.constant = nn.Parameter(torch.tensor(0.0))
41 |
42 | def forward(self, batch):
43 | out = {}
44 | # model prediction
45 | net_out = self.net(batch)
46 | logits = net_out['logits']
47 |
48 | # Q->A branch
49 | q_embedding = net_out['q_emb'] # N * q_emb
50 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate
51 | q_pred = self.q_1(q_embedding)
52 |
53 | # V->A branch
54 | if self.is_va:
55 | v_embedding = net_out['v_emb'] # N * v_emb
56 | v_embedding = grad_mul_const(v_embedding, 0.0) # don't backpropagate
57 | v_pred = self.v_1(v_embedding)
58 | else:
59 | v_pred = None
60 |
61 | # both q, k and v are the facts
62 | z_qkv = self.fusion(logits, q_pred, v_pred, q_fact=True, k_fact=True, v_fact=True) # te
63 | # q is the fact while k and v are the counterfactuals
64 | z_q = self.fusion(logits, q_pred, v_pred, q_fact=True, k_fact=False, v_fact=False) # nie
65 |
66 | logits_cfvqa = z_qkv - z_q
67 |
68 | if self.end_classif:
69 | q_out = self.q_2(q_pred)
70 | if self.is_va:
71 | v_out = self.v_2(v_pred)
72 | else:
73 | q_out = q_pred
74 | if self.is_va:
75 | v_out = v_pred
76 |
77 | out['logits_all'] = z_qkv # for optimization
78 | out['logits_vq'] = logits # predictions of the original VQ branch, i.e., NIE
79 | out['logits_cfvqa'] = logits_cfvqa # predictions of CFVQA, i.e., TIE
80 | out['logits_q'] = q_out # for optimization
81 | if self.is_va:
82 | out['logits_v'] = v_out # for optimization
83 |
84 | if self.is_va:
85 | out['z_nde'] = self.fusion(logits.clone().detach(), q_pred.clone().detach(), v_pred.clone().detach(), q_fact=True, k_fact=False, v_fact=False) # tie
86 | else:
87 | out['z_nde'] = self.fusion(logits.clone().detach(), q_pred.clone().detach(), None, q_fact=True, k_fact=False, v_fact=False) # tie
88 |
89 | return out
90 |
91 | def process_answers(self, out, key=''):
92 | out = self.net.process_answers(out, key='_all')
93 | out = self.net.process_answers(out, key='_vq')
94 | out = self.net.process_answers(out, key='_cfvqa')
95 | out = self.net.process_answers(out, key='_q')
96 | if self.is_va:
97 | out = self.net.process_answers(out, key='_v')
98 | return out
99 |
100 | def fusion(self, z_k, z_q, z_v, q_fact=False, k_fact=False, v_fact=False):
101 |
102 | z_k, z_q, z_v = self.transform(z_k, z_q, z_v, q_fact, k_fact, v_fact)
103 |
104 | if self.fusion_mode == 'rubi':
105 | z = z_k * torch.sigmoid(z_q)
106 |
107 | elif self.fusion_mode == 'hm':
108 | if self.is_va:
109 | z = z_k * z_q * z_v
110 | else:
111 | z = z_k * z_q
112 | z = torch.log(z + eps) - torch.log1p(z)
113 |
114 | elif self.fusion_mode == 'sum':
115 | if self.is_va:
116 | z = z_k + z_q + z_v
117 | else:
118 | z = z_k + z_q
119 | z = torch.log(torch.sigmoid(z) + eps)
120 |
121 | return z
122 |
123 | def transform(self, z_k, z_q, z_v, q_fact=False, k_fact=False, v_fact=False):
124 |
125 | if not k_fact:
126 | z_k = self.constant * torch.ones_like(z_k).cuda()
127 |
128 | if not q_fact:
129 | z_q = self.constant * torch.ones_like(z_q).cuda()
130 |
131 | if self.is_va:
132 | if not v_fact:
133 | z_v = self.constant * torch.ones_like(z_v).cuda()
134 |
135 | if self.fusion_mode == 'hm':
136 | z_k = torch.sigmoid(z_k)
137 | z_q = torch.sigmoid(z_q)
138 | if self.is_va:
139 | z_v = torch.sigmoid(z_v)
140 |
141 | return z_k, z_q, z_v
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/cfvqaintrod.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from block.models.networks.mlp import MLP
4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask,
5 |
6 | eps = 1e-12
7 |
8 | class CFVQAIntroD(nn.Module):
9 | """
10 | Wraps another model
11 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax)
12 | Returns:
13 | - logits_vq: the original predictions of the model, i.e., NIE
14 | - logits_q: the predictions from the question-only branch
15 | - logits_v: the predictions from the vision-only branch
16 | - logits_all: the predictions from the ensemble model
17 | - logits_cfvqa: the predictions based on CF-VQA, i.e., TIE
18 | => Use `logits_all`, `logits_q` and `logits_v` for the loss
19 | """
20 | def __init__(self, model, model_teacher, output_size, classif_q, classif_v, fusion_mode, end_classif=True, is_va=True):
21 | super().__init__()
22 | self.net_student = model
23 | self.net = model_teacher
24 | self.end_classif = end_classif
25 |
26 | assert fusion_mode in ['rubi', 'hm', 'sum'], "Fusion mode should be rubi/hm/sum."
27 | self.fusion_mode = fusion_mode
28 | self.is_va = is_va and (not fusion_mode=='rubi') # RUBi does not consider V->A
29 |
30 | # Q->A branch
31 | self.q_1 = MLP(**classif_q)
32 | if self.end_classif: # default: True (following RUBi)
33 | self.q_2 = nn.Linear(output_size, output_size)
34 |
35 | # V->A branch
36 | if self.is_va: # default: True (containing V->A)
37 | self.v_1 = MLP(**classif_v)
38 | if self.end_classif: # default: True (following RUBi)
39 | self.v_2 = nn.Linear(output_size, output_size)
40 |
41 | self.constant = nn.Parameter(torch.tensor(0.0))
42 | self.constant.requires_grad = True
43 |
44 | self.net.eval()
45 | self.q_1.eval()
46 | if self.end_classif:
47 | self.q_2.eval()
48 | if self.is_va:
49 | self.v_1.eval()
50 | if self.end_classif:
51 | self.v_2.eval()
52 |
53 | def forward(self, batch):
54 | out = {}
55 | # model prediction
56 | net_out = self.net(batch)
57 | logits = net_out['logits']
58 |
59 | # Q->A branch
60 | q_embedding = net_out['q_emb'] # N * q_emb
61 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate
62 | q_pred = self.q_1(q_embedding)
63 |
64 | # V->A branch
65 | if self.is_va:
66 | v_embedding = net_out['v_emb'] # N * v_emb
67 | v_embedding = grad_mul_const(v_embedding, 0.0) # don't backpropagate
68 | v_pred = self.v_1(v_embedding)
69 | else:
70 | v_pred = None
71 |
72 | # both q, k and v are the facts
73 | z_qkv = self.fusion(logits, q_pred, v_pred, q_fact=True, k_fact=True, v_fact=True) # te
74 | # q is the fact while k and v are the counterfactuals
75 | z_q = self.fusion(logits, q_pred, v_pred, q_fact=True, k_fact=False, v_fact=False) # nie
76 |
77 | logits_cfvqa = z_qkv - z_q
78 |
79 | if self.end_classif:
80 | q_out = self.q_2(q_pred)
81 | if self.is_va:
82 | v_out = self.v_2(v_pred)
83 | else:
84 | q_out = q_pred
85 | if self.is_va:
86 | v_out = v_pred
87 |
88 | out['logits_all'] = z_qkv # for optimization
89 | out['logits_vq'] = logits # predictions of the original VQ branch, i.e., NIE
90 | out['logits_cfvqa'] = logits_cfvqa # predictions of CFVQA, i.e., TIE
91 | out['logits_q'] = q_out # for optimization
92 | if self.is_va:
93 | out['logits_v'] = v_out # for optimization
94 |
95 | # student model
96 | logits_stu = self.net_student(batch)
97 | out['logits_stu'] = logits_stu['logits']
98 |
99 | return out
100 |
101 | def process_answers(self, out, key=''):
102 | out = self.net.process_answers(out, key='_all')
103 | out = self.net.process_answers(out, key='_vq')
104 | out = self.net.process_answers(out, key='_cfvqa')
105 | out = self.net.process_answers(out, key='_q')
106 | if self.is_va:
107 | out = self.net.process_answers(out, key='_v')
108 |
109 | # student model
110 | out = self.net.process_answers(out, key='_stu')
111 |
112 | return out
113 |
114 | def fusion(self, z_k, z_q, z_v, q_fact=False, k_fact=False, v_fact=False):
115 |
116 | z_k, z_q, z_v = self.transform(z_k, z_q, z_v, q_fact, k_fact, v_fact)
117 |
118 | if self.fusion_mode == 'rubi':
119 | z = z_k * torch.sigmoid(z_q)
120 |
121 | elif self.fusion_mode == 'hm':
122 | if self.is_va:
123 | z = z_k * z_q * z_v
124 | else:
125 | z = z_k * z_q
126 | z = torch.log(z + eps) - torch.log1p(z)
127 |
128 | elif self.fusion_mode == 'sum':
129 | if self.is_va:
130 | z = z_k + z_q + z_v
131 | else:
132 | z = z_k + z_q
133 | z = torch.log(torch.sigmoid(z) + eps)
134 |
135 | return z
136 |
137 | def transform(self, z_k, z_q, z_v, q_fact=False, k_fact=False, v_fact=False):
138 |
139 | if not k_fact:
140 | z_k = self.constant * torch.ones_like(z_k).cuda()
141 |
142 | if not q_fact:
143 | z_q = self.constant * torch.ones_like(z_q).cuda()
144 |
145 | if self.is_va:
146 | if not v_fact:
147 | z_v = self.constant * torch.ones_like(z_v).cuda()
148 |
149 | if self.fusion_mode == 'hm':
150 | z_k = torch.sigmoid(z_k)
151 | z_q = torch.sigmoid(z_q)
152 | if self.is_va:
153 | z_v = torch.sigmoid(z_v)
154 |
155 | return z_k, z_q, z_v
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/factory.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import copy
3 | import torch
4 | import torch.nn as nn
5 | import os
6 | import json
7 | from bootstrap.lib.options import Options
8 | from bootstrap.models.networks.data_parallel import DataParallel
9 | from block.models.networks.vqa_net import VQANet as AttentionNet
10 | from bootstrap.lib.logger import Logger
11 |
12 | from .rubi import RUBiNet
13 | from .cfvqa import CFVQA
14 | from .cfvqaintrod import CFVQAIntroD
15 | from .rubiintrod import RUBiIntroD
16 |
17 | def factory(engine):
18 | mode = list(engine.dataset.keys())[0]
19 | dataset = engine.dataset[mode]
20 | opt = Options()['model.network']
21 |
22 | if opt['base'] == 'smrl':
23 | from .smrl_net import SMRLNet as BaselineNet
24 | elif opt['base'] == 'updn':
25 | from .updn_net import UpDnNet as BaselineNet
26 | elif opt['base'] == 'san':
27 | from .san_net import SANNet as BaselineNet
28 | else:
29 | raise ValueError(opt['base'])
30 |
31 | orig_net = BaselineNet(
32 | txt_enc=opt['txt_enc'],
33 | self_q_att=opt['self_q_att'],
34 | agg=opt['agg'],
35 | classif=opt['classif'],
36 | wid_to_word=dataset.wid_to_word,
37 | word_to_wid=dataset.word_to_wid,
38 | aid_to_ans=dataset.aid_to_ans,
39 | ans_to_aid=dataset.ans_to_aid,
40 | fusion=opt['fusion'],
41 | residual=opt['residual'],
42 | q_single=opt['q_single'],
43 | )
44 |
45 | if opt['name'] == 'baseline':
46 | net = orig_net
47 |
48 | elif opt['name'] == 'rubi':
49 | net = RUBiNet(
50 | model=orig_net,
51 | output_size=len(dataset.aid_to_ans),
52 | classif=opt['rubi_params']['mlp_q']
53 | )
54 |
55 | elif opt['name'] == 'cfvqa':
56 | net = CFVQA(
57 | model=orig_net,
58 | output_size=len(dataset.aid_to_ans),
59 | classif_q=opt['cfvqa_params']['mlp_q'],
60 | classif_v=opt['cfvqa_params']['mlp_v'],
61 | fusion_mode=opt['fusion_mode'],
62 | is_va=True
63 | )
64 |
65 | elif opt['name'] == 'cfvqasimple':
66 | net = CFVQA(
67 | model=orig_net,
68 | output_size=len(dataset.aid_to_ans),
69 | classif_q=opt['cfvqa_params']['mlp_q'],
70 | classif_v=None,
71 | fusion_mode=opt['fusion_mode'],
72 | is_va=False
73 | )
74 |
75 | elif opt['name'] == 'cfvqaintrod':
76 | orig_net_teacher = BaselineNet(
77 | txt_enc=opt['txt_enc'],
78 | self_q_att=opt['self_q_att'],
79 | agg=opt['agg'],
80 | classif=opt['classif'],
81 | wid_to_word=dataset.wid_to_word,
82 | word_to_wid=dataset.word_to_wid,
83 | aid_to_ans=dataset.aid_to_ans,
84 | ans_to_aid=dataset.ans_to_aid,
85 | fusion=opt['fusion'],
86 | residual=opt['residual'],
87 | q_single=opt['q_single'],
88 | )
89 | net = CFVQAIntroD(
90 | model=orig_net,
91 | model_teacher=orig_net_teacher,
92 | output_size=len(dataset.aid_to_ans),
93 | classif_q=opt['cfvqa_params']['mlp_q'],
94 | classif_v=opt['cfvqa_params']['mlp_v'],
95 | fusion_mode=opt['fusion_mode']
96 | )
97 |
98 | elif opt['name'] == 'cfvqasimpleintrod':
99 | orig_net_teacher = BaselineNet(
100 | txt_enc=opt['txt_enc'],
101 | self_q_att=opt['self_q_att'],
102 | agg=opt['agg'],
103 | classif=opt['classif'],
104 | wid_to_word=dataset.wid_to_word,
105 | word_to_wid=dataset.word_to_wid,
106 | aid_to_ans=dataset.aid_to_ans,
107 | ans_to_aid=dataset.ans_to_aid,
108 | fusion=opt['fusion'],
109 | residual=opt['residual'],
110 | q_single=opt['q_single'],
111 | )
112 | net = CFVQAIntroD(
113 | model=orig_net,
114 | model_teacher=orig_net_teacher,
115 | output_size=len(dataset.aid_to_ans),
116 | classif_q=opt['cfvqa_params']['mlp_q'],
117 | classif_v=None,
118 | fusion_mode=opt['fusion_mode'],
119 | is_va=False
120 | )
121 |
122 | elif opt['name'] == 'rubiintrod':
123 | orig_net_teacher = BaselineNet(
124 | txt_enc=opt['txt_enc'],
125 | self_q_att=opt['self_q_att'],
126 | agg=opt['agg'],
127 | classif=opt['classif'],
128 | wid_to_word=dataset.wid_to_word,
129 | word_to_wid=dataset.word_to_wid,
130 | aid_to_ans=dataset.aid_to_ans,
131 | ans_to_aid=dataset.ans_to_aid,
132 | fusion=opt['fusion'],
133 | residual=opt['residual'],
134 | q_single=opt['q_single'],
135 | )
136 | net = RUBiIntroD(
137 | model=orig_net,
138 | model_teacher=orig_net_teacher,
139 | output_size=len(dataset.aid_to_ans),
140 | classif=opt['rubi_params']['mlp_q']
141 | )
142 |
143 | else:
144 | raise ValueError(opt['name'])
145 |
146 | if Options()['misc.cuda'] and torch.cuda.device_count() > 1:
147 | net = DataParallel(net)
148 |
149 | return net
150 |
151 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/rubi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from block.models.networks.mlp import MLP
4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask,
5 |
6 |
7 | class RUBiNet(nn.Module):
8 | """
9 | Wraps another model
10 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax)
11 | Returns:
12 | - logits: the original predictions of the model
13 | - logits_q: the predictions from the question-only branch
14 | - logits_rubi: the updated predictions from the model by the mask.
15 | => Use `logits_rubi` and `logits_q` for the loss
16 | """
17 | def __init__(self, model, output_size, classif, end_classif=True):
18 | super().__init__()
19 | self.net = model
20 | self.c_1 = MLP(**classif)
21 | self.end_classif = end_classif
22 | if self.end_classif:
23 | self.c_2 = nn.Linear(output_size, output_size)
24 |
25 | def forward(self, batch):
26 | out = {}
27 | # model prediction
28 | net_out = self.net(batch)
29 | logits = net_out['logits']
30 |
31 | q_embedding = net_out['q_emb'] # N * q_emb
32 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate through question encoder
33 | q_pred = self.c_1(q_embedding)
34 | fusion_pred = logits * torch.sigmoid(q_pred)
35 |
36 | if self.end_classif:
37 | q_out = self.c_2(q_pred)
38 | else:
39 | q_out = q_pred
40 |
41 | out['logits'] = net_out['logits']
42 | out['logits_all'] = fusion_pred
43 | out['logits_q'] = q_out
44 | return out
45 |
46 | def process_answers(self, out, key=''):
47 | out = self.net.process_answers(out)
48 | out = self.net.process_answers(out, key='_all')
49 | out = self.net.process_answers(out, key='_q')
50 | return out
51 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/rubiintrod.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from block.models.networks.mlp import MLP
4 | from .utils import grad_mul_const # mask_softmax, grad_reverse, grad_reverse_mask,
5 |
6 |
7 | class RUBiIntroD(nn.Module):
8 | """
9 | Wraps another model
10 | The original model must return a dictionnary containing the 'logits' key (predictions before softmax)
11 | Returns:
12 | - logits: the original predictions of the model
13 | - logits_q: the predictions from the question-only branch
14 | - logits_rubi: the updated predictions from the model by the mask.
15 | => Use `logits_rubi` and `logits_q` for the loss
16 | """
17 | def __init__(self, model, model_teacher, output_size, classif, end_classif=True):
18 | super().__init__()
19 | self.net_student = model
20 | self.net = model_teacher
21 | self.c_1 = MLP(**classif)
22 | self.end_classif = end_classif
23 | if self.end_classif:
24 | self.c_2 = nn.Linear(output_size, output_size)
25 |
26 | self.net.eval()
27 | self.c_1.eval()
28 | self.c_2.eval()
29 |
30 | def forward(self, batch):
31 | out = {}
32 | # model prediction
33 | net_out = self.net(batch)
34 | logits = net_out['logits']
35 |
36 | q_embedding = net_out['q_emb'] # N * q_emb
37 | q_embedding = grad_mul_const(q_embedding, 0.0) # don't backpropagate through question encoder
38 | q_pred = self.c_1(q_embedding)
39 | fusion_pred = logits * torch.sigmoid(q_pred)
40 |
41 | if self.end_classif:
42 | q_out = self.c_2(q_pred)
43 | else:
44 | q_out = q_pred
45 |
46 | out['logits'] = net_out['logits']
47 | out['logits_all'] = fusion_pred
48 | out['logits_q'] = q_out
49 |
50 | # student model
51 | logits_stu = self.net_student(batch)
52 | out['logits_stu'] = logits_stu['logits']
53 |
54 | return out
55 |
56 | def process_answers(self, out, key=''):
57 | out = self.net.process_answers(out)
58 | out = self.net.process_answers(out, key='_all')
59 | out = self.net.process_answers(out, key='_q')
60 | out = self.net.process_answers(out, key='_stu')
61 | return out
62 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/smrl_net.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import itertools
3 | import os
4 | import numpy as np
5 | import scipy
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from bootstrap.lib.options import Options
10 | from bootstrap.lib.logger import Logger
11 | import block
12 | from block.models.networks.vqa_net import factory_text_enc
13 | from block.models.networks.mlp import MLP
14 |
15 | from .utils import mask_softmax
16 |
17 | class SMRLNet(nn.Module):
18 |
19 | def __init__(self,
20 | txt_enc={},
21 | self_q_att=False,
22 | agg={},
23 | classif={},
24 | wid_to_word={},
25 | word_to_wid={},
26 | aid_to_ans=[],
27 | ans_to_aid={},
28 | fusion={},
29 | residual=False,
30 | q_single=False,
31 | ):
32 | super().__init__()
33 | self.self_q_att = self_q_att
34 | self.agg = agg
35 | assert self.agg['type'] in ['max', 'mean']
36 | self.classif = classif
37 | self.wid_to_word = wid_to_word
38 | self.word_to_wid = word_to_wid
39 | self.aid_to_ans = aid_to_ans
40 | self.ans_to_aid = ans_to_aid
41 | self.fusion = fusion
42 | self.residual = residual
43 |
44 | # Modules
45 | self.txt_enc = self.get_text_enc(self.wid_to_word, txt_enc)
46 | if self.self_q_att:
47 | self.q_att_linear0 = nn.Linear(2400, 512)
48 | self.q_att_linear1 = nn.Linear(512, 2)
49 |
50 | if q_single:
51 | self.txt_enc_single = self.get_text_enc(self.wid_to_word, txt_enc)
52 | if self.self_q_att:
53 | self.q_att_linear0_single = nn.Linear(2400, 512)
54 | self.q_att_linear1_single = nn.Linear(512, 2)
55 | else:
56 | self.txt_enc_single = None
57 |
58 | self.fusion_module = block.factory_fusion(self.fusion)
59 |
60 | if self.classif['mlp']['dimensions'][-1] != len(self.aid_to_ans):
61 | Logger()(f"Warning, the classif_mm output dimension ({self.classif['mlp']['dimensions'][-1]})"
62 | f"doesn't match the number of answers ({len(self.aid_to_ans)}). Modifying the output dimension.")
63 | self.classif['mlp']['dimensions'][-1] = len(self.aid_to_ans)
64 |
65 | self.classif_module = MLP(**self.classif['mlp'])
66 |
67 | Logger().log_value('nparams',
68 | sum(p.numel() for p in self.parameters() if p.requires_grad),
69 | should_print=True)
70 |
71 | Logger().log_value('nparams_txt_enc',
72 | self.get_nparams_txt_enc(),
73 | should_print=True)
74 |
75 |
76 | def get_text_enc(self, vocab_words, options):
77 | """
78 | returns the text encoding network.
79 | """
80 | return factory_text_enc(self.wid_to_word, options)
81 |
82 | def get_nparams_txt_enc(self):
83 | params = [p.numel() for p in self.txt_enc.parameters() if p.requires_grad]
84 | if self.self_q_att:
85 | params += [p.numel() for p in self.q_att_linear0.parameters() if p.requires_grad]
86 | params += [p.numel() for p in self.q_att_linear1.parameters() if p.requires_grad]
87 | return sum(params)
88 |
89 | def process_fusion(self, q, mm):
90 | bsize = mm.shape[0]
91 | n_regions = mm.shape[1]
92 |
93 | mm = mm.contiguous().view(bsize*n_regions, -1)
94 | mm = self.fusion_module([q, mm])
95 | mm = mm.view(bsize, n_regions, -1)
96 | return mm
97 |
98 | def forward(self, batch):
99 | v = batch['visual']
100 | q = batch['question']
101 | l = batch['lengths'].data
102 | c = batch['norm_coord']
103 | nb_regions = batch.get('nb_regions')
104 | bsize = v.shape[0]
105 | n_regions = v.shape[1]
106 |
107 | out = {}
108 |
109 | q = self.process_question(q, l,)
110 | out['q_emb'] = q
111 | q_expand = q[:,None,:].expand(bsize, n_regions, q.shape[1])
112 | q_expand = q_expand.contiguous().view(bsize*n_regions, -1)
113 |
114 | # single txt encoder
115 | if self.txt_enc_single is not None:
116 | out['q_emb'] = self.process_question(q, l, self.txt_enc_single, self.q_att_linear0_single, self.q_att_linear1_single)
117 |
118 | mm = self.process_fusion(q_expand, v,)
119 |
120 | if self.residual:
121 | mm = v + mm
122 |
123 | if self.agg['type'] == 'max':
124 | mm, mm_argmax = torch.max(mm, 1)
125 | elif self.agg['type'] == 'mean':
126 | mm = mm.mean(1)
127 |
128 | out['v_emb'] = v.mean(1)
129 | out['mm'] = mm
130 | out['mm_argmax'] = mm_argmax
131 |
132 | logits = self.classif_module(mm)
133 | out['logits'] = logits
134 | return out
135 |
136 | def process_question(self, q, l, txt_enc=None, q_att_linear0=None, q_att_linear1=None):
137 | if txt_enc is None:
138 | txt_enc = self.txt_enc
139 | if q_att_linear0 is None:
140 | q_att_linear0 = self.q_att_linear0
141 | if q_att_linear1 is None:
142 | q_att_linear1 = self.q_att_linear1
143 | q_emb = txt_enc.embedding(q)
144 |
145 | q, _ = txt_enc.rnn(q_emb)
146 |
147 | if self.self_q_att:
148 | q_att = q_att_linear0(q)
149 | q_att = F.relu(q_att)
150 | q_att = q_att_linear1(q_att)
151 | q_att = mask_softmax(q_att, l)
152 | #self.q_att_coeffs = q_att
153 | if q_att.size(2) > 1:
154 | q_atts = torch.unbind(q_att, dim=2)
155 | q_outs = []
156 | for q_att in q_atts:
157 | q_att = q_att.unsqueeze(2)
158 | q_att = q_att.expand_as(q)
159 | q_out = q_att*q
160 | q_out = q_out.sum(1)
161 | q_outs.append(q_out)
162 | q = torch.cat(q_outs, dim=1)
163 | else:
164 | q_att = q_att.expand_as(q)
165 | q = q_att * q
166 | q = q.sum(1)
167 | else:
168 | # l contains the number of words for each question
169 | # in case of multi-gpus it must be a Tensor
170 | # thus we convert it into a list during the forward pass
171 | l = list(l.data[:,0])
172 | q = txt_enc._select_last(q, l)
173 |
174 | return q
175 |
176 | def process_answers(self, out, key=''):
177 | batch_size = out[f'logits{key}'].shape[0]
178 | _, pred = out[f'logits{key}'].data.max(1)
179 | pred.squeeze_()
180 | if batch_size != 1:
181 | out[f'answers{key}'] = [self.aid_to_ans[pred[i].item()] for i in range(batch_size)]
182 | out[f'answer_ids{key}'] = [pred[i].item() for i in range(batch_size)]
183 | else:
184 | out[f'answers{key}'] = [self.aid_to_ans[pred.item()]]
185 | out[f'answer_ids{key}'] = [pred.item()]
186 | return out
187 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/updn_net.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import itertools
3 | import os
4 | import numpy as np
5 | import scipy
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from bootstrap.lib.options import Options
10 | from bootstrap.lib.logger import Logger
11 | import block
12 | from block.models.networks.vqa_net import factory_text_enc
13 | from block.models.networks.mlp import MLP
14 |
15 | from .utils import mask_softmax
16 |
17 | from torch.nn.utils.weight_norm import weight_norm
18 |
19 | class UpDnNet(nn.Module):
20 |
21 | def __init__(self,
22 | txt_enc={},
23 | self_q_att=False,
24 | agg={},
25 | classif={},
26 | wid_to_word={},
27 | word_to_wid={},
28 | aid_to_ans=[],
29 | ans_to_aid={},
30 | fusion={},
31 | residual=False,
32 | q_single=False,
33 | ):
34 | super().__init__()
35 | self.self_q_att = self_q_att
36 | self.agg = agg
37 | assert self.agg['type'] in ['max', 'mean']
38 | self.classif = classif
39 | self.wid_to_word = wid_to_word
40 | self.word_to_wid = word_to_wid
41 | self.aid_to_ans = aid_to_ans
42 | self.ans_to_aid = ans_to_aid
43 | self.fusion = fusion
44 | self.residual = residual
45 |
46 | # Modules
47 | self.txt_enc = self.get_text_enc(self.wid_to_word, txt_enc)
48 | if self.self_q_att:
49 | self.q_att_linear0 = nn.Linear(2400, 512)
50 | self.q_att_linear1 = nn.Linear(512, 2)
51 |
52 | if q_single:
53 | self.txt_enc_single = self.get_text_enc(self.wid_to_word, txt_enc)
54 | if self.self_q_att:
55 | self.q_att_linear0_single = nn.Linear(2400, 512)
56 | self.q_att_linear1_single = nn.Linear(512, 2)
57 | else:
58 | self.txt_enc_single = None
59 |
60 | if self.classif['mlp']['dimensions'][-1] != len(self.aid_to_ans):
61 | Logger()(f"Warning, the classif_mm output dimension ({self.classif['mlp']['dimensions'][-1]})"
62 | f"doesn't match the number of answers ({len(self.aid_to_ans)}). Modifying the output dimension.")
63 | self.classif['mlp']['dimensions'][-1] = len(self.aid_to_ans)
64 |
65 | self.classif_module = MLP(**self.classif['mlp'])
66 |
67 | # UpDn
68 | q_dim = self.fusion['input_dims'][0]
69 | v_dim = self.fusion['input_dims'][1]
70 | output_dim = self.fusion['output_dim']
71 | self.v_att = Attention(v_dim, q_dim, output_dim)
72 | self.q_net = FCNet([q_dim, output_dim])
73 | self.v_net = FCNet([v_dim, output_dim])
74 |
75 | Logger().log_value('nparams',
76 | sum(p.numel() for p in self.parameters() if p.requires_grad),
77 | should_print=True)
78 |
79 | Logger().log_value('nparams_txt_enc',
80 | self.get_nparams_txt_enc(),
81 | should_print=True)
82 |
83 |
84 | def get_text_enc(self, vocab_words, options):
85 | """
86 | returns the text encoding network.
87 | """
88 | return factory_text_enc(self.wid_to_word, options)
89 |
90 | def get_nparams_txt_enc(self):
91 | params = [p.numel() for p in self.txt_enc.parameters() if p.requires_grad]
92 | if self.self_q_att:
93 | params += [p.numel() for p in self.q_att_linear0.parameters() if p.requires_grad]
94 | params += [p.numel() for p in self.q_att_linear1.parameters() if p.requires_grad]
95 | return sum(params)
96 |
97 | def forward(self, batch):
98 | v = batch['visual']
99 | q = batch['question']
100 | l = batch['lengths'].data
101 | c = batch['norm_coord']
102 | nb_regions = batch.get('nb_regions')
103 |
104 | out = {}
105 |
106 | q_emb = self.process_question(q, l,)
107 | out['v_emb'] = v.mean(1)
108 | out['q_emb'] = q_emb
109 |
110 | # single txt encoder
111 | if self.txt_enc_single is not None:
112 | out['q_emb'] = self.process_question(q, l, self.txt_enc_single, self.q_att_linear0_single, self.q_att_linear1_single)
113 |
114 | # New
115 | att = self.v_att(v, q_emb)
116 | v_emb = (att * v).sum(1)
117 | q_repr = self.q_net(q_emb)
118 | v_repr = self.v_net(v_emb)
119 | joint_repr = q_repr * v_repr
120 |
121 | logits = self.classif_module(joint_repr)
122 | out['logits'] = logits
123 |
124 | return out
125 |
126 | def process_question(self, q, l, txt_enc=None, q_att_linear0=None, q_att_linear1=None):
127 | if txt_enc is None:
128 | txt_enc = self.txt_enc
129 | if q_att_linear0 is None:
130 | q_att_linear0 = self.q_att_linear0
131 | if q_att_linear1 is None:
132 | q_att_linear1 = self.q_att_linear1
133 | q_emb = txt_enc.embedding(q)
134 |
135 | q, _ = txt_enc.rnn(q_emb)
136 |
137 | if self.self_q_att:
138 | q_att = q_att_linear0(q)
139 | q_att = F.relu(q_att)
140 | q_att = q_att_linear1(q_att)
141 | q_att = mask_softmax(q_att, l)
142 | #self.q_att_coeffs = q_att
143 | if q_att.size(2) > 1:
144 | q_atts = torch.unbind(q_att, dim=2)
145 | q_outs = []
146 | for q_att in q_atts:
147 | q_att = q_att.unsqueeze(2)
148 | q_att = q_att.expand_as(q)
149 | q_out = q_att*q
150 | q_out = q_out.sum(1)
151 | q_outs.append(q_out)
152 | q = torch.cat(q_outs, dim=1)
153 | else:
154 | q_att = q_att.expand_as(q)
155 | q = q_att * q
156 | q = q.sum(1)
157 | else:
158 | # l contains the number of words for each question
159 | # in case of multi-gpus it must be a Tensor
160 | # thus we convert it into a list during the forward pass
161 | l = list(l.data[:,0])
162 | q = txt_enc._select_last(q, l)
163 |
164 | return q
165 |
166 | def process_answers(self, out, key=''):
167 | batch_size = out[f'logits{key}'].shape[0]
168 | _, pred = out[f'logits{key}'].data.max(1)
169 | pred.squeeze_()
170 | if batch_size != 1:
171 | out[f'answers{key}'] = [self.aid_to_ans[pred[i].item()] for i in range(batch_size)]
172 | out[f'answer_ids{key}'] = [pred[i].item() for i in range(batch_size)]
173 | else:
174 | out[f'answers{key}'] = [self.aid_to_ans[pred.item()]]
175 | out[f'answer_ids{key}'] = [pred.item()]
176 | return out
177 |
178 | class Attention(nn.Module):
179 | def __init__(self, v_dim, q_dim, num_hid, dropout=0.2):
180 | super(Attention, self).__init__()
181 |
182 | self.v_proj = FCNet([v_dim, num_hid])
183 | self.q_proj = FCNet([q_dim, num_hid])
184 | self.dropout = nn.Dropout(dropout)
185 | self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None)
186 |
187 | def forward(self, v, q):
188 | """
189 | v: [batch, k, vdim]
190 | q: [batch, qdim]
191 | """
192 | logits = self.logits(v, q)
193 | w = nn.functional.softmax(logits, 1)
194 | return w
195 |
196 | def logits(self, v, q):
197 | batch, k, _ = v.size()
198 | v_proj = self.v_proj(v) # [batch, k, qdim]
199 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1)
200 | joint_repr = v_proj * q_proj
201 | joint_repr = self.dropout(joint_repr)
202 | logits = self.linear(joint_repr)
203 | return logits
204 |
205 | class FCNet(nn.Module):
206 | """Simple class for non-linear fully connect network
207 | """
208 | def __init__(self, dims):
209 | super(FCNet, self).__init__()
210 |
211 | layers = []
212 | for i in range(len(dims)-2):
213 | in_dim = dims[i]
214 | out_dim = dims[i+1]
215 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
216 | layers.append(nn.ReLU())
217 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
218 | layers.append(nn.ReLU())
219 |
220 | self.main = nn.Sequential(*layers)
221 |
222 | def forward(self, x):
223 | return self.main(x)
--------------------------------------------------------------------------------
/cfvqa/cfvqa/models/networks/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def mask_softmax(x, lengths):#, dim=1)
4 | mask = torch.zeros_like(x).to(device=x.device, non_blocking=True)
5 | t_lengths = lengths[:,:,None].expand_as(mask)
6 | arange_id = torch.arange(mask.size(1)).to(device=x.device, non_blocking=True)
7 | arange_id = arange_id[None,:,None].expand_as(mask)
8 |
9 | mask[arange_id=2:
26 | nn.init.xavier_uniform_(p.data)
27 | else:
28 | raise ValueError(p.dim())
29 |
30 | return optimizer
31 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqa2/smrl_baseline.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqa2/smrl_baseline
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqa2 # or vqa2vg
7 | dir: data/vqa/vqa2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | vg: false
19 | model:
20 | name: default
21 | network:
22 | import: cfvqa.models.networks.factory
23 | base: smrl
24 | name: baseline
25 | cfvqa_params:
26 | mlp_q:
27 | input_dim: 4800
28 | dimensions: [1024,1024,3000]
29 | mlp_v:
30 | input_dim: 2048
31 | dimensions: [1024,1024,3000]
32 | txt_enc:
33 | name: skipthoughts
34 | type: BayesianUniSkip
35 | dropout: 0.25
36 | fixed_emb: False
37 | dir_st: data/skip-thoughts
38 | self_q_att: True
39 | residual: False
40 | q_single: False
41 | fusion:
42 | type: block
43 | input_dims: [4800, 2048]
44 | output_dim: 2048
45 | mm_dim: 1000
46 | chunks: 20
47 | rank: 15
48 | dropout_input: 0.
49 | dropout_pre_lin: 0.
50 | agg:
51 | type: max
52 | classif:
53 | mlp:
54 | input_dim: 2048
55 | dimensions: [1024,1024,3000]
56 | criterion:
57 | import: cfvqa.models.criterions.factory
58 | name: vqa_cross_entropy
59 | question_loss_weight: 1.0
60 | vision_loss_weight: 1.0
61 | loss_temp: 100.0
62 | metric:
63 | import: cfvqa.models.metrics.factory
64 | name: vqa_accuracies
65 | optimizer:
66 | import: cfvqa.optimizers.factory
67 | name: Adam
68 | lr: 0.0003
69 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
70 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
71 | lr_decay_epochs: [14, 24, 2] #range
72 | lr_decay_rate: .25
73 | engine:
74 | name: logger
75 | debug: False
76 | print_freq: 10
77 | nb_epochs: 22
78 | saving_criteria:
79 | - eval_epoch.accuracy_top1:max
80 | misc:
81 | logs_name:
82 | cuda: True
83 | seed: 1337
84 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqa2/smrl_cfvqa_sum.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqa2/smrl_cfvqa_sum
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqa2 # or vqa2vg
7 | dir: data/vqa/vqa2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | vg: False
19 | model:
20 | name: default
21 | network:
22 | import: cfvqa.models.networks.factory
23 | base: smrl
24 | name: cfvqa
25 | fusion_mode: sum
26 | cfvqa_params:
27 | mlp_q:
28 | input_dim: 4800
29 | dimensions: [1024,1024,3000]
30 | mlp_v:
31 | input_dim: 2048
32 | dimensions: [1024,1024,3000]
33 | txt_enc:
34 | name: skipthoughts
35 | type: BayesianUniSkip
36 | dropout: 0.25
37 | fixed_emb: False
38 | dir_st: data/skip-thoughts
39 | self_q_att: True
40 | residual: False
41 | q_single: False
42 | fusion:
43 | type: block
44 | input_dims: [4800, 2048]
45 | output_dim: 2048
46 | mm_dim: 1000
47 | chunks: 20
48 | rank: 15
49 | dropout_input: 0.
50 | dropout_pre_lin: 0.
51 | agg:
52 | type: max
53 | classif:
54 | mlp:
55 | input_dim: 2048
56 | dimensions: [1024,1024,3000]
57 | criterion:
58 | import: cfvqa.models.criterions.factory
59 | name: cfvqa_criterion
60 | question_loss_weight: 1.0
61 | vision_loss_weight: 1.0
62 | metric:
63 | import: cfvqa.models.metrics.factory
64 | name: vqa_cfvqa_metrics
65 | optimizer:
66 | import: cfvqa.optimizers.factory
67 | name: Adam
68 | lr: 0.0003
69 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
70 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
71 | lr_decay_epochs: [14, 24, 2] #range
72 | lr_decay_rate: .25
73 | engine:
74 | name: logger
75 | debug: False
76 | print_freq: 10
77 | nb_epochs: 22
78 | saving_criteria:
79 | - eval_epoch.accuracy_all_top1:max
80 | - eval_epoch.accuracy_vq_top1:max
81 | - eval_epoch.accuracy_cfvqa_top1:max
82 | misc:
83 | logs_name:
84 | cuda: True
85 | seed: 1337
86 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqa2/smrl_cfvqaintrod_sum.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqa2/smrl_cfvqaintrod_sum
3 | resume: last # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqa2 # or vqa2vg
7 | dir: data/vqa/vqa2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | vg: false
19 | model:
20 | name: default
21 | network:
22 | import: cfvqa.models.networks.factory
23 | base: smrl
24 | name: cfvqaintrod
25 | fusion_mode: sum
26 | cfvqa_params:
27 | mlp_q:
28 | input_dim: 4800
29 | dimensions: [1024,1024,3000]
30 | mlp_v:
31 | input_dim: 2048
32 | dimensions: [1024,1024,3000]
33 | txt_enc:
34 | name: skipthoughts
35 | type: BayesianUniSkip
36 | dropout: 0.25
37 | fixed_emb: False
38 | dir_st: data/skip-thoughts
39 | self_q_att: True
40 | residual: False
41 | q_single: False
42 | fusion:
43 | type: block
44 | input_dims: [4800, 2048]
45 | output_dim: 2048
46 | mm_dim: 1000
47 | chunks: 20
48 | rank: 15
49 | dropout_input: 0.
50 | dropout_pre_lin: 0.
51 | agg:
52 | type: max
53 | classif:
54 | mlp:
55 | input_dim: 2048
56 | dimensions: [1024,1024,3000]
57 | criterion:
58 | import: cfvqa.models.criterions.factory
59 | name: cfvqaintrod_criterion
60 | metric:
61 | import: cfvqa.models.metrics.factory
62 | name: vqa_cfvqaintrod_metrics
63 | optimizer:
64 | import: cfvqa.optimizers.factory
65 | name: Adam
66 | lr: 0.0003
67 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
68 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
69 | lr_decay_epochs: [14, 24, 2] #range
70 | lr_decay_rate: .25
71 | engine:
72 | name: logger
73 | debug: False
74 | print_freq: 10
75 | nb_epochs: 22
76 | saving_criteria:
77 | - eval_epoch.accuracy_stu_top1:max
78 | misc:
79 | logs_name:
80 | cuda: True
81 | seed: 1337
82 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqa2/smrl_cfvqasimple_rubi.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqa2/smrl_cfvqasimple_rubi
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqa2 # or vqa2vg
7 | dir: data/vqa/vqa2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | vg: False
19 | model:
20 | name: default
21 | network:
22 | import: cfvqa.models.networks.factory
23 | base: smrl
24 | name: cfvqasimple
25 | fusion_mode: rubi
26 | is_vq: False
27 | cfvqa_params:
28 | mlp_q:
29 | input_dim: 4800
30 | dimensions: [1024,1024,3000]
31 | txt_enc:
32 | name: skipthoughts
33 | type: BayesianUniSkip
34 | dropout: 0.25
35 | fixed_emb: False
36 | dir_st: data/skip-thoughts
37 | self_q_att: True
38 | residual: False
39 | q_single: False
40 | fusion:
41 | type: block
42 | input_dims: [4800, 2048]
43 | output_dim: 2048
44 | mm_dim: 1000
45 | chunks: 20
46 | rank: 15
47 | dropout_input: 0.
48 | dropout_pre_lin: 0.
49 | agg:
50 | type: max
51 | classif:
52 | mlp:
53 | input_dim: 2048
54 | dimensions: [1024,1024,3000]
55 | criterion:
56 | import: cfvqa.models.criterions.factory
57 | name: cfvqasimple_criterion
58 | question_loss_weight: 1.0
59 | metric:
60 | import: cfvqa.models.metrics.factory
61 | name: vqa_cfvqasimple_metrics
62 | optimizer:
63 | import: cfvqa.optimizers.factory
64 | name: Adam
65 | lr: 0.0003
66 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
67 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
68 | lr_decay_epochs: [14, 24, 2] #range
69 | lr_decay_rate: .25
70 | engine:
71 | name: logger
72 | debug: False
73 | print_freq: 10
74 | nb_epochs: 22
75 | saving_criteria:
76 | - eval_epoch.accuracy_all_top1:max
77 | - eval_epoch.accuracy_vq_top1:max
78 | - eval_epoch.accuracy_cfvqa_top1:max
79 | misc:
80 | logs_name:
81 | cuda: True
82 | seed: 1337
83 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqa2/smrl_cfvqasimpleintrod_rubi.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqa2/smrl_cfvqasimpleintrod_sum
3 | resume: last # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqa2 # or vqa2vg
7 | dir: data/vqa/vqa2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | vg: false
19 | model:
20 | name: default
21 | network:
22 | import: cfvqa.models.networks.factory
23 | base: smrl
24 | name: cfvqasimpleintrod
25 | fusion_mode: sum
26 | cfvqa_params:
27 | mlp_q:
28 | input_dim: 4800
29 | dimensions: [1024,1024,3000]
30 | mlp_v:
31 | input_dim: 2048
32 | dimensions: [1024,1024,3000]
33 | txt_enc:
34 | name: skipthoughts
35 | type: BayesianUniSkip
36 | dropout: 0.25
37 | fixed_emb: False
38 | dir_st: data/skip-thoughts
39 | self_q_att: True
40 | residual: False
41 | q_single: False
42 | fusion:
43 | type: block
44 | input_dims: [4800, 2048]
45 | output_dim: 2048
46 | mm_dim: 1000
47 | chunks: 20
48 | rank: 15
49 | dropout_input: 0.
50 | dropout_pre_lin: 0.
51 | agg:
52 | type: max
53 | classif:
54 | mlp:
55 | input_dim: 2048
56 | dimensions: [1024,1024,3000]
57 | criterion:
58 | import: cfvqa.models.criterions.factory
59 | name: cfvqaintrod_criterion
60 | metric:
61 | import: cfvqa.models.metrics.factory
62 | name: vqa_cfvqaintrod_metrics
63 | optimizer:
64 | import: cfvqa.optimizers.factory
65 | name: Adam
66 | lr: 0.0003
67 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
68 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
69 | lr_decay_epochs: [14, 24, 2] #range
70 | lr_decay_rate: .25
71 | engine:
72 | name: logger
73 | debug: False
74 | print_freq: 10
75 | nb_epochs: 22
76 | saving_criteria:
77 | - eval_epoch.accuracy_stu_top1:max
78 | misc:
79 | logs_name:
80 | cuda: True
81 | seed: 1337
82 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqa2/smrl_rubi.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqa2/smrl_rubi
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqa2 # or vqa2vg
7 | dir: data/vqa/vqa2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | vg: false
19 | model:
20 | name: default
21 | network:
22 | import: cfvqa.models.networks.factory
23 | base: smrl
24 | name: rubi
25 | rubi_params:
26 | mlp_q:
27 | input_dim: 4800
28 | dimensions: [1024,1024,3000]
29 | txt_enc:
30 | name: skipthoughts
31 | type: BayesianUniSkip
32 | dropout: 0.25
33 | fixed_emb: False
34 | dir_st: data/skip-thoughts
35 | self_q_att: True
36 | residual: False
37 | q_single: False
38 | fusion:
39 | type: block
40 | input_dims: [4800, 2048]
41 | output_dim: 2048
42 | mm_dim: 1000
43 | chunks: 20
44 | rank: 15
45 | dropout_input: 0.
46 | dropout_pre_lin: 0.
47 | agg:
48 | type: max
49 | classif:
50 | mlp:
51 | input_dim: 2048
52 | dimensions: [1024,1024,3000]
53 | criterion:
54 | import: cfvqa.models.criterions.factory
55 | name: rubi_criterion
56 | question_loss_weight: 1.0
57 | metric:
58 | import: cfvqa.models.metrics.factory
59 | name: vqa_rubi_metrics
60 | optimizer:
61 | import: cfvqa.optimizers.factory
62 | name: Adam
63 | lr: 0.0003
64 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
65 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
66 | lr_decay_epochs: [14, 24, 2] #range
67 | lr_decay_rate: .25
68 | engine:
69 | name: logger
70 | debug: False
71 | print_freq: 10
72 | nb_epochs: 22
73 | saving_criteria:
74 | - eval_epoch.accuracy_top1:max
75 | - eval_epoch.accuracy_all_top1:max
76 | misc:
77 | logs_name:
78 | cuda: True
79 | seed: 1337
80 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqa2/smrl_rubiintrod.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqa2/smrl_rubiintrod
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqa2 # or vqa2vg
7 | dir: data/vqa/vqa2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | vg: false
19 | model:
20 | name: default
21 | network:
22 | import: cfvqa.models.networks.factory
23 | base: smrl
24 | name: rubiintrod
25 | rubi_params:
26 | mlp_q:
27 | input_dim: 4800
28 | dimensions: [1024,1024,3000]
29 | txt_enc:
30 | name: skipthoughts
31 | type: BayesianUniSkip
32 | dropout: 0.25
33 | fixed_emb: False
34 | dir_st: data/skip-thoughts
35 | self_q_att: True
36 | residual: False
37 | q_single: False
38 | fusion:
39 | type: block
40 | input_dims: [4800, 2048]
41 | output_dim: 2048
42 | mm_dim: 1000
43 | chunks: 20
44 | rank: 15
45 | dropout_input: 0.
46 | dropout_pre_lin: 0.
47 | agg:
48 | type: max
49 | classif:
50 | mlp:
51 | input_dim: 2048
52 | dimensions: [1024,1024,3000]
53 | criterion:
54 | import: cfvqa.models.criterions.factory
55 | name: rubiintrod_criterion
56 | metric:
57 | import: cfvqa.models.metrics.factory
58 | name: vqa_rubiintrod_metrics
59 | optimizer:
60 | import: cfvqa.optimizers.factory
61 | name: Adam
62 | lr: 0.0003
63 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
64 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
65 | lr_decay_epochs: [14, 24, 2] #range
66 | lr_decay_rate: .25
67 | engine:
68 | name: logger
69 | debug: False
70 | print_freq: 10
71 | nb_epochs: 22
72 | saving_criteria:
73 | - eval_epoch.accuracy_stu_top1:max
74 | misc:
75 | logs_name:
76 | cuda: True
77 | seed: 1337
78 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqacp2/smrl_baseline.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqacp2/smrl_baseline
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqacp2 # or vqa2vg
7 | dir: data/vqa/vqacp2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | model:
19 | name: default
20 | network:
21 | import: cfvqa.models.networks.factory
22 | base: smrl
23 | name: baseline
24 | cfvqa_params:
25 | mlp_q:
26 | input_dim: 4800
27 | dimensions: [1024,1024,3000]
28 | mlp_v:
29 | input_dim: 2048
30 | dimensions: [1024,1024,3000]
31 | txt_enc:
32 | name: skipthoughts
33 | type: BayesianUniSkip
34 | dropout: 0.25
35 | fixed_emb: False
36 | dir_st: data/skip-thoughts
37 | self_q_att: True
38 | residual: False
39 | q_single: False
40 | fusion:
41 | type: block
42 | input_dims: [4800, 2048]
43 | output_dim: 2048
44 | mm_dim: 1000
45 | chunks: 20
46 | rank: 15
47 | dropout_input: 0.
48 | dropout_pre_lin: 0.
49 | agg:
50 | type: max
51 | classif:
52 | mlp:
53 | input_dim: 2048
54 | dimensions: [1024,1024,3000]
55 | criterion:
56 | import: cfvqa.models.criterions.factory
57 | name: vqa_cross_entropy
58 | question_loss_weight: 1.0
59 | vision_loss_weight: 1.0
60 | loss_temp: 100.0
61 | metric:
62 | import: cfvqa.models.metrics.factory
63 | name: vqa_accuracies
64 | optimizer:
65 | import: cfvqa.optimizers.factory
66 | name: Adam
67 | lr: 0.0003
68 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
69 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
70 | lr_decay_epochs: [14, 24, 2] #range
71 | lr_decay_rate: .25
72 | engine:
73 | name: logger
74 | debug: False
75 | print_freq: 10
76 | nb_epochs: 22
77 | saving_criteria:
78 | - eval_epoch.accuracy_top1:max
79 | misc:
80 | logs_name:
81 | cuda: True
82 | seed: 1337
83 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqacp2/smrl_cfvqa_sum.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqacp2/smrl_cfvqa_sum
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqacp2 # or vqa2vg
7 | dir: data/vqa/vqacp2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | model:
19 | name: default
20 | network:
21 | import: cfvqa.models.networks.factory
22 | base: smrl
23 | name: cfvqa
24 | fusion_mode: sum
25 | cfvqa_params:
26 | mlp_q:
27 | input_dim: 4800
28 | dimensions: [1024,1024,3000]
29 | mlp_v:
30 | input_dim: 2048
31 | dimensions: [1024,1024,3000]
32 | txt_enc:
33 | name: skipthoughts
34 | type: BayesianUniSkip
35 | dropout: 0.25
36 | fixed_emb: False
37 | dir_st: data/skip-thoughts
38 | self_q_att: True
39 | residual: False
40 | q_single: False
41 | fusion:
42 | type: block
43 | input_dims: [4800, 2048]
44 | output_dim: 2048
45 | mm_dim: 1000
46 | chunks: 20
47 | rank: 15
48 | dropout_input: 0.
49 | dropout_pre_lin: 0.
50 | agg:
51 | type: max
52 | classif:
53 | mlp:
54 | input_dim: 2048
55 | dimensions: [1024,1024,3000]
56 | criterion:
57 | import: cfvqa.models.criterions.factory
58 | name: cfvqa_criterion
59 | question_loss_weight: 1.0
60 | vision_loss_weight: 1.0
61 | metric:
62 | import: cfvqa.models.metrics.factory
63 | name: vqa_cfvqa_metrics
64 | optimizer:
65 | import: cfvqa.optimizers.factory
66 | name: Adam
67 | lr: 0.0003
68 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
69 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
70 | lr_decay_epochs: [14, 24, 2] #range
71 | lr_decay_rate: .25
72 | engine:
73 | name: logger
74 | debug: False
75 | print_freq: 10
76 | nb_epochs: 22
77 | saving_criteria:
78 | - eval_epoch.accuracy_all_top1:max
79 | - eval_epoch.accuracy_vq_top1:max
80 | - eval_epoch.accuracy_cfvqa_top1:max
81 | misc:
82 | logs_name:
83 | cuda: True
84 | seed: 1337
85 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqacp2/smrl_cfvqaintrod_sum.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqacp2/smrl_cfvqaintrod_sum
3 | resume: last # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqacp2 # or vqa2vg
7 | dir: data/vqa/vqacp2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | model:
19 | name: default
20 | network:
21 | import: cfvqa.models.networks.factory
22 | base: smrl
23 | name: cfvqaintrod
24 | fusion_mode: sum
25 | cfvqa_params:
26 | mlp_q:
27 | input_dim: 4800
28 | dimensions: [1024,1024,3000]
29 | mlp_v:
30 | input_dim: 2048
31 | dimensions: [1024,1024,3000]
32 | txt_enc:
33 | name: skipthoughts
34 | type: BayesianUniSkip
35 | dropout: 0.25
36 | fixed_emb: False
37 | dir_st: data/skip-thoughts
38 | self_q_att: True
39 | residual: False
40 | q_single: False
41 | fusion:
42 | type: block
43 | input_dims: [4800, 2048]
44 | output_dim: 2048
45 | mm_dim: 1000
46 | chunks: 20
47 | rank: 15
48 | dropout_input: 0.
49 | dropout_pre_lin: 0.
50 | agg:
51 | type: max
52 | classif:
53 | mlp:
54 | input_dim: 2048
55 | dimensions: [1024,1024,3000]
56 | criterion:
57 | import: cfvqa.models.criterions.factory
58 | name: cfvqaintrod_criterion
59 | metric:
60 | import: cfvqa.models.metrics.factory
61 | name: vqa_cfvqaintrod_metrics
62 | optimizer:
63 | import: cfvqa.optimizers.factory
64 | name: Adam
65 | lr: 0.0003
66 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
67 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
68 | lr_decay_epochs: [14, 24, 2] #range
69 | lr_decay_rate: .25
70 | engine:
71 | name: logger
72 | debug: False
73 | print_freq: 10
74 | nb_epochs: 22
75 | saving_criteria:
76 | - eval_epoch.accuracy_stu_top1:max
77 | misc:
78 | logs_name:
79 | cuda: True
80 | seed: 1337
81 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqacp2/smrl_cfvqasimple_rubi.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqacp2/smrl_cfvqasimple_rubi
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqacp2 # or vqa2vg
7 | dir: data/vqa/vqacp2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | model:
19 | name: default
20 | network:
21 | import: cfvqa.models.networks.factory
22 | base: smrl
23 | name: cfvqasimple
24 | fusion_mode: rubi
25 | cfvqa_params:
26 | mlp_q:
27 | input_dim: 4800
28 | dimensions: [1024,1024,3000]
29 | mlp_v:
30 | input_dim: 2048
31 | dimensions: [1024,1024,3000]
32 | txt_enc:
33 | name: skipthoughts
34 | type: BayesianUniSkip
35 | dropout: 0.25
36 | fixed_emb: False
37 | dir_st: data/skip-thoughts
38 | self_q_att: True
39 | residual: False
40 | q_single: False
41 | fusion:
42 | type: block
43 | input_dims: [4800, 2048]
44 | output_dim: 2048
45 | mm_dim: 1000
46 | chunks: 20
47 | rank: 15
48 | dropout_input: 0.
49 | dropout_pre_lin: 0.
50 | agg:
51 | type: max
52 | classif:
53 | mlp:
54 | input_dim: 2048
55 | dimensions: [1024,1024,3000]
56 | criterion:
57 | import: cfvqa.models.criterions.factory
58 | name: cfvqasimple_criterion
59 | question_loss_weight: 1.0
60 | metric:
61 | import: cfvqa.models.metrics.factory
62 | name: vqa_cfvqasimple_metrics
63 | optimizer:
64 | import: cfvqa.optimizers.factory
65 | name: Adam
66 | lr: 0.0003
67 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
68 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
69 | lr_decay_epochs: [14, 24, 2] #range
70 | lr_decay_rate: .25
71 | engine:
72 | name: logger
73 | debug: False
74 | print_freq: 10
75 | nb_epochs: 22
76 | saving_criteria:
77 | - eval_epoch.accuracy_all_top1:max
78 | - eval_epoch.accuracy_vq_top1:max
79 | - eval_epoch.accuracy_cfvqa_top1:max
80 | misc:
81 | logs_name:
82 | cuda: True
83 | seed: 1337
84 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqacp2/smrl_cfvqasimpleintrod_rubi.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqacp2/smrl_cfvqasimpleintrod_rubi
3 | resume: last # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqacp2 # or vqa2vg
7 | dir: data/vqa/vqacp2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | model:
19 | name: default
20 | network:
21 | import: cfvqa.models.networks.factory
22 | base: smrl
23 | name: cfvqasimpleintrod
24 | fusion_mode: rubi
25 | is_vq: False
26 | cfvqa_params:
27 | mlp_q:
28 | input_dim: 4800
29 | dimensions: [1024,1024,3000]
30 | txt_enc:
31 | name: skipthoughts
32 | type: BayesianUniSkip
33 | dropout: 0.25
34 | fixed_emb: False
35 | dir_st: data/skip-thoughts
36 | self_q_att: True
37 | residual: False
38 | q_single: False
39 | fusion:
40 | type: block
41 | input_dims: [4800, 2048]
42 | output_dim: 2048
43 | mm_dim: 1000
44 | chunks: 20
45 | rank: 15
46 | dropout_input: 0.
47 | dropout_pre_lin: 0.
48 | agg:
49 | type: max
50 | classif:
51 | mlp:
52 | input_dim: 2048
53 | dimensions: [1024,1024,3000]
54 | criterion:
55 | import: cfvqa.models.criterions.factory
56 | name: cfvqaintrod_criterion
57 | metric:
58 | import: cfvqa.models.metrics.factory
59 | name: vqa_cfvqaintrod_metrics
60 | optimizer:
61 | import: cfvqa.optimizers.factory
62 | name: Adam
63 | lr: 0.0003
64 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
65 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
66 | lr_decay_epochs: [14, 24, 2] #range
67 | lr_decay_rate: .25
68 | engine:
69 | name: logger
70 | debug: False
71 | print_freq: 10
72 | nb_epochs: 22
73 | saving_criteria:
74 | - eval_epoch.accuracy_stu_top1:max
75 | misc:
76 | logs_name:
77 | cuda: True
78 | seed: 1337
79 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqacp2/smrl_rubi.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqacp2/smrl_rubi
3 | resume: # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqacp2 # or vqa2vg
7 | dir: data/vqa/vqacp2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | model:
19 | name: default
20 | network:
21 | import: cfvqa.models.networks.factory
22 | base: smrl
23 | name: rubi
24 | rubi_params:
25 | mlp_q:
26 | input_dim: 4800
27 | dimensions: [1024,1024,3000]
28 | txt_enc:
29 | name: skipthoughts
30 | type: BayesianUniSkip
31 | dropout: 0.25
32 | fixed_emb: False
33 | dir_st: data/skip-thoughts
34 | self_q_att: True
35 | residual: False
36 | q_single: False
37 | fusion:
38 | type: block
39 | input_dims: [4800, 2048]
40 | output_dim: 2048
41 | mm_dim: 1000
42 | chunks: 20
43 | rank: 15
44 | dropout_input: 0.
45 | dropout_pre_lin: 0.
46 | agg:
47 | type: max
48 | classif:
49 | mlp:
50 | input_dim: 2048
51 | dimensions: [1024,1024,3000]
52 | criterion:
53 | import: cfvqa.models.criterions.factory
54 | name: rubi_criterion
55 | question_loss_weight: 1.0
56 | metric:
57 | import: cfvqa.models.metrics.factory
58 | name: vqa_rubi_metrics
59 | optimizer:
60 | import: cfvqa.optimizers.factory
61 | name: Adam
62 | lr: 0.0003
63 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
64 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
65 | lr_decay_epochs: [14, 24, 2] #range
66 | lr_decay_rate: .25
67 | engine:
68 | name: logger
69 | debug: False
70 | print_freq: 10
71 | nb_epochs: 22
72 | saving_criteria:
73 | - eval_epoch.accuracy_top1:max
74 | - eval_epoch.accuracy_all_top1:max
75 | misc:
76 | logs_name:
77 | cuda: True
78 | seed: 1337
79 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/options/vqacp2/smrl_rubiintrod.yaml:
--------------------------------------------------------------------------------
1 | exp:
2 | dir: logs/vqacp2/smrl_rubiintrod
3 | resume: last # last, best_[...], or empty (from scratch)
4 | dataset:
5 | import: cfvqa.datasets.factory
6 | name: vqacp2 # or vqa2vg
7 | dir: data/vqa/vqacp2
8 | train_split: train
9 | eval_split: val # or test
10 | proc_split: train # or trainval (preprocessing split, must be equal to train_split)
11 | nb_threads: 4
12 | batch_size: 256
13 | nans: 3000
14 | minwcount: 0
15 | nlp: mcb
16 | samplingans: True
17 | dir_rcnn: data/vqa/coco/extract_rcnn/2018-04-27_bottom-up-attention_fixed_36
18 | model:
19 | name: default
20 | network:
21 | import: cfvqa.models.networks.factory
22 | base: smrl
23 | name: rubiintrod
24 | rubi_params:
25 | mlp_q:
26 | input_dim: 4800
27 | dimensions: [1024,1024,3000]
28 | txt_enc:
29 | name: skipthoughts
30 | type: BayesianUniSkip
31 | dropout: 0.25
32 | fixed_emb: False
33 | dir_st: data/skip-thoughts
34 | self_q_att: True
35 | residual: False
36 | q_single: False
37 | fusion:
38 | type: block
39 | input_dims: [4800, 2048]
40 | output_dim: 2048
41 | mm_dim: 1000
42 | chunks: 20
43 | rank: 15
44 | dropout_input: 0.
45 | dropout_pre_lin: 0.
46 | agg:
47 | type: max
48 | classif:
49 | mlp:
50 | input_dim: 2048
51 | dimensions: [1024,1024,3000]
52 | criterion:
53 | import: cfvqa.models.criterions.factory
54 | name: rubiintrod_criterion
55 | question_loss_weight: 1.0
56 | metric:
57 | import: cfvqa.models.metrics.factory
58 | name: vqa_rubiintrod_metrics
59 | optimizer:
60 | import: cfvqa.optimizers.factory
61 | name: Adam
62 | lr: 0.0003
63 | gradual_warmup_steps: [0.5, 2.0, 7.0] #torch.linspace
64 | gradual_warmup_steps_mm: [0.5, 2.0, 7.0] #torch.linspace
65 | lr_decay_epochs: [14, 24, 2] #range
66 | lr_decay_rate: .25
67 | engine:
68 | name: logger
69 | debug: False
70 | print_freq: 10
71 | nb_epochs: 22
72 | saving_criteria:
73 | - eval_epoch.accuracy_top1:max
74 | - eval_epoch.accuracy_all_top1:max
75 | misc:
76 | logs_name:
77 | cuda: True
78 | seed: 1337
79 |
--------------------------------------------------------------------------------
/cfvqa/cfvqa/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import click
3 | import traceback
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 |
7 | from bootstrap.lib import utils
8 | from bootstrap.lib.logger import Logger
9 | from bootstrap.lib.options import Options
10 | from cfvqa import engines
11 | from bootstrap import datasets
12 | from bootstrap import models
13 | from bootstrap import optimizers
14 | from bootstrap import views
15 |
16 |
17 | def init_experiment_directory(exp_dir, resume=None):
18 | # create the experiment directory
19 | if not os.path.isdir(exp_dir):
20 | os.system('mkdir -p ' + exp_dir)
21 | else:
22 | if resume is None:
23 | if click.confirm('Exp directory already exists in {}. Erase?'
24 | .format(exp_dir, default=False)):
25 | os.system('rm -r ' + exp_dir)
26 | os.system('mkdir -p ' + exp_dir)
27 | else:
28 | os._exit(1)
29 |
30 |
31 | def init_logs_options_files(exp_dir, resume=None):
32 | # get the logs name which is used for the txt, json and yaml files
33 | # default is `logs.txt`, `logs.json` and `options.yaml`
34 | if 'logs_name' in Options()['misc'] and Options()['misc']['logs_name'] is not None:
35 | logs_name = 'logs_{}'.format(Options()['misc']['logs_name'])
36 | path_yaml = os.path.join(exp_dir, 'options_{}.yaml'.format(logs_name))
37 | elif resume and Options()['dataset']['train_split'] is None:
38 | eval_split = Options()['dataset']['eval_split']
39 | path_yaml = os.path.join(exp_dir, 'options_eval_{}.yaml'.format(eval_split))
40 | logs_name = 'logs_eval_{}'.format(eval_split)
41 | else:
42 | path_yaml = os.path.join(exp_dir, 'options.yaml')
43 | logs_name = 'logs'
44 |
45 | # create the options.yaml file
46 | if not os.path.isfile(path_yaml):
47 | Options().save(path_yaml)
48 |
49 | # create the logs.txt and logs.json files
50 | Logger(exp_dir, name=logs_name)
51 |
52 |
53 | def run(path_opts=None):
54 | # first call to Options() load the options yaml file from --path_opts command line argument if path_opts=None
55 | Options(path_opts)
56 | # initialiaze seeds to be able to reproduce experiment on reload
57 | utils.set_random_seed(Options()['misc']['seed'])
58 |
59 | init_experiment_directory(Options()['exp']['dir'], Options()['exp']['resume'])
60 | init_logs_options_files(Options()['exp']['dir'], Options()['exp']['resume'])
61 |
62 | Logger().log_dict('options', Options(), should_print=True) # display options
63 | Logger()(os.uname()) # display server name
64 |
65 | if torch.cuda.is_available():
66 | cudnn.benchmark = True
67 | Logger()('Available GPUs: {}'.format(utils.available_gpu_ids()))
68 |
69 | # engine can train, eval, optimize the model
70 | # engine can save and load the model and optimizer
71 | engine = engines.factory()
72 |
73 | # dataset is a dictionary that contains all the needed datasets indexed by modes
74 | # (example: dataset.keys() -> ['train','eval'])
75 | engine.dataset = datasets.factory(engine)
76 |
77 | # model includes a network, a criterion and a metric
78 | # model can register engine hooks (begin epoch, end batch, end batch, etc.)
79 | # (example: "calculate mAP at the end of the evaluation epoch")
80 | # note: model can access to datasets using engine.dataset
81 | engine.model = models.factory(engine)
82 |
83 | # optimizer can register engine hooks
84 | engine.optimizer = optimizers.factory(engine.model, engine)
85 |
86 | # view will save a view.html in the experiment directory
87 | # with some nice plots and curves to monitor training
88 | engine.view = views.factory(engine)
89 |
90 | # load the model and optimizer from a checkpoint
91 | if Options()['exp']['resume']:
92 | engine.resume()
93 |
94 | # if no training split, evaluate the model on the evaluation split
95 | # (example: $ python main.py --dataset.train_split --dataset.eval_split test)
96 | if not Options()['dataset']['train_split']:
97 | engine.eval()
98 |
99 | # optimize the model on the training split for several epochs
100 | # (example: $ python main.py --dataset.train_split train)
101 | # if evaluation split, evaluate the model after each epochs
102 | # (example: $ python main.py --dataset.train_split train --dataset.eval_split val)
103 | if Options()['dataset']['train_split']:
104 | engine.train()
105 | # with torch.autograd.profiler.profile(use_cuda=Options()['misc.cuda']) as prof:
106 | # engine.train()
107 | # path_tracing = 'tracing_1.0_cuda,{}_all.html'.format(Options()['misc.cuda'])
108 | # prof.export_chrome_trace(path_tracing)
109 |
110 |
111 | def main(path_opts=None, run=None):
112 | try:
113 | run(path_opts=path_opts)
114 | # to avoid traceback for -h flag in arguments line
115 | except SystemExit:
116 | pass
117 | except:
118 | # to be able to write the error trace to exp_dir/logs.txt
119 | try:
120 | Logger()(traceback.format_exc(), Logger.ERROR)
121 | except:
122 | pass
123 |
124 |
125 | if __name__ == '__main__':
126 | main(run=run)
127 |
128 |
--------------------------------------------------------------------------------
/cfvqa/requirements.txt:
--------------------------------------------------------------------------------
1 | block.bootstrap.pytorch
2 | h5py
3 | plotly==3.10.0
--------------------------------------------------------------------------------
/cfvqa/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import click
3 | import traceback
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 |
7 | from bootstrap.lib import utils
8 | from bootstrap.lib.logger import Logger
9 | from bootstrap.lib.options import Options
10 | from cfvqa import engines
11 | from bootstrap import datasets
12 | from bootstrap import models
13 | from bootstrap import optimizers
14 | from bootstrap import views
15 |
16 |
17 | def init_experiment_directory(exp_dir, resume=None):
18 | # create the experiment directory
19 | if not os.path.isdir(exp_dir):
20 | os.system('mkdir -p ' + exp_dir)
21 | else:
22 | if resume is None:
23 | if click.confirm('Exp directory already exists in {}. Erase?'
24 | .format(exp_dir, default=False)):
25 | os.system('rm -r ' + exp_dir)
26 | os.system('mkdir -p ' + exp_dir)
27 | else:
28 | os._exit(1)
29 |
30 |
31 | def init_logs_options_files(exp_dir, resume=None):
32 | # get the logs name which is used for the txt, json and yaml files
33 | # default is `logs.txt`, `logs.json` and `options.yaml`
34 | if 'logs_name' in Options()['misc'] and Options()['misc']['logs_name'] is not None:
35 | logs_name = 'logs_{}'.format(Options()['misc']['logs_name'])
36 | path_yaml = os.path.join(exp_dir, 'options_{}.yaml'.format(logs_name))
37 | elif resume and Options()['dataset']['train_split'] is None:
38 | eval_split = Options()['dataset']['eval_split']
39 | path_yaml = os.path.join(exp_dir, 'options_eval_{}.yaml'.format(eval_split))
40 | logs_name = 'logs_eval_{}'.format(eval_split)
41 | else:
42 | path_yaml = os.path.join(exp_dir, 'options.yaml')
43 | logs_name = 'logs'
44 |
45 | # create the options.yaml file
46 | if not os.path.isfile(path_yaml):
47 | Options().save(path_yaml)
48 |
49 | # create the logs.txt and logs.json files
50 | Logger(exp_dir, name=logs_name)
51 |
52 |
53 | def run(path_opts=None):
54 | # first call to Options() load the options yaml file from --path_opts command line argument if path_opts=None
55 | Options(path_opts)
56 | # initialiaze seeds to be able to reproduce experiment on reload
57 | utils.set_random_seed(Options()['misc']['seed'])
58 |
59 | init_experiment_directory(Options()['exp']['dir'], Options()['exp']['resume'])
60 | init_logs_options_files(Options()['exp']['dir'], Options()['exp']['resume'])
61 |
62 | Logger().log_dict('options', Options(), should_print=True) # display options
63 | Logger()(os.uname()) # display server name
64 |
65 | if torch.cuda.is_available():
66 | cudnn.benchmark = True
67 | Logger()('Available GPUs: {}'.format(utils.available_gpu_ids()))
68 |
69 | # engine can train, eval, optimize the model
70 | # engine can save and load the model and optimizer
71 | engine = engines.factory()
72 |
73 | # dataset is a dictionary that contains all the needed datasets indexed by modes
74 | # (example: dataset.keys() -> ['train','eval'])
75 | engine.dataset = datasets.factory(engine)
76 |
77 | # model includes a network, a criterion and a metric
78 | # model can register engine hooks (begin epoch, end batch, end batch, etc.)
79 | # (example: "calculate mAP at the end of the evaluation epoch")
80 | # note: model can access to datasets using engine.dataset
81 | engine.model = models.factory(engine)
82 |
83 | # optimizer can register engine hooks
84 | engine.optimizer = optimizers.factory(engine.model, engine)
85 |
86 | # view will save a view.html in the experiment directory
87 | # with some nice plots and curves to monitor training
88 | engine.view = views.factory(engine)
89 |
90 | # load the model and optimizer from a checkpoint
91 | if Options()['exp']['resume']:
92 | engine.resume()
93 |
94 | # if no training split, evaluate the model on the evaluation split
95 | # (example: $ python main.py --dataset.train_split --dataset.eval_split test)
96 | if not Options()['dataset']['train_split']:
97 | engine.eval()
98 |
99 | # optimize the model on the training split for several epochs
100 | # (example: $ python main.py --dataset.train_split train)
101 | # if evaluation split, evaluate the model after each epochs
102 | # (example: $ python main.py --dataset.train_split train --dataset.eval_split val)
103 | if Options()['dataset']['train_split']:
104 | engine.train()
105 | # with torch.autograd.profiler.profile(use_cuda=Options()['misc.cuda']) as prof:
106 | # engine.train()
107 | # path_tracing = 'tracing_1.0_cuda,{}_all.html'.format(Options()['misc.cuda'])
108 | # prof.export_chrome_trace(path_tracing)
109 |
110 |
111 | def main(path_opts=None, run=None):
112 | try:
113 | run(path_opts=path_opts)
114 | # to avoid traceback for -h flag in arguments line
115 | except SystemExit:
116 | pass
117 | except:
118 | # to be able to write the error trace to exp_dir/logs.txt
119 | try:
120 | Logger()(traceback.format_exc(), Logger.ERROR)
121 | except:
122 | pass
123 |
124 |
125 | if __name__ == '__main__':
126 | main(run=run)
127 |
128 |
--------------------------------------------------------------------------------
/cfvqa/run_vqa2_cfvqa_introd.sh:
--------------------------------------------------------------------------------
1 | python -m bootstrap.run -o cfvqa/options/vqa2/smrl_cfvqa_sum.yaml
2 | mkdir ./logs/vqa2/smrl_cfvqaintrod_sum/
3 | cp ./logs/vqa2/smrl_cfvqa_sum/ ./logs/vqa2/smrl_cfvqaintrod_sum/
4 | python -m run -o ./cfvqa/options/vqa2/smrl_cfvqaintrod_sum.yaml
--------------------------------------------------------------------------------
/cfvqa/scripts/run_vqa2_cfvqa_introd.sh:
--------------------------------------------------------------------------------
1 | python -m bootstrap.run -o cfvqa/options/vqa2/smrl_cfvqa_sum.yaml
2 | mkdir ./logs/vqa2/smrl_cfvqaintrod_sum/
3 | cp ./logs/vqa2/smrl_cfvqa_sum/ ./logs/vqa2/smrl_cfvqaintrod_sum/
4 | python -m run -o ./cfvqa/options/vqa2/smrl_cfvqaintrod_sum.yaml
--------------------------------------------------------------------------------
/cfvqa/scripts/run_vqa2_rubi_introd.sh:
--------------------------------------------------------------------------------
1 | python -m bootstrap.run -o cfvqa/options/vqa2/smrl_rubi.yaml
2 | mkdir ./logs/vqa2/smrl_rubiintrod/
3 | cp ./logs/vqa2/smrl_rubi/ ./logs/vqa2/smrl_rubiintrod/
4 | python -m run -o ./cfvqa/options/vqa2/smrl_rubiintrod.yaml
--------------------------------------------------------------------------------
/cfvqa/scripts/run_vqa2_rubicf_introd.sh:
--------------------------------------------------------------------------------
1 | python -m bootstrap.run -o cfvqa/options/vqa2/smrl_cfvqasimple_rubi.yaml
2 | mkdir ./logs/vqa2/smrl_cfvqasimpleintrod_rubi/
3 | cp ./logs/vqa2/smrl_cfvqasimple_rubi/ ./logs/vqa2/smrl_cfvqasimpleintrod_rubi/
4 | python -m run -o ./cfvqa/options/vqa2/smrl_cfvqasimpleintrod_rubi.yaml
--------------------------------------------------------------------------------
/cfvqa/scripts/run_vqacp2_rubi_introd.sh:
--------------------------------------------------------------------------------
1 | python -m bootstrap.run -o cfvqa/options/vqacp2/smrl_rubi.yaml
2 | mkdir ./logs/vqacp2/smrl_rubiintrod/
3 | cp ./logs/vqacp2/smrl_rubi/ ./logs/vqacp2/smrl_rubiintrod/
4 | python -m run -o ./cfvqa/options/vqacp2/smrl_rubiintrod.yaml
--------------------------------------------------------------------------------
/cfvqa/scripts/run_vqacp2_rubicf_introd.sh:
--------------------------------------------------------------------------------
1 | python -m bootstrap.run -o cfvqa/options/vqacp2/smrl_cfvqasimple_rubi.yaml
2 | mkdir ./logs/vqacp2/smrl_cfvqasimpleintrod_rubi/
3 | cp ./logs/vqacp2/smrl_cfvqasimple_rubi/ ./logs/vqacp2/smrl_cfvqasimpleintrod_rubi/
4 | python -m run -o ./cfvqa/options/vqacp2/smrl_cfvqasimpleintrod_rubi.yaml
--------------------------------------------------------------------------------
/css/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | data/
132 | data
133 | logs/
134 | logs
135 |
--------------------------------------------------------------------------------
/css/README.md:
--------------------------------------------------------------------------------
1 | # CSS+IntroD
2 |
3 | This code provides the implementation of IntroD on top of LMH and CSS on the VQA task. This code is modified from [css](https://github.com/yanxinzju/CSS-VQA) and [lmh](https://github.com/chrisc36/bottom-up-attention-vqa). The data is provided by the authors of "Counterfactual Samples Synthesizing for Robust VQA".
4 |
5 | ## Prerequisites
6 |
7 | (Follow [css](https://github.com/yanxinzju/CSS-VQA))
8 |
9 | Make sure you are on a machine with a NVIDIA GPU and Python 2.7 with about 100 GB disk space.
10 | h5py==2.10.0
11 | pytorch==1.1.0
12 | Click==7.0
13 | numpy==1.16.5
14 | tqdm==4.35.0
15 |
16 | ## Data Setup
17 |
18 | (The data is provided by [css](https://github.com/yanxinzju/CSS-VQA).)
19 |
20 | You can use
21 | ```
22 | bash tools/download.sh
23 | ```
24 | to download the data
25 | and the rest of the data and trained model can be obtained from [BaiduYun](https://pan.baidu.com/s/1oHdwYDSJXC1mlmvu8cQhKw)(passwd:3jot) or [GoogleDrive](https://drive.google.com/drive/folders/13e-b76otJukupbjfC-n1s05L202PaFKQ?usp=sharing)
26 | unzip feature1.zip and feature2.zip and merge them into data/rcnn_feature/
27 | use
28 | ```
29 | bash tools/process.sh
30 | ```
31 | to process the data
32 |
33 | ## CSS+IntroD
34 |
35 | ### Step 1: train the teacher model
36 |
37 | (Follow [css](https://github.com/yanxinzju/CSS-VQA))
38 |
39 | Run
40 | ```
41 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset cpv2 --mode q_v_debias --debias learned_mixin --topq 1 --topv -1 --qvp 5 --output ./logs/vqacp2/css/
42 | ```
43 | to train a teacher model on VQA-CP v2.
44 |
45 | Or
46 |
47 | Run
48 | ```
49 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset v2 --mode q_v_debias --debias learned_mixin --topq 1 --topv -1 --qvp 5 --output ./logs/vqacp2/css/
50 | ```
51 | to train a teacher model on VQA v2.
52 |
53 | ### Step 2: train the student model
54 |
55 | Run
56 | ```
57 | CUDA_VISIBLE_DEVICES=0 python main_introd.py --dataset cpv2 --output ./logs/vqacp2/css_introd/ --source ./logs/vqacp2/css/
58 | ```
59 | to train a student model on VQA-CP v2.
60 |
61 | Or
62 |
63 | Run
64 | ```
65 | CUDA_VISIBLE_DEVICES=0 python main_introd.py --dataset v2 --output ./logs/vqa2/css_introd/ --source ./logs/vqa2/css/
66 | ```
67 | to train a student model on VQA v2.
68 |
69 | ## LMH+IntroD
70 |
71 | ### Step 1: train the teacher model
72 |
73 | (Follow [css](https://github.com/yanxinzju/CSS-VQA))
74 |
75 | Run
76 | ```
77 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset cpv2 --mode updn --debias learned_mixin --output ./logs/vqacp2/lmh/
78 | ```
79 | to train a teacher model on VQA-CP v2.
80 |
81 | Or
82 |
83 | Run
84 | ```
85 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset v2 --mode updn --debias learned_mixin --output ./logs/vqa2/lmh/
86 | ```
87 | to train a teacher model on VQA v2.
88 |
89 |
90 | ### Step 2: train the student model
91 |
92 | Run
93 | ```
94 | CUDA_VISIBLE_DEVICES=0 python main_introd.py --dataset cpv2 --output ./logs/vqacp2/lmh_introd/ --source ./logs/vqacp2/lmh/
95 | ```
96 | to train a student model on VQA-CP v2.
97 |
98 | Or
99 |
100 | Run
101 | ```
102 | CUDA_VISIBLE_DEVICES=0 python main_introd.py --dataset v2 --output ./logs/vqa2/lmh_introd/ --source ./logs/vqa2/lmh/
103 | ```
104 | to train a student model on VQA v2.
--------------------------------------------------------------------------------
/css/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.weight_norm import weight_norm
4 | from fc import FCNet
5 |
6 |
7 | class Attention(nn.Module):
8 | def __init__(self, v_dim, q_dim, num_hid):
9 | super(Attention, self).__init__()
10 | self.nonlinear = FCNet([v_dim + q_dim, num_hid])
11 | self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None)
12 |
13 | def forward(self, v, q):
14 | """
15 | v: [batch, k, vdim]
16 | q: [batch, qdim]
17 | """
18 | logits = self.logits(v, q)
19 | w = nn.functional.softmax(logits, 1)
20 | return w
21 |
22 | def logits(self, v, q):
23 | num_objs = v.size(1)
24 | q = q.unsqueeze(1).repeat(1, num_objs, 1)
25 | vq = torch.cat((v, q), 2)
26 | joint_repr = self.nonlinear(vq)
27 | logits = self.linear(joint_repr)
28 | return logits
29 |
30 |
31 | class NewAttention(nn.Module):
32 | def __init__(self, v_dim, q_dim, num_hid, dropout=0.2):
33 | super(NewAttention, self).__init__()
34 |
35 | self.v_proj = FCNet([v_dim, num_hid])
36 | self.q_proj = FCNet([q_dim, num_hid])
37 | self.dropout = nn.Dropout(dropout)
38 | self.linear = weight_norm(nn.Linear(q_dim, 1), dim=None)
39 |
40 | def forward(self, v, q):
41 | """
42 | v: [batch, k, vdim]
43 | q: [batch, qdim]
44 | """
45 | logits = self.logits(v, q)
46 | # w = nn.functional.softmax(logits, 1)
47 | # return w
48 | return logits
49 |
50 | def logits(self, v, q):
51 | batch, k, _ = v.size()
52 | v_proj = self.v_proj(v) # [batch, k, qdim]
53 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1)
54 | joint_repr = v_proj * q_proj
55 | joint_repr = self.dropout(joint_repr)
56 | logits = self.linear(joint_repr)
57 | return logits
58 |
--------------------------------------------------------------------------------
/css/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attention import Attention, NewAttention
4 | from language_model import WordEmbedding, QuestionEmbedding
5 | from classifier import SimpleClassifier
6 | from fc import FCNet
7 | import numpy as np
8 |
9 | def mask_softmax(x,mask):
10 | mask=mask.unsqueeze(2).float()
11 | x2=torch.exp(x-torch.max(x))
12 | x3=x2*mask
13 | epsilon=1e-5
14 | x3_sum=torch.sum(x3,dim=1,keepdim=True)+epsilon
15 | x4=x3/x3_sum.expand_as(x3)
16 | return x4
17 |
18 |
19 | class BaseModel(nn.Module):
20 | def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier):
21 | super(BaseModel, self).__init__()
22 | self.w_emb = w_emb
23 | self.q_emb = q_emb
24 | self.v_att = v_att
25 | self.q_net = q_net
26 | self.v_net = v_net
27 | self.classifier = classifier
28 | self.debias_loss_fn = None
29 | # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2))
30 | self.bias_lin = torch.nn.Linear(1024, 1)
31 |
32 | def forward(self, v, q, labels, bias,v_mask):
33 | """Forward
34 |
35 | v: [batch, num_objs, obj_dim]
36 | b: [batch, num_objs, b_dim]
37 | q: [batch_size, seq_length]
38 |
39 | return: logits, not probs
40 | """
41 | w_emb = self.w_emb(q)
42 | q_emb = self.q_emb(w_emb) # [batch, q_dim]
43 |
44 | att = self.v_att(v, q_emb)
45 | if v_mask is None:
46 | att = nn.functional.softmax(att, 1)
47 | else:
48 | att= mask_softmax(att,v_mask)
49 |
50 | v_emb = (att * v).sum(1) # [batch, v_dim]
51 |
52 | q_repr = self.q_net(q_emb)
53 | v_repr = self.v_net(v_emb)
54 | joint_repr = q_repr * v_repr
55 |
56 | logits = self.classifier(joint_repr)
57 |
58 | if labels is not None:
59 | loss = self.debias_loss_fn(joint_repr, logits, bias, labels)
60 | else:
61 | loss = None
62 | return logits, loss,w_emb
63 |
64 | def build_baseline0(dataset, num_hid):
65 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0)
66 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0)
67 | v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid)
68 | q_net = FCNet([num_hid, num_hid])
69 | v_net = FCNet([dataset.v_dim, num_hid])
70 | classifier = SimpleClassifier(
71 | num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5)
72 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier)
73 |
74 |
75 | def build_baseline0_newatt(dataset, num_hid):
76 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0)
77 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0)
78 | v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid)
79 | q_net = FCNet([q_emb.num_hid, num_hid])
80 | v_net = FCNet([dataset.v_dim, num_hid])
81 | classifier = SimpleClassifier(
82 | num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5)
83 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier)
--------------------------------------------------------------------------------
/css/base_model_introd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attention import Attention, NewAttention
4 | from language_model import WordEmbedding, QuestionEmbedding
5 | from classifier import SimpleClassifier
6 | from fc import FCNet
7 | import numpy as np
8 |
9 | def mask_softmax(x,mask):
10 | mask=mask.unsqueeze(2).float()
11 | x2=torch.exp(x-torch.max(x))
12 | x3=x2*mask
13 | epsilon=1e-5
14 | x3_sum=torch.sum(x3,dim=1,keepdim=True)+epsilon
15 | x4=x3/x3_sum.expand_as(x3)
16 | return x4
17 |
18 |
19 | class BaseModel(nn.Module):
20 | def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier):
21 | super(BaseModel, self).__init__()
22 | self.w_emb = w_emb
23 | self.q_emb = q_emb
24 | self.v_att = v_att
25 | self.q_net = q_net
26 | self.v_net = v_net
27 | self.classifier = classifier
28 | self.debias_loss_fn = None
29 | # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2))
30 | self.bias_lin = torch.nn.Linear(1024, 1)
31 |
32 | def forward(self, v, q, labels, bias,v_mask):
33 | """Forward
34 |
35 | v: [batch, num_objs, obj_dim]
36 | b: [batch, num_objs, b_dim]
37 | q: [batch_size, seq_length]
38 |
39 | return: logits, not probs
40 | """
41 | w_emb = self.w_emb(q)
42 | q_emb = self.q_emb(w_emb) # [batch, q_dim]
43 |
44 | att = self.v_att(v, q_emb)
45 | if v_mask is None:
46 | att = nn.functional.softmax(att, 1)
47 | else:
48 | att= mask_softmax(att,v_mask)
49 |
50 | v_emb = (att * v).sum(1) # [batch, v_dim]
51 |
52 | q_repr = self.q_net(q_emb)
53 | v_repr = self.v_net(v_emb)
54 | joint_repr = q_repr * v_repr
55 |
56 | logits = self.classifier(joint_repr)
57 |
58 | if labels is not None:
59 | logits_all, loss = self.debias_loss_fn(joint_repr, logits, bias, labels)
60 | else:
61 | logits_all = None
62 | loss = None
63 | return logits, logits_all, loss, w_emb
64 |
65 | def build_baseline0(dataset, num_hid):
66 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0)
67 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0)
68 | v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid)
69 | q_net = FCNet([num_hid, num_hid])
70 | v_net = FCNet([dataset.v_dim, num_hid])
71 | classifier = SimpleClassifier(
72 | num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5)
73 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier)
74 |
75 |
76 | def build_baseline0_newatt(dataset, num_hid):
77 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0)
78 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0)
79 | v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid)
80 | q_net = FCNet([q_emb.num_hid, num_hid])
81 | v_net = FCNet([dataset.v_dim, num_hid])
82 | classifier = SimpleClassifier(
83 | num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5)
84 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier)
--------------------------------------------------------------------------------
/css/classifier.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn.utils.weight_norm import weight_norm
3 |
4 |
5 | class SimpleClassifier(nn.Module):
6 | def __init__(self, in_dim, hid_dim, out_dim, dropout):
7 | super(SimpleClassifier, self).__init__()
8 | layers = [
9 | weight_norm(nn.Linear(in_dim, hid_dim), dim=None),
10 | nn.ReLU(),
11 | nn.Dropout(dropout, inplace=True),
12 | weight_norm(nn.Linear(hid_dim, out_dim), dim=None)
13 | ]
14 | self.main = nn.Sequential(*layers)
15 |
16 | def forward(self, x):
17 | logits = self.main(x)
18 | return logits
19 |
--------------------------------------------------------------------------------
/css/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import cPickle
4 | from collections import defaultdict, Counter
5 | from os.path import dirname, join
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 | import numpy as np
11 | import os
12 |
13 | # from new_dataset import Dictionary, VQAFeatureDataset
14 | from dataset import Dictionary, VQAFeatureDataset
15 | import base_model
16 | from train import train
17 | import utils
18 |
19 | from vqa_debias_loss_functions import *
20 | from tqdm import tqdm
21 | from torch.autograd import Variable
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method")
26 |
27 | # Arguments we added
28 | parser.add_argument(
29 | '--cache_features', default=True,
30 | help="Cache image features in RAM. Makes things much faster, "
31 | "especially if the filesystem is slow, but requires at least 48gb of RAM")
32 | parser.add_argument(
33 | '--dataset', default='cpv2', help="Run on VQA-2.0 instead of VQA-CP 2.0")
34 | parser.add_argument(
35 | '-p', "--entropy_penalty", default=0.36, type=float,
36 | help="Entropy regularizer weight for the learned_mixin model")
37 | parser.add_argument(
38 | '--debias', default="learned_mixin",
39 | choices=["learned_mixin", "reweight", "bias_product", "none"],
40 | help="Kind of ensemble loss to use")
41 | # Arguments from the original model, we leave this default, except we
42 | # set --epochs to 15 since the model maxes out its performance on VQA 2.0 well before then
43 | parser.add_argument('--num_hid', type=int, default=1024)
44 | parser.add_argument('--model', type=str, default='baseline0_newatt')
45 | parser.add_argument('--batch_size', type=int, default=512)
46 | parser.add_argument('--seed', type=int, default=1111, help='random seed')
47 | parser.add_argument('--model_state', type=str, default='logs/exp0/model.pth')
48 | args = parser.parse_args()
49 | return args
50 |
51 | def compute_score_with_logits(logits, labels):
52 | # logits = torch.max(logits, 1)[1].data # argmax
53 | logits = torch.argmax(logits,1)
54 | one_hots = torch.zeros(*labels.size()).cuda()
55 | one_hots.scatter_(1, logits.view(-1, 1), 1)
56 | scores = (one_hots * labels)
57 | return scores
58 |
59 |
60 | def evaluate(model,dataloader,qid2type):
61 | score = 0
62 | upper_bound = 0
63 | score_yesno = 0
64 | score_number = 0
65 | score_other = 0
66 | total_yesno = 0
67 | total_number = 0
68 | total_other = 0
69 | model.train(False)
70 | # import pdb;pdb.set_trace()
71 |
72 |
73 | for v, q, a, b,qids,hintscore in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"):
74 | v = Variable(v, requires_grad=False).cuda()
75 | q = Variable(q, requires_grad=False).cuda()
76 | pred, _ ,_= model(v, q, None, None,None)
77 | batch_score= compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1)
78 | score += batch_score.sum()
79 | upper_bound += (a.max(1)[0]).sum()
80 | qids = qids.detach().cpu().int().numpy()
81 | for j in range(len(qids)):
82 | qid=qids[j]
83 | typ = qid2type[str(qid)]
84 | if typ == 'yes/no':
85 | score_yesno += batch_score[j]
86 | total_yesno += 1
87 | elif typ == 'other':
88 | score_other += batch_score[j]
89 | total_other += 1
90 | elif typ == 'number':
91 | score_number += batch_score[j]
92 | total_number += 1
93 | else:
94 | print('Hahahahahahahahahahaha')
95 | score = score / len(dataloader.dataset)
96 | upper_bound = upper_bound / len(dataloader.dataset)
97 | score_yesno /= total_yesno
98 | score_other /= total_other
99 | score_number /= total_number
100 | print('\teval overall score: %.2f' % (100 * score))
101 | print('\teval up_bound score: %.2f' % (100 * upper_bound))
102 | print('\teval y/n score: %.2f' % (100 * score_yesno))
103 | print('\teval other score: %.2f' % (100 * score_other))
104 | print('\teval number score: %.2f' % (100 * score_number))
105 |
106 | def evaluate_ai(model,dataloader,qid2type,label2ans):
107 | score=0
108 | upper_bound=0
109 |
110 | ai_top1=0
111 | ai_top2=0
112 | ai_top3=0
113 |
114 | for v, q, a, b, qids, hintscore in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"):
115 | v = Variable(v, requires_grad=False).cuda().float().requires_grad_()
116 | q = Variable(q, requires_grad=False).cuda()
117 | a=a.cuda()
118 | hintscore=hintscore.cuda().float()
119 | pred, _, _ = model(v, q, None, None, None)
120 | vqa_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] # [b , 36, 2048]
121 |
122 | vqa_grad_cam=vqa_grad.sum(2)
123 | sv_ind=torch.argmax(vqa_grad_cam,1)
124 |
125 | x_ind_top1=torch.topk(vqa_grad_cam,k=1)[1]
126 | x_ind_top2=torch.topk(vqa_grad_cam,k=2)[1]
127 | x_ind_top3=torch.topk(vqa_grad_cam,k=3)[1]
128 |
129 | y_score_top1 = hintscore.gather(1,x_ind_top1).sum(1)/1
130 | y_score_top2 = hintscore.gather(1,x_ind_top2).sum(1)/2
131 | y_score_top3 = hintscore.gather(1,x_ind_top3).sum(1)/3
132 |
133 |
134 | batch_score=compute_score_with_logits(pred,a.cuda()).cpu().numpy().sum(1)
135 | score+=batch_score.sum()
136 | upper_bound+=(a.max(1)[0]).sum()
137 | qids=qids.detach().cpu().int().numpy()
138 | for j in range(len(qids)):
139 | if batch_score[j]>0:
140 | ai_top1 += y_score_top1[j]
141 | ai_top2 += y_score_top2[j]
142 | ai_top3 += y_score_top3[j]
143 |
144 |
145 |
146 | score = score / len(dataloader.dataset)
147 | upper_bound = upper_bound / len(dataloader.dataset)
148 | ai_top1=(ai_top1.item() * 1.0) / len(dataloader.dataset)
149 | ai_top2=(ai_top2.item() * 1.0) / len(dataloader.dataset)
150 | ai_top3=(ai_top3.item() * 1.0) / len(dataloader.dataset)
151 |
152 | print('\teval overall score: %.2f' % (100 * score))
153 | print('\teval up_bound score: %.2f' % (100 * upper_bound))
154 | print('\ttop1_ai_score: %.2f' % (100 * ai_top1))
155 | print('\ttop2_ai_score: %.2f' % (100 * ai_top2))
156 | print('\ttop3_ai_score: %.2f' % (100 * ai_top3))
157 |
158 | def main():
159 | args = parse_args()
160 | dataset = args.dataset
161 |
162 |
163 | with open('util/qid2type_%s.json'%args.dataset,'r') as f:
164 | qid2type=json.load(f)
165 |
166 | if dataset=='cpv1':
167 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl')
168 | elif dataset=='cpv2' or dataset=='v2':
169 | dictionary = Dictionary.load_from_file('data/dictionary.pkl')
170 |
171 | print("Building test dataset...")
172 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset,
173 | cache_image_features=args.cache_features)
174 |
175 | # Build the model using the original constructor
176 | constructor = 'build_%s' % args.model
177 | model = getattr(base_model, constructor)(eval_dset, args.num_hid).cuda()
178 |
179 | if args.debias == "bias_product":
180 | model.debias_loss_fn = BiasProduct()
181 | elif args.debias == "none":
182 | model.debias_loss_fn = Plain()
183 | elif args.debias == "reweight":
184 | model.debias_loss_fn = ReweightByInvBias()
185 | elif args.debias == "learned_mixin":
186 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty)
187 | else:
188 | raise RuntimeError(args.mode)
189 |
190 |
191 | model_state = torch.load(args.model_state)
192 | model.load_state_dict(model_state)
193 |
194 |
195 | model = model.cuda()
196 | batch_size = args.batch_size
197 |
198 | torch.manual_seed(args.seed)
199 | torch.cuda.manual_seed(args.seed)
200 | torch.backends.cudnn.benchmark = True
201 |
202 | # The original version uses multiple workers, but that just seems slower on my setup
203 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0)
204 |
205 |
206 |
207 | print("Starting eval...")
208 |
209 | evaluate(model,eval_loader,qid2type)
210 |
211 |
212 |
213 | if __name__ == '__main__':
214 | main()
215 |
--------------------------------------------------------------------------------
/css/fc.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch.nn as nn
3 | from torch.nn.utils.weight_norm import weight_norm
4 |
5 |
6 | class FCNet(nn.Module):
7 | """Simple class for non-linear fully connect network
8 | """
9 | def __init__(self, dims):
10 | super(FCNet, self).__init__()
11 |
12 | layers = []
13 | for i in range(len(dims)-2):
14 | in_dim = dims[i]
15 | out_dim = dims[i+1]
16 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
17 | layers.append(nn.ReLU())
18 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
19 | layers.append(nn.ReLU())
20 |
21 | self.main = nn.Sequential(*layers)
22 |
23 | def forward(self, x):
24 | return self.main(x)
25 |
26 |
27 | if __name__ == '__main__':
28 | fc1 = FCNet([10, 20, 10])
29 | print(fc1)
30 |
31 | print('============')
32 | fc2 = FCNet([10, 20])
33 | print(fc2)
34 |
--------------------------------------------------------------------------------
/css/language_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import numpy as np
5 |
6 |
7 | class WordEmbedding(nn.Module):
8 | """Word Embedding
9 |
10 | The ntoken-th dim is used for padding_idx, which agrees *implicitly*
11 | with the definition in Dictionary.
12 | """
13 | def __init__(self, ntoken, emb_dim, dropout):
14 | super(WordEmbedding, self).__init__()
15 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken)
16 | self.dropout = nn.Dropout(dropout)
17 | self.ntoken = ntoken
18 | self.emb_dim = emb_dim
19 |
20 | def init_embedding(self, np_file):
21 | weight_init = torch.from_numpy(np.load(np_file))
22 | assert weight_init.shape == (self.ntoken, self.emb_dim)
23 | self.emb.weight.data[:self.ntoken] = weight_init
24 |
25 | def forward(self, x):
26 | emb = self.emb(x)
27 | emb = self.dropout(emb)
28 | return emb
29 |
30 |
31 | class QuestionEmbedding(nn.Module):
32 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'):
33 | """Module for question embedding
34 | """
35 | super(QuestionEmbedding, self).__init__()
36 | assert rnn_type == 'LSTM' or rnn_type == 'GRU'
37 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU
38 |
39 | self.rnn = rnn_cls(
40 | in_dim, num_hid, nlayers,
41 | bidirectional=bidirect,
42 | dropout=dropout,
43 | batch_first=True)
44 |
45 | self.in_dim = in_dim
46 | self.num_hid = num_hid
47 | self.nlayers = nlayers
48 | self.rnn_type = rnn_type
49 | self.ndirections = 1 + int(bidirect)
50 |
51 | def init_hidden(self, batch):
52 | # just to get the type of tensor
53 | weight = next(self.parameters()).data
54 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid)
55 | if self.rnn_type == 'LSTM':
56 | return (Variable(weight.new(*hid_shape).zero_()),
57 | Variable(weight.new(*hid_shape).zero_()))
58 | else:
59 | return Variable(weight.new(*hid_shape).zero_())
60 |
61 | def forward(self, x):
62 | # x: [batch, sequence, in_dim]
63 | batch = x.size(0)
64 | hidden = self.init_hidden(batch)
65 | self.rnn.flatten_parameters()
66 | output, hidden = self.rnn(x, hidden)
67 |
68 | if self.ndirections == 1:
69 | return output[:, -1]
70 |
71 | forward_ = output[:, -1, :self.num_hid]
72 | backward = output[:, 0, self.num_hid:]
73 | return torch.cat((forward_, backward), dim=1)
74 |
75 | def forward_all(self, x):
76 | # x: [batch, sequence, in_dim]
77 | batch = x.size(0)
78 | hidden = self.init_hidden(batch)
79 | self.rnn.flatten_parameters()
80 | output, hidden = self.rnn(x, hidden)
81 | return output
82 |
--------------------------------------------------------------------------------
/css/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import cPickle as pickle
4 | from collections import defaultdict, Counter
5 | from os.path import dirname, join
6 | import os
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | import numpy as np
12 |
13 | from dataset import Dictionary, VQAFeatureDataset
14 | import base_model
15 | from train import train
16 | import utils
17 | import click
18 |
19 | from vqa_debias_loss_functions import *
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method")
24 |
25 | # Arguments we added
26 | parser.add_argument(
27 | '--cache_features', default=True,
28 | help="Cache image features in RAM. Makes things much faster, "
29 | "especially if the filesystem is slow, but requires at least 48gb of RAM")
30 | parser.add_argument(
31 | '--dataset', default='cpv2',
32 | choices=["v2", "cpv2", "cpv1", "cpv2val"],
33 | help="Run on VQA-2.0 instead of VQA-CP 2.0"
34 | )
35 | parser.add_argument(
36 | '-p', "--entropy_penalty", default=0.36, type=float,
37 | help="Entropy regularizer weight for the learned_mixin model")
38 | parser.add_argument(
39 | '--mode', default="updn",
40 | choices=["updn", "q_debias","v_debias","q_v_debias"],
41 | help="Kind of ensemble loss to use")
42 | parser.add_argument(
43 | '--debias', default="learned_mixin",
44 | choices=["learned_mixin_rw2", "learned_mixin_rw", "learned_mixin", "reweight", "bias_product", "none",'focal'],
45 | help="Kind of ensemble loss to use")
46 | parser.add_argument(
47 | '--topq', type=int,default=1,
48 | choices=[1,2,3],
49 | help="num of words to be masked in questio")
50 | parser.add_argument(
51 | '--keep_qtype', default=True,
52 | help="keep qtype or not")
53 | parser.add_argument(
54 | '--topv', type=int,default=1,
55 | choices=[1,3,5,-1],
56 | help="num of object bbox to be masked in image")
57 | parser.add_argument(
58 | '--top_hint',type=int, default=9,
59 | choices=[9,18,27,36],
60 | help="num of hint")
61 | parser.add_argument(
62 | '--qvp', type=int,default=0,
63 | choices=[0,1,2,3,4,5,6,7,8,9,10],
64 | help="ratio of q_bias and v_bias")
65 | parser.add_argument(
66 | '--eval_each_epoch', default=True,
67 | help="Evaluate every epoch, instead of at the end")
68 |
69 | # Arguments from the original model, we leave this default, except we
70 | # set --epochs to 30 since the model maxes out its performance on VQA 2.0 well before then
71 | # parser.add_argument('--epochs', type=int, default=30)
72 | parser.add_argument('--epochs', type=int, default=15)
73 | parser.add_argument('--num_hid', type=int, default=1024)
74 | parser.add_argument('--model', type=str, default='baseline0_newatt')
75 | parser.add_argument('--output', type=str, default='logs/exp0')
76 | parser.add_argument('--batch_size', type=int, default=512)
77 | parser.add_argument('--seed', type=int, default=1111, help='random seed')
78 | parser.add_argument('--feature', type=str, default='css')
79 | args = parser.parse_args()
80 | return args
81 |
82 | def get_bias(train_dset,eval_dset):
83 | # Compute the bias:
84 | # The bias here is just the expected score for each answer/question type
85 | answer_voc_size = train_dset.num_ans_candidates
86 |
87 | # question_type -> answer -> total score
88 | question_type_to_probs = defaultdict(Counter)
89 |
90 | # question_type -> num_occurances
91 | question_type_to_count = Counter()
92 | for ex in train_dset.entries:
93 | ans = ex["answer"]
94 | q_type = ans["question_type"]
95 | question_type_to_count[q_type] += 1
96 | if ans["labels"] is not None:
97 | for label, score in zip(ans["labels"], ans["scores"]):
98 | question_type_to_probs[q_type][label] += score
99 | question_type_to_prob_array = {}
100 |
101 | for q_type, count in question_type_to_count.items():
102 | prob_array = np.zeros(answer_voc_size, np.float32)
103 | for label, total_score in question_type_to_probs[q_type].items():
104 | prob_array[label] += total_score
105 | prob_array /= count
106 | question_type_to_prob_array[q_type] = prob_array
107 |
108 | for ds in [train_dset,eval_dset]:
109 | for ex in ds.entries:
110 | q_type = ex["answer"]["question_type"]
111 | ex["bias"] = question_type_to_prob_array[q_type]
112 |
113 |
114 | def main():
115 | args = parse_args()
116 | dataset=args.dataset
117 | # args.output=os.path.join('logs',args.output)
118 | if not os.path.isdir(args.output):
119 | utils.create_dir(args.output)
120 | else:
121 | if click.confirm('Exp directory already exists in {}. Erase?'
122 | .format(args.output, default=False)):
123 | os.system('rm -r ' + args.output)
124 | utils.create_dir(args.output)
125 |
126 | else:
127 | os._exit(1)
128 |
129 | if dataset=='cpv1':
130 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl')
131 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val':
132 | dictionary = Dictionary.load_from_file('data/dictionary.pkl')
133 |
134 | print("Building train dataset...")
135 | train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset,
136 | cache_image_features=args.cache_features)
137 |
138 | print("Building test dataset...")
139 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset,
140 | cache_image_features=args.cache_features)
141 |
142 | get_bias(train_dset,eval_dset)
143 |
144 |
145 | # Build the model using the original constructor
146 | constructor = 'build_%s' % args.model
147 | model = getattr(base_model, constructor)(train_dset, args.num_hid).cuda()
148 | if dataset=='cpv1':
149 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.npy')
150 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val':
151 | model.w_emb.init_embedding('data/glove6b_init_300d.npy')
152 |
153 | # Add the loss_fn based our arguments
154 | if args.debias == "bias_product":
155 | model.debias_loss_fn = BiasProduct()
156 | elif args.debias == "none":
157 | model.debias_loss_fn = Plain()
158 | elif args.debias == "reweight":
159 | model.debias_loss_fn = ReweightByInvBias()
160 | elif args.debias == "learned_mixin":
161 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty)
162 | elif args.debias == "focal":
163 | model.debias_loss_fn = Focal()
164 | else:
165 | raise RuntimeError(args.mode)
166 |
167 |
168 | with open('util/qid2type_%s.json'%args.dataset,'r') as f:
169 | qid2type=json.load(f)
170 | model=model.cuda()
171 | batch_size = args.batch_size
172 |
173 | torch.manual_seed(args.seed)
174 | torch.cuda.manual_seed(args.seed)
175 | torch.backends.cudnn.benchmark = True
176 |
177 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0)
178 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0)
179 |
180 | print("Starting training...")
181 | train(model, train_loader, eval_loader, args,qid2type)
182 |
183 | if __name__ == '__main__':
184 | main()
185 |
186 |
187 |
188 |
189 |
190 |
191 |
--------------------------------------------------------------------------------
/css/main_introd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import cPickle as pickle
4 | from collections import defaultdict, Counter
5 | from os.path import dirname, join
6 | import os
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | import numpy as np
12 |
13 | from dataset import Dictionary, VQAFeatureDataset
14 | import base_model_introd as base_model
15 | from train_introd import train
16 | import utils
17 | import click
18 |
19 | from vqa_debias_loss_functions import *
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method")
24 |
25 | # Arguments we added
26 | parser.add_argument(
27 | '--cache_features', default=True,
28 | help="Cache image features in RAM. Makes things much faster, "
29 | "especially if the filesystem is slow, but requires at least 48gb of RAM")
30 | parser.add_argument(
31 | '--dataset', default='cpv2',
32 | choices=["v2", "cpv2", "cpv1", "cpv2val"],
33 | help="Run on VQA-2.0 instead of VQA-CP 2.0"
34 | )
35 | parser.add_argument(
36 | '-p', "--entropy_penalty", default=0.36, type=float,
37 | help="Entropy regularizer weight for the learned_mixin model")
38 | parser.add_argument(
39 | '--mode', default="updn",
40 | choices=["updn", "q_debias","v_debias","q_v_debias"],
41 | help="Kind of ensemble loss to use")
42 | parser.add_argument(
43 | '--debias', default="learned_mixin",
44 | choices=["learned_mixin_rw2", "learned_mixin_rw", "learned_mixin", "reweight", "bias_product", "none",'focal'],
45 | help="Kind of ensemble loss to use")
46 | parser.add_argument(
47 | '--topq', type=int,default=1,
48 | choices=[1,2,3],
49 | help="num of words to be masked in questio")
50 | parser.add_argument(
51 | '--keep_qtype', default=True,
52 | help="keep qtype or not")
53 | parser.add_argument(
54 | '--topv', type=int,default=1,
55 | choices=[1,3,5,-1],
56 | help="num of object bbox to be masked in image")
57 | parser.add_argument(
58 | '--top_hint',type=int, default=9,
59 | choices=[9,18,27,36],
60 | help="num of hint")
61 | parser.add_argument(
62 | '--qvp', type=int,default=0,
63 | choices=[0,1,2,3,4,5,6,7,8,9,10],
64 | help="ratio of q_bias and v_bias")
65 | parser.add_argument(
66 | '--eval_each_epoch', default=True,
67 | help="Evaluate every epoch, instead of at the end")
68 |
69 | # Arguments from the original model, we leave this default, except we
70 | # set --epochs to 30 since the model maxes out its performance on VQA 2.0 well before then
71 | parser.add_argument('--epochs', type=int, default=30)
72 | parser.add_argument('--num_hid', type=int, default=1024)
73 | parser.add_argument('--model', type=str, default='baseline0_newatt')
74 | parser.add_argument('--source', type=str, default='./logs/vqacp2/css/')
75 | parser.add_argument('--output', type=str, default='logs/exp0')
76 | parser.add_argument('--batch_size', type=int, default=512)
77 | parser.add_argument('--seed', type=int, default=1111, help='random seed')
78 | args = parser.parse_args()
79 | return args
80 |
81 | def get_bias(train_dset,eval_dset):
82 | # Compute the bias:
83 | # The bias here is just the expected score for each answer/question type
84 | answer_voc_size = train_dset.num_ans_candidates
85 |
86 | # question_type -> answer -> total score
87 | question_type_to_probs = defaultdict(Counter)
88 |
89 | # question_type -> num_occurances
90 | question_type_to_count = Counter()
91 | for ex in train_dset.entries:
92 | ans = ex["answer"]
93 | q_type = ans["question_type"]
94 | question_type_to_count[q_type] += 1
95 | if ans["labels"] is not None:
96 | for label, score in zip(ans["labels"], ans["scores"]):
97 | question_type_to_probs[q_type][label] += score
98 | question_type_to_prob_array = {}
99 |
100 | for q_type, count in question_type_to_count.items():
101 | prob_array = np.zeros(answer_voc_size, np.float32)
102 | for label, total_score in question_type_to_probs[q_type].items():
103 | prob_array[label] += total_score
104 | prob_array /= count
105 | question_type_to_prob_array[q_type] = prob_array
106 |
107 | for ds in [train_dset,eval_dset]:
108 | for ex in ds.entries:
109 | q_type = ex["answer"]["question_type"]
110 | ex["bias"] = question_type_to_prob_array[q_type]
111 |
112 |
113 | def main():
114 | args = parse_args()
115 | dataset=args.dataset
116 | # args.output=os.path.join('logs',args.output)
117 | if not os.path.isdir(args.output):
118 | utils.create_dir(args.output)
119 | else:
120 | if click.confirm('Exp directory already exists in {}. Erase?'
121 | .format(args.output, default=False)):
122 | os.system('rm -r ' + args.output)
123 | utils.create_dir(args.output)
124 |
125 | else:
126 | os._exit(1)
127 |
128 | if dataset=='cpv1':
129 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl')
130 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val':
131 | dictionary = Dictionary.load_from_file('data/dictionary.pkl')
132 |
133 | print("Building train dataset...")
134 | train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset,
135 | cache_image_features=args.cache_features)
136 |
137 | print("Building test dataset...")
138 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset,
139 | cache_image_features=args.cache_features)
140 |
141 | get_bias(train_dset,eval_dset)
142 |
143 | # Build the model using the original constructor
144 | constructor = 'build_%s' % args.model
145 | model = getattr(base_model, constructor)(train_dset, args.num_hid).cuda()
146 | if dataset=='cpv1':
147 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.npy')
148 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val':
149 | model.w_emb.init_embedding('data/glove6b_init_300d.npy')
150 |
151 | model_student = getattr(base_model, constructor)(train_dset, args.num_hid).cuda()
152 | if dataset=='cpv1':
153 | model_student.w_emb.init_embedding('data/glove6b_init_300d_v1.npy')
154 | elif dataset=='cpv2' or dataset=='v2' or dataset=='cpv2val':
155 | model_student.w_emb.init_embedding('data/glove6b_init_300d.npy')
156 |
157 | state_dict = torch.load(join(args.source, "model.pth"))
158 | model.debias_loss_fn = LearnedMixinKD()
159 | model.load_state_dict(state_dict, strict=False)
160 |
161 | model_student.debias_loss_fn = PlainKD()
162 | model.train(False)
163 |
164 | with open('util/qid2type_%s.json'%args.dataset,'r') as f:
165 | qid2type=json.load(f)
166 | model=model.cuda()
167 | model_student=model_student.cuda()
168 | batch_size = args.batch_size
169 |
170 | torch.manual_seed(args.seed)
171 | torch.cuda.manual_seed(args.seed)
172 | torch.backends.cudnn.benchmark = True
173 |
174 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0)
175 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0)
176 |
177 | print("Starting training...")
178 | train(model, model_student, train_loader, eval_loader, args,qid2type)
179 |
180 | if __name__ == '__main__':
181 | main()
182 |
183 |
184 |
185 |
186 |
187 |
188 |
--------------------------------------------------------------------------------
/css/tools/compute_softscore_val.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import sys
6 | import json
7 | import numpy as np
8 | import re
9 | import cPickle
10 |
11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12 | from dataset import Dictionary
13 | import utils
14 |
15 |
16 | contractions = {
17 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
18 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
19 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
20 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
21 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
22 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
23 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
24 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
25 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
26 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
27 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've":
28 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
29 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
30 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
31 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
32 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
33 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
34 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
35 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
36 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
37 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
38 | "someoned've": "someone'd've", "someone'dve": "someone'd've",
39 | "someonell": "someone'll", "someones": "someone's", "somethingd":
40 | "something'd", "somethingd've": "something'd've", "something'dve":
41 | "something'd've", "somethingll": "something'll", "thats":
42 | "that's", "thered": "there'd", "thered've": "there'd've",
43 | "there'dve": "there'd've", "therere": "there're", "theres":
44 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
45 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
46 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
47 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
48 | "weren't", "whatll": "what'll", "whatre": "what're", "whats":
49 | "what's", "whatve": "what've", "whens": "when's", "whered":
50 | "where'd", "wheres": "where's", "whereve": "where've", "whod":
51 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
52 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
53 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
54 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
55 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
56 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
57 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
58 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
59 | "you'll", "youre": "you're", "youve": "you've"
60 | }
61 |
62 | manual_map = { 'none': '0',
63 | 'zero': '0',
64 | 'one': '1',
65 | 'two': '2',
66 | 'three': '3',
67 | 'four': '4',
68 | 'five': '5',
69 | 'six': '6',
70 | 'seven': '7',
71 | 'eight': '8',
72 | 'nine': '9',
73 | 'ten': '10'}
74 | articles = ['a', 'an', 'the']
75 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
76 | comma_strip = re.compile("(\d)(\,)(\d)")
77 | punct = [';', r"/", '[', ']', '"', '{', '}',
78 | '(', ')', '=', '+', '\\', '_', '-',
79 | '>', '<', '@', '`', ',', '?', '!']
80 |
81 |
82 | def get_score(occurences):
83 | if occurences == 0:
84 | return 0
85 | elif occurences == 1:
86 | return 0.3
87 | elif occurences == 2:
88 | return 0.6
89 | elif occurences == 3:
90 | return 0.9
91 | else:
92 | return 1
93 |
94 |
95 | def process_punctuation(inText):
96 | outText = inText
97 | for p in punct:
98 | if (p + ' ' in inText or ' ' + p in inText) \
99 | or (re.search(comma_strip, inText) != None):
100 | outText = outText.replace(p, '')
101 | else:
102 | outText = outText.replace(p, ' ')
103 | outText = period_strip.sub("", outText, re.UNICODE)
104 | return outText
105 |
106 |
107 | def process_digit_article(inText):
108 | outText = []
109 | tempText = inText.lower().split()
110 | for word in tempText:
111 | word = manual_map.setdefault(word, word)
112 | if word not in articles:
113 | outText.append(word)
114 | else:
115 | pass
116 | for wordId, word in enumerate(outText):
117 | if word in contractions:
118 | outText[wordId] = contractions[word]
119 | outText = ' '.join(outText)
120 | return outText
121 |
122 |
123 | def multiple_replace(text, wordDict):
124 | for key in wordDict:
125 | text = text.replace(key, wordDict[key])
126 | return text
127 |
128 |
129 | def preprocess_answer(answer):
130 | answer = process_digit_article(process_punctuation(answer))
131 | answer = answer.replace(',', '')
132 | return answer
133 |
134 |
135 | def filter_answers(answers_dset, min_occurence):
136 | """This will change the answer to preprocessed version
137 | """
138 | occurence = {}
139 | for ans_entry in answers_dset:
140 | gtruth = ans_entry['multiple_choice_answer']
141 | gtruth = preprocess_answer(gtruth)
142 | if gtruth not in occurence:
143 | occurence[gtruth] = set()
144 | occurence[gtruth].add(ans_entry['question_id'])
145 | for answer in occurence.keys():
146 | if len(occurence[answer]) < min_occurence:
147 | occurence.pop(answer)
148 |
149 | print('Num of answers that appear >= %d times: %d' % (
150 | min_occurence, len(occurence)))
151 | return occurence
152 |
153 |
154 |
155 |
156 |
157 |
158 | def create_ans2label(occurence, name, cache_root):
159 | """Note that this will also create label2ans.pkl at the same time
160 |
161 | occurence: dict {answer -> whatever}
162 | name: prefix of the output file
163 | cache_root: str
164 | """
165 | ans2label = {}
166 | label2ans = []
167 | label = 0
168 | for answer in occurence:
169 | label2ans.append(answer)
170 | ans2label[answer] = label
171 | label += 1
172 |
173 | utils.create_dir(cache_root)
174 |
175 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl')
176 | cPickle.dump(ans2label, open(cache_file, 'wb'))
177 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl')
178 | cPickle.dump(label2ans, open(cache_file, 'wb'))
179 | return ans2label
180 |
181 |
182 | def compute_target(answers_dset, ans2label, name, cache_root):
183 | """Augment answers_dset with soft score as label
184 |
185 | ***answers_dset should be preprocessed***
186 |
187 | Write result into a cache file
188 | """
189 | target = []
190 | for ans_entry in answers_dset:
191 | answers = ans_entry['answers']
192 | answer_count = {}
193 | for answer in answers:
194 | answer_ = answer['answer']
195 | answer_count[answer_] = answer_count.get(answer_, 0) + 1
196 |
197 | labels = []
198 | scores = []
199 | for answer in answer_count:
200 | if answer not in ans2label:
201 | continue
202 | labels.append(ans2label[answer])
203 | score = get_score(answer_count[answer])
204 | scores.append(score)
205 |
206 | label_counts = {}
207 | for k, v in answer_count.items():
208 | if k in ans2label:
209 | label_counts[ans2label[k]] = v
210 |
211 | target.append({
212 | 'question_id': ans_entry['question_id'],
213 | 'question_type': ans_entry['question_type'],
214 | 'image_id': ans_entry['image_id'],
215 | 'label_counts': label_counts,
216 | 'labels': labels,
217 | 'scores': scores
218 | })
219 |
220 | print(cache_root)
221 | utils.create_dir(cache_root)
222 | cache_file = os.path.join(cache_root, name+'_target.pkl')
223 | print(cache_file)
224 | with open(cache_file, 'wb') as f:
225 | cPickle.dump(target, f)
226 | return target
227 |
228 |
229 |
230 | def get_answer(qid, answers):
231 | for ans in answers:
232 | if ans['question_id'] == qid:
233 | return ans
234 |
235 |
236 | def get_question(qid, questions):
237 | for question in questions:
238 | if question['question_id'] == qid:
239 | return question
240 |
241 | def load_cp():
242 | train_answer_file = "data/vqacp2val/vqacp_v2_train_annotations.json"
243 | with open(train_answer_file) as f:
244 | train_answers = json.load(f) # ['annotations']
245 |
246 | val_answer_file = "data/vqacp2val/vqacp_v2_test_annotations.json"
247 | with open(val_answer_file) as f:
248 | val_answers = json.load(f) # ['annotations']
249 |
250 | occurence = filter_answers(train_answers, 9)
251 | ans2label = create_ans2label(occurence, 'trainval', "data/cpval-cache")
252 | compute_target(train_answers, ans2label, 'train', "data/cpval-cache")
253 | compute_target(val_answers, ans2label, 'val', "data/cpval-cache")
254 |
255 | def main():
256 | # parser = argparse.ArgumentParser("Dataset preprocessing")
257 | # parser.add_argument("dataset", choices=["cp_v2", "v2",'cp_v1'])
258 | # args = parser.parse_args()
259 | # if args.dataset == "v2":
260 | # load_v2()
261 | # elif args.dataset == "cp_v1":
262 | # load_cp_v1()
263 | # elif args.dataset=='cp_v2':
264 | # load_cp()
265 | load_cp()
266 |
267 |
268 |
269 | if __name__ == '__main__':
270 | main()
271 |
--------------------------------------------------------------------------------
/css/tools/create_dictionary.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import json
5 | import numpy as np
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 | from dataset import Dictionary
8 |
9 |
10 |
11 |
12 | def create_dictionary(dataroot):
13 | dictionary = Dictionary()
14 | questions = []
15 | files = [
16 | 'v2_OpenEnded_mscoco_train2014_questions.json',
17 | 'v2_OpenEnded_mscoco_val2014_questions.json',
18 | 'v2_OpenEnded_mscoco_test2015_questions.json',
19 | 'v2_OpenEnded_mscoco_test-dev2015_questions.json'
20 | ]
21 | for path in files:
22 | question_path = os.path.join(dataroot, path)
23 | qs = json.load(open(question_path))['questions']
24 | for q in qs:
25 | dictionary.tokenize(q['question'], True)
26 | dictionary.tokenize('wordmask',True)
27 | return dictionary
28 |
29 |
30 | def create_glove_embedding_init(idx2word, glove_file):
31 | word2emb = {}
32 | with open(glove_file, 'r') as f:
33 | entries = f.readlines()
34 | emb_dim = len(entries[0].split(' ')) - 1
35 | print('embedding dim is %d' % emb_dim)
36 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32)
37 |
38 | for entry in entries:
39 | vals = entry.split(' ')
40 | word = vals[0]
41 | vals = map(float, vals[1:])
42 | word2emb[word] = np.array(vals)
43 | for idx, word in enumerate(idx2word):
44 | if word not in word2emb:
45 | continue
46 | weights[idx] = word2emb[word]
47 | return weights, word2emb
48 |
49 |
50 | if __name__ == '__main__':
51 | d = create_dictionary('data')
52 | d.dump_to_file('data/dictionary.pkl')
53 |
54 | d = Dictionary.load_from_file('data/dictionary.pkl')
55 | emb_dim = 300
56 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim
57 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file)
58 | np.save('data/glove6b_init_%dd.npy' % emb_dim, weights)
59 |
--------------------------------------------------------------------------------
/css/tools/create_dictionary_v1.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import json
5 | import numpy as np
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 | from dataset import Dictionary
8 |
9 |
10 | def create_dictionary(dataroot):
11 | dictionary = Dictionary()
12 | questions = []
13 | files = [
14 | 'OpenEnded_mscoco_train2014_questions.json',
15 | 'OpenEnded_mscoco_val2014_questions.json',
16 | 'OpenEnded_mscoco_test2015_questions.json',
17 | 'OpenEnded_mscoco_test-dev2015_questions.json'
18 | ]
19 | for path in files:
20 | question_path = os.path.join(dataroot, path)
21 | qs = json.load(open(question_path))['questions']
22 | for q in qs:
23 | dictionary.tokenize(q['question'], True)
24 | return dictionary
25 |
26 |
27 | def create_glove_embedding_init(idx2word, glove_file):
28 | word2emb = {}
29 | with open(glove_file, 'r') as f:
30 | entries = f.readlines()
31 | emb_dim = len(entries[0].split(' ')) - 1
32 | print('embedding dim is %d' % emb_dim)
33 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32)
34 |
35 | for entry in entries:
36 | vals = entry.split(' ')
37 | word = vals[0]
38 | vals = map(float, vals[1:])
39 | word2emb[word] = np.array(vals)
40 | for idx, word in enumerate(idx2word):
41 | if word not in word2emb:
42 | continue
43 | weights[idx] = word2emb[word]
44 | return weights, word2emb
45 |
46 |
47 | if __name__ == '__main__':
48 | d = create_dictionary('data')
49 | d.dump_to_file('data/dictionary_v1.pkl')
50 |
51 | d = Dictionary.load_from_file('data/dictionary_v1.pkl')
52 | emb_dim = 300
53 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim
54 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file)
55 | np.save('data/glove6b_init_%dd_v1.npy' % emb_dim, weights)
56 |
--------------------------------------------------------------------------------
/css/tools/download.sh:
--------------------------------------------------------------------------------
1 | ## Script for downloading data
2 |
3 | # GloVe Vectors
4 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip
5 | unzip data/glove.6B.zip -d data/glove
6 | rm data/glove.6B.zip
7 |
8 | # VQA-CP2
9 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json
10 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json
11 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json
12 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json
13 |
14 | # VQA-CP1
15 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json
16 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json
17 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json
18 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json
19 |
20 | # VQA-V2
21 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip
22 | unzip data/v2_Questions_Train_mscoco.zip -d data
23 | rm data/v2_Questions_Train_mscoco.zip
24 |
25 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip
26 | unzip data/v2_Questions_Val_mscoco.zip -d data
27 | rm data/v2_Questions_Val_mscoco.zip
28 |
29 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip
30 | unzip data/v2_Questions_Test_mscoco.zip -d data
31 | rm data/v2_Questions_Test_mscoco.zip
32 |
33 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip
34 | unzip data/v2_Annotations_Train_mscoco.zip -d data
35 | rm data/v2_Annotations_Train_mscoco.zip
36 |
37 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip
38 | unzip data/v2_Annotations_Val_mscoco.zip -d data
39 | rm data/v2_Annotations_Val_mscoco.zip
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/css/tools/process.sh:
--------------------------------------------------------------------------------
1 | # Process data
2 |
3 | python tools/create_dictionary.py
4 | python tools/create_dictionary_v1.py
5 | python tools/compute_softscore.py v2
6 | python tools/compute_softscore.py cp_v1
7 | python tools/compute_softscore.py cp_v2
8 |
9 |
--------------------------------------------------------------------------------
/css/train_introd.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 | import time
5 | from os.path import join
6 |
7 | import torch
8 | import torch.nn as nn
9 | import utils
10 | from torch.autograd import Variable
11 | import numpy as np
12 | from tqdm import tqdm
13 | import random
14 | import copy
15 | from torch.nn import functional as F
16 |
17 |
18 | def compute_score_with_logits(logits, labels):
19 | logits = torch.argmax(logits, 1)
20 | one_hots = torch.zeros(*labels.size()).cuda()
21 | one_hots.scatter_(1, logits.view(-1, 1), 1)
22 | scores = (one_hots * labels)
23 | return scores
24 |
25 | def train(model_teacher, model, train_loader, eval_loader,args,qid2type):
26 | dataset=args.dataset
27 | num_epochs=args.epochs
28 | mode=args.mode
29 | run_eval=args.eval_each_epoch
30 | output=args.output
31 | optim = torch.optim.Adamax(model.parameters())
32 | logger = utils.Logger(os.path.join(output, 'log.txt'))
33 | total_step = 0
34 | best_eval_score = 0
35 |
36 | logsigmoid = torch.nn.LogSigmoid()
37 |
38 | KLDivLoss = torch.nn.KLDivLoss(reduction='none')
39 |
40 | for epoch in range(num_epochs):
41 | total_loss = 0
42 | train_score = 0
43 |
44 | t = time.time()
45 | for i, (v, q, a, b, _, _, _, _) in tqdm(enumerate(train_loader), ncols=100,
46 | desc="Epoch %d" % (epoch + 1), total=len(train_loader)):
47 |
48 | total_step += 1
49 |
50 |
51 | #########################################
52 | v = Variable(v).cuda().requires_grad_()
53 | q = Variable(q).cuda()
54 | # q_mask=Variable(q_mask).cuda()
55 | a = Variable(a).cuda()
56 | b = Variable(b).cuda()
57 | # hintscore = Variable(hintscore).cuda()
58 | # type_mask=Variable(type_mask).float().cuda()
59 | # notype_mask=Variable(notype_mask).float().cuda()
60 | #########################################
61 |
62 | pred_nie, pred_te, _, _ = model_teacher(v, q, a, b, None)
63 | pred, _, loss_ce, _ = model(v, q, a, b, None)
64 |
65 | aa = a/torch.clamp(a.sum(1, keepdim=True), min=1e-24)
66 | loss_te = -(aa*logsigmoid(pred_te) + (1-aa)*logsigmoid(-pred_te)).sum(1)
67 | loss_nie = -(aa*logsigmoid(pred_nie) + (1-aa)*logsigmoid(-pred_nie)).sum(1)
68 |
69 | loss_te = torch.clamp(loss_te, min=1e-12)
70 | loss_nie = torch.clamp(loss_nie, min=1e-12)
71 | w = loss_nie/(loss_te+loss_nie)
72 | w = w.clone().detach()
73 |
74 | # KL
75 | prob_nie = F.softmax(pred_nie, -1).clone().detach()
76 | loss_kl = - prob_nie*F.log_softmax(pred, -1)
77 | loss_kl = loss_kl.sum(1)
78 |
79 | loss = (w*loss_kl + (1-w)*loss_ce).mean()
80 |
81 | if (loss != loss).any():
82 | raise ValueError("NaN loss")
83 | loss.backward()
84 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
85 | optim.step()
86 | optim.zero_grad()
87 |
88 | total_loss += loss.item() * q.size(0)
89 | batch_score = compute_score_with_logits(pred, a.data).sum()
90 | train_score += batch_score
91 |
92 | total_loss /= len(train_loader.dataset)
93 | train_score = 100 * train_score / len(train_loader.dataset)
94 |
95 | if run_eval:
96 | model.train(False)
97 | results = evaluate(model, eval_loader, qid2type)
98 | results["epoch"] = epoch + 1
99 | results["step"] = total_step
100 | results["train_loss"] = total_loss
101 | results["train_score"] = train_score
102 |
103 | model.train(True)
104 |
105 | eval_score = results["score"]
106 | bound = results["upper_bound"]
107 | yn = results['score_yesno']
108 | other = results['score_other']
109 | num = results['score_number']
110 |
111 | logger.write('epoch %d, time: %.2f' % (epoch + 1, time.time() - t))
112 | logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score))
113 |
114 | if run_eval:
115 | logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound))
116 | logger.write('\tyn score: %.2f other score: %.2f num score: %.2f' % (100 * yn, 100 * other, 100 * num))
117 |
118 | if eval_score > best_eval_score:
119 | model_path = os.path.join(output, 'model.pth')
120 | torch.save(model.state_dict(), model_path)
121 | best_eval_score = eval_score
122 |
123 |
124 | def evaluate(model, dataloader, qid2type):
125 | score = 0
126 | upper_bound = 0
127 | score_yesno = 0
128 | score_number = 0
129 | score_other = 0
130 | total_yesno = 0
131 | total_number = 0
132 | total_other = 0
133 |
134 | for v, q, a, b, qids, _ in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"):
135 | v = Variable(v, requires_grad=False).cuda()
136 | q = Variable(q, requires_grad=False).cuda()
137 | pred, _, _, _ = model(v, q, None, None, None)
138 | batch_score = compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1)
139 | score += batch_score.sum()
140 | upper_bound += (a.max(1)[0]).sum()
141 | qids = qids.detach().cpu().int().numpy()
142 | for j in range(len(qids)):
143 | qid = qids[j]
144 | typ = qid2type[str(qid)]
145 | if typ == 'yes/no':
146 | score_yesno += batch_score[j]
147 | total_yesno += 1
148 | elif typ == 'other':
149 | score_other += batch_score[j]
150 | total_other += 1
151 | elif typ == 'number':
152 | score_number += batch_score[j]
153 | total_number += 1
154 | else:
155 | print('Hahahahahahahahahahaha')
156 |
157 |
158 | score = score / len(dataloader.dataset)
159 | upper_bound = upper_bound / len(dataloader.dataset)
160 | score_yesno /= total_yesno
161 | score_other /= total_other
162 | score_number /= total_number
163 |
164 | results = dict(
165 | score=score,
166 | upper_bound=upper_bound,
167 | score_yesno=score_yesno,
168 | score_other=score_other,
169 | score_number=score_number,
170 | )
171 | return results
172 |
--------------------------------------------------------------------------------
/css/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import errno
4 | import os
5 | import numpy as np
6 | # from PIL import Image
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | EPS = 1e-7
12 |
13 |
14 | def assert_eq(real, expected):
15 | # assert real == expected, '%s (true) vs %s (expected)' % (real, expected)
16 | assert real == real, '%s (true) vs %s (expected)' % (real, expected)
17 |
18 |
19 | def assert_array_eq(real, expected):
20 | assert (np.abs(real-expected) < EPS).all(), \
21 | '%s (true) vs %s (expected)' % (real, expected)
22 |
23 |
24 | def load_folder(folder, suffix):
25 | imgs = []
26 | for f in sorted(os.listdir(folder)):
27 | if f.endswith(suffix):
28 | imgs.append(os.path.join(folder, f))
29 | return imgs
30 |
31 |
32 | # def load_imageid(folder):
33 | # images = load_folder(folder, 'jpg')
34 | # img_ids = set()
35 | # for img in images:
36 | # img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1])
37 | # img_ids.add(img_id)
38 | # return img_ids
39 |
40 |
41 | # def pil_loader(path):
42 | # with open(path, 'rb') as f:
43 | # with Image.open(f) as img:
44 | # return img.convert('RGB')
45 |
46 |
47 | def weights_init(m):
48 | """custom weights initialization."""
49 | cname = m.__class__
50 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d:
51 | m.weight.data.normal_(0.0, 0.02)
52 | elif cname == nn.BatchNorm2d:
53 | m.weight.data.normal_(1.0, 0.02)
54 | m.bias.data.fill_(0)
55 | else:
56 | print('%s is not initialized.' % cname)
57 |
58 |
59 | def init_net(net, net_file):
60 | if net_file:
61 | net.load_state_dict(torch.load(net_file))
62 | else:
63 | net.apply(weights_init)
64 |
65 |
66 | def create_dir(path):
67 | if not os.path.exists(path):
68 | try:
69 | os.makedirs(path)
70 | except OSError as exc:
71 | if exc.errno != errno.EEXIST:
72 | raise
73 |
74 |
75 | class Logger(object):
76 | def __init__(self, output_name):
77 | dirname = os.path.dirname(output_name)
78 | if not os.path.exists(dirname):
79 | os.mkdir(dirname)
80 |
81 | self.log_file = open(output_name, 'w')
82 | self.infos = {}
83 |
84 | def append(self, key, val):
85 | vals = self.infos.setdefault(key, [])
86 | vals.append(val)
87 |
88 | def log(self, extra_msg=''):
89 | msgs = [extra_msg]
90 | for key, vals in self.infos.iteritems():
91 | msgs.append('%s %.6f' % (key, np.mean(vals)))
92 | msg = '\n'.join(msgs)
93 | self.log_file.write(msg + '\n')
94 | self.log_file.flush()
95 | self.infos = {}
96 | return msg
97 |
98 | def write(self, msg):
99 | self.log_file.write(msg + '\n')
100 | self.log_file.flush()
101 | print(msg)
102 |
--------------------------------------------------------------------------------
/css/vqa_debias_loss_functions.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict, defaultdict, Counter
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | import numpy as np
6 | import torch
7 | import inspect
8 |
9 |
10 | def convert_sigmoid_logits_to_binary_logprobs(logits):
11 | """computes log(sigmoid(logits)), log(1-sigmoid(logits))"""
12 | log_prob = -F.softplus(-logits)
13 | log_one_minus_prob = -logits + log_prob
14 | return log_prob, log_one_minus_prob
15 |
16 |
17 | def elementwise_logsumexp(a, b):
18 | """computes log(exp(x) + exp(b))"""
19 | return torch.max(a, b) + torch.log1p(torch.exp(-torch.abs(a - b)))
20 |
21 |
22 | def renormalize_binary_logits(a, b):
23 | """Normalize so exp(a) + exp(b) == 1"""
24 | norm = elementwise_logsumexp(a, b)
25 | return a - norm, b - norm
26 |
27 |
28 | class DebiasLossFn(nn.Module):
29 | """General API for our loss functions"""
30 |
31 | def forward(self, hidden, logits, bias, labels):
32 | """
33 | :param hidden: [batch, n_hidden] hidden features from the last layer in the model
34 | :param logits: [batch, n_answers_options] sigmoid logits for each answer option
35 | :param bias: [batch, n_answers_options]
36 | bias probabilities for each answer option between 0 and 1
37 | :param labels: [batch, n_answers_options]
38 | scores for each answer option, between 0 and 1
39 | :return: Scalar loss
40 | """
41 | raise NotImplementedError()
42 |
43 | def to_json(self):
44 | """Get a json representation of this loss function.
45 |
46 | We construct this by looking up the __init__ args
47 | """
48 | cls = self.__class__
49 | init = cls.__init__
50 | if init is object.__init__:
51 | return [] # No init args
52 |
53 | init_signature = inspect.getargspec(init)
54 | if init_signature.varargs is not None:
55 | raise NotImplementedError("varags not supported")
56 | if init_signature.keywords is not None:
57 | raise NotImplementedError("keywords not supported")
58 | args = [x for x in init_signature.args if x != "self"]
59 | out = OrderedDict()
60 | out["name"] = cls.__name__
61 | for key in args:
62 | out[key] = getattr(self, key)
63 | return out
64 |
65 |
66 | class Plain(DebiasLossFn):
67 | def forward(self, hidden, logits, bias, labels):
68 | loss = F.binary_cross_entropy_with_logits(logits, labels)
69 |
70 | loss *= labels.size(1)
71 |
72 | return loss
73 |
74 | class PlainKD(DebiasLossFn):
75 | def forward(self, hidden, logits, bias, labels):
76 | loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
77 |
78 | # loss *= labels.size(1)
79 | loss = loss.sum(1)
80 |
81 | # return loss
82 | return None, loss
83 |
84 |
85 | class Focal(DebiasLossFn):
86 | def forward(self, hidden, logits, bias, labels):
87 | # import pdb;pdb.set_trace()
88 | focal_logits=torch.log(F.softmax(logits,dim=1)+1e-5) * ((1-F.softmax(bias,dim=1))*(1-F.softmax(bias,dim=1)))
89 | loss=F.binary_cross_entropy_with_logits(focal_logits,labels)
90 | loss*=labels.size(1)
91 | return loss
92 |
93 | class ReweightByInvBias(DebiasLossFn):
94 | def forward(self, hidden, logits, bias, labels):
95 | # Manually compute the binary cross entropy since the old version of torch always aggregates
96 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits)
97 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob)
98 | weights = (1 - bias)
99 | loss *= weights # Apply the weights
100 | return loss.sum() / weights.sum()
101 |
102 |
103 | class BiasProduct(DebiasLossFn):
104 | def __init__(self, smooth=True, smooth_init=-1, constant_smooth=0.0):
105 | """
106 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it
107 | :param smooth_init: How to initialize `a`
108 | :param constant_smooth: Constant to add to the bias to smooth it
109 | """
110 | super(BiasProduct, self).__init__()
111 | self.constant_smooth = constant_smooth
112 | self.smooth_init = smooth_init
113 | self.smooth = smooth
114 | if smooth:
115 | self.smooth_param = torch.nn.Parameter(
116 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32)))
117 | else:
118 | self.smooth_param = None
119 |
120 | def forward(self, hidden, logits, bias, labels):
121 | smooth = self.constant_smooth
122 | if self.smooth:
123 | smooth += F.sigmoid(self.smooth_param)
124 |
125 | # Convert the bias into log-space, with a factor for both the
126 | # binary outputs for each answer option
127 | bias_lp = torch.log(bias + smooth)
128 | bias_l_inv = torch.log1p(-bias + smooth)
129 |
130 | # Convert the the logits into log-space with the same format
131 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits)
132 | # import pdb;pdb.set_trace()
133 |
134 | # Add the bias
135 | log_prob += bias_lp
136 | log_one_minus_prob += bias_l_inv
137 |
138 | # Re-normalize the factors in logspace
139 | log_prob, log_one_minus_prob = renormalize_binary_logits(log_prob, log_one_minus_prob)
140 |
141 | # Compute the binary cross entropy
142 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0)
143 | return loss
144 |
145 |
146 | class LearnedMixin(DebiasLossFn):
147 | def __init__(self, w, smooth=True, smooth_init=-1, constant_smooth=0.0):
148 | """
149 | :param w: Weight of the entropy penalty
150 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it
151 | :param smooth_init: How to initialize `a`
152 | :param constant_smooth: Constant to add to the bias to smooth it
153 | """
154 | super(LearnedMixin, self).__init__()
155 | self.w = w
156 | # self.w=0
157 | self.smooth_init = smooth_init
158 | self.constant_smooth = constant_smooth
159 | self.bias_lin = torch.nn.Linear(1024, 1)
160 | self.smooth = smooth
161 | if self.smooth:
162 | self.smooth_param = torch.nn.Parameter(
163 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32)))
164 | else:
165 | self.smooth_param = None
166 |
167 | def forward(self, hidden, logits, bias, labels):
168 | factor = self.bias_lin.forward(hidden) # [batch, 1]
169 | factor = F.softplus(factor)
170 |
171 | bias = torch.stack([bias, 1 - bias], 2) # [batch, n_answers, 2]
172 |
173 | # Smooth
174 | bias += self.constant_smooth
175 | if self.smooth:
176 | soften_factor = F.sigmoid(self.smooth_param)
177 | bias = bias + soften_factor.unsqueeze(1)
178 |
179 | bias = torch.log(bias) # Convert to logspace
180 |
181 | # Scale by the factor
182 | # [batch, n_answers, 2] * [batch, 1, 1] -> [batch, n_answers, 2]
183 | bias = bias * factor.unsqueeze(1)
184 |
185 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits)
186 | log_probs = torch.stack([log_prob, log_one_minus_prob], 2)
187 |
188 | # Add the bias in
189 | logits = bias + log_probs
190 |
191 | # Renormalize to get log probabilities
192 | log_prob, log_one_minus_prob = renormalize_binary_logits(logits[:, :, 0], logits[:, :, 1])
193 |
194 | # Compute loss
195 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0)
196 |
197 | # Re-normalized version of the bias
198 | bias_norm = elementwise_logsumexp(bias[:, :, 0], bias[:, :, 1])
199 | bias_logprob = bias - bias_norm.unsqueeze(2)
200 |
201 | # Compute and add the entropy penalty
202 | entropy = -(torch.exp(bias_logprob) * bias_logprob).sum(2).mean()
203 | return loss + self.w * entropy
204 |
205 |
206 | class LearnedMixinKD(DebiasLossFn):
207 | def __init__(self, smooth=True, smooth_init=-1, constant_smooth=0.0):
208 | """
209 | :param w: Weight of the entropy penalty
210 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it
211 | :param smooth_init: How to initialize `a`
212 | :param constant_smooth: Constant to add to the bias to smooth it
213 | """
214 | super(LearnedMixinKD, self).__init__()
215 | self.smooth_init = smooth_init
216 | self.constant_smooth = constant_smooth
217 | self.bias_lin = torch.nn.Linear(1024, 1)
218 | self.smooth = smooth
219 | if self.smooth:
220 | self.smooth_param = torch.nn.Parameter(
221 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32)))
222 | else:
223 | self.smooth_param = None
224 |
225 | def forward(self, hidden, logits, bias, labels):
226 | factor = self.bias_lin.forward(hidden) # [batch, 1]
227 | factor = F.softplus(factor)
228 |
229 | bias = torch.stack([bias, 1 - bias], 2) # [batch, n_answers, 2]
230 |
231 | # Smooth
232 | bias += self.constant_smooth
233 | if self.smooth:
234 | soften_factor = F.sigmoid(self.smooth_param)
235 | bias = bias + soften_factor.unsqueeze(1)
236 |
237 | bias = torch.log(bias) # Convert to logspace
238 |
239 | # Scale by the factor
240 | # [batch, n_answers, 2] * [batch, 1, 1] -> [batch, n_answers, 2]
241 | bias = bias * factor.unsqueeze(1)
242 |
243 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits)
244 | log_probs = torch.stack([log_prob, log_one_minus_prob], 2)
245 |
246 | # Add the bias in
247 | logits = bias + log_probs
248 |
249 | # Renormalize to get log probabilities
250 | log_prob, log_one_minus_prob = renormalize_binary_logits(logits[:, :, 0], logits[:, :, 1])
251 |
252 | # Compute loss
253 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0)
254 |
255 | # Re-normalized version of the bias
256 | bias_norm = elementwise_logsumexp(bias[:, :, 0], bias[:, :, 1])
257 | bias_logprob = bias - bias_norm.unsqueeze(2)
258 |
259 | prob_all = torch.exp(log_prob)
260 |
261 | p = torch.clamp(1-prob_all, min=1e-12)
262 | p = torch.clamp(prob_all/p, min=1e-12)
263 |
264 | logits_all = torch.log(p)
265 |
266 | return logits_all, loss
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/images/architecture.png
--------------------------------------------------------------------------------
/images/introd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuleiniu/introd/a40407c7efee9c34e3d4270d7947f5be2f926413/images/introd.png
--------------------------------------------------------------------------------