├── .gitignore ├── LICENSE ├── README.md ├── config ├── ban_vqa.json ├── ban_vqa_cp.json ├── butd_vqa.json ├── mutan_vqa.json └── parser.py ├── dataset.py ├── dataset_cp_v2.py ├── eval.py ├── main.py ├── misc └── regat_overview.jpg ├── model ├── __init__.py ├── bc.py ├── bilinear_attention.py ├── classifier.py ├── counting.py ├── fc.py ├── fusion.py ├── graph_att.py ├── graph_att_layer.py ├── language_model.py ├── position_emb.py ├── regat.py └── relation_encoder.py ├── tools ├── __init__.py ├── compute_softscore.py ├── create_dictionary.py ├── create_embedding.py ├── download.sh ├── environment.yml └── process.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | saved_models/ 3 | pretrained_models/ 4 | gt_logits/ 5 | .vscode/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Relation-aware Graph Attention Network for Visual Question Answering 2 | 3 | This repository is the implementation of [Relation-aware Graph Attention Network for Visual Question Answering](https://arxiv.org/abs/1903.12314). 4 | 5 | ![Overview of ReGAT](misc/regat_overview.jpg) 6 | 7 | This repository is based on and inspired by @hengyuan-hu's [work](https://github.com/hengyuan-hu/bottom-up-attention-vqa) and @Jin-Hwa Kim's [work](https://github.com/jnhwkim/ban-vqa). We sincerely thank for their sharing of the codes. 8 | 9 | ## Prerequisites 10 | 11 | You may need a machine with 4 GPUs with 16GB memory each, and PyTorch v1.0.1 for Python 3. 12 | 13 | 1. Install [PyTorch](http://pytorch.org/) with CUDA10.0 and Python 3.7. 14 | 2. Install [h5py](http://docs.h5py.org/en/latest/build.html). 15 | 3. Install [block.bootstrap.pytorch](https://github.com/Cadene/block.bootstrap.pytorch). 16 | 17 | If you are using miniconda, you can install all the prerequisites with `tools/environment.yml`. 18 | 19 | ## Data 20 | 21 | Our implementation uses the pretrained features from [bottom-up-attention](https://github.com/peteanderson80/bottom-up-attention), the adaptive 10-100 features per image. In addition to this, the GloVe vectors and Visual Genome question answer pairs. For your convenience, the below script helps you to download preprocessed data. 22 | 23 | ```bash 24 | source tools/download.sh 25 | ``` 26 | 27 | In addition to data, this script also download several pretrained models. In the end, the data folder and pretrained_models folder should be organized as shown below: 28 | 29 | ```bash 30 | ├── data 31 | │ ├── Answers 32 | │ │ ├── v2_mscoco_train2014_annotations.json 33 | │ │ └── v2_mscoco_val2014_annotations.json 34 | │ ├── Bottom-up-features-adaptive 35 | │ │ ├── train.hdf5 36 | │ │ ├── val.hdf5 37 | │ │ └── test2015.hdf5 38 | │ ├── Bottom-up-features-fixed 39 | │ │ ├── train36.hdf5 40 | │ │ ├── val36.hdf5 41 | │ │ └── test2015_36.hdf5 42 | │ ├── cache 43 | │ │ ├── cp_v2_test_target.pkl 44 | │ │ ├── cp_v2_train_target.pkl 45 | │ │ ├── train_target.pkl 46 | │ │ ├── val_target.pkl 47 | │ │ ├── trainval_ans2label.pkl 48 | │ │ └── trainval_label2ans.pkl 49 | │ ├── cp_v2_annotations 50 | │ │ ├── vqacp_v2_test_annotations.json 51 | │ │ └── vqacp_v2_train_annotations.json 52 | │ ├── cp_v2_questions 53 | │ │ ├── vqacp_v2_test_questions.json 54 | │ │ └── vqacp_v2_train_questions.json 55 | │ ├── glove 56 | │ │ ├── dictionary.pkl 57 | │ │ ├── glove6b_init_300d.npy 58 | │ │ └──- glove6b.300d.txt 59 | │ ├── imgids 60 | │ │ ├── test2015_36_imgid2idx.pkl 61 | │ │ ├── test2015_ids.pkl 62 | │ │ ├── test2015_imgid2idx.pkl 63 | │ │ ├── train36_imgid2idx.pkl 64 | │ │ ├── train_ids.pkl 65 | │ │ ├── train_imgid2idx.pkl 66 | │ │ ├── val36_imgid2idx.pkl 67 | │ │ ├── val_ids.pkl 68 | │ │ └── val_imgid2idx.pkl 69 | │ ├── Questions 70 | │ │ ├── v2_OpenEnded_mscoco_test-dev2015_questions.json 71 | │ │ ├── v2_OpenEnded_mscoco_test2015_questions.json 72 | │ │ ├── v2_OpenEnded_mscoco_train2014_questions.json 73 | │ │ └── v2_OpenEnded_mscoco_val2014_questions.json 74 | │ ├── visualGenome 75 | │ │ ├── image_data.json 76 | │ │ └── question_answers.json 77 | ``` 78 | 79 | ```bash 80 | ├── pretrained_models (each model folder contains hps.json and model.pth) 81 | │ ├── regat_implicit 82 | │ │ ├── ban_1_implicit_vqa_196 83 | │ │ ├── ban_4_implicit_vqa_cp_4422 84 | │ │ ├── butd_implicit_vqa_6371 85 | │ │ └── mutan_implicit_vqa_2632 86 | │ ├── regat_semantic 87 | │ │ ├── ban_1_semantic_vqa_7971 88 | │ │ ├── ban_4_semantic_vqa_cp_9960 89 | │ │ ├── butd_semantic_vqa_244 90 | │ │ └── mutan_semantic_vqa_2711 91 | │ ├── regat_spatial 92 | │ │ ├── ban_1_spatial_vqa_1687 93 | │ │ ├── ban_4_spatial_vqa_cp_4488 94 | │ │ ├── butd_spatial_vqa_5942 95 | │ │ └── mutan_spatial_vqa_3842 96 | ``` 97 | 98 | ## Training 99 | 100 | ```bash 101 | python3 main.py --config config/butd_vqa.json 102 | ``` 103 | 104 | ## Evaluating 105 | 106 | ```bash 107 | # take ban_1_implicit_vqa_196 as an example 108 | # to evaluate cp_v2 performance, need to use --dataset cp_v2 --split test 109 | python3 eval.py --output_folder pretrained_models/regat_implicit/ban_1_implicit_vqa_196 110 | ``` 111 | 112 | ## Citation 113 | 114 | If you use this code as part of any published research, we'd really appreciate it if you could cite the following paper: 115 | 116 | ```text 117 | @article{li2019relation, 118 | title={Relation-aware Graph Attention Network for Visual Question Answering}, 119 | author={Li, Linjie and Gan, Zhe and Cheng, Yu and Liu, Jingjing}, 120 | journal={ICCV}, 121 | year={2019} 122 | } 123 | ``` 124 | 125 | ## License 126 | 127 | MIT License 128 | -------------------------------------------------------------------------------- /config/ban_vqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs": 20, 3 | "base_lr": 0.001, 4 | "lr_decay_start": 15, 5 | "lr_decay_rate": 0.25, 6 | "lr_decay_step": 2, 7 | "lr_decay_based_on_val": true, 8 | "grad_accu_steps": 1, 9 | "grad_clip": 0.25, 10 | "weight_decay": 0, 11 | "batch_size": 64, 12 | "output": "saved_models/", 13 | "log_interval": -1, 14 | "dataset": "vqa", 15 | "data_folder": "./data", 16 | "use_both": false, 17 | "use_vg": false, 18 | "adaptive": true, 19 | "relation_type": "implicit", 20 | "fusion": "ban", 21 | "tfidf": true, 22 | "op": "c", 23 | "num_hid": 1024, 24 | "ban_gamma": 1, 25 | "num_heads": 16, 26 | "imp_pos_emb_dim": 64, 27 | "dir_num": 2, 28 | "spa_label_num": 11, 29 | "sem_label_num": 15, 30 | "relation_dim": 1024, 31 | "nongt_dim": 20, 32 | "num_steps": 1, 33 | "residual_connection": true, 34 | "label_bias": false 35 | } -------------------------------------------------------------------------------- /config/ban_vqa_cp.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs": 20, 3 | "base_lr": 0.001, 4 | "lr_decay_start": 15, 5 | "lr_decay_rate": 0.25, 6 | "lr_decay_step": 2, 7 | "lr_decay_based_on_val": true, 8 | "grad_accu_steps": 1, 9 | "grad_clip": 0.25, 10 | "weight_decay": 0, 11 | "batch_size": 64, 12 | "output": "saved_models/", 13 | "log_interval": -1, 14 | "dataset": "vqa_cp", 15 | "data_folder": "./data", 16 | "use_both": false, 17 | "use_vg": false, 18 | "adaptive": true, 19 | "relation_type": "implicit", 20 | "fusion": "ban", 21 | "tfidf": true, 22 | "op": "c", 23 | "num_hid": 1024, 24 | "ban_gamma": 4, 25 | "num_heads": 16, 26 | "imp_pos_emb_dim": 64, 27 | "dir_num": 2, 28 | "spa_label_num": 11, 29 | "sem_label_num": 15, 30 | "relation_dim": 1024, 31 | "nongt_dim": 20, 32 | "num_steps": 1, 33 | "residual_connection": true, 34 | "label_bias": false 35 | } -------------------------------------------------------------------------------- /config/butd_vqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "adaptive": true, 3 | "base_lr": 1e-3, 4 | "batch_size": 64, 5 | "checkpoint": "", 6 | "data_folder": "./data", 7 | "dataset": "vqa", 8 | "dir_num": 2, 9 | "epochs": 20, 10 | "fusion": "butd", 11 | "imp_pos_emb_dim": 64, 12 | "num_heads": 16, 13 | "spa_label_num": 11, 14 | "sem_label_num": 15, 15 | "lr_decay_start": 15, 16 | "lr_decay_based_on_val": true, 17 | "lr_decay_rate": 0.25, 18 | "nongt_dim": 20, 19 | "num_hid": 1024, 20 | "num_steps": 1, 21 | "output": "saved_models/", 22 | "relation_dim": 1024, 23 | "relation_type": "implicit", 24 | "residual_connection": true, 25 | "label_bias": false, 26 | "tfidf": true 27 | } -------------------------------------------------------------------------------- /config/mutan_vqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs": 20, 3 | "base_lr": 0.001, 4 | "lr_decay_start": 15, 5 | "lr_decay_rate": 0.25, 6 | "lr_decay_step": 2, 7 | "lr_decay_based_on_val": true, 8 | "grad_accu_steps": 1, 9 | "grad_clip": 0.25, 10 | "weight_decay": 0, 11 | "batch_size": 64, 12 | "output": "saved_models/", 13 | "log_interval": -1, 14 | "dataset": "vqa", 15 | "data_folder": "./data", 16 | "adaptive": true, 17 | "relation_type": "implicit", 18 | "fusion": "mutan", 19 | "tfidf": true, 20 | "op": "c", 21 | "num_hid": 1024, 22 | "mutan_gamma": 2, 23 | "num_heads": 16, 24 | "imp_pos_emb_dim": 64, 25 | "dir_num": 2, 26 | "spa_label_num": 11, 27 | "sem_label_num": 15, 28 | "relation_dim": 1024, 29 | "nongt_dim": 20, 30 | "num_steps": 1, 31 | "residual_connection": true, 32 | "label_bias": false 33 | } 34 | -------------------------------------------------------------------------------- /config/parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Relation-aware Graph Attention Network for Visual Question Answering 6 | Linjie Li, Zhe Gan, Yu Cheng, Jingjing Liu 7 | https://arxiv.org/abs/1903.12314 8 | 9 | This code is written by Linjie Li. 10 | """ 11 | import json 12 | import sys 13 | 14 | 15 | def parse_with_config(parser): 16 | args = parser.parse_args() 17 | if args.config is not None: 18 | config_args = json.load(open(args.config)) 19 | override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] 20 | if arg.startswith('--')} 21 | for k, v in config_args.items(): 22 | if k not in override_keys: 23 | setattr(args, k, v) 24 | del args.config 25 | return args 26 | 27 | 28 | class Struct(object): 29 | def __init__(self, dict_): 30 | self.__dict__.update(dict_) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | This code is modified by Linjie Li from Jin-Hwa Kim's repository. 6 | https://github.com/jnhwkim/ban-vqa 7 | MIT License 8 | """ 9 | from __future__ import print_function 10 | import os 11 | import json 12 | import pickle 13 | import numpy as np 14 | import utils 15 | import h5py 16 | import torch 17 | from torch.utils.data import Dataset 18 | import tools.compute_softscore 19 | import itertools 20 | import math 21 | 22 | # TODO: merge dataset_cp_v2.py with dataset.py 23 | 24 | COUNTING_ONLY = False 25 | 26 | 27 | # Following Trott et al. (ICLR 2018) 28 | # Interpretable Counting for Visual Question Answering 29 | def is_howmany(q, a, label2ans): 30 | if 'how many' in q.lower() or \ 31 | ('number of' in q.lower() and 'number of the' not in q.lower()) or \ 32 | 'amount of' in q.lower() or \ 33 | 'count of' in q.lower(): 34 | if a is None or answer_filter(a, label2ans): 35 | return True 36 | else: 37 | return False 38 | else: 39 | return False 40 | 41 | 42 | def answer_filter(answers, label2ans, max_num=10): 43 | for ans in answers['labels']: 44 | if label2ans[ans].isdigit() and max_num >= int(label2ans[ans]): 45 | return True 46 | return False 47 | 48 | 49 | class Dictionary(object): 50 | def __init__(self, word2idx=None, idx2word=None): 51 | if word2idx is None: 52 | word2idx = {} 53 | if idx2word is None: 54 | idx2word = [] 55 | self.word2idx = word2idx 56 | self.idx2word = idx2word 57 | 58 | @property 59 | def ntoken(self): 60 | return len(self.word2idx) 61 | 62 | @property 63 | def padding_idx(self): 64 | return len(self.word2idx) 65 | 66 | def tokenize(self, sentence, add_word): 67 | sentence = sentence.lower() 68 | sentence = sentence.replace(',', '')\ 69 | .replace('?', '').replace('\'s', ' \'s') 70 | words = sentence.split() 71 | tokens = [] 72 | if add_word: 73 | for w in words: 74 | tokens.append(self.add_word(w)) 75 | else: 76 | for w in words: 77 | # the least frequent word (`bebe`) as UNK 78 | # for Visual Genome dataset 79 | tokens.append(self.word2idx.get(w, self.padding_idx-1)) 80 | return tokens 81 | 82 | def dump_to_file(self, path): 83 | pickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 84 | print('dictionary dumped to %s' % path) 85 | 86 | @classmethod 87 | def load_from_file(cls, path): 88 | print('loading dictionary from %s' % path) 89 | word2idx, idx2word = pickle.load(open(path, 'rb')) 90 | d = cls(word2idx, idx2word) 91 | return d 92 | 93 | def add_word(self, word): 94 | if word not in self.word2idx: 95 | self.idx2word.append(word) 96 | self.word2idx[word] = len(self.idx2word) - 1 97 | return self.word2idx[word] 98 | 99 | def __len__(self): 100 | return len(self.idx2word) 101 | 102 | 103 | def _create_entry(img, question, answer): 104 | if answer is not None: 105 | answer.pop('image_id') 106 | answer.pop('question_id') 107 | entry = { 108 | 'question_id': question['question_id'], 109 | 'image_id': question['image_id'], 110 | 'image': img, 111 | 'question': question['question'], 112 | 'answer': answer} 113 | return entry 114 | 115 | 116 | def _load_dataset(dataroot, name, img_id2val, label2ans): 117 | """Load entries 118 | 119 | img_id2val: dict {img_id -> val} val can be used to 120 | retrieve image or features 121 | dataroot: root path of dataset 122 | name: 'train', 'val', 'test-dev2015', test2015' 123 | """ 124 | question_path = os.path.join( 125 | dataroot, 'Questions/v2_OpenEnded_mscoco_%s_questions.json' % 126 | (name + '2014' if 'test' != name[:4] else name)) 127 | questions = sorted(json.load(open(question_path))['questions'], 128 | key=lambda x: x['question_id']) 129 | # train, val 130 | if 'test' != name[:4]: 131 | answer_path = os.path.join(dataroot, 'cache', '%s_target.pkl' % name) 132 | answers = pickle.load(open(answer_path, 'rb')) 133 | answers = sorted(answers, key=lambda x: x['question_id']) 134 | 135 | utils.assert_eq(len(questions), len(answers)) 136 | entries = [] 137 | for question, answer in zip(questions, answers): 138 | utils.assert_eq(question['question_id'], answer['question_id']) 139 | utils.assert_eq(question['image_id'], answer['image_id']) 140 | img_id = question['image_id'] 141 | if not COUNTING_ONLY \ 142 | or is_howmany(question['question'], answer, label2ans): 143 | entries.append(_create_entry(img_id2val[img_id], 144 | question, answer)) 145 | # test2015 146 | else: 147 | entries = [] 148 | for question in questions: 149 | img_id = question['image_id'] 150 | if not COUNTING_ONLY \ 151 | or is_howmany(question['question'], None, None): 152 | entries.append(_create_entry(img_id2val[img_id], 153 | question, None)) 154 | 155 | return entries 156 | 157 | 158 | def _load_visualgenome(dataroot, name, img_id2val, label2ans, adaptive=True): 159 | """Load entries 160 | 161 | img_id2val: dict {img_id -> val} val can be used to 162 | retrieve image or features 163 | dataroot: root path of dataset 164 | name: 'train', 'val' 165 | """ 166 | question_path = os.path.join(dataroot, 167 | 'visualGenome/question_answers.json') 168 | image_data_path = os.path.join(dataroot, 169 | 'visualGenome/image_data.json') 170 | ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl') 171 | cache_path = os.path.join(dataroot, 'cache', 'vg_%s%s_target.pkl' % 172 | (name, '_adaptive' if adaptive else '')) 173 | 174 | if os.path.isfile(cache_path): 175 | entries = pickle.load(open(cache_path, 'rb')) 176 | else: 177 | entries = [] 178 | ans2label = pickle.load(open(ans2label_path, 'rb')) 179 | vgq = json.load(open(question_path, 'r')) 180 | # 108,077 images 181 | _vgv = json.load(open(image_data_path, 'r')) 182 | vgv = {} 183 | for _v in _vgv: 184 | if _v['coco_id']: 185 | vgv[_v['image_id']] = _v['coco_id'] 186 | # used image, used question, total question, out-of-split 187 | counts = [0, 0, 0, 0] 188 | for vg in vgq: 189 | coco_id = vgv.get(vg['id'], None) 190 | if coco_id is not None: 191 | counts[0] += 1 192 | img_idx = img_id2val.get(coco_id, None) 193 | if img_idx is None: 194 | counts[3] += 1 195 | for q in vg['qas']: 196 | counts[2] += 1 197 | _answer = tools.compute_softscore.preprocess_answer( 198 | q['answer']) 199 | label = ans2label.get(_answer, None) 200 | if label and img_idx: 201 | counts[1] += 1 202 | answer = { 203 | 'labels': [label], 204 | 'scores': [1.]} 205 | entry = { 206 | 'question_id': q['qa_id'], 207 | 'image_id': coco_id, 208 | 'image': img_idx, 209 | 'question': q['question'], 210 | 'answer': answer} 211 | if not COUNTING_ONLY \ 212 | or is_howmany(q['question'], answer, label2ans): 213 | entries.append(entry) 214 | 215 | print('Loading VisualGenome %s' % name) 216 | print('\tUsed COCO images: %d/%d (%.4f)' % 217 | (counts[0], len(_vgv), counts[0]/len(_vgv))) 218 | print('\tOut-of-split COCO images: %d/%d (%.4f)' % 219 | (counts[3], counts[0], counts[3]/counts[0])) 220 | print('\tUsed VG questions: %d/%d (%.4f)' % 221 | (counts[1], counts[2], counts[1]/counts[2])) 222 | with open(cache_path, 'wb') as f: 223 | pickle.dump(entries, open(cache_path, 'wb')) 224 | 225 | return entries 226 | 227 | 228 | def _find_coco_id(vgv, vgv_id): 229 | for v in vgv: 230 | if v['image_id'] == vgv_id: 231 | return v['coco_id'] 232 | return None 233 | 234 | 235 | class VQAFeatureDataset(Dataset): 236 | def __init__(self, name, dictionary, relation_type, dataroot='data', 237 | adaptive=False, pos_emb_dim=64, nongt_dim=36): 238 | super(VQAFeatureDataset, self).__init__() 239 | assert name in ['train', 'val', 'test-dev2015', 'test2015'] 240 | 241 | ans2label_path = os.path.join(dataroot, 'cache', 242 | 'trainval_ans2label.pkl') 243 | label2ans_path = os.path.join(dataroot, 'cache', 244 | 'trainval_label2ans.pkl') 245 | self.ans2label = pickle.load(open(ans2label_path, 'rb')) 246 | self.label2ans = pickle.load(open(label2ans_path, 'rb')) 247 | self.num_ans_candidates = len(self.ans2label) 248 | self.dictionary = dictionary 249 | self.relation_type = relation_type 250 | self.adaptive = adaptive 251 | prefix = '36' 252 | if 'test' in name: 253 | prefix = '_36' 254 | 255 | h5_dataroot = dataroot+"/Bottom-up-features-adaptive"\ 256 | if self.adaptive else dataroot+"/Bottom-up-features-fixed" 257 | imgid_dataroot = dataroot+"/imgids" 258 | 259 | self.img_id2idx = pickle.load( 260 | open(os.path.join(imgid_dataroot, '%s%s_imgid2idx.pkl' % 261 | (name, '' if self.adaptive else prefix)), 'rb')) 262 | 263 | h5_path = os.path.join(h5_dataroot, '%s%s.hdf5' % 264 | (name, '' if self.adaptive else prefix)) 265 | 266 | print('loading features from h5 file %s' % h5_path) 267 | with h5py.File(h5_path, 'r') as hf: 268 | self.features = np.array(hf.get('image_features')) 269 | self.normalized_bb = np.array(hf.get('spatial_features')) 270 | self.bb = np.array(hf.get('image_bb')) 271 | if "semantic_adj_matrix" in hf.keys() \ 272 | and self.relation_type == "semantic": 273 | self.semantic_adj_matrix = np.array( 274 | hf.get('semantic_adj_matrix')) 275 | print("Loaded semantic adj matrix from file...", 276 | self.semantic_adj_matrix.shape) 277 | else: 278 | self.semantic_adj_matrix = None 279 | print("Setting semantic adj matrix to None...") 280 | if "image_adj_matrix" in hf.keys()\ 281 | and self.relation_type == "spatial": 282 | self.spatial_adj_matrix = np.array(hf.get('image_adj_matrix')) 283 | print("Loaded spatial adj matrix from file...", 284 | self.spatial_adj_matrix.shape) 285 | else: 286 | self.spatial_adj_matrix = None 287 | print("Setting spatial adj matrix to None...") 288 | 289 | self.pos_boxes = None 290 | if self.adaptive: 291 | self.pos_boxes = np.array(hf.get('pos_boxes')) 292 | self.entries = _load_dataset(dataroot, name, self.img_id2idx, 293 | self.label2ans) 294 | self.tokenize() 295 | 296 | self.tensorize() 297 | self.nongt_dim = nongt_dim 298 | self.emb_dim = pos_emb_dim 299 | self.v_dim = self.features.size(1 if self.adaptive else 2) 300 | self.s_dim = self.normalized_bb.size(1 if self.adaptive else 2) 301 | 302 | def tokenize(self, max_length=14): 303 | """Tokenizes the questions. 304 | 305 | This will add q_token in each entry of the dataset. 306 | -1 represent nil, and should be treated as padding_idx in embedding 307 | """ 308 | for entry in self.entries: 309 | tokens = self.dictionary.tokenize(entry['question'], False) 310 | tokens = tokens[:max_length] 311 | if len(tokens) < max_length: 312 | # Note here we pad to the back of the sentence 313 | padding = [self.dictionary.padding_idx] * \ 314 | (max_length - len(tokens)) 315 | tokens = tokens + padding 316 | utils.assert_eq(len(tokens), max_length) 317 | entry['q_token'] = tokens 318 | 319 | def tensorize(self): 320 | self.features = torch.from_numpy(self.features) 321 | self.normalized_bb = torch.from_numpy(self.normalized_bb) 322 | self.bb = torch.from_numpy(self.bb) 323 | if self.semantic_adj_matrix is not None: 324 | self.semantic_adj_matrix = torch.from_numpy( 325 | self.semantic_adj_matrix).double() 326 | if self.spatial_adj_matrix is not None: 327 | self.spatial_adj_matrix = torch.from_numpy( 328 | self.spatial_adj_matrix).double() 329 | if self.pos_boxes is not None: 330 | self.pos_boxes = torch.from_numpy(self.pos_boxes) 331 | 332 | for entry in self.entries: 333 | question = torch.from_numpy(np.array(entry['q_token'])) 334 | entry['q_token'] = question 335 | 336 | answer = entry['answer'] 337 | if answer is not None: 338 | labels = np.array(answer['labels']) 339 | scores = np.array(answer['scores'], dtype=np.float32) 340 | if len(labels): 341 | labels = torch.from_numpy(labels) 342 | scores = torch.from_numpy(scores) 343 | entry['answer']['labels'] = labels 344 | entry['answer']['scores'] = scores 345 | else: 346 | entry['answer']['labels'] = None 347 | entry['answer']['scores'] = None 348 | 349 | def __getitem__(self, index): 350 | entry = self.entries[index] 351 | raw_question = entry["question"] 352 | image_id = entry["image_id"] 353 | 354 | question = entry['q_token'] 355 | question_id = entry['question_id'] 356 | if self.spatial_adj_matrix is not None: 357 | spatial_adj_matrix = self.spatial_adj_matrix[entry["image"]] 358 | else: 359 | spatial_adj_matrix = torch.zeros(1).double() 360 | if self.semantic_adj_matrix is not None: 361 | semantic_adj_matrix = self.semantic_adj_matrix[entry["image"]] 362 | else: 363 | semantic_adj_matrix = torch.zeros(1).double() 364 | if not self.adaptive: 365 | # fixed number of bounding boxes 366 | features = self.features[entry['image']] 367 | normalized_bb = self.normalized_bb[entry['image']] 368 | bb = self.bb[entry["image"]] 369 | else: 370 | features = self.features[ 371 | self.pos_boxes[ 372 | entry['image']][0]:self.pos_boxes[entry['image']][1], :] 373 | normalized_bb = self.normalized_bb[ 374 | self.pos_boxes[ 375 | entry['image']][0]:self.pos_boxes[entry['image']][1], :] 376 | bb = self.bb[ 377 | self.pos_boxes[ 378 | entry['image']][0]:self.pos_boxes[entry['image']][1], :] 379 | 380 | answer = entry['answer'] 381 | if answer is not None: 382 | labels = answer['labels'] 383 | scores = answer['scores'] 384 | target = torch.zeros(self.num_ans_candidates) 385 | if labels is not None: 386 | target.scatter_(0, labels, scores) 387 | return features, normalized_bb, question, target,\ 388 | question_id, image_id, bb, spatial_adj_matrix,\ 389 | semantic_adj_matrix 390 | 391 | else: 392 | return features, normalized_bb, question, question_id,\ 393 | question_id, image_id, bb, spatial_adj_matrix,\ 394 | semantic_adj_matrix 395 | 396 | def __len__(self): 397 | return len(self.entries) 398 | 399 | 400 | class VisualGenomeFeatureDataset(Dataset): 401 | def __init__(self, name, features, normalized_bb, bb, 402 | spatial_adj_matrix, semantic_adj_matrix, dictionary, 403 | relation_type, dataroot='data', adaptive=False, 404 | pos_boxes=None, pos_emb_dim=64): 405 | super(VisualGenomeFeatureDataset, self).__init__() 406 | # do not use test split images! 407 | assert name in ['train', 'val'] 408 | print('loading Visual Genome data %s' % name) 409 | ans2label_path = os.path.join(dataroot, 'cache', 410 | 'trainval_ans2label.pkl') 411 | label2ans_path = os.path.join(dataroot, 'cache', 412 | 'trainval_label2ans.pkl') 413 | self.ans2label = pickle.load(open(ans2label_path, 'rb')) 414 | self.label2ans = pickle.load(open(label2ans_path, 'rb')) 415 | self.num_ans_candidates = len(self.ans2label) 416 | 417 | self.dictionary = dictionary 418 | self.adaptive = adaptive 419 | 420 | self.img_id2idx = pickle.load( 421 | open(os.path.join(dataroot, 'imgids/%s%s_imgid2idx.pkl' % 422 | (name, '' if self.adaptive else '36')), 423 | 'rb')) 424 | self.bb = bb 425 | self.features = features 426 | self.normalized_bb = normalized_bb 427 | self.spatial_adj_matrix = spatial_adj_matrix 428 | self.semantic_adj_matrix = semantic_adj_matrix 429 | 430 | if self.adaptive: 431 | self.pos_boxes = pos_boxes 432 | 433 | self.entries = _load_visualgenome(dataroot, name, self.img_id2idx, 434 | self.label2ans, 435 | adaptive=self.adaptive) 436 | self.tokenize() 437 | self.tensorize() 438 | self.emb_dim = pos_emb_dim 439 | self.v_dim = self.features.size(1 if self.adaptive else 2) 440 | self.s_dim = self.normalized_bb.size(1 if self.adaptive else 2) 441 | 442 | def tokenize(self, max_length=14): 443 | """Tokenizes the questions. 444 | 445 | This will add q_token in each entry of the dataset. 446 | -1 represent nil, and should be treated as padding_idx in embedding 447 | """ 448 | for entry in self.entries: 449 | tokens = self.dictionary.tokenize(entry['question'], False) 450 | tokens = tokens[:max_length] 451 | if len(tokens) < max_length: 452 | # Note here we pad in front of the sentence 453 | padding = [self.dictionary.padding_idx] * \ 454 | (max_length - len(tokens)) 455 | tokens = tokens + padding 456 | utils.assert_eq(len(tokens), max_length) 457 | entry['q_token'] = tokens 458 | 459 | def tensorize(self): 460 | for entry in self.entries: 461 | question = torch.from_numpy(np.array(entry['q_token'])) 462 | entry['q_token'] = question 463 | 464 | answer = entry['answer'] 465 | labels = np.array(answer['labels']) 466 | scores = np.array(answer['scores'], dtype=np.float32) 467 | if len(labels): 468 | labels = torch.from_numpy(labels) 469 | scores = torch.from_numpy(scores) 470 | entry['answer']['labels'] = labels 471 | entry['answer']['scores'] = scores 472 | else: 473 | entry['answer']['labels'] = None 474 | entry['answer']['scores'] = None 475 | 476 | def __getitem__(self, index): 477 | entry = self.entries[index] 478 | raw_question = entry["question"] 479 | image_id = entry["image_id"] 480 | question = entry['q_token'] 481 | question_id = entry['question_id'] 482 | answer = entry['answer'] 483 | if self.spatial_adj_matrix is not None: 484 | spatial_adj_matrix = self.spatial_adj_matrix[entry["image"]] 485 | else: 486 | spatial_adj_matrix = torch.zeros(1).double() 487 | if self.semantic_adj_matrix is not None: 488 | semantic_adj_matrix = self.semantic_adj_matrix[entry["image"]] 489 | else: 490 | semantic_adj_matrix = torch.zeros(1).double() 491 | if self.adaptive: 492 | features = self.features[ 493 | self.pos_boxes[ 494 | entry['image']][0]:self.pos_boxes[entry['image']][1], :] 495 | normalized_bb = self.normalized_bb[ 496 | self.pos_boxes[ 497 | entry['image']][0]:self.pos_boxes[entry['image']][1], :] 498 | bb = self.bb[self.pos_boxes[ 499 | entry['image']][0]:self.pos_boxes[entry['image']][1], :] 500 | else: 501 | features = self.features[entry['image']] 502 | normalized_bb = self.normalized_bb[entry['image']] 503 | bb = self.bb[entry['image']] 504 | 505 | labels = answer['labels'] 506 | scores = answer['scores'] 507 | target = torch.zeros(self.num_ans_candidates) 508 | if labels is not None: 509 | target.scatter_(0, labels, scores) 510 | return features, normalized_bb, question, target, raw_question,\ 511 | image_id, bb, spatial_adj_matrix, semantic_adj_matrix 512 | 513 | def __len__(self): 514 | return len(self.entries) 515 | 516 | 517 | def tfidf_from_questions(names, dictionary, dataroot='data', 518 | target=['vqa', 'vg']): 519 | # rows, cols for uncoalesce sparse matrix 520 | inds = [[], []] 521 | df = dict() 522 | N = len(dictionary) 523 | 524 | def populate(inds, df, text): 525 | tokens = dictionary.tokenize(text, True) 526 | for t in tokens: 527 | df[t] = df.get(t, 0) + 1 528 | combin = list(itertools.combinations(tokens, 2)) 529 | for c in combin: 530 | if c[0] < N: 531 | inds[0].append(c[0]) 532 | inds[1].append(c[1]) 533 | if c[1] < N: 534 | inds[0].append(c[1]) 535 | inds[1].append(c[0]) 536 | 537 | # VQA 2.0 538 | if 'vqa' in target: 539 | for name in names: 540 | assert name in ['train', 'val', 'test-dev2015', 'test2015'] 541 | question_path = os.path.join( 542 | dataroot, 'Questions/v2_OpenEnded_mscoco_%s_questions.json' % 543 | (name + '2014' if 'test' != name[:4] else name)) 544 | questions = json.load(open(question_path))['questions'] 545 | 546 | for question in questions: 547 | populate(inds, df, question['question']) 548 | 549 | # Visual Genome 550 | if 'vg' in target: 551 | question_path = os.path.join(dataroot, 'visualGenome', 552 | 'question_answers.json') 553 | vgq = json.load(open(question_path, 'r')) 554 | for vg in vgq: 555 | for q in vg['qas']: 556 | populate(inds, df, q['question']) 557 | 558 | # TF-IDF 559 | vals = np.ones((len(inds[1]))) 560 | for idx, col in enumerate(inds[1]): 561 | assert df[col] >= 1, 'document frequency should be greater than zero!' 562 | vals[col] /= df[col] 563 | 564 | # Make stochastic matrix 565 | def normalize(inds, vals): 566 | z = dict() 567 | for row, val in zip(inds[0], vals): 568 | z[row] = z.get(row, 0) + val 569 | for idx, row in enumerate(inds[0]): 570 | vals[idx] /= z[row] 571 | return vals 572 | 573 | vals = normalize(inds, vals) 574 | 575 | tfidf = torch.sparse.FloatTensor(torch.LongTensor(inds), 576 | torch.FloatTensor(vals)) 577 | tfidf = tfidf.coalesce() 578 | 579 | # Latent word embeddings 580 | emb_dim = 300 581 | glove_file = dataroot+'/glove/glove.6B.%dd.txt' % emb_dim 582 | weights, word2emb = utils.create_glove_embedding_init( 583 | dictionary.idx2word[N:], glove_file) 584 | print('tf-idf stochastic matrix (%d x %d) is generated.' % (tfidf.size(0), 585 | tfidf.size(1))) 586 | 587 | return tfidf, weights 588 | 589 | 590 | # VisualGenome Train 591 | # Used COCO images: 51487/108077 (0.4764) 592 | # Out-of-split COCO images: 17464/51487 (0.3392) 593 | # Used VG questions: 325311/726932 (0.4475) 594 | 595 | # VisualGenome Val 596 | # Used COCO images: 51487/108077 (0.4764) 597 | # Out-of-split COCO images: 34023/51487 (0.6608) 598 | # Used VG questions: 166409/726932 (0.2289) 599 | -------------------------------------------------------------------------------- /dataset_cp_v2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Relation-aware Graph Attention Network for Visual Question Answering 6 | Linjie Li, Zhe Gan, Yu Cheng, Jingjing Liu 7 | https://arxiv.org/abs/1903.12314 8 | 9 | This code is written by Linjie Li. 10 | """ 11 | from __future__ import print_function 12 | import os 13 | import json 14 | import pickle 15 | import numpy as np 16 | import utils 17 | import h5py 18 | import torch 19 | from torch.utils.data import Dataset 20 | import tools.compute_softscore 21 | from dataset import is_howmany 22 | 23 | COUNTING_ONLY = False 24 | 25 | 26 | def _create_entry(img, question, answer): 27 | if answer is not None: 28 | answer.pop('image_id') 29 | answer.pop('question_id') 30 | entry = { 31 | 'question_id': question['question_id'], 32 | 'image_id': question['image_id'], 33 | 'image': img, 34 | 'coco_split': question["coco_split"], 35 | 'question': question['question'], 36 | 'answer': answer} 37 | return entry 38 | 39 | 40 | def _load_dataset(dataroot, name, coco_train_img_id2val, coco_val_img_id2val, 41 | label2ans): 42 | """Load entries 43 | 44 | coco_train_img_id2val/coco_val_img_id2val: 45 | dict {img_id -> val} val can be used to retrieve image or features 46 | dataroot: root path of dataset 47 | name: 'train', 'val' 48 | """ 49 | question_path = os.path.join( 50 | dataroot, 'cp_v2_questions/vqacp_v2_%s_questions.json' % name) 51 | questions = sorted(json.load(open(question_path)), 52 | key=lambda x: x['question_id']) 53 | answer_path = os.path.join(dataroot, 'cache', 'cp_v2_%s_target.pkl' % name) 54 | answers = pickle.load(open(answer_path, 'rb')) 55 | answers = sorted(answers, key=lambda x: x['question_id']) 56 | 57 | utils.assert_eq(len(questions), len(answers)) 58 | entries = [] 59 | for question, answer in zip(questions, answers): 60 | utils.assert_eq(question['question_id'], answer['question_id']) 61 | utils.assert_eq(question['image_id'], answer['image_id']) 62 | img_id = question['image_id'] 63 | coco_split = question["coco_split"] 64 | index = coco_train_img_id2val[img_id]\ 65 | if coco_split == "train2014" else coco_val_img_id2val[img_id] 66 | if not COUNTING_ONLY \ 67 | or is_howmany(question['question'], answer, label2ans): 68 | entries.append(_create_entry(index, question, answer)) 69 | return entries 70 | 71 | 72 | class Image_Feature_Loader(): 73 | def __init__(self, coco_split, relation_type, dataroot='data', 74 | adaptive=True): 75 | super(Image_Feature_Loader, self).__init__() 76 | assert coco_split in ['train', 'val'] 77 | self.adaptive = adaptive 78 | self.relation_type = relation_type 79 | prefix = '36' 80 | 81 | self.img_id2idx = pickle.load( 82 | open(os.path.join(dataroot, 'imgids/%s%s_imgid2idx.pkl' % 83 | (coco_split, '' if self.adaptive else prefix)), 84 | 'rb')) 85 | h5_dataroot = dataroot+"/Bottom-up-features-adaptive" \ 86 | if self.adaptive else dataroot+"/Bottom-up-features-fixed" 87 | h5_path = os.path.join(h5_dataroot, 88 | '%s%s.hdf5' % (coco_split, 89 | '' if self.adaptive else prefix)) 90 | 91 | print('loading features from h5 file %s' % h5_path) 92 | with h5py.File(h5_path, 'r') as hf: 93 | self.features = np.array(hf.get('image_features')) 94 | self.spatials = np.array(hf.get('spatial_features')) 95 | self.bb = np.array(hf.get('image_bb')) 96 | if "semantic_adj_matrix" in hf.keys() \ 97 | and self.relation_type == "semantic": 98 | self.semantic_adj_matrix = np.array( 99 | hf.get('semantic_adj_matrix')) 100 | print("Loaded semantic adj matrix from file...", 101 | self.semantic_adj_matrix.shape) 102 | else: 103 | self.semantic_adj_matrix = None 104 | print("Setting semantic adj matrix to None...") 105 | if "image_adj_matrix" in hf.keys() \ 106 | and self.relation_type == "spatial": 107 | self.spatial_adj_matrix = np.array(hf.get('image_adj_matrix')) 108 | print("Loaded spatial adj matrix from file...", 109 | self.spatial_adj_matrix.shape) 110 | else: 111 | self.spatial_adj_matrix = None 112 | print("Setting spatial adj matrix to None...") 113 | self.pos_boxes = None 114 | if self.adaptive: 115 | self.pos_boxes = np.array(hf.get('pos_boxes')) 116 | self.tensorize() 117 | 118 | def tensorize(self): 119 | self.features = torch.from_numpy(self.features) 120 | self.spatials = torch.from_numpy(self.spatials) 121 | self.bb = torch.from_numpy(self.bb) 122 | if self.semantic_adj_matrix is not None: 123 | self.semantic_adj_matrix = torch.from_numpy( 124 | self.semantic_adj_matrix).double() 125 | if self.spatial_adj_matrix is not None: 126 | self.spatial_adj_matrix = torch.from_numpy( 127 | self.spatial_adj_matrix).double() 128 | if self.pos_boxes is not None: 129 | self.pos_boxes = torch.from_numpy(self.pos_boxes) 130 | 131 | 132 | class VQA_cp_Dataset(Dataset): 133 | def __init__(self, name, dictionary, coco_train_features, 134 | coco_val_features, dataroot='data', adaptive=False, 135 | pos_emb_dim=64): 136 | super(VQA_cp_Dataset, self).__init__() 137 | assert name in ['train', 'test'] 138 | 139 | ans2label_path = os.path.join(dataroot, 'cache', 140 | 'trainval_ans2label.pkl') 141 | label2ans_path = os.path.join(dataroot, 'cache', 142 | 'trainval_label2ans.pkl') 143 | self.ans2label = pickle.load(open(ans2label_path, 'rb')) 144 | self.label2ans = pickle.load(open(label2ans_path, 'rb')) 145 | self.num_ans_candidates = len(self.ans2label) 146 | self.dictionary = dictionary 147 | self.adaptive = adaptive 148 | self.relation_type = coco_train_features.relation_type 149 | self.coco_train_features = coco_train_features 150 | self.coco_val_features = coco_val_features 151 | self.entries = _load_dataset(dataroot, name, 152 | self.coco_train_features.img_id2idx, 153 | self.coco_val_features.img_id2idx, 154 | self.label2ans) 155 | self.tokenize() 156 | self.tensorize() 157 | self.emb_dim = pos_emb_dim 158 | self.v_dim = self.coco_train_features.features.size(1 if self.adaptive 159 | else 2) 160 | self.s_dim = self.coco_train_features.spatials.size(1 if self.adaptive 161 | else 2) 162 | 163 | def tokenize(self, max_length=14): 164 | """Tokenizes the questions. 165 | 166 | This will add q_token in each entry of the dataset. 167 | -1 represent nil, and should be treated as padding_idx in embedding 168 | """ 169 | for entry in self.entries: 170 | tokens = self.dictionary.tokenize(entry['question'], False) 171 | tokens = tokens[:max_length] 172 | if len(tokens) < max_length: 173 | # Note here we pad to the back of the sentence 174 | padding = [self.dictionary.padding_idx] * \ 175 | (max_length - len(tokens)) 176 | tokens = tokens + padding 177 | utils.assert_eq(len(tokens), max_length) 178 | entry['q_token'] = tokens 179 | 180 | def tensorize(self): 181 | for entry in self.entries: 182 | question = torch.from_numpy(np.array(entry['q_token'])) 183 | entry['q_token'] = question 184 | 185 | answer = entry['answer'] 186 | if answer is not None: 187 | labels = np.array(answer['labels']) 188 | scores = np.array(answer['scores'], dtype=np.float32) 189 | if len(labels): 190 | labels = torch.from_numpy(labels) 191 | scores = torch.from_numpy(scores) 192 | entry['answer']['labels'] = labels 193 | entry['answer']['scores'] = scores 194 | else: 195 | entry['answer']['labels'] = None 196 | entry['answer']['scores'] = None 197 | 198 | def __getitem__(self, index): 199 | entry = self.entries[index] 200 | raw_question = entry["question"] 201 | image_id = entry["image_id"] 202 | coco_split = entry["coco_split"] 203 | 204 | question = entry['q_token'] 205 | question_id = entry['question_id'] 206 | if "train" in coco_split: 207 | coco_features = self.coco_train_features 208 | elif "val" in coco_split: 209 | coco_features = self.coco_val_features 210 | else: 211 | print("Unknown coco split: %s" % coco_split) 212 | 213 | if coco_features.spatial_adj_matrix is not None: 214 | spatial_adj_matrix = coco_features.spatial_adj_matrix[ 215 | entry["image"]] 216 | else: 217 | spatial_adj_matrix = torch.zeros(1).double() 218 | if coco_features.semantic_adj_matrix is not None: 219 | semantic_adj_matrix = coco_features.semantic_adj_matrix[ 220 | entry["image"]] 221 | else: 222 | semantic_adj_matrix = torch.zeros(1).double() 223 | 224 | if not self.adaptive: 225 | # fixed number of bounding boxes 226 | features = coco_features.features[entry['image']] 227 | spatials = coco_features.spatials[entry['image']] 228 | bb = coco_features.bb[entry["image"]] 229 | else: 230 | features = coco_features.features[ 231 | coco_features.pos_boxes[ 232 | entry['image']][0]:coco_features.pos_boxes[ 233 | entry['image']][1], :] 234 | spatials = coco_features.spatials[ 235 | coco_features.pos_boxes[ 236 | entry['image']][0]:coco_features.pos_boxes[ 237 | entry['image']][1], :] 238 | bb = coco_features.bb[ 239 | coco_features.pos_boxes[ 240 | entry['image']][0]:coco_features.pos_boxes[ 241 | entry['image']][1], :] 242 | 243 | answer = entry['answer'] 244 | if answer is not None: 245 | labels = answer['labels'] 246 | scores = answer['scores'] 247 | target = torch.zeros(self.num_ans_candidates) 248 | if labels is not None: 249 | target.scatter_(0, labels, scores) 250 | return features, spatials, question, target, question_id,\ 251 | image_id, bb, spatial_adj_matrix, semantic_adj_matrix 252 | 253 | else: 254 | return features, spatials, question, question_id, question_id,\ 255 | image_id, bb, spatial_adj_matrix, semantic_adj_matrix 256 | 257 | def __len__(self): 258 | return len(self.entries) 259 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | This code is modified by Linjie Li from Jin-Hwa Kim's repository. 6 | https://github.com/jnhwkim/ban-vqa 7 | MIT License 8 | """ 9 | import os 10 | import argparse 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.data import DataLoader 15 | from torch.autograd import Variable 16 | from tqdm import tqdm 17 | import json 18 | 19 | from dataset import Dictionary, VQAFeatureDataset 20 | from dataset_cp_v2 import VQA_cp_Dataset, Image_Feature_Loader 21 | from model.regat import build_regat 22 | from train import compute_score_with_logits 23 | from model.position_emb import prepare_graph_variables 24 | from config.parser import Struct 25 | import utils 26 | 27 | 28 | @torch.no_grad() 29 | def evaluate(model, dataloader, model_hps, args, device): 30 | model.eval() 31 | label2ans = dataloader.dataset.label2ans 32 | num_answers = len(label2ans) 33 | relation_type = dataloader.dataset.relation_type 34 | N = len(dataloader.dataset) 35 | results = [] 36 | score = 0 37 | pbar = tqdm(total=len(dataloader)) 38 | 39 | if args.save_logits: 40 | idx = 0 41 | pred_logits = np.zeros((N, num_answers)) 42 | gt_logits = np.zeros((N, num_answers)) 43 | 44 | for i, (v, norm_bb, q, target, qid, _, bb, 45 | spa_adj_matrix, sem_adj_matrix) in enumerate(dataloader): 46 | batch_size = v.size(0) 47 | num_objects = v.size(1) 48 | v = Variable(v).to(device) 49 | norm_bb = Variable(norm_bb).to(device) 50 | q = Variable(q).to(device) 51 | pos_emb, sem_adj_matrix, spa_adj_matrix = prepare_graph_variables( 52 | relation_type, bb, sem_adj_matrix, spa_adj_matrix, num_objects, 53 | model_hps.nongt_dim, model_hps.imp_pos_emb_dim, 54 | model_hps.spa_label_num, model_hps.sem_label_num, device) 55 | pred, att = model(v, norm_bb, q, pos_emb, sem_adj_matrix, 56 | spa_adj_matrix, None) 57 | # Check if target is a placeholder or actual targets 58 | if target.size(-1) == num_answers: 59 | target = Variable(target).to(device) 60 | batch_score = compute_score_with_logits( 61 | pred, target, device).sum() 62 | score += batch_score 63 | if args.save_logits: 64 | gt_logits[idx:batch_size+idx, :] = target.cpu().numpy() 65 | 66 | if args.save_logits: 67 | pred_logits[idx:batch_size+idx, :] = pred.cpu().numpy() 68 | idx += batch_size 69 | 70 | if args.save_answers: 71 | qid = qid.cpu() 72 | pred = pred.cpu() 73 | current_results = make_json(pred, qid, dataloader) 74 | results.extend(current_results) 75 | 76 | pbar.update(1) 77 | 78 | score = score / N 79 | results_folder = f"{args.output_folder}/results" 80 | if args.save_logits: 81 | utils.create_dir(results_folder) 82 | save_to = f"{results_folder}/logits_{args.dataset}" +\ 83 | f"_{args.split}.npy" 84 | np.save(save_to, pred_logits) 85 | 86 | utils.create_dir("./gt_logits") 87 | save_to = f"./gt_logits/{args.dataset}_{args.split}_gt.npy" 88 | if not os.path.exists(save_to): 89 | np.save(save_to, gt_logits) 90 | if args.save_answers: 91 | utils.create_dir(results_folder) 92 | save_to = f"{results_folder}/{args.dataset}_" +\ 93 | f"{args.split}.json" 94 | json.dump(results, open(save_to, "w")) 95 | return score 96 | 97 | 98 | def get_answer(p, dataloader): 99 | _m, idx = p.max(0) 100 | return dataloader.dataset.label2ans[idx.item()] 101 | 102 | 103 | def make_json(logits, qIds, dataloader): 104 | utils.assert_eq(logits.size(0), len(qIds)) 105 | results = [] 106 | for i in range(logits.size(0)): 107 | result = {} 108 | result['question_id'] = qIds[i].item() 109 | result['answer'] = get_answer(logits[i], dataloader) 110 | results.append(result) 111 | return results 112 | 113 | 114 | def parse_args(): 115 | parser = argparse.ArgumentParser() 116 | 117 | ''' 118 | For eval logistics 119 | ''' 120 | parser.add_argument('--save_logits', action='store_true', 121 | help='save logits') 122 | parser.add_argument('--save_answers', action='store_true', 123 | help='save poredicted answers') 124 | 125 | ''' 126 | For loading expert pre-trained weights 127 | ''' 128 | parser.add_argument('--checkpoint', type=int, default=-1) 129 | parser.add_argument('--output_folder', type=str, default="", 130 | help="checkpoint folder") 131 | 132 | ''' 133 | For dataset 134 | ''' 135 | parser.add_argument('--data_folder', type=str, default='./data') 136 | parser.add_argument('--dataset', type=str, default='vqa', 137 | choices=["vqa", "vqa_cp"]) 138 | parser.add_argument('--split', type=str, default="val", 139 | choices=["train", "val", "test", "test2015"], 140 | help="test for vqa_cp, test2015 for vqa") 141 | 142 | args = parser.parse_args() 143 | return args 144 | 145 | 146 | if __name__ == '__main__': 147 | args = parse_args() 148 | if not torch.cuda.is_available(): 149 | raise ValueError("CUDA is not available," + 150 | "this code currently only support GPU.") 151 | 152 | n_device = torch.cuda.device_count() 153 | print("Found %d GPU cards for eval" % (n_device)) 154 | device = torch.device("cuda") 155 | 156 | dictionary = Dictionary.load_from_file( 157 | os.path.join(args.data_folder, 'glove/dictionary.pkl')) 158 | 159 | hps_file = f'{args.output_folder}/hps.json' 160 | model_hps = Struct(json.load(open(hps_file))) 161 | batch_size = model_hps.batch_size*n_device 162 | 163 | print("Evaluating on %s dataset with model trained on %s dataset" % 164 | (args.dataset, model_hps.dataset)) 165 | if args.dataset == "vqa_cp": 166 | coco_train_features = Image_Feature_Loader( 167 | 'train', model_hps.relation_type, 168 | adaptive=model_hps.adaptive, 169 | dataroot=model_hps.data_folder) 170 | coco_val_features = Image_Feature_Loader( 171 | 'val', model_hps.relation_type, 172 | adaptive=model_hps.adaptive, 173 | dataroot=model_hps.data_folder) 174 | eval_dset = VQA_cp_Dataset( 175 | args.split, dictionary, coco_train_features, 176 | coco_val_features, adaptive=model_hps.adaptive, 177 | pos_emb_dim=model_hps.imp_pos_emb_dim, 178 | dataroot=model_hps.data_folder) 179 | else: 180 | eval_dset = VQAFeatureDataset( 181 | args.split, dictionary, model_hps.relation_type, 182 | adaptive=model_hps.adaptive, 183 | pos_emb_dim=model_hps.imp_pos_emb_dim, 184 | dataroot=model_hps.data_folder) 185 | 186 | model = build_regat(eval_dset, model_hps).to(device) 187 | 188 | model = nn.DataParallel(model).to(device) 189 | 190 | if args.checkpoint > 0: 191 | checkpoint_path = os.path.join( 192 | args.output_folder, 193 | f"model_{args.checkpoint}.pth") 194 | else: 195 | checkpoint_path = os.path.join(args.output_folder, 196 | f"model.pth") 197 | print("Loading weights from %s" % (checkpoint_path)) 198 | if not os.path.exists(checkpoint_path): 199 | raise ValueError("No such checkpoint exists!") 200 | checkpoint = torch.load(checkpoint_path) 201 | state_dict = checkpoint.get('model_state', checkpoint) 202 | matched_state_dict = {} 203 | unexpected_keys = set() 204 | missing_keys = set() 205 | for name, param in model.named_parameters(): 206 | missing_keys.add(name) 207 | for key, data in state_dict.items(): 208 | if key in missing_keys: 209 | matched_state_dict[key] = data 210 | missing_keys.remove(key) 211 | else: 212 | unexpected_keys.add(key) 213 | print("\tUnexpected_keys:", list(unexpected_keys)) 214 | print("\tMissing_keys:", list(missing_keys)) 215 | model.load_state_dict(matched_state_dict, strict=False) 216 | 217 | eval_loader = DataLoader( 218 | eval_dset, batch_size, shuffle=False, 219 | num_workers=4, collate_fn=utils.trim_collate) 220 | 221 | eval_score = evaluate( 222 | model, eval_loader, model_hps, args, device) 223 | 224 | print('\teval score: %.2f' % (100 * eval_score)) 225 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | ''' 5 | 6 | import os 7 | from os.path import join, exists 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader, ConcatDataset, random_split 12 | import random 13 | import json 14 | 15 | from dataset import Dictionary, VQAFeatureDataset, VisualGenomeFeatureDataset 16 | from dataset import tfidf_from_questions 17 | from dataset_cp_v2 import VQA_cp_Dataset, Image_Feature_Loader 18 | from model.regat import build_regat 19 | from config.parser import parse_with_config 20 | from train import train 21 | import utils 22 | from utils import trim_collate 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | ''' 28 | For training logistics 29 | ''' 30 | parser.add_argument('--epochs', type=int, default=20) 31 | parser.add_argument('--base_lr', type=float, default=1e-3) 32 | parser.add_argument('--lr_decay_start', type=int, default=15) 33 | parser.add_argument('--lr_decay_rate', type=float, default=0.25) 34 | parser.add_argument('--lr_decay_step', type=int, default=2) 35 | parser.add_argument('--lr_decay_based_on_val', action='store_true', 36 | help='Learning rate decay when val score descreases') 37 | parser.add_argument('--grad_accu_steps', type=int, default=1) 38 | parser.add_argument('--grad_clip', type=float, default=0.25) 39 | parser.add_argument('--weight_decay', type=float, default=0) 40 | parser.add_argument('--batch_size', type=int, default=64) 41 | parser.add_argument('--output', type=str, default='saved_models/') 42 | parser.add_argument('--save_optim', action='store_true', 43 | help='save optimizer') 44 | parser.add_argument('--log_interval', type=int, default=-1, 45 | help='Print log for certain steps') 46 | parser.add_argument('--seed', type=int, default=-1, help='random seed') 47 | 48 | ''' 49 | loading trained models 50 | ''' 51 | parser.add_argument('--checkpoint', type=str, default="") 52 | 53 | ''' 54 | For dataset 55 | ''' 56 | parser.add_argument('--dataset', type=str, default='vqa', 57 | choices=["vqa", "vqa_cp"]) 58 | parser.add_argument('--data_folder', type=str, default='./data') 59 | parser.add_argument('--use_both', action='store_true', 60 | help='use both train/val datasets to train?') 61 | parser.add_argument('--use_vg', action='store_true', 62 | help='use visual genome dataset to train?') 63 | parser.add_argument('--adaptive', action='store_true', 64 | help='adaptive or fixed number of regions') 65 | ''' 66 | Model 67 | ''' 68 | parser.add_argument('--relation_type', type=str, default='implicit', 69 | choices=["spatial", "semantic", "implicit"]) 70 | parser.add_argument('--fusion', type=str, default='mutan', 71 | choices=["ban", "butd", "mutan"]) 72 | parser.add_argument('--tfidf', action='store_true', 73 | help='tfidf word embedding?') 74 | parser.add_argument('--op', type=str, default='c', 75 | help="op used in tfidf word embedding") 76 | parser.add_argument('--num_hid', type=int, default=1024) 77 | ''' 78 | Fusion Hyperparamters 79 | ''' 80 | parser.add_argument('--ban_gamma', type=int, default=1, help='glimpse') 81 | parser.add_argument('--mutan_gamma', type=int, default=2, help='glimpse') 82 | ''' 83 | Hyper-params for relations 84 | ''' 85 | # hyper-parameters for implicit relation 86 | parser.add_argument('--imp_pos_emb_dim', type=int, default=64, 87 | help='geometric embedding feature dim') 88 | 89 | # hyper-parameters for explicit relation 90 | parser.add_argument('--spa_label_num', type=int, default=11, 91 | help='number of edge labels in spatial relation graph') 92 | parser.add_argument('--sem_label_num', type=int, default=15, 93 | help='number of edge labels in \ 94 | semantic relation graph') 95 | 96 | # shared hyper-parameters 97 | parser.add_argument('--dir_num', type=int, default=2, 98 | help='number of directions in relation graph') 99 | parser.add_argument('--relation_dim', type=int, default=1024, 100 | help='relation feature dim') 101 | parser.add_argument('--nongt_dim', type=int, default=20, 102 | help='number of objects consider relations per image') 103 | parser.add_argument('--num_heads', type=int, default=16, 104 | help='number of attention heads \ 105 | for multi-head attention') 106 | parser.add_argument('--num_steps', type=int, default=1, 107 | help='number of graph propagation steps') 108 | parser.add_argument('--residual_connection', action='store_true', 109 | help='Enable residual connection in relation encoder') 110 | parser.add_argument('--label_bias', action='store_true', 111 | help='Enable bias term for relation labels \ 112 | in relation encoder') 113 | 114 | # can use config files 115 | parser.add_argument('--config', help='JSON config files') 116 | 117 | args = parse_with_config(parser) 118 | return args 119 | 120 | 121 | if __name__ == '__main__': 122 | args = parse_args() 123 | if not torch.cuda.is_available(): 124 | raise ValueError("CUDA is not available," + 125 | "this code currently only support GPU.") 126 | n_device = torch.cuda.device_count() 127 | print("Found %d GPU cards for training" % (n_device)) 128 | device = torch.device("cuda") 129 | batch_size = args.batch_size*n_device 130 | 131 | torch.backends.cudnn.benchmark = True 132 | 133 | if args.seed != -1: 134 | print("Predefined randam seed %d" % args.seed) 135 | else: 136 | # fix seed 137 | args.seed = random.randint(1, 10000) 138 | print("Choose random seed %d" % args.seed) 139 | torch.manual_seed(args.seed) 140 | torch.cuda.manual_seed_all(args.seed) 141 | 142 | if "ban" == args.fusion: 143 | fusion_methods = args.fusion+"_"+str(args.ban_gamma) 144 | else: 145 | fusion_methods = args.fusion 146 | 147 | dictionary = Dictionary.load_from_file( 148 | join(args.data_folder, 'glove/dictionary.pkl')) 149 | if args.dataset == "vqa_cp": 150 | coco_train_features = Image_Feature_Loader( 151 | 'train', args.relation_type, 152 | adaptive=args.adaptive, dataroot=args.data_folder) 153 | coco_val_features = Image_Feature_Loader( 154 | 'val', args.relation_type, 155 | adaptive=args.adaptive, dataroot=args.data_folder) 156 | val_dset = VQA_cp_Dataset( 157 | 'test', dictionary, coco_train_features, coco_val_features, 158 | adaptive=args.adaptive, pos_emb_dim=args.imp_pos_emb_dim, 159 | dataroot=args.data_folder) 160 | train_dset = VQA_cp_Dataset( 161 | 'train', dictionary, coco_train_features, 162 | coco_val_features, adaptive=args.adaptive, 163 | pos_emb_dim=args.imp_pos_emb_dim, 164 | dataroot=args.data_folder) 165 | else: 166 | val_dset = VQAFeatureDataset( 167 | 'val', dictionary, args.relation_type, adaptive=args.adaptive, 168 | pos_emb_dim=args.imp_pos_emb_dim, dataroot=args.data_folder) 169 | train_dset = VQAFeatureDataset( 170 | 'train', dictionary, args.relation_type, 171 | adaptive=args.adaptive, pos_emb_dim=args.imp_pos_emb_dim, 172 | dataroot=args.data_folder) 173 | 174 | model = build_regat(val_dset, args).to(device) 175 | 176 | tfidf = None 177 | weights = None 178 | if args.tfidf: 179 | tfidf, weights = tfidf_from_questions(['train', 'val', 'test2015'], 180 | dictionary) 181 | model.w_emb.init_embedding(join(args.data_folder, 182 | 'glove/glove6b_init_300d.npy'), 183 | tfidf, weights) 184 | 185 | model = nn.DataParallel(model).to(device) 186 | 187 | if args.checkpoint != "": 188 | print("Loading weights from %s" % (args.checkpoint)) 189 | if not os.path.exists(args.checkpoint): 190 | raise ValueError("No such checkpoint exists!") 191 | checkpoint = torch.load(args.checkpoint) 192 | state_dict = checkpoint.get('model_state', checkpoint) 193 | matched_state_dict = {} 194 | unexpected_keys = set() 195 | missing_keys = set() 196 | for name, param in model.named_parameters(): 197 | missing_keys.add(name) 198 | for key, data in state_dict.items(): 199 | if key in missing_keys: 200 | matched_state_dict[key] = data 201 | missing_keys.remove(key) 202 | else: 203 | unexpected_keys.add(key) 204 | print("Unexpected_keys:", list(unexpected_keys)) 205 | print("Missing_keys:", list(missing_keys)) 206 | model.load_state_dict(matched_state_dict, strict=False) 207 | 208 | # use train & val splits to optimize, only available for vqa, not vqa_cp 209 | if args.use_both and args.dataset == "vqa": 210 | length = len(val_dset) 211 | trainval_concat_dset = ConcatDataset([train_dset, val_dset]) 212 | if args.use_vg or args.use_visdial: 213 | trainval_concat_dsets_split = random_split( 214 | trainval_concat_dset, 215 | [int(0.2*length), 216 | len(trainval_concat_dset)-int(0.2*length)]) 217 | else: 218 | trainval_concat_dsets_split = random_split( 219 | trainval_concat_dset, 220 | [int(0.1*length), 221 | len(trainval_concat_dset)-int(0.1*length)]) 222 | concat_list = [trainval_concat_dsets_split[1]] 223 | 224 | # use a portion of Visual Genome dataset 225 | if args.use_vg: 226 | vg_train_dset = VisualGenomeFeatureDataset( 227 | 'train', train_dset.features, train_dset.normalized_bb, 228 | train_dset.bb, train_dset.spatial_adj_matrix, 229 | train_dset.semantic_adj_matrix, dictionary, 230 | adaptive=train_dset.adaptive, 231 | pos_boxes=train_dset.pos_boxes, 232 | dataroot=args.data_folder) 233 | vg_val_dset = VisualGenomeFeatureDataset( 234 | 'val', val_dset.features, val_dset.normalized_bb, 235 | val_dset.bb, val_dset.spatial_adj_matrix, 236 | val_dset.semantic_adj_matrix, dictionary, 237 | adaptive=val_dset.adaptive, 238 | pos_boxes=val_dset.pos_boxes, 239 | dataroot=args.data_folder) 240 | concat_list.append(vg_train_dset) 241 | concat_list.append(vg_val_dset) 242 | final_train_dset = ConcatDataset(concat_list) 243 | final_eval_dset = trainval_concat_dsets_split[0] 244 | train_loader = DataLoader(final_train_dset, batch_size, shuffle=True, 245 | num_workers=4, collate_fn=trim_collate) 246 | eval_loader = DataLoader(final_eval_dset, batch_size, 247 | shuffle=False, num_workers=4, 248 | collate_fn=trim_collate) 249 | 250 | else: 251 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, 252 | num_workers=4, collate_fn=trim_collate) 253 | eval_loader = DataLoader(val_dset, batch_size, shuffle=False, 254 | num_workers=4, collate_fn=trim_collate) 255 | 256 | output_meta_folder = join(args.output, "regat_%s" % args.relation_type) 257 | utils.create_dir(output_meta_folder) 258 | args.output = output_meta_folder+"/%s_%s_%s_%d" % ( 259 | fusion_methods, args.relation_type, 260 | args.dataset, args.seed) 261 | if exists(args.output) and os.listdir(args.output): 262 | raise ValueError("Output directory ({}) already exists and is not " 263 | "empty.".format(args.output)) 264 | utils.create_dir(args.output) 265 | with open(join(args.output, 'hps.json'), 'w') as writer: 266 | json.dump(vars(args), writer, indent=4) 267 | logger = utils.Logger(join(args.output, 'log.txt')) 268 | 269 | train(model, train_loader, eval_loader, args, device) 270 | -------------------------------------------------------------------------------- /misc/regat_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjieli222/VQA_ReGAT/a80abb6dda6daee6394a47970fdea1a5bd21894c/misc/regat_overview.jpg -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/bc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bilinear Attention Networks 3 | Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang 4 | https://arxiv.org/abs/1805.07932 5 | 6 | This code is from Jin-Hwa Kim's repository. 7 | https://github.com/jnhwkim/ban-vqa 8 | MIT License 9 | """ 10 | from __future__ import print_function 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn.utils.weight_norm import weight_norm 15 | from model.fc import FCNet 16 | 17 | 18 | class BCNet(nn.Module): 19 | """Simple class for non-linear bilinear connect network 20 | """ 21 | def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', 22 | dropout=[.2, .5], k=3): 23 | super(BCNet, self).__init__() 24 | 25 | self.c = 32 26 | self.k = k 27 | self.v_dim = v_dim 28 | self.q_dim = q_dim 29 | self.h_dim = h_dim 30 | self.h_out = h_out 31 | 32 | self.v_net = FCNet([v_dim, h_dim * self.k], act=act, 33 | dropout=dropout[0]) 34 | self.q_net = FCNet([q_dim, h_dim * self.k], act=act, 35 | dropout=dropout[0]) 36 | self.dropout = nn.Dropout(dropout[1]) # attention 37 | if 1 < k: 38 | self.p_net = nn.AvgPool1d(self.k, stride=self.k) 39 | 40 | if h_out is None: 41 | pass 42 | elif h_out <= self.c: 43 | self.h_mat = nn.Parameter( 44 | torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) 45 | self.h_bias = nn.Parameter( 46 | torch.Tensor(1, h_out, 1, 1).normal_()) 47 | else: 48 | self.h_net = weight_norm( 49 | nn.Linear(h_dim * self.k, h_out), dim=None) 50 | 51 | def forward(self, v, q): 52 | if self.h_out is None: 53 | v_ = self.v_net(v) 54 | q_ = self.q_net(q) 55 | logits = torch.einsum('bvk,bqk->bvqk', (v_, q_)) 56 | return logits 57 | 58 | # low-rank bilinear pooling using einsum 59 | elif self.h_out <= self.c: 60 | v_ = self.dropout(self.v_net(v)) 61 | q_ = self.q_net(q) 62 | logits = torch.einsum('xhyk,bvk,bqk->bhvq', 63 | (self.h_mat, v_, q_)) + self.h_bias 64 | return logits # b x h_out x v x q 65 | 66 | # batch outer product, linear projection 67 | # memory efficient but slow computation 68 | else: 69 | v_ = self.dropout(self.v_net(v)).transpose(1, 2).unsqueeze(3) 70 | q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) 71 | d_ = torch.matmul(v_, q_) # b x h_dim x v x q 72 | # b x v x q x h_out 73 | logits = self.h_net(d_.transpose(1, 2).transpose(2, 3)) 74 | return logits.transpose(2, 3).transpose(1, 2) # b x h_out x v x q 75 | 76 | def forward_with_weights(self, v, q, w): 77 | v_ = self.v_net(v) # b x v x d 78 | q_ = self.q_net(q) # b x q x d 79 | logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_)) 80 | if 1 < self.k: 81 | logits = logits.unsqueeze(1) # b x 1 x d 82 | logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling 83 | return logits 84 | -------------------------------------------------------------------------------- /model/bilinear_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bilinear Attention Networks 3 | Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang 4 | https://arxiv.org/abs/1805.07932 5 | 6 | This code is from Jin-Hwa Kim's repository. 7 | https://github.com/jnhwkim/ban-vqa 8 | MIT License 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.utils.weight_norm import weight_norm 13 | from model.fc import FCNet 14 | from model.bc import BCNet 15 | import numpy as np 16 | 17 | 18 | class BiAttention(nn.Module): 19 | def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2, .5]): 20 | super(BiAttention, self).__init__() 21 | 22 | self.glimpse = glimpse 23 | self.logits = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, 24 | dropout=dropout, k=3), 25 | name='h_mat', dim=None) 26 | 27 | def forward(self, v, q, v_mask=True): 28 | """ 29 | v: [batch, k, vdim] 30 | q: [batch, qdim] 31 | """ 32 | p, logits = self.forward_all(v, q, v_mask) 33 | return p, logits 34 | 35 | def forward_all(self, v, q, v_mask=True, logit=False, 36 | mask_with=-float('inf')): 37 | v_num = v.size(1) 38 | q_num = q.size(1) 39 | # if visualize: 40 | # logits,bc_out = self.logits(v,q) # b x g x v x q 41 | # else: 42 | # logits = self.logits(v,q) # b x g x v x q 43 | 44 | logits = self.logits(v, q) # b x g x v x q 45 | if v_mask: 46 | mask = (0 == v.abs().sum(2)).unsqueeze(1).unsqueeze(3).expand( 47 | logits.size()) 48 | logits.data.masked_fill_(mask.data, mask_with) 49 | 50 | p = nn.functional.softmax( 51 | logits.view(-1, self.glimpse, v_num * q_num), 2) 52 | p = p.view(-1, self.glimpse, v_num, q_num) 53 | # if visualize: 54 | # return p,logits, bc_out 55 | if not logit: 56 | return p, logits 57 | return logits 58 | -------------------------------------------------------------------------------- /model/classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | GNU General Public License v3.0 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.utils.weight_norm import weight_norm 9 | 10 | 11 | class SimpleClassifier(nn.Module): 12 | def __init__(self, in_dim, hid_dim, out_dim, dropout): 13 | super(SimpleClassifier, self).__init__() 14 | layers = [ 15 | weight_norm(nn.Linear(in_dim, hid_dim), dim=None), 16 | nn.ReLU(), 17 | nn.Dropout(dropout, inplace=True), 18 | weight_norm(nn.Linear(hid_dim, out_dim), dim=None) 19 | ] 20 | self.main = nn.Sequential(*layers) 21 | 22 | def forward(self, x): 23 | logits = self.main(x) 24 | return logits 25 | -------------------------------------------------------------------------------- /model/counting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Learning to Count Objects in Natural Images for Visual Question Answering 3 | Yan Zhang, Jonathon Hare, Adam Prugel-Bennett 4 | ICLR 2018 5 | 6 | This code is from Yan Zhang's repository. 7 | https://github.com/Cyanogenoid/vqa-counting/blob/master/vqa-v2/counting.py 8 | MIT License 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | 15 | 16 | class Counter(nn.Module): 17 | """ Counting module as proposed in [1]. 18 | Count the number of objects from a set of bounding boxes 19 | and a set of scores for each bounding box. 20 | This produces (self.objects + 1) number of count features. 21 | 22 | Yan Zhang, Jonathon Hare, Adam Prugel-Bennett: 23 | Learning to Count Objects in Natural Images for Visual Question Answering. 24 | https://openreview.net/forum?id=B12Js_yRb 25 | """ 26 | def __init__(self, objects, already_sigmoided=False): 27 | super().__init__() 28 | self.objects = objects 29 | self.already_sigmoided = already_sigmoided 30 | self.f = nn.ModuleList([PiecewiseLin(16) for _ in range(16)]) 31 | 32 | def forward(self, boxes, attention): 33 | """ 34 | Forward propagation of attention weights 35 | and bounding boxes to produce count features. 36 | `boxes` has to be a tensor of shape (n, 4, m) 37 | with the 4 channels containing the x and y coordinates 38 | of the top left corner and the x and y coordinates 39 | of the bottom right corner in this order. 40 | `attention` has to be a tensor of shape (n, m). 41 | Each value should be in [0, 1] if already_sigmoided is set to True, 42 | but there are no restrictions if already_sigmoided is set to False. 43 | This value should be close to 1 44 | if the corresponding boundign box is relevant 45 | and close to 0 if it is not. 46 | n is the batch size, m is the number of bounding boxes per image. 47 | """ 48 | # only care about the highest scoring object proposals 49 | # the ones with low score will have a low impact on the count anyway 50 | boxes, attention = self.filter_most_important( 51 | self.objects, boxes, attention) 52 | # normalise the attention weights to be in [0, 1] 53 | if not self.already_sigmoided: 54 | attention = torch.sigmoid(attention) 55 | 56 | relevancy = self.outer_product(attention) 57 | distance = 1 - self.iou(boxes, boxes) 58 | 59 | # intra-object dedup 60 | score = self.f[0](relevancy) * self.f[1](distance) 61 | 62 | # inter-object dedup 63 | dedup_score = self.f[3](relevancy) * self.f[4](distance) 64 | dedup_per_entry, dedup_per_row = self.deduplicate( 65 | dedup_score, attention) 66 | score = score / dedup_per_entry 67 | 68 | # aggregate the score 69 | # can skip putting this on the diagonal 70 | # since we're just summing over it anyway 71 | correction = self.f[0](attention * attention) / dedup_per_row 72 | score = score.sum(dim=2).sum(dim=1, keepdim=True) +\ 73 | correction.sum(dim=1, keepdim=True) 74 | score = (score + 1e-20).sqrt() 75 | one_hot = self.to_one_hot(score) 76 | 77 | att_conf = (self.f[5](attention) - 0.5).abs() 78 | dist_conf = (self.f[6](distance) - 0.5).abs() 79 | conf = self.f[7](att_conf.mean(dim=1, keepdim=True) + 80 | dist_conf.mean(dim=2).mean(dim=1, keepdim=True)) 81 | 82 | return one_hot * conf 83 | 84 | def deduplicate(self, dedup_score, att): 85 | # using outer-diffs 86 | att_diff = self.outer_diff(att) 87 | score_diff = self.outer_diff(dedup_score) 88 | sim = self.f[2](1 - score_diff).prod(dim=1) * self.f[2](1 - att_diff) 89 | # similarity for each row 90 | row_sims = sim.sum(dim=2) 91 | # similarity for each entry 92 | all_sims = self.outer_product(row_sims) 93 | return all_sims, row_sims 94 | 95 | def to_one_hot(self, scores): 96 | """ Turn a bunch of non-negative scalar values into a one-hot encoding. 97 | E.g. with self.objects = 3, 0 -> [1 0 0 0], 2.75 -> [0 0 0.25 0.75]. 98 | """ 99 | # sanity check, I don't think this ever does anything 100 | # (it certainly shouldn't) 101 | scores = scores.clamp(min=0, max=self.objects) 102 | # compute only on the support 103 | i = scores.long().data 104 | f = scores.frac() 105 | # target_l is the one-hot if the score is rounded down 106 | # target_r is the one-hot if the score is rounded up 107 | target_l = scores.data.new(i.size(0), self.objects + 1).fill_(0) 108 | target_r = scores.data.new(i.size(0), self.objects + 1).fill_(0) 109 | 110 | target_l.scatter_(dim=1, index=i.clamp(max=self.objects), value=1) 111 | target_r.scatter_(dim=1, index=(i + 1).clamp(max=self.objects), 112 | value=1) 113 | # interpolate between these with the fractional part of the score 114 | return (1 - f) * target_l + f * target_r 115 | 116 | def filter_most_important(self, n, boxes, attention): 117 | """ Only keep top-n object proposals, scored by attention weight """ 118 | attention, idx = attention.topk(n, dim=1, sorted=False) 119 | idx = idx.unsqueeze(dim=1).expand( 120 | boxes.size(0), boxes.size(1), idx.size(1)) 121 | boxes = boxes.gather(2, idx) 122 | return boxes, attention 123 | 124 | def outer(self, x): 125 | size = tuple(x.size()) + (x.size()[-1],) 126 | a = x.unsqueeze(dim=-1).expand(*size) 127 | b = x.unsqueeze(dim=-2).expand(*size) 128 | return a, b 129 | 130 | def outer_product(self, x): 131 | # Y_ij = x_i * x_j 132 | a, b = self.outer(x) 133 | return a * b 134 | 135 | def outer_diff(self, x): 136 | # like outer products, except taking the absolute difference instead 137 | # Y_ij = | x_i - x_j | 138 | a, b = self.outer(x) 139 | return (a - b).abs() 140 | 141 | def iou(self, a, b): 142 | # this is just the usual way to IoU from bounding boxes 143 | inter = self.intersection(a, b) 144 | area_a = self.area(a).unsqueeze(2).expand_as(inter) 145 | area_b = self.area(b).unsqueeze(1).expand_as(inter) 146 | return inter / (area_a + area_b - inter + 1e-12) 147 | 148 | def area(self, box): 149 | x = (box[:, 2, :] - box[:, 0, :]).clamp(min=0) 150 | y = (box[:, 3, :] - box[:, 1, :]).clamp(min=0) 151 | return x * y 152 | 153 | def intersection(self, a, b): 154 | size = (a.size(0), 2, a.size(2), b.size(2)) 155 | min_point = torch.max( 156 | a[:, :2, :].unsqueeze(dim=3).expand(*size), 157 | b[:, :2, :].unsqueeze(dim=2).expand(*size), 158 | ) 159 | max_point = torch.min( 160 | a[:, 2:, :].unsqueeze(dim=3).expand(*size), 161 | b[:, 2:, :].unsqueeze(dim=2).expand(*size), 162 | ) 163 | inter = (max_point - min_point).clamp(min=0) 164 | area = inter[:, 0, :, :] * inter[:, 1, :, :] 165 | return area 166 | 167 | 168 | class PiecewiseLin(nn.Module): 169 | def __init__(self, n): 170 | super().__init__() 171 | self.n = n 172 | self.weight = nn.Parameter(torch.ones(n + 1)) 173 | # the first weight here is always 0 with a 0 gradient 174 | self.weight.data[0] = 0 175 | 176 | def forward(self, x): 177 | # all weights are positive -> function is monotonically increasing 178 | w = self.weight.abs() 179 | # make weights sum to one -> f(1) = 1 180 | w = w / w.sum() 181 | w = w.view([self.n + 1] + [1] * x.dim()) 182 | # keep cumulative sum for O(1) time complexity 183 | csum = w.cumsum(dim=0) 184 | csum = csum.expand((self.n + 1,) + tuple(x.size())) 185 | w = w.expand_as(csum) 186 | 187 | # figure out which part of the function the input lies on 188 | y = self.n * x.unsqueeze(0) 189 | idx = y.long().data 190 | f = y.frac() 191 | 192 | # contribution of the linear parts left of the input 193 | x = csum.gather(0, idx.clamp(max=self.n)) 194 | # contribution within the linear segment the input falls into 195 | x = x + f * w.gather(0, (idx + 1).clamp(max=self.n)) 196 | return x.squeeze(0) 197 | -------------------------------------------------------------------------------- /model/fc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | This code is modified by Linjie Li from Jin-Hwa Kim's repository. 6 | https://github.com/jnhwkim/ban-vqa 7 | MIT License 8 | """ 9 | from __future__ import print_function 10 | import torch.nn as nn 11 | from torch.nn.utils.weight_norm import weight_norm 12 | import torch 13 | 14 | 15 | class FCNet(nn.Module): 16 | """Simple class for non-linear fully connect network 17 | """ 18 | def __init__(self, dims, act='ReLU', dropout=0, bias=True): 19 | super(FCNet, self).__init__() 20 | 21 | layers = [] 22 | for i in range(len(dims)-2): 23 | in_dim = dims[i] 24 | out_dim = dims[i+1] 25 | if 0 < dropout: 26 | layers.append(nn.Dropout(dropout)) 27 | layers.append(weight_norm(nn.Linear(in_dim, out_dim, bias=bias), 28 | dim=None)) 29 | if '' != act and act is not None: 30 | layers.append(getattr(nn, act)()) 31 | if 0 < dropout: 32 | layers.append(nn.Dropout(dropout)) 33 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1], bias=bias), 34 | dim=None)) 35 | if '' != act and act is not None: 36 | layers.append(getattr(nn, act)()) 37 | 38 | self.main = nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | return self.main(x) 42 | -------------------------------------------------------------------------------- /model/fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Relation-aware Graph Attention Network for Visual Question Answering 6 | Linjie Li, Zhe Gan, Yu Cheng, Jingjing Liu 7 | https://arxiv.org/abs/1903.12314 8 | 9 | This code is written by Linjie Li. 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | from model.bilinear_attention import BiAttention 14 | import torch.nn.functional as F 15 | from model.fc import FCNet 16 | from model.bc import BCNet 17 | from model.counting import Counter 18 | from torch.nn.utils.weight_norm import weight_norm 19 | from block import fusions 20 | 21 | """ 22 | Bilinear Attention Networks 23 | Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang 24 | https://arxiv.org/abs/1805.07932 25 | 26 | This code is modified from Jin-Hwa Kim's repository. 27 | https://github.com/jnhwkim/ban-vqa 28 | MIT License 29 | """ 30 | 31 | 32 | class BAN(nn.Module): 33 | def __init__(self, v_relation_dim, num_hid, gamma, 34 | min_num_objects=10, use_counter=True): 35 | super(BAN, self).__init__() 36 | 37 | self.v_att = BiAttention(v_relation_dim, num_hid, num_hid, gamma) 38 | self.glimpse = gamma 39 | self.use_counter = use_counter 40 | b_net = [] 41 | q_prj = [] 42 | c_prj = [] 43 | q_att = [] 44 | v_prj = [] 45 | 46 | for i in range(gamma): 47 | b_net.append(BCNet(v_relation_dim, num_hid, num_hid, None, k=1)) 48 | q_prj.append(FCNet([num_hid, num_hid], '', .2)) 49 | if self.use_counter: 50 | c_prj.append(FCNet([min_num_objects + 1, num_hid], 'ReLU', .0)) 51 | 52 | self.b_net = nn.ModuleList(b_net) 53 | self.q_prj = nn.ModuleList(q_prj) 54 | self.q_att = nn.ModuleList(q_att) 55 | self.v_prj = nn.ModuleList(v_prj) 56 | if self.use_counter: 57 | self.c_prj = nn.ModuleList(c_prj) 58 | self.counter = Counter(min_num_objects) 59 | 60 | def forward(self, v_relation, q_emb, b): 61 | if self.use_counter: 62 | boxes = b[:, :, :4].transpose(1, 2) 63 | 64 | b_emb = [0] * self.glimpse 65 | # b x g x v x q 66 | att, att_logits = self.v_att.forward_all(v_relation, q_emb) 67 | 68 | for g in range(self.glimpse): 69 | # b x l x h 70 | b_emb[g] = self.b_net[g].forward_with_weights( 71 | v_relation, q_emb, att[:, g, :, :]) 72 | # atten used for counting module 73 | atten, _ = att_logits[:, g, :, :].max(2) 74 | q_emb = self.q_prj[g](b_emb[g].unsqueeze(1)) + q_emb 75 | 76 | if self.use_counter: 77 | embed = self.counter(boxes, atten) 78 | q_emb = q_emb + self.c_prj[g](embed).unsqueeze(1) 79 | joint_emb = q_emb.sum(1) 80 | return joint_emb, att 81 | 82 | """ 83 | This code is modified by Linjie Li from Hengyuan Hu's repository. 84 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 85 | GNU General Public License v3.0 86 | """ 87 | 88 | 89 | class BUTD(nn.Module): 90 | def __init__(self, v_relation_dim, q_dim, num_hid, dropout=0.2): 91 | super(BUTD, self).__init__() 92 | self.v_proj = FCNet([v_relation_dim, num_hid]) 93 | self.q_proj = FCNet([q_dim, num_hid]) 94 | self.dropout = nn.Dropout(dropout) 95 | self.linear = FCNet([num_hid, 1]) 96 | self.q_net = FCNet([q_dim, num_hid]) 97 | self.v_net = FCNet([v_relation_dim, num_hid]) 98 | 99 | def forward(self, v_relation, q_emb): 100 | """ 101 | v: [batch, k, vdim] 102 | q: [batch, qdim] 103 | b: bounding box features, not used for this fusion method 104 | """ 105 | logits = self.logits(v_relation, q_emb) 106 | att = nn.functional.softmax(logits, 1) 107 | v_emb = (att * v_relation).sum(1) # [batch, v_dim] 108 | 109 | q_repr = self.q_net(q_emb) 110 | v_repr = self.v_net(v_emb) 111 | joint_emb = q_repr * v_repr 112 | return joint_emb, att 113 | 114 | def logits(self, v, q): 115 | batch, k, _ = v.size() 116 | v_proj = self.v_proj(v) # [batch, k, qdim] 117 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) 118 | joint_repr = v_proj * q_proj 119 | joint_repr = self.dropout(joint_repr) 120 | logits = self.linear(joint_repr) 121 | return logits 122 | 123 | """ 124 | This code is modified by Linjie Li from Remi Cadene's repository. 125 | https://github.com/Cadene/vqa.pytorch 126 | """ 127 | 128 | 129 | class MuTAN_Attention(nn.Module): 130 | def __init__(self, dim_v, dim_q, dim_out, method="Mutan", mlp_glimpses=0): 131 | super(MuTAN_Attention, self).__init__() 132 | self.mlp_glimpses = mlp_glimpses 133 | self.fusion = getattr(fusions, method)( 134 | [dim_q, dim_v], dim_out, mm_dim=1200, 135 | dropout_input=0.1) 136 | if self.mlp_glimpses > 0: 137 | self.linear0 = FCNet([dim_out, 512], '', 0) 138 | self.linear1 = FCNet([512, mlp_glimpses], '', 0) 139 | 140 | def forward(self, q, v): 141 | alpha = self.process_attention(q, v) 142 | 143 | if self.mlp_glimpses > 0: 144 | alpha = self.linear0(alpha) 145 | alpha = F.relu(alpha) 146 | alpha = self.linear1(alpha) 147 | 148 | alpha = F.softmax(alpha, dim=1) 149 | 150 | if alpha.size(2) > 1: # nb_glimpses > 1 151 | alphas = torch.unbind(alpha, dim=2) 152 | v_outs = [] 153 | for alpha in alphas: 154 | alpha = alpha.unsqueeze(2).expand_as(v) 155 | v_out = alpha*v 156 | v_out = v_out.sum(1) 157 | v_outs.append(v_out) 158 | v_out = torch.cat(v_outs, dim=1) 159 | else: 160 | alpha = alpha.expand_as(v) 161 | v_out = alpha*v 162 | v_out = v_out.sum(1) 163 | return v_out 164 | 165 | def process_attention(self, q, v): 166 | batch_size = q.size(0) 167 | n_regions = v.size(1) 168 | q = q[:, None, :].expand(q.size(0), n_regions, q.size(1)) 169 | alpha = self.fusion([ 170 | q.contiguous().view(batch_size*n_regions, -1), 171 | v.contiguous().view(batch_size*n_regions, -1) 172 | ]) 173 | alpha = alpha.view(batch_size, n_regions, -1) 174 | return alpha 175 | 176 | 177 | class MuTAN(nn.Module): 178 | def __init__(self, v_relation_dim, num_hid, num_ans_candidates, gamma): 179 | super(MuTAN, self).__init__() 180 | self.gamma = gamma 181 | self.attention = MuTAN_Attention(v_relation_dim, num_hid, 182 | dim_out=360, method="Mutan", 183 | mlp_glimpses=gamma) 184 | self.fusion = getattr(fusions, "Mutan")( 185 | [num_hid, v_relation_dim*2], num_ans_candidates, 186 | mm_dim=1200, dropout_input=0.1) 187 | 188 | def forward(self, v_relation, q_emb): 189 | # b: bounding box features, not used for this fusion method 190 | att = self.attention(q_emb, v_relation) 191 | logits = self.fusion([q_emb, att]) 192 | return logits, att 193 | -------------------------------------------------------------------------------- /model/graph_att.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Relation-aware Graph Attention Network for Visual Question Answering 6 | Linjie Li, Zhe Gan, Yu Cheng, Jingjing Liu 7 | https://arxiv.org/abs/1903.12314 8 | 9 | This code is written by Linjie Li. 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | from model.fc import FCNet 14 | from model.graph_att_layer import GraphSelfAttentionLayer 15 | 16 | 17 | class GAttNet(nn.Module): 18 | def __init__(self, dir_num, label_num, in_feat_dim, out_feat_dim, 19 | nongt_dim=20, dropout=0.2, label_bias=True, 20 | num_heads=16, pos_emb_dim=-1): 21 | """ Attetion module with vectorized version 22 | 23 | Args: 24 | label_num: numer of edge labels 25 | dir_num: number of edge directions 26 | feat_dim: dimension of roi_feat 27 | pos_emb_dim: dimension of postion embedding for implicit relation, set as -1 for explicit relation 28 | 29 | Returns: 30 | output: [num_rois, ovr_feat_dim, output_dim] 31 | """ 32 | super(GAttNet, self).__init__() 33 | assert dir_num <= 2, "Got more than two directions in a graph." 34 | self.dir_num = dir_num 35 | self.label_num = label_num 36 | self.in_feat_dim = in_feat_dim 37 | self.out_feat_dim = out_feat_dim 38 | self.dropout = nn.Dropout(dropout) 39 | self.self_weights = FCNet([in_feat_dim, out_feat_dim], '', dropout) 40 | self.bias = FCNet([label_num, 1], '', 0, label_bias) 41 | self.nongt_dim = nongt_dim 42 | self.pos_emb_dim = pos_emb_dim 43 | neighbor_net = [] 44 | for i in range(dir_num): 45 | g_att_layer = GraphSelfAttentionLayer( 46 | pos_emb_dim=pos_emb_dim, 47 | num_heads=num_heads, 48 | feat_dim=out_feat_dim, 49 | nongt_dim=nongt_dim) 50 | neighbor_net.append(g_att_layer) 51 | self.neighbor_net = nn.ModuleList(neighbor_net) 52 | 53 | def forward(self, v_feat, adj_matrix, pos_emb=None): 54 | """ 55 | Args: 56 | v_feat: [batch_size,num_rois, feat_dim] 57 | adj_matrix: [batch_size, num_rois, num_rois, num_labels] 58 | pos_emb: [batch_size, num_rois, pos_emb_dim] 59 | 60 | Returns: 61 | output: [batch_size, num_rois, feat_dim] 62 | """ 63 | if self.pos_emb_dim > 0 and pos_emb is None: 64 | raise ValueError( 65 | f"position embedding is set to None " 66 | f"with pos_emb_dim {self.pos_emb_dim}") 67 | elif self.pos_emb_dim < 0 and pos_emb is not None: 68 | raise ValueError( 69 | f"position embedding is NOT None " 70 | f"with pos_emb_dim < 0") 71 | batch_size, num_rois, feat_dim = v_feat.shape 72 | nongt_dim = self.nongt_dim 73 | 74 | adj_matrix = adj_matrix.float() 75 | 76 | adj_matrix_list = [adj_matrix, adj_matrix.transpose(1, 2)] 77 | 78 | # Self - looping edges 79 | # [batch_size,num_rois, out_feat_dim] 80 | self_feat = self.self_weights(v_feat) 81 | 82 | output = self_feat 83 | neighbor_emb = [0] * self.dir_num 84 | for d in range(self.dir_num): 85 | # [batch_size,num_rois, nongt_dim,label_num] 86 | input_adj_matrix = adj_matrix_list[d][:, :, :nongt_dim, :] 87 | condensed_adj_matrix = torch.sum(input_adj_matrix, dim=-1) 88 | 89 | # [batch_size,num_rois, nongt_dim] 90 | v_biases_neighbors = self.bias(input_adj_matrix).squeeze(-1) 91 | 92 | # [batch_size,num_rois, out_feat_dim] 93 | neighbor_emb[d] = self.neighbor_net[d].forward( 94 | self_feat, condensed_adj_matrix, pos_emb, 95 | v_biases_neighbors) 96 | 97 | # [batch_size,num_rois, out_feat_dim] 98 | output = output + neighbor_emb[d] 99 | output = self.dropout(output) 100 | output = nn.functional.relu(output) 101 | 102 | return output 103 | -------------------------------------------------------------------------------- /model/graph_att_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Relation-aware Graph Attention Network for Visual Question Answering 6 | Linjie Li, Zhe Gan, Yu Cheng, Jingjing Liu 7 | https://arxiv.org/abs/1903.12314 8 | 9 | This code is written by Linjie Li. 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | from model.fc import FCNet 14 | import math 15 | from torch.nn.utils.weight_norm import weight_norm 16 | 17 | 18 | class GraphSelfAttentionLayer(nn.Module): 19 | def __init__(self, feat_dim, nongt_dim=20, pos_emb_dim=-1, 20 | num_heads=16, dropout=[0.2, 0.5]): 21 | """ Attetion module with vectorized version 22 | 23 | Args: 24 | position_embedding: [num_rois, nongt_dim, pos_emb_dim] 25 | used in implicit relation 26 | pos_emb_dim: set as -1 if explicit relation 27 | nongt_dim: number of objects consider relations per image 28 | fc_dim: should be same as num_heads 29 | feat_dim: dimension of roi_feat 30 | num_heads: number of attention heads 31 | Returns: 32 | output: [num_rois, ovr_feat_dim, output_dim] 33 | """ 34 | super(GraphSelfAttentionLayer, self).__init__() 35 | # multi head 36 | self.fc_dim = num_heads 37 | self.feat_dim = feat_dim 38 | self.dim = (feat_dim, feat_dim, feat_dim) 39 | self.dim_group = (int(self.dim[0] / num_heads), 40 | int(self.dim[1] / num_heads), 41 | int(self.dim[2] / num_heads)) 42 | self.num_heads = num_heads 43 | self.pos_emb_dim = pos_emb_dim 44 | if self.pos_emb_dim > 0: 45 | self.pair_pos_fc1 = FCNet([pos_emb_dim, self.fc_dim], None, dropout[0]) 46 | self.query = FCNet([feat_dim, self.dim[0]], None, dropout[0]) 47 | self.nongt_dim = nongt_dim 48 | 49 | self.key = FCNet([feat_dim, self.dim[1]], None, dropout[0]) 50 | 51 | self.linear_out_ = weight_norm( 52 | nn.Conv2d(in_channels=self.fc_dim * feat_dim, 53 | out_channels=self.dim[2], 54 | kernel_size=(1, 1), 55 | groups=self.fc_dim), dim=None) 56 | 57 | def forward(self, roi_feat, adj_matrix, 58 | position_embedding, label_biases_att): 59 | """ 60 | Args: 61 | roi_feat: [batch_size, N, feat_dim] 62 | adj_matrix: [batch_size, N, nongt_dim] 63 | position_embedding: [num_rois, nongt_dim, pos_emb_dim] 64 | Returns: 65 | output: [batch_size, num_rois, ovr_feat_dim, output_dim] 66 | """ 67 | batch_size = roi_feat.size(0) 68 | num_rois = roi_feat.size(1) 69 | nongt_dim = self.nongt_dim if self.nongt_dim < num_rois else num_rois 70 | # [batch_size,nongt_dim, feat_dim] 71 | nongt_roi_feat = roi_feat[:, :nongt_dim, :] 72 | 73 | # [batch_size,num_rois, self.dim[0] = feat_dim] 74 | q_data = self.query(roi_feat) 75 | 76 | # [batch_size,num_rois, num_heads, feat_dim /num_heads] 77 | q_data_batch = q_data.view(batch_size, num_rois, self.num_heads, 78 | self.dim_group[0]) 79 | 80 | # [batch_size,num_heads, num_rois, feat_dim /num_heads] 81 | q_data_batch = torch.transpose(q_data_batch, 1, 2) 82 | 83 | # [batch_size,nongt_dim, self.dim[1] = feat_dim] 84 | k_data = self.key(nongt_roi_feat) 85 | 86 | # [batch_size,nongt_dim, num_heads, feat_dim /num_heads] 87 | k_data_batch = k_data.view(batch_size, nongt_dim, self.num_heads, 88 | self.dim_group[1]) 89 | 90 | # [batch_size,num_heads, nongt_dim, feat_dim /num_heads] 91 | k_data_batch = torch.transpose(k_data_batch, 1, 2) 92 | 93 | # [batch_size,nongt_dim, feat_dim] 94 | v_data = nongt_roi_feat 95 | 96 | # [batch_size, num_heads, num_rois, nongt_dim] 97 | aff = torch.matmul(q_data_batch, torch.transpose(k_data_batch, 2, 3)) 98 | 99 | # aff_scale, [batch_size, num_heads, num_rois, nongt_dim] 100 | aff_scale = (1.0 / math.sqrt(float(self.dim_group[1]))) * aff 101 | # aff_scale, [batch_size,num_rois,num_heads, nongt_dim] 102 | aff_scale = torch.transpose(aff_scale, 1, 2) 103 | weighted_aff = aff_scale 104 | 105 | if position_embedding is not None and self.pos_emb_dim > 0: 106 | # Adding goemetric features 107 | position_embedding = position_embedding.float() 108 | # [batch_size,num_rois * nongt_dim, emb_dim] 109 | position_embedding_reshape = position_embedding.view( 110 | (batch_size, -1, self.pos_emb_dim)) 111 | 112 | # position_feat_1, [batch_size,num_rois * nongt_dim, fc_dim] 113 | position_feat_1 = self.pair_pos_fc1(position_embedding_reshape) 114 | position_feat_1_relu = nn.functional.relu(position_feat_1) 115 | 116 | # aff_weight, [batch_size,num_rois, nongt_dim, fc_dim] 117 | aff_weight = position_feat_1_relu.view( 118 | (batch_size, -1, nongt_dim, self.fc_dim)) 119 | 120 | # aff_weight, [batch_size,num_rois, fc_dim, nongt_dim] 121 | aff_weight = torch.transpose(aff_weight, 2, 3) 122 | 123 | thresh = torch.FloatTensor([1e-6]).cuda() 124 | # weighted_aff, [batch_size,num_rois, fc_dim, nongt_dim] 125 | threshold_aff = torch.max(aff_weight, thresh) 126 | 127 | weighted_aff += torch.log(threshold_aff) 128 | 129 | if adj_matrix is not None: 130 | # weighted_aff_transposed, [batch_size,num_rois, nongt_dim, num_heads] 131 | weighted_aff_transposed = torch.transpose(weighted_aff, 2, 3) 132 | zero_vec = -9e15*torch.ones_like(weighted_aff_transposed) 133 | 134 | adj_matrix = adj_matrix.view( 135 | adj_matrix.shape[0], adj_matrix.shape[1], 136 | adj_matrix.shape[2], 1) 137 | adj_matrix_expand = adj_matrix.expand( 138 | (-1, -1, -1, 139 | weighted_aff_transposed.shape[-1])) 140 | weighted_aff_masked = torch.where(adj_matrix_expand > 0, 141 | weighted_aff_transposed, 142 | zero_vec) 143 | 144 | weighted_aff_masked = weighted_aff_masked + \ 145 | label_biases_att.unsqueeze(3) 146 | weighted_aff = torch.transpose(weighted_aff_masked, 2, 3) 147 | 148 | # aff_softmax, [batch_size, num_rois, fc_dim, nongt_dim] 149 | aff_softmax = nn.functional.softmax(weighted_aff, 3) 150 | 151 | # aff_softmax_reshape, [batch_size, num_rois*fc_dim, nongt_dim] 152 | aff_softmax_reshape = aff_softmax.view((batch_size, -1, nongt_dim)) 153 | 154 | # output_t, [batch_size, num_rois * fc_dim, feat_dim] 155 | output_t = torch.matmul(aff_softmax_reshape, v_data) 156 | 157 | # output_t, [batch_size*num_rois, fc_dim * feat_dim, 1, 1] 158 | output_t = output_t.view((-1, self.fc_dim * self.feat_dim, 1, 1)) 159 | 160 | # linear_out, [batch_size*num_rois, dim[2], 1, 1] 161 | linear_out = self.linear_out_(output_t) 162 | output = linear_out.view((batch_size, num_rois, self.dim[2])) 163 | return output 164 | -------------------------------------------------------------------------------- /model/language_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | This code is modified by Linjie Li from Jin-Hwa Kim's repository. 6 | https://github.com/jnhwkim/ban-vqa 7 | MIT License 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | import numpy as np 13 | from model.fc import FCNet 14 | import torch.nn.functional as F 15 | 16 | 17 | class WordEmbedding(nn.Module): 18 | """Word Embedding 19 | 20 | The ntoken-th dim is used for padding_idx, which agrees *implicitly* 21 | with the definition in Dictionary. 22 | """ 23 | def __init__(self, ntoken, emb_dim, dropout, op=''): 24 | super(WordEmbedding, self).__init__() 25 | self.op = op 26 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken) 27 | if 'c' in op: 28 | self.emb_ = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken) 29 | self.emb_.weight.requires_grad = False # fixed 30 | self.dropout = nn.Dropout(dropout) 31 | self.ntoken = ntoken 32 | self.emb_dim = emb_dim 33 | 34 | def init_embedding(self, np_file, tfidf=None, tfidf_weights=None): 35 | weight_init = torch.from_numpy(np.load(np_file)) 36 | assert weight_init.shape == (self.ntoken, self.emb_dim) 37 | self.emb.weight.data[:self.ntoken] = weight_init 38 | if tfidf is not None: 39 | if 0 < tfidf_weights.size: 40 | weight_init = torch.cat([weight_init, 41 | torch.from_numpy(tfidf_weights)], 0) 42 | weight_init = tfidf.matmul(weight_init) # (N x N') x (N', F) 43 | if 'c' in self.op: 44 | self.emb_.weight.requires_grad = True 45 | if 'c' in self.op: 46 | self.emb_.weight.data[:self.ntoken] = weight_init.clone() 47 | 48 | def forward(self, x): 49 | emb = self.emb(x) 50 | if 'c' in self.op: 51 | emb = torch.cat((emb, self.emb_(x)), 2) 52 | emb = self.dropout(emb) 53 | return emb 54 | 55 | 56 | class QuestionEmbedding(nn.Module): 57 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, 58 | rnn_type='GRU'): 59 | """Module for question embedding 60 | """ 61 | super(QuestionEmbedding, self).__init__() 62 | assert rnn_type == 'LSTM' or rnn_type == 'GRU' 63 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU \ 64 | if rnn_type == 'GRU' else None 65 | 66 | self.rnn = rnn_cls( 67 | in_dim, num_hid, nlayers, 68 | bidirectional=bidirect, 69 | dropout=dropout, 70 | batch_first=True) 71 | 72 | self.in_dim = in_dim 73 | self.num_hid = num_hid 74 | self.nlayers = nlayers 75 | self.rnn_type = rnn_type 76 | self.ndirections = 1 + int(bidirect) 77 | 78 | def init_hidden(self, batch): 79 | # just to get the type of tensor 80 | weight = next(self.parameters()).data 81 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid) 82 | if self.rnn_type == 'LSTM': 83 | return (weight.new(*hid_shape).zero_(), 84 | weight.new(*hid_shape).zero_()) 85 | else: 86 | return weight.new(*hid_shape).zero_() 87 | 88 | def forward(self, x): 89 | # x: [batch, sequence, in_dim] 90 | batch = x.size(0) 91 | hidden = self.init_hidden(batch) 92 | self.rnn.flatten_parameters() 93 | output, hidden = self.rnn(x, hidden) 94 | 95 | if self.ndirections == 1: 96 | return output[:, -1] 97 | 98 | forward_ = output[:, -1, :self.num_hid] 99 | backward = output[:, 0, self.num_hid:] 100 | return torch.cat((forward_, backward), dim=1) 101 | 102 | def forward_all(self, x): 103 | # x: [batch, sequence, in_dim] 104 | batch = x.size(0) 105 | hidden = self.init_hidden(batch) 106 | self.rnn.flatten_parameters() 107 | output, hidden = self.rnn(x, hidden) 108 | return output 109 | 110 | 111 | class QuestionSelfAttention(nn.Module): 112 | def __init__(self, num_hid, dropout): 113 | super(QuestionSelfAttention, self).__init__() 114 | self.num_hid = num_hid 115 | self.drop = nn.Dropout(dropout) 116 | self.W1_self_att_q = FCNet(dims=[num_hid, num_hid], dropout=dropout, 117 | act=None) 118 | self.W2_self_att_q = FCNet(dims=[num_hid, 1], act=None) 119 | 120 | def forward(self, ques_feat): 121 | ''' 122 | ques_feat: [batch, 14, num_hid] 123 | ''' 124 | batch_size = ques_feat.shape[0] 125 | q_len = ques_feat.shape[1] 126 | 127 | # (batch*14,num_hid) 128 | ques_feat_reshape = ques_feat.contiguous().view(-1, self.num_hid) 129 | # (batch, 14) 130 | atten_1 = self.W1_self_att_q(ques_feat_reshape) 131 | atten_1 = torch.tanh(atten_1) 132 | atten = self.W2_self_att_q(atten_1).view(batch_size, q_len) 133 | # (batch, 1, 14) 134 | weight = F.softmax(atten.t(), dim=1).view(-1, 1, q_len) 135 | ques_feat_self_att = torch.bmm(weight, ques_feat) 136 | ques_feat_self_att = ques_feat_self_att.view(-1, self.num_hid) 137 | # (batch, num_hid) 138 | ques_feat_self_att = self.drop(ques_feat_self_att) 139 | return ques_feat_self_att 140 | -------------------------------------------------------------------------------- /model/position_emb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Relation-aware Graph Attention Network for Visual Question Answering 6 | Linjie Li, Zhe Gan, Yu Cheng, Jingjing Liu 7 | https://arxiv.org/abs/1903.12314 8 | 9 | This code is written by Linjie Li. 10 | """ 11 | import numpy as np 12 | import math 13 | import torch 14 | from torch.autograd import Variable 15 | 16 | 17 | def bb_intersection_over_union(boxA, boxB): 18 | # determine the (x, y)-coordinates of the intersection rectangle 19 | xA = max(boxA[0], boxB[0]) 20 | yA = max(boxA[1], boxB[1]) 21 | xB = min(boxA[2], boxB[2]) 22 | yB = min(boxA[3], boxB[3]) 23 | 24 | # compute the area of intersection rectangle 25 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 26 | 27 | # compute the area of both the prediction and ground-truth 28 | # rectangles 29 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 30 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 31 | 32 | # compute the intersection over union by taking the intersection 33 | # area and dividing it by the sum of prediction + ground-truth 34 | # areas - the interesection area 35 | iou = interArea / float(boxAArea + boxBArea - interArea) 36 | 37 | # return the intersection over union value 38 | return iou 39 | 40 | 41 | def build_graph(bbox, spatial, label_num=11): 42 | """ Build spatial graph 43 | 44 | Args: 45 | bbox: [num_boxes, 4] 46 | 47 | Returns: 48 | adj_matrix: [num_boxes, num_boxes, label_num] 49 | """ 50 | 51 | num_box = bbox.shape[0] 52 | adj_matrix = np.zeros((num_box, num_box)) 53 | xmin, ymin, xmax, ymax = np.split(bbox, 4, axis=1) 54 | # [num_boxes, 1] 55 | bbox_width = xmax - xmin + 1. 56 | bbox_height = ymax - ymin + 1. 57 | image_h = bbox_height[0]/spatial[0, -1] 58 | image_w = bbox_width[0]/spatial[0, -2] 59 | center_x = 0.5 * (xmin + xmax) 60 | center_y = 0.5 * (ymin + ymax) 61 | image_diag = math.sqrt(image_h**2 + image_w**2) 62 | for i in range(num_box): 63 | bbA = bbox[i] 64 | if sum(bbA) == 0: 65 | continue 66 | adj_matrix[i, i] = 12 67 | for j in range(i+1, num_box): 68 | bbB = bbox[j] 69 | if sum(bbB) == 0: 70 | continue 71 | # class 1: inside (j inside i) 72 | if xmin[i] < xmin[j] and xmax[i] > xmax[j] and \ 73 | ymin[i] < ymin[j] and ymax[i] > ymax[j]: 74 | adj_matrix[i, j] = 1 75 | adj_matrix[j, i] = 2 76 | # class 2: cover (j covers i) 77 | elif (xmin[j] < xmin[i] and xmax[j] > xmax[i] and 78 | ymin[j] < ymin[i] and ymax[j] > ymax[i]): 79 | adj_matrix[i, j] = 2 80 | adj_matrix[j, i] = 1 81 | else: 82 | ioU = bb_intersection_over_union(bbA, bbB) 83 | # class 3: i and j overlap 84 | if ioU >= 0.5: 85 | adj_matrix[i, j] = 3 86 | adj_matrix[j, i] = 3 87 | else: 88 | y_diff = center_y[i] - center_y[j] 89 | x_diff = center_x[i] - center_x[j] 90 | diag = math.sqrt((y_diff)**2 + (x_diff)**2) 91 | if diag < 0.5 * image_diag: 92 | sin_ij = y_diff/diag 93 | cos_ij = x_diff/diag 94 | if sin_ij >= 0 and cos_ij >= 0: 95 | label_i = np.arcsin(sin_ij) 96 | label_j = 2*math.pi - label_i 97 | elif sin_ij < 0 and cos_ij >= 0: 98 | label_i = np.arcsin(sin_ij)+2*math.pi 99 | label_j = label_i - math.pi 100 | elif sin_ij >= 0 and cos_ij < 0: 101 | label_i = np.arccos(cos_ij) 102 | label_j = 2*math.pi - label_i 103 | else: 104 | label_i = -np.arccos(sin_ij)+2*math.pi 105 | label_j = label_i - math.pi 106 | adj_matrix[i, j] = int(np.ceil(label_i/(math.pi/4)))+3 107 | adj_matrix[j, i] = int(np.ceil(label_j/(math.pi/4)))+3 108 | return adj_matrix 109 | 110 | 111 | def torch_broadcast_adj_matrix(adj_matrix, label_num=11, 112 | device=torch.device("cuda")): 113 | """ broudcast spatial relation graph 114 | 115 | Args: 116 | adj_matrix: [batch_size,num_boxes, num_boxes] 117 | 118 | Returns: 119 | result: [batch_size,num_boxes, num_boxes, label_num] 120 | """ 121 | result = [] 122 | for i in range(1, label_num+1): 123 | index = torch.nonzero((adj_matrix == i).view(-1).data).squeeze() 124 | curr_result = torch.zeros( 125 | adj_matrix.shape[0], adj_matrix.shape[1], adj_matrix.shape[2]) 126 | curr_result = curr_result.view(-1) 127 | curr_result[index] += 1 128 | result.append(curr_result.view( 129 | (adj_matrix.shape[0], adj_matrix.shape[1], 130 | adj_matrix.shape[2], 1))) 131 | result = torch.cat(result, dim=3) 132 | return result 133 | 134 | 135 | def torch_extract_position_embedding(position_mat, feat_dim, wave_length=1000, 136 | device=torch.device("cuda")): 137 | # position_mat, [batch_size,num_rois, nongt_dim, 4] 138 | feat_range = torch.arange(0, feat_dim / 8) 139 | dim_mat = torch.pow(torch.ones((1,))*wave_length, 140 | (8. / feat_dim) * feat_range) 141 | dim_mat = dim_mat.view(1, 1, 1, -1).to(device) 142 | position_mat = torch.unsqueeze(100.0 * position_mat, dim=4) 143 | div_mat = torch.div(position_mat.to(device), dim_mat) 144 | sin_mat = torch.sin(div_mat) 145 | cos_mat = torch.cos(div_mat) 146 | # embedding, [batch_size,num_rois, nongt_dim, 4, feat_dim/4] 147 | embedding = torch.cat([sin_mat, cos_mat], -1) 148 | # embedding, [batch_size,num_rois, nongt_dim, feat_dim] 149 | embedding = embedding.view(embedding.shape[0], embedding.shape[1], 150 | embedding.shape[2], feat_dim) 151 | return embedding 152 | 153 | 154 | def torch_extract_position_matrix(bbox, nongt_dim=36): 155 | """ Extract position matrix 156 | 157 | Args: 158 | bbox: [batch_size, num_boxes, 4] 159 | 160 | Returns: 161 | position_matrix: [batch_size, num_boxes, nongt_dim, 4] 162 | """ 163 | 164 | xmin, ymin, xmax, ymax = torch.split(bbox, 1, dim=-1) 165 | # [batch_size,num_boxes, 1] 166 | bbox_width = xmax - xmin + 1. 167 | bbox_height = ymax - ymin + 1. 168 | center_x = 0.5 * (xmin + xmax) 169 | center_y = 0.5 * (ymin + ymax) 170 | # [batch_size,num_boxes, num_boxes] 171 | delta_x = center_x-torch.transpose(center_x, 1, 2) 172 | delta_x = torch.div(delta_x, bbox_width) 173 | 174 | delta_x = torch.abs(delta_x) 175 | threshold = 1e-3 176 | delta_x[delta_x < threshold] = threshold 177 | delta_x = torch.log(delta_x) 178 | delta_y = center_y-torch.transpose(center_y, 1, 2) 179 | delta_y = torch.div(delta_y, bbox_height) 180 | delta_y = torch.abs(delta_y) 181 | delta_y[delta_y < threshold] = threshold 182 | delta_y = torch.log(delta_y) 183 | delta_width = torch.div(bbox_width, torch.transpose(bbox_width, 1, 2)) 184 | delta_width = torch.log(delta_width) 185 | delta_height = torch.div(bbox_height, torch.transpose(bbox_height, 1, 2)) 186 | delta_height = torch.log(delta_height) 187 | concat_list = [delta_x, delta_y, delta_width, delta_height] 188 | for idx, sym in enumerate(concat_list): 189 | sym = sym[:, :nongt_dim] 190 | concat_list[idx] = torch.unsqueeze(sym, dim=3) 191 | position_matrix = torch.cat(concat_list, 3) 192 | return position_matrix 193 | 194 | 195 | def prepare_graph_variables(relation_type, bb, sem_adj_matrix, spa_adj_matrix, 196 | num_objects, nongt_dim, pos_emb_dim, spa_label_num, 197 | sem_label_num, device): 198 | 199 | pos_emb_var, sem_adj_matrix_var, spa_adj_matrix_var = None, None, None 200 | if relation_type == "spatial": 201 | assert spa_adj_matrix.dim() > 2, "Found spa_adj_matrix of wrong shape" 202 | spa_adj_matrix = spa_adj_matrix.to(device) 203 | spa_adj_matrix = spa_adj_matrix[:, :num_objects, :num_objects] 204 | spa_adj_matrix = torch_broadcast_adj_matrix( 205 | spa_adj_matrix, label_num=spa_label_num, device=device) 206 | spa_adj_matrix_var = Variable(spa_adj_matrix).to(device) 207 | if relation_type == "semantic": 208 | assert sem_adj_matrix.dim() > 2, "Found sem_adj_matrix of wrong shape" 209 | sem_adj_matrix = sem_adj_matrix.to(device) 210 | sem_adj_matrix = sem_adj_matrix[:, :num_objects, :num_objects] 211 | sem_adj_matrix = torch_broadcast_adj_matrix( 212 | sem_adj_matrix, label_num=sem_label_num, device=device) 213 | sem_adj_matrix_var = Variable(sem_adj_matrix).to(device) 214 | else: 215 | bb = bb.to(device) 216 | pos_mat = torch_extract_position_matrix(bb, nongt_dim=nongt_dim) 217 | pos_emb = torch_extract_position_embedding( 218 | pos_mat, feat_dim=pos_emb_dim, device=device) 219 | pos_emb_var = Variable(pos_emb).to(device) 220 | return pos_emb_var, sem_adj_matrix_var, spa_adj_matrix_var 221 | -------------------------------------------------------------------------------- /model/regat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Relation-aware Graph Attention Network for Visual Question Answering 6 | Linjie Li, Zhe Gan, Yu Cheng, Jingjing Liu 7 | https://arxiv.org/abs/1903.12314 8 | 9 | This code is written by Linjie Li. 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | from model.fusion import BAN, BUTD, MuTAN 14 | from model.language_model import WordEmbedding, QuestionEmbedding,\ 15 | QuestionSelfAttention 16 | from model.relation_encoder import ImplicitRelationEncoder,\ 17 | ExplicitRelationEncoder 18 | from model.classifier import SimpleClassifier 19 | 20 | 21 | class ReGAT(nn.Module): 22 | def __init__(self, dataset, w_emb, q_emb, q_att, v_relation, 23 | joint_embedding, classifier, glimpse, fusion, relation_type): 24 | super(ReGAT, self).__init__() 25 | self.name = "ReGAT_%s_%s" % (relation_type, fusion) 26 | self.relation_type = relation_type 27 | self.fusion = fusion 28 | self.dataset = dataset 29 | self.glimpse = glimpse 30 | self.w_emb = w_emb 31 | self.q_emb = q_emb 32 | self.q_att = q_att 33 | self.v_relation = v_relation 34 | self.joint_embedding = joint_embedding 35 | self.classifier = classifier 36 | 37 | def forward(self, v, b, q, implicit_pos_emb, sem_adj_matrix, 38 | spa_adj_matrix, labels): 39 | """Forward 40 | v: [batch, num_objs, obj_dim] 41 | b: [batch, num_objs, b_dim] 42 | q: [batch_size, seq_length] 43 | pos: [batch_size, num_objs, nongt_dim, emb_dim] 44 | sem_adj_matrix: [batch_size, num_objs, num_objs, num_edge_labels] 45 | spa_adj_matrix: [batch_size, num_objs, num_objs, num_edge_labels] 46 | 47 | return: logits, not probs 48 | """ 49 | w_emb = self.w_emb(q) 50 | q_emb_seq = self.q_emb.forward_all(w_emb) # [batch, q_len, q_dim] 51 | q_emb_self_att = self.q_att(q_emb_seq) 52 | 53 | # [batch_size, num_rois, out_dim] 54 | if self.relation_type == "semantic": 55 | v_emb = self.v_relation.forward(v, sem_adj_matrix, q_emb_self_att) 56 | elif self.relation_type == "spatial": 57 | v_emb = self.v_relation.forward(v, spa_adj_matrix, q_emb_self_att) 58 | else: # implicit 59 | v_emb = self.v_relation.forward(v, implicit_pos_emb, 60 | q_emb_self_att) 61 | 62 | if self.fusion == "ban": 63 | joint_emb, att = self.joint_embedding(v_emb, q_emb_seq, b) 64 | elif self.fusion == "butd": 65 | q_emb = self.q_emb(w_emb) # [batch, q_dim] 66 | joint_emb, att = self.joint_embedding(v_emb, q_emb) 67 | else: # mutan 68 | joint_emb, att = self.joint_embedding(v_emb, q_emb_self_att) 69 | if self.classifier: 70 | logits = self.classifier(joint_emb) 71 | else: 72 | logits = joint_emb 73 | return logits, att 74 | 75 | 76 | def build_regat(dataset, args): 77 | print("Building ReGAT model with %s relation and %s fusion method" % 78 | (args.relation_type, args.fusion)) 79 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, .0, args.op) 80 | q_emb = QuestionEmbedding(300 if 'c' not in args.op else 600, 81 | args.num_hid, 1, False, .0) 82 | q_att = QuestionSelfAttention(args.num_hid, .2) 83 | 84 | if args.relation_type == "semantic": 85 | v_relation = ExplicitRelationEncoder( 86 | dataset.v_dim, args.num_hid, args.relation_dim, 87 | args.dir_num, args.sem_label_num, 88 | num_heads=args.num_heads, 89 | num_steps=args.num_steps, nongt_dim=args.nongt_dim, 90 | residual_connection=args.residual_connection, 91 | label_bias=args.label_bias) 92 | elif args.relation_type == "spatial": 93 | v_relation = ExplicitRelationEncoder( 94 | dataset.v_dim, args.num_hid, args.relation_dim, 95 | args.dir_num, args.spa_label_num, 96 | num_heads=args.num_heads, 97 | num_steps=args.num_steps, nongt_dim=args.nongt_dim, 98 | residual_connection=args.residual_connection, 99 | label_bias=args.label_bias) 100 | else: 101 | v_relation = ImplicitRelationEncoder( 102 | dataset.v_dim, args.num_hid, args.relation_dim, 103 | args.dir_num, args.imp_pos_emb_dim, args.nongt_dim, 104 | num_heads=args.num_heads, num_steps=args.num_steps, 105 | residual_connection=args.residual_connection, 106 | label_bias=args.label_bias) 107 | 108 | classifier = SimpleClassifier(args.num_hid, args.num_hid * 2, 109 | dataset.num_ans_candidates, 0.5) 110 | gamma = 0 111 | if args.fusion == "ban": 112 | joint_embedding = BAN(args.relation_dim, args.num_hid, args.ban_gamma) 113 | gamma = args.ban_gamma 114 | elif args.fusion == "butd": 115 | joint_embedding = BUTD(args.relation_dim, args.num_hid, args.num_hid) 116 | else: 117 | joint_embedding = MuTAN(args.relation_dim, args.num_hid, 118 | dataset.num_ans_candidates, args.mutan_gamma) 119 | gamma = args.mutan_gamma 120 | classifier = None 121 | return ReGAT(dataset, w_emb, q_emb, q_att, v_relation, joint_embedding, 122 | classifier, gamma, args.fusion, args.relation_type) 123 | -------------------------------------------------------------------------------- /model/relation_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Relation-aware Graph Attention Network for Visual Question Answering 6 | Linjie Li, Zhe Gan, Yu Cheng, Jingjing Liu 7 | https://arxiv.org/abs/1903.12314 8 | 9 | This code is written by Linjie Li. 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | from model.graph_att import GAttNet as GAT 15 | from model.language_model import QuestionSelfAttention 16 | from model.fc import FCNet 17 | 18 | 19 | def q_expand_v_cat(q, v, mask=True): 20 | q = q.view(q.size(0), 1, q.size(1)) 21 | repeat_vals = (-1, v.shape[1], -1) 22 | q_expand = q.expand(*repeat_vals) 23 | if mask: 24 | v_sum = v.sum(-1) 25 | mask_index = torch.nonzero(v_sum == 0) 26 | if mask_index.dim() > 1: 27 | q_expand[mask_index[:, 0], mask_index[:, 1]] = 0 28 | v_cat_q = torch.cat((v, q_expand), dim=-1) 29 | return v_cat_q 30 | 31 | 32 | class ImplicitRelationEncoder(nn.Module): 33 | def __init__(self, v_dim, q_dim, out_dim, dir_num, pos_emb_dim, 34 | nongt_dim, num_heads=16, num_steps=1, 35 | residual_connection=True, label_bias=True): 36 | super(ImplicitRelationEncoder, self).__init__() 37 | self.v_dim = v_dim 38 | self.q_dim = q_dim 39 | self.out_dim = out_dim 40 | self.residual_connection = residual_connection 41 | self.num_steps = num_steps 42 | print("In ImplicitRelationEncoder, num of graph propogate steps:", 43 | "%d, residual_connection: %s" % (self.num_steps, 44 | self.residual_connection)) 45 | 46 | if self.v_dim != self.out_dim: 47 | self.v_transform = FCNet([v_dim, out_dim]) 48 | else: 49 | self.v_transform = None 50 | in_dim = out_dim+q_dim 51 | self.implicit_relation = GAT(dir_num, 1, in_dim, out_dim, 52 | nongt_dim=nongt_dim, 53 | label_bias=label_bias, 54 | num_heads=num_heads, 55 | pos_emb_dim=pos_emb_dim) 56 | 57 | def forward(self, v, position_embedding, q): 58 | """ 59 | Args: 60 | v: [batch_size, num_rois, v_dim] 61 | q: [batch_size, q_dim] 62 | position_embedding: [batch_size, num_rois, nongt_dim, emb_dim] 63 | 64 | Returns: 65 | output: [batch_size, num_rois, out_dim,3] 66 | """ 67 | # [batch_size, num_rois, num_rois, 1] 68 | imp_adj_mat = Variable( 69 | torch.ones( 70 | v.size(0), v.size(1), v.size(1), 1)).to(v.device) 71 | imp_v = self.v_transform(v) if self.v_transform else v 72 | 73 | for i in range(self.num_steps): 74 | v_cat_q = q_expand_v_cat(q, imp_v, mask=True) 75 | imp_v_rel = self.implicit_relation.forward(v_cat_q, 76 | imp_adj_mat, 77 | position_embedding) 78 | if self.residual_connection: 79 | imp_v += imp_v_rel 80 | else: 81 | imp_v = imp_v_rel 82 | return imp_v 83 | 84 | 85 | class ExplicitRelationEncoder(nn.Module): 86 | def __init__(self, v_dim, q_dim, out_dim, dir_num, label_num, 87 | nongt_dim=20, num_heads=16, num_steps=1, 88 | residual_connection=True, label_bias=True): 89 | super(ExplicitRelationEncoder, self).__init__() 90 | self.v_dim = v_dim 91 | self.q_dim = q_dim 92 | self.out_dim = out_dim 93 | self.num_steps = num_steps 94 | self.residual_connection = residual_connection 95 | print("In ExplicitRelationEncoder, num of graph propogation steps:", 96 | "%d, residual_connection: %s" % (self.num_steps, 97 | self.residual_connection)) 98 | 99 | if self.v_dim != self.out_dim: 100 | self.v_transform = FCNet([v_dim, out_dim]) 101 | else: 102 | self.v_transform = None 103 | in_dim = out_dim+q_dim 104 | self.explicit_relation = GAT(dir_num, label_num, in_dim, out_dim, 105 | nongt_dim=nongt_dim, 106 | num_heads=num_heads, 107 | label_bias=label_bias, 108 | pos_emb_dim=-1) 109 | 110 | def forward(self, v, exp_adj_matrix, q): 111 | """ 112 | Args: 113 | v: [batch_size, num_rois, v_dim] 114 | q: [batch_size, q_dim] 115 | exp_adj_matrix: [batch_size, num_rois, num_rois, num_labels] 116 | 117 | Returns: 118 | output: [batch_size, num_rois, out_dim] 119 | """ 120 | exp_v = self.v_transform(v) if self.v_transform else v 121 | 122 | for i in range(self.num_steps): 123 | v_cat_q = q_expand_v_cat(q, exp_v, mask=True) 124 | exp_v_rel = self.explicit_relation.forward(v_cat_q, exp_adj_matrix) 125 | if self.residual_connection: 126 | exp_v += exp_v_rel 127 | else: 128 | exp_v = exp_v_rel 129 | return exp_v 130 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tools/compute_softscore.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | This code is modified by Linjie Li from Jin-Hwa Kim's repository. 6 | https://github.com/jnhwkim/ban-vqa 7 | MIT License 8 | """ 9 | from __future__ import print_function 10 | import os 11 | import sys 12 | import json 13 | import numpy as np 14 | import re 15 | import pickle 16 | 17 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 18 | import utils 19 | 20 | 21 | contractions = { 22 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 23 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 24 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 25 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 26 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 27 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 28 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 29 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 30 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 31 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 32 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 33 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 34 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 35 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 36 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 37 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 38 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 39 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 40 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 41 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 42 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 43 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 44 | "someonell": "someone'll", "someones": "someone's", "somethingd": 45 | "something'd", "somethingd've": "something'd've", "something'dve": 46 | "something'd've", "somethingll": "something'll", "thats": 47 | "that's", "thered": "there'd", "thered've": "there'd've", 48 | "there'dve": "there'd've", "therere": "there're", "theres": 49 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 50 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 51 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 52 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 53 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 54 | "what's", "whatve": "what've", "whens": "when's", "whered": 55 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 56 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 57 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 58 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 59 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 60 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 61 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 62 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 63 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 64 | "you'll", "youre": "you're", "youve": "you've" 65 | } 66 | 67 | manual_map = {'none': '0', 68 | 'zero': '0', 69 | 'one': '1', 70 | 'two': '2', 71 | 'three': '3', 72 | 'four': '4', 73 | 'five': '5', 74 | 'six': '6', 75 | 'seven': '7', 76 | 'eight': '8', 77 | 'nine': '9', 78 | 'ten': '10'} 79 | articles = ['a', 'an', 'the'] 80 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 81 | comma_strip = re.compile("(\d)(\,)(\d)") 82 | punct = [';', r"/", '[', ']', '"', '{', '}', 83 | '(', ')', '=', '+', '\\', '_', '-', 84 | '>', '<', '@', '`', ',', '?', '!'] 85 | 86 | 87 | # Notice that VQA score is the average of 10 choose 9 candidate answers cases 88 | # See http://visualqa.org/evaluation.html 89 | def get_score(occurences): 90 | if occurences == 0: 91 | return .0 92 | elif occurences == 1: 93 | return .3 94 | elif occurences == 2: 95 | return .6 96 | elif occurences == 3: 97 | return .9 98 | else: 99 | return 1. 100 | 101 | 102 | def process_punctuation(inText): 103 | outText = inText 104 | for p in punct: 105 | if (p + ' ' in inText or ' ' + p in inText) \ 106 | or (re.search(comma_strip, inText) != None): 107 | outText = outText.replace(p, '') 108 | else: 109 | outText = outText.replace(p, ' ') 110 | outText = period_strip.sub("", outText, re.UNICODE) 111 | return outText 112 | 113 | 114 | def process_digit_article(inText): 115 | outText = [] 116 | tempText = inText.lower().split() 117 | for word in tempText: 118 | word = manual_map.setdefault(word, word) 119 | if word not in articles: 120 | outText.append(word) 121 | else: 122 | pass 123 | for wordId, word in enumerate(outText): 124 | if word in contractions: 125 | outText[wordId] = contractions[word] 126 | outText = ' '.join(outText) 127 | return outText 128 | 129 | 130 | def multiple_replace(text, wordDict): 131 | for key in wordDict: 132 | text = text.replace(key, wordDict[key]) 133 | return text 134 | 135 | 136 | def preprocess_answer(answer): 137 | answer = process_digit_article(process_punctuation(answer)) 138 | answer = answer.replace(',', '') 139 | return answer 140 | 141 | 142 | def filter_answers(answers_dset, min_occurence): 143 | """This will change the answer to preprocessed version 144 | """ 145 | occurence = {} 146 | 147 | for ans_entry in answers_dset: 148 | answers = ans_entry['answers'] 149 | gtruth = ans_entry['multiple_choice_answer'] 150 | gtruth = preprocess_answer(gtruth) 151 | if gtruth not in occurence: 152 | occurence[gtruth] = set() 153 | occurence[gtruth].add(ans_entry['question_id']) 154 | for answer in list(occurence): 155 | if len(occurence[answer]) < min_occurence: 156 | occurence.pop(answer) 157 | 158 | print('Num of answers that appear >= %d times: %d' % ( 159 | min_occurence, len(occurence))) 160 | return occurence 161 | 162 | 163 | def create_ans2label(occurence, name, cache_root='data/cache'): 164 | """Note that this will also create label2ans.pkl at the same time 165 | 166 | occurence: dict {answer -> whatever} 167 | name: prefix of the output file 168 | cache_root: str 169 | """ 170 | ans2label = {} 171 | label2ans = [] 172 | label = 0 173 | for answer in occurence: 174 | label2ans.append(answer) 175 | ans2label[answer] = label 176 | label += 1 177 | 178 | utils.create_dir(cache_root) 179 | 180 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl') 181 | pickle.dump(ans2label, open(cache_file, 'wb')) 182 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl') 183 | pickle.dump(label2ans, open(cache_file, 'wb')) 184 | return ans2label 185 | 186 | 187 | def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): 188 | """Augment answers_dset with soft score as label 189 | 190 | ***answers_dset should be preprocessed*** 191 | 192 | Write result into a cache file 193 | """ 194 | target = [] 195 | for ans_entry in answers_dset: 196 | answers = ans_entry['answers'] 197 | answer_count = {} 198 | for answer in answers: 199 | answer_ = answer['answer'] 200 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 201 | 202 | labels = [] 203 | scores = [] 204 | for answer in answer_count: 205 | if answer not in ans2label: 206 | continue 207 | labels.append(ans2label[answer]) 208 | score = get_score(answer_count[answer]) 209 | scores.append(score) 210 | 211 | target.append({ 212 | 'question_id': ans_entry['question_id'], 213 | 'image_id': ans_entry['image_id'], 214 | 'labels': labels, 215 | 'scores': scores 216 | }) 217 | 218 | utils.create_dir(cache_root) 219 | cache_file = os.path.join(cache_root, name+'_target.pkl') 220 | pickle.dump(target, open(cache_file, 'wb')) 221 | return target 222 | 223 | 224 | def get_answer(qid, answers): 225 | for ans in answers: 226 | if ans['question_id'] == qid: 227 | return ans 228 | 229 | 230 | def get_question(qid, questions): 231 | for question in questions: 232 | if question['question_id'] == qid: 233 | return question 234 | 235 | 236 | if __name__ == '__main__': 237 | train_answer_file = 'data/Answers/v2_mscoco_train2014_annotations.json' 238 | train_answers = json.load(open(train_answer_file))['annotations'] 239 | 240 | val_answer_file = 'data/Answers/v2_mscoco_val2014_annotations.json' 241 | val_answers = json.load(open(val_answer_file))['annotations'] 242 | 243 | train_question_file = 'data/Questions/v2_OpenEnded_mscoco_train2014_questions.json' 244 | train_questions = json.load(open(train_question_file))['questions'] 245 | 246 | val_question_file = 'data/Questions/v2_OpenEnded_mscoco_val2014_questions.json' 247 | val_questions = json.load(open(val_question_file))['questions'] 248 | 249 | answers = train_answers + val_answers 250 | occurence = filter_answers(answers, 9) 251 | 252 | cache_path = 'data/cache/trainval_ans2label.pkl' 253 | if os.path.isfile(cache_path): 254 | print('found %s' % cache_path) 255 | ans2label = pickle.load(open(cache_path, 'rb')) 256 | else: 257 | ans2label = create_ans2label(occurence, 'trainval') 258 | compute_target(train_answers, ans2label, 'train') 259 | compute_target(val_answers, ans2label, 'val') 260 | 261 | train_answer_file = 'data/cp_v2_annotations/vqacp_v2_train_annotations.json' 262 | train_answers = json.load(open(train_answer_file)) 263 | 264 | test_answer_file = 'data/cp_v2_annotations/vqacp_v2_test_annotations.json' 265 | test_answers = json.load(open(test_answer_file)) 266 | 267 | train_question_file = 'data/cp_v2_questions/vqacp_v2_train_questions.json' 268 | train_questions = json.load(open(train_question_file)) 269 | 270 | test_question_file = 'data/cp_v2_questions/vqacp_v2_test_questions.json' 271 | test_questions = json.load(open(test_question_file)) 272 | 273 | answers = train_answers + test_answers 274 | occurence = filter_answers(answers, 9) 275 | 276 | cache_path = 'data/cache/trainval_ans2label.pkl' 277 | if os.path.isfile(cache_path): 278 | print('found %s' % cache_path) 279 | ans2label = pickle.load(open(cache_path, 'rb')) 280 | else: 281 | ans2label = create_ans2label(occurence, 'trainval') 282 | 283 | compute_target(train_answers, ans2label, 'cp_v2_train') 284 | compute_target(test_answers, ans2label, 'cp_v2_test') 285 | -------------------------------------------------------------------------------- /tools/create_dictionary.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | GNU General Public License v3.0 5 | """ 6 | from __future__ import print_function 7 | import os 8 | import sys 9 | import json 10 | import numpy as np 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | from dataset import Dictionary 13 | 14 | 15 | def create_dictionary(dataroot): 16 | dictionary = Dictionary() 17 | questions = [] 18 | files = [ 19 | 'Questions/v2_OpenEnded_mscoco_train2014_questions.json', 20 | 'Questions/v2_OpenEnded_mscoco_val2014_questions.json', 21 | 'Questions/v2_OpenEnded_mscoco_test2015_questions.json', 22 | 'Questions/v2_OpenEnded_mscoco_test-dev2015_questions.json' 23 | ] 24 | for path in files: 25 | question_path = os.path.join(dataroot, path) 26 | qs = json.load(open(question_path))['questions'] 27 | for q in qs: 28 | dictionary.tokenize(q['question'], True) 29 | return dictionary 30 | 31 | 32 | def create_glove_embedding_init(idx2word, glove_file): 33 | word2emb = {} 34 | with open(glove_file, 'r') as f: 35 | entries = f.readlines() 36 | emb_dim = len(entries[0].split(' ')) - 1 37 | print('embedding dim is %d' % emb_dim) 38 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 39 | 40 | for entry in entries: 41 | vals = entry.split(' ') 42 | word = vals[0] 43 | vals = list(map(float, vals[1:])) 44 | word2emb[word] = np.array(vals) 45 | for idx, word in enumerate(idx2word): 46 | if word not in word2emb: 47 | continue 48 | weights[idx] = word2emb[word] 49 | return weights, word2emb 50 | 51 | 52 | if __name__ == '__main__': 53 | d = create_dictionary('data') 54 | d.dump_to_file('data/glove/dictionary.pkl') 55 | 56 | d = Dictionary.load_from_file('data/dictionary.pkl') 57 | emb_dim = 300 58 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 59 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 60 | np.save('/data/glove/glove6b_init_%dd.npy' % emb_dim, weights) 61 | -------------------------------------------------------------------------------- /tools/create_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | GNU General Public License v3.0 5 | """ 6 | from __future__ import print_function 7 | import os 8 | import sys 9 | import json 10 | import functools 11 | import operator 12 | import numpy as np 13 | import pickle 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 15 | from dataset import Dictionary 16 | 17 | 18 | def create_glove_embedding_init(idx2word, glove_file): 19 | word2emb = {} 20 | with open(glove_file, 'r') as f: 21 | entries = f.readlines() 22 | emb_dim = len(entries[0].split(' ')) - 1 23 | print('embedding dim is %d' % emb_dim) 24 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 25 | 26 | for entry in entries: 27 | vals = entry.split(' ') 28 | word = vals[0] 29 | vals = list(map(float, vals[1:])) 30 | word2emb[word] = np.array(vals) 31 | count = 0 32 | for idx, word in enumerate(idx2word): 33 | if word not in word2emb: 34 | updates = 0 35 | for w in word.split(' '): 36 | if w not in word2emb: 37 | continue 38 | weights[idx] += word2emb[w] 39 | updates += 1 40 | if updates == 0: 41 | count += 1 42 | continue 43 | weights[idx] = word2emb[word] 44 | return weights, word2emb 45 | 46 | 47 | if __name__ == '__main__': 48 | emb_dims = [50, 100, 200, 300] 49 | weights = [0] * len(emb_dims) 50 | label2ans = pickle.load(open('data/cache/trainval_label2ans.pkl', 'rb')) 51 | 52 | for idx, emb_dim in enumerate(emb_dims): # available embedding sizes 53 | glove_file = './data/glove/glove.6B.%dd.txt' % emb_dim 54 | weights[idx], word2emb = create_glove_embedding_init( 55 | label2ans, glove_file) 56 | np.save('./data/glove6b_emb_%dd.npy' % functools.reduce( 57 | operator.add, emb_dims), np.hstack(weights)) 58 | -------------------------------------------------------------------------------- /tools/download.sh: -------------------------------------------------------------------------------- 1 | ## Copyright (c) Microsoft Corporation. 2 | ## Licensed under the MIT license. 3 | 4 | ## This code is modified by Linjie Li from Hengyuan Hu's repository. 5 | ## https://github.com/hengyuan-hu/bottom-up-attention-vqa 6 | ## GNU General Public License v3.0 7 | 8 | ## Script for downloading data 9 | 10 | # VQA Questions 11 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip 12 | unzip data/v2_Questions_Train_mscoco.zip -d data/Questions 13 | rm data/v2_Questions_Train_mscoco.zip 14 | 15 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip 16 | unzip data/v2_Questions_Val_mscoco.zip -d data/Questions 17 | rm data/v2_Questions_Val_mscoco.zip 18 | 19 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip 20 | unzip data/v2_Questions_Test_mscoco.zip -d data/Questions 21 | rm data/v2_Questions_Test_mscoco.zip 22 | 23 | # VQA Annotations 24 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip 25 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip 26 | unzip data/v2_Annotations_Train_mscoco.zip -d data/Answers 27 | rm data/v2_Annotations_Train_mscoco.zip 28 | unzip data/v2_Annotations_Val_mscoco.zip -d data/Answers 29 | rm data/v2_Annotations_Val_mscoco.zip 30 | 31 | # VQA cp-v2 Questions 32 | mkdir data/cp_v2_questions 33 | wget -P data/cp_v2_questions https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json 34 | wget -P data/cp_v2_questions https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json 35 | 36 | # VQA cp-v2 Annotations 37 | mkdir data/cp_v2_annotations 38 | wget -P data/cp_v2_annotations https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json 39 | wget -P data/cp_v2_annotations https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json 40 | 41 | # Visual Genome Annotations 42 | mkdir data/visualGenome 43 | wget -P data/visualGenome https://convaisharables.blob.core.windows.net/vqa-regat/data/visualGenome/image_data.json 44 | wget -P data/visualGenome https://convaisharables.blob.core.windows.net/vqa-regat/data/visualGenome/question_answers.json 45 | 46 | # GloVe Vectors and dictionary 47 | wget -P data https://convaisharables.blob.core.windows.net/vqa-regat/data/glove.zip 48 | unzip data/glove.zip -d data/glove 49 | rm data/glove.zip 50 | 51 | # Image Features 52 | # adaptive 53 | # WARNING: This may take a while 54 | mkdir data/Bottom-up-features-adaptive 55 | wget -P data/Bottom-up-features-adaptive https://convaisharables.blob.core.windows.net/vqa-regat/data/Bottom-up-features-adaptive/train.hdf5 56 | wget -P data/Bottom-up-features-adaptive https://convaisharables.blob.core.windows.net/vqa-regat/data/Bottom-up-features-adaptive/val.hdf5 57 | wget -P data/Bottom-up-features-adaptive https://convaisharables.blob.core.windows.net/vqa-regat/data/Bottom-up-features-adaptive/test2015.hdf5 58 | 59 | # fixed 60 | # WARNING: This may take a while 61 | mkdir data/Bottom-up-features-fixed 62 | wget -P data/Bottom-up-features-fixed https://convaisharables.blob.core.windows.net/vqa-regat/data/Bottom-up-features-fixed/train36.hdf5 63 | wget -P data/Bottom-up-features-fixed https://convaisharables.blob.core.windows.net/vqa-regat/data/Bottom-up-features-fixed/val36.hdf5 64 | wget -P data/Bottom-up-features-fixed https://convaisharables.blob.core.windows.net/vqa-regat/data/Bottom-up-features-fixed/test2015_36.hdf5 65 | 66 | # imgids 67 | wget -P data/ https://convaisharables.blob.core.windows.net/vqa-regat/data/imgids.zip 68 | unzip data/imgids.zip -d data/imgids 69 | rm data/imgids.zip 70 | 71 | # Download Pickle caches for the pretrained model 72 | # and extract pkl files under data/cache/. 73 | wget -P data https://convaisharables.blob.core.windows.net/vqa-regat/data/cache.zip 74 | unzip data/cache.zip -d data/cache 75 | rm data/cache.zip 76 | 77 | # Download pretrained models 78 | # and extract files under pretrained_models. 79 | wget https://convaisharables.blob.core.windows.net/vqa-regat/pretrained_models.zip 80 | unzip pretrained_models.zip -d pretrained_models/ 81 | rm pretrained_models.zip 82 | -------------------------------------------------------------------------------- /tools/environment.yml: -------------------------------------------------------------------------------- 1 | name: vqa 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2019.8.28=0 9 | - certifi=2019.9.11=py37_0 10 | - cffi=1.12.3=py37h2e261b9_0 11 | - cudatoolkit=10.0.130=0 12 | - intel-openmp=2019.4=243 13 | - libedit=3.1.20181209=hc058e9b_0 14 | - libffi=3.2.1=hd88cf55_4 15 | - libgcc-ng=9.1.0=hdf63c60_0 16 | - libgfortran-ng=7.3.0=hdf63c60_0 17 | - libstdcxx-ng=9.1.0=hdf63c60_0 18 | - mkl=2019.4=243 19 | - mkl-service=2.3.0=py37he904b0f_0 20 | - mkl_fft=1.0.14=py37ha843d7b_0 21 | - mkl_random=1.1.0=py37hd6b4f25_0 22 | - ncurses=6.1=he6710b0_1 23 | - ninja=1.9.0=py37hfd86e86_0 24 | - numpy=1.17.2=py37haad9e8e_0 25 | - numpy-base=1.17.2=py37hde5b4d6_0 26 | - openssl=1.1.1d=h7b6447c_2 27 | - pip=19.2.3=py37_0 28 | - pycparser=2.19=py37_0 29 | - python=3.7.4=h265db76_1 30 | - pytorch=1.0.1=py3.7_cuda10.0.130_cudnn7.4.2_2 31 | - readline=7.0=h7b6447c_5 32 | - setuptools=41.2.0=py37_0 33 | - six=1.12.0=py37_0 34 | - sqlite=3.29.0=h7b6447c_0 35 | - tk=8.6.8=hbc83047_0 36 | - wheel=0.33.6=py37_0 37 | - xz=5.2.4=h14c3975_4 38 | - zlib=1.2.11=h7b6447c_3 39 | - pip: 40 | - argparse==1.4.0 41 | - backcall==0.1.0 42 | - block-bootstrap-pytorch==0.1.5 43 | - bootstrap-pytorch==0.0.13 44 | - click==7.0 45 | - cycler==0.10.0 46 | - decorator==4.4.0 47 | - h5py==2.10.0 48 | - ipdb==0.12.2 49 | - ipython==7.8.0 50 | - ipython-genutils==0.2.0 51 | - jedi==0.15.1 52 | - kiwisolver==1.1.0 53 | - matplotlib==3.1.1 54 | - munch==2.3.2 55 | - opencv-python==4.1.1.26 56 | - pandas==0.25.1 57 | - parso==0.5.1 58 | - pexpect==4.7.0 59 | - pickleshare==0.7.5 60 | - pillow==6.2.0 61 | - plotly==4.1.1 62 | - pretrainedmodels==0.7.4 63 | - prompt-toolkit==2.0.9 64 | - protobuf==3.9.2 65 | - ptyprocess==0.6.0 66 | - pygments==2.4.2 67 | - pyparsing==2.4.2 68 | - python-dateutil==2.8.0 69 | - pytz==2019.2 70 | - pyyaml==5.1.2 71 | - retrying==1.3.3 72 | - scipy==1.3.1 73 | - seaborn==0.9.0 74 | - skipthoughts==0.0.1 75 | - tabulate==0.8.5 76 | - tensorboardx==1.8 77 | - torchvision==0.4.0 78 | - tqdm==4.36.1 79 | - traitlets==4.3.2 80 | - wcwidth==0.1.7 81 | 82 | -------------------------------------------------------------------------------- /tools/process.sh: -------------------------------------------------------------------------------- 1 | ## Copyright (c) Microsoft Corporation. 2 | ## Licensed under the MIT license. 3 | 4 | ## This code is modified by Linjie Li from Hengyuan Hu's repository. 5 | ## https://github.com/hengyuan-hu/bottom-up-attention-vqa 6 | ## GNU General Public License v3.0 7 | 8 | ## Process data 9 | 10 | python3 tools/create_dictionary.py 11 | python3 tools/compute_softscore.py 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | This code is modified by Linjie Li from Jin-Hwa Kim's repository. 6 | https://github.com/jnhwkim/ban-vqa 7 | MIT License 8 | """ 9 | import os 10 | import time 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from tqdm import tqdm 17 | 18 | import utils 19 | from model.position_emb import prepare_graph_variables 20 | 21 | 22 | def instance_bce_with_logits(logits, labels, reduction='mean'): 23 | assert logits.dim() == 2 24 | loss = F.binary_cross_entropy_with_logits( 25 | logits, labels, reduction=reduction) 26 | if reduction == "mean": 27 | loss *= labels.size(1) 28 | return loss 29 | 30 | 31 | def compute_score_with_logits(logits, labels, device): 32 | # argmax 33 | logits = torch.max(logits, 1)[1].data 34 | logits = logits.view(-1, 1) 35 | one_hots = torch.zeros(*labels.size()).to(device) 36 | one_hots.scatter_(1, logits, 1) 37 | scores = (one_hots * labels) 38 | return scores 39 | 40 | 41 | def train(model, train_loader, eval_loader, args, device=torch.device("cuda")): 42 | N = len(train_loader.dataset) 43 | lr_default = args.base_lr 44 | num_epochs = args.epochs 45 | lr_decay_epochs = range(args.lr_decay_start, num_epochs, 46 | args.lr_decay_step) 47 | gradual_warmup_steps = [0.5 * lr_default, 1.0 * lr_default, 48 | 1.5 * lr_default, 2.0 * lr_default] 49 | 50 | optim = torch.optim.Adamax(filter(lambda p: p.requires_grad, 51 | model.parameters()), 52 | lr=lr_default, betas=(0.9, 0.999), eps=1e-8, 53 | weight_decay=args.weight_decay) 54 | 55 | logger = utils.Logger(os.path.join(args.output, 'log.txt')) 56 | best_eval_score = 0 57 | 58 | utils.print_model(model, logger) 59 | logger.write('optim: adamax lr=%.4f, decay_step=%d, decay_rate=%.2f,' 60 | % (lr_default, args.lr_decay_step, 61 | args.lr_decay_rate) + 'grad_clip=%.2f' % args.grad_clip) 62 | logger.write('LR decay epochs: '+','.join( 63 | [str(i) for i in lr_decay_epochs])) 64 | last_eval_score, eval_score = 0, 0 65 | relation_type = train_loader.dataset.relation_type 66 | 67 | for epoch in range(0, num_epochs): 68 | pbar = tqdm(total=len(train_loader)) 69 | total_norm, count_norm = 0, 0 70 | total_loss, train_score = 0, 0 71 | count, average_loss, att_entropy = 0, 0, 0 72 | t = time.time() 73 | if epoch < len(gradual_warmup_steps): 74 | for i in range(len(optim.param_groups)): 75 | optim.param_groups[i]['lr'] = gradual_warmup_steps[epoch] 76 | logger.write('gradual warmup lr: %.4f' % 77 | optim.param_groups[-1]['lr']) 78 | elif (epoch in lr_decay_epochs or 79 | eval_score < last_eval_score and args.lr_decay_based_on_val): 80 | for i in range(len(optim.param_groups)): 81 | optim.param_groups[i]['lr'] *= args.lr_decay_rate 82 | logger.write('decreased lr: %.4f' % optim.param_groups[-1]['lr']) 83 | else: 84 | logger.write('lr: %.4f' % optim.param_groups[-1]['lr']) 85 | last_eval_score = eval_score 86 | 87 | mini_batch_count = 0 88 | batch_multiplier = args.grad_accu_steps 89 | for i, (v, norm_bb, q, target, _, _, bb, spa_adj_matrix, 90 | sem_adj_matrix) in enumerate(train_loader): 91 | batch_size = v.size(0) 92 | num_objects = v.size(1) 93 | if mini_batch_count == 0: 94 | optim.step() 95 | optim.zero_grad() 96 | mini_batch_count = batch_multiplier 97 | 98 | v = Variable(v).to(device) 99 | norm_bb = Variable(norm_bb).to(device) 100 | q = Variable(q).to(device) 101 | target = Variable(target).to(device) 102 | pos_emb, sem_adj_matrix, spa_adj_matrix = prepare_graph_variables( 103 | relation_type, bb, sem_adj_matrix, spa_adj_matrix, num_objects, 104 | args.nongt_dim, args.imp_pos_emb_dim, args.spa_label_num, 105 | args.sem_label_num, device) 106 | pred, att = model(v, norm_bb, q, pos_emb, sem_adj_matrix, 107 | spa_adj_matrix, target) 108 | loss = instance_bce_with_logits(pred, target) 109 | 110 | loss /= batch_multiplier 111 | loss.backward() 112 | mini_batch_count -= 1 113 | total_norm += nn.utils.clip_grad_norm_(model.parameters(), 114 | args.grad_clip) 115 | count_norm += 1 116 | batch_score = compute_score_with_logits(pred, target, device).sum() 117 | total_loss += loss.data.item() * batch_multiplier * v.size(0) 118 | train_score += batch_score 119 | pbar.update(1) 120 | 121 | if args.log_interval > 0: 122 | average_loss += loss.data.item() * batch_multiplier 123 | if model.module.fusion == "ban": 124 | current_att_entropy = torch.sum(calc_entropy(att.data)) 125 | att_entropy += current_att_entropy / batch_size / att.size(1) 126 | count += 1 127 | if i % args.log_interval == 0: 128 | att_entropy /= count 129 | average_loss /= count 130 | print("step {} / {} (epoch {}), ave_loss {:.3f},".format( 131 | i, len(train_loader), epoch, 132 | average_loss), 133 | "att_entropy {:.3f}".format(att_entropy)) 134 | average_loss = 0 135 | count = 0 136 | att_entropy = 0 137 | 138 | total_loss /= N 139 | train_score = 100 * train_score / N 140 | if eval_loader is not None: 141 | eval_score, bound, entropy = evaluate( 142 | model, eval_loader, device, args) 143 | 144 | logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t)) 145 | logger.write('\ttrain_loss: %.2f, norm: %.4f, score: %.2f' 146 | % (total_loss, total_norm / count_norm, train_score)) 147 | if eval_loader is not None: 148 | logger.write('\teval score: %.2f (%.2f)' 149 | % (100 * eval_score, 100 * bound)) 150 | 151 | if entropy is not None: 152 | info = '' 153 | for i in range(entropy.size(0)): 154 | info = info + ' %.2f' % entropy[i] 155 | logger.write('\tentropy: ' + info) 156 | if (eval_loader is not None)\ 157 | or (eval_loader is None and epoch >= args.saving_epoch): 158 | logger.write("saving current model weights to folder") 159 | model_path = os.path.join(args.output, 'model_%d.pth' % epoch) 160 | opt = optim if args.save_optim else None 161 | utils.save_model(model_path, model, epoch, opt) 162 | 163 | 164 | @torch.no_grad() 165 | def evaluate(model, dataloader, device, args): 166 | model.eval() 167 | relation_type = dataloader.dataset.relation_type 168 | score = 0 169 | upper_bound = 0 170 | num_data = 0 171 | N = len(dataloader.dataset) 172 | entropy = None 173 | if model.module.fusion == "ban": 174 | entropy = torch.Tensor(model.module.glimpse).zero_().to(device) 175 | pbar = tqdm(total=len(dataloader)) 176 | 177 | for i, (v, norm_bb, q, target, _, _, bb, spa_adj_matrix, 178 | sem_adj_matrix) in enumerate(dataloader): 179 | batch_size = v.size(0) 180 | num_objects = v.size(1) 181 | v = Variable(v).to(device) 182 | norm_bb = Variable(norm_bb).to(device) 183 | q = Variable(q).to(device) 184 | target = Variable(target).to(device) 185 | 186 | pos_emb, sem_adj_matrix, spa_adj_matrix = prepare_graph_variables( 187 | relation_type, bb, sem_adj_matrix, spa_adj_matrix, num_objects, 188 | args.nongt_dim, args.imp_pos_emb_dim, args.spa_label_num, 189 | args.sem_label_num, device) 190 | pred, att = model(v, norm_bb, q, pos_emb, sem_adj_matrix, 191 | spa_adj_matrix, target) 192 | batch_score = compute_score_with_logits( 193 | pred, target, device).sum() 194 | score += batch_score 195 | upper_bound += (target.max(1)[0]).sum() 196 | num_data += pred.size(0) 197 | if att is not None and 0 < model.module.glimpse\ 198 | and entropy is not None: 199 | entropy += calc_entropy(att.data)[:model.module.glimpse] 200 | pbar.update(1) 201 | 202 | score = score / len(dataloader.dataset) 203 | upper_bound = upper_bound / len(dataloader.dataset) 204 | 205 | if entropy is not None: 206 | entropy = entropy / len(dataloader.dataset) 207 | model.train() 208 | return score, upper_bound, entropy 209 | 210 | 211 | def calc_entropy(att): 212 | # size(att) = [b x g x v x q] 213 | sizes = att.size() 214 | eps = 1e-8 215 | p = att.view(-1, sizes[1], sizes[2] * sizes[3]) 216 | return (-p * (p + eps).log()).sum(2).sum(0) # g 217 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | This code is modified by Linjie Li from Jin-Hwa Kim's repository. 6 | https://github.com/jnhwkim/ban-vqa 7 | MIT License 8 | """ 9 | from __future__ import print_function 10 | 11 | import errno 12 | import os 13 | import re 14 | import collections 15 | import numpy as np 16 | import operator 17 | import functools 18 | from PIL import Image 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch._six import string_classes 23 | from torch.utils.data.dataloader import default_collate 24 | 25 | 26 | EPS = 1e-7 27 | 28 | 29 | def assert_eq(real, expected): 30 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 31 | 32 | 33 | def assert_array_eq(real, expected): 34 | assert (np.abs(real-expected) < EPS).all(), \ 35 | '%s (true) vs %s (expected)' % (real, expected) 36 | 37 | 38 | def weights_init(m): 39 | """custom weights initialization.""" 40 | cname = m.__class__ 41 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d: 42 | m.weight.data.normal_(0.0, 0.02) 43 | elif cname == nn.BatchNorm2d: 44 | m.weight.data.normal_(1.0, 0.02) 45 | m.bias.data.fill_(0) 46 | else: 47 | print('%s is not initialized.' % cname) 48 | 49 | 50 | def init_net(net, net_file): 51 | if net_file: 52 | net.load_state_dict(torch.load(net_file)) 53 | else: 54 | net.apply(weights_init) 55 | 56 | 57 | def create_dir(path): 58 | if not os.path.exists(path): 59 | try: 60 | os.makedirs(path) 61 | except OSError as exc: 62 | if exc.errno != errno.EEXIST: 63 | raise 64 | 65 | 66 | def print_model(model, logger): 67 | nParams = 0 68 | print(model) 69 | for w in model.parameters(): 70 | if w.requires_grad: 71 | nParams += functools.reduce(operator.mul, w.size(), 1) 72 | if logger: 73 | logger.write('nParams=\t'+str(nParams)) 74 | 75 | 76 | def save_model(path, model, epoch, optimizer=None): 77 | model_dict = { 78 | 'epoch': epoch, 79 | 'model_state': model.state_dict() 80 | } 81 | if optimizer is not None: 82 | model_dict['optimizer_state'] = optimizer.state_dict() 83 | 84 | torch.save(model_dict, path) 85 | 86 | 87 | # Select the indices given by `lengths` in the second dimension 88 | # As a result, # of dimensions is shrinked by one 89 | # @param pad(Tensor) 90 | # @param len(list[int]) 91 | def rho_select(pad, lengths): 92 | # Index of the last output for each sequence. 93 | idx_ = (lengths-1).view(-1, 1).expand( 94 | pad.size(0), pad.size(2)).unsqueeze(1) 95 | extracted = pad.gather(1, idx_).squeeze(1) 96 | return extracted 97 | 98 | 99 | def trim_collate(batch): 100 | "Puts each data field into a tensor with outer dimension batch size" 101 | _use_shared_memory = True 102 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 103 | elem_type = type(batch[0]) 104 | if torch.is_tensor(batch[0]): 105 | out = None 106 | 107 | if 1 < batch[0].dim(): # image features 108 | max_num_boxes = max([x.size(0) for x in batch]) 109 | if _use_shared_memory: 110 | # If we're in a background process, concatenate directly into a 111 | # shared memory tensor to avoid an extra copy 112 | if batch[0].size(-1) != batch[0].size(-2): # bottom-up-feature 113 | numel = len(batch) * max_num_boxes * batch[0].size(-1) 114 | storage = batch[0].storage()._new_shared(numel) 115 | out = batch[0].new(storage) 116 | else: # adj_matrix 117 | numel = len(batch) * max_num_boxes * max_num_boxes 118 | storage = batch[0].storage()._new_shared(numel) 119 | out = batch[0].new(storage) 120 | if batch[0].size(-1) != batch[0].size(-2): # bottom-up-feature 121 | # warning: F.pad returns Variable! 122 | return torch.stack( 123 | [F.pad(x, 124 | (0, 0, 0, 125 | max_num_boxes-x.size(0))).data for x in batch], 126 | 0, out=out) 127 | else: # adj_matrix 128 | return torch.stack( 129 | [F.pad(x, 130 | (0, max_num_boxes-x.size(0), 131 | 0, max_num_boxes-x.size(0))).data for x in batch], 132 | 0, out=out) 133 | else: 134 | if _use_shared_memory: 135 | # If we're in a background process, concatenate directly into a 136 | # shared memory tensor to avoid an extra copy 137 | numel = sum([x.numel() for x in batch]) 138 | storage = batch[0].storage()._new_shared(numel) 139 | out = batch[0].new(storage) 140 | return torch.stack(batch, 0, out=out) 141 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 142 | and elem_type.__name__ != 'string_': 143 | elem = batch[0] 144 | if elem_type.__name__ == 'ndarray': 145 | # array of string classes and object 146 | if re.search('[SaUO]', elem.dtype.str) is not None: 147 | raise TypeError(error_msg.format(elem.dtype)) 148 | 149 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 150 | if elem.shape == (): # scalars 151 | py_type = float if elem.dtype.name.startswith('float') else int 152 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 153 | elif isinstance(batch[0], int): 154 | return torch.LongTensor(batch) 155 | 156 | elif isinstance(batch[0], float): 157 | return torch.DoubleTensor(batch) 158 | elif isinstance(batch[0], string_classes): 159 | return batch 160 | elif isinstance(batch[0], collections.Mapping): 161 | return {key: default_collate( 162 | [d[key] for d in batch]) for key in batch[0]} 163 | elif isinstance(batch[0], collections.Sequence): 164 | transposed = zip(*batch) 165 | return [trim_collate(samples) for samples in transposed] 166 | 167 | raise TypeError((error_msg.format(type(batch[0])))) 168 | 169 | 170 | class Logger(object): 171 | def __init__(self, output_name, reset=False): 172 | dirname = os.path.dirname(output_name) 173 | if not os.path.exists(dirname): 174 | os.mkdir(dirname) 175 | if os.path.exists(output_name) and not reset: 176 | self.log_file = open(output_name, 'a') 177 | else: 178 | self.log_file = open(output_name, 'w') 179 | self.infos = {} 180 | 181 | def append(self, key, val): 182 | vals = self.infos.setdefault(key, []) 183 | vals.append(val) 184 | 185 | def log(self, extra_msg=''): 186 | msgs = [extra_msg] 187 | for key, vals in self.infos.iteritems(): 188 | msgs.append('%s %.6f' % (key, np.mean(vals))) 189 | msg = '\n'.join(msgs) 190 | self.log_file.write(msg + '\n') 191 | self.log_file.flush() 192 | self.infos = {} 193 | return msg 194 | 195 | def write(self, msg): 196 | self.log_file.write(msg + '\n') 197 | self.log_file.flush() 198 | print(msg) 199 | 200 | 201 | def create_glove_embedding_init(idx2word, glove_file): 202 | word2emb = {} 203 | with open(glove_file, 'r', encoding='utf-8') as f: 204 | entries = f.readlines() 205 | emb_dim = len(entries[0].split(' ')) - 1 206 | print('embedding dim is %d' % emb_dim) 207 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 208 | 209 | for entry in entries: 210 | vals = entry.split(' ') 211 | word = vals[0] 212 | vals = list(map(float, vals[1:])) 213 | word2emb[word] = np.array(vals) 214 | for idx, word in enumerate(idx2word): 215 | if word not in word2emb: 216 | continue 217 | weights[idx] = word2emb[word] 218 | return weights, word2emb 219 | --------------------------------------------------------------------------------