├── datasets ├── vqa │ └── .keep └── coco_extract │ └── .keep ├── results ├── log │ └── .keep ├── pred │ └── .keep ├── cache │ └── .keep └── result_test │ └── .keep ├── requirements.txt ├── misc └── approach_combo_diagram.png ├── cfgs ├── large_model.yml ├── small_model.yml ├── path_cfgs.py └── base_cfgs.py ├── LICENSE ├── setup.sh ├── core ├── model │ ├── net_utils.py │ ├── optim.py │ ├── losses.py │ ├── mca.py │ └── PointNet.py ├── data │ ├── save_glove_embeds.py │ ├── ans_punct.py │ ├── load_data.py │ └── data_utils.py ├── eval_novel.py └── exec2steps.py ├── utils ├── proc_ansdict.py ├── vqa.py └── vqaEval.py ├── README.md ├── run.py └── MCAN_LICENSE /datasets/vqa/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/log/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/pred/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/cache/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/coco_extract/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/result_test/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | spacy >= 2.0.18 2 | numpy >= 1.16.2 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /misc/approach_combo_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpencerWhitehead/novelvqa/HEAD/misc/approach_combo_diagram.png -------------------------------------------------------------------------------- /cfgs/large_model.yml: -------------------------------------------------------------------------------- 1 | LAYER: 6 2 | HIDDEN_SIZE: 1024 3 | MULTI_HEAD: 8 4 | DROPOUT_R: 0.1 5 | FLAT_MLP_SIZE: 512 6 | FLAT_GLIMPSES: 1 7 | FLAT_OUT_SIZE: 2048 8 | LR_BASE: 0.00005 9 | LR_DECAY_R: 0.2 10 | GRAD_ACCU_STEPS: 2 11 | CKPT_VERSION: 'large' 12 | CKPT_EPOCH: 13 -------------------------------------------------------------------------------- /cfgs/small_model.yml: -------------------------------------------------------------------------------- 1 | LAYER: 6 2 | HIDDEN_SIZE: 512 3 | MULTI_HEAD: 8 4 | DROPOUT_R: 0.1 5 | FLAT_MLP_SIZE: 512 6 | FLAT_GLIMPSES: 1 7 | FLAT_OUT_SIZE: 1024 8 | LR_BASE: 0.0001 9 | LR_DECAY_R: 0.2 10 | GRAD_ACCU_STEPS: 1 11 | CKPT_VERSION: 'small' 12 | CKPT_EPOCH: 13 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Spencer Whitehead and others 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Download vqa v2 dataset 4 | VQA_DIR=./datasets/vqa 5 | mkdir -p $VQA_DIR 6 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip -O $VQA_DIR/v2_Questions_Train_mscoco.zip 7 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip -O $VQA_DIR/v2_Questions_Val_mscoco.zip 8 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip -O $VQA_DIR/v2_Questions_Test_mscoco.zip 9 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip -O $VQA_DIR/v2_Annotations_Train_mscoco.zip 10 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip -O $VQA_DIR/v2_Annotations_Val_mscoco.zip 11 | 12 | unzip $VQA_DIR/v2_Questions_Train_mscoco.zip -d $VQA_DIR/ 13 | unzip $VQA_DIR/v2_Questions_Val_mscoco.zip -d $VQA_DIR/ 14 | unzip $VQA_DIR/v2_Questions_Test_mscoco.zip -d $VQA_DIR/ 15 | unzip $VQA_DIR/v2_Annotations_Train_mscoco.zip -d $VQA_DIR/ 16 | unzip $VQA_DIR/v2_Annotations_Val_mscoco.zip -d $VQA_DIR/ 17 | 18 | 19 | # Unzip the BUTD features 20 | FEAT_DIR=./datasets/coco_extract 21 | cd $FEAT_DIR 22 | echo Unzip train2014.tar.gz ... 23 | tar -xzvf train2014.tar.gz 24 | echo Unzip val2014.tar.gz ... 25 | tar -xzvf val2014.tar.gz 26 | echo Unzip test2015.tar.gz ... 27 | tar -xzvf test2015.tar.gz 28 | cd ../.. 29 | 30 | 31 | -------------------------------------------------------------------------------- /core/model/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class FC(nn.Module): 6 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True): 7 | super(FC, self).__init__() 8 | self.dropout_r = dropout_r 9 | self.use_relu = use_relu 10 | 11 | self.linear = nn.Linear(in_size, out_size) 12 | 13 | if use_relu: 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | if dropout_r > 0: 17 | self.dropout = nn.Dropout(dropout_r) 18 | 19 | def forward(self, x): 20 | x = self.linear(x) 21 | 22 | if self.use_relu: 23 | x = self.relu(x) 24 | 25 | if self.dropout_r > 0: 26 | x = self.dropout(x) 27 | 28 | return x 29 | 30 | 31 | class MLP(nn.Module): 32 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True): 33 | super(MLP, self).__init__() 34 | 35 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu) 36 | self.linear = nn.Linear(mid_size, out_size) 37 | 38 | def forward(self, x): 39 | return self.linear(self.fc(x)) 40 | 41 | 42 | class LayerNorm(nn.Module): 43 | def __init__(self, size, eps=1e-6): 44 | super(LayerNorm, self).__init__() 45 | self.eps = eps 46 | 47 | self.a_2 = nn.Parameter(torch.ones(size)) 48 | self.b_2 = nn.Parameter(torch.zeros(size)) 49 | 50 | def forward(self, x): 51 | mean = x.mean(-1, keepdim=True) 52 | std = x.std(-1, keepdim=True) 53 | 54 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 55 | -------------------------------------------------------------------------------- /core/model/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as Optim 3 | 4 | 5 | class WarmupOptimizer(object): 6 | def __init__(self, lr_base, optimizer, data_size, batch_size): 7 | self.optimizer = optimizer 8 | self._step = 0 9 | self.lr_base = lr_base 10 | self._rate = 0 11 | self.data_size = data_size 12 | self.batch_size = batch_size 13 | 14 | 15 | def step(self): 16 | self._step += 1 17 | 18 | rate = self.rate() 19 | for p in self.optimizer.param_groups: 20 | p['lr'] = rate 21 | self._rate = rate 22 | 23 | self.optimizer.step() 24 | 25 | 26 | def zero_grad(self): 27 | self.optimizer.zero_grad() 28 | 29 | 30 | def rate(self, step=None): 31 | if step is None: 32 | step = self._step 33 | 34 | if step <= int(self.data_size / self.batch_size * 1): 35 | r = self.lr_base * 1/4. 36 | elif step <= int(self.data_size / self.batch_size * 2): 37 | r = self.lr_base * 2/4. 38 | elif step <= int(self.data_size / self.batch_size * 3): 39 | r = self.lr_base * 3/4. 40 | else: 41 | r = self.lr_base 42 | 43 | return r 44 | 45 | 46 | def get_optim(__C, model, data_size, lr_base=None): 47 | if lr_base is None: 48 | lr_base = __C.LR_BASE 49 | 50 | return WarmupOptimizer( 51 | lr_base, 52 | Optim.Adam( 53 | filter(lambda p: p.requires_grad, model.parameters()), 54 | lr=0, 55 | betas=__C.OPT_BETAS, 56 | eps=__C.OPT_EPS 57 | ), 58 | data_size, 59 | __C.BATCH_SIZE 60 | ) 61 | 62 | 63 | def adjust_lr(optim, decay_r): 64 | optim.lr_base *= decay_r 65 | -------------------------------------------------------------------------------- /utils/proc_ansdict.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see MCAN_LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | import sys 8 | sys.path.append('../') 9 | from core.data.ans_punct import prep_ans 10 | import json 11 | 12 | DATASET_PATH = '../datasets/vqa/' 13 | 14 | ANSWER_PATH = { 15 | 'train': DATASET_PATH + 'v2_mscoco_train2014_annotations.json', 16 | 'val': DATASET_PATH + 'v2_mscoco_val2014_annotations.json', 17 | 'vg': DATASET_PATH + 'VG_annotations.json', 18 | } 19 | 20 | # Loading answer word list 21 | stat_ans_list = \ 22 | json.load(open(ANSWER_PATH['train'], 'r'))['annotations'] + \ 23 | json.load(open(ANSWER_PATH['val'], 'r'))['annotations'] 24 | 25 | 26 | def ans_stat(stat_ans_list): 27 | ans_to_ix = {} 28 | ix_to_ans = {} 29 | ans_freq_dict = {} 30 | 31 | for ans in stat_ans_list: 32 | ans_proc = prep_ans(ans['multiple_choice_answer']) 33 | if ans_proc not in ans_freq_dict: 34 | ans_freq_dict[ans_proc] = 1 35 | else: 36 | ans_freq_dict[ans_proc] += 1 37 | 38 | ans_freq_filter = ans_freq_dict.copy() 39 | for ans in ans_freq_dict: 40 | if ans_freq_dict[ans] <= 8: 41 | ans_freq_filter.pop(ans) 42 | 43 | for ans in ans_freq_filter: 44 | ix_to_ans[ans_to_ix.__len__()] = ans 45 | ans_to_ix[ans] = ans_to_ix.__len__() 46 | 47 | return ans_to_ix, ix_to_ans 48 | 49 | ans_to_ix, ix_to_ans = ans_stat(stat_ans_list) 50 | # print(ans_to_ix.__len__()) 51 | json.dump([ans_to_ix, ix_to_ans], open('../core/data/answer_dict.json', 'w')) 52 | -------------------------------------------------------------------------------- /core/data/save_glove_embeds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import pickle 4 | import numpy as np 5 | 6 | from collections import namedtuple 7 | 8 | 9 | QTokenEmbed = namedtuple('QToken', ['text', 'vector']) 10 | 11 | 12 | class StoredEmbeds(object): 13 | def __init__(self, embed_fname='./ckpts/glove_embeds.pkl'): 14 | self.embed_fname = embed_fname 15 | self._embeddings = [] 16 | self._token_to_ix = {} 17 | if os.path.exists(self.embed_fname): 18 | print('Found embedding file: {}\n\tLoading...'.format(self.embed_fname)) 19 | self._token_to_ix, self._embeddings = self.load() 20 | 21 | def get_embeds(self): 22 | return copy.deepcopy(self._token_to_ix), np.array(self._embeddings) 23 | 24 | def set_embeds(self, token2idx, embed_mtx): 25 | self._token_to_ix = token2idx 26 | self._embeddings = embed_mtx 27 | 28 | def has_embeds(self): 29 | return len(self._token_to_ix) and len(self._embeddings) 30 | 31 | def load(self): 32 | with open(self.embed_fname, 'rb+') as embedf: 33 | data_ = pickle.load(embedf) 34 | return data_ 35 | 36 | def save(self): 37 | # Embeddings will not be overwritten if file already exists. 38 | if not os.path.exists(self.embed_fname): 39 | print('Embedding file does not exist. Saving to: {}'.format(self.embed_fname)) 40 | with open(self.embed_fname, 'wb+') as outf: 41 | pickle.dump((self._token_to_ix, self._embeddings), outf, protocol=-1) 42 | else: 43 | print('Embedding file already exists... New embeddings are not saved.') 44 | 45 | def __call__(self, word): 46 | return QTokenEmbed(text=word, vector=self._embeddings[self._token_to_ix[word]]) 47 | -------------------------------------------------------------------------------- /cfgs/path_cfgs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class PATH: 5 | def __init__(self): 6 | 7 | # vqav2 dataset root path 8 | self.DATASET_PATH = './datasets/vqa/' 9 | 10 | # bottom up features root path 11 | self.FEATURE_PATH = './datasets/coco_extract/' 12 | 13 | self.init_path() 14 | 15 | def init_path(self): 16 | 17 | self.IMG_FEAT_PATH = { 18 | 'train': self.FEATURE_PATH + 'train2014/', 19 | 'val': self.FEATURE_PATH + 'val2014/', 20 | 'test': self.FEATURE_PATH + 'test2015/', 21 | } 22 | 23 | self.QUESTION_PATH = { 24 | 'train': self.DATASET_PATH + 'v2_OpenEnded_mscoco_train2014_questions.json', 25 | 'val': self.DATASET_PATH + 'v2_OpenEnded_mscoco_val2014_questions.json', 26 | 'test': self.DATASET_PATH + 'v2_OpenEnded_mscoco_test2015_questions.json', 27 | 'vg': self.DATASET_PATH + 'VG_questions.json', 28 | } 29 | 30 | self.ANSWER_PATH = { 31 | 'train': self.DATASET_PATH + 'v2_mscoco_train2014_annotations.json', 32 | 'val': self.DATASET_PATH + 'v2_mscoco_val2014_annotations.json', 33 | 'vg': self.DATASET_PATH + 'VG_annotations.json', 34 | } 35 | 36 | self.RESULT_PATH = './results/result_test/' 37 | self.PRED_PATH = './results/pred/' 38 | self.CACHE_PATH = './results/cache/' 39 | self.LOG_PATH = './results/log/' 40 | self.CKPTS_PATH = './ckpts/' 41 | self.ATTN_PATH = './results/attn/' 42 | self.ANA_PATH = './results/analysis' 43 | 44 | if 'result_test' not in os.listdir('./results'): 45 | os.mkdir('./results/result_test') 46 | 47 | if 'pred' not in os.listdir('./results'): 48 | os.mkdir('./results/pred') 49 | 50 | if 'cache' not in os.listdir('./results'): 51 | os.mkdir('./results/cache') 52 | 53 | if 'log' not in os.listdir('./results'): 54 | os.mkdir('./results/log') 55 | 56 | if 'ckpts' not in os.listdir('./'): 57 | os.mkdir('./ckpts') 58 | 59 | def check_path(self): 60 | print('Checking dataset ...') 61 | 62 | for mode in self.IMG_FEAT_PATH: 63 | if not os.path.exists(self.IMG_FEAT_PATH[mode]): 64 | print(self.IMG_FEAT_PATH[mode] + ' DOES NOT EXIST') 65 | exit(-1) 66 | 67 | for mode in self.QUESTION_PATH: 68 | if not os.path.exists(self.QUESTION_PATH[mode]): 69 | print(self.QUESTION_PATH[mode] + ' DOES NOT EXIST') 70 | exit(-1) 71 | 72 | for mode in self.ANSWER_PATH: 73 | if not os.path.exists(self.ANSWER_PATH[mode]): 74 | print(self.ANSWER_PATH[mode] + ' DOES NOT EXIST') 75 | exit(-1) 76 | 77 | print('Finished') 78 | print('') 79 | -------------------------------------------------------------------------------- /core/eval_novel.py: -------------------------------------------------------------------------------- 1 | from core.data.load_data import DataSet 2 | from utils.vqa import VQA 3 | from utils.vqaEval import VQAEval 4 | 5 | import os, copy 6 | 7 | 8 | class Execution: 9 | def __init__(self, __C): 10 | 11 | self.__C = __C 12 | __C_eval = copy.deepcopy(__C) 13 | setattr(__C_eval, 'RUN_MODE', 'val') 14 | print('Loading validation set for per-epoch evaluation ........') 15 | self.dataset_eval = DataSet(__C_eval) 16 | 17 | def run(self, run_mode): 18 | if run_mode == 'valNovel': 19 | self.eval(self.dataset_eval) 20 | else: 21 | exit(-1) 22 | 23 | # Evaluation 24 | def eval(self, dataset): 25 | 26 | print(self.__C.RESULT_EVAL_FILE) 27 | 28 | # Load parameters 29 | if self.__C.RESULT_EVAL_FILE is None: 30 | exit(-1) 31 | 32 | if not os.path.isfile(self.__C.RESULT_EVAL_FILE): 33 | exit(-1) 34 | 35 | result_eval_file = self.__C.RESULT_EVAL_FILE 36 | 37 | # create vqa object and vqaRes object 38 | ques_file_path = self.__C.QUESTION_PATH['val'] 39 | ans_file_path = self.__C.ANSWER_PATH['val'] 40 | 41 | vqa = VQA(ans_file_path, ques_file_path) 42 | vqaRes = vqa.loadRes(result_eval_file, ques_file_path) 43 | 44 | # create vqaEval object by taking vqa and vqaRes 45 | vqaEval = VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 46 | 47 | # evaluate results 48 | """ 49 | If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function 50 | By default it uses all the question ids in annotation file 51 | """ 52 | vqaEval.evaluate() 53 | 54 | # print accuracies 55 | print("\n") 56 | print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 57 | # print("Per Question Type Accuracy is the following:") 58 | # for quesType in vqaEval.accuracy['perQuestionType']: 59 | # print("%s : %.02f" % (quesType, vqaEval.accuracy['perQuestionType'][quesType])) 60 | # print("\n") 61 | print("Per Answer Type Accuracy is the following:") 62 | for ansType in vqaEval.accuracy['perAnswerType']: 63 | print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 64 | print("\n") 65 | 66 | novel_ques_ids = dataset.novel_ques_ids 67 | if self.__C.NOVEL and novel_ques_ids is not None: 68 | # evaluate results on novel subset 69 | 70 | vqaEval.evaluate(novel_ques_ids) 71 | 72 | # print accuracies 73 | print("\n") 74 | print("Novel Subset Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 75 | print("Per Answer Type Accuracy is the following:") 76 | for ansType in vqaEval.accuracy['perAnswerType']: 77 | print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 78 | print("\n") 79 | -------------------------------------------------------------------------------- /core/model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from math import sqrt 6 | import numpy as np 7 | 8 | 9 | class ContrastProjection(nn.Module): 10 | def __init__(self, __C): 11 | super().__init__() 12 | self.linear1 = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 13 | self.linear2 = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 14 | 15 | def forward(self, tokens): 16 | return self.linear2(F.relu(self.linear1(tokens))) 17 | 18 | 19 | class Losses: 20 | def __init__(self, __C): 21 | self.__C = __C 22 | self.maskval = -1e9 23 | if __C.USE_GROUNDING: 24 | self._point_loss = nn.CrossEntropyLoss().cuda() 25 | else: 26 | self._point_loss = None 27 | 28 | if __C.SKILL_CONT_LOSS: 29 | self._skill_contrast_proj = ContrastProjection(__C).cuda() 30 | self._skill_contrast_loss = nn.CrossEntropyLoss().cuda() 31 | else: 32 | self._skill_contrast_proj = None 33 | self._skill_contrast_loss = None 34 | 35 | self._skill_pool_method = __C.SKILL_POOL 36 | 37 | self._skill_temp = __C.SK_TEMP 38 | 39 | self._point_temp = __C.PT_TEMP 40 | 41 | def get_pointing_scores(self, tgt, refs, ref_masks, point_mask_tok): 42 | # tgt: size: batch x sent_len x 512 43 | # refs[i]: size: batch x sent_len x 512 44 | # ref_masks[i]: batch x sent_len; indicates the locations where there is padding (1 if the index is padding, 0 otherwise) 45 | # point_mask_tok: batch x 1; vector indicating where in the target sequence is the masked token used for pointing 46 | 47 | batch_size, num_toks, tok_dim = tgt.size() 48 | n_refs = len(refs) 49 | 50 | row_id = torch.from_numpy(np.array(range(batch_size))) 51 | masked_tok = tgt[row_id, point_mask_tok.squeeze(1)] # batch_size x tok_dim 52 | 53 | all_ref_hiddens = torch.cat(refs, dim=1) 54 | all_ref_masks = torch.cat(ref_masks, dim=-1) 55 | 56 | scores = torch.zeros(batch_size, num_toks * n_refs, dtype=tgt.dtype, device=tgt.device) 57 | 58 | for i in range(batch_size): 59 | scores[i, :] = torch.matmul(masked_tok[i], all_ref_hiddens[i].t()) / sqrt(tok_dim) 60 | 61 | logits = scores.masked_fill(all_ref_masks, self.maskval) # mask out padding 62 | 63 | return logits, F.softmax(logits, dim=-1) 64 | 65 | def pointing_loss(self, tgt, refs, ref_masks, point_mask_tok, pos): 66 | logits, _ = self.get_pointing_scores(tgt, refs, ref_masks, point_mask_tok) 67 | point_loss_ = self._point_loss(logits, pos.squeeze(1)) 68 | return point_loss_ 69 | 70 | def skill_contrast_loss(self, tgt_tokens, tgt_mask, all_ref_tokens, ref_masks, ref_labels): 71 | # tgt_tokens: batch x 1 x dim OR batch x # tokens x dim (if pool_method is given) 72 | # all_ref_tokens: [batch x 1 x dim OR batch x # tokens x dim] x # refs 73 | 74 | if self._skill_pool_method in {'mean', 'max'}: 75 | tgt_tokens.masked_fill_(tgt_mask.unsqueeze(2), 0.) 76 | 77 | if self._skill_pool_method == 'mean': 78 | tgt_tokens = torch.mean(tgt_tokens, dim=1, keepdim=True) 79 | elif self._skill_pool_method == 'max': 80 | tgt_tokens = torch.max(tgt_tokens, dim=1, keepdim=True) 81 | 82 | masked_ref_tokens = [] 83 | 84 | for rt, rm in zip(all_ref_tokens, ref_masks): 85 | 86 | rt.masked_fill_(rm.unsqueeze(2), 0.) 87 | 88 | if self._skill_pool_method == 'mean': 89 | rt = torch.mean(rt, dim=1, keepdim=True) 90 | elif self._skill_pool_method == 'max': 91 | rt = torch.max(rt, dim=1, keepdim=True) 92 | masked_ref_tokens.append(rt) 93 | 94 | all_ref_tokens = torch.cat(masked_ref_tokens, dim=1) # batch x # refs x D 95 | else: 96 | all_ref_tokens = torch.cat(all_ref_tokens, dim=1) # batch x # refs x D 97 | 98 | tgt_tokens = self._skill_contrast_proj(tgt_tokens) 99 | all_ref_tokens = self._skill_contrast_proj(all_ref_tokens) 100 | 101 | norm_tgt_cls = nn.functional.normalize(tgt_tokens, p=2, dim=-1) 102 | norm_all_ref_cls = nn.functional.normalize(all_ref_tokens, p=2, dim=-1) 103 | 104 | sims_ = torch.bmm(norm_all_ref_cls, norm_tgt_cls.permute(0, 2, 1)).squeeze(2) 105 | 106 | sims_ = torch.div(sims_, self._skill_temp) 107 | 108 | return self._skill_contrast_loss(sims_, ref_labels.squeeze(-1)) 109 | -------------------------------------------------------------------------------- /core/data/ans_punct.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 5 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 6 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 7 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 8 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 9 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 10 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 11 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 12 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 13 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 14 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 15 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 16 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 17 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 18 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 19 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 20 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 21 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 22 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 23 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 24 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 25 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 26 | "someonell": "someone'll", "someones": "someone's", "somethingd": 27 | "something'd", "somethingd've": "something'd've", "something'dve": 28 | "something'd've", "somethingll": "something'll", "thats": 29 | "that's", "thered": "there'd", "thered've": "there'd've", 30 | "there'dve": "there'd've", "therere": "there're", "theres": 31 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 32 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 33 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 34 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 35 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 36 | "what's", "whatve": "what've", "whens": "when's", "whered": 37 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 38 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 39 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 40 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 41 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 42 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 43 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 44 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 45 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 46 | "you'll", "youre": "you're", "youve": "you've" 47 | } 48 | 49 | manual_map = { 'none': '0', 50 | 'zero': '0', 51 | 'one': '1', 52 | 'two': '2', 53 | 'three': '3', 54 | 'four': '4', 55 | 'five': '5', 56 | 'six': '6', 57 | 'seven': '7', 58 | 'eight': '8', 59 | 'nine': '9', 60 | 'ten': '10'} 61 | articles = ['a', 'an', 'the'] 62 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 63 | comma_strip = re.compile("(\d)(\,)(\d)") 64 | punct = [';', r"/", '[', ']', '"', '{', '}', 65 | '(', ')', '=', '+', '\\', '_', '-', 66 | '>', '<', '@', '`', ',', '?', '!'] 67 | 68 | def process_punctuation(inText): 69 | outText = inText 70 | for p in punct: 71 | if (p + ' ' in inText or ' ' + p in inText) \ 72 | or (re.search(comma_strip, inText) != None): 73 | outText = outText.replace(p, '') 74 | else: 75 | outText = outText.replace(p, ' ') 76 | outText = period_strip.sub("", outText, re.UNICODE) 77 | return outText 78 | 79 | 80 | def process_digit_article(inText): 81 | outText = [] 82 | tempText = inText.lower().split() 83 | for word in tempText: 84 | word = manual_map.setdefault(word, word) 85 | if word not in articles: 86 | outText.append(word) 87 | else: 88 | pass 89 | for wordId, word in enumerate(outText): 90 | if word in contractions: 91 | outText[wordId] = contractions[word] 92 | outText = ' '.join(outText) 93 | return outText 94 | 95 | 96 | def prep_ans(answer): 97 | answer = process_digit_article(process_punctuation(answer)) 98 | answer = answer.replace(',', '') 99 | return answer 100 | -------------------------------------------------------------------------------- /core/model/mca.py: -------------------------------------------------------------------------------- 1 | from core.model.net_utils import FC, MLP, LayerNorm 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch, math 6 | 7 | 8 | class AttFlat(nn.Module): 9 | def __init__(self, __C): 10 | super(AttFlat, self).__init__() 11 | self.__C = __C 12 | 13 | self.mlp = MLP( 14 | in_size=__C.HIDDEN_SIZE, 15 | mid_size=__C.FLAT_MLP_SIZE, 16 | out_size=__C.FLAT_GLIMPSES, 17 | dropout_r=__C.DROPOUT_R, 18 | use_relu=True 19 | ) 20 | 21 | self.linear_merge = nn.Linear( 22 | __C.HIDDEN_SIZE * __C.FLAT_GLIMPSES, 23 | __C.FLAT_OUT_SIZE 24 | ) 25 | 26 | def forward(self, x, x_mask): 27 | att = self.mlp(x) 28 | att = att.masked_fill( 29 | x_mask.squeeze(1).squeeze(1).unsqueeze(2), 30 | -1e9 31 | ) 32 | att = F.softmax(att, dim=1) 33 | 34 | att_list = [] 35 | for i in range(self.__C.FLAT_GLIMPSES): 36 | att_list.append( 37 | torch.sum(att[:, :, i: i + 1] * x, dim=1) 38 | ) 39 | 40 | x_atted = torch.cat(att_list, dim=1) 41 | x_atted = self.linear_merge(x_atted) 42 | 43 | return x_atted 44 | 45 | 46 | # ------------------------------ 47 | # ---- Multi-Head Attention ---- 48 | # ------------------------------ 49 | 50 | class MHAtt(nn.Module): 51 | def __init__(self, __C): 52 | super(MHAtt, self).__init__() 53 | self.__C = __C 54 | 55 | self.linear_v = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 56 | self.linear_k = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 57 | self.linear_q = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 58 | self.linear_merge = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 59 | 60 | self.dropout = nn.Dropout(__C.DROPOUT_R) 61 | 62 | def forward(self, v, k, q, mask): 63 | n_batches = q.size(0) 64 | 65 | v = self.linear_v(v).view( 66 | n_batches, 67 | -1, 68 | self.__C.MULTI_HEAD, 69 | self.__C.HIDDEN_SIZE_HEAD 70 | ).transpose(1, 2) 71 | 72 | k = self.linear_k(k).view( 73 | n_batches, 74 | -1, 75 | self.__C.MULTI_HEAD, 76 | self.__C.HIDDEN_SIZE_HEAD 77 | ).transpose(1, 2) 78 | 79 | q = self.linear_q(q).view( 80 | n_batches, 81 | -1, 82 | self.__C.MULTI_HEAD, 83 | self.__C.HIDDEN_SIZE_HEAD 84 | ).transpose(1, 2) 85 | 86 | atted, attmap = self.att(v, k, q, mask) 87 | atted = atted.transpose(1, 2).contiguous().view( 88 | n_batches, 89 | -1, 90 | self.__C.HIDDEN_SIZE 91 | ) 92 | 93 | atted = self.linear_merge(atted) 94 | 95 | return atted, attmap 96 | 97 | def att(self, value, key, query, mask): 98 | d_k = query.size(-1) 99 | 100 | scores = torch.matmul( 101 | query, key.transpose(-2, -1) 102 | ) / math.sqrt(d_k) 103 | 104 | if mask is not None: 105 | scores = scores.masked_fill(mask, -1e9) 106 | 107 | att_map_ = F.softmax(scores, dim=-1) # let the attention out 108 | 109 | if self.__C.ATTN_DROPOUT: 110 | att_map = self.dropout(att_map_) 111 | else: 112 | att_map = att_map_ 113 | 114 | return torch.matmul(att_map, value), att_map_ 115 | 116 | 117 | # --------------------------- 118 | # ---- Feed Forward Nets ---- 119 | # --------------------------- 120 | 121 | class FFN(nn.Module): 122 | def __init__(self, __C): 123 | super(FFN, self).__init__() 124 | 125 | self.mlp = MLP( 126 | in_size=__C.HIDDEN_SIZE, 127 | mid_size=__C.FF_SIZE, 128 | out_size=__C.HIDDEN_SIZE, 129 | dropout_r=__C.DROPOUT_R, 130 | use_relu=True 131 | ) 132 | 133 | def forward(self, x): 134 | return self.mlp(x) 135 | 136 | 137 | # ------------------------ 138 | # ---- Self Attention ---- 139 | # ------------------------ 140 | 141 | class SA(nn.Module): 142 | def __init__(self, __C): 143 | super(SA, self).__init__() 144 | 145 | self.mhatt = MHAtt(__C) 146 | self.ffn = FFN(__C) 147 | 148 | self.dropout1 = nn.Dropout(__C.DROPOUT_R) 149 | self.norm1 = LayerNorm(__C.HIDDEN_SIZE) 150 | 151 | self.dropout2 = nn.Dropout(__C.DROPOUT_R) 152 | self.norm2 = LayerNorm(__C.HIDDEN_SIZE) 153 | 154 | def forward(self, x, x_mask): 155 | x_attended, attmap = self.mhatt(x, x, x, x_mask) 156 | 157 | x = self.norm1(x + self.dropout1(x_attended)) 158 | x = self.norm2(x + self.dropout2(self.ffn(x))) 159 | 160 | return x, attmap 161 | -------------------------------------------------------------------------------- /core/model/PointNet.py: -------------------------------------------------------------------------------- 1 | from core.model.net_utils import LayerNorm 2 | from core.model.mca import SA, AttFlat 3 | 4 | import torch.nn as nn 5 | import torch 6 | 7 | import math 8 | 9 | 10 | class MCA_Unified(nn.Module): 11 | def __init__(self, __C, answer_size=0): 12 | super(MCA_Unified, self).__init__() 13 | 14 | # add tokens for answer set 15 | self.cls_token = nn.Parameter( 16 | torch.zeros(1, __C.HIDDEN_SIZE).normal_(mean=0, std=math.sqrt(1 / __C.HIDDEN_SIZE))) 17 | 18 | self.enc_list = nn.ModuleList([SA(__C) for _ in range(__C.LAYER)]) 19 | 20 | def forward(self, x, y, x_mask, y_mask, layer_id=0): 21 | batch_size = x.size(0) 22 | 23 | cls_token_ = self.cls_token.expand(batch_size, *self.cls_token.size()) 24 | 25 | cls_mask = torch.zeros((batch_size, 1, 1, 1), dtype=x_mask.dtype, device=x_mask.device) 26 | 27 | chunks = [1, x.size(1), y.size(1)] # cls, x.size(1), y.size(1) 28 | t_size = x.size(1) 29 | im_size = y.size(1) 30 | combo = torch.cat([cls_token_, x, y], dim=1) 31 | combo_mask = torch.cat([cls_mask, x_mask, y_mask], dim=-1) 32 | combo_mask[:, :, 1:1 + t_size, -im_size::] = True # no text->img direct attn 33 | 34 | attmap_list = [] 35 | hidden_text_list = [] 36 | for enc in self.enc_list: 37 | combo, attmap = enc(combo, combo_mask) 38 | attmap_list.append(attmap.unsqueeze(1)) # make a new dimension for the layer dimension concatenation 39 | 40 | c, x, y = torch.split(combo, chunks, dim=1) 41 | hidden_text_list.insert(0, x) # last layer first, then second last layer 42 | 43 | text_hiddens = hidden_text_list[layer_id] 44 | attmap_list = torch.cat(attmap_list, 1) # batch x layer x head x tokid x tokid 45 | others = attmap_list, c, text_hiddens # let's get class token output as well 46 | return x, y, others 47 | 48 | 49 | # ------------------------- 50 | # ---- Main MCAN Model ---- 51 | # ------------------------- 52 | class PointNet(nn.Module): 53 | def __init__(self, __C, pretrained_emb, token_size, answer_size): 54 | super(PointNet, self).__init__() 55 | 56 | self.GROUND_LAYER = getattr(__C, 'GROUND_LAYER', 0) 57 | 58 | self.embedding = nn.Embedding( 59 | num_embeddings=token_size, 60 | embedding_dim=__C.WORD_EMBED_SIZE 61 | ) 62 | # Loading the GloVe embedding weights 63 | if __C.USE_GLOVE and pretrained_emb is not None: 64 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) 65 | self.lstm = nn.LSTM( 66 | input_size=__C.WORD_EMBED_SIZE, 67 | hidden_size=__C.HIDDEN_SIZE, 68 | num_layers=1, 69 | batch_first=True 70 | ) 71 | 72 | self.img_feat_linear = nn.Linear( 73 | __C.IMG_FEAT_SIZE, 74 | __C.HIDDEN_SIZE 75 | ) 76 | self.backbone = MCA_Unified(__C, answer_size) 77 | self.attflat_img = AttFlat(__C) 78 | self.attflat_lang = AttFlat(__C) 79 | 80 | 81 | if __C.USE_POINT_PROJ: 82 | self.point_proj = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 83 | else: 84 | self.point_proj = lambda x: x 85 | 86 | self.proj_norm = LayerNorm(__C.HIDDEN_SIZE) 87 | self.linear_proj = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 88 | self.classifier = nn.Linear(__C.HIDDEN_SIZE, answer_size) 89 | 90 | def forward(self, img_feat, ques_ix, **kwargs): 91 | # Make mask 92 | lang_feat_mask = self.make_mask(ques_ix.unsqueeze(2)) 93 | img_feat_mask = self.make_mask(img_feat) 94 | 95 | # Pre-process Language Feature 96 | lang_feat = self.embedding(ques_ix) 97 | lang_feat, _ = self.lstm(lang_feat) 98 | 99 | # Pre-process Image Feature 100 | img_feat = self.img_feat_linear(img_feat) 101 | 102 | # Backbone Framework 103 | lang_feat, img_feat, others = self.backbone( 104 | lang_feat, 105 | img_feat, 106 | lang_feat_mask, 107 | img_feat_mask, 108 | layer_id=self.GROUND_LAYER 109 | ) 110 | 111 | c, lang_hiddens = others[1], others[2] 112 | if self.training: 113 | lang_hiddens_out = self.point_proj(lang_hiddens) 114 | else: 115 | lang_hiddens_out = lang_hiddens 116 | 117 | # close to mcan's output layer 118 | proj_feat = torch.sigmoid(self.classifier(self.proj_norm(self.linear_proj(c)))) 119 | 120 | proj_feat = torch.squeeze(proj_feat) 121 | 122 | ret_others = [*others] + [img_feat, img_feat_mask] 123 | 124 | return proj_feat, lang_hiddens_out, lang_feat_mask, ret_others 125 | 126 | # Masking 127 | def make_mask(self, feature): 128 | return (torch.sum( 129 | torch.abs(feature), 130 | dim=-1 131 | ) == 0).unsqueeze(1).unsqueeze(2) 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Separating Skills and Concepts for Novel VQA 2 | 3 | This repository contains the PyTorch code for the CVPR 2021 paper: [Separating Skills and Concepts for Novel Visual Question Answering](https://arxiv.org/abs/2107.09106). 4 | 5 | ![Overview](misc/approach_combo_diagram.png) 6 | 7 | ## Citation 8 | 9 | If you find this repository useful in your research, please consider citing: 10 | 11 | ``` 12 | @inproceedings{whitehead2021skillconcept, 13 | author = {Whitehead, Spencer and Wu, Hui and Ji, Heng and Feris, Rogerio and Saenko, Kate}, 14 | title = {Separating Skills and Concepts for Novel Visual Question Answering}, 15 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 16 | pages = {5632--5641}, 17 | year = {2021} 18 | } 19 | ``` 20 | 21 | ## Setup 22 | 23 | #### Requirements 24 | 25 | - [Python](https://www.python.org/downloads/) >= 3.6 26 | - [PyTorch](http://pytorch.org/) >= 1.6.0 with CUDA 27 | - [SpaCy](https://spacy.io/) >= 2.3.2 and download/install `en_vectors_web_lg` to obtain the [GloVe](https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz) vectors 28 | - PyYAML 29 | 30 | 31 | #### Data Download and Organization 32 | 33 | To setup the visual features, question files, and annotation files, please refer to the ['Setup' portion of the MCAN repository](https://github.com/MILVLG/mcan-vqa#setup) (under `Prerequisites`). Follow this procedure exactly, until the `datasets` directory has the structure shown in their repository. 34 | 35 | ## Concept and Reference Set Preprocessing 36 | 37 | The scripts for running the concept discovery and reference set preprocessing yourself will be added to this repository. For the time being, we provide preprocessed files that contain concepts, skill labels (if applicable), and reference sets for each question: 38 | 39 | - [train2014 and val2014 files with concepts and skill labels](https://drive.google.com/file/d/1WS6SOxmgzUxADmHXxlGzVc2KYjGf2Nzw/view?usp=sharing): This does not include reference sets and can be used for novel composition evaluation, even independently of this repository. 40 | - [train2014 and val2014 question files with concepts, skill labels, and reference sets](https://drive.google.com/file/d/1j6xejSs_zcCcq1HJHalk-lxMfbtHCzg5/view?usp=sharing): This is the same as the above, but it includes reference sets and should be used when running the code for our approach. 41 | 42 | You should decompress the zip file and place the JSON files in the `datasets/vqa` directory: 43 | ```angular2html 44 | |-- datasets 45 | |-- coco_extract 46 | | |-- ... 47 | |-- vqa 48 | | |-- train2014_scr_questions.json 49 | | |-- train2014_scr_annotations.json 50 | | |-- val2014_sc_questions.json 51 | | |-- val2014_sc_annotations.json 52 | | |-- ... 53 | ``` 54 | 55 | 56 | ## Training 57 | 58 | The base of the command to run the training is: 59 | 60 | ```bash 61 | python run.py --RUN train ... 62 | ``` 63 | 64 | Some pertinent arguments to add are: 65 | 66 | - ```--VERSION```: the model/experiment name that will be used to save the outputs. 67 | 68 | - ```--MODEL={'small', 'large'}```: whether you want to use a small or large transformer model (see `cfgs/small_model.yml` or `cfgs/large_model.yml` for details). In the paper, use `small`, which is the default. 69 | 70 | - ```--USE_GROUNDING=True```: whether to utilize our concept grounding loss. Default is `True` 71 | 72 | - ```--CONCEPT```: specifies which concepts should not have any labeled data appear in training (e.g., `--CONCEPT vehicle,car,bus,train`). When used in combination with `--SKILL`, then the composition of the specified skill and concept(s) have their labeled data removed from training. For example, if we have `--SKILL count` and `--CONCEPT car,bus,train`, then the labels for compositions of counting and `car,bus,train` will not be used. 73 | 74 | - ```--SKILL_CONT_LOSS=True```: whether utilize our skill matching loss. Note: `USE_GROUNDING` must be `True` in order to use the skill matching loss in the current implementation. 75 | 76 | - ```--SKILL```: specifies which skill for the skill-concept composition(s) should not have any labeled data appear in training (e.g., `--SKILL count`). 77 | 78 | 79 | During training, the lastest model checkpoints are saved to `ckpts/ckpt_/last_epoch.pkl` and the training logs are saved to `results/log/log_run_.txt`. Validation predictions after every epoch will be saved in the `results/cache/` directory. Additionally, accuracies on novel compositions (or novel concepts) are also evaluated after each epoch. 80 | 81 | ## Evaluating Novel Compositions/Concepts 82 | 83 | While performance on novel compositions/concepts are evaluated after every epoch, they can also be evaluated separately. 84 | 85 | Given a file containing the model predictions on the val2014 data (in the [VQA v2 evaluation format](https://visualqa.org/evaluation.html)), run the following to get results on the novel compositions/concepts: 86 | ```bash 87 | python run.py --RUN valNovel --RESULT_EVAL_FILE --CONCEPT --SKILL 88 | ``` 89 | where `--CONCEPT` and `--SKILL` should be the same as the held out compositions/concepts from training (i.e., exact same arguments). If both, `--CONCEPT` and `--SKILL` are supplied, then that novel skill-concept composition is evaluated. If only, `--CONCEPT` is supplied, then that novel concept is evaluated. 90 | 91 | To obtain a file with model predictions, run: 92 | 93 | ```bash 94 | python run.py --RUN val --CKPT_PATH 95 | ``` 96 | 97 | ## Acknowledgements 98 | 99 | This repository is adapted from the [MCAN](https://github.com/MILVLG/mcan-vqa) repository. We thank the authors for providing their code. 100 | -------------------------------------------------------------------------------- /utils/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | __version__ = '0.9' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | 24 | class VQA: 25 | def __init__(self, annotation_file=None, question_file=None): 26 | """ 27 | Constructor of VQA helper class for reading and visualizing questions and answers. 28 | :param annotation_file (str): location of VQA annotation file 29 | :return: 30 | """ 31 | # load dataset 32 | self.dataset = {} 33 | self.questions = {} 34 | self.qa = {} 35 | self.qqa = {} 36 | self.imgToQA = {} 37 | if not annotation_file == None and not question_file == None: 38 | print('loading VQA annotations and questions into memory...') 39 | time_t = datetime.datetime.utcnow() 40 | dataset = json.load(open(annotation_file, 'r')) 41 | questions = json.load(open(question_file, 'r')) 42 | print(datetime.datetime.utcnow() - time_t) 43 | self.dataset = dataset 44 | self.questions = questions 45 | self.createIndex() 46 | 47 | def createIndex(self): 48 | # create index 49 | print('creating index...') 50 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} 51 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']} 52 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} 53 | for ann in self.dataset['annotations']: 54 | imgToQA[ann['image_id']] += [ann] 55 | qa[ann['question_id']] = ann 56 | for ques in self.questions['questions']: 57 | qqa[ques['question_id']] = ques 58 | print('index created!') 59 | 60 | # create class members 61 | self.qa = qa 62 | self.qqa = qqa 63 | self.imgToQA = imgToQA 64 | 65 | def info(self): 66 | """ 67 | Print information about the VQA annotation file. 68 | :return: 69 | """ 70 | for key, value in self.dataset['info'].items(): 71 | print('%s: %s' % (key, value)) 72 | 73 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 74 | """ 75 | Get question ids that satisfy given filter conditions. default skips that filter 76 | :param imgIds (int array) : get question ids for given imgs 77 | quesTypes (str array) : get question ids for given question types 78 | ansTypes (str array) : get question ids for given answer types 79 | :return: ids (int array) : integer array of question ids 80 | """ 81 | imgIds = imgIds if type(imgIds) == list else [imgIds] 82 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 83 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 84 | 85 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 86 | anns = self.dataset['annotations'] 87 | else: 88 | if not len(imgIds) == 0: 89 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], []) 90 | else: 91 | anns = self.dataset['annotations'] 92 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 93 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 94 | ids = [ann['question_id'] for ann in anns] 95 | return ids 96 | 97 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 98 | """ 99 | Get image ids that satisfy given filter conditions. default skips that filter 100 | :param quesIds (int array) : get image ids for given question ids 101 | quesTypes (str array) : get image ids for given question types 102 | ansTypes (str array) : get image ids for given answer types 103 | :return: ids (int array) : integer array of image ids 104 | """ 105 | quesIds = quesIds if type(quesIds) == list else [quesIds] 106 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 107 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 108 | 109 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 110 | anns = self.dataset['annotations'] 111 | else: 112 | if not len(quesIds) == 0: 113 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], []) 114 | else: 115 | anns = self.dataset['annotations'] 116 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 117 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 118 | ids = [ann['image_id'] for ann in anns] 119 | return ids 120 | 121 | def loadQA(self, ids=[]): 122 | """ 123 | Load questions and answers with the specified question ids. 124 | :param ids (int array) : integer ids specifying question ids 125 | :return: qa (object array) : loaded qa objects 126 | """ 127 | if type(ids) == list: 128 | return [self.qa[id] for id in ids] 129 | elif type(ids) == int: 130 | return [self.qa[ids]] 131 | 132 | def showQA(self, anns): 133 | """ 134 | Display the specified annotations. 135 | :param anns (array of object): annotations to display 136 | :return: None 137 | """ 138 | if len(anns) == 0: 139 | return 0 140 | for ann in anns: 141 | quesId = ann['question_id'] 142 | print("Question: %s" % (self.qqa[quesId]['question'])) 143 | for ans in ann['answers']: 144 | print("Answer %d: %s" % (ans['answer_id'], ans['answer'])) 145 | 146 | def loadRes(self, resFile, quesFile): 147 | """ 148 | Load result file and return a result object. 149 | :param resFile (str) : file name of result file 150 | :return: res (obj) : result api object 151 | """ 152 | res = VQA() 153 | res.questions = json.load(open(quesFile)) 154 | 155 | res.dataset['info'] = copy.deepcopy(self.questions.get('info', 'none')) 156 | res.dataset['task_type'] = copy.deepcopy(self.questions.get('task_type', 'none')) 157 | res.dataset['data_type'] = copy.deepcopy(self.questions.get('data_type', 'none')) 158 | res.dataset['data_subtype'] = copy.deepcopy(self.questions.get('data_subtype', 'none')) 159 | res.dataset['license'] = copy.deepcopy(self.questions.get('license', 'none')) 160 | 161 | print('Loading and preparing results... ') 162 | time_t = datetime.datetime.utcnow() 163 | anns = json.load(open(resFile)) 164 | assert type(anns) == list, 'results is not an array of objects' 165 | annsQuesIds = [ann['question_id'] for ann in anns] 166 | assert set(annsQuesIds) == set(self.getQuesIds()), \ 167 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' 168 | for ann in anns: 169 | quesId = ann['question_id'] 170 | if res.dataset['task_type'] == 'Multiple Choice': 171 | assert ann['answer'] in self.qqa[quesId][ 172 | 'multiple_choices'], 'predicted answer is not one of the multiple choices' 173 | qaAnn = self.qa[quesId] 174 | ann['image_id'] = qaAnn['image_id'] 175 | ann['question_type'] = qaAnn['question_type'] 176 | ann['answer_type'] = qaAnn['answer_type'] 177 | print('DONE (t=%0.2fs)' % ((datetime.datetime.utcnow() - time_t).total_seconds())) 178 | 179 | res.dataset['annotations'] = anns 180 | res.createIndex() 181 | return res 182 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from cfgs.base_cfgs import Cfgs 2 | from core.exec2steps import Execution as Exec2Steps 3 | from core.eval_novel import Execution as NovelEval 4 | 5 | import argparse, yaml 6 | 7 | from distutils import util as dutil 8 | 9 | def str2bool(v): 10 | return bool(dutil.strtobool(v)) 11 | 12 | 13 | def parse_args(): 14 | ''' 15 | Parse input arguments 16 | ''' 17 | parser = argparse.ArgumentParser(description='Model training/evaluation args') 18 | 19 | parser.add_argument('--RUN', dest='RUN_MODE', 20 | choices=['train', 'val', 'test', 'valNovel'], 21 | help='{train, val, test, valNovel}', 22 | type=str, required=True) 23 | 24 | parser.add_argument('--MODEL', dest='MODEL', 25 | choices=['small', 'large'], 26 | help='{small, large}', 27 | default='small', type=str) 28 | 29 | parser.add_argument('--num_hidden_layers', dest='num_hidden_layers', 30 | default=6, type=int) 31 | 32 | parser.add_argument('--num_attention_heads', dest='num_attention_heads', 33 | default=8, type=int) 34 | 35 | parser.add_argument('--ATTN_DROPOUT', dest='ATTN_DROPOUT', 36 | type=str2bool, default= True) 37 | 38 | parser.add_argument('--NOVEL_AUGMENT', dest='NOVEL_AUGMENT', 39 | type=int, default=1) # during pointing, can augment exposure to novel concepts 40 | 41 | parser.add_argument('--GROUND_LAYER', dest='GROUND_LAYER', 42 | type=int, default=0) # last layer id = 0, second-last layer id = 1, etc. 43 | 44 | parser.add_argument('--RESULT_EVAL_FILE', dest='RESULT_EVAL_FILE', 45 | type=str, default=None, help='JSON file containing generated answers for evaluation.') 46 | 47 | parser.add_argument('--CONCEPT', dest='CONCEPT', 48 | type=str, default=None, 49 | help='Novel concepts with no labeled data in training (string with commas separating conepts)') 50 | 51 | parser.add_argument('--SKILL', dest='SKILL', 52 | type=str, default=None, 53 | help='Novel skill with no labeled data in training. ' + \ 54 | 'When combined with the CONCEPT arg, this will ' + \ 55 | 'remove labeled data for skill-concept compositions') 56 | 57 | parser.add_argument('--SPLIT', dest='TRAIN_SPLIT', 58 | choices=['train', 'train+val', 'train+val+vg'], 59 | help="set training split, " 60 | "eg.'train', 'train+val+vg'" 61 | "set 'train' can trigger the " 62 | "eval after every epoch", 63 | default='train', 64 | type=str) 65 | 66 | parser.add_argument('--LR_DECAY_LIST', dest='LR_DECAY_LIST', 67 | type=int, nargs='*', default=[10, 12]) 68 | 69 | parser.add_argument('--USE_GROUNDING', dest='USE_GROUNDING', 70 | type=str2bool, default=True) 71 | 72 | parser.add_argument('--TGT_MASKING', dest='TGT_MASKING', 73 | type=str, default='target', choices=['target', 'bert', 'even', 'none']) 74 | 75 | parser.add_argument('--USE_POINT_PROJ', dest='USE_POINT_PROJ', 76 | type=str2bool, default=True) 77 | 78 | parser.add_argument('--PT_TEMP', dest='PT_TEMP', 79 | type=float, default=1.0) 80 | 81 | parser.add_argument('--GROUNDING_PROB', dest='GROUNDING_PROB', 82 | type=float, default=0.1) 83 | 84 | parser.add_argument('--SK_TEMP', dest='SK_TEMP', 85 | type=float, default=0.5) 86 | 87 | parser.add_argument('--SKILL_CONT_LOSS', dest='SKILL_CONT_LOSS', 88 | type=str2bool, default=True) 89 | 90 | parser.add_argument('--SKILL_POOL', dest='SKILL_POOL', 91 | type=str, default='mean', choices=['cls', 'mean', 'max']) 92 | 93 | parser.add_argument('--EVAL_EE', dest='EVAL_EVERY_EPOCH', 94 | help='set True to evaluate the ' 95 | 'val split when an epoch finished' 96 | "(only work when train with " 97 | "'train' split)", 98 | type=bool) 99 | 100 | parser.add_argument('--SAVE_PRED', dest='TEST_SAVE_PRED', 101 | help='set True to save the ' 102 | 'prediction vectors' 103 | '(only work in testing)', 104 | type=bool) 105 | 106 | parser.add_argument('--BS', dest='BATCH_SIZE', 107 | help='batch size during training', 108 | type=int) 109 | 110 | parser.add_argument('--MAX_EPOCH', dest='MAX_EPOCH', 111 | help='max training epoch', 112 | type=int) 113 | 114 | parser.add_argument('--GPU', dest='GPU', 115 | help="gpu select, eg.'0, 1, 2'", 116 | type=str) 117 | 118 | parser.add_argument('--SEED', dest='SEED', 119 | help='fix random seed', 120 | type=int) 121 | 122 | parser.add_argument('--VERSION', dest='VERSION', 123 | help='version control', 124 | type=str) 125 | 126 | parser.add_argument('--RESUME', dest='RESUME', 127 | help='resume training', 128 | type=str2bool) 129 | 130 | parser.add_argument('--CKPT_V', dest='CKPT_VERSION', 131 | help='checkpoint version', 132 | type=str) 133 | 134 | parser.add_argument('--CKPT_E', dest='CKPT_EPOCH', 135 | help='checkpoint epoch', 136 | type=int) 137 | 138 | parser.add_argument('--CKPT_PATH', dest='CKPT_PATH', 139 | help='load checkpoint path, we ' 140 | 'recommend that you use ' 141 | 'CKPT_VERSION and CKPT_EPOCH ' 142 | 'instead', 143 | type=str) 144 | 145 | parser.add_argument('--ACCU', dest='GRAD_ACCU_STEPS', 146 | help='reduce gpu memory usage', 147 | type=int) 148 | 149 | parser.add_argument('--NW', dest='NUM_WORKERS', 150 | help='multithreaded loading', 151 | type=int) 152 | 153 | parser.add_argument('--PINM', dest='PIN_MEM', 154 | help='use pin memory', 155 | type=bool) 156 | 157 | parser.add_argument('--VERB', dest='VERBOSE', 158 | help='verbose print', 159 | type=bool) 160 | 161 | parser.add_argument('--DATA_PATH', dest='DATASET_PATH', 162 | help='vqav2 dataset root path', 163 | type=str) 164 | 165 | parser.add_argument('--FEAT_PATH', dest='FEATURE_PATH', 166 | help='bottom up features root path', 167 | type=str) 168 | 169 | args = parser.parse_args() 170 | return args 171 | 172 | 173 | if __name__ == '__main__': 174 | __C = Cfgs() 175 | 176 | args = parse_args() 177 | args_dict = __C.parse_to_dict(args) 178 | 179 | print(args) 180 | 181 | cfg_file = "cfgs/{}_model.yml".format(args.MODEL) 182 | with open(cfg_file, 'r') as f: 183 | yaml_dict = yaml.load(f) 184 | 185 | args_dict = {**yaml_dict, **args_dict} 186 | __C.add_args(args_dict) 187 | __C.fix_and_add_args(args_dict) 188 | __C.proc() 189 | 190 | print('Hyper Parameters:') 191 | print(__C) 192 | 193 | __C.check_path() 194 | 195 | if __C.RUN_MODE == 'valNovel': 196 | print('Compute validation accuracy on novel subsets') 197 | execution = NovelEval(__C) 198 | else: 199 | if __C.USE_GROUNDING: 200 | print('Use 2-step Loss') 201 | else: 202 | print('No grounding loss') 203 | execution = Exec2Steps(__C) 204 | 205 | execution.run(__C.RUN_MODE) 206 | -------------------------------------------------------------------------------- /utils/vqaEval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | __author__='aagrawal' 4 | 5 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 6 | # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). 7 | import sys 8 | import re 9 | 10 | class VQAEval: 11 | def __init__(self, vqa, vqaRes, n=2): 12 | self.n = n 13 | self.accuracy = {} 14 | self.evalQA = {} 15 | self.evalQuesType = {} 16 | self.evalAnsType = {} 17 | self.vqa = vqa 18 | self.vqaRes = vqaRes 19 | self.params = {'question_id': vqa.getQuesIds()} 20 | self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", 21 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", 22 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", 23 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", 24 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 25 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 26 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 27 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", 28 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", 29 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", 30 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", 31 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", 32 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", 33 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", 34 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", 35 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", 36 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", 37 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 38 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 39 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 40 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", 41 | "youll": "you'll", "youre": "you're", "youve": "you've"} 42 | self.manualMap = { 'none': '0', 43 | 'zero': '0', 44 | 'one': '1', 45 | 'two': '2', 46 | 'three': '3', 47 | 'four': '4', 48 | 'five': '5', 49 | 'six': '6', 50 | 'seven': '7', 51 | 'eight': '8', 52 | 'nine': '9', 53 | 'ten': '10' 54 | } 55 | self.articles = ['a', 56 | 'an', 57 | 'the' 58 | ] 59 | 60 | 61 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 62 | self.commaStrip = re.compile("(\d)(,)(\d)") 63 | self.punct = [';', r"/", '[', ']', '"', '{', '}', 64 | '(', ')', '=', '+', '\\', '_', '-', 65 | '>', '<', '@', '`', ',', '?', '!'] 66 | 67 | 68 | def evaluate(self, quesIds=None): 69 | if quesIds == None: 70 | quesIds = [quesId for quesId in self.params['question_id']] 71 | gts = {} 72 | res = {} 73 | for quesId in quesIds: 74 | gts[quesId] = self.vqa.qa[quesId] 75 | res[quesId] = self.vqaRes.qa[quesId] 76 | 77 | # ================================================= 78 | # Compute accuracy 79 | # ================================================= 80 | accQA = [] 81 | accQuesType = {} 82 | accAnsType = {} 83 | print ("computing accuracy") 84 | step = 0 85 | for quesId in quesIds: 86 | resAns = res[quesId]['answer'] 87 | resAns = resAns.replace('\n', ' ') 88 | resAns = resAns.replace('\t', ' ') 89 | resAns = resAns.strip() 90 | resAns = self.processPunctuation(resAns) 91 | resAns = self.processDigitArticle(resAns) 92 | gtAcc = [] 93 | gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] 94 | if len(set(gtAnswers)) > 1: 95 | for ansDic in gts[quesId]['answers']: 96 | ansDic['answer'] = self.processPunctuation(ansDic['answer']) 97 | for gtAnsDatum in gts[quesId]['answers']: 98 | otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum] 99 | matchingAns = [item for item in otherGTAns if item['answer']==resAns] 100 | acc = min(1, float(len(matchingAns))/3) 101 | gtAcc.append(acc) 102 | quesType = gts[quesId]['question_type'] 103 | ansType = gts[quesId]['answer_type'] 104 | avgGTAcc = float(sum(gtAcc))/len(gtAcc) 105 | accQA.append(avgGTAcc) 106 | if quesType not in accQuesType: 107 | accQuesType[quesType] = [] 108 | accQuesType[quesType].append(avgGTAcc) 109 | if ansType not in accAnsType: 110 | accAnsType[ansType] = [] 111 | accAnsType[ansType].append(avgGTAcc) 112 | self.setEvalQA(quesId, avgGTAcc) 113 | self.setEvalQuesType(quesId, quesType, avgGTAcc) 114 | self.setEvalAnsType(quesId, ansType, avgGTAcc) 115 | if step%100 == 0: 116 | self.updateProgress(step/float(len(quesIds))) 117 | step = step + 1 118 | 119 | self.setAccuracy(accQA, accQuesType, accAnsType) 120 | print ("Done computing accuracy") 121 | 122 | def processPunctuation(self, inText): 123 | outText = inText 124 | for p in self.punct: 125 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): 126 | outText = outText.replace(p, '') 127 | else: 128 | outText = outText.replace(p, ' ') 129 | outText = self.periodStrip.sub("", 130 | outText, 131 | re.UNICODE) 132 | return outText 133 | 134 | def processDigitArticle(self, inText): 135 | outText = [] 136 | tempText = inText.lower().split() 137 | for word in tempText: 138 | word = self.manualMap.setdefault(word, word) 139 | if word not in self.articles: 140 | outText.append(word) 141 | else: 142 | pass 143 | for wordId, word in enumerate(outText): 144 | if word in self.contractions: 145 | outText[wordId] = self.contractions[word] 146 | outText = ' '.join(outText) 147 | return outText 148 | 149 | def setAccuracy(self, accQA, accQuesType, accAnsType): 150 | self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n) 151 | self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType} 152 | self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType} 153 | 154 | def setEvalQA(self, quesId, acc): 155 | self.evalQA[quesId] = round(100*acc, self.n) 156 | 157 | def setEvalQuesType(self, quesId, quesType, acc): 158 | if quesType not in self.evalQuesType: 159 | self.evalQuesType[quesType] = {} 160 | self.evalQuesType[quesType][quesId] = round(100*acc, self.n) 161 | 162 | def setEvalAnsType(self, quesId, ansType, acc): 163 | if ansType not in self.evalAnsType: 164 | self.evalAnsType[ansType] = {} 165 | self.evalAnsType[ansType][quesId] = round(100*acc, self.n) 166 | 167 | def updateProgress(self, progress): 168 | barLength = 20 169 | status = "" 170 | if isinstance(progress, int): 171 | progress = float(progress) 172 | if not isinstance(progress, float): 173 | progress = 0 174 | status = "error: progress var must be float\r\n" 175 | if progress < 0: 176 | progress = 0 177 | status = "Halt...\r\n" 178 | if progress >= 1: 179 | progress = 1 180 | status = "Done...\r\n" 181 | block = int(round(barLength*progress)) 182 | text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status) 183 | sys.stdout.write(text) 184 | sys.stdout.flush() 185 | 186 | -------------------------------------------------------------------------------- /cfgs/base_cfgs.py: -------------------------------------------------------------------------------- 1 | from cfgs.path_cfgs import PATH 2 | 3 | import torch, random 4 | import numpy as np 5 | from types import MethodType 6 | 7 | 8 | class Cfgs(PATH): 9 | def __init__(self): 10 | super(Cfgs, self).__init__() 11 | 12 | # Set Devices 13 | # If use multi-gpu training, set e.g.'0, 1, 2' instead 14 | self.GPU = '0' 15 | 16 | # Set RNG For CPU And GPUs 17 | self.SEED = random.randint(0, 99999999) 18 | 19 | # ------------------------- 20 | # ---- Version Control ---- 21 | # ------------------------- 22 | 23 | # Define a specific name to start new training 24 | # self.VERSION = 'Anonymous_' + str(self.SEED) 25 | self.VERSION = str(self.SEED) 26 | 27 | # Resume training 28 | self.RESUME = False 29 | 30 | # Used in Resume training and testing 31 | self.CKPT_VERSION = self.VERSION 32 | self.CKPT_EPOCH = 0 33 | 34 | # Absolutely checkpoint path, 'CKPT_VERSION' and 'CKPT_EPOCH' will be overridden 35 | self.CKPT_PATH = None 36 | 37 | # Print loss every step 38 | self.VERBOSE = True 39 | 40 | # ------------------------------ 41 | # ---- Data Provider Params ---- 42 | # ------------------------------ 43 | 44 | # {'train', 'val', 'test'} 45 | self.RUN_MODE = 'train' 46 | 47 | # Set True to evaluate offline 48 | self.EVAL_EVERY_EPOCH = True 49 | 50 | # Set True to save the prediction vector (Ensemble) 51 | self.TEST_SAVE_PRED = False 52 | 53 | # Define the 'train' 'val' 'test' data split 54 | # (EVAL_EVERY_EPOCH triggered when set {'train': 'train'}) 55 | self.SPLIT = { 56 | 'train': '', 57 | 'val': 'val', 58 | 'test': 'test', 59 | 'valNovel': 'val' 60 | } 61 | 62 | # A external method to set train split 63 | self.TRAIN_SPLIT = 'train+val+vg' 64 | 65 | # Set True to use pretrained word embedding 66 | # (GloVe: spaCy https://spacy.io/) 67 | self.USE_GLOVE = True 68 | 69 | # Word embedding matrix size 70 | # (token size x WORD_EMBED_SIZE) 71 | self.WORD_EMBED_SIZE = 300 72 | 73 | # Max length of question sentences 74 | self.MAX_TOKEN = 14 75 | 76 | # Filter the answer by occurrence 77 | # self.ANS_FREQ = 8 78 | 79 | # Max length of extracted faster-rcnn 2048D features 80 | # (bottom-up and Top-down: https://github.com/peteanderson80/bottom-up-attention) 81 | self.IMG_FEAT_PAD_SIZE = 100 82 | 83 | # Faster-rcnn 2048D features 84 | self.IMG_FEAT_SIZE = 2048 85 | 86 | self.IMG_SPATIAL_FEAT_SIZE = 7 87 | 88 | # Default training batch size: 64 89 | self.BATCH_SIZE = 64 90 | 91 | # Multi-thread I/O 92 | self.NUM_WORKERS = 8 93 | 94 | # Use pin memory 95 | # (Warning: pin memory can accelerate GPU loading but may 96 | # increase the CPU memory usage when NUM_WORKS is large) 97 | self.PIN_MEM = True 98 | 99 | # Large model can not training with batch size 64 100 | # Gradient accumulate can split batch to reduce gpu memory usage 101 | # (Warning: BATCH_SIZE should be divided by GRAD_ACCU_STEPS) 102 | self.GRAD_ACCU_STEPS = 1 103 | 104 | # Set 'external': use external shuffle method to implement training shuffle 105 | # Set 'internal': use pytorch dataloader default shuffle method 106 | self.SHUFFLE_MODE = 'external' 107 | 108 | # ------------------------ 109 | # ---- Network Params ---- 110 | # ------------------------ 111 | 112 | # Model deeps 113 | # (Encoder and Decoder will be same deeps) 114 | self.LAYER = 6 115 | 116 | # Model hidden size 117 | # (512 as default, bigger will be a sharp increase of gpu memory usage) 118 | self.HIDDEN_SIZE = 512 119 | 120 | # Multi-head number in MCA layers 121 | # (Warning: HIDDEN_SIZE should be divided by MULTI_HEAD) 122 | self.MULTI_HEAD = 8 123 | 124 | # Dropout rate for all dropout layers 125 | # (dropout can prevent overfitting: [Dropout: a simple way to prevent neural networks from overfitting]) 126 | self.DROPOUT_R = 0.1 127 | 128 | # MLP size in flatten layers 129 | self.FLAT_MLP_SIZE = 512 130 | 131 | # Flatten the last hidden to vector with {n} attention glimpses 132 | self.FLAT_GLIMPSES = 1 133 | self.FLAT_OUT_SIZE = 1024 134 | 135 | self.SK_TEMP = 1.0 136 | 137 | # -------------------------- 138 | # ---- Optimizer Params ---- 139 | # -------------------------- 140 | 141 | # The base learning rate 142 | self.LR_BASE = 0.0001 143 | 144 | # Learning rate decay ratio 145 | self.LR_DECAY_R = 0.2 146 | 147 | # Learning rate decay at {x, y, z...} epoch 148 | self.LR_DECAY_LIST = [10, 12] 149 | 150 | # Max training epoch 151 | self.MAX_EPOCH = 13 152 | 153 | # Gradient clip 154 | # (default: -1 means not using) 155 | self.GRAD_NORM_CLIP = -1 156 | 157 | # Adam optimizer betas and eps 158 | self.OPT_BETAS = (0.9, 0.98) 159 | self.OPT_EPS = 1e-9 160 | 161 | def parse_to_dict(self, args): 162 | args_dict = {} 163 | for arg in dir(args): 164 | if not arg.startswith('_') and not isinstance(getattr(args, arg), MethodType): 165 | if getattr(args, arg) is not None: 166 | args_dict[arg] = getattr(args, arg) 167 | 168 | return args_dict 169 | 170 | def add_args(self, args_dict): 171 | for arg in args_dict: 172 | setattr(self, arg, args_dict[arg]) 173 | 174 | def fix_and_add_args(self, args_dict): 175 | print('Manually fix question paths for reference sets ...') 176 | 177 | self.QUESTION_PATH['train'] = './datasets/vqa/train2014_scr_questions.json' 178 | self.ANSWER_PATH['train'] = './datasets/vqa/train2014_scr_annotations.json' 179 | 180 | self.QUESTION_PATH['val'] = './datasets/vqa/val2014_sc_questions.json' 181 | self.ANSWER_PATH['val'] = './datasets/vqa/val2014_sc_annotations.json' 182 | 183 | if args_dict.get('SKILL', None) is None: 184 | args_dict['SKILL'] = None 185 | 186 | if args_dict.get('CONCEPT', None) is None: 187 | args_dict['CONCEPT'] = None 188 | 189 | for arg in args_dict: 190 | setattr(self, arg, args_dict[arg]) 191 | 192 | def proc(self): 193 | assert self.RUN_MODE in ['train', 'val', 'test', 'valNovel'] 194 | 195 | # ------------ Devices setup 196 | # os.environ['CUDA_VISIBLE_DEVICES'] = self.GPU 197 | self.N_GPU = len(self.GPU.split(',')) 198 | self.DEVICES = [_ for _ in range(self.N_GPU)] 199 | torch.set_num_threads(2) 200 | 201 | # ------------ Seed setup 202 | # fix pytorch seed 203 | torch.manual_seed(self.SEED) 204 | if self.N_GPU < 2: 205 | torch.cuda.manual_seed(self.SEED) 206 | else: 207 | torch.cuda.manual_seed_all(self.SEED) 208 | torch.backends.cudnn.deterministic = True 209 | 210 | # fix numpy seed 211 | np.random.seed(self.SEED) 212 | 213 | # fix random seed 214 | random.seed(self.SEED) 215 | 216 | if self.CKPT_PATH is not None: 217 | print('Warning: you are now using CKPT_PATH args, ' 218 | 'CKPT_VERSION and CKPT_EPOCH will not work') 219 | self.CKPT_VERSION = self.CKPT_PATH.split('/')[-2] + '_' + str(random.randint(0, 99999999)) 220 | 221 | # ------------ Split setup 222 | self.SPLIT['train'] = self.TRAIN_SPLIT 223 | if 'val' in self.SPLIT['train'].split('+') or self.RUN_MODE not in ['train']: 224 | self.EVAL_EVERY_EPOCH = False 225 | 226 | if self.RUN_MODE not in ['test']: 227 | self.TEST_SAVE_PRED = False 228 | 229 | # ------------ Gradient accumulate setup 230 | assert self.BATCH_SIZE % self.GRAD_ACCU_STEPS == 0 231 | self.SUB_BATCH_SIZE = int(self.BATCH_SIZE / self.GRAD_ACCU_STEPS) 232 | 233 | # Use a small eval batch will reduce gpu memory usage 234 | self.EVAL_BATCH_SIZE = int(self.SUB_BATCH_SIZE / 2) 235 | 236 | # ------------ Networks setup 237 | # FeedForwardNet size in every MCA layer 238 | self.FF_SIZE = int(self.HIDDEN_SIZE * 4) 239 | 240 | # A pipe line hidden size in attention compute 241 | assert self.HIDDEN_SIZE % self.MULTI_HEAD == 0 242 | self.HIDDEN_SIZE_HEAD = int(self.HIDDEN_SIZE / self.MULTI_HEAD) 243 | 244 | def __str__(self): 245 | for attr in dir(self): 246 | if not attr.startswith('__') and not isinstance(getattr(self, attr), MethodType): 247 | print('{ %-17s }->' % attr, getattr(self, attr)) 248 | 249 | return '' 250 | -------------------------------------------------------------------------------- /core/data/load_data.py: -------------------------------------------------------------------------------- 1 | from core.data.data_utils import get_concept_position, prune_refsets 2 | from core.data.data_utils import refset_point_refset_index, sample_refset, do_token_masking 3 | from core.data.data_utils import img_feat_path_load, ques_load, tokenize, ans_stat 4 | from core.data.data_utils import proc_img_feat, proc_ques, proc_ans 5 | from core.data.data_utils import filter_concept_skill, get_novel_ids 6 | from core.data.data_utils import build_skill_references, sample_contrasting_skills 7 | import numpy as np 8 | import random 9 | import glob, json, torch 10 | import torch.utils.data as Data 11 | 12 | 13 | class DataSet(Data.Dataset): 14 | def __init__(self, __C): 15 | self.__C = __C 16 | 17 | # Loading all image paths 18 | self.img_feat_path_list = [] 19 | img_split_list = __C.SPLIT[__C.RUN_MODE].split('+') 20 | 21 | if self.__C.VQACP: 22 | img_split_list = ['train', 'val'] 23 | 24 | for split in img_split_list: 25 | if split in ['train', 'val', 'test']: 26 | self.img_feat_path_list += glob.glob(__C.IMG_FEAT_PATH[split] + '*.npz') 27 | 28 | # Loading question word list 29 | stat_ques_list = \ 30 | json.load(open(__C.QUESTION_PATH['train'], 'r'))['questions'] + \ 31 | json.load(open(__C.QUESTION_PATH['val'], 'r'))['questions'] + \ 32 | json.load(open(__C.QUESTION_PATH['test'], 'r'))['questions'] + \ 33 | json.load(open(__C.QUESTION_PATH['vg'], 'r'))['questions'] 34 | 35 | # Loading question and answer list 36 | self.ques_list = [] 37 | self.ans_list = [] 38 | 39 | split_list = __C.SPLIT[__C.RUN_MODE].split('+') 40 | for split in split_list: 41 | self.ques_list += json.load(open(__C.QUESTION_PATH[split], 'r'))['questions'] 42 | if __C.RUN_MODE in ['train', 'vqaAccRegion', 'evalAll']: 43 | self.ans_list += json.load(open(__C.ANSWER_PATH[split], 'r'))['annotations'] 44 | 45 | self.rs_idx = [] 46 | self.qid2bbanns = {} 47 | 48 | # ------------------------ 49 | # ---- Data statistic ---- 50 | # ------------------------ 51 | 52 | # {image id} -> {image feature absolutely path} 53 | self.iid_to_img_feat_path = img_feat_path_load(self.img_feat_path_list) 54 | 55 | self.novel_ques_ids = None 56 | if self.__C.NOVEL == 'remove' and self.__C.RUN_MODE == 'train': 57 | filter_concept_skill(self.ques_list, self.ans_list, concept=self.__C.CONCEPT, skill=self.__C.SKILL) 58 | elif self.__C.NOVEL == 'get_ids' and self.__C.RUN_MODE == 'val': 59 | self.novel_ques_ids, _ = \ 60 | get_novel_ids(self.ques_list, concept=self.__C.CONCEPT, skill=self.__C.SKILL) 61 | else: 62 | self.novel_ques_ids, self.novel_indices = \ 63 | get_novel_ids(self.ques_list, concept=self.__C.CONCEPT, skill=self.__C.SKILL) 64 | 65 | 66 | # Define run data size 67 | if __C.RUN_MODE in ['train']: 68 | self.data_size = self.ans_list.__len__() 69 | else: 70 | self.data_size = self.ques_list.__len__() 71 | 72 | print('== Dataset size:', self.data_size) 73 | 74 | # {question id} -> {question} 75 | self.qid_to_ques = ques_load(self.ques_list) 76 | 77 | # Tokenize 78 | self.token_to_ix, self.pretrained_emb = tokenize( 79 | stat_ques_list, 80 | __C.USE_GLOVE, 81 | save_embeds=(self.__C.RUN_MODE in {'train'}) # Embeddings will not be overwritten if file already exists. 82 | ) 83 | self.token_size = self.token_to_ix.__len__() 84 | print('== Question token vocab size:', self.token_size) 85 | 86 | # Answers stats 87 | # self.ans_to_ix, self.ix_to_ans = ans_stat(self.stat_ans_list, __C.ANS_FREQ) 88 | self.ans_to_ix, self.ix_to_ans = ans_stat('core/data/answer_dict.json') 89 | self.ans_size = self.ans_to_ix.__len__() 90 | print('== Answer vocab size (occurr more than {} times):'.format(8), self.ans_size) 91 | print('Finished!\n') 92 | 93 | def __getitem__(self, idx): 94 | # For code safety 95 | img_feat_iter = np.zeros(1) 96 | img_pos_feat_iter = np.zeros(1) 97 | ques_ix_iter = np.zeros(1) 98 | ans_iter = np.zeros(1) 99 | 100 | # Process ['train'] and ['val', 'test'] respectively 101 | if self.__C.RUN_MODE in ['train', 'evalAll']: 102 | # Load the run data from list 103 | ans = self.ans_list[idx] 104 | ques = self.qid_to_ques[str(ans['question_id'])] 105 | 106 | # Process image feature from (.npz) file 107 | img_feat = np.load(self.iid_to_img_feat_path[str(ans['image_id'])]) 108 | img_feat_x = img_feat['x'].transpose((1, 0)) 109 | 110 | img_feat_iter = proc_img_feat(img_feat_x, self.__C.IMG_FEAT_PAD_SIZE) 111 | 112 | # Process question 113 | ques_ix_iter = proc_ques( 114 | ques, self.token_to_ix, self.__C.MAX_TOKEN, add_cls=False 115 | ) 116 | 117 | # Process answer 118 | ans_iter = proc_ans(ans, self.ans_to_ix) 119 | 120 | else: 121 | # Load the run data from list 122 | ques = self.ques_list[idx] 123 | 124 | img_feat = np.load(self.iid_to_img_feat_path[str(ques['image_id'])]) 125 | img_feat_x = img_feat['x'].transpose((1, 0)) 126 | 127 | img_feat_iter = proc_img_feat(img_feat_x, self.__C.IMG_FEAT_PAD_SIZE) 128 | 129 | # Process question 130 | ques_ix_iter = proc_ques(ques, self.token_to_ix, self.__C.MAX_TOKEN, add_cls=False) 131 | 132 | return torch.from_numpy(img_feat_iter), \ 133 | torch.from_numpy(ques_ix_iter), \ 134 | torch.from_numpy(ans_iter) 135 | 136 | def __len__(self): 137 | return self.data_size 138 | 139 | 140 | class RefPointDataSet(DataSet): 141 | def __init__(self, __C): 142 | super().__init__(__C) 143 | self.__C = __C 144 | 145 | self.refset_sizes = list(zip(['pos', 'neg1', 'neg2'], [1, 1, 1])) 146 | 147 | self.ques_list = prune_refsets(self.ques_list, self.refset_sizes, self.__C.MAX_TOKEN) 148 | 149 | if not self.rs_idx: 150 | self.rs_idx = refset_point_refset_index( 151 | self.ques_list, 152 | self.__C.MAX_TOKEN, 153 | self.novel_indices, 154 | aug_factor=getattr(self.__C, 'NOVEL_AUGMENT', 1) 155 | ) 156 | 157 | print('== Refset Dataset size:\n', len(self.rs_idx)) 158 | 159 | def __len__(self): 160 | return len(self.rs_idx) 161 | 162 | def __getitem__(self, idx): 163 | target_idx, target_concept = self.rs_idx[idx] 164 | pos_idx_list, neg1_idx_list, neg2_idx_list = \ 165 | sample_refset(self.ques_list[target_idx], target_concept, self.refset_sizes) 166 | all_cand_idx = pos_idx_list + neg1_idx_list + neg2_idx_list 167 | 168 | random.shuffle(all_cand_idx) 169 | 170 | # This assumes that there is only one positive example 171 | data_ref = [] 172 | point_positions = [] 173 | ref_qids = [] 174 | cand_q_len = 0 175 | for i_cand in all_cand_idx: 176 | curr_cand = super().__getitem__(i_cand) 177 | curr_cand_pt_pos = get_concept_position(self.ques_list[i_cand], target_concept) 178 | if curr_cand_pt_pos > -1: 179 | curr_cand_pt_pos += cand_q_len 180 | point_positions.append(curr_cand_pt_pos) 181 | 182 | data_ref.append(curr_cand) 183 | ref_qids.append(self.ques_list[i_cand]['question_id']) 184 | cand_q_len += len(curr_cand[2]) # length of current candidate question 185 | 186 | data_target = super().__getitem__(target_idx) 187 | target_concept_pos = get_concept_position(self.ques_list[target_idx], target_concept) 188 | 189 | assert target_concept_pos != -1 190 | 191 | ori_id = data_target[2][target_concept_pos] 192 | new_id = do_token_masking(ori_id, self.token_to_ix, self.__C.TGT_MASKING) 193 | data_target[2][target_concept_pos] = new_id 194 | 195 | cand_labels = [target_concept_pos] 196 | 197 | cand_labels = torch.from_numpy(np.array(cand_labels)).type(torch.LongTensor) 198 | point_positions = torch.from_numpy(np.array(point_positions)).type(torch.LongTensor) 199 | 200 | qid_data_ = { 201 | 'concept': target_concept, 202 | 'tgt': self.ques_list[target_idx]['question_id'], 203 | 'refs': ref_qids 204 | } 205 | 206 | return data_target, data_ref, cand_labels, point_positions, qid_data_ 207 | 208 | 209 | class SkillContrastDataSet(DataSet): 210 | def __init__(self, __C): 211 | super().__init__(__C) 212 | 213 | print('Building skill references...') 214 | self.rs_idx = build_skill_references(self.ques_list) 215 | print('Training reference set questions with skill references: {}'.format(len(self))) 216 | self.pretrained_emb = None 217 | 218 | def __len__(self): 219 | return len(self.rs_idx) 220 | 221 | def __getitem__(self, idx): 222 | target_idx, target_concept = self.rs_idx[idx] 223 | 224 | pos_idx_list, neg1_idx_list = \ 225 | sample_contrasting_skills(self.ques_list[target_idx], n_pos_samples=1, n_neg_samples=2) 226 | 227 | all_cand_idx = pos_idx_list + neg1_idx_list 228 | point_positions = [0] 229 | 230 | # This assumes that there is only one positive example 231 | data_ref = [] 232 | ref_qids = [] 233 | for i_cand in all_cand_idx: 234 | curr_cand = super().__getitem__(i_cand) 235 | data_ref.append(curr_cand) 236 | ref_qids.append(self.ques_list[i_cand]['question_id']) 237 | 238 | data_target = super().__getitem__(target_idx) 239 | 240 | cand_labels = torch.from_numpy(np.array(point_positions)).type(torch.LongTensor) 241 | point_positions = torch.from_numpy(np.array(point_positions)).type(torch.LongTensor) 242 | 243 | qid_data_ = { 244 | 'concept': target_concept, 245 | 'tgt': self.ques_list[target_idx]['question_id'], 246 | 'refs': ref_qids 247 | } 248 | return data_target, data_ref, cand_labels, point_positions, qid_data_ 249 | -------------------------------------------------------------------------------- /MCAN_LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2019] [Vision and Language Group@ MIL] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /core/data/data_utils.py: -------------------------------------------------------------------------------- 1 | from core.data.ans_punct import prep_ans 2 | 3 | from core.data.save_glove_embeds import StoredEmbeds 4 | import numpy as np 5 | import random, re, json 6 | from torch.utils.data._utils.collate import default_collate 7 | 8 | 9 | try: 10 | import en_vectors_web_lg 11 | except ImportError: 12 | import spacy 13 | 14 | 15 | def shuffle_list(ans_list): 16 | random.shuffle(ans_list) 17 | 18 | 19 | def save_json(obj, fname): 20 | with open(fname, 'w') as f: 21 | json.dump(obj, f) 22 | 23 | 24 | def load_json(fname): 25 | with open(fname, 'r') as f: 26 | data_ = json.load(f) 27 | return data_ 28 | 29 | # ------------------------------ 30 | # ---- Initialization Utils ---- 31 | # ------------------------------ 32 | 33 | def img_feat_path_load(path_list): 34 | iid_to_path = {} 35 | 36 | for ix, path in enumerate(path_list): 37 | iid = str(int(path.split('/')[-1].split('_')[-1].split('.')[0])) 38 | iid_to_path[iid] = path 39 | 40 | return iid_to_path 41 | 42 | 43 | def img_feat_load(path_list): 44 | iid_to_feat = {} 45 | 46 | for ix, path in enumerate(path_list): 47 | iid = str(int(path.split('/')[-1].split('_')[-1].split('.')[0])) 48 | img_feat = np.load(path) 49 | img_feat_x = img_feat['x'].transpose((1, 0)) 50 | iid_to_feat[iid] = img_feat_x 51 | print('\rPre-Loading: [{} | {}] '.format(ix, path_list.__len__()), end=' ') 52 | 53 | return iid_to_feat 54 | 55 | 56 | def ques_load(ques_list): 57 | qid_to_ques = {} 58 | 59 | for ques in ques_list: 60 | qid = str(ques['question_id']) 61 | qid_to_ques[qid] = ques 62 | 63 | return qid_to_ques 64 | 65 | 66 | def get_words(question_str): 67 | return re.sub( 68 | r"([.,'!?\"()*#:;])", 69 | '', 70 | question_str.lower() 71 | ).replace('-', ' ').replace('/', ' ').split() 72 | 73 | 74 | def tokenize(stat_ques_list, use_glove, save_embeds=False): 75 | # This function basically requires use_glove to be true in order to work correctly. 76 | # Otherwise, the indicies in token_to_ix don't match the actual embedding matrix. 77 | 78 | token_to_ix = { 79 | 'PAD': 0, 80 | 'UNK': 1, 81 | '[MASK]': 2, 82 | '[CLS]': 3 83 | } 84 | 85 | spacy_tool = None 86 | pretrained_emb = [] 87 | stored_embeds = StoredEmbeds(embed_fname='./ckpts/glove_embeds.pkl') 88 | if use_glove: 89 | try: 90 | spacy_tool = en_vectors_web_lg.load() 91 | except NameError: 92 | try: 93 | spacy_tool = spacy.load('en_vectors_web_lg') 94 | except OSError: 95 | if not stored_embeds.has_embeds(): 96 | raise ValueError('Spacy could not be loaded and no stored glove embeddings were found.') 97 | return stored_embeds.get_embeds() 98 | 99 | known_vec = spacy_tool('the').vector 100 | mu = 0. 101 | sigma = np.sqrt(1. / known_vec.shape[0]) 102 | 103 | pretrained_emb.append(spacy_tool('PAD').vector) 104 | pretrained_emb.append(spacy_tool('UNK').vector) 105 | pretrained_emb.append( 106 | sigma * np.random.randn(*known_vec.shape).astype(dtype=known_vec.dtype) + mu 107 | ) # Embedding for [MASK] 108 | pretrained_emb.append( 109 | sigma * np.random.randn(*known_vec.shape).astype(dtype=known_vec.dtype) + mu 110 | ) # Embedding for [CLS] 111 | 112 | for ques in stat_ques_list: 113 | words = get_words(ques['question']) 114 | 115 | for word in words: 116 | if word not in token_to_ix: 117 | token_to_ix[word] = len(token_to_ix) 118 | if use_glove: 119 | pretrained_emb.append(spacy_tool(word).vector) 120 | 121 | if save_embeds: 122 | # Embeddings will not be overwritten if file already exists. 123 | stored_embeds.set_embeds(token_to_ix, pretrained_emb) 124 | stored_embeds.save() 125 | 126 | pretrained_emb = np.array(pretrained_emb) 127 | 128 | return token_to_ix, pretrained_emb 129 | 130 | 131 | # def ans_stat(stat_ans_list, ans_freq): 132 | # ans_to_ix = {} 133 | # ix_to_ans = {} 134 | # ans_freq_dict = {} 135 | # 136 | # for ans in stat_ans_list: 137 | # ans_proc = prep_ans(ans['multiple_choice_answer']) 138 | # if ans_proc not in ans_freq_dict: 139 | # ans_freq_dict[ans_proc] = 1 140 | # else: 141 | # ans_freq_dict[ans_proc] += 1 142 | # 143 | # ans_freq_filter = ans_freq_dict.copy() 144 | # for ans in ans_freq_dict: 145 | # if ans_freq_dict[ans] <= ans_freq: 146 | # ans_freq_filter.pop(ans) 147 | # 148 | # for ans in ans_freq_filter: 149 | # ix_to_ans[ans_to_ix.__len__()] = ans 150 | # ans_to_ix[ans] = ans_to_ix.__len__() 151 | # 152 | # return ans_to_ix, ix_to_ans 153 | 154 | def ans_stat(json_file): 155 | ans_to_ix, ix_to_ans = json.load(open(json_file, 'r')) 156 | 157 | return ans_to_ix, ix_to_ans 158 | 159 | 160 | # ------------------------------------ 161 | # ---- Real-Time Processing Utils ---- 162 | # ------------------------------------ 163 | 164 | def proc_img_feat(img_feat, img_feat_pad_size): 165 | if img_feat.shape[0] > img_feat_pad_size: 166 | img_feat = img_feat[:img_feat_pad_size] 167 | 168 | img_feat = np.pad( 169 | img_feat, 170 | ((0, img_feat_pad_size - img_feat.shape[0]), (0, 0)), 171 | mode='constant', 172 | constant_values=0 173 | ) 174 | 175 | return img_feat 176 | 177 | 178 | def proc_ques(ques, token_to_ix, max_token, add_cls=False): 179 | if not add_cls: 180 | ques_ix = np.zeros(max_token, np.int64) 181 | start_ix = 0 182 | max_len = max_token 183 | else: 184 | ques_ix = np.zeros(max_token + 1, np.int64) 185 | ques_ix[0] = token_to_ix['[CLS]'] 186 | start_ix = 1 187 | max_len = max_token + 1 188 | 189 | words = get_words(ques['question']) 190 | 191 | for ix, word in enumerate(words, start=start_ix): 192 | if word in token_to_ix: 193 | ques_ix[ix] = token_to_ix[word] 194 | else: 195 | ques_ix[ix] = token_to_ix['UNK'] 196 | 197 | if ix + 1 == max_len: 198 | break 199 | 200 | return ques_ix 201 | 202 | 203 | def get_score(occur): 204 | if occur == 0: 205 | return .0 206 | elif occur == 1: 207 | return .3 208 | elif occur == 2: 209 | return .6 210 | elif occur == 3: 211 | return .9 212 | else: 213 | return 1. 214 | 215 | 216 | def proc_ans(ans, ans_to_ix): 217 | ans_score = np.zeros(ans_to_ix.__len__(), np.float32) 218 | ans_prob_dict = {} 219 | 220 | for ans_ in ans['answers']: 221 | ans_proc = prep_ans(ans_['answer']) 222 | if ans_proc not in ans_prob_dict: 223 | ans_prob_dict[ans_proc] = 1 224 | else: 225 | ans_prob_dict[ans_proc] += 1 226 | 227 | for ans_ in ans_prob_dict: 228 | if ans_ in ans_to_ix: 229 | ans_score[ans_to_ix[ans_]] = get_score(ans_prob_dict[ans_]) 230 | 231 | return ans_score 232 | 233 | 234 | def refset_collate(batch): 235 | 236 | tgt, refs, label, pos, qid_data = zip(*batch) 237 | 238 | tgt, label, pos = default_collate(tgt), default_collate(label), default_collate(pos) 239 | 240 | refs = list(refs) 241 | n_refs = len(refs[0]) # number of reference examples 242 | 243 | batched_refs = [] 244 | for i in range(n_refs): 245 | ref_i_all = [per_row[i] for per_row in refs] 246 | ref_i_batched = default_collate(ref_i_all) 247 | batched_refs.append(ref_i_batched) 248 | 249 | 250 | return tgt, batched_refs, label, pos, qid_data 251 | 252 | 253 | def refset_tocuda(refset_data): 254 | tgt, batched_refs, label, pos, qid_data = refset_data 255 | label, pos = label.cuda(), pos.cuda() 256 | 257 | tgt = (tgt[0].cuda(),tgt[1].cuda(), tgt[2].cuda()) 258 | 259 | if all(len(x) for x in batched_refs): 260 | batched_refs = [(x[0].cuda(), x[1].cuda(), x[2].cuda()) for x in batched_refs] 261 | 262 | return tgt, batched_refs, label, pos, qid_data 263 | 264 | 265 | def refset_point_refset_index(question_list, max_token, novel_indices=None, aug_factor=1): 266 | # This assumes that each concept only appears once in the question. 267 | # If this is a bad assumption, then we need to iterate over question['concepts'] 268 | 269 | n_questions = len(question_list) 270 | is_novel = [False for _ in range(n_questions)] 271 | if novel_indices: 272 | for x in novel_indices: 273 | is_novel[x] = True 274 | 275 | rs_idx = [] 276 | count_novel = 0 277 | for qidx, question in enumerate(question_list): 278 | if question.get('refsets', None): 279 | for c, crefs in question['refsets'].items(): 280 | has_refs = True 281 | for dkey, vals in crefs.items(): 282 | if not len(vals['index']) or not len(vals['question_id']): 283 | has_refs = False 284 | break 285 | 286 | # Assumes each concepts appears once. 287 | if get_concept_position(question, c) < max_token and has_refs: 288 | rs_idx.append((qidx, c)) 289 | 290 | if is_novel[qidx]: 291 | for _ in range(aug_factor-1): 292 | rs_idx.append((qidx, c)) 293 | count_novel += 1 294 | 295 | print('Added {x} number of novel questions for the refset'.format(x=count_novel)) 296 | 297 | return rs_idx 298 | 299 | 300 | def prune_refsets(question_list, refset_sizes, max_token): 301 | for question in question_list: 302 | if question.get('refsets', None): 303 | for c, crefs in question['refsets'].items(): 304 | for dkey, _ in refset_sizes: 305 | for i, idx in reversed(list(enumerate(crefs[dkey]['index']))): 306 | if len(get_words(question_list[idx]['question'])) > max_token: 307 | crefs[dkey]['index'].pop(i) 308 | crefs[dkey]['question_id'].pop(i) 309 | 310 | return question_list 311 | 312 | 313 | def get_concept_position(question, concept): 314 | # This assumes that concepts are only 1 word. Also, as in refset_index(), we assume each concept appears once. 315 | return question['concepts'].get(concept, [[-1]])[0][0] 316 | 317 | 318 | def do_token_masking(token_id, token_to_ix, mask_mode): 319 | # target = do what we do now 320 | # bert = do what BERT does 321 | # even = 50 / 50 Mask or keep same 322 | # https://github.com/google-research/bert/blob/0fce551b55caabcfba52c61e18f34b541aef186a/create_pretraining_data.py#L342 323 | masked_token_id = None 324 | if mask_mode == 'target': 325 | masked_token_id = token_to_ix['[MASK]'] 326 | elif mask_mode == 'bert': 327 | # 80% of the time, replace with [MASK] 328 | if random.random() < 0.8: 329 | masked_token_id = token_to_ix['[MASK]'] 330 | else: 331 | # 10% of the time, keep original 332 | if random.random() <= 0.5: 333 | masked_token_id = token_id 334 | # 10% of the time, replace with random word 335 | else: 336 | masked_token_id = random.randint(4, len(token_to_ix) - 1) # start at 4 to account for PAD, UNK, MASK, CLS 337 | elif mask_mode == 'even': 338 | if random.random() <= 0.5: 339 | masked_token_id = token_to_ix['[MASK]'] 340 | else: 341 | masked_token_id = token_id 342 | elif mask_mode is None or mask_mode.lower() == 'none': 343 | masked_token_id = token_id 344 | else: 345 | raise ValueError('mask_mode must be in [target, bert, even, none/None]') 346 | 347 | return masked_token_id 348 | 349 | 350 | def filter_concept_skill(ques_list, ans_list, concept, skill): 351 | N, N_ans = len(ques_list), len(ans_list) 352 | assert N == N_ans 353 | 354 | novel_ques_ids, novel_indices = get_novel_ids(ques_list, concept, skill) 355 | 356 | count = 0 357 | for id in reversed(novel_indices): # going back to front, delete novel idx 358 | del ques_list[id] 359 | del ans_list[id] 360 | count += 1 361 | 362 | print('Removed {x} number of novel questions from the current split'.format(x=count)) 363 | print('New dataset size is {x}'.format(x=len(ques_list))) 364 | 365 | 366 | def get_novel_ids(ques_list, concept, skill): 367 | novel_ids, novel_indices = [], [] 368 | if not concept: return novel_ids, novel_indices 369 | 370 | if isinstance(concept, str): 371 | concept = concept.split(',') 372 | 373 | concept_set = set(concept) 374 | 375 | N = len(ques_list) 376 | 377 | for i in range(N): 378 | ques = ques_list[i] 379 | 380 | if 'all_concepts' not in ques: 381 | curr_concepts = set(ques['concepts']) 382 | else: 383 | curr_concepts = set(ques['all_concepts']) 384 | 385 | found_concept = bool(len(concept_set & curr_concepts)) 386 | 387 | if not found_concept: 388 | continue 389 | 390 | if (skill is None or skill.lower() == 'none') or ques['skill'] == skill: 391 | # Found a match, add question id 392 | novel_ids.append(ques['question_id']) 393 | novel_indices.append(i) 394 | 395 | print('Found {x} number of novel question ids'.format(x= len(novel_ids))) 396 | return novel_ids, novel_indices 397 | 398 | 399 | def sample_references(question, concept, reftype_key, n_samples=1): 400 | return random.sample(question['refsets'][concept][reftype_key]['index'], k=n_samples) 401 | 402 | 403 | def sample_refset(question, concept, refset_sizes): 404 | sampled_rs = [] 405 | for dkey, n_samples in refset_sizes: 406 | sampled_rs.append(sample_references(question, concept, dkey, n_samples)) 407 | return sampled_rs 408 | 409 | 410 | def build_skill_references(question_list): 411 | skill_refs = [] 412 | for i, ques in enumerate(question_list): 413 | if ques.get('skill_refset', None): 414 | if len(ques['skill_refset']['pos']) and len(ques['skill_refset']['neg']) > 1: 415 | skill_refs.append((i, 'none')) 416 | return skill_refs 417 | 418 | 419 | def sample_contrasting_skills(question, n_pos_samples, n_neg_samples): 420 | pos_samples_ = random.sample(question['skill_refset']['pos'], k=n_pos_samples) 421 | neg_samples_ = random.sample(question['skill_refset']['neg'], k=n_neg_samples) 422 | return pos_samples_, neg_samples_ 423 | -------------------------------------------------------------------------------- /core/exec2steps.py: -------------------------------------------------------------------------------- 1 | from core.data.load_data import DataSet, RefPointDataSet, SkillContrastDataSet 2 | from core.model.PointNet import PointNet 3 | from core.model.optim import get_optim, adjust_lr 4 | from core.model.losses import Losses 5 | from core.data.data_utils import shuffle_list, refset_collate, refset_tocuda 6 | from utils.vqa import VQA 7 | from utils.vqaEval import VQAEval 8 | 9 | import os, json, torch, datetime, pickle, copy, shutil, time 10 | import numpy as np 11 | import torch.nn as nn 12 | import torch.utils.data as Data 13 | import random 14 | 15 | 16 | class Execution: 17 | def __init__(self, __C): 18 | self.__C = __C 19 | 20 | print('Loading training set ........') 21 | 22 | 23 | if __C.CONCEPT is not None or __C.SKILL is not None: # take out novel concept/skill from training 24 | setattr(__C, 'NOVEL', 'remove') 25 | else: 26 | setattr(__C, 'NOVEL', 'get_ids') 27 | 28 | self.dataset = DataSet(__C) 29 | 30 | if self.__C.USE_GROUNDING: 31 | __C_ref = copy.deepcopy(__C) 32 | setattr(__C_ref, 'NOVEL', 'augment') 33 | self.refdataset = RefPointDataSet(__C_ref) 34 | else: 35 | self.refdataset = None 36 | 37 | if self.__C.SKILL_CONT_LOSS: 38 | __C_sk_ref = copy.deepcopy(__C) 39 | setattr(__C_sk_ref, 'NOVEL', 'augment') 40 | self.sk_contrast_dataset = SkillContrastDataSet(__C_sk_ref) 41 | else: 42 | self.sk_contrast_dataset = None 43 | 44 | self.dataset_eval = None 45 | if __C.EVAL_EVERY_EPOCH: 46 | __C_eval = copy.deepcopy(__C) 47 | setattr(__C_eval, 'RUN_MODE', 'val') 48 | setattr(__C_eval, 'NOVEL', 'get_ids') # for validation, we just need the ids to compute accuracy 49 | 50 | print('Loading validation set for per-epoch evaluation ........') 51 | self.dataset_eval = DataSet(__C_eval) 52 | 53 | def train(self, dataset, refdataset=None, sk_contdataset=None, dataset_eval=None): 54 | # Obtain needed information 55 | data_size = dataset.data_size 56 | token_size = dataset.token_size 57 | ans_size = dataset.ans_size 58 | pretrained_emb = dataset.pretrained_emb 59 | 60 | loss_fns = Losses(self.__C) 61 | 62 | # Define the model 63 | net = PointNet( 64 | self.__C, 65 | pretrained_emb, 66 | token_size, 67 | ans_size) 68 | 69 | print(net) 70 | 71 | net.cuda() 72 | net.train() 73 | 74 | # Define the multi-gpu training if needed 75 | if self.__C.N_GPU > 1: 76 | net = nn.DataParallel(net, device_ids=self.__C.DEVICES) 77 | 78 | # Define the binary cross entropy loss 79 | loss_fn = torch.nn.BCELoss(reduction='sum').cuda() 80 | 81 | # Load checkpoint if resume training 82 | if self.__C.RESUME: 83 | print('========== Resume training ==========') 84 | 85 | if self.__C.CKPT_PATH is not None: 86 | print('Warning: you are now using CKPT_PATH args, ' 87 | 'CKPT_VERSION and CKPT_EPOCH will not work') 88 | 89 | path = self.__C.CKPT_PATH 90 | else: 91 | path = self.__C.CKPTS_PATH + \ 92 | 'ckpt_' + self.__C.CKPT_VERSION + \ 93 | '/epoch' + str(self.__C.CKPT_EPOCH) + '.pkl' 94 | 95 | # Load the network parameters 96 | print('Loading ckpt {}'.format(path)) 97 | ckpt = torch.load(path) 98 | print('Finish!') 99 | net.load_state_dict(ckpt['state_dict']) 100 | 101 | # Load the optimizer paramters 102 | optim = get_optim(self.__C, net, data_size, ckpt['lr_base']) 103 | optim._step = int(data_size / self.__C.BATCH_SIZE * self.__C.CKPT_EPOCH) 104 | optim.optimizer.load_state_dict(ckpt['optimizer']) 105 | 106 | start_epoch = self.__C.CKPT_EPOCH 107 | 108 | else: 109 | if ('ckpt_' + self.__C.VERSION) in os.listdir(self.__C.CKPTS_PATH): 110 | shutil.rmtree(self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION) 111 | 112 | os.mkdir(self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION) 113 | 114 | optim = get_optim(self.__C, net, data_size) 115 | start_epoch = 0 116 | 117 | loss_sum = 0 118 | 119 | # Define multi-thread dataloader 120 | if self.__C.SHUFFLE_MODE in ['external']: 121 | dataloader = Data.DataLoader( 122 | dataset, 123 | batch_size=self.__C.BATCH_SIZE, 124 | shuffle=False, 125 | num_workers=self.__C.NUM_WORKERS, 126 | pin_memory=self.__C.PIN_MEM, 127 | drop_last=True 128 | ) 129 | else: 130 | dataloader = Data.DataLoader( 131 | dataset, 132 | batch_size=self.__C.BATCH_SIZE, 133 | shuffle=True, 134 | num_workers=self.__C.NUM_WORKERS, 135 | pin_memory=self.__C.PIN_MEM, 136 | drop_last=True 137 | ) 138 | 139 | if self.__C.USE_GROUNDING: 140 | refsetloader = Data.DataLoader( 141 | refdataset, 142 | batch_size=self.__C.BATCH_SIZE, 143 | shuffle=True, 144 | num_workers=self.__C.NUM_WORKERS, 145 | pin_memory=self.__C.PIN_MEM, 146 | drop_last=True, 147 | collate_fn=refset_collate 148 | ) 149 | 150 | refsetloader_iter = iter(refsetloader) 151 | 152 | if self.__C.SKILL_CONT_LOSS: 153 | sk_contloader = Data.DataLoader( 154 | sk_contdataset, 155 | batch_size=self.__C.BATCH_SIZE // 4, 156 | shuffle=True, 157 | num_workers=self.__C.NUM_WORKERS, 158 | pin_memory=self.__C.PIN_MEM, 159 | drop_last=True, 160 | collate_fn=refset_collate 161 | ) 162 | sk_contloader_iter = iter(sk_contloader) 163 | 164 | # Training script 165 | for epoch in range(start_epoch, self.__C.MAX_EPOCH): 166 | 167 | # Save log information 168 | logfile = open( 169 | self.__C.LOG_PATH + 170 | 'log_run_' + self.__C.VERSION + '.txt', 171 | 'a+' 172 | ) 173 | logfile.write( 174 | 'nowTime: ' + 175 | datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + 176 | '\n' 177 | ) 178 | logfile.close() 179 | 180 | # Learning Rate Decay 181 | if epoch in self.__C.LR_DECAY_LIST: 182 | adjust_lr(optim, self.__C.LR_DECAY_R) 183 | 184 | # Externally shuffle 185 | if self.__C.SHUFFLE_MODE == 'external': 186 | shuffle_list(dataset.ans_list) 187 | 188 | time_start = time.time() 189 | # Iteration 190 | for step, ( 191 | img_feat_iter, 192 | ques_ix_iter, 193 | ans_iter 194 | ) in enumerate(dataloader): 195 | 196 | optim.zero_grad() 197 | 198 | img_feat_iter = img_feat_iter.cuda() 199 | ques_ix_iter = ques_ix_iter.cuda() 200 | ans_iter = ans_iter.cuda() 201 | 202 | for accu_step in range(self.__C.GRAD_ACCU_STEPS): 203 | 204 | sub_img_feat_iter = \ 205 | img_feat_iter[accu_step * self.__C.SUB_BATCH_SIZE: 206 | (accu_step + 1) * self.__C.SUB_BATCH_SIZE] 207 | sub_ques_ix_iter = \ 208 | ques_ix_iter[accu_step * self.__C.SUB_BATCH_SIZE: 209 | (accu_step + 1) * self.__C.SUB_BATCH_SIZE] 210 | sub_ans_iter = \ 211 | ans_iter[accu_step * self.__C.SUB_BATCH_SIZE: 212 | (accu_step + 1) * self.__C.SUB_BATCH_SIZE] 213 | 214 | output = net(sub_img_feat_iter, sub_ques_ix_iter) 215 | pred = output[0] 216 | 217 | loss = loss_fn(pred, sub_ans_iter) 218 | 219 | # only mean-reduction needs be divided by grad_accu_steps 220 | # removing this line wouldn't change our results because of the Adam optimizer, 221 | # but would be necessary if you use SGD optimizer. 222 | # loss /= self.__C.GRAD_ACCU_STEPS 223 | loss.backward() 224 | loss_sum += loss.cpu().data.numpy() * self.__C.GRAD_ACCU_STEPS 225 | 226 | if self.__C.VERBOSE: 227 | if dataset_eval is not None: 228 | mode_str = self.__C.SPLIT['train'] + '->' + self.__C.SPLIT['val'] 229 | else: 230 | mode_str = self.__C.SPLIT['train'] + '->' + self.__C.SPLIT['test'] 231 | 232 | print("\r[version %s][epoch %2d][step %4d/%4d][%s] loss: %.4f, lr: %.2e" % ( 233 | self.__C.VERSION, 234 | epoch + 1, 235 | step, 236 | int(data_size / self.__C.BATCH_SIZE), 237 | mode_str, 238 | loss.cpu().data.numpy() / self.__C.SUB_BATCH_SIZE, 239 | optim._rate 240 | ), end=' ') 241 | 242 | # Gradient norm clipping 243 | if self.__C.GRAD_NORM_CLIP > 0: 244 | nn.utils.clip_grad_norm_( 245 | net.parameters(), 246 | self.__C.GRAD_NORM_CLIP 247 | ) 248 | 249 | optim.step() 250 | 251 | if self.__C.USE_GROUNDING and random.random() <= self.__C.GROUNDING_PROB: 252 | optim.zero_grad() 253 | 254 | try: 255 | point_batch = next(refsetloader_iter) 256 | except StopIteration: 257 | refsetloader_iter = iter(refsetloader) 258 | point_batch = next(refsetloader_iter) 259 | 260 | target, refs, mask_tok_pos, point_positions, qid_data = refset_tocuda(point_batch) 261 | 262 | # -------------- Forward pass: target and refs ---------------- # 263 | output = net(target[0], target[1]) 264 | target_vqa_output, target_hiddens = output[0], output[1] 265 | 266 | # -------------- Compute loss of pointing -------------- # 267 | refs_vqa_output, refs_hiddens, refs_masks = [], [], [] 268 | for i in range(len(refs)): 269 | ref_output = net(refs[i][0], refs[i][1]) 270 | r_output, r_hidden, r_hidden_mask = ref_output[0], ref_output[1], ref_output[2] 271 | refs_vqa_output.append(r_output) 272 | refs_hiddens.append(r_hidden) 273 | refs_masks.append(r_hidden_mask.squeeze(2).squeeze(1)) 274 | 275 | loss_pointing = loss_fns.pointing_loss( 276 | target_hiddens, refs_hiddens, refs_masks, mask_tok_pos, point_positions 277 | ) 278 | 279 | if self.__C.SKILL_CONT_LOSS: 280 | try: 281 | sk_cont_batch = next(sk_contloader_iter) 282 | except StopIteration: 283 | sk_contloader_iter = iter(sk_contloader) 284 | sk_cont_batch = next(sk_contloader_iter) 285 | 286 | target, refs, _, point_positions, _ = refset_tocuda(sk_cont_batch) 287 | 288 | output = net(target[0], target[1]) 289 | 290 | if self.__C.SKILL_POOL == 'cls': 291 | target_tokens = output[-1][1] 292 | target_mask = None 293 | else: 294 | target_tokens = output[1] 295 | target_mask = output[2].squeeze(2).squeeze(1) 296 | 297 | 298 | # -------------- Compute skill loss -------------- # 299 | 300 | refs_tokens = [] 301 | refs_masks = [] 302 | for i in range(len(refs)): 303 | ref_output = net(refs[i][0], refs[i][1]) 304 | 305 | if self.__C.SKILL_POOL == 'cls': 306 | r_token = ref_output[-1][1] 307 | r_token_mask = None 308 | else: 309 | r_token = ref_output[1] 310 | r_token_mask = ref_output[2].squeeze(2).squeeze(1) 311 | 312 | refs_tokens.append(r_token) 313 | refs_masks.append(r_token_mask) 314 | 315 | loss_sk_cont = loss_fns.skill_contrast_loss( 316 | target_tokens, 317 | target_mask, 318 | refs_tokens, 319 | refs_masks, 320 | point_positions, 321 | ) 322 | 323 | loss_pointing += loss_sk_cont 324 | 325 | loss_pointing.backward() 326 | optim.step() 327 | 328 | time_end = time.time() 329 | print('Finished in {}s'.format(int(time_end-time_start))) 330 | 331 | epoch_finish = epoch + 1 332 | 333 | # Save checkpoint 334 | state = { 335 | 'state_dict': net.state_dict(), 336 | 'optimizer': optim.optimizer.state_dict(), 337 | 'lr_base': optim.lr_base 338 | } 339 | torch.save( 340 | state, 341 | self.__C.CKPTS_PATH + 342 | 'ckpt_' + self.__C.VERSION + 343 | # '/epoch' + str(epoch_finish) + 344 | '/last_epoch.pkl' 345 | ) 346 | 347 | # Logging 348 | logfile = open( 349 | self.__C.LOG_PATH + 350 | 'log_run_' + self.__C.VERSION + '.txt', 351 | 'a+' 352 | ) 353 | logfile.write( 354 | 'epoch = ' + str(epoch_finish) + 355 | ' loss = ' + str(loss_sum / data_size) + 356 | '\n' + 357 | 'lr = ' + str(optim._rate) + 358 | '\n\n' 359 | ) 360 | logfile.close() 361 | 362 | # Eval after every epoch 363 | if dataset_eval is not None: 364 | with torch.no_grad(): 365 | self.eval( 366 | dataset_eval, 367 | state_dict=net.state_dict(), 368 | valid=True 369 | ) 370 | 371 | loss_sum = 0 372 | 373 | # Evaluation 374 | def eval(self, dataset, state_dict=None, valid=False): 375 | 376 | # Load parameters 377 | if self.__C.CKPT_PATH is not None: 378 | print('Warning: you are now using CKPT_PATH args, ' 379 | 'CKPT_VERSION and CKPT_EPOCH will not work') 380 | 381 | path = self.__C.CKPT_PATH 382 | else: 383 | path = self.__C.CKPTS_PATH + \ 384 | 'ckpt_' + self.__C.CKPT_VERSION + \ 385 | '/epoch' + str(self.__C.CKPT_EPOCH) + '.pkl' 386 | 387 | val_ckpt_flag = False 388 | if state_dict is None: 389 | val_ckpt_flag = True 390 | print('Loading ckpt {}'.format(path)) 391 | state_dict = torch.load(path)['state_dict'] 392 | print('Finish!') 393 | 394 | # Store the prediction list 395 | qid_list = [ques['question_id'] for ques in dataset.ques_list] 396 | ans_ix_list = [] 397 | pred_list = [] 398 | 399 | data_size = dataset.data_size 400 | token_size = dataset.token_size 401 | ans_size = dataset.ans_size 402 | pretrained_emb = dataset.pretrained_emb 403 | novel_ques_ids = dataset.novel_ques_ids 404 | 405 | # Define the model 406 | net = PointNet( 407 | self.__C, 408 | pretrained_emb, 409 | token_size, 410 | ans_size) 411 | 412 | net.cuda() 413 | net.eval() 414 | 415 | if self.__C.N_GPU > 1: 416 | net = nn.DataParallel(net, device_ids=self.__C.DEVICES) 417 | 418 | net.load_state_dict(state_dict) 419 | 420 | dataloader = Data.DataLoader( 421 | dataset, 422 | batch_size=self.__C.EVAL_BATCH_SIZE, 423 | shuffle=False, 424 | num_workers=self.__C.NUM_WORKERS, 425 | pin_memory=True 426 | ) 427 | 428 | for step, ( 429 | img_feat_iter, 430 | ques_ix_iter, 431 | ans_iter 432 | ) in enumerate(dataloader): 433 | print("\rEvaluation: [step %4d/%4d]" % ( 434 | step, 435 | int(data_size / self.__C.EVAL_BATCH_SIZE), 436 | ), end=' ') 437 | 438 | img_feat_iter = img_feat_iter.cuda() 439 | ques_ix_iter = ques_ix_iter.cuda() 440 | 441 | output = net(img_feat_iter, ques_ix_iter) 442 | pred = output[0] 443 | pred_np = pred.cpu().data.numpy() 444 | pred_argmax = np.argmax(pred_np, axis=1) 445 | 446 | # Save the answer index 447 | if pred_argmax.shape[0] != self.__C.EVAL_BATCH_SIZE: 448 | pred_argmax = np.pad( 449 | pred_argmax, 450 | (0, self.__C.EVAL_BATCH_SIZE - pred_argmax.shape[0]), 451 | mode='constant', 452 | constant_values=-1 453 | ) 454 | 455 | ans_ix_list.append(pred_argmax) 456 | 457 | # Save the whole prediction vector 458 | if self.__C.TEST_SAVE_PRED: 459 | if pred_np.shape[0] != self.__C.EVAL_BATCH_SIZE: 460 | pred_np = np.pad( 461 | pred_np, 462 | ((0, self.__C.EVAL_BATCH_SIZE - pred_np.shape[0]), (0, 0)), 463 | mode='constant', 464 | constant_values=-1 465 | ) 466 | 467 | pred_list.append(pred_np) 468 | 469 | print('') 470 | ans_ix_list = np.array(ans_ix_list).reshape(-1) 471 | 472 | result = [{ 473 | 'answer': dataset.ix_to_ans[str(ans_ix_list[qix])], # ix_to_ans(load with json) keys are type of string 474 | 'question_id': int(qid_list[qix]) 475 | } for qix in range(qid_list.__len__())] 476 | 477 | # Write the results to result file 478 | if valid: 479 | if val_ckpt_flag: 480 | result_eval_file = \ 481 | self.__C.CACHE_PATH + \ 482 | 'result_run_' + self.__C.CKPT_VERSION + \ 483 | '.json' 484 | else: 485 | result_eval_file = \ 486 | self.__C.CACHE_PATH + \ 487 | 'result_run_' + self.__C.VERSION + \ 488 | '.json' 489 | 490 | else: 491 | if self.__C.CKPT_PATH is not None: 492 | result_eval_file = \ 493 | self.__C.RESULT_PATH + \ 494 | 'result_run_' + self.__C.CKPT_VERSION + \ 495 | '.json' 496 | else: 497 | result_eval_file = \ 498 | self.__C.RESULT_PATH + \ 499 | 'result_run_' + self.__C.CKPT_VERSION + \ 500 | '_epoch' + str(self.__C.CKPT_EPOCH) + \ 501 | '.json' 502 | 503 | print('Save the result to file: {}'.format(result_eval_file)) 504 | 505 | json.dump(result, open(result_eval_file, 'w')) 506 | 507 | # Save the whole prediction vector 508 | if self.__C.TEST_SAVE_PRED: 509 | 510 | if self.__C.CKPT_PATH is not None: 511 | ensemble_file = \ 512 | self.__C.PRED_PATH + \ 513 | 'result_run_' + self.__C.CKPT_VERSION + \ 514 | '.json' 515 | else: 516 | ensemble_file = \ 517 | self.__C.PRED_PATH + \ 518 | 'result_run_' + self.__C.CKPT_VERSION + \ 519 | '_epoch' + str(self.__C.CKPT_EPOCH) + \ 520 | '.json' 521 | 522 | print('Save the prediction vector to file: {}'.format(ensemble_file)) 523 | 524 | pred_list = np.array(pred_list).reshape(-1, ans_size) 525 | result_pred = [{ 526 | 'pred': pred_list[qix], 527 | 'question_id': int(qid_list[qix]) 528 | } for qix in range(qid_list.__len__())] 529 | 530 | pickle.dump(result_pred, open(ensemble_file, 'wb+'), protocol=-1) 531 | 532 | # Run validation script 533 | if valid: 534 | # create vqa object and vqaRes object 535 | ques_file_path = self.__C.QUESTION_PATH['val'] 536 | ans_file_path = self.__C.ANSWER_PATH['val'] 537 | 538 | vqa = VQA(ans_file_path, ques_file_path) 539 | vqaRes = vqa.loadRes(result_eval_file, ques_file_path) 540 | 541 | # create vqaEval object by taking vqa and vqaRes 542 | vqaEval = VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 543 | 544 | # evaluate results 545 | """ 546 | If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function 547 | By default it uses all the question ids in annotation file 548 | """ 549 | vqaEval.evaluate() 550 | 551 | # print accuracies 552 | print("\n") 553 | print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 554 | print("Per Answer Type Accuracy is the following:") 555 | for ansType in vqaEval.accuracy['perAnswerType']: 556 | print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 557 | print("\n") 558 | 559 | if type(novel_ques_ids) is list and len(novel_ques_ids): 560 | # evaluate results on novel subset 561 | 562 | vqaEval.evaluate(novel_ques_ids) 563 | 564 | # print accuracies 565 | print("\n") 566 | print("Novel Subset Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 567 | print("Per Answer Type Accuracy is the following:") 568 | for ansType in vqaEval.accuracy['perAnswerType']: 569 | print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 570 | print("\n") 571 | 572 | if val_ckpt_flag: 573 | print('Write to log file: {}'.format( 574 | self.__C.LOG_PATH + 575 | 'log_run_' + self.__C.CKPT_VERSION + '.txt', 576 | 'a+') 577 | ) 578 | 579 | logfile = open( 580 | self.__C.LOG_PATH + 581 | 'log_run_' + self.__C.CKPT_VERSION + '.txt', 582 | 'a+' 583 | ) 584 | 585 | else: 586 | print('Write to log file: {}'.format( 587 | self.__C.LOG_PATH + 588 | 'log_run_' + self.__C.VERSION + '.txt', 589 | 'a+') 590 | ) 591 | 592 | logfile = open( 593 | self.__C.LOG_PATH + 594 | 'log_run_' + self.__C.VERSION + '.txt', 595 | 'a+' 596 | ) 597 | 598 | logfile.write("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 599 | for ansType in vqaEval.accuracy['perAnswerType']: 600 | logfile.write("%s : %.02f " % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 601 | logfile.write("\n\n") 602 | logfile.close() 603 | 604 | def run(self, run_mode): 605 | if run_mode == 'train': 606 | self.empty_log(self.__C.VERSION) 607 | 608 | refdata = None 609 | if self.__C.USE_GROUNDING: 610 | refdata = self.refdataset 611 | 612 | skcontdata = None 613 | if self.__C.SKILL_CONT_LOSS: 614 | skcontdata = self.sk_contrast_dataset 615 | 616 | self.train( 617 | self.dataset, 618 | refdataset=refdata, 619 | sk_contdataset=skcontdata, 620 | dataset_eval=self.dataset_eval 621 | ) 622 | 623 | elif run_mode == 'val': 624 | with torch.no_grad(): 625 | self.eval(self.dataset, valid=True) 626 | 627 | elif run_mode == 'test': 628 | with torch.no_grad(): 629 | self.eval(self.dataset) 630 | 631 | else: 632 | exit(-1) 633 | 634 | def empty_log(self, version): 635 | print('Initializing log file ........') 636 | if (os.path.exists(self.__C.LOG_PATH + 'log_run_' + version + '.txt')): 637 | os.remove(self.__C.LOG_PATH + 'log_run_' + version + '.txt') 638 | print('Finished!') 639 | print('') 640 | 641 | 642 | 643 | 644 | --------------------------------------------------------------------------------