├── CAS_scripts ├── CAS_dataset_vqacp.py └── Candidate_Answers_Selector.py ├── LMH_lxmert_model.py ├── LMH_vqa_debias_loss_functions.py ├── QTD_model.py ├── README.md ├── SAR_concatenate_dataset_vqacp.py ├── SAR_main.py ├── SAR_replace_dataset_vqacp.py ├── SAR_test.py ├── SAR_train.py ├── attention.py ├── classifier.py ├── comput_score.py ├── data ├── create_dictionary.py ├── download_data.sh ├── preprocess_image.py ├── preprocess_text.py └── utils.py ├── data4VE ├── offline-QTD_model.pth ├── test_dataset4VE_demo.json └── train_dataset4VE_demo.json ├── fc.py ├── language_model.py ├── lxmert_model.py ├── model.jpg ├── opts_SAR.py ├── saved_models_cp2 └── log.txt └── utils.py /CAS_scripts/CAS_dataset_vqacp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import json 8 | import _pickle as cPickle 9 | import numpy as np 10 | import utils 11 | import warnings 12 | 13 | with warnings.catch_warnings(): 14 | warnings.filterwarnings("ignore", category=FutureWarning) 15 | import h5py 16 | from xml.etree.ElementTree import parse 17 | import torch 18 | from torch.utils.data import Dataset 19 | import zarr 20 | import random 21 | COUNTING_ONLY = False 22 | 23 | class VQAFeatureDataset(Dataset): 24 | def __init__(self, name, dictionary, dataroot, image_dataroot, ratio, adaptive=False): 25 | super(VQAFeatureDataset, self).__init__() 26 | assert name in ['train', 'test'] 27 | ans2label_path = os.path.join(dataroot, 'cache', 'train_test_ans2label.pkl') 28 | label2ans_path = os.path.join(dataroot, 'cache', 'train_test_label2ans.pkl') 29 | self.ans2label = cPickle.load(open(ans2label_path, 'rb')) 30 | self.label2ans = cPickle.load(open(label2ans_path, 'rb')) 31 | self.num_ans_candidates = len(self.ans2label) 32 | 33 | self.dictionary = dictionary 34 | self.adaptive = adaptive 35 | 36 | print('loading image features and bounding boxes') 37 | # Load image features and bounding boxes 38 | self.features = zarr.open(os.path.join(image_dataroot, 'trainval.zarr'), mode='r') 39 | self.spatials = zarr.open(os.path.join(image_dataroot, 'trainval_boxes.zarr'), mode='r') 40 | 41 | 42 | 43 | self.v_dim = self.features[list(self.features.keys())[1]].shape[1] 44 | self.s_dim = self.spatials[list(self.spatials.keys())[1]].shape[1] 45 | print('loading image features and bounding boxes done!') 46 | 47 | self.entries = _load_dataset(dataroot, name, self.label2ans, ratio) 48 | self.tokenize() 49 | self.tensorize() 50 | 51 | def tokenize(self, max_length=14): 52 | """Tokenizes the questions. 53 | 54 | This will add q_token in each entry of the dataset. 55 | -1 represent nil, and should be treated as padding_idx in embedding 56 | """ 57 | for entry in self.entries: 58 | tokens = self.dictionary.tokenize(entry['question'], False) 59 | tokens = tokens[:max_length] 60 | if len(tokens) < max_length: 61 | # Note here we pad in front of the sentence 62 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 63 | tokens = tokens + padding 64 | utils.assert_eq(len(tokens), max_length) 65 | entry['q_token'] = tokens 66 | 67 | def tensorize(self): 68 | for entry in self.entries: 69 | question = torch.from_numpy(np.array(entry['q_token'])) 70 | entry['q_token'] = question 71 | answer = entry['answer'] 72 | if None != answer: 73 | labels = np.array(answer['labels']) 74 | scores = np.array(answer['scores'], dtype=np.float32) 75 | if len(labels): 76 | 77 | labels = torch.from_numpy(labels) 78 | scores = torch.from_numpy(scores) 79 | entry['answer']['labels'] = labels 80 | entry['answer']['scores'] = scores 81 | else: 82 | entry['answer']['labels'] = None 83 | entry['answer']['scores'] = None 84 | def __getitem__(self, index): 85 | entry = self.entries[index] 86 | if not self.adaptive: 87 | features = torch.from_numpy(np.array(self.features[entry['image']])) 88 | spatials = torch.from_numpy(np.array(self.spatials[entry['image']])) 89 | 90 | question = entry['q_token'] 91 | question_id = entry['question_id'] 92 | image_id = entry['image_id'] 93 | answer = entry['answer'] 94 | 95 | if None != answer: 96 | labels = answer['labels'] 97 | scores = answer['scores'] 98 | target = torch.zeros(self.num_ans_candidates) 99 | if labels is not None: 100 | target.scatter_(0, labels, scores) 101 | return features, spatials, question, target, question_id, image_id 102 | else: 103 | return features, spatials, question, question_id, image_id 104 | 105 | def __len__(self): 106 | return len(self.entries) 107 | 108 | 109 | -------------------------------------------------------------------------------- /CAS_scripts/Candidate_Answers_Selector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import torch.nn.functional as F 4 | def compute_TopKscore_with_logits(logits, labels, n): 5 | prediction_ans_k, top_ans_ind = torch.topk(F.softmax(logits, dim=-1), k=n, dim=-1, sorted=True) 6 | logits_ind = top_ans_ind 7 | one_hots = torch.zeros(*labels.size()).cuda() 8 | one_hots.scatter_(1, logits_ind.view(-1, n), 1) 9 | scores = (one_hots * labels) 10 | scores = torch.max(scores, 1)[0].data 11 | topN_scores = [] 12 | for i in range(len(labels)): 13 | topN_scores.append(labels[i][logits_ind[i]]) 14 | return scores, top_ans_ind, topN_scores 15 | 16 | 17 | 18 | @torch.no_grad() 19 | def evaluate(model, dataloader): 20 | ''' 21 | When setting dataloader == train_dataloader, we can get the training set of 22 | the Datasets for Answer Re-ranking module. 23 | So does it for the test set. 24 | ''' 25 | score = 0 26 | score20 = 0 27 | upper_bound = 0 28 | num_data = 0 29 | entropy = 0 30 | topN_dict_list = [] 31 | for i, (v, b, q, a, q_id, v_id) in enumerate(dataloader): 32 | v = v.cuda() 33 | b = b.cuda() 34 | q = q.cuda() 35 | 36 | q_id = q_id.cuda() 37 | 38 | v_id = v_id.cuda() 39 | 40 | pred, att = model(q, v, False) 41 | batch_score = compute_score_with_logits(pred, a.cuda()).sum() 42 | score += batch_score.item() 43 | batch_score20, top20, top20_scores = compute_TopKscore_with_logits(pred, a.cuda(), 20) 44 | batch_score20 = batch_score20.sum() 45 | score20 += batch_score20.item() 46 | for i in range(len(q)): 47 | topN_dict = {} 48 | topN_dict['question_id'] = q_id[i].cpu().numpy().tolist() 49 | topN_dict['image_id'] = v_id[i].cpu().numpy().tolist() 50 | topN_dict['top20'] = top20[i].cpu().numpy().tolist() 51 | topN_dict['top20_scores'] = top20_scores[i].cpu().numpy().tolist() 52 | topN_dict_list.append(topN_dict) 53 | 54 | 55 | upper_bound += (a.max(1)[0]).sum().item() 56 | num_data += pred.size(0) 57 | 58 | entropy += calc_entropy(att.data) 59 | 60 | json_str = json.dumps(topN_dict_list, indent=4) 61 | ################################################################## 62 | ### To build the Dataset for the Answer Re-ranking module ### 63 | ### based on Visual Entailment. ### 64 | ### (build training and test sets, respectively) ### 65 | ################################################################## 66 | #with open('./TrainingSet_top20_condidates.json', 'w') as json_file: 67 | with open('./TestSet_top20_condidates.json', 'w') as json_file: 68 | json_file.write(json_str) 69 | 70 | score = score / len(dataloader.dataset) 71 | score20 = score20 / len(dataloader.dataset) 72 | upper_bound = upper_bound / len(dataloader.dataset) 73 | 74 | if entropy is not None: 75 | entropy = entropy / len(dataloader.dataset) 76 | print("score:",score) 77 | print("score20",score20) 78 | return score, upper_bound, entropy, json_str 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /LMH_lxmert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from transformers import LxmertTokenizer, LxmertModel 5 | import numpy as np 6 | from language_model import WordEmbedding, QuestionEmbedding 7 | from classifier import SimpleClassifier, PaperClassifier 8 | 9 | from torch.nn import functional as F 10 | from fc import FCNet, GTH 11 | from attention import Att_0, Att_1, Att_2, Att_3, Att_P, Att_PD, Att_3S 12 | import torch 13 | import random 14 | from LMH_vqa_debias_loss_functions import LearnedMixin 15 | class Model(nn.Module): 16 | def __init__(self, opt): 17 | super(Model, self).__init__() 18 | self.opt = opt 19 | self.model = LxmertModel.from_pretrained('unc-nlp/lxmert-base-uncased', return_dict=True) 20 | self.model = nn.DataParallel(self.model) 21 | self.candi_ans_num = opt.train_candi_ans_num 22 | self.batchsize = opt.batch_size 23 | self.Linear_layer = nn.Linear(768, 1) 24 | norm = opt.norm 25 | activation = opt.activation 26 | dropC = opt.dropC 27 | self.debias_loss_fn = LearnedMixin(0.36) 28 | self.classifier = SimpleClassifier(in_dim=768, hid_dim=2 * 768, out_dim=1, 29 | dropout=dropC, norm=norm, act=activation) 30 | 31 | def forward(self, qa_text, v, b, epo, name, bias, labels): 32 | """ 33 | qa_text (btachsize, candi_ans_num, max_length) 34 | v (batchsize, obj_num, v_dim) 35 | b (batchsize, obj_num, b_dim) 36 | 37 | return: logits 38 | """ 39 | qa_text = qa_text.cuda() 40 | v= v.cuda() 41 | b= b.cuda() 42 | bias = bias.cuda() 43 | 44 | if name == 'train': 45 | self.candi_ans_num = self.opt.train_candi_ans_num 46 | elif name == 'test': 47 | self.candi_ans_num = self.opt.test_candi_ans_num 48 | qa_text_reshape = qa_text.reshape(qa_text.shape[0] * self.candi_ans_num, -1) 49 | 50 | v_repeat = v.repeat(1, self.candi_ans_num, 1) 51 | v_reshape = v_repeat.reshape( v.shape[0] * self.candi_ans_num,v.shape[1], v.shape[2] ) 52 | b_repeat = b.repeat(1, self.candi_ans_num , 1) 53 | b_reshape = b_repeat.reshape( b.shape[0] * self.candi_ans_num,b.shape[1], b.shape[2] ) 54 | 55 | 56 | outputs = self.model(qa_text_reshape, v_reshape, b_reshape) 57 | pool_out = outputs.pooled_output 58 | 59 | logits = self.classifier(pool_out) 60 | logits_reshape = logits.reshape(-1, self.candi_ans_num) 61 | pool_out_reshape = pool_out.reshape(v.shape[0], self.candi_ans_num, -1) 62 | 63 | if labels is not None: 64 | loss = self.debias_loss_fn(torch.mean(pool_out_reshape,dim=1,keepdim=False), logits_reshape,bias, labels) 65 | else: 66 | loss = None 67 | return logits_reshape, loss 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /LMH_vqa_debias_loss_functions.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict, Counter 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import torch 7 | import inspect 8 | 9 | 10 | def convert_sigmoid_logits_to_binary_logprobs(logits): 11 | """computes log(sigmoid(logits)), log(1-sigmoid(logits))""" 12 | log_prob = -F.softplus(-logits) 13 | log_one_minus_prob = -logits + log_prob 14 | return log_prob, log_one_minus_prob 15 | 16 | 17 | def elementwise_logsumexp(a, b): 18 | """computes log(exp(x) + exp(b))""" 19 | return torch.max(a, b) + torch.log1p(torch.exp(-torch.abs(a - b))) 20 | 21 | 22 | def renormalize_binary_logits(a, b): 23 | """Normalize so exp(a) + exp(b) == 1""" 24 | norm = elementwise_logsumexp(a, b) 25 | return a - norm, b - norm 26 | 27 | 28 | class DebiasLossFn(nn.Module): 29 | """General API for our loss functions""" 30 | 31 | def forward(self, hidden, logits, bias, labels): 32 | """ 33 | :param hidden: [batch, n_hidden] hidden features from the last layer in the model 34 | :param logits: [batch, n_answers_options] sigmoid logits for each answer option 35 | :param bias: [batch, n_answers_options] 36 | bias probabilities for each answer option between 0 and 1 37 | :param labels: [batch, n_answers_options] 38 | scores for each answer option, between 0 and 1 39 | :return: Scalar loss 40 | """ 41 | raise NotImplementedError() 42 | 43 | def to_json(self): 44 | """Get a json representation of this loss function. 45 | 46 | We construct this by looking up the __init__ args 47 | """ 48 | cls = self.__class__ 49 | init = cls.__init__ 50 | if init is object.__init__: 51 | return [] # No init args 52 | 53 | init_signature = inspect.getargspec(init) 54 | if init_signature.varargs is not None: 55 | raise NotImplementedError("varags not supported") 56 | if init_signature.keywords is not None: 57 | raise NotImplementedError("keywords not supported") 58 | args = [x for x in init_signature.args if x != "self"] 59 | out = OrderedDict() 60 | out["name"] = cls.__name__ 61 | for key in args: 62 | out[key] = getattr(self, key) 63 | return out 64 | 65 | 66 | class Plain(DebiasLossFn): 67 | def forward(self, hidden, logits, bias, labels): 68 | loss = F.binary_cross_entropy_with_logits(logits, labels) 69 | loss *= labels.size(1) 70 | return loss 71 | 72 | 73 | class ReweightByInvBias(DebiasLossFn): 74 | def forward(self, hidden, logits, bias, labels): 75 | # Manually compute the binary cross entropy since the old version of torch always aggregates 76 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 77 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob) 78 | weights = (1 - bias) 79 | loss *= weights # Apply the weights 80 | return loss.sum() / weights.sum() 81 | 82 | 83 | class BiasProduct(DebiasLossFn): 84 | def __init__(self, smooth=True, smooth_init=-1, constant_smooth=0.0): 85 | """ 86 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 87 | :param smooth_init: How to initialize `a` 88 | :param constant_smooth: Constant to add to the bias to smooth it 89 | """ 90 | super(BiasProduct, self).__init__() 91 | self.constant_smooth = constant_smooth 92 | self.smooth_init = smooth_init 93 | self.smooth = smooth 94 | if smooth: 95 | self.smooth_param = torch.nn.Parameter( 96 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 97 | else: 98 | self.smooth_param = None 99 | 100 | def forward(self, hidden, logits, bias, labels): 101 | smooth = self.constant_smooth 102 | if self.smooth: 103 | smooth += F.sigmoid(self.smooth_param) 104 | 105 | # Convert the bias into log-space, with a factor for both the 106 | # binary outputs for each answer option 107 | bias_lp = torch.log(bias + smooth) 108 | bias_l_inv = torch.log1p(-bias + smooth) 109 | 110 | # Convert the the logits into log-space with the same format 111 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 112 | 113 | # Add the bias 114 | log_prob += bias_lp 115 | log_one_minus_prob += bias_l_inv 116 | 117 | # Re-normalize the factors in logspace 118 | log_prob, log_one_minus_prob = renormalize_binary_logits(log_prob, log_one_minus_prob) 119 | 120 | # Compute the binary cross entropy 121 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 122 | return loss 123 | 124 | 125 | class LearnedMixin(DebiasLossFn): 126 | def __init__(self, w, smooth=True, smooth_init=-1, constant_smooth=0.0): 127 | """ 128 | :param w: Weight of the entropy penalty 129 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 130 | :param smooth_init: How to initialize `a` 131 | :param constant_smooth: Constant to add to the bias to smooth it 132 | """ 133 | super(LearnedMixin, self).__init__() 134 | self.w = w 135 | self.smooth_init = smooth_init 136 | self.constant_smooth = constant_smooth 137 | self.bias_lin = torch.nn.Linear(1024, 1) 138 | self.bias_lin4lxmert = torch.nn.Linear(768, 1) 139 | self.smooth = smooth 140 | if self.smooth: 141 | self.smooth_param = torch.nn.Parameter( 142 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 143 | else: 144 | self.smooth_param = None 145 | 146 | def forward(self, hidden, logits, bias, labels): 147 | factor = self.bias_lin4lxmert.forward(hidden) # [batch, 1] 148 | factor = F.softplus(factor) 149 | 150 | bias = torch.stack([bias, 1 - bias], 2) # [batch, n_answers, 2] 151 | 152 | # Smooth 153 | bias += self.constant_smooth 154 | if self.smooth: 155 | soften_factor = F.sigmoid(self.smooth_param) 156 | bias = bias + soften_factor.unsqueeze(1) 157 | 158 | bias = torch.log(bias) # Convert to logspace 159 | 160 | # Scale by the factor 161 | # [batch, n_answers, 2] * [batch, 1, 1] -> [batch, n_answers, 2] 162 | bias = bias * factor.unsqueeze(1) 163 | 164 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 165 | log_probs = torch.stack([log_prob, log_one_minus_prob], 2) 166 | 167 | # Add the bias in 168 | logits = bias + log_probs 169 | 170 | # Renormalize to get log probabilities 171 | log_prob, log_one_minus_prob = renormalize_binary_logits(logits[:, :, 0], logits[:, :, 1]) 172 | 173 | # Compute loss 174 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 175 | 176 | # Re-normalized version of the bias 177 | bias_norm = elementwise_logsumexp(bias[:, :, 0], bias[:, :, 1]) 178 | bias_logprob = bias - bias_norm.unsqueeze(2) 179 | 180 | # Compute and add the entropy penalty 181 | entropy = -(torch.exp(bias_logprob) * bias_logprob).sum(2).mean() 182 | return loss + self.w * entropy 183 | -------------------------------------------------------------------------------- /QTD_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import _pickle as cPickle 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from language_model import WordEmbedding, QuestionEmbedding 7 | from classifier import SimpleClassifier, PaperClassifier 8 | from fc import FCNet, GTH 9 | from attention import Att_0, Att_1, Att_2, Att_3, Att_P, Att_PD, Att_3S 10 | import torch 11 | import random 12 | 13 | class Dictionary(object): 14 | def __init__(self, word2idx=None, idx2word=None): 15 | if word2idx is None: 16 | word2idx = {} 17 | if idx2word is None: 18 | idx2word = [] 19 | self.word2idx = word2idx 20 | self.idx2word = idx2word 21 | 22 | @property 23 | def ntoken(self): 24 | return len(self.word2idx) 25 | 26 | @property 27 | def padding_idx(self): 28 | return len(self.word2idx) 29 | 30 | def tokenize(self, sentence, add_word): 31 | sentence = sentence.lower() 32 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s') 33 | words = sentence.split() 34 | tokens = [] 35 | if add_word: 36 | for w in words: 37 | tokens.append(self.add_word(w)) 38 | else: 39 | for w in words: 40 | # the least frequent word (`bebe`) as UNK for Visual Genome dataset 41 | tokens.append(self.word2idx.get(w, self.padding_idx - 1)) 42 | return tokens 43 | 44 | def dump_to_file(self, path): 45 | cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 46 | print('dictionary dumped to %s' % path) 47 | 48 | @classmethod 49 | def load_from_file(cls, path): 50 | print('loading dictionary from %s' % path) 51 | word2idx, idx2word = cPickle.load(open(path, 'rb')) 52 | d = cls(word2idx, idx2word) 53 | return d 54 | 55 | def add_word(self, word): 56 | if word not in self.word2idx: 57 | self.idx2word.append(word) 58 | self.word2idx[word] = len(self.idx2word) - 1 59 | return self.word2idx[word] 60 | 61 | def __len__(self): 62 | return len(self.idx2word) 63 | 64 | 65 | 66 | 67 | class Model(nn.Module): 68 | def __init__(self, opt): 69 | super(Model, self).__init__() 70 | self.dictionary = Dictionary.load_from_file(opt.dataroot + 'dictionary.pkl') 71 | num_hid = 128 72 | activation = opt.activation 73 | dropG = opt.dropG 74 | dropW = opt.dropW 75 | dropout = opt.dropout 76 | dropL = opt.dropL 77 | norm = opt.norm 78 | dropC = opt.dropC 79 | self.opt = opt 80 | 81 | self.w_emb = WordEmbedding(opt.ntokens, emb_dim=300, dropout=dropW) 82 | self.w_emb.init_embedding(opt.dataroot + 'glove6b_init_300d.npy') 83 | self.q_emb = QuestionEmbedding(in_dim=300, num_hid=num_hid, nlayers=1, 84 | bidirect=False, dropout=dropG, rnn_type='GRU') 85 | self.q_net = FCNet([self.q_emb.num_hid, num_hid], dropout=dropL, norm=norm, act=activation) 86 | self.classifier = SimpleClassifier(in_dim=num_hid, hid_dim = num_hid//2 , out_dim= 2,#opt.test_candi_ans_num, 87 | dropout=dropC, norm=norm, act=activation) 88 | self.normal = nn.BatchNorm1d(num_hid,affine=False) 89 | 90 | def forward(self, q): 91 | q = self.tokenize(q) 92 | q = torch.from_numpy(np.array(q)) 93 | w_emb = self.w_emb(q.cuda()) 94 | q_emb = self.q_emb(w_emb) 95 | q_repr = self.q_net(q_emb) 96 | batch_size = q.size(0) 97 | logits_pos = self.classifier(q_repr) 98 | return logits_pos 99 | def tokenize(self, q_text, max_length=14): 100 | """Tokenizes the questions. 101 | 102 | This will add q_token in each entry of the dataset. 103 | -1 represent nil, and should be treated as padding_idx in embedding 104 | """ 105 | token_list = [] 106 | for q_iter in q_text: 107 | tokens = self.dictionary.tokenize(q_iter, False) 108 | tokens = tokens[:max_length] 109 | if len(tokens) < max_length: 110 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 111 | tokens = tokens + padding 112 | assert len(tokens) == max_length 113 | token_list.append(tokens) 114 | return token_list 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAR-VQA 2 | Here is the implementation of our ACL-2021 [Check It Again: Progressive Visual Question Answering via Visual Entailment](https://aclanthology.org/2021.acl-long.317.pdf). 3 | This repository contains code modified from [here for SAR+SSL](https://github.com/CrossmodalGroup/SSL-VQA) and [here for SAR+LMH](https://github.com/chrisc36/bottom-up-attention-vqa), many thanks! 4 | ![image](https://github.com/PhoebusSi/SAR/blob/master/model.jpg) 5 | ## Requirements 6 | * python 3.7.6 7 | * pytorch 1.5.0 8 | * zarr 9 | * tdqm 10 | * spacy 11 | * h5py 12 | 13 | ## Download and preprocess the data 14 | ```Bash 15 | cd data 16 | bash download.sh 17 | python preprocess_image.py --data trainval 18 | python create_dictionary.py --dataroot vqacp2/ 19 | python preprocess_text.py --dataroot vqacp2/ --version v2 20 | cd .. 21 | ``` 22 | 23 | ## Train Candidate Answers Selector & Build the datasets for the Answers Re-ranking module 24 | * The VQA model applied as Candidate Answer Selector(CAS) is a free choice in our framework. In this paper, we mainly use SSL as CAS. 25 | 26 | 27 | * The setting of model training of CAS can be refered in [SSL](https://github.com/CrossmodalGroup/SSL-VQA). 28 | 29 | 30 | * To build the Dataset for the Answer Re-ranking module based on Visual Entailment, we modified the SSL's code of `VQAFeatureDataset()` in [dataset_vqacp.py](https://github.com/CrossmodalGroup/SSL-VQA/blob/master/dataset_vqacp.py) and `evaluate()` in [train.py](https://github.com/CrossmodalGroup/SSL-VQA/blob/master/train.py). The modified codes are avaliable in `CAS_scripts`, just replace the corresponding class/function in [SSL](https://github.com/CrossmodalGroup/SSL-VQA). 31 | 32 | 33 | * After the Candidate Answers Selecting Module, we can get `train_top20_candidates.json` and `test_top20_candidates.json` files as the training and test set for Answer Re-ranking Module,respectively. There are demos for the two output json file in `data4VE` folder: `train_dataset4VE_demo.json`, `train_dataset4VE_demo.json`. 34 | 35 | ## Builed Top20-Candidate-Answers dataset (entries) for training/test the model of Answer Re-ranking module 36 | If you don't want to train CAS model(e.g. SSL) to build the datasets in the way mentioned above, you can download the rebuiled top20-candidate-answers dataset (with different Qiestion-Answer-Combination strategies) from here([C-train](https://drive.google.com/file/d/1XJ6u0111n1_36tIy7o97WtcvnvVeRNLQ/view?usp=sharing),[C-test](https://drive.google.com/file/d/1XvkwCQIMIM-YFoU4dRNa7Y-qwXKgnOqb/view?usp=sharing),[R-train](https://drive.google.com/file/d/1g3XdVyedgGpK0ZQ_bNE3V6QP0b_DyQ_7/view?usp=sharing),[R-test](https://drive.google.com/file/d/14wHKJrg8hL2ycgPMsq4j6Mz087zoEiYt/view?usp=sharing)). 37 | 38 | * Put the downloaded Pickle files into the `data4VE` folder, then the code will load and rebuild it into the `entries` which will be feed in `__getitem__()` of dataloader. (Skipping all data preprocessing steps of the Answer Re-ranking based on Visual Entailment directly) 39 | * Each entry of the entries rebuiled from this Pickle file includes `image_features`, `image_spatials`, `top20_score`, `question_id`, `QA_text_ids`, `top20_label`, `answer_type`, `question_text`, `LMH_bias`, where `QA_text_ids` is the question-answer-combination(R/C) ids obtained/preprocessed from the LXMERT tokenizer. 40 | 41 | 42 | ## Training (Answer Re-ranking based on Visual Entailment) 43 | * Train Top12-SAR 44 | ```Bash 45 | CUDA_VISIBLE_DEVICES=0,1 python SAR_main.py --output saved_models_cp2/ --lp 0 --train_condi_ans_num 12 46 | ``` 47 | * Train Top20-SAR 48 | ```Bash 49 | CUDA_VISIBLE_DEVICES=0,1 python SAR_main.py --output saved_models_cp2/ --lp 0 --train_condi_ans_num 20 50 | ``` 51 | * Train Top12-SAR+SSL 52 | ```Bash 53 | CUDA_VISIBLE_DEVICES=0,1 python SAR_main.py --output saved_models_cp2/ --lp 1 --self_loss_weight 3 --train_condi_ans_num 12 54 | ``` 55 | * Train Top20-SAR+SSL 56 | ```Bash 57 | CUDA_VISIBLE_DEVICES=0,1 python SAR_main.py --output saved_models_cp2/ --lp 1 --self_loss_weight 3 --train_condi_ans_num 20 58 | ``` 59 | * Train Top12-SAR+LMH 60 | ```Bash 61 | CUDA_VISIBLE_DEVICES=0,1 python SAR_main.py --output saved_models_cp2/ --lp 2 --train_condi_ans_num 12 62 | ``` 63 | * Train Top20-SAR+LMH 64 | ```Bash 65 | CUDA_VISIBLE_DEVICES=0,1 python SAR_main.py --output saved_models_cp2/ --lp 2 --train_condi_ans_num 20 66 | ``` 67 | 68 | 69 | The function `evaluate()` in `SAR_train.py` is used to select the best model during training, without QTD module yet. The trained QTD model is used in `SAR_test.py` where we obtain the final test score. 70 | 71 | ## Evaluation 72 | * Evaluate trained SAR model 73 | ```Bash 74 | CUDA_VISIBLE_DEVICES=0 python SAR_test.py --checkpoint_path4test saved_models_cp2/SAR_top12_best_model.pth --output saved_models_cp2/result/ --lp 0 --QTD_N4yesno 1 --QTD_N4non_yesno 12 75 | ``` 76 | * Evaluate trained SAR+SSL model 77 | ```Bash 78 | CUDA_VISIBLE_DEVICES=0 python SAR_test.py --checkpoint_path4test saved_models_cp2/SAR_SSL_top12_best_model.pth --output saved_models_cp2/result/ --lp 1 --QTD_N4yesno 1 --QTD_N4non_yesno 12 79 | ``` 80 | * Evaluate trained SAR+LMH model 81 | ```Bash 82 | CUDA_VISIBLE_DEVICES=0 python SAR_test.py --checkpoint_path4test saved_models_cp2/SAR_LMH_top12_best_model.pth --output saved_models_cp2/result/ --lp 2 --QTD_N4yesno 2 --QTD_N4non_yesno 12 83 | ``` 84 | * Note that we mainly use `R->C` Question-Answer Combination Strategy, which can always achieves or rivals the best performance on SAR/SAR+SSL/SAR+LMH. Specifically, we first use strategy `R` ( `SAR_replace_dataset_vqacp.py`) at training, which is aimed at preventing the model from excessively focusing on the co-occurrence relation between question category and answer, and then use strategy `C`(`SAR_concatenate_dataset_vqacp.py`) at testing to introduce more information for inference. 85 | * Compute detailed accuracy for each answer type: 86 | ```bash 87 | python comput_score.py --input saved_models_cp2/result/XX.json --dataroot data/vqacp2/cache 88 | ``` 89 | ## Bugs or questions? 90 | If you have any questions related to the code or the paper, feel free to email Qingyi (`siqingyi@iie.ac.cn`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker! 91 | 92 | ## Reference 93 | If you found this code is useful, please cite the following paper: 94 | ``` 95 | @inproceedings{si-etal-2021-check, 96 | title = "Check It Again:Progressive Visual Question Answering via Visual Entailment", 97 | author = "Si, Qingyi and 98 | Lin, Zheng and 99 | Zheng, Ming yu and 100 | Fu, Peng and 101 | Wang, Weiping", 102 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", 103 | month = aug, 104 | year = "2021", 105 | address = "Online", 106 | publisher = "Association for Computational Linguistics", 107 | url = "https://aclanthology.org/2021.acl-long.317", 108 | doi = "10.18653/v1/2021.acl-long.317", 109 | pages = "4101--4110", 110 | abstract = "While sophisticated neural-based models have achieved remarkable success in Visual Question Answering (VQA), these models tend to answer questions only according to superficial correlations between question and answer. Several recent approaches have been developed to address this language priors problem. However, most of them predict the correct answer according to one best output without checking the authenticity of answers. Besides, they only explore the interaction between image and question, ignoring the semantics of candidate answers. In this paper, we propose a select-and-rerank (SAR) progressive framework based on Visual Entailment. Specifically, we first select the candidate answers relevant to the question or the image, then we rerank the candidate answers by a visual entailment task, which verifies whether the image semantically entails the synthetic statement of the question and each candidate answer. Experimental results show the effectiveness of our proposed framework, which establishes a new state-of-the-art accuracy on VQA-CP v2 with a 7.55{\%} improvement.", 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /SAR_concatenate_dataset_vqacp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import json 4 | import _pickle as cPickle 5 | import numpy as np 6 | import utils 7 | from transformers import LxmertTokenizer, LxmertModel 8 | import warnings 9 | 10 | with warnings.catch_warnings(): 11 | warnings.filterwarnings("ignore", category=FutureWarning) 12 | import h5py 13 | from xml.etree.ElementTree import parse 14 | import torch 15 | from torch.utils.data import Dataset 16 | import zarr 17 | import random 18 | import pickle 19 | COUNTING_ONLY = False 20 | 21 | 22 | def is_howmany(q, a, label2ans): 23 | if 'how many' in q.lower() or \ 24 | ('number of' in q.lower() and 'number of the' not in q.lower()) or \ 25 | 'amount of' in q.lower() or \ 26 | 'count of' in q.lower(): 27 | if a is None or answer_filter(a, label2ans): 28 | return True 29 | else: 30 | return False 31 | else: 32 | return False 33 | 34 | 35 | def answer_filter(answers, label2ans, max_num=10): 36 | for ans in answers['labels']: 37 | if label2ans[ans].isdigit() and max_num >= int(label2ans[ans]): 38 | return True 39 | return False 40 | 41 | 42 | class Dictionary(object): 43 | def __init__(self, word2idx=None, idx2word=None): 44 | if word2idx is None: 45 | word2idx = {} 46 | if idx2word is None: 47 | idx2word = [] 48 | self.word2idx = word2idx 49 | self.idx2word = idx2word 50 | 51 | @property 52 | def ntoken(self): 53 | return len(self.word2idx) 54 | 55 | @property 56 | def padding_idx(self): 57 | return len(self.word2idx) 58 | 59 | def tokenize(self, sentence, add_word): 60 | sentence = sentence.lower() 61 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s') 62 | words = sentence.split() 63 | tokens = [] 64 | if add_word: 65 | for w in words: 66 | tokens.append(self.add_word(w)) 67 | else: 68 | for w in words: 69 | # the least frequent word (`bebe`) as UNK for Visual Genome dataset 70 | tokens.append(self.word2idx.get(w, self.padding_idx - 1)) 71 | return tokens 72 | 73 | def dump_to_file(self, path): 74 | cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 75 | print('dictionary dumped to %s' % path) 76 | 77 | @classmethod 78 | def load_from_file(cls, path): 79 | print('loading dictionary from %s' % path) 80 | word2idx, idx2word = cPickle.load(open(path, 'rb')) 81 | d = cls(word2idx, idx2word) 82 | return d 83 | 84 | def add_word(self, word): 85 | if word not in self.word2idx: 86 | self.idx2word.append(word) 87 | self.word2idx[word] = len(self.idx2word) - 1 88 | return self.word2idx[word] 89 | 90 | def __len__(self): 91 | return len(self.idx2word) 92 | 93 | 94 | def _create_entry(img, question, answer, ans4reranker, label2ans): 95 | 96 | if None != answer: 97 | answer.pop('image_id') 98 | answer.pop('question_id') 99 | ans4reranker.pop('image_id') 100 | ans4reranker.pop('question_id') 101 | if len(answer['labels']): 102 | answer['label_text'] = label2ans[answer['labels'][answer['scores'].index(max(answer['scores']))]] 103 | answer['label_all_text'] = ", ".join([label2ans[i] for i in answer['labels']] ) 104 | else: 105 | answer['label_text'] = None 106 | answer['label_all_text'] = None 107 | candi_ans = {} 108 | candi_ans['top20'] = ans4reranker['top20'] 109 | candi_ans['top20_scores'] = ans4reranker['top20_scores'] 110 | top20_text = [label2ans[i] for i in candi_ans['top20']] 111 | candi_ans['top20_text'] = top20_text 112 | 113 | 114 | entry = { 115 | 'question_id': question['question_id'], 116 | 'image_id': question['image_id'], 117 | 'image': img, 118 | 'question': question['question'], 119 | 'question_type': answer['question_type'], 120 | 'answer': answer, 121 | 'candi_ans' : candi_ans 122 | } 123 | return entry 124 | 125 | 126 | def _load_dataset(dataroot, name, label2ans,ratio=1.0): 127 | """Load entries 128 | 129 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features 130 | dataroot: root path of dataset 131 | name: 'train', 'test' 132 | """ 133 | question_path = os.path.join(dataroot, 'vqacp_v2_%s_questions.json' % (name)) 134 | questions = sorted(json.load(open(question_path)), 135 | key=lambda x: x['question_id']) 136 | 137 | 138 | answer_path = os.path.join(dataroot, 'cache', '%s_target.pkl' % name) 139 | answers = cPickle.load(open(answer_path, 'rb')) 140 | answers = sorted(answers, key=lambda x: x['question_id'])[0:len(questions)] 141 | 142 | 143 | ans4reranker_path = os.path.join(dataroot, '%s_top20_candidates.json'%name) 144 | #ans4reranker_path = os.path.join('data4VE/%s_dataset4VE_demo.json'%name) 145 | ans4reranker = sorted(json.load(open(ans4reranker_path)), 146 | key=lambda x: x['question_id']) 147 | 148 | ans_mean_len = 0 149 | ques_num = 0 150 | for i in answers: 151 | ans_mean_len = ans_mean_len + len(i['labels']) 152 | ques_num = ques_num + 1 153 | utils.assert_eq(len(questions), len(answers)) 154 | utils.assert_eq(len(ans4reranker), len(answers)) 155 | 156 | if ratio < 1.0: 157 | index = random.sample(range(0,len(questions)), int(len(questions)*ratio)) 158 | questions_new = [questions[i] for i in index] 159 | answers_new = [answers[i] for i in index] 160 | ans4reranker_new = [ans4reranker[i] for i in index] 161 | else: 162 | questions_new = questions 163 | answers_new = answers 164 | ans4reranker_new = ans4reranker 165 | entries = [] 166 | tongji = {} 167 | tongji_ques = {} 168 | for question, answer, ans4reranker in zip(questions_new, answers_new, ans4reranker_new): 169 | utils.assert_eq(question['question_id'], answer['question_id']) 170 | utils.assert_eq(question['image_id'], answer['image_id']) 171 | utils.assert_eq(question['image_id'], ans4reranker['image_id']) 172 | utils.assert_eq(question['image_id'], ans4reranker['image_id']) 173 | img_id = question['image_id'] 174 | 175 | if not COUNTING_ONLY or is_howmany(question['question'], answer, label2ans): 176 | new_entry = _create_entry(img_id, question, answer, ans4reranker, label2ans) 177 | ans_word = new_entry['answer']['label_text'] 178 | if ans_word not in tongji.keys(): 179 | tongji[ans_word] = 1 180 | else: 181 | tongji[ans_word] = tongji[ans_word] + 1 182 | entries.append(new_entry) 183 | que_word = " ".join(new_entry['question'].split()[:2]) 184 | if que_word not in tongji_ques.keys(): 185 | tongji_ques[que_word] = 1 186 | else: 187 | tongji_ques[que_word] = tongji_ques[que_word] + 1 188 | 189 | 190 | return entries 191 | 192 | 193 | class VQAFeatureDataset(Dataset): 194 | def __init__(self, name, dictionary, dataroot, image_dataroot, ratio, adaptive=False, opt=None): 195 | super(VQAFeatureDataset, self).__init__() 196 | assert name in ['train', 'test'] 197 | ans2label_path = os.path.join(dataroot, 'cache', 'train_test_ans2label.pkl') 198 | label2ans_path = os.path.join(dataroot, 'cache', 'train_test_label2ans.pkl') 199 | self.ans2label = cPickle.load(open(ans2label_path, 'rb')) 200 | self.label2ans = cPickle.load(open(label2ans_path, 'rb')) 201 | 202 | if name == "train": 203 | self.candi_ans_num = opt.train_candi_ans_num 204 | self.num_ans_candidates = opt.train_candi_ans_num 205 | elif name == "test": 206 | self.candi_ans_num = opt.test_candi_ans_num 207 | self.num_ans_candidates = opt.test_candi_ans_num 208 | 209 | self.dictionary = dictionary 210 | self.adaptive = adaptive 211 | 212 | print('loading image features and bounding boxes') 213 | # Load image features and bounding boxes 214 | self.features = zarr.open(os.path.join(image_dataroot, 'trainval.zarr'), mode='r') 215 | self.spatials = zarr.open(os.path.join(image_dataroot, 'trainval_boxes.zarr'), mode='r') 216 | 217 | 218 | 219 | self.v_dim = self.features[list(self.features.keys())[1]].shape[1] 220 | self.s_dim = self.spatials[list(self.spatials.keys())[1]].shape[1] 221 | is_exist = os.path.exists('data4VE/C_'+name+'_top20_densecaption_tokenizer_ids.pkl') 222 | if not is_exist: 223 | self.entries = _load_dataset(dataroot, name, self.label2ans, ratio) 224 | self.tokenize(max_length=18, candi_ans_num=self.candi_ans_num) 225 | self.tensorize(name) 226 | else: 227 | fp = open('data4VE/C_'+name+"_top20_densecaption_tokenizer_ids.pkl","rb+") 228 | self.entries = pickle.load(fp) 229 | def tokenize(self, max_length=18, candi_ans_num=5): 230 | tokenizer = LxmertTokenizer.from_pretrained('unc-nlp/lxmert-base-uncased') 231 | for entry in self.entries: 232 | q_a_text_top20 = [] 233 | question_text = entry['question'] 234 | question_type_text = entry['question_type'] 235 | ans_text_list = entry['candi_ans']['top20_text'] 236 | for ind, i in enumerate(ans_text_list): 237 | lower_question_text = question_text.lower() 238 | if question_type_text in lower_question_text : 239 | dense_caption = i+" "+lower_question_text 240 | else: 241 | dense_caption = i+" "+lower_question_text 242 | dense_caption_token_dict = tokenizer(dense_caption) 243 | qa_tokens = dense_caption_token_dict['input_ids'] 244 | if len(qa_tokens) > max_length : 245 | qa_tokens = qa_tokens[:max_length] 246 | else: 247 | padding = [tokenizer('[PAD]')['input_ids'][1:-1][0]]*(max_length - len(qa_tokens)) 248 | qa_tokens = qa_tokens + padding 249 | assert len(qa_tokens) == max_length 250 | q_a_tokens_tensor = torch.from_numpy(np.array([qa_tokens])) 251 | if ind == 0: 252 | q_a_tokens_top_20 = q_a_tokens_tensor 253 | else: 254 | q_a_tokens_top_20 = torch.cat([q_a_tokens_top_20, q_a_tokens_tensor]) 255 | entry['candi_ans']["20_qa_text"] = q_a_tokens_top_20 256 | 257 | 258 | def tensorize(self, name): 259 | for entry in self.entries: 260 | answer = entry['answer'] 261 | candi_ans = entry['candi_ans'] 262 | top20 = torch.from_numpy(np.array(candi_ans['top20'])) 263 | entry['candi_ans']['top20'] = top20 264 | top20_scores = torch.from_numpy(np.array(candi_ans['top20_scores'])) 265 | entry['candi_ans']['top20_scores'] = top20_scores 266 | with open('data4VE/C_'+name+'_top20_densecaption_tokenizer_ids.pkl', 'wb') as f: 267 | pickle.dump(self.entries, f) 268 | 269 | def __getitem__(self, index): 270 | entry = self.entries[index] 271 | if not self.adaptive: 272 | features = torch.from_numpy(np.array(self.features[entry['image']])) 273 | spatials = torch.from_numpy(np.array(self.spatials[entry['image']])) 274 | 275 | question_text = entry['question'] 276 | question_id = entry['question_id'] 277 | answer = entry['answer'] 278 | candi_ans = entry['condi_ans'] 279 | 280 | if None != answer: 281 | labels = answer['labels'] 282 | scores = answer['scores'] 283 | ans_type = answer['answer_type'] 284 | target = candi_ans['top20_scores'][:self.candi_ans_num] 285 | qa_text = candi_ans['20_qa_text'][:self.candi_ans_num] 286 | topN_id = candi_ans['top20'][:self.candi_ans_num] 287 | LMH_bias = entry["bias"][:self.candi_ans_num] 288 | return features, spatials, target, question_id, qa_text, topN_id, ans_type, question_text, LMH_bias#entry["bias"] 289 | else: 290 | return features, spatials, question_id 291 | 292 | def __len__(self): 293 | return len(self.entries) 294 | 295 | 296 | if __name__ == '__main__': 297 | 298 | from torch.utils.data import DataLoader 299 | 300 | dataroot = './data/vqacp2/' 301 | img_root = './data/coco/' 302 | dictionary = Dictionary.load_from_file(dataroot + 'dictionary.pkl') 303 | print(dictionary) 304 | train_dset = VQAFeatureDataset('train', dictionary, dataroot, img_root, ratio=1.0, adaptive=False) 305 | 306 | loader = DataLoader(train_dset, 256, shuffle=True, num_workers=1, collate_fn=utils.trim_collate) 307 | 308 | for v, b, q, a, qid in loader: 309 | print(a.shape) 310 | -------------------------------------------------------------------------------- /SAR_main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import defaultdict, Counter 3 | from transformers import LxmertTokenizer, LxmertModel 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | import torch.nn.init as init 9 | import numpy as np 10 | #R 11 | from SAR_replace_dataset_vqacp import Dictionary, VQAFeatureDataset 12 | #C 13 | #from SAR_concatenate_dataset_vqacp import Dictionary, VQAFeatureDataset 14 | 15 | from LMH_lxmert_model import Model as LXM_Model 16 | from lxmert_model import Model 17 | import utils 18 | import opts_SAR as opts 19 | from SAR_train import train 20 | 21 | 22 | def weights_init_kn(m): 23 | if isinstance(m, nn.Linear): 24 | nn.init.kaiming_normal_(m.weight.data, a=0.01) 25 | 26 | 27 | if __name__ == '__main__': 28 | opt = opts.parse_opt() 29 | seed = 0 30 | if opt.seed == 0: 31 | seed = random.randint(1, 10000) 32 | random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed(opt.seed) 35 | else: 36 | seed = opt.seed 37 | random.seed(seed) 38 | torch.manual_seed(opt.seed) 39 | torch.cuda.manual_seed(opt.seed) 40 | torch.backends.cudnn.benchmark = True 41 | 42 | dictionary = Dictionary.load_from_file(opt.dataroot + 'dictionary.pkl') 43 | opt.ntokens = dictionary.ntoken 44 | if int(opt.lp) == 0: 45 | model = Model(opt) 46 | elif int(opt.lp) == 1: 47 | model = Model(opt) 48 | elif int(opt.lp) == 2: 49 | model = LXM_Model(opt) 50 | else: 51 | print("opt.lp has to be selected in [0,1,2]") 52 | assert 0 == 1 53 | model = model.cuda() 54 | train_dset = VQAFeatureDataset('train', dictionary, opt.dataroot, opt.img_root, ratio=opt.ratio, adaptive=False,opt=opt) # load labeld data 55 | eval_dset = VQAFeatureDataset('test', dictionary, opt.dataroot, opt.img_root,ratio=1.0, adaptive=False,opt=opt) 56 | answer_voc_size = opt.ans_dim# 57 | 58 | # Compute the bias: 59 | # The bias here is just the expected score for each answer/question type 60 | 61 | # question_type -> answer -> total score 62 | question_type_to_probs = defaultdict(Counter) 63 | # question_type -> num_occurances 64 | question_type_to_count = Counter() 65 | for ex in train_dset.entries: 66 | ans = ex["answer"] 67 | q_type = ans["question_type"] 68 | question_type_to_count[q_type] += 1 69 | if ans["labels"] is not None: 70 | for label, score in zip(ans["labels"], ans["scores"]): 71 | question_type_to_probs[q_type][label] += score 72 | 73 | question_type_to_prob_array = {} 74 | for q_type, count in question_type_to_count.items(): 75 | prob_array = np.zeros(answer_voc_size, np.float32) 76 | for label, total_score in question_type_to_probs[q_type].items(): 77 | prob_array[label] += total_score 78 | prob_array /= count 79 | question_type_to_prob_array[q_type] = prob_array 80 | 81 | # Now add a `bias` field to each example 82 | for ds in [train_dset, eval_dset]: 83 | for ex in ds.entries: 84 | q_type = ex["answer"]["question_type"] 85 | candi_top20_prob_array = np.zeros(20, np.float32) 86 | for i in range(len(candi_top20_prob_array)): 87 | candi_top20_prob_array[i] = question_type_to_prob_array[q_type][ex['condi_ans']['top20'][i]] 88 | ex['bias'] = candi_top20_prob_array 89 | 90 | train_loader = DataLoader(train_dset, opt.batch_size, shuffle=True, num_workers=0)#1, collate_fn=utils.trim_collate) 91 | opt.use_all = 1 92 | eval_loader = DataLoader(eval_dset, opt.batch_size, shuffle=False, num_workers=0)#1, collate_fn=utils.trim_collate) 93 | 94 | 95 | train(model, train_loader, eval_loader, opt) 96 | -------------------------------------------------------------------------------- /SAR_replace_dataset_vqacp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import json 4 | import _pickle as cPickle 5 | import numpy as np 6 | import utils 7 | from transformers import LxmertTokenizer, LxmertModel 8 | import warnings 9 | 10 | with warnings.catch_warnings(): 11 | warnings.filterwarnings("ignore", category=FutureWarning) 12 | import h5py 13 | from xml.etree.ElementTree import parse 14 | import torch 15 | from torch.utils.data import Dataset 16 | import zarr 17 | import random 18 | import pickle 19 | COUNTING_ONLY = False 20 | 21 | 22 | def is_howmany(q, a, label2ans): 23 | if 'how many' in q.lower() or \ 24 | ('number of' in q.lower() and 'number of the' not in q.lower()) or \ 25 | 'amount of' in q.lower() or \ 26 | 'count of' in q.lower(): 27 | if a is None or answer_filter(a, label2ans): 28 | return True 29 | else: 30 | return False 31 | else: 32 | return False 33 | 34 | 35 | def answer_filter(answers, label2ans, max_num=10): 36 | for ans in answers['labels']: 37 | if label2ans[ans].isdigit() and max_num >= int(label2ans[ans]): 38 | return True 39 | return False 40 | 41 | 42 | class Dictionary(object): 43 | def __init__(self, word2idx=None, idx2word=None): 44 | if word2idx is None: 45 | word2idx = {} 46 | if idx2word is None: 47 | idx2word = [] 48 | self.word2idx = word2idx 49 | self.idx2word = idx2word 50 | 51 | @property 52 | def ntoken(self): 53 | return len(self.word2idx) 54 | 55 | @property 56 | def padding_idx(self): 57 | return len(self.word2idx) 58 | 59 | def tokenize(self, sentence, add_word): 60 | sentence = sentence.lower() 61 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s') 62 | words = sentence.split() 63 | tokens = [] 64 | if add_word: 65 | for w in words: 66 | tokens.append(self.add_word(w)) 67 | else: 68 | for w in words: 69 | # the least frequent word (`bebe`) as UNK for Visual Genome dataset 70 | tokens.append(self.word2idx.get(w, self.padding_idx - 1)) 71 | return tokens 72 | 73 | def dump_to_file(self, path): 74 | cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 75 | print('dictionary dumped to %s' % path) 76 | 77 | @classmethod 78 | def load_from_file(cls, path): 79 | print('loading dictionary from %s' % path) 80 | word2idx, idx2word = cPickle.load(open(path, 'rb')) 81 | d = cls(word2idx, idx2word) 82 | return d 83 | 84 | def add_word(self, word): 85 | if word not in self.word2idx: 86 | self.idx2word.append(word) 87 | self.word2idx[word] = len(self.idx2word) - 1 88 | return self.word2idx[word] 89 | 90 | def __len__(self): 91 | return len(self.idx2word) 92 | 93 | 94 | def _create_entry(img, question, answer, ans4reranker, label2ans): 95 | 96 | if None != answer: 97 | answer.pop('image_id') 98 | answer.pop('question_id') 99 | ans4reranker.pop('image_id') 100 | ans4reranker.pop('question_id') 101 | if len(answer['labels']): 102 | answer['label_text'] = label2ans[answer['labels'][answer['scores'].index(max(answer['scores']))]] 103 | answer['label_all_text'] = ", ".join([label2ans[i] for i in answer['labels']] ) 104 | else: 105 | answer['label_text'] = None 106 | answer['label_all_text'] = None 107 | candi_ans = {} 108 | candi_ans['top20'] = ans4reranker['top20'] 109 | candi_ans['top20_scores'] = ans4reranker['top20_scores'] 110 | top20_text = [label2ans[i] for i in candi_ans['top20']] 111 | candi_ans['top20_text'] = top20_text 112 | 113 | 114 | entry = { 115 | 'question_id': question['question_id'], 116 | 'image_id': question['image_id'], 117 | 'image': img, 118 | 'question': question['question'], 119 | 'question_type': answer['question_type'], 120 | 'answer': answer, 121 | 'candi_ans' : candi_ans 122 | } 123 | return entry 124 | 125 | 126 | def _load_dataset(dataroot, name, label2ans,ratio=1.0): 127 | """Load entries 128 | 129 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features 130 | dataroot: root path of dataset 131 | name: 'train', 'test' 132 | """ 133 | question_path = os.path.join(dataroot, 'vqacp_v2_%s_questions.json' % (name)) 134 | questions = sorted(json.load(open(question_path)), 135 | key=lambda x: x['question_id']) 136 | 137 | 138 | answer_path = os.path.join(dataroot, 'cache', '%s_target.pkl' % name) 139 | answers = cPickle.load(open(answer_path, 'rb')) 140 | answers = sorted(answers, key=lambda x: x['question_id'])[0:len(questions)] 141 | 142 | 143 | ans4reranker_path = os.path.join(dataroot, '%s_top20_candidates.json'%name) 144 | #ans4reranker_path = os.path.join('data4VE/%s_dataset4VE_demo.json'%name) 145 | ans4reranker = sorted(json.load(open(ans4reranker_path)), 146 | key=lambda x: x['question_id']) 147 | 148 | ans_mean_len = 0 149 | ques_num = 0 150 | for i in answers: 151 | ans_mean_len = ans_mean_len + len(i['labels']) 152 | ques_num = ques_num + 1 153 | utils.assert_eq(len(questions), len(answers)) 154 | utils.assert_eq(len(ans4reranker), len(answers)) 155 | 156 | if ratio < 1.0: 157 | index = random.sample(range(0,len(questions)), int(len(questions)*ratio)) 158 | questions_new = [questions[i] for i in index] 159 | answers_new = [answers[i] for i in index] 160 | ans4reranker_new = [ans4reranker[i] for i in index] 161 | else: 162 | questions_new = questions 163 | answers_new = answers 164 | ans4reranker_new = ans4reranker 165 | entries = [] 166 | tongji = {} 167 | tongji_ques = {} 168 | for question, answer, ans4reranker in zip(questions_new, answers_new, ans4reranker_new): 169 | utils.assert_eq(question['question_id'], answer['question_id']) 170 | utils.assert_eq(question['image_id'], answer['image_id']) 171 | utils.assert_eq(question['image_id'], ans4reranker['image_id']) 172 | utils.assert_eq(question['image_id'], ans4reranker['image_id']) 173 | img_id = question['image_id'] 174 | 175 | if not COUNTING_ONLY or is_howmany(question['question'], answer, label2ans): 176 | new_entry = _create_entry(img_id, question, answer, ans4reranker, label2ans) 177 | ans_word = new_entry['answer']['label_text'] 178 | if ans_word not in tongji.keys(): 179 | tongji[ans_word] = 1 180 | else: 181 | tongji[ans_word] = tongji[ans_word] + 1 182 | entries.append(new_entry) 183 | que_word = " ".join(new_entry['question'].split()[:2]) 184 | if que_word not in tongji_ques.keys(): 185 | tongji_ques[que_word] = 1 186 | else: 187 | tongji_ques[que_word] = tongji_ques[que_word] + 1 188 | 189 | 190 | return entries 191 | 192 | 193 | class VQAFeatureDataset(Dataset): 194 | def __init__(self, name, dictionary, dataroot, image_dataroot, ratio, adaptive=False, opt=None): 195 | super(VQAFeatureDataset, self).__init__() 196 | assert name in ['train', 'test'] 197 | ans2label_path = os.path.join(dataroot, 'cache', 'train_test_ans2label.pkl') 198 | label2ans_path = os.path.join(dataroot, 'cache', 'train_test_label2ans.pkl') 199 | self.ans2label = cPickle.load(open(ans2label_path, 'rb')) 200 | self.label2ans = cPickle.load(open(label2ans_path, 'rb')) 201 | 202 | if name == "train": 203 | self.candi_ans_num = opt.train_candi_ans_num 204 | self.num_ans_candidates = opt.train_candi_ans_num 205 | elif name == "test": 206 | self.candi_ans_num = opt.test_candi_ans_num 207 | self.num_ans_candidates = opt.test_candi_ans_num 208 | 209 | self.dictionary = dictionary 210 | self.adaptive = adaptive 211 | 212 | print('loading image features and bounding boxes') 213 | # Load image features and bounding boxes 214 | self.features = zarr.open(os.path.join(image_dataroot, 'trainval.zarr'), mode='r') 215 | self.spatials = zarr.open(os.path.join(image_dataroot, 'trainval_boxes.zarr'), mode='r') 216 | 217 | 218 | 219 | self.v_dim = self.features[list(self.features.keys())[1]].shape[1] 220 | self.s_dim = self.spatials[list(self.spatials.keys())[1]].shape[1] 221 | is_exist = os.path.exists('data4VE/R_'+name+'_top20_densecaption_tokenizer_ids.pkl') 222 | if not is_exist: 223 | self.entries = _load_dataset(dataroot, name, self.label2ans, ratio) 224 | self.tokenize(max_length=15, candi_ans_num=self.candi_ans_num) 225 | self.tensorize(name) 226 | else: 227 | fp = open('data4VE/R_'+name+"_top20_densecaption_tokenizer_ids.pkl","rb+") 228 | self.entries = pickle.load(fp) 229 | def tokenize(self, max_length=15, candi_ans_num=5): 230 | tokenizer = LxmertTokenizer.from_pretrained('unc-nlp/lxmert-base-uncased') 231 | for entry in self.entries: 232 | q_a_text_top20 = [] 233 | question_text = entry['question'] 234 | question_type_text = entry['question_type'] 235 | ans_text_list = entry['candi_ans']['top20_text'] 236 | for ind, i in enumerate(ans_text_list): 237 | lower_question_text = question_text.lower() 238 | if question_type_text in lower_question_text : 239 | dense_caption = lower_question_text.replace(question_type_text,i)[:-1] 240 | else: 241 | dense_caption = i+" "+lower_question_text 242 | dense_caption_token_dict = tokenizer(dense_caption) 243 | qa_tokens = dense_caption_token_dict['input_ids'] 244 | if len(qa_tokens) > max_length : 245 | qa_tokens = qa_tokens[:max_length] 246 | else: 247 | padding = [tokenizer('[PAD]')['input_ids'][1:-1][0]]*(max_length - len(qa_tokens)) 248 | qa_tokens = qa_tokens + padding 249 | assert len(qa_tokens) == max_length 250 | q_a_tokens_tensor = torch.from_numpy(np.array([qa_tokens])) 251 | if ind == 0: 252 | q_a_tokens_top_20 = q_a_tokens_tensor 253 | else: 254 | q_a_tokens_top_20 = torch.cat([q_a_tokens_top_20, q_a_tokens_tensor]) 255 | entry['candi_ans']["20_qa_text"] = q_a_tokens_top_20 256 | 257 | 258 | def tensorize(self, name): 259 | for entry in self.entries: 260 | answer = entry['answer'] 261 | candi_ans = entry['candi_ans'] 262 | top20 = torch.from_numpy(np.array(candi_ans['top20'])) 263 | entry['candi_ans']['top20'] = top20 264 | top20_scores = torch.from_numpy(np.array(candi_ans['top20_scores'])) 265 | entry['candi_ans']['top20_scores'] = top20_scores 266 | with open('data4VE/R_'+name+'_top20_densecaption_tokenizer_ids.pkl', 'wb') as f: 267 | pickle.dump(self.entries, f) 268 | 269 | def __getitem__(self, index): 270 | entry = self.entries[index] 271 | if not self.adaptive: 272 | features = torch.from_numpy(np.array(self.features[entry['image']])) 273 | spatials = torch.from_numpy(np.array(self.spatials[entry['image']])) 274 | 275 | question_text = entry['question'] 276 | question_id = entry['question_id'] 277 | answer = entry['answer'] 278 | candi_ans = entry['condi_ans'] 279 | 280 | if None != answer: 281 | labels = answer['labels'] 282 | scores = answer['scores'] 283 | ans_type = answer['answer_type'] 284 | target = candi_ans['top20_scores'][:self.candi_ans_num] 285 | qa_text = candi_ans['20_qa_text'][:self.candi_ans_num] 286 | topN_id = candi_ans['top20'][:self.candi_ans_num] 287 | LMH_bias = entry["bias"][:self.candi_ans_num] 288 | return features, spatials, target, question_id, qa_text, topN_id, ans_type, question_text, LMH_bias#entry["bias"] 289 | else: 290 | return features, spatials, question_id 291 | 292 | def __len__(self): 293 | return len(self.entries) 294 | 295 | 296 | if __name__ == '__main__': 297 | 298 | from torch.utils.data import DataLoader 299 | 300 | dataroot = './data/vqacp2/' 301 | img_root = './data/coco/' 302 | dictionary = Dictionary.load_from_file(dataroot + 'dictionary.pkl') 303 | print(dictionary) 304 | train_dset = VQAFeatureDataset('train', dictionary, dataroot, img_root, ratio=1.0, adaptive=False) 305 | 306 | loader = DataLoader(train_dset, 256, shuffle=True, num_workers=1, collate_fn=utils.trim_collate) 307 | 308 | for v, b, q, a, qid in loader: 309 | print(a.shape) 310 | -------------------------------------------------------------------------------- /SAR_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import progressbar 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | 9 | from collections import defaultdict, Counter 10 | import opts_SAR as opts 11 | #R 12 | #from SAR_replace_dataset_vqacp import Dictionary, VQAFeatureDataset 13 | #C 14 | from SAR_concatenate_dataset_vqacp import Dictionary, VQAFeatureDataset 15 | 16 | from LMH_lxmert_model import Model as LXM_Model 17 | from lxmert_model import Model 18 | from QTD_model import Model as Model2 19 | import utils 20 | 21 | 22 | 23 | def get_question(q, dataloader): 24 | str = [] 25 | dictionary = dataloader.dataset.dictionary 26 | for i in range(q.size(0)): 27 | str.append(dictionary.idx2word[q[i]] if q[i] < len(dictionary.idx2word) else '_') 28 | return ' '.join(str) 29 | 30 | 31 | def get_answer(p, dataloader, topN_id, all_N): 32 | _m, idx = p[:all_N].max(0) 33 | idx = topN_id[idx] 34 | return dataloader.dataset.label2ans[idx.item()] 35 | 36 | 37 | @torch.no_grad() 38 | def get_logits(model, model2, dataloader, opt): 39 | N = len(dataloader.dataset) 40 | M = dataloader.dataset.num_ans_candidates 41 | K = 36 42 | pred = torch.FloatTensor(N, M).zero_() 43 | batch_topN_id = torch.IntTensor(N, M).zero_() 44 | qIds = torch.IntTensor(N).zero_() 45 | idx = 0 46 | all_N = [] 47 | bar = progressbar.ProgressBar(maxval=N or None).start() 48 | counting2stop = 0 49 | count_l = [0]*20 50 | for v, b, a, q_id,qa_text, topN_id, ans_type, ques_text,bias in iter(dataloader): 51 | batch_N = [] 52 | counting2stop=counting2stop+1 53 | bar.update(idx) 54 | batch_size = v.size(0) 55 | v = v.cuda() 56 | b = b.cuda() 57 | qa_text = qa_text.cuda() 58 | topN_id = topN_id.cuda() 59 | if opt.lp == 0: 60 | logits = model(qa_text, v, b, 0, 'test') 61 | elif opt.lp == 1: 62 | logits = model(qa_text, v, b, 0, 'test') 63 | elif opt.lp == 2: 64 | logits,_ = model(qa_text, v, b, 0, 'test',bias,a) 65 | mask = model2(ques_text) 66 | for i in mask: 67 | l = i.tolist() 68 | ind = l.index(max(l)) 69 | count_l[ind]=count_l[ind]+1 70 | if ind == 0: 71 | #N' for yes/no question 72 | batch_N.append(opt.QTD_N4yesno) 73 | elif ind == 1: 74 | #N' for non-yes/no question 75 | batch_N.append(opt.QTD_N4non_yesno) 76 | else: 77 | assert 1==0 78 | 79 | pred[idx:idx+batch_size,:].copy_(logits.data) 80 | qIds[idx:idx+batch_size].copy_(q_id) 81 | all_N = all_N + batch_N 82 | batch_topN_id[idx:idx+batch_size].copy_(topN_id.data) 83 | 84 | idx += batch_size 85 | bar.update(idx) 86 | return pred, qIds, batch_topN_id, all_N 87 | 88 | 89 | def make_json(logits, qIds, dataloader, topN_id, all_N): 90 | utils.assert_eq(logits.size(0), len(qIds)) 91 | 92 | results = [] 93 | for i in range(logits.size(0)): 94 | result = {} 95 | result['question_id'] = qIds[i].item() 96 | result['answer'] = get_answer(logits[i], dataloader, topN_id[i], all_N[i]) 97 | results.append(result) 98 | return results 99 | 100 | if __name__ == '__main__': 101 | opt = opts.parse_opt() 102 | 103 | torch.backends.cudnn.benchmark = True 104 | 105 | dictionary = Dictionary.load_from_file(opt.dataroot + 'dictionary.pkl') 106 | opt.ntokens = dictionary.ntoken 107 | eval_dset = VQAFeatureDataset('test', dictionary, opt.dataroot, opt.img_root,ratio=1.0, adaptive=False,opt=opt) 108 | train_dset = VQAFeatureDataset('train', dictionary, opt.dataroot, opt.img_root, ratio=opt.ratio, adaptive=False,opt=opt) 109 | 110 | 111 | n_device = torch.cuda.device_count() 112 | batch_size = opt.batch_size * n_device 113 | 114 | 115 | if int(opt.lp) == 0: 116 | model = Model(opt) 117 | elif int(opt.lp) == 1: 118 | model = Model(opt) 119 | elif int(opt.lp) == 2: 120 | model = LXM_Model(opt) 121 | model = model.cuda() 122 | model2 = Model2(opt) 123 | model2 = model2.cuda() 124 | answer_voc_size = opt.ans_dim#train_dset.num_ans_candidates 125 | # Compute the bias: 126 | # The bias here is just the expected score for each answer/question type 127 | 128 | # question_type -> answer -> total score 129 | question_type_to_probs = defaultdict(Counter) 130 | # question_type -> num_occurances 131 | question_type_to_count = Counter() 132 | for ex in train_dset.entries: 133 | ans = ex["answer"] 134 | q_type = ans["question_type"] 135 | question_type_to_count[q_type] += 1 136 | if ans["labels"] is not None: 137 | for label, score in zip(ans["labels"], ans["scores"]): 138 | question_type_to_probs[q_type][label] += score 139 | 140 | question_type_to_prob_array = {} 141 | for q_type, count in question_type_to_count.items(): 142 | prob_array = np.zeros(answer_voc_size, np.float32) 143 | for label, total_score in question_type_to_probs[q_type].items(): 144 | prob_array[label] += total_score 145 | prob_array /= count 146 | question_type_to_prob_array[q_type] = prob_array 147 | 148 | # Now add a `bias` field to each example 149 | for ds in [train_dset, eval_dset]: 150 | for ex in ds.entries: 151 | q_type = ex["answer"]["question_type"] 152 | candi_top20_prob_array = np.zeros(20, np.float32) 153 | for i in range(len(candi_top20_prob_array)): 154 | candi_top20_prob_array[i]=question_type_to_prob_array[q_type][ex['condi_ans']['top20'][i]] 155 | ex['bias'] = candi_top20_prob_array 156 | 157 | 158 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=1, collate_fn=utils.trim_collate) 159 | 160 | def process(args, model, model2, eval_loader): 161 | 162 | model_data = torch.load(opt.checkpoint_path4test) 163 | model2_data = torch.load(opt.checkpoint_path4test_QTDmodel) 164 | 165 | model.load_state_dict(model_data.get('model_state', model_data)) 166 | model2.load_state_dict(model2_data.get('model_state', model2_data)) 167 | model = nn.DataParallel(model).cuda() 168 | model2 = nn.DataParallel(model2).cuda() 169 | opt.s_epoch = model_data['epoch'] + 1 170 | 171 | model.train(False) 172 | model2.train(False) 173 | 174 | logits, qIds, topN_id, all_N = get_logits(model, model2, eval_loader, opt) 175 | results = make_json(logits, qIds, eval_loader, topN_id, all_N) 176 | model_label = opt.label 177 | 178 | if opt.logits: 179 | utils.create_dir('logits/'+model_label) 180 | torch.save(logits, 'logits/'+model_label+'/logits%d.pth' % opt.s_epoch) 181 | 182 | utils.create_dir(opt.output) 183 | if 0 <= opt.s_epoch: 184 | model_label += '_epoch%d' % opt.s_epoch 185 | 186 | if opt.lp == 0: 187 | test_type = "-SAR" 188 | elif opt.lp == 1: 189 | test_type = "-SAR+SSL" 190 | elif opt.lp == 2: 191 | test_type = "-SAR+LMH" 192 | with open(opt.output+'/top'+str(opt.QTD_N4yesno)+'_'+str(opt.QTD_N4non_yesno)+test_type+'_answers_test_%s.json' \ 193 | % (model_label), 'w') as f: 194 | json.dump(results, f) 195 | 196 | process(opt, model, model2, eval_loader) 197 | -------------------------------------------------------------------------------- /SAR_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import itertools 4 | from torch.autograd import Variable 5 | import random 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import utils 10 | import numpy as np 11 | import json 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | 14 | 15 | # standard cross-entropy loss 16 | def instance_bce(logits, labels): 17 | assert logits.dim() == 2 18 | cross_entropy_loss = nn.CrossEntropyLoss() 19 | prediction_ans_k, top_ans_ind = torch.topk(F.softmax(labels, dim=-1), k=1, dim=-1, sorted=False) 20 | ce_loss = cross_entropy_loss(logits, top_ans_ind.squeeze(-1)) 21 | 22 | return ce_loss 23 | 24 | # multi-label soft loss 25 | def instance_bce_with_logits(logits, labels, reduction='mean'): 26 | assert logits.dim() == 2 27 | loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels, reduction=reduction) 28 | if reduction == 'mean': 29 | loss *= labels.size(1) 30 | return loss 31 | 32 | def compute_score_with_logits(logits, labels): 33 | logits = torch.max(logits, 1)[1].data # argmax 34 | one_hots = torch.zeros(*labels.size()).cuda() 35 | one_hots.scatter_(1, logits.view(-1, 1), 1) 36 | scores = (one_hots * labels) 37 | return scores 38 | def compute_TopKscore_with_logits(logits, labels,n): 39 | prediction_ans_k, top_ans_ind = torch.topk(F.softmax(logits, dim=-1), k=n, dim=-1, sorted=True) 40 | logits_ind = top_ans_ind 41 | one_hots = torch.zeros(*labels.size()).cuda() 42 | one_hots.scatter_(1, logits_ind.view(-1, n), 1) 43 | scores = (one_hots * labels) 44 | scores = torch.max(scores, 1)[0].data 45 | return scores 46 | def compute_self_loss(logits_neg, a): 47 | prediction_ans_k, top_ans_ind = torch.topk(F.softmax(a, dim=-1), k=1, dim=-1, sorted=False) 48 | neg_top_k = torch.gather(F.softmax(logits_neg, dim=-1), 1, top_ans_ind).sum(1) 49 | qice_loss = neg_top_k.mean() 50 | return qice_loss 51 | 52 | 53 | 54 | 55 | def train(model, train_loader, eval_loader, opt): 56 | utils.create_dir(opt.output) 57 | optim = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, betas=(0.9, 0.999), eps=1e-08, 58 | weight_decay=opt.weight_decay) 59 | 60 | 61 | logger = utils.Logger(os.path.join(opt.output, 'log.txt')) 62 | 63 | utils.print_model(model, logger) 64 | for param_group in optim.param_groups: 65 | param_group['lr'] = opt.learning_rate 66 | 67 | scheduler = MultiStepLR(optim, milestones=[100], gamma=0.8) 68 | 69 | scheduler.last_epoch = opt.s_epoch 70 | 71 | 72 | 73 | best_eval_score = 0 74 | for epoch in range(opt.s_epoch, opt.num_epochs): 75 | total_loss = 0 76 | total_norm = 0 77 | count_norm = 0 78 | train_score = 0 79 | t = time.time() 80 | N = len(train_loader.dataset) 81 | scheduler.step() 82 | 83 | for i, (v, b, a, _, qa_text, _, _, q_t, bias) in enumerate(train_loader): 84 | v = v.cuda() 85 | b = b.cuda() 86 | a = a.cuda() 87 | bias = bias.cuda() 88 | qa_text = qa_text.cuda() 89 | rand_index = random.sample(range(0, opt.train_candi_ans_num), opt.train_candi_ans_num) 90 | qa_text = qa_text[:,rand_index,:] 91 | a = a[:,rand_index] 92 | bias = bias[:,rand_index] 93 | 94 | if opt.lp == 0: 95 | logits = model(qa_text, v, b, epoch, 'train') 96 | loss = instance_bce_with_logits(logits, a, reduction='mean') 97 | elif opt.lp == 1: 98 | logits = model(qa_text, v, b, epoch, 'train') 99 | loss_pos = instance_bce_with_logits(logits, a, reduction='mean') 100 | index = random.sample(range(0, v.shape[0]), v.shape[0]) 101 | v_neg = v[index] 102 | b_neg = b[index] 103 | logits_neg = model(qa_text, v_neg, b_neg, epoch, 'train') 104 | self_loss = compute_self_loss(logits_neg, a) 105 | loss = loss_pos + opt.self_loss_weight * self_loss 106 | elif opt.lp == 2: 107 | logits, loss = model(qa_text, v, b, epoch, 'train', bias, a) 108 | else: 109 | assert 1==2 110 | 111 | loss.backward() 112 | 113 | total_norm += nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) 114 | count_norm += 1 115 | 116 | optim.step() 117 | optim.zero_grad() 118 | 119 | score = compute_score_with_logits(logits, a.data).sum() 120 | train_score += score.item() 121 | total_loss += loss.item() * v.size(0) 122 | 123 | if i != 0 and i % 100 == 0: 124 | print( 125 | 'training: %d/%d, train_loss: %.6f, train_acc: %.6f' % 126 | (i, len(train_loader), total_loss / (i * v.size(0)), 127 | 100 * train_score / (i * v.size(0)))) 128 | total_loss /= N 129 | if None != eval_loader: 130 | model.train(False) 131 | eval_score, bound = evaluate(model, eval_loader, opt) 132 | model.train(True) 133 | 134 | logger.write('\nlr: %.7f' % optim.param_groups[0]['lr']) 135 | logger.write('epoch %d, time: %.2f' % (epoch, time.time() - t)) 136 | logger.write( 137 | '\ttrain_loss: %.2f, norm: %.4f, score: %.2f' % (total_loss, total_norm / count_norm, train_score)) 138 | if eval_loader is not None: 139 | logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound)) 140 | 141 | 142 | if (eval_loader is not None and eval_score > best_eval_score): 143 | if opt.lp == 0: 144 | model_path = os.path.join(opt.output, 'SAR_top'+str(opt.train_candi_ans_num)+'_best_model.pth') 145 | elif opt.lp == 1: 146 | model_path = os.path.join(opt.output, 'SAR_SSL_top'+str(opt.train_candi_ans_num)+'_best_model.pth') 147 | elif opt.lp == 2: 148 | model_path = os.path.join(opt.output, 'SAR_LMH_top'+str(opt.train_candi_ans_num)+'_best_model.pth') 149 | utils.save_model(model_path, model, epoch, optim) 150 | if eval_loader is not None: 151 | best_eval_score = eval_score 152 | @torch.no_grad() 153 | def evaluate(model, dataloader, opt): 154 | score = 0 155 | 156 | score_ini_num_list=[] 157 | for num in range(opt.test_candi_ans_num): 158 | score_ini_num_list.append(0) 159 | upper_bound = 0 160 | num_data = 0 161 | entropy = 0 162 | for i, (v, b, a, q_id, qa_text, _, _, q_t, bias) in enumerate(dataloader): 163 | v = v.cuda() 164 | b = b.cuda() 165 | bias = bias.cuda() 166 | a = a.cuda() 167 | q_id = q_id.cuda() 168 | qa_text = qa_text.cuda() 169 | if opt.lp == 0: 170 | logits = model(qa_text, v, b, 0, 'test') 171 | elif opt.lp == 1: 172 | logits = model(qa_text, v, b, 0, 'test') 173 | elif opt.lp == 2: 174 | logits, _ = model(qa_text, v, b, 0, 'test', bias, a) 175 | pred = logits 176 | batch_score = compute_score_with_logits(pred, a.cuda()).sum() 177 | score += batch_score.item() 178 | for num in range(opt.test_candi_ans_num): 179 | batch_score_num = compute_TopKscore_with_logits(pred, a.cuda(), num+1).sum() 180 | score_ini_num_list[num] += batch_score_num.item() 181 | upper_bound += (a.max(1)[0]).sum().item() 182 | num_data += pred.size(0) 183 | 184 | score = score / len(dataloader.dataset) 185 | score_num_list = [] 186 | for score_num in score_ini_num_list: 187 | score_num = score_num / len(dataloader.dataset) 188 | score_num_list.append(score_num) 189 | upper_bound = upper_bound / len(dataloader.dataset) 190 | 191 | return score, upper_bound#, entropy 192 | 193 | 194 | def calc_entropy(att): # size(att) = [b x v x q] 195 | sizes = att.size() 196 | eps = 1e-8 197 | # att = att.unsqueeze(-1) 198 | p = att.view(-1, sizes[1] * sizes[2]) 199 | return (-p * (p + eps).log()).sum(1).sum(0) # g 200 | 201 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | from fc import FCNet, GTH, get_norm 5 | 6 | 7 | # Default concat, 1 layer, output layer 8 | class Att_0(nn.Module): 9 | def __init__(self, v_dim, q_dim, num_hid, norm, act, dropout=0.0): 10 | super(Att_0, self).__init__() 11 | norm_layer = get_norm(norm) 12 | self.nonlinear = FCNet([v_dim + q_dim, num_hid], dropout= dropout, norm= norm, act= act) 13 | self.linear = norm_layer(nn.Linear(num_hid, 1), dim=None) 14 | 15 | def forward(self, v, q): 16 | """ 17 | v: [batch, k, vdim] 18 | q: [batch, qdim] 19 | """ 20 | logits = self.logits(v, q) 21 | w = nn.functional.softmax(logits, 1) 22 | return w 23 | 24 | def logits(self, v, q): 25 | num_objs = v.size(1) 26 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 27 | vq = torch.cat((v, q), 2) 28 | joint_repr = self.nonlinear(vq) 29 | logits = self.linear(joint_repr) 30 | return logits 31 | 32 | 33 | # concat, 2 layer, output layer 34 | class Att_1(nn.Module): 35 | def __init__(self, v_dim, q_dim, num_hid, norm, act, dropout=0.0): 36 | super(Att_1, self).__init__() 37 | norm_layer = get_norm(norm) 38 | self.nonlinear = FCNet([v_dim + q_dim, num_hid, num_hid], dropout= dropout, norm= norm, act= act) 39 | self.linear = norm_layer(nn.Linear(num_hid, 1), dim=None) 40 | 41 | def forward(self, v, q): 42 | """ 43 | v: [batch, k, vdim] 44 | q: [batch, qdim] 45 | """ 46 | logits = self.logits(v, q) 47 | w = nn.functional.softmax(logits, 1) 48 | return w 49 | 50 | def logits(self, v, q): 51 | num_objs = v.size(1) 52 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 53 | vq = torch.cat((v, q), 2) 54 | joint_repr = self.nonlinear(vq) 55 | logits = self.linear(joint_repr) 56 | return logits 57 | 58 | 59 | # 1 layer seperate, element-wise *, output layer 60 | class Att_2(nn.Module): 61 | def __init__(self, v_dim, q_dim, num_hid, norm, act, dropout=0.0): 62 | super(Att_2, self).__init__() 63 | norm_layer = get_norm(norm) 64 | self.v_proj = FCNet([v_dim, num_hid], dropout= dropout, norm= norm, act= act) 65 | self.q_proj = FCNet([q_dim, num_hid], dropout= dropout, norm= norm, act= act) 66 | self.linear = norm_layer(nn.Linear(q_dim, 1), dim=None) 67 | 68 | def forward(self, v, q): 69 | """ 70 | v: [batch, k, vdim] 71 | q: [batch, qdim] 72 | """ 73 | logits = self.logits(v, q) 74 | w = nn.functional.softmax(logits, 1) 75 | return w 76 | 77 | def logits(self, v, q): 78 | batch, k, _ = v.size() 79 | v_proj = self.v_proj(v) # [batch, k, num_hid] 80 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) # [batch, k, num_hid] 81 | joint_repr = v_proj * q_proj 82 | logits = self.linear(joint_repr) 83 | return logits 84 | 85 | 86 | # 1 layer seperate, element-wise *, 1 layer seperate, output layer 87 | class Att_3(nn.Module): 88 | def __init__(self, v_dim, q_dim, num_hid, norm, act, dropout=0.0): 89 | super(Att_3, self).__init__() 90 | norm_layer = get_norm(norm) 91 | self.v_proj = FCNet([v_dim, num_hid], dropout= dropout, norm= norm, act= act) 92 | self.q_proj = FCNet([q_dim, num_hid], dropout= dropout, norm= norm, act= act) 93 | self.nonlinear = FCNet([num_hid, num_hid], dropout= dropout, norm= norm, act= act) 94 | self.linear = norm_layer(nn.Linear(num_hid, 1), dim=None) 95 | 96 | def forward(self, v, q): 97 | """ 98 | v: [batch, k, vdim] 99 | q: [batch, qdim] 100 | """ 101 | logits = self.logits(v, q) 102 | w = nn.functional.softmax(logits, 1) 103 | return w 104 | 105 | def logits(self, v, q): 106 | batch, k, _ = v.size() 107 | v_proj = self.v_proj(v) # [batch, k, num_hid] 108 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) # [batch, k, num_hid] 109 | joint_repr = v_proj * q_proj 110 | joint_repr = self.nonlinear(joint_repr) 111 | logits = self.linear(joint_repr) 112 | return logits 113 | 114 | # 1 layer seperate, element-wise *, 1 layer seperate, output layer 115 | class Att_3S(nn.Module): 116 | def __init__(self, v_dim, q_dim, num_hid, norm, act, dropout=0.0): 117 | super(Att_3S, self).__init__() 118 | norm_layer = get_norm(norm) 119 | self.v_proj = FCNet([v_dim, num_hid], dropout=dropout, norm=norm, act=act) 120 | self.q_proj = FCNet([q_dim, num_hid], dropout=dropout, norm=norm, act=act) 121 | self.nonlinear = FCNet([num_hid, num_hid], dropout=dropout, norm=norm, act=act) 122 | self.linear = norm_layer(nn.Linear(num_hid, 1), dim=None) 123 | 124 | def forward(self, v, q): 125 | """ 126 | v: [batch, k, vdim] 127 | q: [batch, qdim] 128 | """ 129 | logits = self.logits(v, q) 130 | w = nn.functional.sigmoid(logits) 131 | #w = nn.functional.leaky_relu(logits) 132 | return w 133 | 134 | def logits(self, v, q): 135 | batch, k, _ = v.size() 136 | v_proj = self.v_proj(v) # [batch, k, num_hid] 137 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) # [batch, k, num_hid] 138 | joint_repr = v_proj * q_proj 139 | joint_repr = self.nonlinear(joint_repr) 140 | logits = self.linear(joint_repr) 141 | return logits 142 | 143 | 144 | # concat w/ 2 layer seperate, element-wise *, output layer 145 | class Att_PD(nn.Module): 146 | def __init__(self, v_dim, q_dim, num_hid, norm, act, dropout=0.0): 147 | super(Att_PD, self).__init__() 148 | norm_layer = get_norm(norm) 149 | self.nonlinear = FCNet([v_dim + q_dim, num_hid, num_hid], dropout= dropout, norm= norm, act= act) 150 | self.nonlinear_gate = FCNet([v_dim + q_dim, num_hid, num_hid], dropout= dropout, norm= norm, act= 'Sigmoid') 151 | self.linear = norm_layer(nn.Linear(num_hid, 1), dim=None) 152 | 153 | def forward(self, v, q): 154 | """ 155 | v: [batch, k, vdim] 156 | q: [batch, qdim] 157 | """ 158 | logits = self.logits(v, q) 159 | w = nn.functional.softmax(logits, 1) 160 | return w 161 | 162 | def logits(self, v, q): 163 | num_objs = v.size(1) 164 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 165 | vq = torch.cat((v, q), 2) 166 | joint_repr = self.nonlinear(vq) 167 | gate = self.nonlinear_gate(vq) 168 | logits = joint_repr*gate 169 | logits = self.linear(logits) 170 | return logits 171 | 172 | 173 | # concat w/ 1 layer seperate, element-wise *, output layer 174 | class Att_P(nn.Module): 175 | def __init__(self, v_dim, q_dim, num_hid, norm, act, dropout=0.0): 176 | super(Att_P, self).__init__() 177 | norm_layer = get_norm(norm) 178 | 179 | self.gated_tanh = GTH( in_dim= v_dim + q_dim, out_dim= num_hid, dropout= dropout, norm= norm, act= act) 180 | self.linear = norm_layer(nn.Linear(num_hid, 1), dim=None) 181 | 182 | def forward(self, v, q): 183 | """ 184 | v: [batch, k, vdim] 185 | q: [batch, qdim] 186 | """ 187 | logits = self.logits(v, q) 188 | w = nn.functional.softmax(logits, 1) 189 | return w 190 | 191 | def logits(self, v, q): 192 | num_objs = v.size(1) 193 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 194 | vq = torch.cat((v, q), 2) 195 | joint_repr = self.gated_tanh(vq) 196 | logits = self.linear(joint_repr) 197 | return logits 198 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.weight_norm import weight_norm 3 | from fc import FCNet, GTH, get_act, get_norm 4 | 5 | class SimpleClassifier(nn.Module): 6 | def __init__(self, in_dim, hid_dim, out_dim, norm, act, dropout=0.5): 7 | super(SimpleClassifier, self).__init__() 8 | 9 | norm_layer = get_norm(norm) 10 | act_layer = get_act(act) 11 | 12 | layers = [ 13 | norm_layer(nn.Linear(in_dim, hid_dim), dim=None), 14 | act_layer(), 15 | nn.Dropout(dropout, inplace=False), 16 | norm_layer(nn.Linear(hid_dim, out_dim), dim=None) 17 | ] 18 | self.main = nn.Sequential(*layers) 19 | 20 | def forward(self, x): 21 | logits = self.main(x) 22 | return logits 23 | 24 | 25 | 26 | class PaperClassifier(nn.Module): 27 | def __init__(self, in_dim, hid_dim_1, hid_dim_2, out_dim, norm, act, dropout=0.5): 28 | super(PaperClassifier, self).__init__() 29 | 30 | no_norm = lambda x, dim: x 31 | if norm == 'weight': 32 | norm_layer = weight_norm 33 | elif norm == 'batch': 34 | norm_layer = nn.BatchNorm1d 35 | elif norm == 'layer': 36 | norm_layer = nn.LayerNorm 37 | elif norm == 'none': 38 | norm_layer = no_norm 39 | else: 40 | print("Invalid Normalization") 41 | raise Exception("Invalid Normalization") 42 | 43 | 44 | self.gated_tanh_1 = GTH(in_dim=in_dim, out_dim=hid_dim_1, dropout=dropout, norm=norm, act=act) 45 | self.gated_tanh_2 = GTH(in_dim=in_dim, out_dim=hid_dim_2, dropout=dropout, norm=norm, act=act) 46 | 47 | self.linear_1 = norm_layer(nn.Linear(hid_dim_1, out_dim), dim=None) 48 | self.linear_2 = norm_layer(nn.Linear(hid_dim_2, out_dim), dim=None) 49 | 50 | def forward(self, x): 51 | v_1 = self.gated_tanh_1(x) 52 | v_2 = self.gated_tanh_2(x) 53 | 54 | v_1 = self.linear_1(v_1) 55 | v_2 = self.linear_2(v_2) 56 | 57 | logits = v_1 + v_2 58 | return logits 59 | class PaperClassifier1(nn.Module): 60 | def __init__(self, in_dim, hid_dim_1, hid_dim_2, out_dim, norm, act, dropout=0.5): 61 | super(PaperClassifier1, self).__init__() 62 | 63 | no_norm = lambda x, dim: x 64 | if norm == 'weight': 65 | norm_layer = weight_norm 66 | elif norm == 'batch': 67 | norm_layer = nn.BatchNorm1d 68 | elif norm == 'layer': 69 | norm_layer = nn.LayerNorm 70 | elif norm == 'none': 71 | norm_layer = no_norm 72 | else: 73 | print("Invalid Normalization") 74 | raise Exception("Invalid Normalization") 75 | 76 | 77 | self.gated_tanh_1 = FCNet([in_dim, hid_dim_1], dropout=dropout, norm=norm, act=act) 78 | self.gated_tanh_2 = FCNet([in_dim, hid_dim_2], dropout=dropout, norm=norm, act=act) 79 | 80 | self.linear_1 = norm_layer(nn.Linear(hid_dim_1, out_dim), dim=None) 81 | self.linear_2 = norm_layer(nn.Linear(hid_dim_2, out_dim), dim=None) 82 | 83 | def forward(self, x): 84 | v_1 = self.gated_tanh_1(x) 85 | v_2 = self.gated_tanh_2(x) 86 | 87 | v_1 = self.linear_1(v_1) 88 | v_2 = self.linear_2(v_2) 89 | 90 | logits = v_1 + v_2 91 | return logits 92 | 93 | 94 | 95 | class ImageClassifier(nn.Module): 96 | def __init__(self, in_dim, hid_dim, out_dim, norm, act, dropout=0.5): 97 | super(ImageClassifier, self).__init__() 98 | 99 | norm_layer = get_norm(norm) 100 | act_layer = get_act(act) 101 | 102 | layers = [ 103 | norm_layer(nn.Linear(in_dim, hid_dim), dim=None), 104 | act_layer(), 105 | nn.Dropout(dropout, inplace=False), 106 | norm_layer(nn.Linear(hid_dim, out_dim), dim=None) 107 | ] 108 | self.main = nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | logits = self.main(x) 112 | return logits -------------------------------------------------------------------------------- /comput_score.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | import json 3 | import os 4 | import torch 5 | import argparse 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--input', type=str, default='saved_models/test_epochs_17.json') 10 | parser.add_argument('--name', type=str, default='test') 11 | parser.add_argument('--dataroot', type=str, default='../../SSL-VQA/data/vqacp2/cache') 12 | 13 | args = parser.parse_args() 14 | return args 15 | 16 | if __name__ == '__main__': 17 | 18 | args = parse_args() 19 | 20 | anno_path = osp.join(args.dataroot, '%s_target_count.pth'%(args.name)) 21 | annotations = torch.load(anno_path) 22 | 23 | annotations = sorted(annotations, key=lambda x: x['question_id']) 24 | print(annotations[0]) 25 | print(len(annotations)) 26 | predictions = sorted(json.load(open(args.input)), key=lambda x: x['question_id']) 27 | 28 | score = 0 29 | count = 0 30 | other_score = 0 31 | yes_no_score = 0 32 | num_score = 0 33 | yes_count = 0 34 | other_count = 0 35 | num_count = 0 36 | upper_bound = 0 37 | upper_bound_num = 0 38 | upper_bound_yes_no = 0 39 | upper_bound_other = 0 40 | 41 | for pred, anno in zip(predictions, annotations): 42 | if pred['question_id'] == anno['question_id']: 43 | G_T= max(anno['answer_count'].values()) 44 | upper_bound += min(1, G_T / 3) 45 | if pred['answer'] in anno['answers_word']: 46 | proba = anno['answer_count'][pred['answer']] 47 | score += min(1, proba / 3) 48 | count +=1 49 | if anno['answer_type'] == 'yes/no': 50 | yes_no_score += min(1, proba / 3) 51 | upper_bound_yes_no += min(1, G_T / 3) 52 | yes_count +=1 53 | if anno['answer_type'] == 'other': 54 | other_score += min(1, proba / 3) 55 | upper_bound_other += min(1, G_T / 3) 56 | other_count +=1 57 | if anno['answer_type'] == 'number': 58 | num_score += min(1, proba / 3) 59 | upper_bound_num += min(1, G_T / 3) 60 | num_count +=1 61 | else: 62 | score += 0 63 | yes_no_score +=0 64 | other_score +=0 65 | num_score +=0 66 | if anno['answer_type'] == 'yes/no': 67 | upper_bound_yes_no += min(1, G_T / 3) 68 | yes_count +=1 69 | if anno['answer_type'] == 'other': 70 | upper_bound_other += min(1, G_T / 3) 71 | other_count +=1 72 | if anno['answer_type'] == 'number': 73 | upper_bound_num += min(1, G_T / 3) 74 | num_count +=1 75 | 76 | 77 | print('count:', count, ' score:', round(score*100/len(annotations),2)) 78 | print('Yes/No:', round(100*yes_no_score/yes_count,2), 'Num:', round(100*num_score/num_count,2), 79 | 'other:', round(100*other_score/other_count,2)) 80 | 81 | print('count:', len(annotations), ' upper_bound:', round(score*upper_bound/len(annotations)),2) 82 | print('upper_bound_Yes/No:', round(100*upper_bound_yes_no/yes_count,2), 'upper_bound_Num:', 83 | round(100 * upper_bound_num/num_count,2), 'upper_bound_other:', round(100*upper_bound_other/other_count,2)) 84 | 85 | -------------------------------------------------------------------------------- /data/create_dictionary.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import argparse 8 | import sys 9 | import json 10 | import _pickle as cPickle 11 | import numpy as np 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | # from dataset import Dictionary 14 | # from utils import get_sent_data 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--task', type=str, default='vqacp2', help='vqacp2 or vqacp1') 20 | parser.add_argument('--dataroot', type=str, default='vqacp2/annotations/') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | class Dictionary(object): 26 | def __init__(self, word2idx=None, idx2word=None): 27 | if word2idx is None: 28 | word2idx = {} 29 | if idx2word is None: 30 | idx2word = [] 31 | self.word2idx = word2idx 32 | self.idx2word = idx2word 33 | 34 | @property 35 | def ntoken(self): 36 | return len(self.word2idx) 37 | 38 | @property 39 | def padding_idx(self): 40 | return len(self.word2idx) 41 | 42 | def tokenize(self, sentence, add_word): 43 | sentence = sentence.lower() 44 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s') 45 | words = sentence.split() 46 | tokens = [] 47 | if add_word: 48 | for w in words: 49 | tokens.append(self.add_word(w)) 50 | else: 51 | for w in words: 52 | # the least frequent word (`bebe`) as UNK for Visual Genome dataset 53 | tokens.append(self.word2idx.get(w, self.padding_idx - 1)) 54 | return tokens 55 | 56 | def dump_to_file(self, path): 57 | cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 58 | print('dictionary dumped to %s' % path) 59 | 60 | @classmethod 61 | def load_from_file(cls, path): 62 | print('loading dictionary from %s' % path) 63 | word2idx, idx2word = cPickle.load(open(path, 'rb')) 64 | d = cls(word2idx, idx2word) 65 | return d 66 | 67 | def add_word(self, word): 68 | if word not in self.word2idx: 69 | self.idx2word.append(word) 70 | self.word2idx[word] = len(self.idx2word) - 1 71 | return self.word2idx[word] 72 | 73 | def __len__(self): 74 | return len(self.idx2word) 75 | 76 | def create_dictionary(dataroot, task='vqacp2'): 77 | dictionary = Dictionary() 78 | if task == 'vqacp2': 79 | files = [ 80 | 'vqacp_v2_test_questions.json', 81 | 'vqacp_v2_train_questions.json' 82 | ] 83 | for path in files: 84 | question_path = os.path.join(dataroot, path) 85 | qs = json.load(open(question_path)) 86 | for q in qs: 87 | dictionary.tokenize(q['question'], True) 88 | else: 89 | files = [ 90 | 'vqacp_v1_test_questions.json', 91 | 'vqacp_v1_train_questions.json' 92 | ] 93 | for path in files: 94 | question_path = os.path.join(dataroot, path) 95 | qs = json.load(open(question_path)) 96 | for q in qs: 97 | dictionary.tokenize(q['question'], True) 98 | 99 | return dictionary 100 | 101 | 102 | def create_glove_embedding_init(idx2word, glove_file): 103 | word2emb = {} 104 | with open(glove_file, 'r') as f: 105 | entries = f.readlines() 106 | emb_dim = len(entries[0].split(' ')) - 1 107 | print('embedding dim is %d' % emb_dim) 108 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 109 | 110 | for entry in entries: 111 | vals = entry.split(' ') 112 | word = vals[0] 113 | vals = list(map(float, vals[1:])) 114 | word2emb[word] = np.array(vals) 115 | for idx, word in enumerate(idx2word): 116 | if word not in word2emb: 117 | continue 118 | weights[idx] = word2emb[word] 119 | return weights, word2emb 120 | 121 | 122 | if __name__ == '__main__': 123 | args = parse_args() 124 | 125 | dictionary_path = os.path.join(args.dataroot, 'dictionary.pkl') 126 | 127 | d = create_dictionary(args.dataroot, args.task) 128 | d.dump_to_file(dictionary_path) 129 | 130 | d = Dictionary.load_from_file(dictionary_path) 131 | emb_dim = 300 132 | glove_file = 'glove/glove.6B.%dd.txt' % emb_dim 133 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 134 | np.save(os.path.join(args.dataroot, 'glove6b_init_%dd.npy' % emb_dim), weights) 135 | -------------------------------------------------------------------------------- /data/download_data.sh: -------------------------------------------------------------------------------- 1 | #download glove 2 | wget http://nlp.stanford.edu/data/glove.6B.zip 3 | unzip glove.6B.zip -d glove 4 | rm glove.6B.zip 5 | 6 | #download the image feature 7 | wget -P coco https://imagecaption.blob.core.windows.net/imagecaption/trainval_36.zip 8 | unzip coco/trainval_36.zip -d coco/ 9 | rm coco/trainval_36.zip 10 | 11 | #download vqacp2 12 | wget -P vqacp2/ https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json 13 | wget -P vqacp2/ https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json 14 | wget -P vqacp2/ https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json 15 | wget -P vqacp2/ https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json 16 | -------------------------------------------------------------------------------- /data/preprocess_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Learning Conditioned Graph Structures for Interpretable Visual Question Answering 3 | Will Norcliffe-Brown and Efstathios Vafeias and Sarah Parisot 4 | https://arxiv.org/abs/1806.07243 5 | This code is written by Will Norcliffe-Brown. 6 | """ 7 | from __future__ import division 8 | from __future__ import print_function 9 | from __future__ import absolute_import 10 | 11 | import os 12 | import argparse 13 | import base64 14 | import numpy as np 15 | import csv 16 | import sys 17 | import h5py 18 | import pandas as pd 19 | import zarr 20 | from tqdm import tqdm 21 | 22 | 23 | csv.field_size_limit(sys.maxsize) 24 | 25 | 26 | def features_to_zarr(phase): 27 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 28 | 'num_boxes', 'boxes', 'features'] 29 | 30 | if phase == 'trainval': 31 | infiles = [ 32 | 'coco/trainval_36/trainval_resnet101_faster_rcnn_genome_36.tsv', 33 | ] 34 | elif phase == 'test': 35 | infiles = [ 36 | 'coco/test2015_36/test2015_resnet101_faster_rcnn_genome_36.tsv', 37 | ] 38 | else: 39 | raise SystemExit('Unrecognised phase') 40 | 41 | # Read the tsv and append to files 42 | boxes = zarr.open_group('coco/' + phase + '_boxes.zarr', mode='w') 43 | features = zarr.open_group('coco/' + phase + '.zarr', mode='w') 44 | image_size = {} 45 | for infile in infiles: 46 | with open(infile, "r") as tsv_in_file: 47 | reader = csv.DictReader( 48 | tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) 49 | print('Converting ' + infile + ' to zarr...') 50 | for item in tqdm(reader): 51 | item['image_id'] = str(item['image_id']) 52 | item['image_h'] = int(item['image_h']) 53 | item['image_w'] = int(item['image_w']) 54 | item['num_boxes'] = int(item['num_boxes']) 55 | for field in ['boxes', 'features']: 56 | encoded_str = base64.decodestring( 57 | item[field].encode('utf-8')) 58 | item[field] = np.frombuffer(encoded_str, 59 | dtype=np.float32).reshape((item['num_boxes'], -1)) 60 | # append to zarr files 61 | boxes.create_dataset(item['image_id'], data=item['boxes']) 62 | features.create_dataset(item['image_id'], data=item['features']) 63 | # image_size dict 64 | image_size[item['image_id']] = { 65 | 'image_h':item['image_h'], 66 | 'image_w':item['image_w'], 67 | } 68 | 69 | 70 | # convert dict to pandas dataframe 71 | 72 | 73 | # create image sizes csv 74 | print('Writing image sizes csv...') 75 | df = pd.DataFrame.from_dict(image_size) 76 | df = df.transpose() 77 | d = df.to_dict() 78 | dw = d['image_w'] 79 | dh = d['image_h'] 80 | d = [dw, dh] 81 | dwh = {} 82 | for k in dw.keys(): 83 | dwh[k] = np.array([d0[k] for d0 in d]) 84 | image_sizes = pd.DataFrame(dwh) 85 | image_sizes.to_csv(phase + '_image_size.csv') 86 | 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser( 90 | description='Preprocessing for VQA v2 image data') 91 | parser.add_argument('--data', nargs='+', help='trainval, and/or test, list of data phases to be processed', required=True) 92 | args, unparsed = parser.parse_known_args() 93 | if len(unparsed) != 0: 94 | raise SystemExit('Unknown argument: {}'.format(unparsed)) 95 | 96 | phase_list = args.data 97 | 98 | for phase in phase_list: 99 | # First download and extract 100 | 101 | if not os.path.exists(phase + '.zarr'): 102 | print('Converting features tsv to zarr file...') 103 | features_to_zarr(phase) 104 | 105 | print('Done') 106 | -------------------------------------------------------------------------------- /data/preprocess_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is slightly modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import sys 8 | import json 9 | import numpy as np 10 | import re 11 | import _pickle as cPickle 12 | 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14 | # from dataset import Dictionary 15 | import utils 16 | import argparse 17 | import torch 18 | 19 | 20 | contractions = { 21 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 22 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 23 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 24 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 25 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 26 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 27 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 28 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 29 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 30 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 31 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 32 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 33 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 34 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 35 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 36 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 37 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 38 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 39 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 40 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 41 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 42 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 43 | "someonell": "someone'll", "someones": "someone's", "somethingd": 44 | "something'd", "somethingd've": "something'd've", "something'dve": 45 | "something'd've", "somethingll": "something'll", "thats": 46 | "that's", "thered": "there'd", "thered've": "there'd've", 47 | "there'dve": "there'd've", "therere": "there're", "theres": 48 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 49 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 50 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 51 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 52 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 53 | "what's", "whatve": "what've", "whens": "when's", "whered": 54 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 55 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 56 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 57 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 58 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 59 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 60 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 61 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 62 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 63 | "you'll", "youre": "you're", "youve": "you've" 64 | } 65 | 66 | manual_map = { 'none': '0', 67 | 'zero': '0', 68 | 'one': '1', 69 | 'two': '2', 70 | 'three': '3', 71 | 'four': '4', 72 | 'five': '5', 73 | 'six': '6', 74 | 'seven': '7', 75 | 'eight': '8', 76 | 'nine': '9', 77 | 'ten': '10'} 78 | articles = ['a', 'an', 'the'] 79 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 80 | comma_strip = re.compile("(\d)(\,)(\d)") 81 | punct = [';', r"/", '[', ']', '"', '{', '}', 82 | '(', ')', '=', '+', '\\', '_', '-', 83 | '>', '<', '@', '`', ',', '?', '!'] 84 | 85 | # See http://visualqa.org/evaluation.html 86 | def get_score(occurences): 87 | return min(occurences/3, 1) 88 | 89 | 90 | def process_punctuation(inText): 91 | outText = inText 92 | for p in punct: 93 | if (p + ' ' in inText or ' ' + p in inText) \ 94 | or (re.search(comma_strip, inText) != None): 95 | outText = outText.replace(p, '') 96 | else: 97 | outText = outText.replace(p, ' ') 98 | outText = period_strip.sub("", outText, re.UNICODE) 99 | return outText 100 | 101 | 102 | def process_digit_article(inText): 103 | outText = [] 104 | tempText = inText.lower().split() 105 | for word in tempText: 106 | word = manual_map.setdefault(word, word) 107 | if word not in articles: 108 | outText.append(word) 109 | else: 110 | pass 111 | for wordId, word in enumerate(outText): 112 | if word in contractions: 113 | outText[wordId] = contractions[word] 114 | outText = ' '.join(outText) 115 | return outText 116 | 117 | 118 | def multiple_replace(text, wordDict): 119 | for key in wordDict: 120 | text = text.replace(key, wordDict[key]) 121 | return text 122 | 123 | 124 | def preprocess_answer(answer): 125 | answer = process_digit_article(process_punctuation(answer)) 126 | answer = answer.replace(',', '') 127 | return answer 128 | 129 | 130 | def filter_answers(answers_dset, min_occurence): 131 | """This will change the answer to preprocessed version 132 | """ 133 | occurence = {} 134 | 135 | for ans_entry in answers_dset: 136 | answers = ans_entry['answers'] 137 | gtruth = ans_entry['multiple_choice_answer'] 138 | gtruth = preprocess_answer(gtruth) 139 | if gtruth not in occurence: 140 | occurence[gtruth] = set() 141 | occurence[gtruth].add(ans_entry['question_id']) 142 | for answer in list(occurence): 143 | if len(occurence[answer]) < min_occurence: 144 | occurence.pop(answer) 145 | 146 | print('Num of answers that appear >= %d times: %d' % ( 147 | min_occurence, len(occurence))) 148 | return occurence 149 | 150 | 151 | def create_ans2label(occurence, name, cache_root='cache'): 152 | """Note that this will also create label2ans.pkl at the same time 153 | occurence: dict {answer -> whatever} 154 | name: prefix of the output file 155 | cache_root: str 156 | """ 157 | ans2label = {} 158 | label2ans = [] 159 | label = 0 160 | for answer in occurence: 161 | label2ans.append(answer) 162 | ans2label[answer] = label 163 | label += 1 164 | 165 | utils.create_dir(cache_root) 166 | 167 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl') 168 | cPickle.dump(ans2label, open(cache_file, 'wb')) 169 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl') 170 | cPickle.dump(label2ans, open(cache_file, 'wb')) 171 | return ans2label 172 | 173 | 174 | def compute_target_for_testing(answers_dset, ans2label, name, cache_root='/cache'): 175 | """compute the score for each answer 176 | Write result into a cache file 177 | """ 178 | target = [] 179 | for ans_entry in answers_dset: 180 | temp = {} 181 | answers = ans_entry['answers'] 182 | answer_count = {} 183 | for answer in answers: 184 | answer_ = answer['answer'] 185 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 186 | 187 | answers_id = [] 188 | scores = [] 189 | answers_word = [] 190 | for answer in answer_count: 191 | if answer not in ans2label: 192 | continue 193 | answers_id.append(ans2label[answer]) 194 | answers_word.append(answer) 195 | scores.append(answer_count[answer]) 196 | 197 | temp['question_id'] = ans_entry['question_id'] 198 | temp['answer_type'] = ans_entry['answer_type'] 199 | temp['answers_word'] = answers_word 200 | temp['answers_id'] = answers_id 201 | temp['answers_count'] = scores 202 | temp['answer_count'] = answer_count 203 | 204 | target.append(temp) 205 | 206 | # utils.create_dir(cache_root) 207 | cache_file = os.path.join(cache_root, name+'_target_count.pth') 208 | torch.save(target, cache_file) 209 | return target 210 | 211 | def compute_target(answers_dset, ans2label, name, cache_root='cache'): 212 | """Augment answers_dset with soft score as label 213 | ***answers_dset should be preprocessed*** 214 | Write result into a cache file 215 | """ 216 | target = [] 217 | for ans_entry in answers_dset: 218 | answers = ans_entry['answers'] 219 | answer_count = {} 220 | for answer in answers: 221 | answer_ = answer['answer'] 222 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 223 | 224 | labels = [] 225 | scores = [] 226 | for answer in answer_count: 227 | if answer not in ans2label: 228 | continue 229 | labels.append(ans2label[answer]) 230 | score = get_score(answer_count[answer]) 231 | scores.append(score) 232 | 233 | target.append({ 234 | 'question_id': ans_entry['question_id'], 235 | 'image_id': ans_entry['image_id'], 236 | 'labels': labels, 237 | 'scores': scores 238 | }) 239 | 240 | utils.create_dir(cache_root) 241 | cache_file = os.path.join(cache_root, name+'_target.pkl') 242 | cPickle.dump(target, open(cache_file, 'wb')) 243 | return target 244 | 245 | 246 | def get_answer(qid, answers): 247 | for ans in answers: 248 | if ans['question_id'] == qid: 249 | return ans 250 | 251 | 252 | def get_question(qid, questions): 253 | for question in questions: 254 | if question['question_id'] == qid: 255 | return question 256 | 257 | 258 | def parse_args(): 259 | parser = argparse.ArgumentParser() 260 | parser.add_argument('--version', type=str, default='v2') 261 | parser.add_argument('--dataroot', type=str, default='vqacp2/') 262 | 263 | args = parser.parse_args() 264 | return args 265 | 266 | if __name__ == '__main__': 267 | args = parse_args() 268 | 269 | train_answer_file = args.dataroot + 'vqacp_%s_train_annotations.json' % (args.version) 270 | train_answers = json.load(open(train_answer_file)) 271 | 272 | val_answer_file = args.dataroot + 'vqacp_%s_test_annotations.json' % (args.version) 273 | val_answers = json.load(open(val_answer_file)) 274 | 275 | train_question_file = args.dataroot + 'vqacp_%s_train_questions.json' % (args.version) 276 | train_questions = json.load(open(train_question_file)) 277 | 278 | val_question_file = args.dataroot + 'vqacp_%s_test_questions.json' % (args.version) 279 | val_questions = json.load(open(val_question_file)) 280 | 281 | answers = train_answers 282 | occurence = filter_answers(answers, 9) 283 | 284 | cache_path = args.dataroot + 'cache/train_test_ans2label.pkl' 285 | if os.path.isfile(cache_path): 286 | print('found %s' % cache_path) 287 | ans2label = cPickle.load(open(cache_path, 'rb')) 288 | else: 289 | ans2label = create_ans2label(occurence, 'train_test', target_root) 290 | 291 | target_root = args.dataroot +'cache' 292 | compute_target(train_answers, ans2label, 'train', target_root) 293 | compute_target(val_answers, ans2label, 'test', target_root) 294 | compute_target_for_testing(val_answers, ans2label, 'test', target_root) 295 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import errno 4 | import os 5 | from PIL import Image 6 | import torch 7 | import torch.nn as nn 8 | import re 9 | 10 | import json 11 | import pickle as cPickle 12 | import numpy as np 13 | import utils 14 | import h5py 15 | import operator 16 | import functools 17 | from torch._six import string_classes 18 | import torch.nn.functional as F 19 | import collections 20 | 21 | #from pycocotools.coco import COCO 22 | # from scipy.sparse import coo_matrix 23 | # from sklearn.metrics.pairwise import cosine_similarity 24 | from torch.utils.data.dataloader import default_collate 25 | 26 | 27 | EPS = 1e-7 28 | 29 | 30 | def assert_eq(real, expected): 31 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 32 | 33 | 34 | def assert_array_eq(real, expected): 35 | assert (np.abs(real-expected) < EPS).all(), \ 36 | '%s (true) vs %s (expected)' % (real, expected) 37 | 38 | 39 | def load_folder(folder, suffix): 40 | imgs = [] 41 | for f in sorted(os.listdir(folder)): 42 | if f.endswith(suffix): 43 | imgs.append(os.path.join(folder, f)) 44 | return imgs 45 | 46 | 47 | def load_imageid(folder): 48 | images = load_folder(folder, 'jpg') 49 | img_ids = set() 50 | for img in images: 51 | img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1]) 52 | img_ids.add(img_id) 53 | return img_ids 54 | 55 | 56 | def pil_loader(path): 57 | with open(path, 'rb') as f: 58 | with Image.open(f) as img: 59 | return img.convert('RGB') 60 | 61 | 62 | def weights_init(m): 63 | """custom weights initialization.""" 64 | cname = m.__class__ 65 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d: 66 | m.weight.data.normal_(0.0, 0.02) 67 | elif cname == nn.BatchNorm2d: 68 | m.weight.data.normal_(1.0, 0.02) 69 | m.bias.data.fill_(0) 70 | else: 71 | print('%s is not initialized.' % cname) 72 | 73 | 74 | def init_net(net, net_file): 75 | if net_file: 76 | net.load_state_dict(torch.load(net_file)) 77 | else: 78 | net.apply(weights_init) 79 | 80 | 81 | def create_dir(path): 82 | if not os.path.exists(path): 83 | try: 84 | os.makedirs(path) 85 | except OSError as exc: 86 | if exc.errno != errno.EEXIST: 87 | raise 88 | 89 | 90 | class Logger(object): 91 | def __init__(self, output_name): 92 | dirname = os.path.dirname(output_name) 93 | if not os.path.exists(dirname): 94 | os.mkdir(dirname) 95 | 96 | self.log_file = open(output_name, 'w') 97 | self.infos = {} 98 | 99 | def append(self, key, val): 100 | vals = self.infos.setdefault(key, []) 101 | vals.append(val) 102 | 103 | def log(self, extra_msg=''): 104 | msgs = [extra_msg] 105 | for key, vals in self.infos.iteritems(): 106 | msgs.append('%s %.6f' % (key, np.mean(vals))) 107 | msg = '\n'.join(msgs) 108 | self.log_file.write(msg + '\n') 109 | self.log_file.flush() 110 | self.infos = {} 111 | return msg 112 | 113 | def write(self, msg): 114 | self.log_file.write(msg + '\n') 115 | self.log_file.flush() 116 | print(msg) 117 | 118 | def print_model(model, logger): 119 | print(model) 120 | nParams = 0 121 | for w in model.parameters(): 122 | nParams += functools.reduce(operator.mul, w.size(), 1) 123 | if logger: 124 | logger.write('nParams=\t'+str(nParams)) 125 | 126 | 127 | def save_model(path, model, epoch, optimizer=None): 128 | model_dict = { 129 | 'epoch': epoch, 130 | 'model_state': model.state_dict() 131 | } 132 | if optimizer is not None: 133 | model_dict['optimizer_state'] = optimizer.state_dict() 134 | 135 | torch.save(model_dict, path) 136 | 137 | def rho_select(pad, lengths): 138 | # Index of the last output for each sequence. 139 | idx_ = (lengths-1).view(-1,1).expand(pad.size(0), pad.size(2)).unsqueeze(1) 140 | extracted = pad.gather(1, idx_).squeeze(1) 141 | return extracted 142 | 143 | def trim_collate(batch): 144 | "Puts each data field into a tensor with outer dimension batch size" 145 | _use_shared_memory = True 146 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 147 | elem_type = type(batch[0]) 148 | if torch.is_tensor(batch[0]): 149 | out = None 150 | if 1 < batch[0].dim(): # image features 151 | max_num_boxes = max([x.size(0) for x in batch]) 152 | if _use_shared_memory: 153 | # If we're in a background process, concatenate directly into a 154 | # shared memory tensor to avoid an extra copy 155 | numel = len(batch) * max_num_boxes * batch[0].size(-1) 156 | storage = batch[0].storage()._new_shared(numel) 157 | out = batch[0].new(storage) 158 | # warning: F.pad returns Variable! 159 | return torch.stack([F.pad(x, (0,0,0,max_num_boxes-x.size(0))).data for x in batch], 0, out=out) 160 | else: 161 | if _use_shared_memory: 162 | # If we're in a background process, concatenate directly into a 163 | # shared memory tensor to avoid an extra copy 164 | numel = sum([x.numel() for x in batch]) 165 | storage = batch[0].storage()._new_shared(numel) 166 | out = batch[0].new(storage) 167 | return torch.stack(batch, 0, out=out) 168 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 169 | and elem_type.__name__ != 'string_': 170 | elem = batch[0] 171 | if elem_type.__name__ == 'ndarray': 172 | # array of string classes and object 173 | if re.search('[SaUO]', elem.dtype.str) is not None: 174 | raise TypeError(error_msg.format(elem.dtype)) 175 | 176 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 177 | if elem.shape == (): # scalars 178 | py_type = float if elem.dtype.name.startswith('float') else int 179 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 180 | elif isinstance(batch[0], int): 181 | return torch.LongTensor(batch) 182 | elif isinstance(batch[0], float): 183 | return torch.DoubleTensor(batch) 184 | elif isinstance(batch[0], string_classes): 185 | return batch 186 | elif isinstance(batch[0], collections.Mapping): 187 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 188 | elif isinstance(batch[0], collections.Sequence): 189 | transposed = zip(*batch) 190 | return [trim_collate(samples) for samples in transposed] 191 | 192 | raise TypeError((error_msg.format(type(batch[0])))) 193 | 194 | 195 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 196 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 197 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 198 | indices = torch.from_numpy( 199 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 200 | values = torch.from_numpy(sparse_mx.data) 201 | shape = torch.Size(sparse_mx.shape) 202 | return torch.sparse.FloatTensor(indices, values, shape) 203 | 204 | 205 | def mask_softmax(x, lengths): # , dim=1) 206 | mask = torch.zeros_like(x).to(device=x.device, non_blocking=True) 207 | t_lengths = lengths[:, :, None].expand_as(mask) 208 | arange_id = torch.arange(mask.size(1)).to(device=x.device, non_blocking=True) 209 | arange_id = arange_id[None, :, None].expand_as(mask) 210 | 211 | mask[arange_id < t_lengths] = 1 212 | # https://stackoverflow.com/questions/42599498/numercially-stable-softmax 213 | # https://stackoverflow.com/questions/34968722/how-to-implement-the-softmax-function-in-python 214 | # exp(x - max(x)) instead of exp(x) is a trick 215 | # to improve the numerical stability while giving 216 | # the same outputs 217 | x2 = torch.exp(x - torch.max(x)) 218 | x3 = x2 * mask 219 | epsilon = 1e-5 220 | x3_sum = torch.sum(x3, dim=1, keepdim=True) + epsilon 221 | x4 = x3 / x3_sum.expand_as(x3) 222 | return x4 223 | 224 | 225 | class GradReverseMask(torch.autograd.Function): 226 | """ 227 | This layer is used to create an adversarial loss. 228 | """ 229 | 230 | @staticmethod 231 | def forward(ctx, x, mask, weight): 232 | """ 233 | The mask should be composed of 0 or 1. 234 | The '1' will get their gradient reversed.. 235 | """ 236 | ctx.save_for_backward(mask) 237 | ctx.weight = weight 238 | return x.view_as(x) 239 | 240 | @staticmethod 241 | def backward(ctx, grad_output): 242 | mask, = ctx.saved_tensors 243 | mask_c = mask.clone().detach().float() 244 | mask_c[mask == 0] = 1.0 245 | mask_c[mask == 1] = - float(ctx.weight) 246 | return grad_output * mask_c[:, None].float(), None, None 247 | 248 | 249 | def grad_reverse_mask(x, mask, weight=1): 250 | return GradReverseMask.apply(x, mask, weight) 251 | 252 | 253 | class GradReverse(torch.autograd.Function): 254 | """ 255 | This layer is used to create an adversarial loss. 256 | """ 257 | 258 | @staticmethod 259 | def forward(ctx, x): 260 | return x.view_as(x) 261 | 262 | @staticmethod 263 | def backward(ctx, grad_output): 264 | return grad_output.neg() 265 | 266 | 267 | def grad_reverse(x): 268 | return GradReverse.apply(x) 269 | 270 | 271 | class GradMulConst(torch.autograd.Function): 272 | """ 273 | This layer is used to create an adversarial loss. 274 | """ 275 | 276 | @staticmethod 277 | def forward(ctx, x, const): 278 | ctx.const = const 279 | return x.view_as(x) 280 | 281 | @staticmethod 282 | def backward(ctx, grad_output): 283 | return grad_output * ctx.const, None 284 | 285 | 286 | def grad_mul_const(x, const): 287 | return GradMulConst.apply(x, const) 288 | -------------------------------------------------------------------------------- /data4VE/offline-QTD_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhoebusSi/SAR/a1934bf6b728edf7149aa8e7fa69233167e512dc/data4VE/offline-QTD_model.pth -------------------------------------------------------------------------------- /data4VE/test_dataset4VE_demo.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question_id": 9001, 4 | "image_id": 9, 5 | "top20": [ 6 | 985, 7 | 661, 8 | 1031, 9 | 240, 10 | 1982, 11 | 2043, 12 | 69, 13 | 714, 14 | 1015, 15 | 1110, 16 | 1286, 17 | 1794, 18 | 977, 19 | 288, 20 | 1939, 21 | 787, 22 | 281, 23 | 170, 24 | 921, 25 | 1997 26 | ], 27 | "top20_scores": [ 28 | 0.0, 29 | 0.0, 30 | 0.0, 31 | 0.0, 32 | 0.0, 33 | 0.0, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 0.0, 38 | 0.0, 39 | 0.0, 40 | 0.0, 41 | 0.0, 42 | 0.0, 43 | 0.0, 44 | 0.0, 45 | 0.0, 46 | 0.0, 47 | 0.0 48 | ] 49 | }, 50 | { 51 | "question_id": 25004, 52 | "image_id": 25, 53 | "top20": [ 54 | 574, 55 | 410, 56 | 840, 57 | 4, 58 | 745, 59 | 1413, 60 | 356, 61 | 624, 62 | 180, 63 | 1367, 64 | 576, 65 | 53, 66 | 1611, 67 | 261, 68 | 1773, 69 | 867, 70 | 1122, 71 | 1631, 72 | 674, 73 | 1648 74 | ], 75 | "top20_scores": [ 76 | 0.3333333432674408, 77 | 0.3333333432674408, 78 | 0.0, 79 | 0.3333333432674408, 80 | 0.0, 81 | 0.0, 82 | 0.0, 83 | 0.0, 84 | 0.3333333432674408, 85 | 0.0, 86 | 0.0, 87 | 0.0, 88 | 0.0, 89 | 0.0, 90 | 0.0, 91 | 0.0, 92 | 0.0, 93 | 0.0, 94 | 0.0, 95 | 0.0 96 | ] 97 | }, 98 | { 99 | "question_id": 25008, 100 | "image_id": 25, 101 | "top20": [ 102 | 577, 103 | 180, 104 | 436, 105 | 549, 106 | 1465, 107 | 956, 108 | 384, 109 | 1085, 110 | 624, 111 | 119, 112 | 1154, 113 | 674, 114 | 1122, 115 | 731, 116 | 1843, 117 | 306, 118 | 940, 119 | 576, 120 | 1141, 121 | 1611 122 | ], 123 | "top20_scores": [ 124 | 1.0, 125 | 1.0, 126 | 0.0, 127 | 0.3333333432674408, 128 | 0.0, 129 | 0.0, 130 | 0.0, 131 | 0.0, 132 | 0.0, 133 | 0.0, 134 | 0.0, 135 | 0.0, 136 | 0.0, 137 | 0.0, 138 | 0.0, 139 | 0.0, 140 | 0.0, 141 | 0.0, 142 | 0.0, 143 | 0.0 144 | ] 145 | }, 146 | { 147 | "question_id": 25011, 148 | "image_id": 25, 149 | "top20": [ 150 | 180, 151 | 577, 152 | 1465, 153 | 549, 154 | 1085, 155 | 436, 156 | 384, 157 | 119, 158 | 956, 159 | 1154, 160 | 1391, 161 | 1141, 162 | 661, 163 | 365, 164 | 1737, 165 | 1359, 166 | 1509, 167 | 547, 168 | 484, 169 | 634 170 | ], 171 | "top20_scores": [ 172 | 1.0, 173 | 0.3333333432674408, 174 | 0.0, 175 | 0.0, 176 | 0.0, 177 | 0.0, 178 | 0.0, 179 | 0.0, 180 | 0.0, 181 | 0.0, 182 | 0.0, 183 | 0.0, 184 | 0.0, 185 | 0.0, 186 | 0.0, 187 | 0.0, 188 | 0.0, 189 | 0.0, 190 | 0.0, 191 | 0.0 192 | ] 193 | }, 194 | { 195 | "question_id": 25012, 196 | "image_id": 25, 197 | "top20": [ 198 | 577, 199 | 180, 200 | 391, 201 | 1085, 202 | 576, 203 | 867, 204 | 651, 205 | 956, 206 | 484, 207 | 549, 208 | 1990, 209 | 916, 210 | 119, 211 | 384, 212 | 1465, 213 | 900, 214 | 436, 215 | 290, 216 | 323, 217 | 940 218 | ], 219 | "top20_scores": [ 220 | 0.0, 221 | 1.0, 222 | 0.0, 223 | 0.0, 224 | 0.0, 225 | 0.0, 226 | 0.0, 227 | 0.0, 228 | 0.0, 229 | 0.0, 230 | 0.0, 231 | 0.0, 232 | 0.0, 233 | 0.0, 234 | 0.0, 235 | 0.0, 236 | 0.3333333432674408, 237 | 0.0, 238 | 0.0, 239 | 0.0 240 | ] 241 | }, 242 | { 243 | "question_id": 30001, 244 | "image_id": 30, 245 | "top20": [ 246 | 1271, 247 | 1725, 248 | 804, 249 | 337, 250 | 936, 251 | 1642, 252 | 1417, 253 | 1009, 254 | 655, 255 | 508, 256 | 221, 257 | 1092, 258 | 262, 259 | 1065, 260 | 81, 261 | 2040, 262 | 388, 263 | 624, 264 | 1196, 265 | 99 266 | ], 267 | "top20_scores": [ 268 | 0.0, 269 | 0.0, 270 | 0.0, 271 | 0.0, 272 | 0.0, 273 | 0.0, 274 | 0.0, 275 | 0.0, 276 | 0.0, 277 | 0.0, 278 | 0.3333333432674408, 279 | 0.0, 280 | 0.0, 281 | 0.0, 282 | 1.0, 283 | 0.0, 284 | 0.0, 285 | 0.0, 286 | 0.0, 287 | 0.0 288 | ] 289 | }, 290 | { 291 | "question_id": 30002, 292 | "image_id": 30, 293 | "top20": [ 294 | 1122, 295 | 674, 296 | 180, 297 | 450, 298 | 577, 299 | 654, 300 | 2228, 301 | 1337, 302 | 2085, 303 | 890, 304 | 676, 305 | 427, 306 | 759, 307 | 303, 308 | 845, 309 | 2110, 310 | 29, 311 | 384, 312 | 702, 313 | 1173 314 | ], 315 | "top20_scores": [ 316 | 1.0, 317 | 0.6666666865348816, 318 | 0.0, 319 | 0.0, 320 | 0.3333333432674408, 321 | 0.0, 322 | 0.0, 323 | 0.0, 324 | 0.0, 325 | 0.0, 326 | 0.0, 327 | 0.0, 328 | 0.0, 329 | 0.0, 330 | 0.0, 331 | 0.0, 332 | 0.0, 333 | 0.0, 334 | 0.0, 335 | 0.0 336 | ] 337 | }, 338 | { 339 | "question_id": 30005, 340 | "image_id": 30, 341 | "top20": [ 342 | 180, 343 | 577, 344 | 772, 345 | 985, 346 | 240, 347 | 549, 348 | 436, 349 | 384, 350 | 1465, 351 | 1031, 352 | 661, 353 | 119, 354 | 233, 355 | 1141, 356 | 170, 357 | 940, 358 | 956, 359 | 1154, 360 | 1737, 361 | 365 362 | ], 363 | "top20_scores": [ 364 | 0.3333333432674408, 365 | 1.0, 366 | 0.0, 367 | 0.0, 368 | 0.0, 369 | 0.0, 370 | 0.0, 371 | 0.0, 372 | 0.0, 373 | 0.0, 374 | 0.0, 375 | 0.0, 376 | 0.0, 377 | 0.0, 378 | 0.0, 379 | 0.0, 380 | 0.0, 381 | 0.0, 382 | 0.0, 383 | 0.0 384 | ] 385 | }, 386 | { 387 | "question_id": 34001, 388 | "image_id": 34, 389 | "top20": [ 390 | 576, 391 | 1151, 392 | 651, 393 | 349, 394 | 916, 395 | 946, 396 | 180, 397 | 290, 398 | 391, 399 | 376, 400 | 1209, 401 | 659, 402 | 444, 403 | 624, 404 | 986, 405 | 1004, 406 | 647, 407 | 591, 408 | 1732, 409 | 363 410 | ], 411 | "top20_scores": [ 412 | 1.0, 413 | 0.0, 414 | 0.0, 415 | 0.0, 416 | 0.0, 417 | 0.0, 418 | 0.0, 419 | 0.0, 420 | 0.0, 421 | 0.0, 422 | 0.0, 423 | 0.0, 424 | 0.0, 425 | 0.0, 426 | 0.0, 427 | 0.0, 428 | 0.0, 429 | 0.0, 430 | 0.0, 431 | 0.0 432 | ] 433 | }, 434 | { 435 | "question_id": 34002, 436 | "image_id": 34, 437 | "top20": [ 438 | 229, 439 | 1929, 440 | 469, 441 | 667, 442 | 1465, 443 | 549, 444 | 2230, 445 | 1391, 446 | 436, 447 | 468, 448 | 1154, 449 | 2148, 450 | 1861, 451 | 384, 452 | 1010, 453 | 662, 454 | 1141, 455 | 1737, 456 | 446, 457 | 718 458 | ], 459 | "top20_scores": [ 460 | 0.3333333432674408, 461 | 0.0, 462 | 0.3333333432674408, 463 | 0.0, 464 | 0.0, 465 | 0.6666666865348816, 466 | 0.3333333432674408, 467 | 0.0, 468 | 0.0, 469 | 0.0, 470 | 0.0, 471 | 0.0, 472 | 0.0, 473 | 0.0, 474 | 0.0, 475 | 0.0, 476 | 0.0, 477 | 0.0, 478 | 0.0, 479 | 0.3333333432674408 480 | ] 481 | }, 482 | { 483 | "question_id": 42000, 484 | "image_id": 42, 485 | "top20": [ 486 | 987, 487 | 407, 488 | 235, 489 | 1359, 490 | 772, 491 | 1982, 492 | 1662, 493 | 1031, 494 | 714, 495 | 240, 496 | 806, 497 | 1846, 498 | 350, 499 | 1109, 500 | 905, 501 | 825, 502 | 1683, 503 | 1953, 504 | 1956, 505 | 1071 506 | ], 507 | "top20_scores": [ 508 | 0.0, 509 | 0.0, 510 | 0.0, 511 | 0.0, 512 | 1.0, 513 | 0.0, 514 | 0.3333333432674408, 515 | 0.0, 516 | 0.0, 517 | 0.0, 518 | 0.0, 519 | 0.0, 520 | 0.0, 521 | 0.0, 522 | 0.0, 523 | 0.0, 524 | 0.0, 525 | 0.0, 526 | 0.0, 527 | 0.0 528 | ] 529 | }, 530 | { 531 | "question_id": 42002, 532 | "image_id": 42, 533 | "top20": [ 534 | 235, 535 | 240, 536 | 1359, 537 | 987, 538 | 1211, 539 | 1499, 540 | 772, 541 | 714, 542 | 806, 543 | 1683, 544 | 407, 545 | 1031, 546 | 250, 547 | 661, 548 | 170, 549 | 1982, 550 | 1510, 551 | 1109, 552 | 853, 553 | 825 554 | ], 555 | "top20_scores": [ 556 | 0.3333333432674408, 557 | 1.0, 558 | 0.0, 559 | 0.0, 560 | 0.0, 561 | 0.0, 562 | 0.0, 563 | 0.0, 564 | 0.0, 565 | 0.0, 566 | 0.0, 567 | 0.0, 568 | 0.0, 569 | 0.0, 570 | 0.0, 571 | 0.0, 572 | 0.0, 573 | 0.0, 574 | 0.0, 575 | 0.0 576 | ] 577 | }, 578 | { 579 | "question_id": 49001, 580 | "image_id": 49, 581 | "top20": [ 582 | 772, 583 | 1359, 584 | 407, 585 | 235, 586 | 1662, 587 | 987, 588 | 1109, 589 | 1846, 590 | 1071, 591 | 985, 592 | 350, 593 | 806, 594 | 714, 595 | 240, 596 | 2134, 597 | 2061, 598 | 577, 599 | 1956, 600 | 905, 601 | 1705 602 | ], 603 | "top20_scores": [ 604 | 0.0, 605 | 0.0, 606 | 0.0, 607 | 1.0, 608 | 0.0, 609 | 0.0, 610 | 0.0, 611 | 0.0, 612 | 0.0, 613 | 0.0, 614 | 0.0, 615 | 0.0, 616 | 0.0, 617 | 0.0, 618 | 0.0, 619 | 0.0, 620 | 0.0, 621 | 0.0, 622 | 0.0, 623 | 0.0 624 | ] 625 | }, 626 | { 627 | "question_id": 61003, 628 | "image_id": 61, 629 | "top20": [ 630 | 772, 631 | 1031, 632 | 714, 633 | 240, 634 | 235, 635 | 987, 636 | 985, 637 | 1109, 638 | 1982, 639 | 1359, 640 | 661, 641 | 1683, 642 | 350, 643 | 185, 644 | 806, 645 | 384, 646 | 170, 647 | 2061, 648 | 1514, 649 | 1066 650 | ], 651 | "top20_scores": [ 652 | 0.3333333432674408, 653 | 0.3333333432674408, 654 | 0.0, 655 | 1.0, 656 | 0.0, 657 | 0.0, 658 | 0.0, 659 | 0.0, 660 | 0.0, 661 | 0.0, 662 | 0.0, 663 | 0.0, 664 | 0.0, 665 | 0.0, 666 | 0.0, 667 | 0.0, 668 | 0.0, 669 | 0.0, 670 | 0.0, 671 | 0.0 672 | ] 673 | }, 674 | { 675 | "question_id": 64003, 676 | "image_id": 64, 677 | "top20": [ 678 | 377, 679 | 749, 680 | 1083, 681 | 1526, 682 | 1457, 683 | 421, 684 | 773, 685 | 1804, 686 | 1597, 687 | 803, 688 | 198, 689 | 243, 690 | 864, 691 | 76, 692 | 978, 693 | 1448, 694 | 1530, 695 | 744, 696 | 283, 697 | 93 698 | ], 699 | "top20_scores": [ 700 | 1.0, 701 | 0.0, 702 | 0.0, 703 | 0.0, 704 | 0.0, 705 | 0.0, 706 | 0.0, 707 | 0.0, 708 | 0.0, 709 | 0.0, 710 | 0.0, 711 | 0.0, 712 | 0.0, 713 | 0.0, 714 | 0.0, 715 | 0.0, 716 | 0.0, 717 | 0.0, 718 | 0.0, 719 | 0.0 720 | ] 721 | }, 722 | { 723 | "question_id": 71002, 724 | "image_id": 71, 725 | "top20": [ 726 | 1094, 727 | 180, 728 | 1136, 729 | 2145, 730 | 931, 731 | 1671, 732 | 303, 733 | 577, 734 | 1626, 735 | 384, 736 | 676, 737 | 854, 738 | 1482, 739 | 450, 740 | 487, 741 | 1856, 742 | 1877, 743 | 1187, 744 | 119, 745 | 1122 746 | ], 747 | "top20_scores": [ 748 | 0.0, 749 | 0.0, 750 | 0.3333333432674408, 751 | 0.0, 752 | 0.0, 753 | 0.0, 754 | 0.0, 755 | 0.0, 756 | 0.0, 757 | 0.0, 758 | 0.0, 759 | 0.0, 760 | 0.0, 761 | 0.0, 762 | 0.0, 763 | 0.0, 764 | 0.0, 765 | 0.0, 766 | 0.0, 767 | 0.0 768 | ] 769 | }, 770 | { 771 | "question_id": 72001, 772 | "image_id": 72, 773 | "top20": [ 774 | 391, 775 | 576, 776 | 916, 777 | 651, 778 | 867, 779 | 591, 780 | 1085, 781 | 1990, 782 | 1151, 783 | 323, 784 | 290, 785 | 1854, 786 | 659, 787 | 180, 788 | 2107, 789 | 272, 790 | 1414, 791 | 647, 792 | 577, 793 | 80 794 | ], 795 | "top20_scores": [ 796 | 1.0, 797 | 0.3333333432674408, 798 | 0.0, 799 | 0.0, 800 | 0.0, 801 | 0.0, 802 | 0.0, 803 | 0.0, 804 | 0.0, 805 | 0.0, 806 | 0.0, 807 | 0.0, 808 | 0.0, 809 | 0.0, 810 | 0.0, 811 | 0.0, 812 | 0.0, 813 | 0.0, 814 | 0.0, 815 | 0.0 816 | ] 817 | }, 818 | { 819 | "question_id": 73003, 820 | "image_id": 73, 821 | "top20": [ 822 | 1188, 823 | 1745, 824 | 1872, 825 | 783, 826 | 2158, 827 | 1348, 828 | 2221, 829 | 2140, 830 | 2161, 831 | 649, 832 | 1580, 833 | 1260, 834 | 731, 835 | 956, 836 | 2257, 837 | 384, 838 | 2052, 839 | 436, 840 | 1154, 841 | 1465 842 | ], 843 | "top20_scores": [ 844 | 0.0, 845 | 0.0, 846 | 0.0, 847 | 0.0, 848 | 0.0, 849 | 0.0, 850 | 0.0, 851 | 0.0, 852 | 0.0, 853 | 0.0, 854 | 0.0, 855 | 0.0, 856 | 0.0, 857 | 0.0, 858 | 0.0, 859 | 0.0, 860 | 0.0, 861 | 0.0, 862 | 0.0, 863 | 0.0 864 | ] 865 | }, 866 | { 867 | "question_id": 74000, 868 | "image_id": 74, 869 | "top20": [ 870 | 577, 871 | 180, 872 | 384, 873 | 119, 874 | 1141, 875 | 940, 876 | 436, 877 | 549, 878 | 365, 879 | 1737, 880 | 956, 881 | 731, 882 | 1454, 883 | 1465, 884 | 651, 885 | 714, 886 | 1032, 887 | 772, 888 | 1359, 889 | 240 890 | ], 891 | "top20_scores": [ 892 | 1.0, 893 | 1.0, 894 | 0.0, 895 | 0.0, 896 | 0.0, 897 | 0.0, 898 | 0.0, 899 | 0.0, 900 | 0.0, 901 | 0.0, 902 | 0.0, 903 | 0.0, 904 | 0.0, 905 | 0.0, 906 | 0.0, 907 | 0.0, 908 | 0.0, 909 | 0.0, 910 | 0.0, 911 | 0.0 912 | ] 913 | }, 914 | { 915 | "question_id": 74001, 916 | "image_id": 74, 917 | "top20": [ 918 | 93, 919 | 624, 920 | 198, 921 | 1467, 922 | 800, 923 | 1932, 924 | 211, 925 | 99, 926 | 1611, 927 | 262, 928 | 357, 929 | 596, 930 | 180, 931 | 686, 932 | 4, 933 | 1175, 934 | 1065, 935 | 1795, 936 | 888, 937 | 1380 938 | ], 939 | "top20_scores": [ 940 | 1.0, 941 | 0.3333333432674408, 942 | 0.3333333432674408, 943 | 0.0, 944 | 0.0, 945 | 0.0, 946 | 0.0, 947 | 0.0, 948 | 0.0, 949 | 0.0, 950 | 0.0, 951 | 0.0, 952 | 0.3333333432674408, 953 | 0.0, 954 | 0.3333333432674408, 955 | 0.0, 956 | 0.0, 957 | 0.0, 958 | 0.0, 959 | 0.0 960 | ] 961 | }, 962 | { 963 | "question_id": 74002, 964 | "image_id": 74, 965 | "top20": [ 966 | 653, 967 | 1736, 968 | 1320, 969 | 237, 970 | 354, 971 | 306, 972 | 123, 973 | 1217, 974 | 411, 975 | 395, 976 | 224, 977 | 1351, 978 | 1915, 979 | 1102, 980 | 651, 981 | 298, 982 | 1178, 983 | 630, 984 | 1175, 985 | 197 986 | ], 987 | "top20_scores": [ 988 | 0.6666666865348816, 989 | 0.0, 990 | 0.0, 991 | 0.3333333432674408, 992 | 1.0, 993 | 0.0, 994 | 0.0, 995 | 0.0, 996 | 0.0, 997 | 0.0, 998 | 0.0, 999 | 0.0, 1000 | 0.0, 1001 | 0.0, 1002 | 0.0, 1003 | 0.0, 1004 | 0.0, 1005 | 0.0, 1006 | 0.0, 1007 | 0.0 1008 | ] 1009 | }, 1010 | { 1011 | "question_id": 77000, 1012 | "image_id": 77, 1013 | "top20": [ 1014 | 180, 1015 | 577, 1016 | 384, 1017 | 1, 1018 | 119, 1019 | 1141, 1020 | 732, 1021 | 888, 1022 | 436, 1023 | 1737, 1024 | 940, 1025 | 549, 1026 | 365, 1027 | 1465, 1028 | 1026, 1029 | 48, 1030 | 846, 1031 | 1154, 1032 | 240, 1033 | 1391 1034 | ], 1035 | "top20_scores": [ 1036 | 1.0, 1037 | 0.3333333432674408, 1038 | 0.0, 1039 | 0.0, 1040 | 0.0, 1041 | 0.0, 1042 | 0.0, 1043 | 0.0, 1044 | 0.0, 1045 | 0.0, 1046 | 0.0, 1047 | 0.0, 1048 | 0.0, 1049 | 0.0, 1050 | 0.0, 1051 | 0.0, 1052 | 0.0, 1053 | 0.0, 1054 | 0.0, 1055 | 0.0 1056 | ] 1057 | }, 1058 | { 1059 | "question_id": 81000, 1060 | "image_id": 81, 1061 | "top20": [ 1062 | 180, 1063 | 577, 1064 | 384, 1065 | 119, 1066 | 1141, 1067 | 1737, 1068 | 940, 1069 | 365, 1070 | 549, 1071 | 772, 1072 | 436, 1073 | 1465, 1074 | 956, 1075 | 2002, 1076 | 714, 1077 | 1391, 1078 | 1154, 1079 | 1032, 1080 | 731, 1081 | 1454 1082 | ], 1083 | "top20_scores": [ 1084 | 0.6666666865348816, 1085 | 1.0, 1086 | 0.0, 1087 | 0.0, 1088 | 0.0, 1089 | 0.0, 1090 | 0.0, 1091 | 0.0, 1092 | 0.0, 1093 | 0.0, 1094 | 0.0, 1095 | 0.3333333432674408, 1096 | 0.0, 1097 | 0.0, 1098 | 0.0, 1099 | 0.0, 1100 | 0.0, 1101 | 0.0, 1102 | 0.0, 1103 | 0.0 1104 | ] 1105 | }, 1106 | { 1107 | "question_id": 81003, 1108 | "image_id": 81, 1109 | "top20": [ 1110 | 2140, 1111 | 783, 1112 | 1740, 1113 | 1188, 1114 | 2257, 1115 | 2161, 1116 | 2221, 1117 | 2205, 1118 | 649, 1119 | 384, 1120 | 1260, 1121 | 2158, 1122 | 1745, 1123 | 940, 1124 | 956, 1125 | 1872, 1126 | 180, 1127 | 731, 1128 | 119, 1129 | 586 1130 | ], 1131 | "top20_scores": [ 1132 | 0.0, 1133 | 0.0, 1134 | 0.0, 1135 | 0.0, 1136 | 0.0, 1137 | 1.0, 1138 | 0.0, 1139 | 0.0, 1140 | 0.0, 1141 | 0.0, 1142 | 0.0, 1143 | 0.0, 1144 | 0.0, 1145 | 0.0, 1146 | 0.0, 1147 | 0.0, 1148 | 0.0, 1149 | 0.0, 1150 | 0.0, 1151 | 0.0 1152 | ] 1153 | }, 1154 | { 1155 | "question_id": 86002, 1156 | "image_id": 86, 1157 | "top20": [ 1158 | 752, 1159 | 1973, 1160 | 1753, 1161 | 1402, 1162 | 180, 1163 | 143, 1164 | 704, 1165 | 1640, 1166 | 896, 1167 | 531, 1168 | 1661, 1169 | 651, 1170 | 551, 1171 | 63, 1172 | 490, 1173 | 642, 1174 | 1027, 1175 | 664, 1176 | 955, 1177 | 1133 1178 | ], 1179 | "top20_scores": [ 1180 | 1.0, 1181 | 0.3333333432674408, 1182 | 0.0, 1183 | 0.0, 1184 | 0.0, 1185 | 1.0, 1186 | 0.0, 1187 | 0.0, 1188 | 0.0, 1189 | 0.0, 1190 | 0.0, 1191 | 0.0, 1192 | 0.0, 1193 | 0.0, 1194 | 0.0, 1195 | 0.0, 1196 | 0.0, 1197 | 0.0, 1198 | 0.0, 1199 | 0.0 1200 | ] 1201 | }, 1202 | { 1203 | "question_id": 92002, 1204 | "image_id": 92, 1205 | "top20": [ 1206 | 1154, 1207 | 1465, 1208 | 549, 1209 | 436, 1210 | 1391, 1211 | 1509, 1212 | 1461, 1213 | 1861, 1214 | 2, 1215 | 956, 1216 | 174, 1217 | 634, 1218 | 1779, 1219 | 2093, 1220 | 547, 1221 | 718, 1222 | 592, 1223 | 384, 1224 | 30, 1225 | 559 1226 | ], 1227 | "top20_scores": [ 1228 | 0.0, 1229 | 0.3333333432674408, 1230 | 0.0, 1231 | 1.0, 1232 | 0.0, 1233 | 0.0, 1234 | 0.0, 1235 | 0.0, 1236 | 0.0, 1237 | 0.0, 1238 | 0.0, 1239 | 0.0, 1240 | 0.0, 1241 | 0.0, 1242 | 0.0, 1243 | 0.0, 1244 | 0.0, 1245 | 0.0, 1246 | 0.0, 1247 | 0.0 1248 | ] 1249 | } 1250 | ] 1251 | -------------------------------------------------------------------------------- /data4VE/train_dataset4VE_demo.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question_id": 9000, 4 | "image_id": 9, 5 | "top20": [ 6 | 1465, 7 | 549, 8 | 436, 9 | 1154, 10 | 1391, 11 | 1509, 12 | 2, 13 | 956, 14 | 1461, 15 | 1779, 16 | 1861, 17 | 592, 18 | 174, 19 | 180, 20 | 2093, 21 | 577, 22 | 547, 23 | 289, 24 | 30, 25 | 1001 26 | ], 27 | "top20_scores": [ 28 | 0.0, 29 | 1.0, 30 | 0.0, 31 | 0.0, 32 | 0.0, 33 | 0.0, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 0.0, 38 | 0.0, 39 | 0.0, 40 | 0.0, 41 | 0.0, 42 | 0.0, 43 | 0.0, 44 | 0.0, 45 | 0.0, 46 | 0.0, 47 | 0.0 48 | ] 49 | }, 50 | { 51 | "question_id": 9002, 52 | "image_id": 9, 53 | "top20": [ 54 | 611, 55 | 225, 56 | 480, 57 | 1801, 58 | 445, 59 | 1270, 60 | 378, 61 | 73, 62 | 1800, 63 | 423, 64 | 170, 65 | 1968, 66 | 180, 67 | 659, 68 | 1209, 69 | 1744, 70 | 1912, 71 | 1309, 72 | 1632, 73 | 533 74 | ], 75 | "top20_scores": [ 76 | 1.0, 77 | 0.0, 78 | 0.0, 79 | 0.0, 80 | 0.0, 81 | 0.0, 82 | 0.0, 83 | 0.0, 84 | 0.0, 85 | 0.0, 86 | 0.0, 87 | 0.0, 88 | 0.0, 89 | 0.0, 90 | 0.0, 91 | 0.0, 92 | 0.0, 93 | 0.0, 94 | 0.0, 95 | 0.0 96 | ] 97 | }, 98 | { 99 | "question_id": 25000, 100 | "image_id": 25, 101 | "top20": [ 102 | 867, 103 | 323, 104 | 2107, 105 | 1414, 106 | 1990, 107 | 1743, 108 | 391, 109 | 647, 110 | 576, 111 | 651, 112 | 745, 113 | 840, 114 | 957, 115 | 916, 116 | 1006, 117 | 1279, 118 | 1073, 119 | 1196, 120 | 283, 121 | 591 122 | ], 123 | "top20_scores": [ 124 | 1.0, 125 | 0.0, 126 | 0.0, 127 | 0.0, 128 | 0.6666666865348816, 129 | 0.0, 130 | 0.0, 131 | 0.0, 132 | 0.0, 133 | 0.0, 134 | 0.0, 135 | 0.0, 136 | 0.0, 137 | 0.0, 138 | 0.0, 139 | 0.0, 140 | 0.0, 141 | 0.0, 142 | 0.0, 143 | 0.0 144 | ] 145 | }, 146 | { 147 | "question_id": 25001, 148 | "image_id": 25, 149 | "top20": [ 150 | 900, 151 | 952, 152 | 1085, 153 | 916, 154 | 574, 155 | 484, 156 | 1394, 157 | 1306, 158 | 1990, 159 | 180, 160 | 356, 161 | 1277, 162 | 1006, 163 | 391, 164 | 1713, 165 | 448, 166 | 123, 167 | 840, 168 | 468, 169 | 908 170 | ], 171 | "top20_scores": [ 172 | 1.0, 173 | 0.0, 174 | 0.0, 175 | 0.0, 176 | 0.0, 177 | 0.0, 178 | 0.0, 179 | 0.0, 180 | 0.0, 181 | 0.0, 182 | 0.0, 183 | 0.0, 184 | 0.0, 185 | 0.0, 186 | 0.0, 187 | 0.0, 188 | 0.0, 189 | 0.0, 190 | 0.0, 191 | 0.0 192 | ] 193 | }, 194 | { 195 | "question_id": 25002, 196 | "image_id": 25, 197 | "top20": [ 198 | 577, 199 | 180, 200 | 1085, 201 | 384, 202 | 119, 203 | 1141, 204 | 549, 205 | 1737, 206 | 365, 207 | 940, 208 | 1465, 209 | 574, 210 | 349, 211 | 916, 212 | 867, 213 | 484, 214 | 436, 215 | 887, 216 | 1990, 217 | 900 218 | ], 219 | "top20_scores": [ 220 | 0.0, 221 | 1.0, 222 | 0.0, 223 | 0.0, 224 | 0.0, 225 | 0.0, 226 | 0.0, 227 | 0.0, 228 | 0.0, 229 | 0.0, 230 | 0.0, 231 | 0.0, 232 | 0.0, 233 | 0.0, 234 | 0.0, 235 | 0.0, 236 | 0.0, 237 | 0.0, 238 | 0.0, 239 | 0.0 240 | ] 241 | }, 242 | { 243 | "question_id": 25003, 244 | "image_id": 25, 245 | "top20": [ 246 | 180, 247 | 577, 248 | 391, 249 | 867, 250 | 1990, 251 | 576, 252 | 323, 253 | 1085, 254 | 651, 255 | 1414, 256 | 2107, 257 | 1743, 258 | 647, 259 | 916, 260 | 290, 261 | 1009, 262 | 840, 263 | 1800, 264 | 1006, 265 | 1279 266 | ], 267 | "top20_scores": [ 268 | 1.0, 269 | 0.0, 270 | 0.0, 271 | 0.0, 272 | 0.0, 273 | 0.0, 274 | 0.0, 275 | 0.0, 276 | 0.0, 277 | 0.0, 278 | 0.0, 279 | 0.0, 280 | 0.0, 281 | 0.0, 282 | 0.0, 283 | 0.0, 284 | 0.0, 285 | 0.0, 286 | 0.0, 287 | 0.0 288 | ] 289 | }, 290 | { 291 | "question_id": 25005, 292 | "image_id": 25, 293 | "top20": [ 294 | 180, 295 | 577, 296 | 1085, 297 | 436, 298 | 549, 299 | 1465, 300 | 956, 301 | 384, 302 | 1122, 303 | 119, 304 | 674, 305 | 1843, 306 | 1154, 307 | 1141, 308 | 1391, 309 | 484, 310 | 221, 311 | 53, 312 | 940, 313 | 4 314 | ], 315 | "top20_scores": [ 316 | 0.0, 317 | 1.0, 318 | 0.0, 319 | 0.0, 320 | 0.0, 321 | 0.0, 322 | 0.0, 323 | 0.0, 324 | 0.0, 325 | 0.0, 326 | 0.0, 327 | 0.0, 328 | 0.0, 329 | 0.0, 330 | 0.0, 331 | 0.0, 332 | 0.0, 333 | 0.0, 334 | 0.0, 335 | 0.0 336 | ] 337 | }, 338 | { 339 | "question_id": 25006, 340 | "image_id": 25, 341 | "top20": [ 342 | 867, 343 | 1743, 344 | 1990, 345 | 1414, 346 | 323, 347 | 2107, 348 | 1006, 349 | 1132, 350 | 745, 351 | 840, 352 | 391, 353 | 647, 354 | 957, 355 | 1009, 356 | 766, 357 | 180, 358 | 376, 359 | 1279, 360 | 283, 361 | 1085 362 | ], 363 | "top20_scores": [ 364 | 1.0, 365 | 1.0, 366 | 0.3333333432674408, 367 | 0.0, 368 | 0.0, 369 | 0.0, 370 | 0.0, 371 | 0.0, 372 | 0.0, 373 | 0.0, 374 | 0.0, 375 | 0.0, 376 | 0.0, 377 | 0.0, 378 | 0.0, 379 | 0.0, 380 | 0.0, 381 | 0.0, 382 | 0.0, 383 | 0.0 384 | ] 385 | }, 386 | { 387 | "question_id": 25007, 388 | "image_id": 25, 389 | "top20": [ 390 | 180, 391 | 577, 392 | 391, 393 | 867, 394 | 1085, 395 | 576, 396 | 1990, 397 | 916, 398 | 119, 399 | 384, 400 | 840, 401 | 651, 402 | 290, 403 | 1009, 404 | 365, 405 | 940, 406 | 1141, 407 | 323, 408 | 484, 409 | 900 410 | ], 411 | "top20_scores": [ 412 | 1.0, 413 | 0.6666666865348816, 414 | 0.0, 415 | 0.0, 416 | 0.0, 417 | 0.0, 418 | 0.0, 419 | 0.0, 420 | 0.0, 421 | 0.0, 422 | 0.0, 423 | 0.0, 424 | 0.0, 425 | 0.0, 426 | 0.0, 427 | 0.0, 428 | 0.0, 429 | 0.0, 430 | 0.0, 431 | 0.0 432 | ] 433 | }, 434 | { 435 | "question_id": 25009, 436 | "image_id": 25, 437 | "top20": [ 438 | 577, 439 | 180, 440 | 1085, 441 | 119, 442 | 384, 443 | 549, 444 | 1141, 445 | 1737, 446 | 365, 447 | 484, 448 | 867, 449 | 1465, 450 | 574, 451 | 349, 452 | 940, 453 | 1990, 454 | 436, 455 | 840, 456 | 887, 457 | 916 458 | ], 459 | "top20_scores": [ 460 | 1.0, 461 | 1.0, 462 | 0.0, 463 | 0.0, 464 | 0.0, 465 | 0.0, 466 | 0.0, 467 | 0.0, 468 | 0.0, 469 | 0.0, 470 | 0.0, 471 | 0.0, 472 | 0.0, 473 | 0.0, 474 | 0.0, 475 | 0.0, 476 | 0.0, 477 | 0.0, 478 | 0.0, 479 | 0.0 480 | ] 481 | }, 482 | { 483 | "question_id": 25010, 484 | "image_id": 25, 485 | "top20": [ 486 | 867, 487 | 1743, 488 | 323, 489 | 957, 490 | 647, 491 | 1710, 492 | 1132, 493 | 1414, 494 | 2107, 495 | 766, 496 | 391, 497 | 1009, 498 | 1279, 499 | 1359, 500 | 1990, 501 | 283, 502 | 745, 503 | 576, 504 | 840, 505 | 376 506 | ], 507 | "top20_scores": [ 508 | 0.6666666865348816, 509 | 1.0, 510 | 0.0, 511 | 0.0, 512 | 0.0, 513 | 0.0, 514 | 0.0, 515 | 0.0, 516 | 0.0, 517 | 0.0, 518 | 0.0, 519 | 0.0, 520 | 0.0, 521 | 0.0, 522 | 0.3333333432674408, 523 | 0.0, 524 | 0.0, 525 | 0.0, 526 | 0.0, 527 | 0.0 528 | ] 529 | }, 530 | { 531 | "question_id": 25013, 532 | "image_id": 25, 533 | "top20": [ 534 | 577, 535 | 180, 536 | 1085, 537 | 436, 538 | 549, 539 | 384, 540 | 1465, 541 | 119, 542 | 1141, 543 | 1737, 544 | 956, 545 | 365, 546 | 484, 547 | 867, 548 | 940, 549 | 170, 550 | 1359, 551 | 731, 552 | 349, 553 | 887 554 | ], 555 | "top20_scores": [ 556 | 1.0, 557 | 0.0, 558 | 0.0, 559 | 0.0, 560 | 0.0, 561 | 0.0, 562 | 0.0, 563 | 0.0, 564 | 0.0, 565 | 0.0, 566 | 0.0, 567 | 0.0, 568 | 0.0, 569 | 0.0, 570 | 0.0, 571 | 0.0, 572 | 0.0, 573 | 0.0, 574 | 0.0, 575 | 0.0 576 | ] 577 | }, 578 | { 579 | "question_id": 25014, 580 | "image_id": 25, 581 | "top20": [ 582 | 180, 583 | 577, 584 | 384, 585 | 549, 586 | 1465, 587 | 119, 588 | 436, 589 | 1141, 590 | 1085, 591 | 940, 592 | 1737, 593 | 365, 594 | 1391, 595 | 956, 596 | 1154, 597 | 170, 598 | 1861, 599 | 2148, 600 | 1509, 601 | 731 602 | ], 603 | "top20_scores": [ 604 | 1.0, 605 | 1.0, 606 | 0.0, 607 | 0.0, 608 | 0.0, 609 | 0.0, 610 | 0.0, 611 | 0.0, 612 | 0.0, 613 | 0.0, 614 | 0.0, 615 | 0.0, 616 | 0.0, 617 | 0.0, 618 | 0.0, 619 | 0.0, 620 | 0.0, 621 | 0.0, 622 | 0.0, 623 | 0.0 624 | ] 625 | }, 626 | { 627 | "question_id": 25015, 628 | "image_id": 25, 629 | "top20": [ 630 | 549, 631 | 436, 632 | 1465, 633 | 956, 634 | 1154, 635 | 180, 636 | 1391, 637 | 1085, 638 | 349, 639 | 384, 640 | 887, 641 | 577, 642 | 1141, 643 | 1359, 644 | 119, 645 | 2061, 646 | 365, 647 | 484, 648 | 900, 649 | 468 650 | ], 651 | "top20_scores": [ 652 | 1.0, 653 | 0.6666666865348816, 654 | 0.0, 655 | 0.0, 656 | 0.0, 657 | 0.0, 658 | 0.0, 659 | 0.0, 660 | 0.0, 661 | 0.0, 662 | 0.0, 663 | 0.0, 664 | 0.0, 665 | 0.0, 666 | 0.0, 667 | 0.0, 668 | 0.0, 669 | 0.0, 670 | 0.0, 671 | 0.0 672 | ] 673 | }, 674 | { 675 | "question_id": 25016, 676 | "image_id": 25, 677 | "top20": [ 678 | 577, 679 | 180, 680 | 384, 681 | 956, 682 | 1843, 683 | 119, 684 | 940, 685 | 549, 686 | 1141, 687 | 731, 688 | 1454, 689 | 576, 690 | 651, 691 | 365, 692 | 624, 693 | 436, 694 | 1737, 695 | 1465, 696 | 674, 697 | 867 698 | ], 699 | "top20_scores": [ 700 | 1.0, 701 | 0.0, 702 | 0.0, 703 | 0.0, 704 | 0.0, 705 | 0.0, 706 | 0.0, 707 | 0.0, 708 | 0.0, 709 | 0.0, 710 | 0.0, 711 | 0.0, 712 | 0.0, 713 | 0.0, 714 | 0.0, 715 | 0.0, 716 | 0.0, 717 | 0.0, 718 | 0.0, 719 | 0.0 720 | ] 721 | }, 722 | { 723 | "question_id": 25017, 724 | "image_id": 25, 725 | "top20": [ 726 | 549, 727 | 436, 728 | 1465, 729 | 1085, 730 | 956, 731 | 1154, 732 | 180, 733 | 1391, 734 | 349, 735 | 484, 736 | 887, 737 | 384, 738 | 2118, 739 | 1141, 740 | 119, 741 | 1627, 742 | 365, 743 | 577, 744 | 1359, 745 | 926 746 | ], 747 | "top20_scores": [ 748 | 1.0, 749 | 0.0, 750 | 0.0, 751 | 0.3333333432674408, 752 | 0.0, 753 | 0.0, 754 | 0.0, 755 | 0.0, 756 | 0.0, 757 | 0.0, 758 | 0.0, 759 | 0.0, 760 | 0.0, 761 | 0.0, 762 | 0.0, 763 | 0.0, 764 | 0.0, 765 | 0.0, 766 | 0.0, 767 | 0.0 768 | ] 769 | }, 770 | { 771 | "question_id": 30000, 772 | "image_id": 30, 773 | "top20": [ 774 | 180, 775 | 577, 776 | 231, 777 | 1465, 778 | 436, 779 | 549, 780 | 972, 781 | 1154, 782 | 777, 783 | 807, 784 | 1391, 785 | 1247, 786 | 1509, 787 | 1558, 788 | 956, 789 | 2, 790 | 1461, 791 | 592, 792 | 508, 793 | 1861 794 | ], 795 | "top20_scores": [ 796 | 1.0, 797 | 0.0, 798 | 0.0, 799 | 0.0, 800 | 0.0, 801 | 0.0, 802 | 0.0, 803 | 0.0, 804 | 0.0, 805 | 0.0, 806 | 0.0, 807 | 0.0, 808 | 0.0, 809 | 0.0, 810 | 0.0, 811 | 0.0, 812 | 0.0, 813 | 0.0, 814 | 0.0, 815 | 0.0 816 | ] 817 | }, 818 | { 819 | "question_id": 30003, 820 | "image_id": 30, 821 | "top20": [ 822 | 577, 823 | 180, 824 | 436, 825 | 1465, 826 | 549, 827 | 384, 828 | 956, 829 | 119, 830 | 1141, 831 | 940, 832 | 1154, 833 | 1391, 834 | 731, 835 | 1737, 836 | 365, 837 | 1292, 838 | 1122, 839 | 240, 840 | 1454, 841 | 772 842 | ], 843 | "top20_scores": [ 844 | 1.0, 845 | 1.0, 846 | 0.0, 847 | 0.0, 848 | 0.0, 849 | 0.0, 850 | 0.0, 851 | 0.0, 852 | 0.0, 853 | 0.0, 854 | 0.0, 855 | 0.0, 856 | 0.0, 857 | 0.0, 858 | 0.0, 859 | 0.0, 860 | 0.0, 861 | 0.0, 862 | 0.0, 863 | 0.0 864 | ] 865 | }, 866 | { 867 | "question_id": 30004, 868 | "image_id": 30, 869 | "top20": [ 870 | 772, 871 | 350, 872 | 233, 873 | 650, 874 | 987, 875 | 1109, 876 | 1956, 877 | 1071, 878 | 1662, 879 | 1031, 880 | 2146, 881 | 806, 882 | 1359, 883 | 1514, 884 | 407, 885 | 1846, 886 | 985, 887 | 967, 888 | 714, 889 | 384 890 | ], 891 | "top20_scores": [ 892 | 1.0, 893 | 0.0, 894 | 0.0, 895 | 0.0, 896 | 0.0, 897 | 0.0, 898 | 0.0, 899 | 0.0, 900 | 0.0, 901 | 0.0, 902 | 0.0, 903 | 0.0, 904 | 0.0, 905 | 0.0, 906 | 0.0, 907 | 0.0, 908 | 0.0, 909 | 0.0, 910 | 0.0, 911 | 0.0 912 | ] 913 | }, 914 | { 915 | "question_id": 34000, 916 | "image_id": 34, 917 | "top20": [ 918 | 180, 919 | 577, 920 | 349, 921 | 549, 922 | 384, 923 | 1465, 924 | 436, 925 | 119, 926 | 1085, 927 | 772, 928 | 1141, 929 | 365, 930 | 940, 931 | 956, 932 | 8, 933 | 731, 934 | 1359, 935 | 1737, 936 | 407, 937 | 985 938 | ], 939 | "top20_scores": [ 940 | 0.0, 941 | 1.0, 942 | 0.0, 943 | 0.0, 944 | 0.0, 945 | 0.0, 946 | 0.0, 947 | 0.0, 948 | 0.0, 949 | 0.0, 950 | 0.0, 951 | 0.0, 952 | 0.0, 953 | 0.0, 954 | 0.0, 955 | 0.0, 956 | 0.0, 957 | 0.0, 958 | 0.0, 959 | 0.0 960 | ] 961 | }, 962 | { 963 | "question_id": 36000, 964 | "image_id": 36, 965 | "top20": [ 966 | 577, 967 | 180, 968 | 384, 969 | 119, 970 | 1141, 971 | 940, 972 | 365, 973 | 1737, 974 | 48, 975 | 514, 976 | 732, 977 | 731, 978 | 714, 979 | 651, 980 | 1454, 981 | 1587, 982 | 955, 983 | 956, 984 | 1032, 985 | 436 986 | ], 987 | "top20_scores": [ 988 | 0.0, 989 | 1.0, 990 | 0.0, 991 | 0.0, 992 | 0.0, 993 | 0.0, 994 | 0.0, 995 | 0.0, 996 | 0.0, 997 | 0.0, 998 | 0.0, 999 | 0.0, 1000 | 0.0, 1001 | 0.0, 1002 | 0.0, 1003 | 0.0, 1004 | 0.0, 1005 | 0.0, 1006 | 0.0, 1007 | 0.0 1008 | ] 1009 | }, 1010 | { 1011 | "question_id": 36001, 1012 | "image_id": 36, 1013 | "top20": [ 1014 | 1031, 1015 | 1794, 1016 | 240, 1017 | 1982, 1018 | 1066, 1019 | 2159, 1020 | 1110, 1021 | 1683, 1022 | 650, 1023 | 772, 1024 | 714, 1025 | 288, 1026 | 1776, 1027 | 661, 1028 | 281, 1029 | 977, 1030 | 1227, 1031 | 1211, 1032 | 437, 1033 | 1299 1034 | ], 1035 | "top20_scores": [ 1036 | 1.0, 1037 | 0.0, 1038 | 0.0, 1039 | 0.0, 1040 | 0.0, 1041 | 0.0, 1042 | 0.0, 1043 | 0.0, 1044 | 0.0, 1045 | 0.0, 1046 | 0.0, 1047 | 0.0, 1048 | 0.0, 1049 | 0.0, 1050 | 0.0, 1051 | 0.0, 1052 | 0.0, 1053 | 0.0, 1054 | 0.0, 1055 | 0.0 1056 | ] 1057 | }, 1058 | { 1059 | "question_id": 36002, 1060 | "image_id": 36, 1061 | "top20": [ 1062 | 1277, 1063 | 497, 1064 | 526, 1065 | 832, 1066 | 700, 1067 | 1822, 1068 | 180, 1069 | 2103, 1070 | 1875, 1071 | 2047, 1072 | 616, 1073 | 808, 1074 | 1408, 1075 | 2184, 1076 | 2036, 1077 | 1781, 1078 | 630, 1079 | 1255, 1080 | 1103, 1081 | 939 1082 | ], 1083 | "top20_scores": [ 1084 | 0.0, 1085 | 0.6666666865348816, 1086 | 0.0, 1087 | 0.0, 1088 | 0.0, 1089 | 0.0, 1090 | 0.0, 1091 | 0.0, 1092 | 0.0, 1093 | 0.0, 1094 | 0.0, 1095 | 0.0, 1096 | 0.0, 1097 | 0.0, 1098 | 0.0, 1099 | 0.0, 1100 | 0.0, 1101 | 0.0, 1102 | 0.0, 1103 | 0.0 1104 | ] 1105 | }, 1106 | { 1107 | "question_id": 42001, 1108 | "image_id": 42, 1109 | "top20": [ 1110 | 180, 1111 | 577, 1112 | 240, 1113 | 1122, 1114 | 674, 1115 | 235, 1116 | 1813, 1117 | 1843, 1118 | 384, 1119 | 119, 1120 | 436, 1121 | 1359, 1122 | 676, 1123 | 940, 1124 | 53, 1125 | 2016, 1126 | 1141, 1127 | 1292, 1128 | 1648, 1129 | 549 1130 | ], 1131 | "top20_scores": [ 1132 | 1.0, 1133 | 0.0, 1134 | 0.0, 1135 | 0.0, 1136 | 0.0, 1137 | 0.0, 1138 | 0.0, 1139 | 0.0, 1140 | 0.0, 1141 | 0.0, 1142 | 0.0, 1143 | 0.0, 1144 | 0.0, 1145 | 0.0, 1146 | 0.0, 1147 | 0.0, 1148 | 0.0, 1149 | 0.0, 1150 | 0.0, 1151 | 0.0 1152 | ] 1153 | }, 1154 | { 1155 | "question_id": 49000, 1156 | "image_id": 49, 1157 | "top20": [ 1158 | 180, 1159 | 577, 1160 | 972, 1161 | 436, 1162 | 1465, 1163 | 384, 1164 | 549, 1165 | 119, 1166 | 231, 1167 | 985, 1168 | 772, 1169 | 240, 1170 | 1154, 1171 | 1247, 1172 | 1141, 1173 | 956, 1174 | 1790, 1175 | 714, 1176 | 1391, 1177 | 940 1178 | ], 1179 | "top20_scores": [ 1180 | 1.0, 1181 | 0.0, 1182 | 0.0, 1183 | 0.0, 1184 | 0.0, 1185 | 0.0, 1186 | 0.0, 1187 | 0.0, 1188 | 0.0, 1189 | 0.0, 1190 | 0.0, 1191 | 0.0, 1192 | 0.0, 1193 | 0.0, 1194 | 0.0, 1195 | 0.0, 1196 | 0.0, 1197 | 0.0, 1198 | 0.0, 1199 | 0.0 1200 | ] 1201 | }, 1202 | { 1203 | "question_id": 49002, 1204 | "image_id": 49, 1205 | "top20": [ 1206 | 577, 1207 | 180, 1208 | 384, 1209 | 119, 1210 | 1141, 1211 | 940, 1212 | 1737, 1213 | 549, 1214 | 956, 1215 | 365, 1216 | 731, 1217 | 436, 1218 | 1465, 1219 | 772, 1220 | 240, 1221 | 714, 1222 | 1454, 1223 | 1154, 1224 | 1032, 1225 | 1292 1226 | ], 1227 | "top20_scores": [ 1228 | 1.0, 1229 | 1.0, 1230 | 0.0, 1231 | 0.0, 1232 | 0.0, 1233 | 0.0, 1234 | 0.0, 1235 | 0.0, 1236 | 0.0, 1237 | 0.0, 1238 | 0.0, 1239 | 0.0, 1240 | 0.0, 1241 | 0.0, 1242 | 0.0, 1243 | 0.0, 1244 | 0.0, 1245 | 0.0, 1246 | 0.0, 1247 | 0.0 1248 | ] 1249 | } 1250 | ] 1251 | -------------------------------------------------------------------------------- /fc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | 5 | 6 | def get_norm(norm): 7 | no_norm = lambda x, dim: x 8 | if norm == 'weight': 9 | norm_layer = weight_norm 10 | elif norm == 'batch': 11 | norm_layer = nn.BatchNorm1d 12 | elif norm == 'layer': 13 | norm_layer = nn.LayerNorm 14 | elif norm == 'none': 15 | norm_layer = no_norm 16 | else: 17 | print("Invalid Normalization") 18 | raise Exception("Invalid Normalization") 19 | return norm_layer 20 | 21 | 22 | def get_act(act): 23 | if act == 'ReLU': 24 | act_layer = nn.ReLU 25 | elif act == 'LeakyReLU': 26 | act_layer = nn.LeakyReLU 27 | elif act == 'PReLU': 28 | act_layer = nn.PReLU 29 | elif act == 'RReLU': 30 | act_layer = nn.RReLU 31 | elif act == 'ELU': 32 | act_layer = nn.ELU 33 | elif act == 'SELU': 34 | act_layer = nn.SELU 35 | elif act == 'Tanh': 36 | act_layer = nn.Tanh 37 | elif act == 'Hardtanh': 38 | act_layer = nn.Hardtanh 39 | elif act == 'Sigmoid': 40 | act_layer = nn.Sigmoid 41 | else: 42 | print("Invalid activation function") 43 | raise Exception("Invalid activation function") 44 | return act_layer 45 | 46 | 47 | 48 | class FCNet(nn.Module): 49 | """Simple class for non-linear fully connect network 50 | """ 51 | def __init__(self, dims, dropout, norm, act): 52 | super(FCNet, self).__init__() 53 | 54 | norm_layer = get_norm(norm) 55 | act_layer = get_act(act) 56 | 57 | layers = [] 58 | for i in range(len(dims)-2): 59 | in_dim = dims[i] 60 | out_dim = dims[i+1] 61 | layers.append(norm_layer(nn.Linear(in_dim, out_dim), dim=None)) 62 | layers.append(act_layer()) 63 | layers.append(nn.Dropout(p=dropout)) 64 | layers.append(norm_layer(nn.Linear(dims[-2], dims[-1]), dim=None)) 65 | layers.append(act_layer()) 66 | layers.append(nn.Dropout(p=dropout)) 67 | 68 | self.main = nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | return self.main(x) 72 | 73 | 74 | 75 | class GTH(nn.Module): 76 | """Simple class for Gated Tanh 77 | """ 78 | def __init__(self, in_dim, out_dim, dropout, norm, act): 79 | super(GTH, self).__init__() 80 | 81 | self.nonlinear = FCNet([in_dim, out_dim], dropout= dropout, norm= norm, act= act) 82 | self.gate = FCNet([in_dim, out_dim], dropout= dropout, norm= norm, act= 'Sigmoid') 83 | 84 | def forward(self, x): 85 | x_proj = self.nonlinear(x) 86 | gate = self.gate(x) 87 | x_proj = x_proj*gate 88 | return x_proj 89 | 90 | 91 | if __name__ == '__main__': 92 | fc1 = FCNet([10, 20, 10]) 93 | print(fc1) 94 | 95 | print('============') 96 | fc2 = FCNet([10, 20]) 97 | print(fc2) -------------------------------------------------------------------------------- /language_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class WordEmbedding(nn.Module): 7 | """Word Embedding 8 | The ntoken-th dim is used for padding_idx, which agrees *implicitly* 9 | with the definition in Dictionary. 10 | """ 11 | def __init__(self, ntoken, emb_dim, dropout): 12 | super(WordEmbedding, self).__init__() 13 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken) 14 | self.dropout = nn.Dropout(dropout) 15 | self.ntoken = ntoken 16 | self.emb_dim = emb_dim 17 | 18 | def init_embedding(self, np_file): 19 | weight_init = torch.from_numpy(np.load(np_file)) 20 | #assert weight_init.shape == (self.ntoken, self.emb_dim) 21 | self.emb.weight.data[:weight_init.shape[0]] = weight_init 22 | 23 | def forward(self, x): 24 | emb = self.emb(x) 25 | emb = self.dropout(emb) 26 | return emb 27 | 28 | 29 | class QuestionEmbedding(nn.Module): 30 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'): 31 | """Module for question embedding 32 | """ 33 | super(QuestionEmbedding, self).__init__() 34 | assert rnn_type == 'LSTM' or rnn_type == 'GRU' 35 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU 36 | 37 | self.rnn = rnn_cls( 38 | in_dim, num_hid, nlayers, 39 | bidirectional=bidirect, 40 | dropout=dropout, 41 | batch_first=True) 42 | 43 | self.in_dim = in_dim 44 | self.num_hid = num_hid 45 | self.nlayers = nlayers 46 | self.rnn_type = rnn_type 47 | self.ndirections = 1 + int(bidirect) 48 | 49 | def init_hidden(self, batch): 50 | # just to get the type of tensor 51 | weight = next(self.parameters()).data 52 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid) 53 | if self.rnn_type == 'LSTM': 54 | return (weight.new(*hid_shape).zero_(), 55 | weight.new(*hid_shape).zero_()) 56 | else: 57 | return weight.new(*hid_shape).zero_() 58 | 59 | def forward(self, x): 60 | # x: [batch, sequence, in_dim] 61 | batch = x.size(0) 62 | hidden = self.init_hidden(batch) 63 | # self.rnn.flatten_parameters() 64 | output, hidden = self.rnn(x, hidden) 65 | 66 | if self.ndirections == 1: 67 | return output[:, -1] 68 | 69 | forward_ = output[:, -1, :self.num_hid] 70 | backward = output[:, 0, self.num_hid:] 71 | return torch.cat((forward_, backward), dim=1) 72 | 73 | def forward_all(self, x): 74 | # x: [batch, sequence, in_dim] 75 | batch = x.size(0) 76 | hidden = self.init_hidden(batch) 77 | # self.rnn.flatten_parameters() 78 | output, hidden = self.rnn(x, hidden) 79 | return output 80 | -------------------------------------------------------------------------------- /lxmert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from transformers import LxmertTokenizer, LxmertModel 5 | import numpy as np 6 | from language_model import WordEmbedding, QuestionEmbedding 7 | from classifier import SimpleClassifier, PaperClassifier 8 | 9 | from torch.nn import functional as F 10 | from fc import FCNet, GTH 11 | from attention import Att_0, Att_1, Att_2, Att_3, Att_P, Att_PD, Att_3S 12 | import torch 13 | import random 14 | class Model(nn.Module): 15 | def __init__(self, opt): 16 | super(Model, self).__init__() 17 | self.opt = opt 18 | self.model = LxmertModel.from_pretrained('unc-nlp/lxmert-base-uncased', return_dict=True) 19 | self.model = nn.DataParallel(self.model)#.cuda() 20 | self.candi_ans_num = opt.train_candi_ans_num 21 | self.batchsize = opt.batch_size 22 | self.Linear_layer = nn.Linear(768, 1)#.cuda() 23 | norm = opt.norm#"weight" 24 | activation = opt.activation#Relu 25 | dropC = opt.dropC#0.5 26 | self.classifier = SimpleClassifier(in_dim=768, hid_dim=2 * 768, out_dim=1, 27 | dropout=dropC, norm=norm, act=activation) 28 | 29 | def forward(self, qa_text, v, b, epo, name): 30 | """ 31 | qa_text (btachsize, candi_ans_num, max_length) 32 | v (batchsize, obj_num, v_dim) 33 | b (batchsize, obj_num, b_dim) 34 | 35 | return: logits 36 | """ 37 | qa_text = qa_text.cuda() 38 | v= v.cuda() 39 | b= b.cuda() 40 | 41 | print("qa_text.shape",qa_text.shape) 42 | if name == 'train': 43 | self.candi_ans_num = self.opt.train_candi_ans_num 44 | elif name == 'test': 45 | self.candi_ans_num = self.opt.test_candi_ans_num 46 | qa_text_reshape = qa_text.reshape(qa_text.shape[0] * self.candi_ans_num, -1) 47 | v_repeat = v.repeat(1, self.candi_ans_num, 1) 48 | v_reshape = v_repeat.reshape( v.shape[0] * self.candi_ans_num,v.shape[1], v.shape[2] ) 49 | b_repeat = b.repeat(1, self.candi_ans_num , 1) 50 | b_reshape = b_repeat.reshape( b.shape[0] * self.candi_ans_num,b.shape[1], b.shape[2] ) 51 | 52 | outputs = self.model(qa_text_reshape, v_reshape, b_reshape) 53 | pool_out = outputs.pooled_output 54 | 55 | logits = self.classifier(pool_out) 56 | logits_reshape = logits.reshape(-1, self.candi_ans_num) 57 | 58 | return logits_reshape 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhoebusSi/SAR/a1934bf6b728edf7149aa8e7fa69233167e512dc/model.jpg -------------------------------------------------------------------------------- /opts_SAR.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_opt(): 4 | parser = argparse.ArgumentParser() 5 | # Data input settings 6 | parser.add_argument('--rnn_size', type=int, default=1280, 7 | help='size of the rnn in number of hidden nodes in question gru') 8 | parser.add_argument('--num_hid', type=int, default=1280, 9 | help='size of the rnn in number of hidden nodes in question gru') 10 | parser.add_argument('--num_layers', type=int, default=2, 11 | help='number of GCN layers') 12 | parser.add_argument('--rnn_type', type=str, default='gru', 13 | help='rnn, gru, or lstm') 14 | parser.add_argument('--v_dim', type=int, default=2048, 15 | help='2048 for resnet, 4096 for vgg') 16 | parser.add_argument('--ans_dim', type=int, default=2274, 17 | help='3219 for VQA-CP v2, 2185 for VQA-CP v1') 18 | parser.add_argument('--logit_layers', type=int, default=1, 19 | help='number of layers in the RNN') 20 | parser.add_argument('--activation', type=str, default='ReLU', 21 | help='number of layers in the RNN') 22 | parser.add_argument('--norm', type=str, default='weight', 23 | help='number of layers in the RNN') 24 | parser.add_argument('--initializer', type=str, default='kaiming_normal', 25 | help='number of layers in the RNN') 26 | 27 | # Optimization: General 28 | parser.add_argument('--num_epochs', type=int, default=20, 29 | help='number of epochs') 30 | parser.add_argument('--train_candi_ans_num', type=int, default=20, 31 | help='number of candidate answers') 32 | 33 | parser.add_argument('--s_epoch', type=int, default=0, 34 | help='training from s epochs') 35 | parser.add_argument('--ratio', type=float, default=1, 36 | help='ratio of training set used') 37 | parser.add_argument('--batch_size', type=int, default=32, 38 | help='minibatch size') 39 | parser.add_argument('--grad_clip', type=float, default=0.25, 40 | help='clip gradients at this value') 41 | parser.add_argument('--dropC', type=float, default=0.5, 42 | help='strength of dropout in the Language Model RNN') 43 | parser.add_argument('--dropG', type=float, default=0.2, 44 | help='strength of dropout in the Language Model RNN') 45 | parser.add_argument('--dropL', type=float, default=0.1, 46 | help='strength of dropout in the Language Model RNN') 47 | parser.add_argument('--dropW', type=float, default=0.4, 48 | help='strength of dropout in the Language Model RNN') 49 | parser.add_argument('--dropout', type=float, default=0.2, 50 | help='strength of dropout in the Language Model RNN') 51 | 52 | #Optimization: for the Language Model 53 | parser.add_argument('--optimizer', type=str, default='adam', 54 | help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam') 55 | parser.add_argument('--learning_rate', type=float, default=1e-5, 56 | help='learning rate') 57 | parser.add_argument('--self_loss_weight', type=float, default=3, 58 | help='self-supervised loss weight') 59 | parser.add_argument('--self_sup', type=int, default=1, 60 | help='whether using self-sup processing') 61 | parser.add_argument('--optim_alpha', type=float, default=0.9, 62 | help='alpha for adam') 63 | parser.add_argument('--optim_beta', type=float, default=0.999, 64 | help='beta used for adam') 65 | parser.add_argument('--optim_epsilon', type=float, default=1e-8, 66 | help='epsilon that goes into denominator for smoothing') 67 | parser.add_argument('--weight_decay', type=float, default=0, 68 | help='weight_decay') 69 | parser.add_argument('--seed', type=int, default=1024, 70 | help='seed') 71 | parser.add_argument('--ntokens', type=int, default=777, 72 | help='ntokens') 73 | 74 | parser.add_argument('--dataroot', type=str, default='../../SSL-VQA/data/vqacp2/',help='dataroot') 75 | parser.add_argument('--img_root', type=str, default='../../SSL-VQA/data/coco/',help='image_root') 76 | 77 | parser.add_argument('--checkpoint_path4test', type=str, default='saved_models_cp2/base/SAR_LMH_top20_best_model.pth', 78 | help='directory to store checkpointed models4test, used for testing') 79 | parser.add_argument('--checkpoint_path4test_QTDmodel', type=str, default='data4VE/offline-QTD_model.pth', 80 | help='directory to store the QTDmodel, used for testing') 81 | parser.add_argument('--lp', type=int, default=0, #[0, 1, 2] 82 | help='the combination with Language-Priors method: 0-Non_LP; 1-SSL; 2-LMH') 83 | parser.add_argument('--test_candi_ans_num', type=int, default=12, 84 | help='number of candidate answers in test') 85 | parser.add_argument('--QTD_N4yesno', type=int, default=1, 86 | help='number for the candidate answers of yes/no question in test') 87 | parser.add_argument('--QTD_N4non_yesno', type=int, default=12, 88 | help='number for the candidate answers of non-yes/no question in test') 89 | 90 | 91 | 92 | parser.add_argument('--input', type=str, default=None) 93 | parser.add_argument('--output', type=str, default='saved_models_cp2/base/') 94 | parser.add_argument('--debug', action='store_true') 95 | parser.add_argument('--logits', action='store_true') 96 | parser.add_argument('--index', type=int, default=0) 97 | parser.add_argument('--label', type=str, default='best') 98 | 99 | args = parser.parse_args() 100 | 101 | return args 102 | -------------------------------------------------------------------------------- /saved_models_cp2/log.txt: -------------------------------------------------------------------------------- 1 | nParams= 209122055 2 | 3 | lr: 0.0000100 4 | epoch 0, time: 19661.28 5 | train_loss: 2.45, norm: 6.7790, score: 142375.67 6 | eval score: 54.89 (86.42) 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import errno 3 | import os 4 | from PIL import Image 5 | import torch 6 | import torch.nn as nn 7 | import re 8 | 9 | import json 10 | import pickle as cPickle 11 | import numpy as np 12 | import utils 13 | import h5py 14 | import operator 15 | import functools 16 | from torch._six import string_classes 17 | import torch.nn.functional as F 18 | import collections 19 | 20 | from torch.utils.data.dataloader import default_collate 21 | 22 | 23 | EPS = 1e-7 24 | 25 | 26 | def assert_eq(real, expected): 27 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 28 | 29 | 30 | def assert_array_eq(real, expected): 31 | assert (np.abs(real-expected) < EPS).all(), \ 32 | '%s (true) vs %s (expected)' % (real, expected) 33 | 34 | 35 | def load_folder(folder, suffix): 36 | imgs = [] 37 | for f in sorted(os.listdir(folder)): 38 | if f.endswith(suffix): 39 | imgs.append(os.path.join(folder, f)) 40 | return imgs 41 | 42 | 43 | def load_imageid(folder): 44 | images = load_folder(folder, 'jpg') 45 | img_ids = set() 46 | for img in images: 47 | img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1]) 48 | img_ids.add(img_id) 49 | return img_ids 50 | 51 | 52 | def pil_loader(path): 53 | with open(path, 'rb') as f: 54 | with Image.open(f) as img: 55 | return img.convert('RGB') 56 | 57 | 58 | def weights_init(m): 59 | """custom weights initialization.""" 60 | cname = m.__class__ 61 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d: 62 | m.weight.data.normal_(0.0, 0.02) 63 | elif cname == nn.BatchNorm2d: 64 | m.weight.data.normal_(1.0, 0.02) 65 | m.bias.data.fill_(0) 66 | else: 67 | print('%s is not initialized.' % cname) 68 | 69 | 70 | def init_net(net, net_file): 71 | if net_file: 72 | net.load_state_dict(torch.load(net_file)) 73 | else: 74 | net.apply(weights_init) 75 | 76 | 77 | def create_dir(path): 78 | if not os.path.exists(path): 79 | try: 80 | os.makedirs(path) 81 | except OSError as exc: 82 | if exc.errno != errno.EEXIST: 83 | raise 84 | 85 | 86 | class Logger(object): 87 | def __init__(self, output_name): 88 | dirname = os.path.dirname(output_name) 89 | if not os.path.exists(dirname): 90 | os.mkdir(dirname) 91 | 92 | self.log_file = open(output_name, 'w') 93 | self.infos = {} 94 | 95 | def append(self, key, val): 96 | vals = self.infos.setdefault(key, []) 97 | vals.append(val) 98 | 99 | def log(self, extra_msg=''): 100 | msgs = [extra_msg] 101 | for key, vals in self.infos.iteritems(): 102 | msgs.append('%s %.6f' % (key, np.mean(vals))) 103 | msg = '\n'.join(msgs) 104 | self.log_file.write(msg + '\n') 105 | self.log_file.flush() 106 | self.infos = {} 107 | return msg 108 | 109 | def write(self, msg): 110 | self.log_file.write(msg + '\n') 111 | self.log_file.flush() 112 | print(msg) 113 | 114 | def print_model(model, logger): 115 | print(model) 116 | nParams = 0 117 | for w in model.parameters(): 118 | nParams += functools.reduce(operator.mul, w.size(), 1) 119 | if logger: 120 | logger.write('nParams=\t'+str(nParams)) 121 | 122 | 123 | def save_model(path, model, epoch, optimizer=None): 124 | model_dict = { 125 | 'epoch': epoch, 126 | 'model_state': model.state_dict() 127 | } 128 | if optimizer is not None: 129 | model_dict['optimizer_state'] = optimizer.state_dict() 130 | 131 | torch.save(model_dict, path) 132 | 133 | def rho_select(pad, lengths): 134 | # Index of the last output for each sequence. 135 | idx_ = (lengths-1).view(-1,1).expand(pad.size(0), pad.size(2)).unsqueeze(1) 136 | extracted = pad.gather(1, idx_).squeeze(1) 137 | return extracted 138 | 139 | def trim_collate(batch): 140 | "Puts each data field into a tensor with outer dimension batch size" 141 | _use_shared_memory = True 142 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 143 | elem_type = type(batch[0]) 144 | if torch.is_tensor(batch[0]): 145 | out = None 146 | #print("batch[0].dim()",len(batch[0]),batch[0].dim(),batch[0]) 147 | if 1 < batch[0].dim(): # image features 148 | max_num_boxes = max([x.size(0) for x in batch]) 149 | if _use_shared_memory: 150 | # If we're in a background process, concatenate directly into a 151 | # shared memory tensor to avoid an extra copy 152 | numel = len(batch) * max_num_boxes * batch[0].size(-1) 153 | storage = batch[0].storage()._new_shared(numel) 154 | out = batch[0].new(storage) 155 | # warning: F.pad returns Variable! 156 | return torch.stack([F.pad(x, (0,0,0,max_num_boxes-x.size(0))).data for x in batch], 0, out=out) 157 | else: 158 | if _use_shared_memory: 159 | # If we're in a background process, concatenate directly into a 160 | # shared memory tensor to avoid an extra copy 161 | numel = sum([x.numel() for x in batch]) 162 | storage = batch[0].storage()._new_shared(numel) 163 | out = batch[0].new(storage) 164 | #print("batch",batch,"\n\n\n",len(batch)) 165 | return torch.stack(batch, 0, out=out) 166 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 167 | and elem_type.__name__ != 'string_': 168 | elem = batch[0] 169 | if elem_type.__name__ == 'ndarray': 170 | # array of string classes and object 171 | if re.search('[SaUO]', elem.dtype.str) is not None: 172 | raise TypeError(error_msg.format(elem.dtype)) 173 | 174 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 175 | if elem.shape == (): # scalars 176 | py_type = float if elem.dtype.name.startswith('float') else int 177 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 178 | elif isinstance(batch[0], int): 179 | return torch.LongTensor(batch) 180 | elif isinstance(batch[0], float): 181 | return torch.DoubleTensor(batch) 182 | elif isinstance(batch[0], string_classes): 183 | return batch 184 | elif isinstance(batch[0], collections.Mapping): 185 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 186 | elif isinstance(batch[0], collections.Sequence): 187 | transposed = zip(*batch) 188 | return [trim_collate(samples) for samples in transposed] 189 | 190 | raise TypeError((error_msg.format(type(batch[0])))) 191 | 192 | 193 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 194 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 195 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 196 | indices = torch.from_numpy( 197 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 198 | values = torch.from_numpy(sparse_mx.data) 199 | shape = torch.Size(sparse_mx.shape) 200 | return torch.sparse.FloatTensor(indices, values, shape) 201 | 202 | 203 | def mask_softmax(x, lengths): # , dim=1) 204 | mask = torch.zeros_like(x).to(device=x.device, non_blocking=True) 205 | t_lengths = lengths[:, :, None].expand_as(mask) 206 | arange_id = torch.arange(mask.size(1)).to(device=x.device, non_blocking=True) 207 | arange_id = arange_id[None, :, None].expand_as(mask) 208 | 209 | mask[arange_id < t_lengths] = 1 210 | # https://stackoverflow.com/questions/42599498/numercially-stable-softmax 211 | # https://stackoverflow.com/questions/34968722/how-to-implement-the-softmax-function-in-python 212 | # exp(x - max(x)) instead of exp(x) is a trick 213 | # to improve the numerical stability while giving 214 | # the same outputs 215 | x2 = torch.exp(x - torch.max(x)) 216 | x3 = x2 * mask 217 | epsilon = 1e-5 218 | x3_sum = torch.sum(x3, dim=1, keepdim=True) + epsilon 219 | x4 = x3 / x3_sum.expand_as(x3) 220 | return x4 221 | 222 | 223 | class GradReverseMask(torch.autograd.Function): 224 | """ 225 | This layer is used to create an adversarial loss. 226 | 227 | """ 228 | 229 | @staticmethod 230 | def forward(ctx, x, mask, weight): 231 | """ 232 | The mask should be composed of 0 or 1. 233 | The '1' will get their gradient reversed.. 234 | """ 235 | ctx.save_for_backward(mask) 236 | ctx.weight = weight 237 | return x.view_as(x) 238 | 239 | @staticmethod 240 | def backward(ctx, grad_output): 241 | mask, = ctx.saved_tensors 242 | mask_c = mask.clone().detach().float() 243 | mask_c[mask == 0] = 1.0 244 | mask_c[mask == 1] = - float(ctx.weight) 245 | return grad_output * mask_c[:, None].float(), None, None 246 | 247 | 248 | def grad_reverse_mask(x, mask, weight=1): 249 | return GradReverseMask.apply(x, mask, weight) 250 | 251 | 252 | class GradReverse(torch.autograd.Function): 253 | """ 254 | This layer is used to create an adversarial loss. 255 | """ 256 | 257 | @staticmethod 258 | def forward(ctx, x): 259 | return x.view_as(x) 260 | 261 | @staticmethod 262 | def backward(ctx, grad_output): 263 | return grad_output.neg() 264 | 265 | 266 | def grad_reverse(x): 267 | return GradReverse.apply(x) 268 | 269 | 270 | class GradMulConst(torch.autograd.Function): 271 | """ 272 | This layer is used to create an adversarial loss. 273 | """ 274 | 275 | @staticmethod 276 | def forward(ctx, x, const): 277 | ctx.const = const 278 | return x.view_as(x) 279 | 280 | @staticmethod 281 | def backward(ctx, grad_output): 282 | return grad_output * ctx.const, None 283 | 284 | 285 | def grad_mul_const(x, const): 286 | return GradMulConst.apply(x, const) 287 | --------------------------------------------------------------------------------