├── README.md ├── attention.py ├── base_model.py ├── classifier.py ├── dataset.py ├── eval.py ├── fc.py ├── language_model.py ├── main.py ├── rubi_base_model.py ├── rubi_main.py ├── rubi_train.py ├── tools ├── compute_softscore.py ├── create_dictionary.py ├── create_dictionary_v1.py ├── download.sh └── process.sh ├── train.py ├── util ├── qid2type_cpv1.json ├── qid2type_cpv2.json └── qid2type_v2.json ├── utils.py └── vqa_debias_loss_functions.py /README.md: -------------------------------------------------------------------------------- 1 | # CVPR2020 Counterfactual Samples Synthesizing for Robust VQA 2 | This repo contains code for our paper ["Counterfactual Samples Synthesizing for Robust Visual Question Answering"](https://arxiv.org/pdf/2003.06576.pdf) 3 | This repo contains code modified from [here](https://github.com/chrisc36/bottom-up-attention-vqa),many thanks! 4 | 5 | ### Prerequisites 6 | 7 | Make sure you are on a machine with a NVIDIA GPU and Python 2.7 with about 100 GB disk space.
8 | h5py==2.10.0
9 | pytorch==1.1.0
10 | Click==7.0
11 | numpy==1.16.5
12 | tqdm==4.35.0
13 | 14 | ### Data Setup 15 | You can use 16 | ``` 17 | bash tools/download.sh 18 | ``` 19 | to download the data
20 | and the rest of the data and trained model can be obtained from [BaiduYun](https://pan.baidu.com/s/1oHdwYDSJXC1mlmvu8cQhKw)(passwd:3jot) or [MEGADrive](https://mega.nz/folder/0JBzGBZD#YGgonKMnwqmeSZmoV7hjMg) 21 | unzip feature1.zip and feature2.zip and merge them into data/rcnn_feature/
22 | use 23 | ``` 24 | bash tools/process.sh 25 | ``` 26 | to process the data
27 | 28 | ### Training 29 | Run 30 | ``` 31 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset cpv2 --mode q_v_debias --debias learned_mixin --topq 1 --topv -1 --qvp 5 --output [] --seed 0 32 | ``` 33 | to train a model 34 | 35 | ### Testing 36 | Run 37 | ``` 38 | CUDA_VISIBLE_DEVICES=0 python eval.py --dataset cpv2 --debias learned_mixin --model_state [] 39 | ``` 40 | to eval a model 41 | 42 | 43 | 44 | ## Citation 45 | 46 | If you find this code useful, please cite the following paper: 47 | 48 | ``` 49 | @inproceedings{chen2020counterfactual, 50 | title={Counterfactual Samples Synthesizing for Robust Visual Question Answering}, 51 | author={Chen, Long and Yan, Xin and Xiao, Jun and Zhang, Hanwang and Pu, Shiliang and Zhuang, Yueting}, 52 | booktitle={CVPR}, 53 | year={2020} 54 | } 55 | ``` 56 | 57 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | from fc import FCNet 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, v_dim, q_dim, num_hid): 9 | super(Attention, self).__init__() 10 | self.nonlinear = FCNet([v_dim + q_dim, num_hid]) 11 | self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None) 12 | 13 | def forward(self, v, q): 14 | """ 15 | v: [batch, k, vdim] 16 | q: [batch, qdim] 17 | """ 18 | logits = self.logits(v, q) 19 | w = nn.functional.softmax(logits, 1) 20 | return w 21 | 22 | def logits(self, v, q): 23 | num_objs = v.size(1) 24 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 25 | vq = torch.cat((v, q), 2) 26 | joint_repr = self.nonlinear(vq) 27 | logits = self.linear(joint_repr) 28 | return logits 29 | 30 | 31 | class NewAttention(nn.Module): 32 | def __init__(self, v_dim, q_dim, num_hid, dropout=0.2): 33 | super(NewAttention, self).__init__() 34 | 35 | self.v_proj = FCNet([v_dim, num_hid]) 36 | self.q_proj = FCNet([q_dim, num_hid]) 37 | self.dropout = nn.Dropout(dropout) 38 | self.linear = weight_norm(nn.Linear(q_dim, 1), dim=None) 39 | 40 | def forward(self, v, q): 41 | """ 42 | v: [batch, k, vdim] 43 | q: [batch, qdim] 44 | """ 45 | logits = self.logits(v, q) 46 | # w = nn.functional.softmax(logits, 1) 47 | # return w 48 | return logits 49 | 50 | def logits(self, v, q): 51 | batch, k, _ = v.size() 52 | v_proj = self.v_proj(v) # [batch, k, qdim] 53 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) 54 | joint_repr = v_proj * q_proj 55 | joint_repr = self.dropout(joint_repr) 56 | logits = self.linear(joint_repr) 57 | return logits 58 | -------------------------------------------------------------------------------- /base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attention import Attention, NewAttention 4 | from language_model import WordEmbedding, QuestionEmbedding 5 | from classifier import SimpleClassifier 6 | from fc import FCNet 7 | import numpy as np 8 | 9 | def mask_softmax(x,mask): 10 | mask=mask.unsqueeze(2).float() 11 | x2=torch.exp(x-torch.max(x)) 12 | x3=x2*mask 13 | epsilon=1e-5 14 | x3_sum=torch.sum(x3,dim=1,keepdim=True)+epsilon 15 | x4=x3/x3_sum.expand_as(x3) 16 | return x4 17 | 18 | 19 | class BaseModel(nn.Module): 20 | def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier): 21 | super(BaseModel, self).__init__() 22 | self.w_emb = w_emb 23 | self.q_emb = q_emb 24 | self.v_att = v_att 25 | self.q_net = q_net 26 | self.v_net = v_net 27 | self.classifier = classifier 28 | self.debias_loss_fn = None 29 | # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2)) 30 | self.bias_lin = torch.nn.Linear(1024, 1) 31 | 32 | def forward(self, v, q, labels, bias,v_mask): 33 | """Forward 34 | 35 | v: [batch, num_objs, obj_dim] 36 | b: [batch, num_objs, b_dim] 37 | q: [batch_size, seq_length] 38 | 39 | return: logits, not probs 40 | """ 41 | w_emb = self.w_emb(q) 42 | q_emb = self.q_emb(w_emb) # [batch, q_dim] 43 | 44 | att = self.v_att(v, q_emb) 45 | if v_mask is None: 46 | att = nn.functional.softmax(att, 1) 47 | else: 48 | att= mask_softmax(att,v_mask) 49 | 50 | v_emb = (att * v).sum(1) # [batch, v_dim] 51 | 52 | q_repr = self.q_net(q_emb) 53 | v_repr = self.v_net(v_emb) 54 | joint_repr = q_repr * v_repr 55 | 56 | logits = self.classifier(joint_repr) 57 | 58 | if labels is not None: 59 | loss = self.debias_loss_fn(joint_repr, logits, bias, labels) 60 | 61 | else: 62 | loss = None 63 | return logits, loss,w_emb 64 | 65 | def build_baseline0(dataset, num_hid): 66 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 67 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 68 | v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid) 69 | q_net = FCNet([num_hid, num_hid]) 70 | v_net = FCNet([dataset.v_dim, num_hid]) 71 | classifier = SimpleClassifier( 72 | num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5) 73 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) 74 | 75 | 76 | def build_baseline0_newatt(dataset, num_hid): 77 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 78 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 79 | v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid) 80 | q_net = FCNet([q_emb.num_hid, num_hid]) 81 | v_net = FCNet([dataset.v_dim, num_hid]) 82 | classifier = SimpleClassifier( 83 | num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5) 84 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.weight_norm import weight_norm 3 | 4 | 5 | class SimpleClassifier(nn.Module): 6 | def __init__(self, in_dim, hid_dim, out_dim, dropout): 7 | super(SimpleClassifier, self).__init__() 8 | layers = [ 9 | weight_norm(nn.Linear(in_dim, hid_dim), dim=None), 10 | nn.ReLU(), 11 | nn.Dropout(dropout, inplace=True), 12 | weight_norm(nn.Linear(hid_dim, out_dim), dim=None) 13 | ] 14 | self.main = nn.Sequential(*layers) 15 | 16 | def forward(self, x): 17 | logits = self.main(x) 18 | return logits 19 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import unicode_literals 3 | 4 | import os 5 | import json 6 | import cPickle 7 | from collections import Counter 8 | 9 | import numpy as np 10 | import utils 11 | import h5py 12 | import torch 13 | from torch.utils.data import Dataset 14 | from tqdm import tqdm 15 | from random import choice 16 | 17 | class Dictionary(object): 18 | def __init__(self, word2idx=None, idx2word=None): 19 | if word2idx is None: 20 | word2idx = {} 21 | if idx2word is None: 22 | idx2word = [] 23 | self.word2idx = word2idx 24 | self.idx2word = idx2word 25 | 26 | @property 27 | def ntoken(self): 28 | return len(self.word2idx) 29 | 30 | @property 31 | def padding_idx(self): 32 | return len(self.word2idx) 33 | 34 | def tokenize(self, sentence, add_word): 35 | sentence = sentence.lower() 36 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s').replace('-', 37 | ' ').replace('.','').replace('"', '').replace('n\'t', ' not').replace('$', ' dollar ') 38 | words = sentence.split() 39 | tokens = [] 40 | if add_word: 41 | for w in words: 42 | tokens.append(self.add_word(w)) 43 | else: 44 | for w in words: 45 | if w in self.word2idx: 46 | tokens.append(self.word2idx[w]) 47 | else: 48 | tokens.append(len(self.word2idx)) 49 | return tokens 50 | 51 | def dump_to_file(self, path): 52 | cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 53 | print('dictionary dumped to %s' % path) 54 | 55 | @classmethod 56 | def load_from_file(cls, path): 57 | print('loading dictionary from %s' % path) 58 | word2idx, idx2word = cPickle.load(open(path, 'rb')) 59 | d = cls(word2idx, idx2word) 60 | return d 61 | 62 | def add_word(self, word): 63 | if word not in self.word2idx: 64 | self.idx2word.append(word) 65 | self.word2idx[word] = len(self.idx2word) - 1 66 | return self.word2idx[word] 67 | 68 | def __len__(self): 69 | return len(self.idx2word) 70 | 71 | 72 | def _create_entry(img_idx, question, answer): 73 | answer.pop('image_id') 74 | answer.pop('question_id') 75 | entry = { 76 | 'question_id' : question['question_id'], 77 | 'image_id' : question['image_id'], 78 | 'image_idx' : img_idx, 79 | 'question' : question['question'], 80 | 'answer' : answer 81 | } 82 | return entry 83 | 84 | 85 | def _load_dataset(dataroot, name, img_id2val, dataset): 86 | """Load entries 87 | 88 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features 89 | dataroot: root path of dataset 90 | name: 'train', 'val' 91 | """ 92 | if dataset=='cpv2': 93 | answer_path = os.path.join(dataroot, 'cp-cache', '%s_target.pkl' % name) 94 | name = "train" if name == "train" else "test" 95 | question_path = os.path.join(dataroot, 'vqacp_v2_%s_questions.json' % name) 96 | with open(question_path) as f: 97 | questions = json.load(f) 98 | elif dataset=='cpv1': 99 | answer_path = os.path.join(dataroot, 'cp-v1-cache', '%s_target.pkl' % name) 100 | name = "train" if name == "train" else "test" 101 | question_path = os.path.join(dataroot, 'vqacp_v1_%s_questions.json' % name) 102 | with open(question_path) as f: 103 | questions = json.load(f) 104 | elif dataset=='v2': 105 | answer_path = os.path.join(dataroot, 'cache', '%s_target.pkl' % name) 106 | question_path = os.path.join(dataroot, 'v2_OpenEnded_mscoco_%s2014_questions.json' % name) 107 | with open(question_path) as f: 108 | questions = json.load(f)["questions"] 109 | 110 | with open(answer_path, 'rb') as f: 111 | answers = cPickle.load(f) 112 | 113 | questions.sort(key=lambda x: x['question_id']) 114 | answers.sort(key=lambda x: x['question_id']) 115 | 116 | utils.assert_eq(len(questions), len(answers)) 117 | entries = [] 118 | for question, answer in zip(questions, answers): 119 | if answer["labels"] is None: 120 | raise ValueError() 121 | utils.assert_eq(question['question_id'], answer['question_id']) 122 | utils.assert_eq(question['image_id'], answer['image_id']) 123 | img_id = question['image_id'] 124 | img_idx = None 125 | if img_id2val: 126 | img_idx = img_id2val[img_id] 127 | 128 | entries.append(_create_entry(img_idx, question, answer)) 129 | return entries 130 | 131 | 132 | class VQAFeatureDataset(Dataset): 133 | def __init__(self, name, dictionary, dataroot='data', dataset='cpv2', 134 | use_hdf5=False, cache_image_features=False): 135 | super(VQAFeatureDataset, self).__init__() 136 | self.name=name 137 | if dataset=='cpv2': 138 | with open('data/train_cpv2_hintscore.json', 'r') as f: 139 | self.train_hintscore = json.load(f) 140 | with open('data/test_cpv2_hintscore.json', 'r') as f: 141 | self.test_hintsocre = json.load(f) 142 | with open('util/cpv2_type_mask.json', 'r') as f: 143 | self.type_mask = json.load(f) 144 | with open('util/cpv2_notype_mask.json', 'r') as f: 145 | self.notype_mask = json.load(f) 146 | 147 | elif dataset=='cpv1': 148 | with open('data/train_cpv1_hintscore.json', 'r') as f: 149 | self.train_hintscore = json.load(f) 150 | with open('data/test_cpv1_hintscore.json', 'r') as f: 151 | self.test_hintsocre = json.load(f) 152 | with open('util/cpv1_type_mask.json', 'r') as f: 153 | self.type_mask = json.load(f) 154 | with open('util/cpv1_notype_mask.json', 'r') as f: 155 | self.notype_mask = json.load(f) 156 | elif dataset=='v2': 157 | with open('data/train_v2_hintscore.json', 'r') as f: 158 | self.train_hintscore = json.load(f) 159 | with open('data/test_v2_hintscore.json', 'r') as f: 160 | self.test_hintsocre = json.load(f) 161 | with open('util/v2_type_mask.json', 'r') as f: 162 | self.type_mask = json.load(f) 163 | with open('util/v2_notype_mask.json', 'r') as f: 164 | self.notype_mask = json.load(f) 165 | 166 | assert name in ['train', 'val'] 167 | 168 | if dataset=='cpv2': 169 | ans2label_path = os.path.join(dataroot, 'cp-cache', 'trainval_ans2label.pkl') 170 | label2ans_path = os.path.join(dataroot, 'cp-cache', 'trainval_label2ans.pkl') 171 | elif dataset=='cpv1': 172 | ans2label_path = os.path.join(dataroot, 'cp-v1-cache', 'trainval_ans2label.pkl') 173 | label2ans_path = os.path.join(dataroot, 'cp-v1-cache', 'trainval_label2ans.pkl') 174 | elif dataset=='v2': 175 | ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl') 176 | label2ans_path = os.path.join(dataroot, 'cache', 'trainval_label2ans.pkl') 177 | self.ans2label = cPickle.load(open(ans2label_path, 'rb')) 178 | self.label2ans = cPickle.load(open(label2ans_path, 'rb')) 179 | self.num_ans_candidates = len(self.ans2label) 180 | 181 | self.dictionary = dictionary 182 | self.use_hdf5 = use_hdf5 183 | 184 | if use_hdf5: 185 | h5_path = os.path.join(dataroot, '%s36.hdf5'%name) 186 | self.hf = h5py.File(h5_path, 'r') 187 | self.features = self.hf.get('image_features') 188 | 189 | with open("util/%s36_imgid2img.pkl"%name, "rb") as f: 190 | imgid2idx = cPickle.load(f) 191 | else: 192 | imgid2idx = None 193 | 194 | self.entries = _load_dataset(dataroot, name, imgid2idx, dataset=dataset) 195 | if cache_image_features: 196 | image_to_fe = {} 197 | for entry in tqdm(self.entries, ncols=100, desc="caching-features"): 198 | img_id = entry["image_id"] 199 | if img_id not in image_to_fe: 200 | if use_hdf5: 201 | fe = np.array(self.features[imgid2idx[img_id]]) 202 | else: 203 | fe=torch.load('data/rcnn_feature/'+str(img_id)+'.pth')['image_feature'] 204 | image_to_fe[img_id]=fe 205 | self.image_to_fe = image_to_fe 206 | if use_hdf5: 207 | self.hf.close() 208 | else: 209 | self.image_to_fe = None 210 | 211 | self.tokenize() 212 | self.tensorize() 213 | 214 | self.v_dim = 2048 215 | 216 | def tokenize(self, max_length=14): 217 | """Tokenizes the questions. 218 | 219 | This will add q_token in each entry of the dataset. 220 | -1 represent nil, and should be treated as padding_idx in embedding 221 | """ 222 | for entry in tqdm(self.entries, ncols=100, desc="tokenize"): 223 | tokens = self.dictionary.tokenize(entry['question'], False) 224 | tokens = tokens[:max_length] 225 | if len(tokens) < max_length: 226 | # Note here we pad in front of the sentence 227 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 228 | padding_mask=[self.dictionary.padding_idx-1] * (max_length - len(tokens)) 229 | tokens_mask = padding_mask + tokens 230 | tokens = padding + tokens 231 | 232 | utils.assert_eq(len(tokens), max_length) 233 | entry['q_token'] = tokens 234 | entry['q_token_mask']=tokens_mask 235 | 236 | def tensorize(self): 237 | for entry in tqdm(self.entries, ncols=100, desc="tensorize"): 238 | question = torch.from_numpy(np.array(entry['q_token'])) 239 | question_mask = torch.from_numpy(np.array(entry['q_token_mask'])) 240 | 241 | entry['q_token'] = question 242 | entry['q_token_mask']=question_mask 243 | 244 | answer = entry['answer'] 245 | labels = np.array(answer['labels']) 246 | scores = np.array(answer['scores'], dtype=np.float32) 247 | if len(labels): 248 | labels = torch.from_numpy(labels) 249 | scores = torch.from_numpy(scores) 250 | entry['answer']['labels'] = labels 251 | entry['answer']['scores'] = scores 252 | else: 253 | entry['answer']['labels'] = None 254 | entry['answer']['scores'] = None 255 | 256 | def __getitem__(self, index): 257 | entry = self.entries[index] 258 | if self.image_to_fe is not None: 259 | features = self.image_to_fe[entry["image_id"]] 260 | elif self.use_hdf5: 261 | features = np.array(self.features[entry['image_idx']]) 262 | features = torch.from_numpy(features).view(36, 2048) 263 | else: 264 | features = torch.load('data/rcnn_feature/' + str(entry["image_id"]) + '.pth')['image_feature'] 265 | 266 | q_id=entry['question_id'] 267 | ques = entry['q_token'] 268 | ques_mask=entry['q_token_mask'] 269 | answer = entry['answer'] 270 | labels = answer['labels'] 271 | scores = answer['scores'] 272 | target = torch.zeros(self.num_ans_candidates) 273 | if labels is not None: 274 | target.scatter_(0, labels, scores) 275 | 276 | if self.name=='train': 277 | train_hint=torch.tensor(self.train_hintscore[str(q_id)]) 278 | type_mask=torch.tensor(self.type_mask[str(q_id)]) 279 | notype_mask=torch.tensor(self.notype_mask[str(q_id)]) 280 | if "bias" in entry: 281 | return features, ques, target,entry["bias"],train_hint,type_mask,notype_mask,ques_mask 282 | 283 | else: 284 | return features, ques,target, 0,train_hint 285 | else: 286 | test_hint=torch.tensor(self.test_hintsocre[str(q_id)]) 287 | if "bias" in entry: 288 | return features, ques, target, entry["bias"],q_id,test_hint 289 | else: 290 | return features, ques, target, 0,q_id,test_hint 291 | 292 | def __len__(self): 293 | return len(self.entries) 294 | 295 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import cPickle 4 | from collections import defaultdict, Counter 5 | from os.path import dirname, join 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | import numpy as np 11 | import os 12 | 13 | # from new_dataset import Dictionary, VQAFeatureDataset 14 | from dataset import Dictionary, VQAFeatureDataset 15 | import base_model 16 | from train import train 17 | import utils 18 | 19 | from vqa_debias_loss_functions import * 20 | from tqdm import tqdm 21 | from torch.autograd import Variable 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method") 26 | 27 | # Arguments we added 28 | parser.add_argument( 29 | '--cache_features', default=True, 30 | help="Cache image features in RAM. Makes things much faster, " 31 | "especially if the filesystem is slow, but requires at least 48gb of RAM") 32 | parser.add_argument( 33 | '--dataset', default='cpv2', help="Run on VQA-2.0 instead of VQA-CP 2.0") 34 | parser.add_argument( 35 | '-p', "--entropy_penalty", default=0.36, type=float, 36 | help="Entropy regularizer weight for the learned_mixin model") 37 | parser.add_argument( 38 | '--debias', default="learned_mixin", 39 | choices=["learned_mixin", "reweight", "bias_product", "none"], 40 | help="Kind of ensemble loss to use") 41 | # Arguments from the original model, we leave this default, except we 42 | # set --epochs to 15 since the model maxes out its performance on VQA 2.0 well before then 43 | parser.add_argument('--num_hid', type=int, default=1024) 44 | parser.add_argument('--model', type=str, default='baseline0_newatt') 45 | parser.add_argument('--batch_size', type=int, default=512) 46 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 47 | parser.add_argument('--model_state', type=str, default='logs/exp0/model.pth') 48 | args = parser.parse_args() 49 | return args 50 | 51 | def compute_score_with_logits(logits, labels): 52 | # logits = torch.max(logits, 1)[1].data # argmax 53 | logits = torch.argmax(logits,1) 54 | one_hots = torch.zeros(*labels.size()).cuda() 55 | one_hots.scatter_(1, logits.view(-1, 1), 1) 56 | scores = (one_hots * labels) 57 | return scores 58 | 59 | 60 | def evaluate(model,dataloader,qid2type): 61 | score = 0 62 | upper_bound = 0 63 | score_yesno = 0 64 | score_number = 0 65 | score_other = 0 66 | total_yesno = 0 67 | total_number = 0 68 | total_other = 0 69 | model.train(False) 70 | # import pdb;pdb.set_trace() 71 | 72 | 73 | for v, q, a, b,qids,hintscore in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 74 | v = Variable(v, requires_grad=False).cuda() 75 | q = Variable(q, requires_grad=False).cuda() 76 | pred, _ ,_= model(v, q, None, None,None) 77 | batch_score= compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1) 78 | score += batch_score.sum() 79 | upper_bound += (a.max(1)[0]).sum() 80 | qids = qids.detach().cpu().int().numpy() 81 | for j in range(len(qids)): 82 | qid=qids[j] 83 | typ = qid2type[str(qid)] 84 | if typ == 'yes/no': 85 | score_yesno += batch_score[j] 86 | total_yesno += 1 87 | elif typ == 'other': 88 | score_other += batch_score[j] 89 | total_other += 1 90 | elif typ == 'number': 91 | score_number += batch_score[j] 92 | total_number += 1 93 | else: 94 | print('Hahahahahahahahahahaha') 95 | score = score / len(dataloader.dataset) 96 | upper_bound = upper_bound / len(dataloader.dataset) 97 | score_yesno /= total_yesno 98 | score_other /= total_other 99 | score_number /= total_number 100 | print('\teval overall score: %.2f' % (100 * score)) 101 | print('\teval up_bound score: %.2f' % (100 * upper_bound)) 102 | print('\teval y/n score: %.2f' % (100 * score_yesno)) 103 | print('\teval other score: %.2f' % (100 * score_other)) 104 | print('\teval number score: %.2f' % (100 * score_number)) 105 | 106 | def evaluate_ai(model,dataloader,qid2type,label2ans): 107 | score=0 108 | upper_bound=0 109 | 110 | ai_top1=0 111 | ai_top2=0 112 | ai_top3=0 113 | 114 | for v, q, a, b, qids, hintscore in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 115 | v = Variable(v, requires_grad=False).cuda().float().requires_grad_() 116 | q = Variable(q, requires_grad=False).cuda() 117 | a=a.cuda() 118 | hintscore=hintscore.cuda().float() 119 | pred, _, _ = model(v, q, None, None, None) 120 | vqa_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] # [b , 36, 2048] 121 | 122 | vqa_grad_cam=vqa_grad.sum(2) 123 | sv_ind=torch.argmax(vqa_grad_cam,1) 124 | 125 | x_ind_top1=torch.topk(vqa_grad_cam,k=1)[1] 126 | x_ind_top2=torch.topk(vqa_grad_cam,k=2)[1] 127 | x_ind_top3=torch.topk(vqa_grad_cam,k=3)[1] 128 | 129 | y_score_top1 = hintscore.gather(1,x_ind_top1).sum(1)/1 130 | y_score_top2 = hintscore.gather(1,x_ind_top2).sum(1)/2 131 | y_score_top3 = hintscore.gather(1,x_ind_top3).sum(1)/3 132 | 133 | 134 | batch_score=compute_score_with_logits(pred,a.cuda()).cpu().numpy().sum(1) 135 | score+=batch_score.sum() 136 | upper_bound+=(a.max(1)[0]).sum() 137 | qids=qids.detach().cpu().int().numpy() 138 | for j in range(len(qids)): 139 | if batch_score[j]>0: 140 | ai_top1 += y_score_top1[j] 141 | ai_top2 += y_score_top2[j] 142 | ai_top3 += y_score_top3[j] 143 | 144 | 145 | 146 | score = score / len(dataloader.dataset) 147 | upper_bound = upper_bound / len(dataloader.dataset) 148 | ai_top1=(ai_top1.item() * 1.0) / len(dataloader.dataset) 149 | ai_top2=(ai_top2.item() * 1.0) / len(dataloader.dataset) 150 | ai_top3=(ai_top3.item() * 1.0) / len(dataloader.dataset) 151 | 152 | print('\teval overall score: %.2f' % (100 * score)) 153 | print('\teval up_bound score: %.2f' % (100 * upper_bound)) 154 | print('\ttop1_ai_score: %.2f' % (100 * ai_top1)) 155 | print('\ttop2_ai_score: %.2f' % (100 * ai_top2)) 156 | print('\ttop3_ai_score: %.2f' % (100 * ai_top3)) 157 | 158 | def main(): 159 | args = parse_args() 160 | dataset = args.dataset 161 | 162 | 163 | with open('util/qid2type_%s.json'%args.dataset,'r') as f: 164 | qid2type=json.load(f) 165 | 166 | if dataset=='cpv1': 167 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl') 168 | elif dataset=='cpv2' or dataset=='v2': 169 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 170 | 171 | print("Building test dataset...") 172 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset, 173 | cache_image_features=args.cache_features) 174 | 175 | # Build the model using the original constructor 176 | constructor = 'build_%s' % args.model 177 | model = getattr(base_model, constructor)(eval_dset, args.num_hid).cuda() 178 | 179 | if args.debias == "bias_product": 180 | model.debias_loss_fn = BiasProduct() 181 | elif args.debias == "none": 182 | model.debias_loss_fn = Plain() 183 | elif args.debias == "reweight": 184 | model.debias_loss_fn = ReweightByInvBias() 185 | elif args.debias == "learned_mixin": 186 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty) 187 | else: 188 | raise RuntimeError(args.mode) 189 | 190 | 191 | model_state = torch.load(args.model_state) 192 | model.load_state_dict(model_state) 193 | 194 | 195 | model = model.cuda() 196 | batch_size = args.batch_size 197 | 198 | torch.manual_seed(args.seed) 199 | torch.cuda.manual_seed(args.seed) 200 | torch.backends.cudnn.benchmark = True 201 | 202 | # The original version uses multiple workers, but that just seems slower on my setup 203 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0) 204 | 205 | 206 | 207 | print("Starting eval...") 208 | 209 | evaluate(model,eval_loader,qid2type) 210 | 211 | 212 | 213 | if __name__ == '__main__': 214 | main() 215 | -------------------------------------------------------------------------------- /fc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | 5 | 6 | class FCNet(nn.Module): 7 | """Simple class for non-linear fully connect network 8 | """ 9 | def __init__(self, dims): 10 | super(FCNet, self).__init__() 11 | 12 | layers = [] 13 | for i in range(len(dims)-2): 14 | in_dim = dims[i] 15 | out_dim = dims[i+1] 16 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 17 | layers.append(nn.ReLU()) 18 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) 19 | layers.append(nn.ReLU()) 20 | 21 | self.main = nn.Sequential(*layers) 22 | 23 | def forward(self, x): 24 | return self.main(x) 25 | 26 | 27 | if __name__ == '__main__': 28 | fc1 = FCNet([10, 20, 10]) 29 | print(fc1) 30 | 31 | print('============') 32 | fc2 = FCNet([10, 20]) 33 | print(fc2) 34 | -------------------------------------------------------------------------------- /language_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | class WordEmbedding(nn.Module): 8 | """Word Embedding 9 | 10 | The ntoken-th dim is used for padding_idx, which agrees *implicitly* 11 | with the definition in Dictionary. 12 | """ 13 | def __init__(self, ntoken, emb_dim, dropout): 14 | super(WordEmbedding, self).__init__() 15 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken) 16 | self.dropout = nn.Dropout(dropout) 17 | self.ntoken = ntoken 18 | self.emb_dim = emb_dim 19 | 20 | def init_embedding(self, np_file): 21 | weight_init = torch.from_numpy(np.load(np_file)) 22 | assert weight_init.shape == (self.ntoken, self.emb_dim) 23 | self.emb.weight.data[:self.ntoken] = weight_init 24 | 25 | def forward(self, x): 26 | emb = self.emb(x) 27 | emb = self.dropout(emb) 28 | return emb 29 | 30 | 31 | class QuestionEmbedding(nn.Module): 32 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'): 33 | """Module for question embedding 34 | """ 35 | super(QuestionEmbedding, self).__init__() 36 | assert rnn_type == 'LSTM' or rnn_type == 'GRU' 37 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU 38 | 39 | self.rnn = rnn_cls( 40 | in_dim, num_hid, nlayers, 41 | bidirectional=bidirect, 42 | dropout=dropout, 43 | batch_first=True) 44 | 45 | self.in_dim = in_dim 46 | self.num_hid = num_hid 47 | self.nlayers = nlayers 48 | self.rnn_type = rnn_type 49 | self.ndirections = 1 + int(bidirect) 50 | 51 | def init_hidden(self, batch): 52 | # just to get the type of tensor 53 | weight = next(self.parameters()).data 54 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid) 55 | if self.rnn_type == 'LSTM': 56 | return (Variable(weight.new(*hid_shape).zero_()), 57 | Variable(weight.new(*hid_shape).zero_())) 58 | else: 59 | return Variable(weight.new(*hid_shape).zero_()) 60 | 61 | def forward(self, x): 62 | # x: [batch, sequence, in_dim] 63 | batch = x.size(0) 64 | hidden = self.init_hidden(batch) 65 | self.rnn.flatten_parameters() 66 | output, hidden = self.rnn(x, hidden) 67 | 68 | if self.ndirections == 1: 69 | return output[:, -1] 70 | 71 | forward_ = output[:, -1, :self.num_hid] 72 | backward = output[:, 0, self.num_hid:] 73 | return torch.cat((forward_, backward), dim=1) 74 | 75 | def forward_all(self, x): 76 | # x: [batch, sequence, in_dim] 77 | batch = x.size(0) 78 | hidden = self.init_hidden(batch) 79 | self.rnn.flatten_parameters() 80 | output, hidden = self.rnn(x, hidden) 81 | return output 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import cPickle as pickle 4 | from collections import defaultdict, Counter 5 | from os.path import dirname, join 6 | import os 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | 13 | from dataset import Dictionary, VQAFeatureDataset 14 | import base_model 15 | from train import train 16 | import utils 17 | import click 18 | 19 | from vqa_debias_loss_functions import * 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method") 24 | 25 | # Arguments we added 26 | parser.add_argument( 27 | '--cache_features', default=True, 28 | help="Cache image features in RAM. Makes things much faster, " 29 | "especially if the filesystem is slow, but requires at least 48gb of RAM") 30 | parser.add_argument( 31 | '--dataset', default='cpv2', 32 | choices=["v2", "cpv2", "cpv1"], 33 | help="Run on VQA-2.0 instead of VQA-CP 2.0" 34 | ) 35 | parser.add_argument( 36 | '-p', "--entropy_penalty", default=0.36, type=float, 37 | help="Entropy regularizer weight for the learned_mixin model") 38 | parser.add_argument( 39 | '--mode', default="updn", 40 | choices=["updn", "q_debias","v_debias","q_v_debias"], 41 | help="Kind of ensemble loss to use") 42 | parser.add_argument( 43 | '--debias', default="learned_mixin", 44 | choices=["learned_mixin", "reweight", "bias_product", "none",'focal'], 45 | help="Kind of ensemble loss to use") 46 | parser.add_argument( 47 | '--topq', type=int,default=1, 48 | choices=[1,2,3], 49 | help="num of words to be masked in questio") 50 | parser.add_argument( 51 | '--keep_qtype', default=True, 52 | help="keep qtype or not") 53 | parser.add_argument( 54 | '--topv', type=int,default=1, 55 | choices=[1,3,5,-1], 56 | help="num of object bbox to be masked in image") 57 | parser.add_argument( 58 | '--top_hint',type=int, default=9, 59 | choices=[9,18,27,36], 60 | help="num of hint") 61 | parser.add_argument( 62 | '--qvp', type=int,default=0, 63 | choices=[0,1,2,3,4,5,6,7,8,9,10], 64 | help="ratio of q_bias and v_bias") 65 | parser.add_argument( 66 | '--eval_each_epoch', default=True, 67 | help="Evaluate every epoch, instead of at the end") 68 | 69 | # Arguments from the original model, we leave this default, except we 70 | # set --epochs to 30 since the model maxes out its performance on VQA 2.0 well before then 71 | parser.add_argument('--epochs', type=int, default=30) 72 | parser.add_argument('--num_hid', type=int, default=1024) 73 | parser.add_argument('--model', type=str, default='baseline0_newatt') 74 | parser.add_argument('--output', type=str, default='logs/exp0') 75 | parser.add_argument('--batch_size', type=int, default=512) 76 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 77 | args = parser.parse_args() 78 | return args 79 | 80 | def get_bias(train_dset,eval_dset): 81 | # Compute the bias: 82 | # The bias here is just the expected score for each answer/question type 83 | answer_voc_size = train_dset.num_ans_candidates 84 | 85 | # question_type -> answer -> total score 86 | question_type_to_probs = defaultdict(Counter) 87 | 88 | # question_type -> num_occurances 89 | question_type_to_count = Counter() 90 | for ex in train_dset.entries: 91 | ans = ex["answer"] 92 | q_type = ans["question_type"] 93 | question_type_to_count[q_type] += 1 94 | if ans["labels"] is not None: 95 | for label, score in zip(ans["labels"], ans["scores"]): 96 | question_type_to_probs[q_type][label] += score 97 | question_type_to_prob_array = {} 98 | 99 | for q_type, count in question_type_to_count.items(): 100 | prob_array = np.zeros(answer_voc_size, np.float32) 101 | for label, total_score in question_type_to_probs[q_type].items(): 102 | prob_array[label] += total_score 103 | prob_array /= count 104 | question_type_to_prob_array[q_type] = prob_array 105 | 106 | for ds in [train_dset,eval_dset]: 107 | for ex in ds.entries: 108 | q_type = ex["answer"]["question_type"] 109 | ex["bias"] = question_type_to_prob_array[q_type] 110 | 111 | 112 | def main(): 113 | args = parse_args() 114 | dataset=args.dataset 115 | args.output=os.path.join('logs',args.output) 116 | if not os.path.isdir(args.output): 117 | utils.create_dir(args.output) 118 | else: 119 | if click.confirm('Exp directory already exists in {}. Erase?' 120 | .format(args.output, default=False)): 121 | os.system('rm -r ' + args.output) 122 | utils.create_dir(args.output) 123 | 124 | else: 125 | os._exit(1) 126 | 127 | 128 | 129 | if dataset=='cpv1': 130 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl') 131 | elif dataset=='cpv2' or dataset=='v2': 132 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 133 | 134 | print("Building train dataset...") 135 | train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset, 136 | cache_image_features=args.cache_features) 137 | 138 | print("Building test dataset...") 139 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset, 140 | cache_image_features=args.cache_features) 141 | 142 | get_bias(train_dset,eval_dset) 143 | 144 | 145 | # Build the model using the original constructor 146 | constructor = 'build_%s' % args.model 147 | model = getattr(base_model, constructor)(train_dset, args.num_hid).cuda() 148 | if dataset=='cpv1': 149 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.npy') 150 | elif dataset=='cpv2' or dataset=='v2': 151 | model.w_emb.init_embedding('data/glove6b_init_300d.npy') 152 | 153 | # Add the loss_fn based our arguments 154 | if args.debias == "bias_product": 155 | model.debias_loss_fn = BiasProduct() 156 | elif args.debias == "none": 157 | model.debias_loss_fn = Plain() 158 | elif args.debias == "reweight": 159 | model.debias_loss_fn = ReweightByInvBias() 160 | elif args.debias == "learned_mixin": 161 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty) 162 | elif args.debias=='focal': 163 | model.debias_loss_fn = Focal() 164 | else: 165 | raise RuntimeError(args.mode) 166 | 167 | 168 | with open('util/qid2type_%s.json'%args.dataset,'r') as f: 169 | qid2type=json.load(f) 170 | model=model.cuda() 171 | batch_size = args.batch_size 172 | 173 | torch.manual_seed(args.seed) 174 | torch.cuda.manual_seed(args.seed) 175 | torch.backends.cudnn.benchmark = True 176 | 177 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0) 178 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0) 179 | 180 | print("Starting training...") 181 | train(model, train_loader, eval_loader, args,qid2type) 182 | 183 | if __name__ == '__main__': 184 | main() 185 | 186 | 187 | 188 | 189 | 190 | 191 | -------------------------------------------------------------------------------- /rubi_base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attention import Attention, NewAttention 4 | from language_model import WordEmbedding, QuestionEmbedding 5 | from classifier import SimpleClassifier 6 | from fc import FCNet 7 | from torch.nn import functional as F 8 | 9 | import numpy as np 10 | 11 | def mask_softmax(x,mask): 12 | mask=mask.unsqueeze(2).float() 13 | x2=torch.exp(x-torch.max(x)) 14 | x3=x2*mask 15 | epsilon=1e-5 16 | x3_sum=torch.sum(x3,dim=1,keepdim=True)+epsilon 17 | x4=x3/x3_sum.expand_as(x3) 18 | return x4 19 | 20 | 21 | class MLP(nn.Module): 22 | 23 | def __init__(self, 24 | input_dim, 25 | dimensions, 26 | activation='relu', 27 | dropout=0.): 28 | super(MLP, self).__init__() 29 | self.input_dim = input_dim 30 | self.dimensions = dimensions 31 | self.activation = activation 32 | self.dropout = dropout 33 | # Modules 34 | self.linears = nn.ModuleList([nn.Linear(input_dim, dimensions[0])]) 35 | for din, dout in zip(dimensions[:-1], dimensions[1:]): 36 | self.linears.append(nn.Linear(din, dout)) 37 | 38 | def forward(self, x): 39 | for i, lin in enumerate(self.linears): 40 | x = lin(x) 41 | if (i < len(self.linears) - 1): 42 | x = nn.functional.__dict__[self.activation](x) 43 | if self.dropout > 0: 44 | x = nn.functional.dropout(x, self.dropout, training=self.training) 45 | return x 46 | 47 | class BaseModel(nn.Module): 48 | def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier,c_1,c_2): 49 | super(BaseModel, self).__init__() 50 | self.w_emb = w_emb 51 | self.q_emb = q_emb 52 | self.v_att = v_att 53 | self.q_net = q_net 54 | self.v_net = v_net 55 | self.classifier = classifier 56 | self.debias_loss_fn = None 57 | # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2)) 58 | self.bias_lin = torch.nn.Linear(1024, 1) 59 | self.c_1=c_1 60 | self.c_2=c_2 61 | 62 | def forward(self, v, q, labels, bias,v_mask): 63 | """Forward 64 | 65 | v: [batch, num_objs, obj_dim] 66 | b: [batch, num_objs, b_dim] 67 | q: [batch_size, seq_length] 68 | 69 | return: logits, not probs 70 | """ 71 | w_emb = self.w_emb(q) 72 | q_emb = self.q_emb(w_emb) # [batch, q_dim] 73 | 74 | att = self.v_att(v, q_emb) 75 | if v_mask is None: 76 | att = nn.functional.softmax(att, 1) 77 | else: 78 | att= mask_softmax(att,v_mask) 79 | 80 | v_emb = (att * v).sum(1) # [batch, v_dim] 81 | 82 | q_repr = self.q_net(q_emb) 83 | v_repr = self.v_net(v_emb) 84 | joint_repr = q_repr * v_repr 85 | 86 | logits = self.classifier(joint_repr) 87 | 88 | q_pred=self.c_1(q_emb.detach()) 89 | 90 | q_out=self.c_2(q_pred) 91 | 92 | if labels is not None: 93 | rubi_logits=logits*torch.sigmoid(q_pred) 94 | loss=F.binary_cross_entropy_with_logits(rubi_logits, labels)+F.binary_cross_entropy_with_logits(q_out, labels) 95 | loss *= labels.size(1) 96 | 97 | else: 98 | loss = None 99 | return logits, loss,w_emb 100 | 101 | def build_baseline0(dataset, num_hid): 102 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 103 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 104 | v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid) 105 | q_net = FCNet([num_hid, num_hid]) 106 | v_net = FCNet([dataset.v_dim, num_hid]) 107 | classifier = SimpleClassifier( 108 | num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5) 109 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) 110 | 111 | 112 | def build_baseline0_newatt(dataset, num_hid): 113 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 114 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 115 | v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid) 116 | q_net = FCNet([q_emb.num_hid, num_hid]) 117 | v_net = FCNet([dataset.v_dim, num_hid]) 118 | c_1=MLP(input_dim=1024,dimensions=[1024,1024,dataset.num_ans_candidates]) 119 | c_2=nn.Linear(dataset.num_ans_candidates,dataset.num_ans_candidates) 120 | classifier = SimpleClassifier( 121 | num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5) 122 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier,c_1,c_2) -------------------------------------------------------------------------------- /rubi_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import cPickle as pickle 4 | from collections import defaultdict, Counter 5 | from os.path import dirname, join 6 | import os 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | 13 | from dataset import Dictionary, VQAFeatureDataset 14 | import rubi_base_model 15 | from rubi_train import train 16 | import utils 17 | import click 18 | 19 | from vqa_debias_loss_functions import * 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method") 24 | 25 | # Arguments we added 26 | parser.add_argument( 27 | '--cache_features', default=True, 28 | help="Cache image features in RAM. Makes things much faster, " 29 | "especially if the filesystem is slow, but requires at least 48gb of RAM") 30 | parser.add_argument( 31 | '--dataset', default='cpv2', 32 | choices=["v2", "cpv2", "cpv1"], 33 | help="Run on VQA-2.0 instead of VQA-CP 2.0" 34 | ) 35 | parser.add_argument( 36 | '--mode', default="updn", 37 | choices=["updn", "q_debias","v_debias","q_v_debias"], 38 | help="Kind of ensemble loss to use") 39 | parser.add_argument( 40 | '--topq', type=int,default=1, 41 | choices=[1,2,3], 42 | help="num of q to mask") 43 | parser.add_argument( 44 | '--keep_qtype', default=True, 45 | help="keep qtype or not") 46 | parser.add_argument( 47 | '--topv', type=int,default=1, 48 | choices=[1,3,5,-1], 49 | help="num of v to mask") 50 | parser.add_argument( 51 | '--top_hint',type=int, default=9, 52 | choices=[9,18,27,36], 53 | help="num of hint") 54 | parser.add_argument( 55 | '--qvp', type=int,default=0, 56 | choices=[0,1,2,3,4,5,6,7,8,9,10], 57 | help="proportion of q/v") 58 | parser.add_argument( 59 | '--eval_each_epoch', default=True, 60 | help="Evaluate every epoch, instead of at the end") 61 | 62 | # Arguments from the original model, we leave this default, except we 63 | # set --epochs to 15 since the model maxes out its performance on VQA 2.0 well before then 64 | parser.add_argument('--epochs', type=int, default=30) 65 | parser.add_argument('--num_hid', type=int, default=1024) 66 | parser.add_argument('--model', type=str, default='baseline0_newatt') 67 | parser.add_argument('--output', type=str, default='logs/exp0') 68 | parser.add_argument('--batch_size', type=int, default=512) 69 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 70 | args = parser.parse_args() 71 | return args 72 | 73 | def get_bias(train_dset,eval_dset): 74 | # Compute the bias: 75 | # The bias here is just the expected score for each answer/question type 76 | answer_voc_size = train_dset.num_ans_candidates 77 | 78 | # question_type -> answer -> total score 79 | question_type_to_probs = defaultdict(Counter) 80 | 81 | # question_type -> num_occurances 82 | question_type_to_count = Counter() 83 | for ex in train_dset.entries: 84 | ans = ex["answer"] 85 | q_type = ans["question_type"] 86 | question_type_to_count[q_type] += 1 87 | if ans["labels"] is not None: 88 | for label, score in zip(ans["labels"], ans["scores"]): 89 | question_type_to_probs[q_type][label] += score 90 | question_type_to_prob_array = {} 91 | 92 | for q_type, count in question_type_to_count.items(): 93 | prob_array = np.zeros(answer_voc_size, np.float32) 94 | for label, total_score in question_type_to_probs[q_type].items(): 95 | prob_array[label] += total_score 96 | prob_array /= count 97 | question_type_to_prob_array[q_type] = prob_array 98 | 99 | for ds in [train_dset,eval_dset]: 100 | for ex in ds.entries: 101 | q_type = ex["answer"]["question_type"] 102 | ex["bias"] = question_type_to_prob_array[q_type] 103 | 104 | 105 | def main(): 106 | args = parse_args() 107 | dataset=args.dataset 108 | args.output=os.path.join('logs',args.output) 109 | if not os.path.isdir(args.output): 110 | utils.create_dir(args.output) 111 | else: 112 | if click.confirm('Exp directory already exists in {}. Erase?' 113 | .format(args.output, default=False)): 114 | os.system('rm -r ' + args.output) 115 | utils.create_dir(args.output) 116 | 117 | else: 118 | os._exit(1) 119 | 120 | 121 | 122 | if dataset=='cpv1': 123 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl') 124 | elif dataset=='cpv2' or dataset=='v2': 125 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 126 | 127 | print("Building train dataset...") 128 | train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset, 129 | cache_image_features=args.cache_features) 130 | 131 | print("Building test dataset...") 132 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset, 133 | cache_image_features=args.cache_features) 134 | 135 | get_bias(train_dset,eval_dset) 136 | 137 | 138 | # Build the model using the original constructor 139 | constructor = 'build_%s' % args.model 140 | model = getattr(rubi_base_model, constructor)(train_dset, args.num_hid).cuda() 141 | if dataset=='cpv1': 142 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.npy') 143 | elif dataset=='cpv2' or dataset=='v2': 144 | model.w_emb.init_embedding('data/glove6b_init_300d.npy') 145 | 146 | # Add the loss_fn based our arguments 147 | # model.debias_loss_fn = Focal() 148 | 149 | with open('util/qid2type_%s.json'%args.dataset,'r') as f: 150 | qid2type=json.load(f) 151 | model=model.cuda() 152 | batch_size = args.batch_size 153 | 154 | torch.manual_seed(args.seed) 155 | torch.cuda.manual_seed(args.seed) 156 | torch.backends.cudnn.benchmark = True 157 | 158 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0) 159 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0) 160 | 161 | print("Starting training...") 162 | train(model, train_loader, eval_loader, args,qid2type) 163 | 164 | if __name__ == '__main__': 165 | main() 166 | 167 | 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /rubi_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import time 5 | from os.path import join 6 | 7 | import torch 8 | import torch.nn as nn 9 | import utils 10 | from torch.autograd import Variable 11 | import numpy as np 12 | from tqdm import tqdm 13 | import random 14 | import copy 15 | 16 | 17 | def compute_score_with_logits(logits, labels): 18 | logits = torch.argmax(logits, 1) 19 | one_hots = torch.zeros(*labels.size()).cuda() 20 | one_hots.scatter_(1, logits.view(-1, 1), 1) 21 | scores = (one_hots * labels) 22 | return scores 23 | 24 | def train(model, train_loader, eval_loader,args,qid2type): 25 | num_epochs=args.epochs 26 | mode=args.mode 27 | run_eval=args.eval_each_epoch 28 | output=args.output 29 | optim = torch.optim.Adamax(model.parameters()) 30 | logger = utils.Logger(os.path.join(output, 'log.txt')) 31 | total_step = 0 32 | best_eval_score = 0 33 | 34 | 35 | 36 | if mode=='q_debias': 37 | topq=args.topq 38 | keep_qtype=args.keep_qtype 39 | elif mode=='v_debias': 40 | topv=args.topv 41 | top_hint=args.top_hint 42 | elif mode=='q_v_debias': 43 | topv=args.topv 44 | top_hint=args.top_hint 45 | topq=args.topq 46 | keep_qtype=args.keep_qtype 47 | qvp=args.qvp 48 | 49 | 50 | 51 | for epoch in range(num_epochs): 52 | total_loss = 0 53 | train_score = 0 54 | 55 | t = time.time() 56 | for i, (v, q, a, b, hintscore,type_mask,notype_mask,q_mask) in tqdm(enumerate(train_loader), ncols=100, 57 | desc="Epoch %d" % (epoch + 1), total=len(train_loader)): 58 | 59 | total_step += 1 60 | 61 | 62 | ######################################### 63 | v = Variable(v).cuda().requires_grad_() 64 | q = Variable(q).cuda() 65 | q_mask=Variable(q_mask).cuda() 66 | a = Variable(a).cuda() 67 | b = Variable(b).cuda() 68 | hintscore = Variable(hintscore).cuda() 69 | type_mask=Variable(type_mask).float().cuda() 70 | notype_mask=Variable(notype_mask).float().cuda() 71 | ######################################### 72 | 73 | if mode=='updn': 74 | pred, loss,_ = model(v, q, a, b, None) 75 | if (loss != loss).any(): 76 | raise ValueError("NaN loss") 77 | loss.backward() 78 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 79 | optim.step() 80 | optim.zero_grad() 81 | 82 | total_loss += loss.item() * q.size(0) 83 | batch_score = compute_score_with_logits(pred, a.data).sum() 84 | train_score += batch_score 85 | 86 | elif mode=='q_debias': 87 | if keep_qtype==True: 88 | sen_mask=type_mask 89 | else: 90 | sen_mask=notype_mask 91 | ## first train 92 | pred, loss,word_emb = model(v, q, a, b, None) 93 | 94 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0] 95 | 96 | if (loss != loss).any(): 97 | raise ValueError("NaN loss") 98 | loss.backward() 99 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 100 | optim.step() 101 | optim.zero_grad() 102 | 103 | total_loss += loss.item() * q.size(0) 104 | batch_score = compute_score_with_logits(pred, a.data).sum() 105 | train_score += batch_score 106 | 107 | ## second train 108 | 109 | word_grad_cam = word_grad.sum(2) 110 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000) 111 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask) 112 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask 113 | 114 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq] 115 | 116 | q2 = copy.deepcopy(q_mask) 117 | 118 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1] 119 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0] 120 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1] 121 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0] 122 | q2 = q2 * m2.long() + m3.long() 123 | 124 | pred, _, _ = model(v, q2, None, b, None) 125 | 126 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 127 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 128 | false_ans.scatter_(1, pred_ind, 0) 129 | a2 = a * false_ans 130 | q3 = copy.deepcopy(q) 131 | q3.scatter_(1, w_ind, 18455) 132 | 133 | ## third train 134 | 135 | pred, loss, _ = model(v, q3, a2, b, None) 136 | 137 | if (loss != loss).any(): 138 | raise ValueError("NaN loss") 139 | loss.backward() 140 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 141 | optim.step() 142 | optim.zero_grad() 143 | 144 | total_loss += loss.item() * q.size(0) 145 | 146 | elif mode=='v_debias': 147 | ## first train 148 | pred, loss, _ = model(v, q, a, b, None) 149 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 150 | 151 | if (loss != loss).any(): 152 | raise ValueError("NaN loss") 153 | loss.backward() 154 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 155 | optim.step() 156 | optim.zero_grad() 157 | 158 | total_loss += loss.item() * q.size(0) 159 | batch_score = compute_score_with_logits(pred, a.data).sum() 160 | train_score += batch_score 161 | 162 | ##second train 163 | v_mask = torch.zeros(v.shape[0], 36).cuda() 164 | visual_grad_cam = visual_grad.sum(2) 165 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 166 | v_ind = hint_ind[:, :top_hint] 167 | v_grad = visual_grad_cam.gather(1, v_ind) 168 | 169 | if topv==-1: 170 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True) 171 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1) 172 | v_grad_sum=torch.cumsum(v_grad_score,dim=1) 173 | v_grad_mask=(v_grad_sum<=0.6).long() 174 | v_grad_mask[:,0] = 1 175 | 176 | v_mask_ind=v_grad_mask*v_ind 177 | for x in range(a.shape[0]): 178 | num=len(torch.nonzero(v_grad_mask[x])) 179 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1) 180 | else: 181 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 182 | v_star = v_ind.gather(1, v_grad_ind) 183 | v_mask.scatter_(1, v_star, 1) 184 | 185 | 186 | pred, _, _ = model(v, q, None, b, v_mask) 187 | 188 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 189 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 190 | false_ans.scatter_(1, pred_ind, 0) 191 | a2 = a * false_ans 192 | 193 | v_mask = 1 - v_mask 194 | 195 | pred, loss, _ = model(v, q, a2, b, v_mask) 196 | 197 | if (loss != loss).any(): 198 | raise ValueError("NaN loss") 199 | loss.backward() 200 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 201 | optim.step() 202 | optim.zero_grad() 203 | 204 | total_loss += loss.item() * q.size(0) 205 | 206 | elif mode=='q_v_debias': 207 | random_num = random.randint(1, 10) 208 | if keep_qtype == True: 209 | sen_mask = type_mask 210 | else: 211 | sen_mask = notype_mask 212 | if random_num<=qvp: 213 | ## first train 214 | pred, loss, word_emb = model(v, q, a, b, None) 215 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0] 216 | 217 | if (loss != loss).any(): 218 | raise ValueError("NaN loss") 219 | loss.backward() 220 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 221 | optim.step() 222 | optim.zero_grad() 223 | 224 | total_loss += loss.item() * q.size(0) 225 | batch_score = compute_score_with_logits(pred, a.data).sum() 226 | train_score += batch_score 227 | 228 | ## second train 229 | 230 | word_grad_cam = word_grad.sum(2) 231 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask) 232 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask 233 | 234 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq] 235 | 236 | q2 = copy.deepcopy(q_mask) 237 | 238 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1] 239 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0] 240 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1] 241 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0] 242 | q2 = q2 * m2.long() + m3.long() 243 | 244 | pred, _, _ = model(v, q2, None, b, None) 245 | 246 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 247 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 248 | false_ans.scatter_(1, pred_ind, 0) 249 | a2 = a * false_ans 250 | q3 = copy.deepcopy(q) 251 | q3.scatter_(1, w_ind, 18455) 252 | 253 | ## third train 254 | 255 | pred, loss, _ = model(v, q3, a2, b, None) 256 | 257 | if (loss != loss).any(): 258 | raise ValueError("NaN loss") 259 | loss.backward() 260 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 261 | optim.step() 262 | optim.zero_grad() 263 | 264 | total_loss += loss.item() * q.size(0) 265 | 266 | 267 | else: 268 | ## first train 269 | pred, loss, _ = model(v, q, a, b, None) 270 | visual_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 271 | 272 | if (loss != loss).any(): 273 | raise ValueError("NaN loss") 274 | loss.backward() 275 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 276 | optim.step() 277 | optim.zero_grad() 278 | 279 | total_loss += loss.item() * q.size(0) 280 | batch_score = compute_score_with_logits(pred, a.data).sum() 281 | train_score += batch_score 282 | 283 | ##second train 284 | v_mask = torch.zeros(v.shape[0], 36).cuda() 285 | visual_grad_cam = visual_grad.sum(2) 286 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 287 | v_ind = hint_ind[:, :top_hint] 288 | v_grad = visual_grad_cam.gather(1, v_ind) 289 | 290 | if topv == -1: 291 | v_grad_score, v_grad_ind = v_grad.sort(1, descending=True) 292 | v_grad_score = nn.functional.softmax(v_grad_score * 10, dim=1) 293 | v_grad_sum = torch.cumsum(v_grad_score, dim=1) 294 | v_grad_mask = (v_grad_sum <= 0.65).long() 295 | v_grad_mask[:,0] = 1 296 | v_mask_ind = v_grad_mask * v_ind 297 | for x in range(a.shape[0]): 298 | num = len(torch.nonzero(v_grad_mask[x])) 299 | v_mask[x].scatter_(0, v_mask_ind[x,:num], 1) 300 | else: 301 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 302 | v_star = v_ind.gather(1, v_grad_ind) 303 | v_mask.scatter_(1, v_star, 1) 304 | 305 | pred, _, _ = model(v, q, None, b, v_mask) 306 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 307 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 308 | false_ans.scatter_(1, pred_ind, 0) 309 | a2 = a * false_ans 310 | 311 | v_mask = 1 - v_mask 312 | 313 | pred, loss, _ = model(v, q, a2, b, v_mask) 314 | 315 | if (loss != loss).any(): 316 | raise ValueError("NaN loss") 317 | loss.backward() 318 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 319 | optim.step() 320 | optim.zero_grad() 321 | 322 | total_loss += loss.item() * q.size(0) 323 | 324 | if mode=='updn': 325 | total_loss /= len(train_loader.dataset) 326 | else: 327 | total_loss /= len(train_loader.dataset) * 2 328 | train_score = 100 * train_score / len(train_loader.dataset) 329 | 330 | if run_eval: 331 | model.train(False) 332 | results = evaluate(model, eval_loader, qid2type) 333 | results["epoch"] = epoch + 1 334 | results["step"] = total_step 335 | results["train_loss"] = total_loss 336 | results["train_score"] = train_score 337 | 338 | model.train(True) 339 | 340 | eval_score = results["score"] 341 | bound = results["upper_bound"] 342 | yn = results['score_yesno'] 343 | other = results['score_other'] 344 | num = results['score_number'] 345 | 346 | logger.write('epoch %d, time: %.2f' % (epoch, time.time() - t)) 347 | logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score)) 348 | 349 | if run_eval: 350 | logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound)) 351 | logger.write('\tyn score: %.2f other score: %.2f num score: %.2f' % (100 * yn, 100 * other, 100 * num)) 352 | 353 | if eval_score > best_eval_score: 354 | model_path = os.path.join(output, 'model.pth') 355 | torch.save(model.state_dict(), model_path) 356 | best_eval_score = eval_score 357 | 358 | 359 | def evaluate(model, dataloader, qid2type): 360 | score = 0 361 | upper_bound = 0 362 | score_yesno = 0 363 | score_number = 0 364 | score_other = 0 365 | total_yesno = 0 366 | total_number = 0 367 | total_other = 0 368 | 369 | for v, q, a, b, qids, _ in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 370 | v = Variable(v, requires_grad=False).cuda() 371 | q = Variable(q, requires_grad=False).cuda() 372 | pred, _,_ = model(v, q, None, None, None) 373 | batch_score = compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1) 374 | score += batch_score.sum() 375 | upper_bound += (a.max(1)[0]).sum() 376 | qids = qids.detach().cpu().int().numpy() 377 | for j in range(len(qids)): 378 | qid = qids[j] 379 | typ = qid2type[str(qid)] 380 | if typ == 'yes/no': 381 | score_yesno += batch_score[j] 382 | total_yesno += 1 383 | elif typ == 'other': 384 | score_other += batch_score[j] 385 | total_other += 1 386 | elif typ == 'number': 387 | score_number += batch_score[j] 388 | total_number += 1 389 | else: 390 | print('Hahahahahahahahahahaha') 391 | 392 | 393 | score = score / len(dataloader.dataset) 394 | upper_bound = upper_bound / len(dataloader.dataset) 395 | score_yesno /= total_yesno 396 | score_other /= total_other 397 | score_number /= total_number 398 | 399 | results = dict( 400 | score=score, 401 | upper_bound=upper_bound, 402 | score_yesno=score_yesno, 403 | score_other=score_other, 404 | score_number=score_number, 405 | ) 406 | return results 407 | -------------------------------------------------------------------------------- /tools/compute_softscore.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import json 7 | import numpy as np 8 | import re 9 | import cPickle 10 | 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | from dataset import Dictionary 13 | import utils 14 | 15 | 16 | contractions = { 17 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 18 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 19 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 20 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 21 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 22 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 23 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 24 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 25 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 26 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 27 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 28 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 29 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 30 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 31 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 32 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 33 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 34 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 35 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 36 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 37 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 38 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 39 | "someonell": "someone'll", "someones": "someone's", "somethingd": 40 | "something'd", "somethingd've": "something'd've", "something'dve": 41 | "something'd've", "somethingll": "something'll", "thats": 42 | "that's", "thered": "there'd", "thered've": "there'd've", 43 | "there'dve": "there'd've", "therere": "there're", "theres": 44 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 45 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 46 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 47 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 48 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 49 | "what's", "whatve": "what've", "whens": "when's", "whered": 50 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 51 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 52 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 53 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 54 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 55 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 56 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 57 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 58 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 59 | "you'll", "youre": "you're", "youve": "you've" 60 | } 61 | 62 | manual_map = { 'none': '0', 63 | 'zero': '0', 64 | 'one': '1', 65 | 'two': '2', 66 | 'three': '3', 67 | 'four': '4', 68 | 'five': '5', 69 | 'six': '6', 70 | 'seven': '7', 71 | 'eight': '8', 72 | 'nine': '9', 73 | 'ten': '10'} 74 | articles = ['a', 'an', 'the'] 75 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 76 | comma_strip = re.compile("(\d)(\,)(\d)") 77 | punct = [';', r"/", '[', ']', '"', '{', '}', 78 | '(', ')', '=', '+', '\\', '_', '-', 79 | '>', '<', '@', '`', ',', '?', '!'] 80 | 81 | 82 | def get_score(occurences): 83 | if occurences == 0: 84 | return 0 85 | elif occurences == 1: 86 | return 0.3 87 | elif occurences == 2: 88 | return 0.6 89 | elif occurences == 3: 90 | return 0.9 91 | else: 92 | return 1 93 | 94 | 95 | def process_punctuation(inText): 96 | outText = inText 97 | for p in punct: 98 | if (p + ' ' in inText or ' ' + p in inText) \ 99 | or (re.search(comma_strip, inText) != None): 100 | outText = outText.replace(p, '') 101 | else: 102 | outText = outText.replace(p, ' ') 103 | outText = period_strip.sub("", outText, re.UNICODE) 104 | return outText 105 | 106 | 107 | def process_digit_article(inText): 108 | outText = [] 109 | tempText = inText.lower().split() 110 | for word in tempText: 111 | word = manual_map.setdefault(word, word) 112 | if word not in articles: 113 | outText.append(word) 114 | else: 115 | pass 116 | for wordId, word in enumerate(outText): 117 | if word in contractions: 118 | outText[wordId] = contractions[word] 119 | outText = ' '.join(outText) 120 | return outText 121 | 122 | 123 | def multiple_replace(text, wordDict): 124 | for key in wordDict: 125 | text = text.replace(key, wordDict[key]) 126 | return text 127 | 128 | 129 | def preprocess_answer(answer): 130 | answer = process_digit_article(process_punctuation(answer)) 131 | answer = answer.replace(',', '') 132 | return answer 133 | 134 | 135 | def filter_answers(answers_dset, min_occurence): 136 | """This will change the answer to preprocessed version 137 | """ 138 | occurence = {} 139 | for ans_entry in answers_dset: 140 | gtruth = ans_entry['multiple_choice_answer'] 141 | gtruth = preprocess_answer(gtruth) 142 | if gtruth not in occurence: 143 | occurence[gtruth] = set() 144 | occurence[gtruth].add(ans_entry['question_id']) 145 | for answer in occurence.keys(): 146 | if len(occurence[answer]) < min_occurence: 147 | occurence.pop(answer) 148 | 149 | print('Num of answers that appear >= %d times: %d' % ( 150 | min_occurence, len(occurence))) 151 | return occurence 152 | 153 | 154 | 155 | 156 | 157 | 158 | def create_ans2label(occurence, name, cache_root): 159 | """Note that this will also create label2ans.pkl at the same time 160 | 161 | occurence: dict {answer -> whatever} 162 | name: prefix of the output file 163 | cache_root: str 164 | """ 165 | ans2label = {} 166 | label2ans = [] 167 | label = 0 168 | for answer in occurence: 169 | label2ans.append(answer) 170 | ans2label[answer] = label 171 | label += 1 172 | 173 | utils.create_dir(cache_root) 174 | 175 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl') 176 | cPickle.dump(ans2label, open(cache_file, 'wb')) 177 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl') 178 | cPickle.dump(label2ans, open(cache_file, 'wb')) 179 | return ans2label 180 | 181 | 182 | def compute_target(answers_dset, ans2label, name, cache_root): 183 | """Augment answers_dset with soft score as label 184 | 185 | ***answers_dset should be preprocessed*** 186 | 187 | Write result into a cache file 188 | """ 189 | target = [] 190 | for ans_entry in answers_dset: 191 | answers = ans_entry['answers'] 192 | answer_count = {} 193 | for answer in answers: 194 | answer_ = answer['answer'] 195 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 196 | 197 | labels = [] 198 | scores = [] 199 | for answer in answer_count: 200 | if answer not in ans2label: 201 | continue 202 | labels.append(ans2label[answer]) 203 | score = get_score(answer_count[answer]) 204 | scores.append(score) 205 | 206 | label_counts = {} 207 | for k, v in answer_count.items(): 208 | if k in ans2label: 209 | label_counts[ans2label[k]] = v 210 | 211 | target.append({ 212 | 'question_id': ans_entry['question_id'], 213 | 'question_type': ans_entry['question_type'], 214 | 'image_id': ans_entry['image_id'], 215 | 'label_counts': label_counts, 216 | 'labels': labels, 217 | 'scores': scores 218 | }) 219 | 220 | print(cache_root) 221 | utils.create_dir(cache_root) 222 | cache_file = os.path.join(cache_root, name+'_target.pkl') 223 | print(cache_file) 224 | with open(cache_file, 'wb') as f: 225 | cPickle.dump(target, f) 226 | return target 227 | 228 | 229 | 230 | def get_answer(qid, answers): 231 | for ans in answers: 232 | if ans['question_id'] == qid: 233 | return ans 234 | 235 | 236 | def get_question(qid, questions): 237 | for question in questions: 238 | if question['question_id'] == qid: 239 | return question 240 | 241 | 242 | def load_cp(): 243 | train_answer_file = "data/vqacp_v2_train_annotations.json" 244 | with open(train_answer_file) as f: 245 | train_answers = json.load(f) # ['annotations'] 246 | 247 | val_answer_file = "data/vqacp_v2_test_annotations.json" 248 | with open(val_answer_file) as f: 249 | val_answers = json.load(f) # ['annotations'] 250 | 251 | occurence = filter_answers(train_answers, 9) 252 | ans2label = create_ans2label(occurence, 'trainval', "data/cp-cache") 253 | compute_target(train_answers, ans2label, 'train', "data/cp-cache") 254 | compute_target(val_answers, ans2label, 'val', "data/cp-cache") 255 | 256 | def load_cp_v1(): 257 | train_answer_file = "data/vqacp_v1_train_annotations.json" 258 | with open(train_answer_file) as f: 259 | train_answers = json.load(f) # ['annotations'] 260 | 261 | val_answer_file = "data/vqacp_v1_test_annotations.json" 262 | with open(val_answer_file) as f: 263 | val_answers = json.load(f) # ['annotations'] 264 | 265 | occurence = filter_answers(train_answers, 9) 266 | ans2label = create_ans2label(occurence, 'trainval', "data/cp-v1-cache") 267 | compute_target(train_answers, ans2label, 'train', "data/cp-v1-cache") 268 | compute_target(val_answers, ans2label, 'val', "data/cp-v1-cache") 269 | 270 | 271 | def load_v2(): 272 | train_answer_file = 'data/v2_mscoco_train2014_annotations.json' 273 | with open(train_answer_file) as f: 274 | train_answers = json.load(f)['annotations'] 275 | 276 | val_answer_file = 'data/v2_mscoco_val2014_annotations.json' 277 | with open(val_answer_file) as f: 278 | val_answers = json.load(f)['annotations'] 279 | 280 | occurence = filter_answers(train_answers, 9) 281 | ans2label = create_ans2label(occurence, 'trainval', "data/cache") 282 | compute_target(train_answers, ans2label, 'train', "data/cache") 283 | compute_target(val_answers, ans2label, 'val', "data/cache") 284 | 285 | 286 | def main(): 287 | parser = argparse.ArgumentParser("Dataset preprocessing") 288 | parser.add_argument("dataset", choices=["cp_v2", "v2",'cp_v1']) 289 | args = parser.parse_args() 290 | if args.dataset == "v2": 291 | load_v2() 292 | elif args.dataset == "cp_v1": 293 | load_cp_v1() 294 | elif args.dataset=='cp_v2': 295 | load_cp() 296 | 297 | 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /tools/create_dictionary.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | from dataset import Dictionary 8 | 9 | 10 | 11 | 12 | def create_dictionary(dataroot): 13 | dictionary = Dictionary() 14 | questions = [] 15 | files = [ 16 | 'v2_OpenEnded_mscoco_train2014_questions.json', 17 | 'v2_OpenEnded_mscoco_val2014_questions.json', 18 | 'v2_OpenEnded_mscoco_test2015_questions.json', 19 | 'v2_OpenEnded_mscoco_test-dev2015_questions.json' 20 | ] 21 | for path in files: 22 | question_path = os.path.join(dataroot, path) 23 | qs = json.load(open(question_path))['questions'] 24 | for q in qs: 25 | dictionary.tokenize(q['question'], True) 26 | dictionary.tokenize('wordmask',True) 27 | return dictionary 28 | 29 | 30 | def create_glove_embedding_init(idx2word, glove_file): 31 | word2emb = {} 32 | with open(glove_file, 'r') as f: 33 | entries = f.readlines() 34 | emb_dim = len(entries[0].split(' ')) - 1 35 | print('embedding dim is %d' % emb_dim) 36 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 37 | 38 | for entry in entries: 39 | vals = entry.split(' ') 40 | word = vals[0] 41 | vals = map(float, vals[1:]) 42 | word2emb[word] = np.array(vals) 43 | for idx, word in enumerate(idx2word): 44 | if word not in word2emb: 45 | continue 46 | weights[idx] = word2emb[word] 47 | return weights, word2emb 48 | 49 | 50 | if __name__ == '__main__': 51 | d = create_dictionary('data') 52 | d.dump_to_file('data/dictionary.pkl') 53 | 54 | d = Dictionary.load_from_file('data/dictionary.pkl') 55 | emb_dim = 300 56 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 57 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 58 | np.save('data/glove6b_init_%dd.npy' % emb_dim, weights) 59 | -------------------------------------------------------------------------------- /tools/create_dictionary_v1.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | from dataset import Dictionary 8 | 9 | 10 | def create_dictionary(dataroot): 11 | dictionary = Dictionary() 12 | questions = [] 13 | files = [ 14 | 'OpenEnded_mscoco_train2014_questions.json', 15 | 'OpenEnded_mscoco_val2014_questions.json', 16 | 'OpenEnded_mscoco_test2015_questions.json', 17 | 'OpenEnded_mscoco_test-dev2015_questions.json' 18 | ] 19 | for path in files: 20 | question_path = os.path.join(dataroot, path) 21 | qs = json.load(open(question_path))['questions'] 22 | for q in qs: 23 | dictionary.tokenize(q['question'], True) 24 | return dictionary 25 | 26 | 27 | def create_glove_embedding_init(idx2word, glove_file): 28 | word2emb = {} 29 | with open(glove_file, 'r') as f: 30 | entries = f.readlines() 31 | emb_dim = len(entries[0].split(' ')) - 1 32 | print('embedding dim is %d' % emb_dim) 33 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 34 | 35 | for entry in entries: 36 | vals = entry.split(' ') 37 | word = vals[0] 38 | vals = map(float, vals[1:]) 39 | word2emb[word] = np.array(vals) 40 | for idx, word in enumerate(idx2word): 41 | if word not in word2emb: 42 | continue 43 | weights[idx] = word2emb[word] 44 | return weights, word2emb 45 | 46 | 47 | if __name__ == '__main__': 48 | d = create_dictionary('data') 49 | d.dump_to_file('data/dictionary_v1.pkl') 50 | 51 | d = Dictionary.load_from_file('data/dictionary_v1.pkl') 52 | emb_dim = 300 53 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 54 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 55 | np.save('data/glove6b_init_%dd_v1.npy' % emb_dim, weights) 56 | -------------------------------------------------------------------------------- /tools/download.sh: -------------------------------------------------------------------------------- 1 | ## Script for downloading data 2 | 3 | # GloVe Vectors 4 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip 5 | unzip data/glove.6B.zip -d data/glove 6 | rm data/glove.6B.zip 7 | 8 | # VQA-CP2 9 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json 10 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json 11 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json 12 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json 13 | 14 | # VQA-CP1 15 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json 16 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json 17 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json 18 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json 19 | 20 | # VQA-V2 21 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip 22 | unzip data/v2_Questions_Train_mscoco.zip -d data 23 | rm data/v2_Questions_Train_mscoco.zip 24 | 25 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip 26 | unzip data/v2_Questions_Val_mscoco.zip -d data 27 | rm data/v2_Questions_Val_mscoco.zip 28 | 29 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip 30 | unzip data/v2_Questions_Test_mscoco.zip -d data 31 | rm data/v2_Questions_Test_mscoco.zip 32 | 33 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip 34 | unzip data/v2_Annotations_Train_mscoco.zip -d data 35 | rm data/v2_Annotations_Train_mscoco.zip 36 | 37 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip 38 | unzip data/v2_Annotations_Val_mscoco.zip -d data 39 | rm data/v2_Annotations_Val_mscoco.zip 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /tools/process.sh: -------------------------------------------------------------------------------- 1 | # Process data 2 | 3 | python tools/create_dictionary.py 4 | python tools/create_dictionary_v1.py 5 | python tools/compute_softscore.py v2 6 | python tools/compute_softscore.py cp_v1 7 | python tools/compute_softscore.py cp_v2 8 | 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import time 5 | from os.path import join 6 | 7 | import torch 8 | import torch.nn as nn 9 | import utils 10 | from torch.autograd import Variable 11 | import numpy as np 12 | from tqdm import tqdm 13 | import random 14 | import copy 15 | 16 | 17 | def compute_score_with_logits(logits, labels): 18 | logits = torch.argmax(logits, 1) 19 | one_hots = torch.zeros(*labels.size()).cuda() 20 | one_hots.scatter_(1, logits.view(-1, 1), 1) 21 | scores = (one_hots * labels) 22 | return scores 23 | 24 | def train(model, train_loader, eval_loader,args,qid2type): 25 | dataset=args.dataset 26 | num_epochs=args.epochs 27 | mode=args.mode 28 | run_eval=args.eval_each_epoch 29 | output=args.output 30 | optim = torch.optim.Adamax(model.parameters()) 31 | logger = utils.Logger(os.path.join(output, 'log.txt')) 32 | total_step = 0 33 | best_eval_score = 0 34 | 35 | 36 | 37 | if mode=='q_debias': 38 | topq=args.topq 39 | keep_qtype=args.keep_qtype 40 | elif mode=='v_debias': 41 | topv=args.topv 42 | top_hint=args.top_hint 43 | elif mode=='q_v_debias': 44 | topv=args.topv 45 | top_hint=args.top_hint 46 | topq=args.topq 47 | keep_qtype=args.keep_qtype 48 | qvp=args.qvp 49 | 50 | 51 | 52 | for epoch in range(num_epochs): 53 | total_loss = 0 54 | train_score = 0 55 | 56 | t = time.time() 57 | for i, (v, q, a, b, hintscore,type_mask,notype_mask,q_mask) in tqdm(enumerate(train_loader), ncols=100, 58 | desc="Epoch %d" % (epoch + 1), total=len(train_loader)): 59 | 60 | total_step += 1 61 | 62 | 63 | ######################################### 64 | v = Variable(v).cuda().requires_grad_() 65 | q = Variable(q).cuda() 66 | q_mask=Variable(q_mask).cuda() 67 | a = Variable(a).cuda() 68 | b = Variable(b).cuda() 69 | hintscore = Variable(hintscore).cuda() 70 | type_mask=Variable(type_mask).float().cuda() 71 | notype_mask=Variable(notype_mask).float().cuda() 72 | ######################################### 73 | 74 | if mode=='updn': 75 | pred, loss,_ = model(v, q, a, b, None) 76 | if (loss != loss).any(): 77 | raise ValueError("NaN loss") 78 | loss.backward() 79 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 80 | optim.step() 81 | optim.zero_grad() 82 | 83 | total_loss += loss.item() * q.size(0) 84 | batch_score = compute_score_with_logits(pred, a.data).sum() 85 | train_score += batch_score 86 | 87 | elif mode=='q_debias': 88 | if keep_qtype==True: 89 | sen_mask=type_mask 90 | else: 91 | sen_mask=notype_mask 92 | ## first train 93 | pred, loss,word_emb = model(v, q, a, b, None) 94 | 95 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0] 96 | 97 | if (loss != loss).any(): 98 | raise ValueError("NaN loss") 99 | loss.backward() 100 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 101 | optim.step() 102 | optim.zero_grad() 103 | 104 | total_loss += loss.item() * q.size(0) 105 | batch_score = compute_score_with_logits(pred, a.data).sum() 106 | train_score += batch_score 107 | 108 | ## second train 109 | 110 | word_grad_cam = word_grad.sum(2) 111 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000) 112 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask) 113 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask 114 | 115 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq] 116 | 117 | q2 = copy.deepcopy(q_mask) 118 | 119 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1] 120 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0] 121 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1] 122 | if dataset=='cpv1': 123 | m3=m1*18330 124 | else: 125 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0] 126 | q2 = q2 * m2.long() + m3.long() 127 | 128 | pred, _, _ = model(v, q2, None, b, None) 129 | 130 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 131 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 132 | false_ans.scatter_(1, pred_ind, 0) 133 | a2 = a * false_ans 134 | q3 = copy.deepcopy(q) 135 | if dataset=='cpv1': 136 | q3.scatter_(1, w_ind, 18330) 137 | else: 138 | q3.scatter_(1, w_ind, 18455) 139 | 140 | ## third train 141 | 142 | pred, loss, _ = model(v, q3, a2, b, None) 143 | 144 | if (loss != loss).any(): 145 | raise ValueError("NaN loss") 146 | loss.backward() 147 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 148 | optim.step() 149 | optim.zero_grad() 150 | 151 | total_loss += loss.item() * q.size(0) 152 | 153 | elif mode=='v_debias': 154 | ## first train 155 | pred, loss, _ = model(v, q, a, b, None) 156 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 157 | 158 | if (loss != loss).any(): 159 | raise ValueError("NaN loss") 160 | loss.backward() 161 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 162 | optim.step() 163 | optim.zero_grad() 164 | 165 | total_loss += loss.item() * q.size(0) 166 | batch_score = compute_score_with_logits(pred, a.data).sum() 167 | train_score += batch_score 168 | 169 | ##second train 170 | v_mask = torch.zeros(v.shape[0], 36).cuda() 171 | visual_grad_cam = visual_grad.sum(2) 172 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 173 | v_ind = hint_ind[:, :top_hint] 174 | v_grad = visual_grad_cam.gather(1, v_ind) 175 | 176 | if topv==-1: 177 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True) 178 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1) 179 | v_grad_sum=torch.cumsum(v_grad_score,dim=1) 180 | v_grad_mask=(v_grad_sum<=0.65).long() 181 | v_grad_mask[:,0] = 1 182 | v_mask_ind=v_grad_mask*v_ind 183 | for x in range(a.shape[0]): 184 | num=len(torch.nonzero(v_grad_mask[x])) 185 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1) 186 | else: 187 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 188 | v_star = v_ind.gather(1, v_grad_ind) 189 | v_mask.scatter_(1, v_star, 1) 190 | 191 | 192 | pred, _, _ = model(v, q, None, b, v_mask) 193 | 194 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 195 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 196 | false_ans.scatter_(1, pred_ind, 0) 197 | a2 = a * false_ans 198 | 199 | v_mask = 1 - v_mask 200 | 201 | pred, loss, _ = model(v, q, a2, b, v_mask) 202 | 203 | if (loss != loss).any(): 204 | raise ValueError("NaN loss") 205 | loss.backward() 206 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 207 | optim.step() 208 | optim.zero_grad() 209 | 210 | total_loss += loss.item() * q.size(0) 211 | 212 | elif mode=='q_v_debias': 213 | random_num = random.randint(1, 10) 214 | if keep_qtype == True: 215 | sen_mask = type_mask 216 | else: 217 | sen_mask = notype_mask 218 | if random_num<=qvp: 219 | ## first train 220 | pred, loss, word_emb = model(v, q, a, b, None) 221 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0] 222 | 223 | if (loss != loss).any(): 224 | raise ValueError("NaN loss") 225 | loss.backward() 226 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 227 | optim.step() 228 | optim.zero_grad() 229 | 230 | total_loss += loss.item() * q.size(0) 231 | batch_score = compute_score_with_logits(pred, a.data).sum() 232 | train_score += batch_score 233 | 234 | ## second train 235 | 236 | word_grad_cam = word_grad.sum(2) 237 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000) 238 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask) 239 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask 240 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq] 241 | 242 | q2 = copy.deepcopy(q_mask) 243 | 244 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1] 245 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0] 246 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1] 247 | if dataset=='cpv1': 248 | m3=m1*18330 249 | else: 250 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0] 251 | q2 = q2 * m2.long() + m3.long() 252 | 253 | pred, _, _ = model(v, q2, None, b, None) 254 | 255 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 256 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 257 | false_ans.scatter_(1, pred_ind, 0) 258 | a2 = a * false_ans 259 | q3 = copy.deepcopy(q) 260 | if dataset=='cpv1': 261 | q3.scatter_(1, w_ind, 18330) 262 | else: 263 | q3.scatter_(1, w_ind, 18455) 264 | 265 | ## third train 266 | 267 | pred, loss, _ = model(v, q3, a2, b, None) 268 | 269 | if (loss != loss).any(): 270 | raise ValueError("NaN loss") 271 | loss.backward() 272 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 273 | optim.step() 274 | optim.zero_grad() 275 | 276 | total_loss += loss.item() * q.size(0) 277 | 278 | 279 | else: 280 | ## first train 281 | pred, loss, _ = model(v, q, a, b, None) 282 | visual_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 283 | 284 | if (loss != loss).any(): 285 | raise ValueError("NaN loss") 286 | loss.backward() 287 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 288 | optim.step() 289 | optim.zero_grad() 290 | 291 | total_loss += loss.item() * q.size(0) 292 | batch_score = compute_score_with_logits(pred, a.data).sum() 293 | train_score += batch_score 294 | 295 | ##second train 296 | v_mask = torch.zeros(v.shape[0], 36).cuda() 297 | visual_grad_cam = visual_grad.sum(2) 298 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 299 | v_ind = hint_ind[:, :top_hint] 300 | v_grad = visual_grad_cam.gather(1, v_ind) 301 | 302 | if topv == -1: 303 | v_grad_score, v_grad_ind = v_grad.sort(1, descending=True) 304 | v_grad_score = nn.functional.softmax(v_grad_score * 10, dim=1) 305 | v_grad_sum = torch.cumsum(v_grad_score, dim=1) 306 | v_grad_mask = (v_grad_sum <= 0.65).long() 307 | v_grad_mask[:,0] = 1 308 | v_mask_ind = v_grad_mask * v_ind 309 | for x in range(a.shape[0]): 310 | num = len(torch.nonzero(v_grad_mask[x])) 311 | v_mask[x].scatter_(0, v_mask_ind[x,:num], 1) 312 | else: 313 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 314 | v_star = v_ind.gather(1, v_grad_ind) 315 | v_mask.scatter_(1, v_star, 1) 316 | 317 | pred, _, _ = model(v, q, None, b, v_mask) 318 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 319 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 320 | false_ans.scatter_(1, pred_ind, 0) 321 | a2 = a * false_ans 322 | 323 | v_mask = 1 - v_mask 324 | 325 | pred, loss, _ = model(v, q, a2, b, v_mask) 326 | 327 | if (loss != loss).any(): 328 | raise ValueError("NaN loss") 329 | loss.backward() 330 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 331 | optim.step() 332 | optim.zero_grad() 333 | 334 | total_loss += loss.item() * q.size(0) 335 | 336 | if mode=='updn': 337 | total_loss /= len(train_loader.dataset) 338 | else: 339 | total_loss /= len(train_loader.dataset) * 2 340 | train_score = 100 * train_score / len(train_loader.dataset) 341 | 342 | if run_eval: 343 | model.train(False) 344 | results = evaluate(model, eval_loader, qid2type) 345 | results["epoch"] = epoch + 1 346 | results["step"] = total_step 347 | results["train_loss"] = total_loss 348 | results["train_score"] = train_score 349 | 350 | model.train(True) 351 | 352 | eval_score = results["score"] 353 | bound = results["upper_bound"] 354 | yn = results['score_yesno'] 355 | other = results['score_other'] 356 | num = results['score_number'] 357 | 358 | logger.write('epoch %d, time: %.2f' % (epoch, time.time() - t)) 359 | logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score)) 360 | 361 | if run_eval: 362 | logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound)) 363 | logger.write('\tyn score: %.2f other score: %.2f num score: %.2f' % (100 * yn, 100 * other, 100 * num)) 364 | 365 | if eval_score > best_eval_score: 366 | model_path = os.path.join(output, 'model.pth') 367 | torch.save(model.state_dict(), model_path) 368 | best_eval_score = eval_score 369 | 370 | 371 | def evaluate(model, dataloader, qid2type): 372 | score = 0 373 | upper_bound = 0 374 | score_yesno = 0 375 | score_number = 0 376 | score_other = 0 377 | total_yesno = 0 378 | total_number = 0 379 | total_other = 0 380 | 381 | for v, q, a, b, qids, _ in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 382 | v = Variable(v, requires_grad=False).cuda() 383 | q = Variable(q, requires_grad=False).cuda() 384 | pred, _,_ = model(v, q, None, None, None) 385 | batch_score = compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1) 386 | score += batch_score.sum() 387 | upper_bound += (a.max(1)[0]).sum() 388 | qids = qids.detach().cpu().int().numpy() 389 | for j in range(len(qids)): 390 | qid = qids[j] 391 | typ = qid2type[str(qid)] 392 | if typ == 'yes/no': 393 | score_yesno += batch_score[j] 394 | total_yesno += 1 395 | elif typ == 'other': 396 | score_other += batch_score[j] 397 | total_other += 1 398 | elif typ == 'number': 399 | score_number += batch_score[j] 400 | total_number += 1 401 | else: 402 | print('Hahahahahahahahahahaha') 403 | 404 | 405 | score = score / len(dataloader.dataset) 406 | upper_bound = upper_bound / len(dataloader.dataset) 407 | score_yesno /= total_yesno 408 | score_other /= total_other 409 | score_number /= total_number 410 | 411 | results = dict( 412 | score=score, 413 | upper_bound=upper_bound, 414 | score_yesno=score_yesno, 415 | score_other=score_other, 416 | score_number=score_number, 417 | ) 418 | return results 419 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import errno 4 | import os 5 | import numpy as np 6 | # from PIL import Image 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | EPS = 1e-7 12 | 13 | 14 | def assert_eq(real, expected): 15 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 16 | 17 | 18 | def assert_array_eq(real, expected): 19 | assert (np.abs(real-expected) < EPS).all(), \ 20 | '%s (true) vs %s (expected)' % (real, expected) 21 | 22 | 23 | def load_folder(folder, suffix): 24 | imgs = [] 25 | for f in sorted(os.listdir(folder)): 26 | if f.endswith(suffix): 27 | imgs.append(os.path.join(folder, f)) 28 | return imgs 29 | 30 | 31 | # def load_imageid(folder): 32 | # images = load_folder(folder, 'jpg') 33 | # img_ids = set() 34 | # for img in images: 35 | # img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1]) 36 | # img_ids.add(img_id) 37 | # return img_ids 38 | 39 | 40 | # def pil_loader(path): 41 | # with open(path, 'rb') as f: 42 | # with Image.open(f) as img: 43 | # return img.convert('RGB') 44 | 45 | 46 | def weights_init(m): 47 | """custom weights initialization.""" 48 | cname = m.__class__ 49 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d: 50 | m.weight.data.normal_(0.0, 0.02) 51 | elif cname == nn.BatchNorm2d: 52 | m.weight.data.normal_(1.0, 0.02) 53 | m.bias.data.fill_(0) 54 | else: 55 | print('%s is not initialized.' % cname) 56 | 57 | 58 | def init_net(net, net_file): 59 | if net_file: 60 | net.load_state_dict(torch.load(net_file)) 61 | else: 62 | net.apply(weights_init) 63 | 64 | 65 | def create_dir(path): 66 | if not os.path.exists(path): 67 | try: 68 | os.makedirs(path) 69 | except OSError as exc: 70 | if exc.errno != errno.EEXIST: 71 | raise 72 | 73 | 74 | class Logger(object): 75 | def __init__(self, output_name): 76 | dirname = os.path.dirname(output_name) 77 | if not os.path.exists(dirname): 78 | os.mkdir(dirname) 79 | 80 | self.log_file = open(output_name, 'w') 81 | self.infos = {} 82 | 83 | def append(self, key, val): 84 | vals = self.infos.setdefault(key, []) 85 | vals.append(val) 86 | 87 | def log(self, extra_msg=''): 88 | msgs = [extra_msg] 89 | for key, vals in self.infos.iteritems(): 90 | msgs.append('%s %.6f' % (key, np.mean(vals))) 91 | msg = '\n'.join(msgs) 92 | self.log_file.write(msg + '\n') 93 | self.log_file.flush() 94 | self.infos = {} 95 | return msg 96 | 97 | def write(self, msg): 98 | self.log_file.write(msg + '\n') 99 | self.log_file.flush() 100 | print(msg) 101 | -------------------------------------------------------------------------------- /vqa_debias_loss_functions.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict, Counter 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import torch 7 | import inspect 8 | 9 | 10 | def convert_sigmoid_logits_to_binary_logprobs(logits): 11 | """computes log(sigmoid(logits)), log(1-sigmoid(logits))""" 12 | log_prob = -F.softplus(-logits) 13 | log_one_minus_prob = -logits + log_prob 14 | return log_prob, log_one_minus_prob 15 | 16 | 17 | def elementwise_logsumexp(a, b): 18 | """computes log(exp(x) + exp(b))""" 19 | return torch.max(a, b) + torch.log1p(torch.exp(-torch.abs(a - b))) 20 | 21 | 22 | def renormalize_binary_logits(a, b): 23 | """Normalize so exp(a) + exp(b) == 1""" 24 | norm = elementwise_logsumexp(a, b) 25 | return a - norm, b - norm 26 | 27 | 28 | class DebiasLossFn(nn.Module): 29 | """General API for our loss functions""" 30 | 31 | def forward(self, hidden, logits, bias, labels): 32 | """ 33 | :param hidden: [batch, n_hidden] hidden features from the last layer in the model 34 | :param logits: [batch, n_answers_options] sigmoid logits for each answer option 35 | :param bias: [batch, n_answers_options] 36 | bias probabilities for each answer option between 0 and 1 37 | :param labels: [batch, n_answers_options] 38 | scores for each answer option, between 0 and 1 39 | :return: Scalar loss 40 | """ 41 | raise NotImplementedError() 42 | 43 | def to_json(self): 44 | """Get a json representation of this loss function. 45 | 46 | We construct this by looking up the __init__ args 47 | """ 48 | cls = self.__class__ 49 | init = cls.__init__ 50 | if init is object.__init__: 51 | return [] # No init args 52 | 53 | init_signature = inspect.getargspec(init) 54 | if init_signature.varargs is not None: 55 | raise NotImplementedError("varags not supported") 56 | if init_signature.keywords is not None: 57 | raise NotImplementedError("keywords not supported") 58 | args = [x for x in init_signature.args if x != "self"] 59 | out = OrderedDict() 60 | out["name"] = cls.__name__ 61 | for key in args: 62 | out[key] = getattr(self, key) 63 | return out 64 | 65 | 66 | class Plain(DebiasLossFn): 67 | def forward(self, hidden, logits, bias, labels): 68 | loss = F.binary_cross_entropy_with_logits(logits, labels) 69 | 70 | loss *= labels.size(1) 71 | return loss 72 | 73 | 74 | class Focal(DebiasLossFn): 75 | def forward(self, hidden, logits, bias, labels): 76 | # import pdb;pdb.set_trace() 77 | focal_logits=torch.log(F.softmax(logits,dim=1)+1e-5) * ((1-F.softmax(bias,dim=1))*(1-F.softmax(bias,dim=1))) 78 | loss=F.binary_cross_entropy_with_logits(focal_logits,labels) 79 | loss*=labels.size(1) 80 | return loss 81 | 82 | class ReweightByInvBias(DebiasLossFn): 83 | def forward(self, hidden, logits, bias, labels): 84 | # Manually compute the binary cross entropy since the old version of torch always aggregates 85 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 86 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob) 87 | weights = (1 - bias) 88 | loss *= weights # Apply the weights 89 | return loss.sum() / weights.sum() 90 | 91 | 92 | class BiasProduct(DebiasLossFn): 93 | def __init__(self, smooth=True, smooth_init=-1, constant_smooth=0.0): 94 | """ 95 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 96 | :param smooth_init: How to initialize `a` 97 | :param constant_smooth: Constant to add to the bias to smooth it 98 | """ 99 | super(BiasProduct, self).__init__() 100 | self.constant_smooth = constant_smooth 101 | self.smooth_init = smooth_init 102 | self.smooth = smooth 103 | if smooth: 104 | self.smooth_param = torch.nn.Parameter( 105 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 106 | else: 107 | self.smooth_param = None 108 | 109 | def forward(self, hidden, logits, bias, labels): 110 | smooth = self.constant_smooth 111 | if self.smooth: 112 | smooth += F.sigmoid(self.smooth_param) 113 | 114 | # Convert the bias into log-space, with a factor for both the 115 | # binary outputs for each answer option 116 | bias_lp = torch.log(bias + smooth) 117 | bias_l_inv = torch.log1p(-bias + smooth) 118 | 119 | # Convert the the logits into log-space with the same format 120 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 121 | # import pdb;pdb.set_trace() 122 | 123 | # Add the bias 124 | log_prob += bias_lp 125 | log_one_minus_prob += bias_l_inv 126 | 127 | # Re-normalize the factors in logspace 128 | log_prob, log_one_minus_prob = renormalize_binary_logits(log_prob, log_one_minus_prob) 129 | 130 | # Compute the binary cross entropy 131 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 132 | return loss 133 | 134 | 135 | class LearnedMixin(DebiasLossFn): 136 | def __init__(self, w, smooth=True, smooth_init=-1, constant_smooth=0.0): 137 | """ 138 | :param w: Weight of the entropy penalty 139 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 140 | :param smooth_init: How to initialize `a` 141 | :param constant_smooth: Constant to add to the bias to smooth it 142 | """ 143 | super(LearnedMixin, self).__init__() 144 | self.w = w 145 | # self.w=0 146 | self.smooth_init = smooth_init 147 | self.constant_smooth = constant_smooth 148 | self.bias_lin = torch.nn.Linear(1024, 1) 149 | self.smooth = smooth 150 | if self.smooth: 151 | self.smooth_param = torch.nn.Parameter( 152 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 153 | else: 154 | self.smooth_param = None 155 | 156 | def forward(self, hidden, logits, bias, labels): 157 | factor = self.bias_lin.forward(hidden) # [batch, 1] 158 | factor = F.softplus(factor) 159 | 160 | bias = torch.stack([bias, 1 - bias], 2) # [batch, n_answers, 2] 161 | 162 | # Smooth 163 | bias += self.constant_smooth 164 | if self.smooth: 165 | soften_factor = F.sigmoid(self.smooth_param) 166 | bias = bias + soften_factor.unsqueeze(1) 167 | 168 | bias = torch.log(bias) # Convert to logspace 169 | 170 | # Scale by the factor 171 | # [batch, n_answers, 2] * [batch, 1, 1] -> [batch, n_answers, 2] 172 | bias = bias * factor.unsqueeze(1) 173 | 174 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 175 | log_probs = torch.stack([log_prob, log_one_minus_prob], 2) 176 | 177 | # Add the bias in 178 | logits = bias + log_probs 179 | 180 | # Renormalize to get log probabilities 181 | log_prob, log_one_minus_prob = renormalize_binary_logits(logits[:, :, 0], logits[:, :, 1]) 182 | 183 | # Compute loss 184 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 185 | 186 | # Re-normalized version of the bias 187 | bias_norm = elementwise_logsumexp(bias[:, :, 0], bias[:, :, 1]) 188 | bias_logprob = bias - bias_norm.unsqueeze(2) 189 | 190 | # Compute and add the entropy penalty 191 | entropy = -(torch.exp(bias_logprob) * bias_logprob).sum(2).mean() 192 | return loss + self.w * entropy --------------------------------------------------------------------------------