├── .gitignore ├── LICENSE ├── README.md ├── ansemb ├── __init__.py ├── config.py ├── dataset │ ├── __init__.py │ ├── base.py │ ├── data_utils.py │ ├── preprocess.py │ ├── v7w.py │ ├── vg.py │ └── vqa.py ├── loss.py ├── models │ ├── __init__.py │ ├── embedding.py │ └── layers.py ├── utils.py └── vector.py ├── data ├── .keep ├── README.md ├── prepare_data.sh ├── preprocess_v7w.py └── preprocess_vqa.py ├── requirements.txt ├── tools ├── _init_paths.py ├── dump_vqa_eval_json.py ├── preprocess_answer.py └── preprocess_question.py ├── train_v7w_embedding.py ├── train_vg_embedding.py └── train_vqa_embedding.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Hexiang Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Answer Embedding 2 | Code Release for [Learning Answer Embeddings for Visual Question Answering](http://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Learning_Answer_Embeddings_CVPR_2018_paper.pdf). (CVPR 2018) 3 | 4 | ## Usage 5 | 6 | 7 | ``` 8 | usage: train_v7w_embedding.py [-h] [--gpu_id GPU_ID] [--batch_size BATCH_SIZE] 9 | [--max_negative_answer MAX_NEGATIVE_ANSWER] 10 | [--answer_batch_size ANSWER_BATCH_SIZE] 11 | [--loss_temperature LOSS_TEMPERATURE] 12 | [--pretrained_model PRETRAINED_MODEL] 13 | [--context_embedding {SAN,BoW}] 14 | [--answer_embedding {BoW,RNN}] [--name NAME] 15 | 16 | optional arguments: 17 | -h, --help show this help message and exit 18 | --gpu_id GPU_ID 19 | --batch_size BATCH_SIZE 20 | --max_negative_answer MAX_NEGATIVE_ANSWER 21 | --answer_batch_size ANSWER_BATCH_SIZE 22 | --loss_temperature LOSS_TEMPERATURE 23 | --pretrained_model PRETRAINED_MODEL 24 | --context_embedding {SAN,BoW} 25 | --answer_embedding {BoW,RNN} 26 | --name NAME 27 | ``` 28 | 29 | ## Bibtex 30 | 31 | Please cite with the following bibtex if you are using any related resource of this repo for your research. 32 | 33 | ``` 34 | @inproceedings{hu2018learning, 35 | title={Learning Answer Embeddings for Visual Question Answering}, 36 | author={Hu, Hexiang and Chao, Wei-Lun and Sha, Fei}, 37 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 38 | pages={5428--5436}, 39 | year={2018} 40 | } 41 | ``` 42 | 43 | ## Acknowledgement 44 | Part of this code uses components from [pytorch-vqa](https://github.com/Cyanogenoid/pytorch-vqa) and [torchtext](https://github.com/pytorch/text). We thank authors for releasing their code. 45 | 46 | ## References 47 | 48 | 1. Being Negative but Constructively: 49 | Lessons Learnt from Creating Better Visual Question Answering Datasets ([qaVG website](http://www.teds.usc.edu/website_vqa/)) 50 | 2. Visual7W: Grounded Question Answering in Images 51 | ([website](http://web.stanford.edu/~yukez/visual7w/index.html)) 52 | 3. Making the V in VQA Matter: Elevating the Role of Image Understanding in Visual Question Answering [website](http://www.visualqa.org/) 53 | -------------------------------------------------------------------------------- /ansemb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hexiang-hu/answer_embedding/154182974565de3fd24b669d7d298278e1e8a5d0/ansemb/__init__.py -------------------------------------------------------------------------------- /ansemb/config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import random 4 | import torch 5 | from easydict import EasyDict as edict 6 | 7 | this_dir = osp.dirname(__file__) 8 | project_root = osp.abspath( osp.join(this_dir, '..') ) 9 | 10 | cfg = edict() 11 | cfg.cache_path = osp.join(project_root, '.cache') 12 | cfg.output_path = osp.join(project_root, 'outputs') 13 | cfg.embedding_size = 1024 # embedding dimensionality 14 | cfg.seed = 1618 15 | cfg.question_vocab_path = osp.join(project_root, 'data', 'question.vocab.json') # a joint question vocab across all dataset 16 | 17 | # preprocess config 18 | cfg.image_size = 448 19 | cfg.output_size = cfg.image_size // 32 20 | cfg.output_features = 2048 21 | cfg.central_fraction = 0.875 22 | 23 | # Train params 24 | cfg.TRAIN = edict() 25 | 26 | cfg.TRAIN.epochs = 50 27 | cfg.TRAIN.batch_size = 128 28 | cfg.TRAIN.base_lr = 1e-3 # default Adam lr 29 | cfg.TRAIN.lr_decay = 15 # in epochs 30 | # cfg.TRAIN.data_workers = 20 31 | cfg.TRAIN.data_workers = 10 32 | 33 | cfg.TRAIN.answer_batch_size = 3000 # batch size for answer network 34 | cfg.TRAIN.max_negative_answer = 8000 # max negative answers to sample 35 | 36 | # Test params 37 | cfg.TEST = edict() 38 | cfg.TEST.max_answer_index = 3000 # max answer index for computing acc 39 | 40 | # Dataset params 41 | 42 | # VQA2 params 43 | cfg.VQA2 = edict() 44 | 45 | cfg.VQA2.qa_path = osp.join(project_root, 'data', 'vqa2') 46 | cfg.VQA2.feature_path = osp.join(project_root, 'features', 'vqa-resnet-14x14.h5') 47 | cfg.VQA2.answer_vocab_path = osp.join(project_root, 'data', 'answer.vocab.vqa.json') 48 | cfg.VQA2.train_img_path = osp.join(cfg.VQA2.qa_path, 'images', 'train2014') 49 | cfg.VQA2.val_img_path = osp.join(cfg.VQA2.qa_path, 'images', 'val2014') 50 | cfg.VQA2.test_img_path = osp.join(cfg.VQA2.qa_path, 'images', 'test-dev2015') 51 | 52 | cfg.VQA2.train_qa = 'train2014' 53 | cfg.VQA2.val_qa = 'val2014' 54 | cfg.VQA2.test_qa = 'test-dev2015' 55 | 56 | cfg.VQA2.task = 'OpenEnded' 57 | cfg.VQA2.dataset = 'mscoco' 58 | 59 | # VG params 60 | cfg.VG = edict() 61 | 62 | cfg.VG.qa_path = osp.join(project_root, 'data', 'vg') 63 | cfg.VG.feature_path = osp.join(project_root, 'features', 'vg-resnet-14x14.h5') 64 | cfg.VG.answer_vocab_path = osp.join(project_root, 'data', 'answer.vocab.vg.json') 65 | 66 | cfg.VG.train_qa = 'VG_train_decoys.json' 67 | cfg.VG.val_qa = 'VG_val_decoys.json' 68 | cfg.VG.test_qa = 'VG_test_decoys.json' 69 | 70 | cfg.VG.img_path = osp.join(cfg.VG.qa_path, 'images') 71 | 72 | # V7W params 73 | cfg.Visual7W = edict() 74 | 75 | cfg.Visual7W.qa_path = osp.join(project_root, 'data', 'v7w') 76 | cfg.Visual7W.feature_path = osp.join(project_root, 'features', 'vg-resnet-14x14.h5') 77 | cfg.Visual7W.answer_vocab_path = osp.join(project_root, 'data', 'answer.vocab.v7w.json') 78 | 79 | cfg.Visual7W.train_qa = 'v7w_train_questions.json' 80 | cfg.Visual7W.val_qa = 'v7w_val_questions.json' 81 | cfg.Visual7W.test_qa = 'v7w_test_questions.json' 82 | 83 | ################################################################################# 84 | # A curated dataset for V7W, which removes bias towards modality 85 | # - See [Chao et. al. NAACL-HTL 2018 Being Negative but Constructively...] for 86 | # details. We refer this dataset as V7W throughout the paper. 87 | ################################################################################# 88 | 89 | cfg.Visual7W.train_v7w_decoys = 'v7w_train_decoys.json' 90 | cfg.Visual7W.val_v7w_decoys = 'v7w_val_decoys.json' 91 | cfg.Visual7W.test_v7w_decoys = 'v7w_test_decoys.json' 92 | 93 | cfg.Visual7W.img_path = osp.join(cfg.Visual7W.qa_path, 'images') 94 | 95 | def set_random_seed(seed): 96 | random.seed(seed) 97 | if torch.cuda.is_available(): 98 | torch.cuda.manual_seed(seed) 99 | torch.manual_seed(seed) 100 | np.random.seed(seed) 101 | 102 | def update_train_configs(_cfg): 103 | cfg.TRAIN.batch_size = _cfg.batch_size 104 | cfg.TRAIN.answer_batch_size = _cfg.answer_batch_size 105 | cfg.TRAIN.max_negative_answer = _cfg.max_negative_answer 106 | 107 | if hasattr(_cfg, 'learning_rate'): # Handle optional attributes 108 | cfg.TRAIN.base_lr = _cfg.learning_rate 109 | 110 | if hasattr(_cfg, 'max_answer_index'): 111 | cfg.TEST.max_answer_index = _cfg.max_answer_index 112 | 113 | -------------------------------------------------------------------------------- /ansemb/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hexiang-hu/answer_embedding/154182974565de3fd24b669d7d298278e1e8a5d0/ansemb/dataset/__init__.py -------------------------------------------------------------------------------- /ansemb/dataset/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import random 5 | import nltk 6 | import h5py 7 | 8 | from collections import Counter 9 | import torch 10 | import torch.utils.data as data 11 | from torch.utils.data.dataloader import default_collate 12 | 13 | from ansemb.config import cfg 14 | from ansemb.dataset.preprocess import invert_dict 15 | 16 | class VisualQA(data.Dataset): 17 | def __init__(self, 18 | vector, 19 | image_features_path, 20 | answer_vocab_path=None): 21 | super(VisualQA, self).__init__() 22 | 23 | # vocab 24 | self.vector = vector 25 | 26 | # process question 27 | with open(cfg.question_vocab_path, 'r') as fd: question_vocab = json.load(fd) 28 | self.token_to_index = question_vocab['question'] 29 | self._max_question_length = question_vocab['max_question_length'] 30 | 31 | self.index_to_token = invert_dict(self.token_to_index) 32 | 33 | if answer_vocab_path is not None: 34 | print('import answer vocabulary from: {}'.format(answer_vocab_path)) 35 | # process answer 36 | with open(answer_vocab_path, 'r') as fd: answer_vocab = json.load(fd) 37 | self.answer_to_index = answer_vocab['answer'] 38 | self.index_to_answer = invert_dict(self.answer_to_index) 39 | 40 | self.cached_answers = {} 41 | self.unk_vector = self.vector['UNK'] 42 | 43 | @property 44 | def max_question_length(self): 45 | return self._max_question_length 46 | 47 | @property 48 | def max_answer_length(self): 49 | assert hasattr(self, answers), 'Dataloader must have access to answers' 50 | if not hasattr(self, '_max_answer_length'): 51 | self._max_answer_length = max(map(len, self.answers)) 52 | return self._max_answer_length 53 | 54 | @property 55 | def num_tokens(self): 56 | return len(self.token_to_index) 57 | 58 | @property 59 | def num_answers(self): 60 | return len(self.answer_to_index) 61 | 62 | def __len__(self): 63 | return len(self.questions) 64 | 65 | ### Internal data utility--------------------------------------- 66 | def _load_image(self, image_id): 67 | """ Load an image """ 68 | if not hasattr(self, 'features_file'): 69 | # Loading the h5 file has to be done here and not in __init__ because when the DataLoader 70 | # forks for multiple works, every child would use the same file object and fail 71 | # Having multiple readers using different file objects is fine though, so we just init in here. 72 | self.features_file = h5py.File(self.image_features_path, 'r') 73 | index = self.image_id_to_index[image_id] 74 | dataset = self.features_file['features'] 75 | img = dataset[index].astype('float32') 76 | return torch.from_numpy(img) 77 | 78 | def _get_answer_vectors(self, answer_indices): 79 | if isinstance(answer_indices[0], list): 80 | N, C = len(answer_indices), len(answer_indices[0]) 81 | vector = torch.zeros(N, C, self.vector.dim) 82 | for i, answer_ids in enumerate(answer_indices): 83 | for j, answer_id in enumerate(answer_ids): 84 | if answer_id != -1: 85 | vector[i, j, :] = self._encode_answer_vector(self.index_to_answer[answer_id]) 86 | else: 87 | vector[i, j, :] = self.unk_vector 88 | else: 89 | vector = torch.zeros(len(answer_indices), self.vector.dim) 90 | for idx, answer_id in enumerate(answer_indices): 91 | if answer_id != -1: 92 | vector[idx, :] = self._encode_answer_vector(self.index_to_answer[answer_id]) 93 | else: 94 | vector[idx, :] = self.unk_vector 95 | return vector, [] 96 | 97 | def _get_answer_sequences(self, answer_indices): 98 | seqs, lengths = [], [] 99 | max_seq_length = 0 100 | if isinstance(answer_indices[0], list): 101 | N, C = len(answer_indices), len(answer_indices[0]) 102 | for i, answer_ids in enumerate(answer_indices): 103 | _seqs = [] 104 | for j, answer_id in enumerate(answer_ids): 105 | if answer_id != -1: 106 | _seqs.append( self._encode_answer_sequence(self.index_to_answer[answer_id]) ) 107 | else: 108 | _seqs.append([ self.unk_vector ]) 109 | if max_seq_length < len(_seqs[-1]): max_seq_length = len(_seqs[-1]) # determing max length 110 | seqs.append(_seqs) 111 | 112 | vector = torch.zeros(N, C, max_seq_length, self.vector.dim) 113 | for i, _seqs in enumerate(seqs): 114 | for j, seq in enumerate(_seqs): 115 | if len(seq) != 0: vector[i, j, :len(seq), :] = torch.cat(seq, dim=0) 116 | lengths.append(len(seq)) 117 | assert len(lengths) == N*C, 'Wrong lengths - length: {} vs N: {}, C: {} vs seqs: {}'.format(len(lengths), N, C, len(seqs)) 118 | else: 119 | for idx, answer_id in enumerate(answer_indices): 120 | if answer_id != -1: 121 | seqs.append( self._encode_answer_sequence(self.index_to_answer[answer_id]) ) 122 | else: 123 | seqs.append([ self.unk_vector ]) 124 | 125 | if max_seq_length < len(seqs[-1]): max_seq_length = len(seqs[-1]) # determing max length 126 | 127 | vector = torch.zeros(len(answer_indices), max_seq_length, self.vector.dim) 128 | for idx, seq in enumerate(seqs): 129 | if len(seq) != 0: vector[idx, :len(seq), :] = torch.cat(seq, dim=0) 130 | lengths.append(len(seq)) 131 | 132 | return vector, lengths 133 | 134 | def _create_image_id_to_index(self): 135 | """ Create a mapping from a COCO image id into the corresponding index into the h5 file """ 136 | with h5py.File(self.image_features_path, 'r') as features_file: 137 | image_ids = features_file['ids'][()] 138 | image_id_to_index = {id: i for i, id in enumerate(image_ids)} 139 | return image_id_to_index 140 | 141 | def _encode_question(self, question): 142 | """ Turn a question into a vector of indices and a question length """ 143 | vec = torch.zeros(self.max_question_length).long() 144 | for i, token in enumerate(question): 145 | index = self.token_to_index.get(token, 0) 146 | vec[i] = index 147 | return vec, len(question) 148 | 149 | def _encode_answer_vector(self, answer): 150 | if isinstance(self.cached_answers.get(answer, -1), int): 151 | tokens = nltk.word_tokenize(answer) 152 | answer_vec = torch.zeros(1, self.vector.dim) 153 | cnt = 0 154 | for i, token in enumerate(tokens): 155 | if self.vector.check(token): 156 | answer_vec += self.vector[token] 157 | cnt += 1 158 | self.cached_answers[answer] = answer_vec / (cnt + 1e-12) 159 | 160 | return self.cached_answers[answer] 161 | 162 | def _encode_answer_sequence(self, answer): 163 | if isinstance(self.cached_answers.get(answer, -1), int): 164 | tokens = nltk.word_tokenize(answer) 165 | answer_seq = [] 166 | for i, token in enumerate(tokens): 167 | if self.vector.check(token): 168 | answer_seq.append(self.vector[token].view(1, self.vector.dim)) 169 | else: 170 | answer_seq.append(self.vector[''].view(1, self.vector.dim)) 171 | self.cached_answers[answer] = answer_seq 172 | 173 | return self.cached_answers[answer] 174 | 175 | def _encode_multihot_labels(self, answers, max_answer_index=cfg.TEST.max_answer_index): 176 | """ Turn an answer into a vector """ 177 | tail_index = max_answer_index 178 | answer_vec = torch.zeros(max_answer_index) 179 | for answer in answers: 180 | index = self.answer_to_index.get(answer) 181 | if index is not None: 182 | if index < max_answer_index: 183 | answer_vec[index] += 1 184 | return answer_vec 185 | 186 | def evaluate(self, predictions): 187 | raise NotImplementedError 188 | -------------------------------------------------------------------------------- /ansemb/dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import re 5 | import nltk 6 | import random 7 | 8 | from collections import Counter 9 | from ansemb.config import cfg 10 | import torch 11 | import torch.utils.data as data 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | class Composite(data.Dataset): 15 | """ Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """ 16 | def __init__(self, *datasets): 17 | self.datasets = datasets 18 | 19 | def __getitem__(self, item): 20 | current = self.datasets[0] 21 | for d in self.datasets: 22 | if item < len(d): 23 | return d[item] 24 | item -= len(d) 25 | else: 26 | raise IndexError('Index too large for composite dataset') 27 | 28 | def __len__(self): 29 | return sum(map(len, self.datasets)) 30 | 31 | def _get_answer_vectors(self, answer_indices): 32 | return self.datasets[0]._get_answer_vectors(answer_indices) 33 | 34 | def _get_answer_sequences(self, answer_indices): 35 | return self.datasets[0]._get_answer_sequences(answer_indices) 36 | 37 | @property 38 | def vector(self): 39 | return self.datasets[0].vector 40 | 41 | @property 42 | def token_to_index(self): 43 | return self.datasets[0].token_to_index 44 | 45 | @property 46 | def answer_to_index(self): 47 | return self.datasets[0].answer_to_index 48 | 49 | @property 50 | def index_to_answer(self): 51 | return self.datasets[0].index_to_answer 52 | 53 | @property 54 | def num_tokens(self): 55 | return self.datasets[0].num_tokens 56 | 57 | @property 58 | def num_answer_tokens(self): 59 | return self.datasets[0].num_answer_tokens 60 | 61 | @property 62 | def vocab(self): 63 | return self.datasets[0].vocab 64 | 65 | def generate_batch_answer(indices, counts): 66 | unique_answers = set( aid for aids in indices for aid in aids ) 67 | negative_answers = random.sample( set(range(cfg.TRAIN.max_negative_answer)) - unique_answers, 68 | max(cfg.TRAIN.answer_batch_size - len(unique_answers), 0)) 69 | unique_answers = list(unique_answers) + negative_answers 70 | # unique_answers = list(set( aid for aids in indices for aid in aids )) 71 | answer_dict = { k: i for i, k in enumerate(unique_answers)} 72 | answer_vector = torch.zeros(len(indices), len(unique_answers)) 73 | for i in range(len(counts)): 74 | for j, c in zip(indices[i], counts[i]):answer_vector[i, answer_dict[j]] = c 75 | 76 | return unique_answers, answer_vector 77 | 78 | def collate_fn(batch): 79 | # put question lengths in descending order so that we can use packed sequences later 80 | _images, _questions, _answer_indices, _answer_counts, _choices, _labels, _indices, _question_lengths = zip(*batch) 81 | 82 | # universal contents 83 | images = default_collate(_images) 84 | questions = default_collate(_questions) 85 | indices = default_collate(_indices) 86 | question_lengths = default_collate(_question_lengths) 87 | 88 | if ( _answer_indices[0] == None ) and ( _answer_counts[0] == None ): 89 | return images, questions, indices, question_lengths 90 | 91 | # flatten nested list 92 | _unique_answers, _answer_vectors = generate_batch_answer(_answer_indices, _answer_counts) 93 | 94 | unique_answers = default_collate(_unique_answers) 95 | answer_vectors = default_collate(_answer_vectors) 96 | 97 | if _choices[0] is not None: 98 | return images, questions, unique_answers, answer_vectors, _choices, indices, question_lengths 99 | elif _labels[0] is not None: 100 | answer_labels = default_collate(_labels) 101 | return images, questions, unique_answers, answer_vectors, answer_labels, indices, question_lengths 102 | else: 103 | raise NotImplementedError('Something is wrong with dataloader') 104 | 105 | def eval_collate_fn(batch): 106 | # put question lengths in descending order so that we can use packed sequences later 107 | batch.sort(key=lambda x: x[-1], reverse=True) 108 | return data.dataloader.default_collate(batch) 109 | -------------------------------------------------------------------------------- /ansemb/dataset/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import re 5 | import nltk 6 | import random 7 | import itertools 8 | from collections import Counter 9 | 10 | from PIL import Image 11 | import h5py 12 | import torch 13 | import torch.utils.data as data 14 | import torchvision.transforms as transforms 15 | 16 | import ansemb.utils as utils 17 | 18 | # this is used for normalizing questions 19 | _special_chars = re.compile('[^a-z0-9 ]*') 20 | 21 | # these try to emulate the original normalization scheme for answers 22 | _period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)') 23 | _comma_strip = re.compile(r'(\d)(,)(\d)') 24 | _punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!') 25 | _punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars))) 26 | _punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars)) 27 | 28 | def process_punctuation(s): 29 | # the original is somewhat broken, so things that look odd here might just be to mimic that behaviour 30 | # this version should be faster since we use re instead of repeated operations on str's 31 | original_s = s 32 | if _punctuation.search(s) is None: 33 | return s 34 | s = _punctuation_with_a_space.sub('', s) 35 | if re.search(_comma_strip, s) is not None: 36 | s = s.replace(',', '') 37 | s = _punctuation.sub(' ', s) 38 | s = _period_strip.sub('', s) 39 | if s.strip() == '': return original_s.strip() 40 | else: return s.strip() 41 | 42 | def extract_vocab(iterable, top_k=None, start=0, input_vocab=None): 43 | """ Turns an iterable of list of tokens into a vocabulary. 44 | These tokens could be single answers or word tokens in questions. 45 | """ 46 | all_tokens = itertools.chain.from_iterable(iterable) 47 | counter = Counter(all_tokens) 48 | if top_k: 49 | most_common = counter.most_common(top_k) 50 | most_common = (t for t, c in most_common) 51 | else: 52 | most_common = counter.keys() 53 | # descending in count, then lexicographical order 54 | tokens = sorted(most_common, key=lambda x: (counter[x], x), reverse=True) 55 | 56 | vocab = {t: i for i, t in enumerate(tokens, start=start)} 57 | return vocab 58 | 59 | class CocoImages(data.Dataset): 60 | def __init__(self, path, transform=None): 61 | super(CocoImages, self).__init__() 62 | self.path = path 63 | self.id_to_filename = self._find_images() 64 | self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order 65 | print('found {} images in {}'.format(len(self), self.path)) 66 | self.transform = transform 67 | 68 | def _find_images(self): 69 | id_to_filename = {} 70 | for filename in os.listdir(self.path): 71 | if not filename.endswith('.jpg'): continue 72 | id_and_extension = filename.split('_')[-1] 73 | id = int(id_and_extension.split('.')[0]) 74 | id_to_filename[id] = filename 75 | return id_to_filename 76 | 77 | def __getitem__(self, item): 78 | id = self.sorted_ids[item] 79 | path = os.path.join(self.path, self.id_to_filename[id]) 80 | img = Image.open(path).convert('RGB') 81 | 82 | if self.transform is not None: img = self.transform(img) 83 | return id, img 84 | 85 | def __len__(self): 86 | return len(self.sorted_ids) 87 | 88 | class VGImages(data.Dataset): 89 | def __init__(self, path, transform=None): 90 | super(VGImages, self).__init__() 91 | self.path = path 92 | self.id_to_filename = self._find_images() 93 | self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order 94 | print('found {} images in {}'.format(len(self), self.path)) 95 | self.transform = transform 96 | 97 | def _find_images(self): 98 | id_to_filename = {} 99 | for filename in os.listdir(self.path): 100 | if not filename.endswith('.jpg'): 101 | continue 102 | id = int(filename.split('.')[0]) 103 | id_to_filename[id] = filename 104 | return id_to_filename 105 | 106 | def __getitem__(self, item): 107 | id = self.sorted_ids[item] 108 | path = os.path.join(self.path, self.id_to_filename[id]) 109 | img = Image.open(path).convert('RGB') 110 | 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | return id, img 114 | 115 | def __len__(self): 116 | return len(self.sorted_ids) 117 | 118 | def invert_dict(d): return {v: k for k, v in d.items()} 119 | 120 | class VizwizImages(data.Dataset): 121 | def __init__(self, path, transform=None): 122 | super(VizwizImages, self).__init__() 123 | self.path = path 124 | self.id_to_filename = self._find_images() 125 | self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order 126 | print('found {} images in {}'.format(len(self), self.path)) 127 | self.transform = transform 128 | 129 | def _find_images(self): 130 | id_to_filename = {} 131 | for filename in os.listdir(self.path): 132 | if not filename.endswith('.jpg'): continue 133 | id_and_extension = filename.split('_')[-1] 134 | id = int(id_and_extension.split('.')[0]) 135 | id_to_filename[id] = filename 136 | return id_to_filename 137 | 138 | def __getitem__(self, item): 139 | id = self.sorted_ids[item] 140 | path = os.path.join(self.path, self.id_to_filename[id]) 141 | img = Image.open(path).convert('RGB') 142 | 143 | if self.transform is not None: img = self.transform(img) 144 | return id, img 145 | 146 | def __len__(self): 147 | return len(self.sorted_ids) 148 | -------------------------------------------------------------------------------- /ansemb/dataset/v7w.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import nltk 5 | 6 | from collections import Counter, OrderedDict 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | from torch.utils.data.dataloader import default_collate 11 | 12 | from ansemb.config import cfg 13 | from ansemb.dataset.preprocess import process_punctuation 14 | from ansemb.dataset.base import VisualQA 15 | 16 | import ansemb.utils as utils 17 | import ansemb.dataset.data_utils as data_utils 18 | 19 | def path_for(train=False, val=False, test=False): 20 | assert train + val + test == 1 21 | if train: split = cfg.Visual7W.train_qa 22 | elif val: split = cfg.Visual7W.val_qa 23 | elif test: split = cfg.Visual7W.test_qa 24 | else: raise ValueError('Unsupported split.') 25 | 26 | return os.path.join(cfg.Visual7W.qa_path, split) 27 | 28 | def path_for_decoys(train=False, val=False, test=False): 29 | assert train + val + test == 1 30 | if train: split = cfg.Visual7W.train_v7w_decoys 31 | elif val: split = cfg.Visual7W.val_v7w_decoys 32 | elif test: split = cfg.Visual7W.test_v7w_decoys 33 | else: raise ValueError('Unsupported split.') 34 | 35 | return os.path.join(cfg.Visual7W.qa_path, split) 36 | 37 | def get_loader(vector, train=False, val=False, test=False, vocab_path=None): 38 | assert train + val + test == 1, 'need to set exactly one of {train, val, test} to True' 39 | split = Visual7W( 40 | path_for(train=train, val=val, test=test), 41 | path_for_decoys(train=train, val=val, test=test), 42 | cfg.Visual7W.feature_path, 43 | vector, 44 | vocab_path=vocab_path, 45 | answerable_only=train, 46 | ) 47 | loader = torch.utils.data.DataLoader( 48 | split, 49 | batch_size=cfg.TRAIN.batch_size, 50 | shuffle=train, # only shuffle the data in training 51 | pin_memory=True, 52 | num_workers=cfg.TRAIN.data_workers, 53 | collate_fn=data_utils.collate_fn, 54 | ) 55 | return loader 56 | 57 | class Visual7W(VisualQA): 58 | def __init__(self, questions_path, decoys_path, image_features_path, 59 | vector, vocab_path=None, answerable_only=False): 60 | answer_vocab_path = cfg.Visual7W.answer_vocab_path if vocab_path is None else vocab_path 61 | super(Visual7W, self).__init__(vector, 62 | image_features_path, 63 | answer_vocab_path=answer_vocab_path) 64 | 65 | # load annotation 66 | with open(questions_path, 'r') as fd: questions_json = json.load(fd) 67 | with open(decoys_path, 'r') as fd: decoys_json = json.load(fd) 68 | 69 | # q and a 70 | cache_filepath = osp.join(cfg.cache_path, "visual7w.{}.pt".format(questions_path.split('/')[-1])) 71 | 72 | if not os.path.exists( cache_filepath ): 73 | print('extracting answers...') 74 | self.answers = [ ans for ans in prepare_answers(questions_json) ] 75 | self.v7w_answers = [ ans for ans in prepare_v7w_answers(questions_json, decoys_json) ] 76 | 77 | print('encoding questions...') 78 | self.questions = list(prepare_questions(questions_json)) 79 | self.questions = [self._encode_question(q) for q in self.questions] 80 | 81 | self.answer_indices = [ [ self.answer_to_index.get(_a, -1) for _a in a ] for a in self.answers ] 82 | self.v7w_answer_indices = [ [ self.answer_to_index.get(_a, -1) for _a in a ] for a in self.v7w_answers ] 83 | print('saving cache to: {}'.format(cache_filepath)) 84 | torch.save({'questions': self.questions, 85 | 'answer_indices': self.answer_indices, 86 | 'v7w_answer_indices': self.v7w_answer_indices 87 | }, cache_filepath) 88 | else: 89 | print('loading cache from: {}'.format(cache_filepath)) 90 | _cache = torch.load(cache_filepath) 91 | self.questions = _cache['questions'] 92 | self.answer_indices = _cache['answer_indices'] 93 | self.v7w_answer_indices = _cache['v7w_answer_indices'] 94 | 95 | # process images 96 | self.image_features_path = image_features_path 97 | self.image_id_to_index = self._create_image_id_to_index() 98 | self.image_ids = [q['image_index'] for q in questions_json] 99 | 100 | # swtich for determing the data serving [default is False] 101 | self.serving_v7w = False 102 | 103 | def set_v7w_server(self, value): 104 | self.serving_v7w = value 105 | if self.serving_v7w: 106 | print('Now serving V7W.') 107 | else: 108 | print('Now serving Visual7W.') 109 | 110 | def __getitem__(self, item): 111 | question, question_length = self.questions[item] 112 | 113 | # sample answers 114 | answer_indices = self.answer_indices[item] 115 | 116 | #TODO: beautify the hack that hard code decoys to be zero 117 | if self.serving_v7w == True: 118 | choices = self.v7w_answer_indices[item] 119 | counts = [0, 0, 0, 0, 0, 0, 1] 120 | else: 121 | choices = self.answer_indices[item] 122 | counts = [0, 0, 0, 1] 123 | 124 | image_id = self.image_ids[item] 125 | image = self._load_image(image_id) 126 | return image, question, answer_indices, counts, choices, None, item, question_length 127 | 128 | def prepare_questions(questions_json): 129 | questions = [q['question'] for q in questions_json] 130 | for question in questions: 131 | question = question.lower()[:-1] 132 | yield nltk.word_tokenize(process_punctuation(question)) 133 | 134 | ################################################################################# 135 | # Note that we processed the data so that correct choice is always the last one, 136 | # this is a hack just for convience. So models that take all answer choices as input 137 | # should be cautious about cheating by learning the bias in order 138 | ################################################################################# 139 | 140 | def prepare_answers(answers_json): 141 | answers = [ [ _a.lower().strip('.') for _a in a['multiple_choices'] ] for a in answers_json] 142 | for answer in answers: 143 | yield [ process_punctuation(a) for a in answer ] 144 | 145 | def prepare_v7w_answers(answers_json, decoys_json): 146 | answers= [] 147 | for ans, decoy in zip(answers_json, decoys_json): 148 | assert ans['qa_id'] == decoy['qa_id'], 'inconsistent qa_id: {}, decoy_id: {}'.format(ans['qa_id'], decoy['qa_id']) 149 | answers.append(decoy['IoU_decoys'] + decoy['QoU_decoys'] + [ ans['answer'] ]) 150 | 151 | answers = [ [ _a.lower().strip('.') for _a in a ] for a in answers] 152 | for answer in answers: 153 | yield [ process_punctuation(a) for a in answer ] 154 | -------------------------------------------------------------------------------- /ansemb/dataset/vg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import nltk 5 | 6 | from collections import Counter 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | from torch.utils.data.dataloader import default_collate 11 | 12 | from ansemb.config import cfg 13 | from ansemb.dataset.base import VisualQA 14 | from ansemb.dataset.preprocess import process_punctuation, invert_dict 15 | 16 | import ansemb.utils as utils 17 | import ansemb.dataset.data_utils as data_utils 18 | 19 | def path_for(train=False, val=False, test=False): 20 | assert train + val + test == 1 21 | if train: split = cfg.VG.train_qa 22 | elif val: split = cfg.VG.val_qa 23 | else: split = cfg.VG.test_qa 24 | 25 | return os.path.join(cfg.VG.qa_path, split) 26 | 27 | def create_trainval_loader(vector): 28 | datasets = [ 29 | VG(path_for(train=True), cfg.VG.feature_path, vector), 30 | VG(path_for(val=True), cfg.VG.feature_path, vector) 31 | ] 32 | 33 | data_loader = torch.utils.data.DataLoader( 34 | data_utils.Composite(*datasets), 35 | batch_size=cfg.TRAIN.batch_size, 36 | shuffle=True, # only shuffle the data in training 37 | pin_memory=True, 38 | num_workers=cfg.TRAIN.data_workers, 39 | collate_fn=data_utils.collate_fn, 40 | ) 41 | return data_loader 42 | 43 | def get_loader(vector, train=False, val=False, test=False, vocab_path=None): 44 | assert train + val + test == 1, 'need to set exactly one of {train, val, test} to True' 45 | split = VG( 46 | path_for(train=train, val=val, test=test), 47 | cfg.VG.feature_path, 48 | vector, 49 | vocab_path=vocab_path, 50 | answerable_only=train, 51 | ) 52 | loader = torch.utils.data.DataLoader( 53 | split, 54 | batch_size=cfg.TRAIN.batch_size, 55 | shuffle=train, # only shuffle the data in training 56 | pin_memory=True, 57 | num_workers=cfg.TRAIN.data_workers, 58 | collate_fn=data_utils.collate_fn, 59 | ) 60 | return loader 61 | 62 | def invert_dict(d): 63 | return {v: k for k, v in d.items()} 64 | 65 | class VG(VisualQA): 66 | def __init__(self, questions_path, image_features_path, 67 | vector, vocab_path=None, answerable_only=False): 68 | answer_vocab_path = cfg.VG.answer_vocab_path if vocab_path is None else vocab_path 69 | super(VG, self).__init__(vector, 70 | image_features_path, 71 | answer_vocab_path=answer_vocab_path) 72 | 73 | # load annotation 74 | with open(questions_path, 'r') as fd: questions_json = json.load(fd) 75 | 76 | # q and a 77 | cache_filepath = osp.join(cfg.cache_path, "{}.pt".format(questions_path.split('/')[-1])) 78 | 79 | if not os.path.exists( cache_filepath ): 80 | print('extracting answers...') 81 | self.answers = list(prepare_answers(questions_json)) 82 | self.choices = list(prepare_choices(questions_json)) 83 | 84 | print('encoding questions...') 85 | self.questions = list(prepare_questions(questions_json)) 86 | self.questions = [self._encode_question(q) for q in self.questions] 87 | 88 | self.answer_indices = [ [ self.answer_to_index.get(_a, -1) for _a in a ] for a in self.answers ] 89 | self.choice_indices = [ [ self.answer_to_index.get(_a, -1) for _a in a ] for a in self.choices ] 90 | self.answer_vectors = torch.cat( [self._encode_answer_vector(answer) 91 | for answer, index in self.answer_to_index.items() ], dim=0).float() 92 | print('saving cache to: {}'.format(cache_filepath)) 93 | torch.save({'questions': self.questions, 'answer_indices': self.answer_indices, 94 | 'choice_indices': self.choice_indices, 'answer_vectors': self.answer_vectors}, cache_filepath) 95 | else: 96 | print('loading cache from: {}'.format(cache_filepath)) 97 | _cache = torch.load(cache_filepath) 98 | self.questions = _cache['questions'] 99 | self.answer_indices = _cache['answer_indices'] 100 | self.choice_indices = _cache['choice_indices'] 101 | self.answer_vectors = _cache['answer_vectors'] 102 | 103 | # process images 104 | self.image_features_path = image_features_path 105 | self.image_id_to_index = self._create_image_id_to_index() 106 | self.image_ids = [q['image_id'] for q in questions_json] 107 | 108 | def __getitem__(self, item): 109 | question, question_length = self.questions[item] 110 | 111 | # sample answers 112 | answer_indices = self.answer_indices[item] 113 | counts = [1] 114 | 115 | choices = self.choice_indices[item] 116 | image_id = self.image_ids[item] 117 | image = self._load_image(image_id) 118 | return image, question, answer_indices, counts, choices, None, item, question_length 119 | 120 | def evaluate(self, predictions): 121 | raise NotImplementedError 122 | 123 | def prepare_questions(questions_json): 124 | questions = [q['question'] for q in questions_json] 125 | for question in questions: 126 | question = question.lower()[:-1] 127 | yield nltk.word_tokenize(process_punctuation(question)) 128 | 129 | def prepare_answers(answers_json): 130 | answers = [a['answer'] for a in answers_json] 131 | for answer in answers: 132 | yield [ process_punctuation(answer.lower().strip('.')) ] 133 | 134 | ################################################################################# 135 | # Note that we processed the data so that correct choice is always the last one, 136 | # this is a hack just for convience. So models that take all answer choices as input 137 | # should be cautious about cheating by learning the bias in order 138 | ################################################################################# 139 | 140 | def prepare_choices(answers_json): 141 | answers = [ a['IoU_decoys'] + a['QoU_decoys'] + [ a['answer'] ] for a in answers_json] 142 | for answer in answers: 143 | yield [ process_punctuation(a.lower().strip('.')) for a in answer ] 144 | -------------------------------------------------------------------------------- /ansemb/dataset/vqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import nltk 5 | import random 6 | 7 | from collections import Counter 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | from torch.utils.data.dataloader import default_collate 12 | 13 | from ansemb.config import cfg 14 | from ansemb.dataset.base import VisualQA 15 | from ansemb.dataset.preprocess import process_punctuation, invert_dict 16 | 17 | import ansemb.utils as utils 18 | import ansemb.dataset.data_utils as data_utils 19 | 20 | def path_for(train=False, val=False, test=False, question=False, answer=False): 21 | assert train + val + test == 1 22 | assert question + answer == 1 23 | assert not (test and answer), 'loading answers from test split not supported' 24 | if train: split = cfg.VQA2.train_qa 25 | elif val: split = cfg.VQA2.val_qa 26 | else: split = cfg.VQA2.test_qa 27 | if question: fmt = 'v2_{0}_{1}_{2}_questions.json' 28 | else: fmt = 'v2_{1}_{2}_annotations.json' 29 | s = fmt.format(cfg.VQA2.task, cfg.VQA2.dataset, split) 30 | return os.path.join(cfg.VQA2.qa_path, s) 31 | 32 | 33 | def get_loader(vector, train=False, val=False, test=False, vocab_path=None, _batch_size=None): 34 | """ Returns a data loader for the desired split """ 35 | assert train + val + test == 1, 'need to set exactly one of {train, val, test} to True' 36 | batch_size = _batch_size if _batch_size is not None else cfg.TRAIN.batch_size 37 | if test == True: 38 | split = VQAEval( path_for(test=test, question=True), 39 | cfg.VQA2.feature_path, 40 | vector, 41 | vocab_path=vocab_path) 42 | loader = torch.utils.data.DataLoader( 43 | split, 44 | batch_size=batch_size, 45 | shuffle=False, 46 | pin_memory=True, 47 | num_workers=cfg.TRAIN.data_workers, 48 | collate_fn=data_utils.collate_fn, 49 | ) 50 | else: 51 | split = VQA( 52 | path_for(train=train, val=val, question=True), 53 | path_for(train=train, val=val, answer=True), 54 | cfg.VQA2.feature_path, 55 | vector, 56 | vocab_path=vocab_path, 57 | ) 58 | loader = torch.utils.data.DataLoader( 59 | split, 60 | batch_size=batch_size, 61 | shuffle=train, # only shuffle the data in training 62 | pin_memory=True, 63 | num_workers=cfg.TRAIN.data_workers, 64 | collate_fn=data_utils.collate_fn, 65 | ) 66 | return loader 67 | 68 | def create_trainval_loader(vector): 69 | datasets = [ 70 | VQA(path_for(train=True, question=True), path_for(train=True, answer=True), cfg.VQA2.feature_path, vector), 71 | VQA(path_for(val=True, question=True), path_for(val=True, answer=True), cfg.VQA2.feature_path, vector) 72 | ] 73 | 74 | data_loader = torch.utils.data.DataLoader( 75 | data_utils.Composite(*datasets), 76 | batch_size=cfg.TRAIN.batch_size, 77 | shuffle=True, # only shuffle the data in training 78 | pin_memory=True, 79 | num_workers=cfg.TRAIN.data_workers, 80 | collate_fn=data_utils.collate_fn, 81 | ) 82 | return data_loader 83 | 84 | class VQA(VisualQA): 85 | """ VQA dataset, open-ended """ 86 | def __init__(self, questions_path, answers_path, image_features_path, vector, 87 | vocab_path=None): 88 | answer_vocab_path = cfg.VQA2.answer_vocab_path if vocab_path is None else vocab_path 89 | super(VQA, self).__init__(vector, image_features_path, 90 | answer_vocab_path=answer_vocab_path) 91 | 92 | # load annotation 93 | with open(questions_path, 'r') as fd: questions_json = json.load(fd) 94 | with open(answers_path, 'r') as fd: answers_json = json.load(fd) 95 | 96 | # q and a 97 | cache_filepath = osp.join(cfg.cache_path, "{}.{}.pt".format(questions_path.split('/')[-1], 98 | answers_path.split('/')[-1])) 99 | 100 | print('extracting answers...') 101 | self.answers = list(prepare_answers(answers_json['annotations'])) 102 | 103 | if not os.path.exists( cache_filepath ): 104 | print('encoding questions...') 105 | self.questions = list(prepare_questions(questions_json['questions'])) 106 | self.questions = [self._encode_question(q) for q in self.questions] 107 | 108 | self.answer_indices = [ [ self.answer_to_index.get(_a, -1) for _a in a ] for a in self.answers ] 109 | self.answer_vectors = torch.cat( [self._encode_answer_vector(answer) 110 | for answer, index in self.answer_to_index.items() ], dim=0).float() 111 | print('saving cache to: {}'.format(cache_filepath)) 112 | torch.save({'questions': self.questions, 'answer_indices': self.answer_indices, 113 | 'answer_vectors': self.answer_vectors}, cache_filepath) 114 | else: 115 | print('loading cache from: {}'.format(cache_filepath)) 116 | _cache = torch.load(cache_filepath) 117 | self.questions = _cache['questions'] 118 | self.answer_indices = _cache['answer_indices'] 119 | self.answer_vectors = _cache['answer_vectors'] 120 | 121 | # process images 122 | self.image_features_path = image_features_path 123 | self.image_id_to_index = self._create_image_id_to_index() 124 | self.image_ids = [q['image_id'] for q in questions_json['questions']] 125 | 126 | self.vqa_minus = torch.Tensor([ (anno['answer_type'] != 'yes/no') for anno in answers_json['annotations'] ]).byte() 127 | 128 | def __getitem__(self, item): 129 | question, question_length = self.questions[item] 130 | 131 | # sample answers 132 | answer_cands = Counter(self.answer_indices[item]) 133 | answer_indices = list(answer_cands.keys()) 134 | counts = list(answer_cands.values()) 135 | 136 | label = self._encode_multihot_labels(self.answers[item]) 137 | image_id = self.image_ids[item] 138 | image = self._load_image(image_id) 139 | return image, question, answer_indices, counts, None, label, item, question_length 140 | 141 | def prepare_questions(questions_json): 142 | """ Tokenize and normalize questions from a given question json in the usual VQA format. """ 143 | questions = [q['question'] for q in questions_json] 144 | for question in questions: 145 | question = question.lower()[:-1] 146 | yield nltk.word_tokenize(process_punctuation(question)) 147 | 148 | def prepare_answers(answers_json): 149 | """ Normalize answers from a given answer json in the usual VQA format. """ 150 | answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in answers_json] 151 | for answer_list in answers: 152 | ret = list(map(process_punctuation, answer_list)) 153 | yield ret 154 | 155 | def prepare_multiple_choice_answer(answers_json): 156 | """ Normalize answers from a given answer json in the usual VQA format. """ 157 | multiple_choice_answers = [ ans_dict['multiple_choice_answer'] for ans_dict in answers_json] 158 | for answer in multiple_choice_answers: 159 | yield [ process_punctuation(answer) ] 160 | 161 | class VQAEval(VisualQA): 162 | """ VQA dataset, open-ended """ 163 | def __init__(self, questions_path, image_features_path, vector, vocab_path=None): 164 | answer_vocab_path = cfg.VQA2.answer_vocab_path if vocab_path is None else vocab_path 165 | super(VQAEval, self).__init__(vector, 166 | image_features_path, 167 | answer_vocab_path=answer_vocab_path) 168 | 169 | with open(questions_path, 'r') as fd: questions_json = json.load(fd) 170 | 171 | # q and a 172 | cache_filepath = osp.join(cfg.cache_path, "{}.pt".format(questions_path.split('/')[-1])) 173 | 174 | if not os.path.exists( cache_filepath ): 175 | print('encoding questions...') 176 | self.questions = list(prepare_questions(questions_json['questions'])) 177 | self.questions = [self._encode_question(q) for q in self.questions] 178 | 179 | print('saving cache to: {}'.format(cache_filepath)) 180 | torch.save({'questions': self.questions}, cache_filepath) 181 | else: 182 | print('loading cache from: {}'.format(cache_filepath)) 183 | _cache = torch.load(cache_filepath) 184 | self.questions = _cache['questions'] 185 | 186 | # v 187 | self.image_features_path = image_features_path 188 | self.image_id_to_index = self._create_image_id_to_index() 189 | self.image_ids = [q['image_id'] for q in questions_json['questions']] 190 | 191 | def __getitem__(self, item): 192 | question, question_length = self.questions[item] 193 | 194 | image_id = self.image_ids[item] 195 | image = self._load_image(image_id) 196 | return image, question, None, None, None, None, item, question_length 197 | 198 | -------------------------------------------------------------------------------- /ansemb/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | from ansemb.utils import cosine_sim 7 | 8 | class ContrastiveLoss(nn.Module): 9 | def __init__(self, margin=0.2, measure=False, dual=False, beta=10): 10 | super(ContrastiveLoss, self).__init__() 11 | self.margin = margin 12 | if measure == 'cosine': 13 | self.sim = cosine_sim 14 | else: 15 | raise ValueError('Unknown similarity.[{}]'.format(measure)) 16 | 17 | self.beta = beta 18 | 19 | def forward(self, scores, match, weight=None): 20 | N, M = scores.size() 21 | 22 | match_byte = ( match > 0 ).byte() 23 | pos_iq = torch.zeros(N, M) 24 | pos_iq = scores[match_byte].view(N, 1).expand(N, M) 25 | 26 | margin_weight = 1 27 | if weight is not None: 28 | _, _match = torch.max(match, 1) 29 | _weight = F.normalize(weight, p=2, dim=1) 30 | margin_weight = (1 - ((1 + cosine_sim(_weight, _weight)[_match]) / 2)**self.beta).clamp(min=0, max=1) 31 | 32 | cost_iq = (self.margin*margin_weight + scores - pos_iq).clamp(min=0) 33 | cost_iq = cost_iq.masked_fill_(match_byte, 0) 34 | 35 | return cost_iq.sum() 36 | -------------------------------------------------------------------------------- /ansemb/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hexiang-hu/answer_embedding/154182974565de3fd24b669d7d298278e1e8a5d0/ansemb/models/__init__.py -------------------------------------------------------------------------------- /ansemb/models/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | 7 | import time 8 | from ansemb.models.layers import * 9 | 10 | class StackedAttentionEmbedding(nn.Module): 11 | def __init__(self, embedding_tokens, embedding_weights=None, 12 | output_features=2048, embedding_size=1024, 13 | rnn_bidirectional=True, embedding_requires_grad=True): 14 | super(StackedAttentionEmbedding, self).__init__() 15 | question_features = 1024 16 | rnn_features = int(question_features // 2) if rnn_bidirectional else int(question_features) 17 | vision_features = output_features 18 | glimpses = 2 19 | 20 | vocab_size = embedding_weights.size(0) 21 | vector_dim = embedding_weights.size(1) 22 | self.embedding = nn.Embedding(vocab_size, vector_dim, padding_idx=0) 23 | 24 | self.drop = nn.Dropout(0.5) 25 | self.text = Seq2SeqRNN( 26 | input_features=vector_dim, 27 | rnn_features=int(rnn_features), 28 | rnn_type='LSTM', 29 | rnn_bidirectional=rnn_bidirectional, 30 | ) 31 | self.attention = Attention( 32 | v_features=vision_features, 33 | q_features=question_features, 34 | mid_features=512, 35 | glimpses=2, 36 | drop=0.5, 37 | ) 38 | self.mlp = GroupMLP( 39 | in_features=glimpses * vision_features + question_features, 40 | mid_features=4096, 41 | out_features=embedding_size, 42 | drop=0.5, 43 | groups=64, 44 | ) 45 | 46 | for m in self.modules(): 47 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 48 | init.xavier_uniform(m.weight) 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | 52 | self.embedding.weight.data = embedding_weights 53 | self.embedding.weight.requires_grad = embedding_requires_grad 54 | 55 | def forward(self, v, q, q_len): 56 | q = self.text(self.drop(self.embedding(q)), list(q_len.data)) 57 | # q = self.text(self.embedding(q), list(q_len.data)) 58 | 59 | v = F.normalize(v, p=2, dim=1) 60 | a = self.attention(v, q) 61 | v = apply_attention(v, a) 62 | 63 | combined = torch.cat([v, q], dim=1) 64 | embedding = self.mlp(combined) 65 | return embedding 66 | 67 | class VisualSemanticEmbedding(nn.Module): 68 | def __init__(self, embedding_tokens, embedding_weights=None, 69 | output_features=2048, embedding_size=1024, 70 | rnn_bidirectional=True, embedding_requires_grad=True): 71 | super(VisualSemanticEmbedding, self).__init__() 72 | question_features = 300 73 | rnn_features = int(question_features // 2) if rnn_bidirectional else int(question_features) 74 | vision_features = output_features 75 | glimpses = 2 76 | 77 | # self.text = BagOfWordsMLPProcessor( 78 | self.text = BagOfWordsProcessor( 79 | embedding_tokens=embedding_weights.size(0) if embedding_weights is not None else embedding_tokens, 80 | embedding_weights=embedding_weights, 81 | embedding_features=300, 82 | embedding_requires_grad=True, 83 | rnn_features=rnn_features, 84 | drop=0.5, 85 | ) 86 | self.mlp = GroupMLP( 87 | in_features= vision_features + question_features, 88 | mid_features=4096, 89 | out_features=embedding_size, 90 | drop=0.5, 91 | groups=64, 92 | ) 93 | 94 | for m in self.modules(): 95 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 96 | init.xavier_uniform(m.weight) 97 | if m.bias is not None: 98 | m.bias.data.zero_() 99 | 100 | def forward(self, v, q, q_len): 101 | q = F.normalize(self.text(q, list(q_len.data)), p=2, dim=1) 102 | v = F.normalize(F.avg_pool2d(v, (v.size(2), v.size(3))).squeeze(), p=2, dim=1) 103 | 104 | combined = torch.cat([v, q], dim=1) 105 | embedding = self.mlp(combined) 106 | return embedding 107 | 108 | class MLPEmbedding(nn.Module): 109 | def __init__(self, embedding_features, 110 | embedding_weights=None, 111 | embedding_size=1024): 112 | super(MLPEmbedding, self).__init__() 113 | 114 | self.mlp = GroupMLP( 115 | in_features=embedding_features, 116 | mid_features=4096, 117 | out_features=embedding_size, 118 | drop=0.5, 119 | groups=64, 120 | ) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 124 | init.xavier_uniform(m.weight) 125 | if m.bias is not None: 126 | m.bias.data.zero_() 127 | 128 | def forward(self, a, a_len=None): 129 | return self.mlp(F.normalize(a, p=2)) 130 | 131 | class RNNEmbedding(nn.Module): 132 | def __init__(self, embedding_features, 133 | embedding_weights=None, 134 | rnn_bidirectional=True, 135 | embedding_size=1024): 136 | super(RNNEmbedding, self).__init__() 137 | 138 | rnn_features = int(embedding_size // 2) if rnn_bidirectional else int(embedding_size) 139 | self.text = MaxoutRNN( 140 | input_features=embedding_features, 141 | rnn_features=int(rnn_features), 142 | rnn_type='GRU', 143 | num_layers=2, 144 | rnn_bidirectional=rnn_bidirectional, 145 | drop=0.5, 146 | ) 147 | 148 | def forward(self, a, a_len): 149 | return self.text(a, a_len) 150 | # return self.text(F.normalize(a, p=2), a_len) 151 | -------------------------------------------------------------------------------- /ansemb/models/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | 9 | class MLP(nn.Sequential): 10 | def __init__(self, in_features, mid_features, out_features, drop=0.0, groups=1): 11 | super(MLP, self).__init__() 12 | self.add_module('drop1', nn.Dropout(drop)) 13 | self.add_module('lin1', nn.Linear(in_features, mid_features)) 14 | self.add_module('relu', nn.LeakyReLU()) 15 | self.add_module('drop2', nn.Dropout(drop)) 16 | self.add_module('lin2', nn.Linear(mid_features, out_features)) 17 | 18 | class MaxoutRNN(nn.Module): 19 | def __init__(self, input_features, rnn_features, num_layers=1, drop=0.0, 20 | rnn_type='LSTM', rnn_bidirectional=False): 21 | super(MaxoutRNN, self).__init__() 22 | self.bidirectional = rnn_bidirectional 23 | 24 | if rnn_type == 'LSTM': 25 | self.rnn = nn.LSTM(input_size=input_features, 26 | hidden_size=rnn_features, dropout=drop, 27 | num_layers=num_layers, batch_first=True, 28 | bidirectional=rnn_bidirectional) 29 | elif rnn_type == 'GRU': 30 | self.rnn = nn.GRU(input_size=input_features, 31 | hidden_size=rnn_features, dropout=drop, 32 | num_layers=num_layers, batch_first=True, 33 | bidirectional=rnn_bidirectional) 34 | else: 35 | raise ValueError('Unsupported RNN type') 36 | 37 | self.features = rnn_features 38 | 39 | self._init_rnn(self.rnn.weight_ih_l0) 40 | self._init_rnn(self.rnn.weight_hh_l0) 41 | self.rnn.bias_ih_l0.data.zero_() 42 | self.rnn.bias_hh_l0.data.zero_() 43 | 44 | def _init_rnn(self, weight): 45 | for w in weight.chunk(3, 0): 46 | init.xavier_uniform(w) 47 | 48 | def forward(self, q_emb, q_len, hidden=None): 49 | lengths = torch.LongTensor(q_len) 50 | lens, indices = torch.sort(lengths, 0, True) 51 | 52 | packed_batch = pack_padded_sequence(q_emb[indices.cuda()], lens.tolist(), batch_first=True) 53 | if hidden is not None: 54 | N_, H_ = hidden.size() 55 | hs, _ = self.rnn(packed_batch, hidden[indices.cuda()].view(1, N_, H_)) 56 | else: 57 | hs, _ = self.rnn(packed_batch) 58 | outputs, _ = pad_packed_sequence(hs, batch_first=True, padding_value=np.float('-inf')) 59 | 60 | _, _indices = torch.sort(indices, 0) 61 | outputs = outputs[_indices.cuda()] 62 | N, L, H = outputs.size() 63 | return F.max_pool1d(outputs.transpose(1, 2), L).squeeze().view(N, H) 64 | 65 | class Seq2SeqRNN(nn.Module): 66 | def __init__(self, input_features, rnn_features, num_layers=1, drop=0.0, 67 | rnn_type='LSTM', rnn_bidirectional=False): 68 | super(Seq2SeqRNN, self).__init__() 69 | self.bidirectional = rnn_bidirectional 70 | 71 | if rnn_type == 'LSTM': 72 | self.rnn = nn.LSTM(input_size=input_features, 73 | hidden_size=rnn_features, dropout=drop, 74 | num_layers=num_layers, batch_first=True, 75 | bidirectional=rnn_bidirectional) 76 | elif rnn_type == 'GRU': 77 | self.rnn = nn.GRU(input_size=input_features, 78 | hidden_size=rnn_features, dropout=drop, 79 | num_layers=num_layers, batch_first=True, 80 | bidirectional=rnn_bidirectional) 81 | else: 82 | raise ValueError('Unsupported Type') 83 | 84 | self.init_weight(rnn_bidirectional, rnn_type) 85 | 86 | def init_weight(self, bidirectional, rnn_type): 87 | self._init_rnn(self.rnn.weight_ih_l0, rnn_type) 88 | self._init_rnn(self.rnn.weight_hh_l0, rnn_type) 89 | self.rnn.bias_ih_l0.data.zero_() 90 | self.rnn.bias_hh_l0.data.zero_() 91 | 92 | if bidirectional: 93 | self._init_rnn(self.rnn.weight_ih_l0_reverse, rnn_type) 94 | self._init_rnn(self.rnn.weight_hh_l0_reverse, rnn_type) 95 | self.rnn.bias_ih_l0_reverse.data.zero_() 96 | self.rnn.bias_hh_l0_reverse.data.zero_() 97 | 98 | def _init_rnn(self, weight, rnn_type): 99 | chunk_size = 4 if rnn_type == 'LSTM' else 3 100 | for w in weight.chunk(chunk_size, 0): 101 | init.xavier_uniform(w) 102 | 103 | def forward(self, q_emb, q_len): 104 | lengths = torch.LongTensor(q_len) 105 | lens, indices = torch.sort(lengths, 0, True) 106 | 107 | packed = pack_padded_sequence(q_emb[indices.cuda()], lens.tolist(), batch_first=True) 108 | if isinstance(self.rnn, nn.LSTM): 109 | _, ( outputs, _ ) = self.rnn(packed) 110 | elif isinstance(self.rnn, nn.GRU): 111 | _, outputs = self.rnn(packed) 112 | 113 | if self.bidirectional: 114 | outputs = torch.cat([ outputs[0, :, :], outputs[1, :, :] ], dim=1) 115 | else: 116 | outputs = outputs.squeeze(0) 117 | 118 | _, _indices = torch.sort(indices, 0) 119 | outputs = outputs[_indices.cuda()] 120 | 121 | return outputs 122 | 123 | class SelfAttention(nn.Module): 124 | def __init__(self, v_features, mid_features, glimpses, drop=0.0): 125 | super(SelfAttention, self).__init__() 126 | self.v_conv = nn.Conv2d(v_features, mid_features, 1, bias=False) # let self.lin take care of bias 127 | self.x_conv = nn.Conv2d(mid_features, glimpses, 1) 128 | 129 | self.drop = nn.Dropout(drop) 130 | self.relu = nn.LeakyReLU(inplace=True) 131 | 132 | def forward(self, v): 133 | v = self.v_conv(self.drop(v)) 134 | x = self.relu(v) 135 | x = self.x_conv(self.drop(x)) 136 | return x 137 | 138 | class Attention(nn.Module): 139 | def __init__(self, v_features, q_features, mid_features, glimpses, drop=0.0): 140 | super(Attention, self).__init__() 141 | self.v_conv = nn.Conv2d(v_features, mid_features, 1, bias=False) # let self.lin take care of bias 142 | self.q_lin = nn.Linear(q_features, mid_features) 143 | self.x_conv = nn.Conv2d(mid_features, glimpses, 1) 144 | 145 | self.drop = nn.Dropout(drop) 146 | self.relu = nn.LeakyReLU(inplace=True) 147 | 148 | def forward(self, v, q): 149 | v = self.v_conv(self.drop(v)) 150 | q = self.q_lin(self.drop(q)) 151 | q = tile_2d_over_nd(q, v) 152 | x = self.relu(v + q) 153 | x = self.x_conv(self.drop(x)) 154 | return x 155 | 156 | def apply_attention(input, attention): 157 | """ Apply any number of attention maps over the input. 158 | The attention map has to have the same size in all dimensions except dim=1. 159 | """ 160 | n, c = input.size()[:2] 161 | glimpses = attention.size(1) 162 | 163 | # flatten the spatial dims into the third dim, since we don't need to care about how they are arranged 164 | input = input.view(n, c, -1) 165 | attention = attention.view(n, glimpses, -1) 166 | s = input.size(2) 167 | 168 | # apply a softmax to each attention map separately 169 | # since softmax only takes 2d inputs, we have to collapse the first two dimensions together 170 | # so that each glimpse is normalized separately 171 | attention = attention.view(n * glimpses, -1) 172 | attention = F.softmax(attention) 173 | 174 | # apply the weighting by creating a new dim to tile both tensors over 175 | target_size = [n, glimpses, c, s] 176 | input = input.view(n, 1, c, s).expand(*target_size) 177 | attention = attention.view(n, glimpses, 1, s).expand(*target_size) 178 | weighted = input * attention 179 | # sum over only the spatial dimension 180 | weighted_mean = weighted.sum(dim=3) 181 | # the shape at this point is (n, glimpses, c, 1) 182 | return weighted_mean.view(n, -1) 183 | 184 | 185 | def tile_2d_over_nd(feature_vector, feature_map): 186 | """ Repeat the same feature vector over all spatial positions of a given feature map. 187 | The feature vector should have the same batch size and number of features as the feature map. 188 | """ 189 | n, c = feature_vector.size() 190 | spatial_size = feature_map.dim() - 2 191 | tiled = feature_vector.view(n, c, *([1] * spatial_size)).expand_as(feature_map) 192 | return tiled 193 | 194 | class BagOfWordsProcessor(nn.Module): 195 | def __init__(self, embedding_tokens, embedding_features, rnn_features, 196 | embedding_weights, embedding_requires_grad, drop=0.0): 197 | super(BagOfWordsProcessor, self).__init__() 198 | self.embedding = nn.Embedding(embedding_tokens, embedding_features, padding_idx=0) 199 | 200 | self.embedding.weight.data = embedding_weights 201 | self.embedding.weight.requires_grad = embedding_requires_grad 202 | 203 | def forward(self, q, q_len): 204 | embedded = self.embedding(q) 205 | q_len = Variable(torch.Tensor(q_len).view(-1, 1) + 1e-12, requires_grad=False).cuda() 206 | 207 | return torch.div( torch.sum(embedded, 1), q_len ) 208 | 209 | class GroupMLP(nn.Module): 210 | def __init__(self, in_features, mid_features, out_features, drop=0.5, groups=1): 211 | super(GroupMLP, self).__init__() 212 | 213 | self.conv1 = nn.Conv1d(in_features, mid_features, 1) 214 | self.drop = nn.Dropout(p=drop) 215 | self.relu = nn.LeakyReLU() 216 | self.conv2 = nn.Conv1d(mid_features, out_features, 1, groups=groups) 217 | 218 | def forward(self, a): 219 | N, C = a.size() 220 | h = self.relu(self.conv1(a.view(N, C, 1))) 221 | return self.conv2(self.drop(h)).view(N, -1) 222 | 223 | -------------------------------------------------------------------------------- /ansemb/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms as transforms 7 | 8 | from ansemb.config import cfg 9 | 10 | def cosine_sim(im, s): 11 | return im.mm(s.t()) 12 | 13 | def batch_mc_acc(predicted): 14 | """ Compute the accuracies for a batch of predictions and answers """ 15 | N, C = predicted.squeeze().size() 16 | _, predicted_index = predicted.max(dim=1, keepdim=True) 17 | return (predicted_index == C - 1).float() 18 | 19 | def batch_top1(predicted, true): 20 | """ Compute the accuracies for a batch of predictions and answers """ 21 | _, predicted_index = predicted.max(dim=1, keepdim=True) 22 | return true.gather(dim=1, index=predicted_index).clamp(max=1) 23 | 24 | def batch_accuracy(predicted, true): 25 | """ Compute the accuracies for a batch of predictions and answers """ 26 | _, predicted_index = predicted.max(dim=1, keepdim=True) 27 | agreeing = true.gather(dim=1, index=predicted_index) 28 | return (agreeing * 0.3).clamp(max=1) 29 | 30 | def update_learning_rate(optimizer, epoch): 31 | learning_rate = cfg.TRAIN.base_lr * 0.5**(float(epoch) / cfg.TRAIN.lr_decay) 32 | for param_group in optimizer.param_groups: param_group['lr'] = learning_rate 33 | 34 | return learning_rate 35 | 36 | class Tracker: 37 | """ Keep track of results over time, while having access to monitors to display information about them. """ 38 | def __init__(self): 39 | self.data = {} 40 | 41 | def track(self, name, *monitors): 42 | """ Track a set of results with given monitors under some name (e.g. 'val_acc'). 43 | When appending to the returned list storage, use the monitors to retrieve useful information. 44 | """ 45 | l = Tracker.ListStorage(monitors) 46 | self.data.setdefault(name, []).append(l) 47 | return l 48 | 49 | def to_dict(self): 50 | # turn list storages into regular lists 51 | return {k: list(map(list, v)) for k, v in self.data.items()} 52 | 53 | class ListStorage: 54 | """ Storage of data points that updates the given monitors """ 55 | def __init__(self, monitors=[]): 56 | self.data = [] 57 | self.monitors = monitors 58 | for monitor in self.monitors: 59 | setattr(self, monitor.name, monitor) 60 | 61 | def append(self, item): 62 | for monitor in self.monitors: 63 | monitor.update(item) 64 | self.data.append(item) 65 | 66 | def __iter__(self): 67 | return iter(self.data) 68 | 69 | class MeanMonitor: 70 | """ Take the mean over the given values """ 71 | name = 'mean' 72 | 73 | def __init__(self): 74 | self.n = 0 75 | self.total = 0 76 | 77 | def update(self, value): 78 | self.total += value 79 | self.n += 1 80 | 81 | @property 82 | def value(self): 83 | return self.total / self.n 84 | 85 | class MovingMeanMonitor: 86 | """ Take an exponentially moving mean over the given values """ 87 | name = 'mean' 88 | 89 | def __init__(self, momentum=0.9): 90 | self.momentum = momentum 91 | self.first = True 92 | self.value = None 93 | 94 | def update(self, value): 95 | if self.first: 96 | self.value = value 97 | self.first = False 98 | else: 99 | m = self.momentum 100 | self.value = m * self.value + (1 - m) * value 101 | 102 | def get_transform(target_size, central_fraction=1.0): 103 | return transforms.Compose([ 104 | transforms.Scale(int(target_size / central_fraction)), 105 | transforms.CenterCrop(target_size), 106 | transforms.ToTensor(), 107 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 108 | std=[0.229, 0.224, 0.225]), 109 | ]) 110 | -------------------------------------------------------------------------------- /ansemb/vector.py: -------------------------------------------------------------------------------- 1 | ## The following code is modified based on https://github.com/pytorch/text/blob/master/torchtext/vocab.py 2 | import array 3 | import zipfile 4 | from tqdm import tqdm 5 | from six.moves.urllib.request import urlretrieve 6 | import os 7 | import os.path as osp 8 | import torch 9 | import io 10 | 11 | def reporthook(t): 12 | """https://github.com/tqdm/tqdm""" 13 | last_b = [0] 14 | 15 | def inner(b=1, bsize=1, tsize=None): 16 | """ 17 | b: int, optionala 18 | Number of blocks just transferred [default: 1]. 19 | bsize: int, optional 20 | Size of each block (in tqdm units) [default: 1]. 21 | tsize: int, optional 22 | Total size (in tqdm units). If [default: None] remains unchanged. 23 | """ 24 | if tsize is not None: 25 | t.total = tsize 26 | t.update((b - last_b[0]) * bsize) 27 | last_b[0] = b 28 | return inner 29 | 30 | class Vector(object): 31 | def __init__(self, cache_path='.vector_cache', 32 | vector_type='glove.840B', unk_init=torch.Tensor.zero_): 33 | urls = { 34 | 'glove.42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 35 | 'glove.840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 36 | 'glove.6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 37 | } 38 | url = urls[vector_type] if urls.get(vector_type, False) != False else None 39 | name = osp.splitext(osp.basename(url))[0] + '.txt' 40 | 41 | self.unk_init = unk_init 42 | self.cache(name, cache_path, url=url) 43 | 44 | def __getitem__(self, token): 45 | if self.stoi.get(token, -1) != -1: 46 | return self.vectors[self.stoi[token]] 47 | else: 48 | return self.unk_init(torch.Tensor(1, self.dim)) 49 | 50 | def _prepare(self, vocab): 51 | word2vec = torch.Tensor( len(vocab), self.dim ) 52 | for token, idx in vocab.items(): 53 | word2vec[idx, :] = self[token] 54 | 55 | return word2vec 56 | 57 | def check(self, token): 58 | if self.stoi.get(token, -1) != -1: 59 | return True 60 | else: 61 | return False 62 | 63 | def cache(self, name, cache_path, url=None): 64 | path = osp.join(cache_path, name) 65 | path_pt = "{}.pt".format(path) 66 | 67 | if not osp.isfile(path_pt): 68 | # download vocab file if it does not exists 69 | if not osp.exists(path) and url: 70 | dest = osp.join(cache_path, os.path.basename(url)) 71 | if not osp.exists(dest): 72 | print('[-] Downloading vectors from {}'.format(url)) 73 | if not osp.exists(cache_path): os.mkdir(cache_path) 74 | 75 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: 76 | urlretrieve(url, dest, reporthook=reporthook(t)) 77 | 78 | print('[-] Extracting vectors into {}'.format(path)) 79 | ext = os.path.splitext(dest)[1][1:] 80 | if ext == 'zip': 81 | with zipfile.ZipFile(dest, "r") as zf: zf.extractall(cache_path) 82 | 83 | if not os.path.isfile(path): 84 | raise RuntimeError('no vectors found at {}'.format(path)) 85 | 86 | # build vocab list 87 | itos, vectors, dim = [], array.array(str('d')), None 88 | 89 | # Try to read the whole file with utf-8 encoding. 90 | binary_lines = False 91 | try: 92 | with io.open(path, encoding="utf8") as f: 93 | lines = [line for line in f] 94 | # If there are malformed lines, read in binary mode 95 | # and manually decode each word from utf-8 96 | except: 97 | print("[!] Could not read {} as UTF8 file, " 98 | "reading file as bytes and skipping " 99 | "words with malformed UTF8.".format(path)) 100 | with open(path, 'rb') as f: 101 | lines = [line for line in f] 102 | binary_lines = True 103 | 104 | print("[-] Loading vectors from {}".format(path)) 105 | for line in tqdm(lines, total=len(lines)): 106 | # Explicitly splitting on " " is important, so we don't 107 | # get rid of Unicode non-breaking spaces in the vectors. 108 | entries = line.rstrip().split(" ") 109 | word, entries = entries[0], entries[1:] 110 | if dim is None and len(entries) > 1: 111 | dim = len(entries) 112 | elif len(entries) == 1: 113 | print("Skipping token {} with 1-dimensional " 114 | "vector {}; likely a header".format(word, entries)) 115 | continue 116 | elif dim != len(entries): 117 | raise RuntimeError( 118 | "Vector for token {} has {} dimensions, but previously " 119 | "read vectors have {} dimensions. All vectors must have " 120 | "the same number of dimensions.".format(word, len(entries), dim)) 121 | 122 | vectors.extend(float(x) for x in entries) 123 | itos.append(word) 124 | 125 | self.itos = itos 126 | self.stoi = {word: i for i, word in enumerate(itos)} 127 | self.vectors = torch.Tensor(vectors).view(-1, dim) 128 | self.dim = dim 129 | print('* Caching vectors to {}'.format(path_pt)) 130 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) 131 | else: 132 | print('* Loading vectors to {}'.format(path_pt)) 133 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) 134 | 135 | -------------------------------------------------------------------------------- /data/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hexiang-hu/answer_embedding/154182974565de3fd24b669d7d298278e1e8a5d0/data/.keep -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data preprocessing 2 | 3 | 4 | ## QA data preprocess 5 | 6 | Please run the following command to download the data and prepare json 7 | 8 | ``` 9 | sh prepare_data.sh 10 | ``` 11 | 12 | ## Image data preprocess 13 | 14 | After running the above script, please download the images for each dataset and put them in the subdirectory `/image` for visual feature extraction. 15 | 16 | ## References 17 | - **[qaVG] dataset**. Being Negative but Constructively: 18 | Lessons Learnt from Creating Better Visual Question Answering Datasets ([qaVG website](http://www.teds.usc.edu/website_vqa/)) 19 | - **[V7W] dataset**. Grounded Question Answering in Images 20 | ([website](http://web.stanford.edu/~yukez/visual7w/index.html)) 21 | - **[VQA2] dataset**. Making the V in VQA Matter: Elevating the Role of Image Understanding in Visual Question Answering [website](http://www.visualqa.org/) 22 | -------------------------------------------------------------------------------- /data/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | mkdir -p v7w 3 | if [ -f "v7w/v7w_train_questions.json" ]; then 4 | echo "visual7w dataset exists, skipping..." 5 | else 6 | echo "Download visual7w" 7 | wget -q "http://web.stanford.edu/~yukez/papers/resources/dataset_v7w_telling.zip" -O v7w/dataset.zip 8 | unzip -j v7w/dataset.zip -d v7w/ 9 | python preprocess_v7w.py 10 | fi 11 | 12 | mkdir -p vg 13 | if [ -f "vg/VG_train_decoys.json" ]; then 14 | echo "qaVG dataset exists, skipping..." 15 | else 16 | echo "Download qaVG" 17 | wget -q "http://hexianghu.com/vqa-negative-decoys/Visual_Genome_decoys.zip" -O vg/Visual_Genome_decoys.zip 18 | unzip -j vg/Visual_Genome_decoys.zip -d vg/ 19 | fi 20 | 21 | mkdir -p vqa2 22 | if [ -f "vqa2/vqa_train_questions.json" ]; then 23 | echo "vqa2 dataset exists, skipping..." 24 | else 25 | echo "Download vqa2" 26 | wget -q "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip" -O vqa2/v2_Questions_Train_mscoco.zip 27 | unzip -j vqa2/v2_Questions_Train_mscoco.zip -d vqa2/ 28 | wget -q "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip" -O vqa2/v2_Questions_Val_mscoco.zip 29 | unzip -j vqa2/v2_Questions_Val_mscoco.zip -d vqa2/ 30 | wget -q "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip" -O vqa2/v2_Questions_Test_mscoco.zip 31 | unzip -j vqa2/v2_Questions_Test_mscoco.zip -d vqa2/ 32 | wget -q "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip" -O vqa2/v2_Annotations_Train_mscoco.zip 33 | unzip -j vqa2/v2_Annotations_Train_mscoco.zip -d vqa2/ 34 | wget -q "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip" -O vqa2/v2_Annotations_Val_mscoco.zip 35 | unzip -j vqa2/v2_Annotations_Val_mscoco.zip -d vqa2/ 36 | fi 37 | 38 | python preprocess_vqa.py 39 | -------------------------------------------------------------------------------- /data/preprocess_v7w.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | import json 4 | 5 | this_dir = osp.dirname(__file__) 6 | with open(osp.join(this_dir, 'v7w', 'dataset_v7w_telling.json'), 'r') as _file: 7 | qa_anno = json.load(_file)['images'] 8 | 9 | print 'parsing json.' 10 | output_jsons = { 'train': 'v7w_train_questions.json', 'val': 'v7w_val_questions.json', 'test': 'v7w_test_questions.json' } 11 | qa_pairs = { 'train': [], 'val': [], 'test': [] } 12 | for qas in qa_anno: 13 | qa_pairs[qas['split']].extend( [ { 'qa_id': qa['qa_id'], 'question': qa['question'], 'answer': qa['answer'], 14 | 'multiple_choices': qa['multiple_choices'] + [ qa['answer'] ], 'question_type': qa['type'], 15 | 'image_index': qas['image_id'], 'filename': qas['filename'] } for qa in qas['qa_pairs']] ) 16 | 17 | print 'writing out parsed jsons.' 18 | for k, v in qa_pairs.items(): 19 | with open(osp.join(this_dir, 'v7w', output_jsons[k]), 'w') as _file: json.dump(v, _file) 20 | -------------------------------------------------------------------------------- /data/preprocess_vqa.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | import json 4 | import re 5 | from collections import Counter, defaultdict, OrderedDict 6 | from tqdm import tqdm 7 | from IPython import embed 8 | 9 | this_dir = osp.dirname(__file__) 10 | 11 | contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \ 12 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \ 13 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \ 14 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \ 15 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \ 16 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \ 17 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \ 18 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \ 19 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \ 20 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \ 21 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \ 22 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \ 23 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \ 24 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \ 25 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \ 26 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \ 27 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \ 28 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \ 29 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \ 30 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \ 31 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \ 32 | "youll": "you'll", "youre": "you're", "youve": "you've"} 33 | manualMap = { 'none': '0', 34 | 'zero': '0', 35 | 'one': '1', 36 | 'two': '2', 37 | 'three': '3', 38 | 'four': '4', 39 | 'five': '5', 40 | 'six': '6', 41 | 'seven': '7', 42 | 'eight': '8', 43 | 'nine': '9', 44 | 'ten': '10' 45 | } 46 | articles = ['a', 47 | 'an', 48 | 'the' 49 | ] 50 | 51 | periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 52 | commaStrip = re.compile("(\d)(\,)(\d)") 53 | punct = [';', r"/", '[', ']', '"', '{', '}', 54 | '(', ')', '=', '+', '\\', '_', '-', 55 | '>', '<', '@', '`', ',', '?', '!'] 56 | 57 | def processPunctuation(inText): 58 | outText = inText 59 | for p in punct: 60 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(commaStrip, inText) != None): 61 | outText = outText.replace(p, '') 62 | else: 63 | outText = outText.replace(p, ' ') 64 | outText = periodStrip.sub("", outText, re.UNICODE) 65 | 66 | return outText 67 | 68 | def processDigitArticle(inText): 69 | outText = [] 70 | tempText = inText.lower().split() 71 | for word in tempText: 72 | word = manualMap.setdefault(word, word) 73 | if word not in articles: 74 | outText.append(word) 75 | else: 76 | pass 77 | for wordId, word in enumerate(outText): 78 | if word in contractions: 79 | outText[wordId] = contractions[word] 80 | outText = ' '.join(outText) 81 | return outText 82 | 83 | def preprocess_answer(answer): 84 | answer = answer.replace('\n', ' ') 85 | answer = answer.replace('\t', ' ') 86 | answer = answer.strip() 87 | answer = processPunctuation(answer) 88 | answer = processDigitArticle(answer) 89 | 90 | return answer 91 | 92 | train_apath = osp.join(this_dir, 'vqa2', 'v2_mscoco_train2014_annotations.json') 93 | val_apath = osp.join(this_dir, 'vqa2', 'v2_mscoco_val2014_annotations.json') 94 | 95 | train_qpath = osp.join(this_dir, 'vqa2', 'v2_OpenEnded_mscoco_train2014_questions.json') 96 | val_qpath = osp.join(this_dir, 'vqa2', 'v2_OpenEnded_mscoco_val2014_questions.json') 97 | 98 | output_jsons = { 'train': 'vqa_train_questions.json', 'val': 'vqa_val_questions.json', 'test': 'vqa_test_questions.json' } 99 | answer_conf = { 'maybe': 0.5, 'yes': 1 , 'no': 0.1} 100 | qa_pairs = { 'train': [], 'val': [] } 101 | 102 | print('loading train json.') 103 | with open(train_apath, 'r') as afile: train_anno = json.load(afile)['annotations'] 104 | with open(train_qpath, 'r') as qfile: train_ques = json.load(qfile)['questions'] 105 | assert len(train_anno) == len(train_ques) 106 | 107 | print('loading val json.') 108 | with open(val_apath, 'r') as afile: val_anno = json.load(afile)['annotations'] 109 | with open(val_qpath, 'r') as qfile: val_ques = json.load(qfile)['questions'] 110 | assert len(val_anno) == len(val_ques) 111 | 112 | print('parsing train json.') 113 | for q, a in tqdm(zip( train_ques, train_anno )): 114 | assert q['question_id'] == a['question_id'] 115 | qa_pairs['train'].extend([{ 116 | 'qa_id': q['question_id'], 'question': q['question'], 'answers': [ preprocess_answer(_a['answer']) for _a in a['answers']], 'multiple_choice_answer': preprocess_answer(a['multiple_choice_answer']), 117 | 'question_type': a['question_type'], 'image_index': a['image_id'], 'filename': 'COCO_train2014_{:012}.jpg'.format(a['image_id']) 118 | }]) 119 | 120 | print('parsing val json.') 121 | for q, a in tqdm(zip( val_ques, val_anno )): 122 | assert q['question_id'] == a['question_id'] 123 | qa_pairs['val'].extend([{ 124 | 'qa_id': q['question_id'], 'question': q['question'], 'answers': [ preprocess_answer(_a['answer']) for _a in a['answers']], 'multiple_choice_answer': preprocess_answer(a['multiple_choice_answer']), 125 | 'question_type': a['question_type'], 'image_index': a['image_id'], 'filename': 'COCO_val2014_{:012}.jpg'.format(a['image_id']) 126 | }]) 127 | 128 | print('writing out parsed jsons.') 129 | for k, v in qa_pairs.items(): 130 | with open(osp.join(this_dir, 'vqa2', output_jsons[k]), 'w') as _file: json.dump(v, _file) 131 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | nltk==3.4.0.3 3 | h5py==2.9.0 4 | python>3.0 5 | numpy>=1.10 6 | -------------------------------------------------------------------------------- /tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: sys.path.insert(0, path) 6 | 7 | this_dir = osp.dirname(__file__) 8 | 9 | # Add LIB_PATH to PYTHONPATH 10 | lib_path = osp.abspath( osp.join(this_dir, '..') ) 11 | add_path(lib_path) 12 | -------------------------------------------------------------------------------- /tools/dump_vqa_eval_json.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import json 3 | import argparse 4 | import torch 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--input_path', required=True, type=str) 8 | parser.add_argument('--iter', default=-1, type=int) 9 | parser.add_argument('--input_json', default='data/vqa2/v2_OpenEnded_mscoco_val2014_questions.json', type=str) 10 | parser.add_argument('--test_json', default='data/vqa2/v2_OpenEnded_mscoco_test2015_questions.json', type=str) 11 | parser.add_argument('--output_path', required=True, type=str) 12 | args = parser.parse_args() 13 | 14 | def invert_dict(d): 15 | return {v: k for k, v in d.items()} 16 | 17 | def main(args): 18 | data = torch.load(args.input_path) 19 | evaluation = data['eval'] 20 | if isinstance( evaluation, list ): evaluation = evaluation[args.iter] 21 | 22 | answer_ids = torch.cat(evaluation['answer_ids']).numpy().tolist() 23 | question_ids = torch.cat(evaluation['question_ids']).numpy().tolist() 24 | print( 'Total number of answer: {}'.format( len(answer_ids) ) ) 25 | 26 | with open(args.input_json, 'r') as fd: questions = json.load(fd)['questions'] 27 | assert len(answer_ids) == len(questions) 28 | 29 | invert_vocab = data['vocab']['index_to_answer'] 30 | 31 | diff_qids = None 32 | if 'test' in args.input_json: 33 | with open(args.test_json, 'r') as fd: test_questions = json.load(fd)['questions'] 34 | dev_qids = set( que['question_id'] for que in questions) 35 | test_qids = set( que['question_id'] for que in test_questions) 36 | diff_qids = test_qids - dev_qids 37 | 38 | with open(args.output_path, 'w') as fd: 39 | val_json = [ { 'answer': invert_vocab[answer], 'question_id': questions[idx]['question_id'] } 40 | for answer, idx in zip(answer_ids, question_ids) ] 41 | 42 | if diff_qids is not None: 43 | val_json.extend([ {'question_id': qid, 'answer': 'yes' } for qid in diff_qids ]) 44 | 45 | json.dump(val_json, fd) 46 | 47 | if __name__ == '__main__': 48 | print(args.__dict__) 49 | main(args) 50 | -------------------------------------------------------------------------------- /tools/preprocess_answer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import Counter 3 | import itertools 4 | import nltk 5 | import argparse 6 | 7 | import _init_paths 8 | from ansemb.config import cfg 9 | import ansemb.utils as utils 10 | import ansemb.dataset.preprocess as prepro 11 | from ansemb.vector import Vector 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--max_answers', default=None, type=int) 15 | parser.add_argument('--dataset', default='v7w', choices=['v7w', 'vqa', 'vg', 'vizwiz']) 16 | parser.add_argument('--vocab_json_path', default='data/answer.vocab.{}.json', type=str) 17 | args = parser.parse_args() 18 | 19 | # import dataset specific dataloader 20 | exec('import ansemb.dataset.{} as data'.format(args.dataset)) 21 | 22 | def parse_input(args): 23 | if args.dataset == 'v7w' or args.dataset == 'vg': 24 | train_questions = data.path_for(train=True) 25 | val_questions = data.path_for(val=True) 26 | test_questions = data.path_for(test=True) 27 | 28 | with open(train_questions, 'r') as fd: 29 | annotations = json.load(fd) 30 | with open(val_questions, 'r') as fd: 31 | annotations.extend( json.load(fd) ) 32 | with open(test_questions, 'r') as fd: 33 | annotations.extend( json.load(fd) ) 34 | elif args.dataset == 'vqa': 35 | train_answers = data.path_for(train=True, answer=True) 36 | val_answers = data.path_for(val=True, answer=True) 37 | 38 | with open(train_answers, 'r') as fd: 39 | annotations = json.load(fd)['annotations'] 40 | with open(val_answers, 'r') as fd: 41 | annotations.extend( json.load(fd)['annotations'] ) 42 | elif args.dataset == 'vizwiz': 43 | train_answers = data.path_for(train=True) 44 | val_answers = data.path_for(val=True) 45 | test_answers = data.path_for(test=True) 46 | with open(train_answers, 'r') as fd: 47 | annotations = json.load(fd)['annotations'] 48 | with open(val_answers, 'r') as fd: 49 | annotations.extend( json.load(fd)['annotations'] ) 50 | with open(test_answers, 'r') as fd: 51 | annotations.extend( json.load(fd)['annotations'] ) 52 | else: 53 | raise ValueError('Unsupported Dataset') 54 | 55 | return annotations 56 | 57 | def main(args): 58 | output_format = args.vocab_json_path.format 59 | 60 | # process input json files 61 | annotations = parse_input(args) 62 | 63 | word2vec = Vector() 64 | 65 | answers = data.prepare_answers(annotations) 66 | answer_vocab = prepro.extract_vocab(answers, top_k=args.max_answers) 67 | 68 | vocabs = { 'answer': answer_vocab } 69 | print('* Dump output vocab to: {}'.format(output_format(args.dataset))) 70 | with open(output_format(args.dataset), 'w') as fd: 71 | json.dump(vocabs, fd) 72 | 73 | if __name__ == '__main__': 74 | main(args) 75 | -------------------------------------------------------------------------------- /tools/preprocess_question.py: -------------------------------------------------------------------------------- 1 | import json 2 | import itertools 3 | import nltk 4 | import argparse 5 | 6 | import _init_paths 7 | import ansemb.dataset.vqa as data_vqa 8 | import ansemb.dataset.vg as data_vg 9 | import ansemb.dataset.v7w as data_v7w 10 | from ansemb.dataset.preprocess import extract_vocab 11 | import ansemb.utils as utils 12 | from ansemb.vector import Vector 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--vocab_json_path', default='data/question.vocab.json', type=str) 16 | args = parser.parse_args() 17 | 18 | def main(args): 19 | vqa_train_questions = data_vqa.path_for(train=True, question=True) 20 | vqa_val_questions = data_vqa.path_for(val=True, question=True) 21 | 22 | vg_train_questions = data_vg.path_for(train=True) 23 | vg_val_questions = data_vg.path_for(val=True) 24 | vg_test_questions = data_vg.path_for(test=True) 25 | 26 | v7w_train_questions = data_v7w.path_for(train=True) 27 | v7w_val_questions = data_v7w.path_for(val=True) 28 | v7w_test_questions = data_v7w.path_for(test=True) 29 | 30 | with open(vqa_train_questions, 'r') as fd: 31 | vqaq = json.load(fd)['questions'] 32 | with open(vqa_val_questions, 'r') as fd: 33 | vqaq.extend( json.load(fd)['questions'] ) 34 | 35 | with open(vg_train_questions, 'r') as fd: 36 | vg = json.load(fd) 37 | with open(vg_val_questions, 'r') as fd: 38 | vg.extend( json.load(fd) ) 39 | with open(vg_test_questions, 'r') as fd: 40 | vg.extend( json.load(fd)) 41 | 42 | with open(v7w_train_questions, 'r') as fd: 43 | v7w = json.load(fd) 44 | with open(v7w_val_questions, 'r') as fd: 45 | v7w.extend( json.load(fd) ) 46 | with open(v7w_test_questions, 'r') as fd: 47 | v7w.extend( json.load(fd) ) 48 | 49 | word2vec = Vector() 50 | max_question_length = max( max(map(len, data_vqa.prepare_questions(vqaq))), 51 | max(map(len, data_vg.prepare_questions(vg))), 52 | max(map(len, data_v7w.prepare_questions(v7w)))) 53 | 54 | questions = [ q for q in data_vqa.prepare_questions(vqaq) ] 55 | questions.extend([ q for q in data_vg.prepare_questions(vg) ]) 56 | questions.extend([ q for q in data_v7w.prepare_questions(v7w) ]) 57 | 58 | question_vocab = extract_vocab(questions, start=1, input_vocab=word2vec) 59 | question_vocab[''] = 0 # set token 0 as unknown token 60 | 61 | vocabs = { 62 | 'question': question_vocab, 63 | 'max_question_length': max_question_length, 64 | } 65 | with open(args.vocab_json_path, 'w') as fd: 66 | json.dump(vocabs, fd) 67 | 68 | if __name__ == '__main__': 69 | main(args) 70 | -------------------------------------------------------------------------------- /train_v7w_embedding.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path 3 | import math 4 | import json 5 | import numpy as np 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | from tqdm import tqdm 15 | import copy 16 | import argparse 17 | import time 18 | 19 | cudnn.enabled = True 20 | cudnn.benchmark = True 21 | 22 | from ansemb.config import cfg, set_random_seed, update_train_configs 23 | import ansemb.dataset.v7w as data 24 | import ansemb.models.embedding as model 25 | import ansemb.utils as utils 26 | 27 | from ansemb.utils import cosine_sim 28 | from ansemb.vector import Vector 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--gpu_id', default=0, type=int) 32 | parser.add_argument('--batch_size', default=128, type=float) 33 | parser.add_argument('--max_negative_answer', default=50000, type=float) 34 | parser.add_argument('--answer_batch_size', default=3000, type=float) 35 | parser.add_argument('--loss_temperature', default=0.01, type=float) 36 | parser.add_argument('--pretrained_model', default=None, type=str) 37 | parser.add_argument('--context_embedding', default='SAN', choices=['SAN', 'BoW']) 38 | parser.add_argument('--answer_embedding', default='BoW', choices=['BoW', 'RNN']) 39 | parser.add_argument('--name', default=None, type=str) 40 | args = parser.parse_args() 41 | 42 | # fix random seed 43 | set_random_seed(cfg.seed) 44 | 45 | def test(context_net, answer_net, loader, tracker, args, prefix='', epoch=0): 46 | context_net.eval() 47 | answer_net.eval() 48 | tracker_class, tracker_params = tracker.MeanMonitor, {} 49 | accs = [] 50 | 51 | tq = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) 52 | acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) 53 | 54 | var_params = { 'volatile': True, 'requires_grad': False } 55 | 56 | cnt = 0 57 | for v, q, _, _, a_ids, idx, q_len in tq: 58 | if args.answer_embedding == 'RNN': 59 | answer_var, answer_len = loader.dataset._get_answer_sequences(a_ids) 60 | answer_var = Variable(answer_var.cuda(), **var_params) 61 | N, C, L, _ = answer_var.size() 62 | answer_embedding = answer_net(answer_var.view(N*C, L, -1), answer_len) 63 | else: 64 | answer_var, answer_len = loader.dataset._get_answer_vectors(a_ids) 65 | answer_var = Variable(answer_var.cuda(), **var_params) 66 | N, C, _ = answer_var.size() 67 | answer_embedding = answer_net(answer_var.view(N*C, -1), answer_len) 68 | 69 | v = Variable(v.cuda(), **var_params) 70 | q = Variable(q.cuda(), **var_params) 71 | q_len = Variable(q_len.cuda(), **var_params) 72 | 73 | context_embedding = context_net(v, q, q_len) 74 | 75 | _, D = context_embedding.size() 76 | out = torch.sum( torch.mul( context_embedding.view(N, 1, D).expand(N, C, D), answer_embedding.view(N, C, D)), 2) 77 | 78 | acc = utils.batch_mc_acc(out.data.view(N, C, -1)).cpu() 79 | 80 | accs.append(acc.view(-1)) 81 | 82 | acc_tracker.append(acc.mean()) 83 | 84 | fmt = '{:.4f}'.format 85 | tq.set_postfix(acc=fmt(acc_tracker.mean.value)) 86 | 87 | return accs 88 | 89 | def train(context_net, answer_net, loader, optimizer, tracker, args, prefix='', epoch=0): 90 | """ Run an epoch over the given loader """ 91 | context_net.train() 92 | answer_net.train() 93 | tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99} 94 | 95 | tq = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) 96 | loss_tracker = tracker.track('{}_loss'.format(prefix), tracker_class(**tracker_params)) 97 | acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) 98 | lr_tracker = tracker.track('{}_lr'.format(prefix), tracker_class(**tracker_params)) 99 | 100 | var_params = { 'volatile': False, 'requires_grad': False, } 101 | log_softmax = nn.LogSoftmax().cuda() 102 | cnt = 0 103 | start_tm=time.time() 104 | for v, q, avocab, a, _, idx, q_len in tq: 105 | data_tm = time.time() - start_tm 106 | start_tm=time.time() 107 | 108 | if args.answer_embedding == 'RNN': 109 | answer_var, answer_len = loader.dataset._get_answer_sequences(avocab) 110 | else: 111 | answer_var, answer_len = loader.dataset._get_answer_vectors(avocab) 112 | answer_var = Variable(answer_var.cuda(), **var_params) 113 | 114 | v = Variable(v.cuda(), **var_params) 115 | q = Variable(q.cuda(), **var_params) 116 | a = Variable(a.cuda(), **var_params) 117 | q_len = Variable(q_len.cuda(), **var_params) 118 | 119 | encode_tm = time.time() - start_tm 120 | start_tm=time.time() 121 | 122 | context_embedding = context_net(v, q, q_len) 123 | answer_embedding = answer_net(answer_var, answer_len) 124 | 125 | predicts = cosine_sim(context_embedding, answer_embedding) / args.loss_temperature #temperature 126 | nll = -log_softmax(predicts) 127 | loss = (nll * a ).sum(dim=1).mean() 128 | 129 | acc = utils.batch_top1(predicts.data, a.data).cpu() 130 | 131 | global total_iterations 132 | lr = utils.update_learning_rate(optimizer, epoch) 133 | 134 | optimizer.zero_grad() 135 | loss.backward() 136 | 137 | optimizer.step() 138 | 139 | model_tm = time.time() - start_tm 140 | start_tm=time.time() 141 | 142 | loss_tracker.append(loss.data[0]) 143 | acc_tracker.append(acc.mean()) 144 | lr_tracker.append(lr) 145 | fmt = '{:.6f}'.format 146 | tq.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value), lr=fmt(lr_tracker.mean.value), t_data=data_tm, t_model=model_tm) 147 | 148 | def main(args): 149 | if args.name is None: 150 | from datetime import datetime 151 | name = args.context_embedding+"_"+args.answer_embedding+"_v7w_batch_softmax_embedding_"+datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 152 | else: 153 | name = args.context_embedding+"_"+args.answer_embedding+"_"+args.name 154 | 155 | output_filepath = os.path.join(cfg.output_path, '{}.pth'.format(name)) 156 | print('Output data would be saved to {}'.format(output_filepath)) 157 | 158 | 159 | word2vec = Vector() 160 | train_loader = data.get_loader(word2vec, train=True) 161 | val_loader = data.get_loader(word2vec, val=True) 162 | test_loader = data.get_loader(word2vec, test=True) 163 | 164 | question_word2vec = word2vec._prepare(train_loader.dataset.token_to_index) 165 | 166 | if args.context_embedding == 'SAN': 167 | context_net = model.StackedAttentionEmbedding( 168 | train_loader.dataset.num_tokens, 169 | question_word2vec).cuda() 170 | elif args.context_embedding == 'BoW': 171 | context_net = model.VisualSemanticEmbedding( 172 | train_loader.dataset.num_tokens, 173 | question_word2vec).cuda() 174 | else: 175 | raise TypeError('Unsupported Context Model') 176 | 177 | if args.answer_embedding == 'BoW': 178 | answer_net = model.MLPEmbedding(train_loader.dataset.vector.dim).cuda() 179 | elif args.answer_embedding == 'RNN': 180 | answer_net = model.RNNEmbedding(train_loader.dataset.vector.dim).cuda() 181 | else: 182 | raise TypeError('Unsupported Answer Model') 183 | 184 | print('Context Model:') 185 | print(context_net) 186 | 187 | print('Answer Model:') 188 | print(answer_net) 189 | 190 | if args.pretrained_model is not None: 191 | states = torch.load(args.pretrained_model) 192 | answer_state, context_state = states['answer_net'], states['context_net'] 193 | 194 | answer_net.load_state_dict(answer_state) 195 | context_net.load_state_dict(context_state) 196 | 197 | params_for_optimization = list(context_net.parameters()) + list(answer_net.parameters()) 198 | optimizer = optim.Adam([p for p in params_for_optimization if p.requires_grad]) 199 | 200 | tracker = utils.Tracker() 201 | if args.pretrained_model is not None: 202 | accs = test(context_net, answer_net, val_loader, tracker, args, prefix='val', epoch=-1) 203 | print('* Visual7W Val Accuracy: {}'.format(torch.cat(accs).mean())) 204 | accs = test(context_net, answer_net, test_loader, tracker, args, prefix='test', epoch=-1) 205 | print('* Visual7W Test Accuracy: {}'.format(torch.cat(accs).mean())) 206 | 207 | test_loader.dataset.set_v7w_server(True) 208 | accs = test(context_net, answer_net, test_loader, tracker, args, prefix='test', epoch=-1) 209 | print('* V7W Test Accuracy: {}'.format(torch.cat(accs).mean())) 210 | test_loader.dataset.set_v7w_server(False) 211 | 212 | raise NotImplementedError 213 | 214 | best_val_acc = 0 215 | best_context_net, best_answser_net = None, None 216 | _eval = [] 217 | for i in range(cfg.TRAIN.epochs): 218 | _ = train(context_net, answer_net, train_loader, optimizer, tracker, args, prefix='train', epoch=i) 219 | 220 | r = test(context_net, answer_net, val_loader, tracker, args, prefix='val', epoch=i) 221 | print('* Visual7W Val Accuracy: {}'.format(torch.cat(r).mean())) 222 | val_acc = 0.5 * torch.mean( torch.cat(r, dim=0) ) 223 | 224 | val_loader.dataset.set_v7w_server(True) 225 | r = test(context_net, answer_net, val_loader, tracker, args, prefix='val', epoch=i) 226 | print('* Visual7W Val Accuracy: {}'.format(torch.cat(r).mean())) 227 | val_loader.dataset.set_v7w_server(False) 228 | 229 | val_acc += 0.5 * torch.mean( torch.cat(r, dim=0) ) 230 | _eval.append({ 'accuracies': r }) 231 | 232 | if best_val_acc < val_acc: 233 | accs = test(context_net, answer_net, test_loader, tracker, args, prefix='test', epoch=i) 234 | print('* Visual7W Test Accuracy: {}'.format(torch.cat(accs).mean())) 235 | 236 | test_loader.dataset.set_v7w_server(True) 237 | accs = test(context_net, answer_net, test_loader, tracker, args, prefix='test', epoch=i) 238 | print('* V7W Test Accuracy: {}'.format(torch.cat(accs).mean())) 239 | test_loader.dataset.set_v7w_server(False) 240 | 241 | best_val_acc = val_acc 242 | best_context_net = copy.deepcopy( context_net.state_dict() ) 243 | best_answer_net = copy.deepcopy( answer_net.state_dict() ) 244 | 245 | results = { 246 | 'name': name, 247 | 'tracker': tracker.to_dict(), 248 | 'config': cfg, 249 | 'context_net': best_context_net, 250 | 'answer_net': best_answer_net, 251 | 'eval': _eval, 252 | 'vocab': { 'answer_to_index': train_loader.dataset.answer_to_index, 253 | 'index_to_answer': train_loader.dataset.index_to_answer } 254 | } 255 | torch.save(results, output_filepath) 256 | 257 | test(context_net, answer_net, test_loader, tracker, args, prefix='test', epoch=-1) 258 | 259 | if __name__ == '__main__': 260 | torch.cuda.set_device(args.gpu_id) 261 | print(args.__dict__) 262 | print(cfg) 263 | update_train_configs(args) 264 | main(args) 265 | -------------------------------------------------------------------------------- /train_vg_embedding.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path 3 | import math 4 | import json 5 | import numpy as np 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | from tqdm import tqdm 15 | import copy 16 | import argparse 17 | import time 18 | 19 | cudnn.enabled = True 20 | cudnn.benchmark = True 21 | 22 | from ansemb.config import cfg, set_random_seed, update_train_configs 23 | import ansemb.dataset.vg as data 24 | import ansemb.models.embedding as model 25 | import ansemb.utils as utils 26 | 27 | from ansemb.utils import cosine_sim 28 | from ansemb.vector import Vector 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--gpu_id', default=0, type=int) 32 | parser.add_argument('--batch_size', default=128, type=float) 33 | parser.add_argument('--max_negative_answer', default=80000, type=float) 34 | parser.add_argument('--answer_batch_size', default=5000, type=float) 35 | parser.add_argument('--max_answer_index', default=5000, type=float) 36 | parser.add_argument('--loss_temperature', default=0.01, type=float) 37 | parser.add_argument('--pretrained_model', default=None, type=str) 38 | parser.add_argument('--context_embedding', default='SAN', choices=['SAN', 'BoW']) 39 | parser.add_argument('--answer_embedding', default='BoW', choices=['BoW', 'RNN']) 40 | parser.add_argument('--name', default=None, type=str) 41 | args = parser.parse_args() 42 | 43 | # fix random seed 44 | set_random_seed(cfg.seed) 45 | 46 | def test(context_net, answer_net, loader, tracker, args, prefix='', epoch=0): 47 | context_net.eval() 48 | answer_net.eval() 49 | tracker_class, tracker_params = tracker.MeanMonitor, {} 50 | accs = [] 51 | 52 | tq = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) 53 | acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) 54 | 55 | var_params = { 'volatile': True, 'requires_grad': False } 56 | 57 | cnt = 0 58 | for v, q, avocab, a, a_ids, idx, q_len in tq: 59 | if args.answer_embedding == 'RNN': 60 | answer_var, answer_len = loader.dataset._get_answer_sequences(a_ids) 61 | answer_var = Variable(answer_var.cuda(), **var_params) 62 | N, C, L, _ = answer_var.size() 63 | answer_embedding = answer_net(answer_var.view(N*C, L, -1), answer_len) 64 | else: 65 | answer_var, answer_len = loader.dataset._get_answer_vectors(a_ids) 66 | answer_var = Variable(answer_var.cuda(), **var_params) 67 | N, C, _ = answer_var.size() 68 | answer_embedding = answer_net(answer_var.view(N*C, -1), answer_len) 69 | 70 | v = Variable(v.cuda(), **var_params) 71 | q = Variable(q.cuda(), **var_params) 72 | q_len = Variable(q_len.cuda(), **var_params) 73 | 74 | context_embedding = context_net(v, q, q_len) 75 | 76 | _, D = context_embedding.size() 77 | out = torch.sum( torch.mul( context_embedding.view(N, 1, D).expand(N, C, D), answer_embedding.view(N, C, D)), 2) 78 | 79 | acc = utils.batch_mc_acc(out.data.view(N, C, -1)).cpu() 80 | 81 | accs.append(acc.view(-1)) 82 | 83 | acc_tracker.append(acc.mean()) 84 | 85 | fmt = '{:.4f}'.format 86 | tq.set_postfix(acc=fmt(acc_tracker.mean.value)) 87 | 88 | return accs 89 | 90 | def train(context_net, answer_net, loader, optimizer, tracker, args, prefix='', epoch=0): 91 | """ Run an epoch over the given loader """ 92 | context_net.train() 93 | answer_net.train() 94 | tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99} 95 | 96 | tq = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) 97 | loss_tracker = tracker.track('{}_loss'.format(prefix), tracker_class(**tracker_params)) 98 | acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) 99 | lr_tracker = tracker.track('{}_lr'.format(prefix), tracker_class(**tracker_params)) 100 | 101 | var_params = { 'volatile': False, 'requires_grad': False, } 102 | log_softmax = nn.LogSoftmax().cuda() 103 | cnt = 0 104 | start_tm=time.time() 105 | for v, q, avocab, a, _, idx, q_len in tq: 106 | data_tm = time.time() - start_tm 107 | start_tm=time.time() 108 | 109 | if args.answer_embedding == 'RNN': 110 | answer_var, answer_len = loader.dataset._get_answer_sequences(avocab) 111 | else: 112 | answer_var, answer_len = loader.dataset._get_answer_vectors(avocab) 113 | answer_var = Variable(answer_var.cuda(), **var_params) 114 | 115 | v = Variable(v.cuda(), **var_params) 116 | q = Variable(q.cuda(), **var_params) 117 | a = Variable(a.cuda(), **var_params) 118 | q_len = Variable(q_len.cuda(), **var_params) 119 | 120 | encode_tm = time.time() - start_tm 121 | start_tm=time.time() 122 | 123 | context_embedding = context_net(v, q, q_len) 124 | answer_embedding = answer_net(answer_var, answer_len) 125 | 126 | predicts = cosine_sim(context_embedding, answer_embedding) / args.loss_temperature #temperature 127 | nll = -log_softmax(predicts) 128 | loss = (nll * a ).sum(dim=1).mean() 129 | 130 | acc = utils.batch_top1(predicts.data, a.data).cpu() 131 | 132 | global total_iterations 133 | lr = utils.update_learning_rate(optimizer, epoch) 134 | 135 | optimizer.zero_grad() 136 | loss.backward() 137 | 138 | optimizer.step() 139 | 140 | model_tm = time.time() - start_tm 141 | start_tm=time.time() 142 | 143 | loss_tracker.append(loss.data[0]) 144 | acc_tracker.append(acc.mean()) 145 | lr_tracker.append(lr) 146 | fmt = '{:.6f}'.format 147 | tq.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value), lr=fmt(lr_tracker.mean.value), t_data=data_tm, t_model=model_tm) 148 | 149 | def main(args): 150 | if args.name is None: 151 | from datetime import datetime 152 | name = args.context_embedding+"_"+args.answer_embedding+"_vg_batch_softmax_embedding_"+datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 153 | else: 154 | name = args.context_embedding+"_"+args.answer_embedding+"_"+args.name 155 | 156 | output_filepath = os.path.join(cfg.output_path, '{}.pth'.format(name)) 157 | print('Output data would be saved to {}'.format(output_filepath)) 158 | 159 | word2vec = Vector() 160 | train_loader = data.get_loader(word2vec, train=True) 161 | val_loader = data.get_loader(word2vec, val=True) 162 | test_loader = data.get_loader(word2vec, test=True) 163 | 164 | question_word2vec = word2vec._prepare(train_loader.dataset.token_to_index) 165 | 166 | if args.context_embedding == 'SAN': 167 | context_net = model.StackedAttentionEmbedding( 168 | train_loader.dataset.num_tokens, 169 | question_word2vec).cuda() 170 | elif args.context_embedding == 'BoW': 171 | context_net = model.VisualSemanticEmbedding( 172 | train_loader.dataset.num_tokens, 173 | question_word2vec).cuda() 174 | else: 175 | raise TypeError('Unsupported Context Model') 176 | 177 | if args.answer_embedding == 'BoW': 178 | answer_net = model.MLPEmbedding(train_loader.dataset.vector.dim).cuda() 179 | elif args.answer_embedding == 'RNN': 180 | answer_net = model.RNNEmbedding(train_loader.dataset.vector.dim).cuda() 181 | else: 182 | raise TypeError('Unsupported Answer Model') 183 | 184 | print('Context Model:') 185 | print(context_net) 186 | 187 | print('Answer Model:') 188 | print(answer_net) 189 | 190 | if args.pretrained_model is not None: 191 | states = torch.load(args.pretrained_model) 192 | answer_state, context_state = states['answer_net'], states['context_net'] 193 | 194 | answer_net.load_state_dict(answer_state) 195 | context_net.load_state_dict(context_state) 196 | 197 | params_for_optimization = list(context_net.parameters()) + list(answer_net.parameters()) 198 | optimizer = optim.Adam([p for p in params_for_optimization if p.requires_grad]) 199 | 200 | tracker = utils.Tracker() 201 | if args.pretrained_model is not None: 202 | accs = test(context_net, answer_net, val_loader, tracker, args, prefix='val', epoch=-1) 203 | print('* Val Accuracy: {}'.format(torch.cat(accs).mean())) 204 | accs = test(context_net, answer_net, test_loader, tracker, args, prefix='test', epoch=-1) 205 | print('* Test Accuracy: {}'.format(torch.cat(accs).mean())) 206 | raise NotImplementedError 207 | 208 | best_val_acc = 0 209 | best_context_net, best_answser_net = None, None 210 | _eval = [] 211 | for i in range(cfg.TRAIN.epochs): 212 | _ = train(context_net, answer_net, train_loader, optimizer, tracker, args, prefix='train', epoch=i) 213 | r = test(context_net, answer_net, val_loader, tracker, args, prefix='val', epoch=i) 214 | 215 | _eval.append({ 'accuracies': r }) 216 | val_acc = torch.mean( torch.cat(r, dim=0) ) 217 | if best_val_acc < val_acc: 218 | # test(context_net, answer_net, test_loader, tracker, args, prefix='test', epoch=i) 219 | best_val_acc = val_acc 220 | best_context_net = copy.deepcopy( context_net.state_dict() ) 221 | best_answer_net = copy.deepcopy( answer_net.state_dict() ) 222 | 223 | results = { 224 | 'name': name, 225 | 'tracker': tracker.to_dict(), 226 | 'config': cfg, 227 | 'context_net': best_context_net, 228 | 'answer_net': best_answer_net, 229 | 'eval': _eval, 230 | 'vocab': { 'answer_to_index': train_loader.dataset.answer_to_index, 231 | 'index_to_answer': train_loader.dataset.index_to_answer } 232 | } 233 | torch.save(results, output_filepath) 234 | 235 | test(context_net, answer_net, test_loader, tracker, args, prefix='test', epoch=-1) 236 | 237 | if __name__ == '__main__': 238 | torch.cuda.set_device(args.gpu_id) 239 | print(args.__dict__) 240 | print(cfg) 241 | update_train_configs(args) 242 | main(args) 243 | -------------------------------------------------------------------------------- /train_vqa_embedding.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path 3 | import math 4 | import json 5 | import numpy as np 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | from tqdm import tqdm 15 | import copy 16 | import argparse 17 | import time 18 | 19 | cudnn.enabled = True 20 | cudnn.benchmark = True 21 | 22 | from ansemb.config import cfg, set_random_seed, update_train_configs 23 | import ansemb.dataset.vqa as data 24 | import ansemb.models.embedding as model 25 | import ansemb.utils as utils 26 | 27 | from ansemb.utils import cosine_sim 28 | from ansemb.vector import Vector 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--gpu_id', default=0, type=int) 32 | parser.add_argument('--finetune', action='store_true') 33 | parser.add_argument('--batch_size', default=128, type=float) 34 | parser.add_argument('--max_negative_answer', default=12000, type=int) 35 | parser.add_argument('--answer_batch_size', default=3000, type=int) 36 | parser.add_argument('--loss_temperature', default=0.01, type=float) 37 | parser.add_argument('--pretrained_model', default=None, type=str) 38 | parser.add_argument('--context_embedding', default='SAN', choices=['SAN', 'BoW']) 39 | parser.add_argument('--answer_embedding', default='BoW', choices=['BoW', 'RNN']) 40 | parser.add_argument('--name', default=None, type=str) 41 | args = parser.parse_args() 42 | 43 | # fix random seed 44 | set_random_seed(cfg.seed) 45 | 46 | def test(context_net, answer_net, loader, tracker, args, prefix='', epoch=0): 47 | context_net.eval() 48 | answer_net.eval() 49 | tracker_class, tracker_params = tracker.MeanMonitor, {} 50 | ans_ids, que_ids = [], [] 51 | accs, masks = [], [] 52 | 53 | tq = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) 54 | acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) 55 | 56 | var_params = { 'volatile': True, 'requires_grad': False } 57 | if args.answer_embedding == 'RNN': 58 | answer_var, answer_len = loader.dataset._get_answer_sequences(range(cfg.TEST.max_answer_index)) 59 | else: 60 | answer_var, answer_len = loader.dataset._get_answer_vectors(range(cfg.TEST.max_answer_index)) 61 | 62 | answer_var = Variable(answer_var.cuda(), **var_params) 63 | answer_embedding = answer_net.forward(answer_var, answer_len) 64 | 65 | cnt = 0 66 | for v, q, _, _, labels, idx, q_len in tq: 67 | v = Variable(v.cuda(), **var_params) 68 | q = Variable(q.cuda(), **var_params) 69 | q_len = Variable(q_len.cuda(), **var_params) 70 | 71 | context_embedding = context_net(v, q, q_len) 72 | 73 | predicts = cosine_sim(context_embedding, answer_embedding) / args.loss_temperature #temperature 74 | acc = utils.batch_accuracy(predicts.data, labels.cuda()).cpu() 75 | 76 | _, flag = labels.max(1) 77 | flag = ( flag >= 2 ).byte() 78 | masks.append(flag) 79 | 80 | accs.append(acc.view(-1)) 81 | acc_tracker.append(acc.mean()) 82 | 83 | # collect stats 84 | _, _ans_ids = predicts.data.cpu().max(dim=1) 85 | ans_ids.append(_ans_ids.view(-1)) 86 | que_ids.append(idx.view(-1).clone()) 87 | 88 | fmt = '{:.4f}'.format 89 | tq.set_postfix(acc=fmt(acc_tracker.mean.value)) 90 | 91 | return accs, masks, ans_ids, que_ids 92 | 93 | def train(context_net, answer_net, loader, optimizer, tracker, args, prefix='', epoch=0): 94 | """ Run an epoch over the given loader """ 95 | context_net.train() 96 | answer_net.train() 97 | tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99} 98 | 99 | tq = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) 100 | loss_tracker = tracker.track('{}_loss'.format(prefix), tracker_class(**tracker_params)) 101 | acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) 102 | lr_tracker = tracker.track('{}_lr'.format(prefix), tracker_class(**tracker_params)) 103 | 104 | var_params = { 'volatile': False, 'requires_grad': False, } 105 | log_softmax = nn.LogSoftmax().cuda() 106 | cnt = 0 107 | start_tm=time.time() 108 | for v, q, avocab, a, labels, idx, q_len in tq: 109 | data_tm = time.time() - start_tm 110 | start_tm=time.time() 111 | 112 | if args.answer_embedding == 'RNN': 113 | answer_var, answer_len = loader.dataset._get_answer_sequences(avocab) 114 | else: 115 | answer_var, answer_len = loader.dataset._get_answer_vectors(avocab) 116 | answer_var = Variable(answer_var.cuda(), **var_params) 117 | 118 | v = Variable(v.cuda(), **var_params) 119 | q = Variable(q.cuda(), **var_params) 120 | a = Variable(a.cuda(), **var_params) 121 | q_len = Variable(q_len.cuda(), **var_params) 122 | 123 | encode_tm = time.time() - start_tm 124 | start_tm=time.time() 125 | 126 | context_embedding = context_net(v, q, q_len) 127 | answer_embedding = answer_net(answer_var, answer_len) 128 | 129 | predicts = cosine_sim(context_embedding, answer_embedding) / args.loss_temperature #temperature 130 | nll = -log_softmax(predicts) 131 | loss = (nll * a / a.sum(1, keepdim=True)).sum(dim=1).mean() 132 | 133 | acc = utils.batch_accuracy(predicts.data, a.data).cpu() 134 | 135 | global total_iterations 136 | lr = utils.update_learning_rate(optimizer, epoch) 137 | 138 | optimizer.zero_grad() 139 | loss.backward() 140 | 141 | optimizer.step() 142 | 143 | model_tm = time.time() - start_tm 144 | start_tm=time.time() 145 | 146 | loss_tracker.append(loss.data[0]) 147 | acc_tracker.append(acc.mean()) 148 | lr_tracker.append(lr) 149 | fmt = '{:.6f}'.format 150 | tq.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value), lr=fmt(lr_tracker.mean.value), t_data=data_tm, t_model=model_tm, t_encode=encode_tm) 151 | 152 | def main(args): 153 | if args.name is None: 154 | from datetime import datetime 155 | name = args.context_embedding+"_"+args.answer_embedding+"_vqa_batch_softmax_embedding_"+datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 156 | name = ( name + '_finetune' ) if args.finetune else name 157 | else: 158 | name = args.context_embedding+"_"+args.answer_embedding+"_vqa_batch_softmax_embedding_"+args.name 159 | 160 | output_filepath = os.path.join(cfg.output_path, '{}.pth'.format(name)) 161 | print('Output data would be saved to {}'.format(output_filepath)) 162 | 163 | word2vec = Vector() 164 | train_loader = data.get_loader(word2vec, train=True) 165 | val_loader = data.get_loader(word2vec, val=True) 166 | 167 | question_word2vec = word2vec._prepare(train_loader.dataset.token_to_index) 168 | 169 | if args.context_embedding == 'SAN': 170 | context_net = model.StackedAttentionEmbedding( 171 | train_loader.dataset.num_tokens, 172 | question_word2vec).cuda() 173 | elif args.context_embedding == 'BoW': 174 | context_net = model.VisualSemanticEmbedding( 175 | train_loader.dataset.num_tokens, 176 | question_word2vec).cuda() 177 | else: 178 | raise TypeError('Unsupported Context Model') 179 | 180 | if args.answer_embedding == 'BoW': 181 | answer_net = model.MLPEmbedding(train_loader.dataset.vector.dim).cuda() 182 | elif args.answer_embedding == 'RNN': 183 | answer_net = model.RNNEmbedding(train_loader.dataset.vector.dim).cuda() 184 | else: 185 | raise TypeError('Unsupported Answer Model') 186 | 187 | print('Context Model:') 188 | print(context_net) 189 | 190 | print('Answer Model:') 191 | print(answer_net) 192 | 193 | if args.pretrained_model is not None: 194 | states = torch.load(args.pretrained_model) 195 | answer_state, context_state = states['answer_net'], states['context_net'] 196 | 197 | answer_net.load_state_dict(answer_state) 198 | context_net.load_state_dict(context_state) 199 | 200 | params_for_optimization = list(context_net.parameters()) + list(answer_net.parameters()) 201 | optimizer = optim.Adam([p for p in params_for_optimization if p.requires_grad]) 202 | 203 | tracker = utils.Tracker() 204 | if args.pretrained_model is not None: 205 | accs, masks, ans_ids, que_ids = test(context_net, answer_net, val_loader, tracker, args, prefix='val', epoch=-1) 206 | 207 | total_accs = torch.cat(accs) 208 | total_masks = torch.cat(masks) 209 | print('* VQA2 Val Accuracy: {}'.format(total_accs.mean())) 210 | accs_ = total_accs[total_masks].sum() 211 | print('* VQA2- Val Accuracy: {}'.format(accs_ / total_masks.sum())) 212 | 213 | results = { 'name': name, 214 | 'eval': [{'answer_ids': ans_ids, 'question_ids': que_ids}], 215 | 'vocab': { 'answer_to_index': train_loader.dataset.answer_to_index, 216 | 'index_to_answer': train_loader.dataset.index_to_answer } } 217 | print('* Dumpping output to: {}'.format(output_filepath)) 218 | torch.save(results, output_filepath) 219 | 220 | if not args.finetune: 221 | raise ValueError('Testing Finished') 222 | 223 | best_val_acc = 0 224 | best_context_net, best_answser_net = None, None 225 | _eval = [] 226 | for i in range(cfg.TRAIN.epochs): 227 | train(context_net, answer_net, train_loader, optimizer, tracker, args, prefix='train', epoch=i) 228 | accs, _, ans_ids, que_ids = test(context_net, answer_net, val_loader, tracker, args, prefix='val', epoch=i) 229 | 230 | _eval.append({ 'accuracies': accs, 'answer_ids': ans_ids, 'question_ids': que_ids }) 231 | val_acc = torch.mean( torch.cat(accs, dim=0) ) 232 | if best_val_acc < val_acc: 233 | best_val_acc = val_acc 234 | best_context_net = copy.deepcopy( context_net.state_dict() ) 235 | best_answer_net = copy.deepcopy( answer_net.state_dict() ) 236 | 237 | results = { 238 | 'name': name, 239 | 'tracker': tracker.to_dict(), 240 | 'config': cfg, 241 | 'context_net': best_context_net, 242 | 'answer_net': best_answer_net, 243 | 'eval': _eval, 244 | 'vocab': { 'answer_to_index': train_loader.dataset.answer_to_index, 245 | 'index_to_answer': train_loader.dataset.index_to_answer } 246 | } 247 | torch.save(results, output_filepath) 248 | 249 | if __name__ == '__main__': 250 | torch.cuda.set_device(args.gpu_id) 251 | print(args.__dict__) 252 | print(cfg) 253 | update_train_configs(args) 254 | main(args) 255 | --------------------------------------------------------------------------------