├── README.md ├── attention.py ├── base_model.py ├── classifier.py ├── dataset.py ├── eval.py ├── fc.py ├── language_model.py ├── main.py ├── tools ├── compute_softscore.py ├── create_dictionary.py ├── create_dictionary_v1.py ├── download.sh ├── download_v1.sh └── process.sh ├── train.py ├── utils.py └── vqa_debias_loss_functions.py /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Contrast the Counterfactual Samples for Robust Visual Question Answering 2 | The source code for our paper [Learning to Contrast the Counterfactual Samples for Robust Visual Question Answering](https://www.aclweb.org/anthology/2020.emnlp-main.265.pdf) published in EMNLP 2020. This repo contains code modified from [CSS-VQA](https://github.com/yanxinzju/CSS-VQA), Many thanks for their efforts. 3 | 4 | ### Prerequisites 5 | 6 | Make sure you are on a machine with a NVIDIA GPU and Python 2.7 with about 100 GB disk space.
7 | h5py==2.10.0
8 | pytorch==1.1.0
9 | Click==7.0
10 | numpy==1.16.5
11 | tqdm==4.35.0
12 | 13 | ### Data Setup 14 | All data preprocess and set up please refer to [bottom-up-attention-vqa](https://github.com/hengyuan-hu/bottom-up-attention-vqa) 15 | 16 | 1. Please run the script to download the data. 17 | 18 | ``` 19 | bash tools/download.sh 20 | ``` 21 | 22 | 2. Please click the link [HERE](https://drive.google.com/drive/folders/13e-b76otJukupbjfC-n1s05L202PaFKQ?usp=sharing) to download the rest of the data, which is kindly shared by [CSS-VQA](https://github.com/yanxinzju/CSS-VQA). 23 | 24 | 25 | 26 | ### Training 27 | 28 | All the args for running our code is preset in the main.py. 29 | 30 | Run 31 | ``` 32 | CUDA_VISIBLE_DEVICES=0 python main.py 33 | ``` 34 | to train a model 35 | 36 | ### Testing 37 | Run 38 | ``` 39 | CUDA_VISIBLE_DEVICES=0 python eval.py --dataset [] --debias [] --model_state [] 40 | ``` 41 | to eval a model 42 | 43 | ## Citation 44 | 45 | If you find this paper helps your research, please kindly consider citing our paper in your publications. 46 | 47 | ```BibTeX 48 | @inproceedings{liang2020learning, 49 | title={Learning to Contrast the Counterfactual Samples for Robust Visual Question Answering}, 50 | author={Liang, Zujie and Jiang, Weitao and Hu, Haifeng and Zhu, Jiaying}, 51 | booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 52 | year={2020} 53 | } 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | from fc import FCNet 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, v_dim, q_dim, num_hid): 9 | super(Attention, self).__init__() 10 | self.nonlinear = FCNet([v_dim + q_dim, num_hid]) 11 | self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None) 12 | 13 | def forward(self, v, q): 14 | """ 15 | v: [batch, k, vdim] 16 | q: [batch, qdim] 17 | """ 18 | logits = self.logits(v, q) 19 | w = nn.functional.softmax(logits, 1) 20 | return w 21 | 22 | def logits(self, v, q): 23 | num_objs = v.size(1) 24 | q = q.unsqueeze(1).repeat(1, num_objs, 1) 25 | vq = torch.cat((v, q), 2) 26 | joint_repr = self.nonlinear(vq) 27 | logits = self.linear(joint_repr) 28 | return logits 29 | 30 | 31 | class NewAttention(nn.Module): 32 | def __init__(self, v_dim, q_dim, num_hid, dropout=0.2): 33 | super(NewAttention, self).__init__() 34 | 35 | self.v_proj = FCNet([v_dim, num_hid]) 36 | self.q_proj = FCNet([q_dim, num_hid]) 37 | self.dropout = nn.Dropout(dropout) 38 | self.linear = weight_norm(nn.Linear(q_dim, 1), dim=None) 39 | 40 | def forward(self, v, q): 41 | """ 42 | v: [batch, k, vdim] 43 | q: [batch, qdim] 44 | """ 45 | logits = self.logits(v, q) 46 | # w = nn.functional.softmax(logits, 1) 47 | # return w 48 | return logits 49 | 50 | def logits(self, v, q): 51 | batch, k, _ = v.size() 52 | v_proj = self.v_proj(v) # [batch, k, qdim] 53 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1) 54 | joint_repr = v_proj * q_proj 55 | joint_repr = self.dropout(joint_repr) 56 | logits = self.linear(joint_repr) 57 | return logits 58 | -------------------------------------------------------------------------------- /base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from attention import Attention, NewAttention 4 | from language_model import WordEmbedding, QuestionEmbedding 5 | from classifier import SimpleClassifier 6 | from fc import FCNet 7 | import numpy as np 8 | 9 | def mask_softmax(x,mask): 10 | mask=mask.unsqueeze(2).float() 11 | x2=torch.exp(x-torch.max(x)) 12 | x3=x2*mask 13 | epsilon=1e-5 14 | x3_sum=torch.sum(x3,dim=1,keepdim=True)+epsilon 15 | x4=x3/x3_sum.expand_as(x3) 16 | return x4 17 | 18 | 19 | class BaseModel(nn.Module): 20 | def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier): 21 | super(BaseModel, self).__init__() 22 | self.w_emb = w_emb 23 | self.q_emb = q_emb 24 | self.v_att = v_att 25 | self.q_net = q_net 26 | self.v_net = v_net 27 | self.classifier = classifier 28 | self.debias_loss_fn = None 29 | # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2)) 30 | self.bias_lin = torch.nn.Linear(1024, 1) 31 | self.is_contras = None 32 | 33 | def forward(self, v, q, labels, bias,v_mask): 34 | """Forward 35 | 36 | v: [batch, num_objs, obj_dim] 37 | b: [batch, num_objs, b_dim] 38 | q: [batch_size, seq_length] 39 | 40 | return: logits, not probs 41 | """ 42 | w_emb = self.w_emb(q) 43 | q_emb = self.q_emb(w_emb) # [batch, q_dim] 44 | 45 | 46 | att = self.v_att(v, q_emb) 47 | if v_mask is None: 48 | att = nn.functional.softmax(att, 1) 49 | else: 50 | att= mask_softmax(att,v_mask) 51 | 52 | v_emb = (att * v).sum(1) # [batch, v_dim] 53 | 54 | q_repr = self.q_net(q_emb) 55 | v_repr = self.v_net(v_emb) 56 | joint_repr = q_repr * v_repr 57 | 58 | logits = self.classifier(joint_repr) 59 | 60 | if labels is not None: 61 | loss = self.debias_loss_fn(joint_repr, logits, bias, labels) 62 | 63 | else: 64 | loss = None 65 | if (self.is_contras is True) and (self.training is True): 66 | return logits, loss, w_emb, joint_repr 67 | else: 68 | return logits, loss, w_emb 69 | 70 | def build_baseline0(dataset, num_hid): 71 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 72 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 73 | v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid) 74 | q_net = FCNet([num_hid, num_hid]) 75 | v_net = FCNet([dataset.v_dim, num_hid]) 76 | classifier = SimpleClassifier( 77 | num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5) 78 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) 79 | 80 | 81 | def build_baseline0_newatt(dataset, num_hid): 82 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0) 83 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0) 84 | v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid) 85 | q_net = FCNet([q_emb.num_hid, num_hid]) 86 | v_net = FCNet([dataset.v_dim, num_hid]) 87 | classifier = SimpleClassifier( 88 | num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5) 89 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier) -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.weight_norm import weight_norm 3 | 4 | 5 | class SimpleClassifier(nn.Module): 6 | def __init__(self, in_dim, hid_dim, out_dim, dropout): 7 | super(SimpleClassifier, self).__init__() 8 | layers = [ 9 | weight_norm(nn.Linear(in_dim, hid_dim), dim=None), 10 | nn.ReLU(), 11 | nn.Dropout(dropout, inplace=True), 12 | weight_norm(nn.Linear(hid_dim, out_dim), dim=None) 13 | ] 14 | self.main = nn.Sequential(*layers) 15 | 16 | def forward(self, x): 17 | logits = self.main(x) 18 | return logits 19 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import unicode_literals 3 | 4 | import os 5 | import json 6 | # import cPickle 7 | import pickle as cPickle #python3 8 | from collections import Counter 9 | 10 | import numpy as np 11 | import utils 12 | import h5py 13 | import torch 14 | from torch.utils.data import Dataset 15 | from tqdm import tqdm 16 | from random import choice 17 | 18 | class Dictionary(object): 19 | def __init__(self, word2idx=None, idx2word=None): 20 | if word2idx is None: 21 | word2idx = {} 22 | if idx2word is None: 23 | idx2word = [] 24 | self.word2idx = word2idx 25 | self.idx2word = idx2word 26 | 27 | @property 28 | def ntoken(self): 29 | return len(self.word2idx) 30 | 31 | @property 32 | def padding_idx(self): 33 | return len(self.word2idx) 34 | 35 | def tokenize(self, sentence, add_word): 36 | sentence = sentence.lower() 37 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s').replace('-', 38 | ' ').replace('.','').replace('"', '').replace('n\'t', ' not').replace('$', ' dollar ') 39 | words = sentence.split() 40 | tokens = [] 41 | if add_word: 42 | for w in words: 43 | tokens.append(self.add_word(w)) 44 | else: 45 | for w in words: 46 | if w in self.word2idx: 47 | tokens.append(self.word2idx[w]) 48 | else: 49 | tokens.append(len(self.word2idx)) 50 | return tokens 51 | 52 | def dump_to_file(self, path): 53 | cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 54 | print('dictionary dumped to %s' % path) 55 | 56 | @classmethod 57 | def load_from_file(cls, path): 58 | print('loading dictionary from %s' % path) 59 | word2idx, idx2word = cPickle.load(open(path, 'rb')) 60 | d = cls(word2idx, idx2word) 61 | return d 62 | 63 | def add_word(self, word): 64 | if word not in self.word2idx: 65 | self.idx2word.append(word) 66 | self.word2idx[word] = len(self.idx2word) - 1 67 | return self.word2idx[word] 68 | 69 | def __len__(self): 70 | return len(self.idx2word) 71 | 72 | 73 | def _create_entry(img_idx, question, answer): 74 | answer.pop('image_id') 75 | answer.pop('question_id') 76 | entry = { 77 | 'question_id' : question['question_id'], 78 | 'image_id' : question['image_id'], 79 | 'image_idx' : img_idx, 80 | 'question' : question['question'], 81 | 'answer' : answer 82 | } 83 | return entry 84 | 85 | 86 | def _load_dataset(dataroot, name, img_id2val, dataset): 87 | """Load entries 88 | 89 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features 90 | dataroot: root path of dataset 91 | name: 'train', 'val' 92 | """ 93 | if dataset=='cpv2': 94 | answer_path = os.path.join(dataroot, 'cp-cache', '%s_target.pkl' % name) 95 | name = "train" if name == "train" else "test" 96 | question_path = os.path.join(dataroot, 'vqacp_v2_%s_questions.json' % name) 97 | with open(question_path) as f: 98 | questions = json.load(f) 99 | elif dataset=='cpv1': 100 | answer_path = os.path.join(dataroot, 'cp-v1-cache', '%s_target.pkl' % name) 101 | name = "train" if name == "train" else "test" 102 | question_path = os.path.join(dataroot, 'vqacp_v1_%s_questions.json' % name) 103 | with open(question_path) as f: 104 | questions = json.load(f) 105 | elif dataset=='v2': 106 | answer_path = os.path.join(dataroot, 'cache', '%s_target.pkl' % name) 107 | question_path = os.path.join(dataroot, 'v2_OpenEnded_mscoco_%s2014_questions.json' % name) 108 | with open(question_path) as f: 109 | questions = json.load(f)["questions"] 110 | 111 | with open(answer_path, 'rb') as f: 112 | answers = cPickle.load(f) 113 | 114 | questions.sort(key=lambda x: x['question_id']) 115 | answers.sort(key=lambda x: x['question_id']) 116 | 117 | utils.assert_eq(len(questions), len(answers)) 118 | entries = [] 119 | for question, answer in zip(questions, answers): 120 | if answer["labels"] is None: 121 | raise ValueError() 122 | utils.assert_eq(question['question_id'], answer['question_id']) 123 | utils.assert_eq(question['image_id'], answer['image_id']) 124 | img_id = question['image_id'] 125 | img_idx = None 126 | if img_id2val: 127 | img_idx = img_id2val[img_id] 128 | 129 | entries.append(_create_entry(img_idx, question, answer)) 130 | return entries 131 | 132 | 133 | class VQAFeatureDataset(Dataset): 134 | def __init__(self, name, dictionary, dataroot='data', dataset='cpv2', 135 | use_hdf5=False, cache_image_features=False): 136 | super(VQAFeatureDataset, self).__init__() 137 | self.name=name 138 | if dataset=='cpv2': 139 | with open('data/train_cpv2_hintscore.json', 'r') as f: 140 | self.train_hintscore = json.load(f) 141 | with open('data/test_cpv2_hintscore.json', 'r') as f: 142 | self.test_hintsocre = json.load(f) 143 | with open('util/cpv2_type_mask.json', 'r') as f: 144 | self.type_mask = json.load(f) 145 | with open('util/cpv2_notype_mask.json', 'r') as f: 146 | self.notype_mask = json.load(f) 147 | 148 | elif dataset=='cpv1': 149 | with open('data/train_cpv1_hintscore.json', 'r') as f: 150 | self.train_hintscore = json.load(f) 151 | with open('data/test_cpv1_hintscore.json', 'r') as f: 152 | self.test_hintsocre = json.load(f) 153 | with open('util/cpv1_type_mask.json', 'r') as f: 154 | self.type_mask = json.load(f) 155 | with open('util/cpv1_notype_mask.json', 'r') as f: 156 | self.notype_mask = json.load(f) 157 | elif dataset=='v2': 158 | with open('data/train_v2_hintscore.json', 'r') as f: 159 | self.train_hintscore = json.load(f) 160 | with open('data/test_v2_hintscore.json', 'r') as f: 161 | self.test_hintsocre = json.load(f) 162 | with open('util/v2_type_mask.json', 'r') as f: 163 | self.type_mask = json.load(f) 164 | with open('util/v2_notype_mask.json', 'r') as f: 165 | self.notype_mask = json.load(f) 166 | 167 | assert name in ['train', 'val'] 168 | 169 | if dataset=='cpv2': 170 | ans2label_path = os.path.join(dataroot, 'cp-cache', 'trainval_ans2label.pkl') 171 | label2ans_path = os.path.join(dataroot, 'cp-cache', 'trainval_label2ans.pkl') 172 | elif dataset=='cpv1': 173 | ans2label_path = os.path.join(dataroot, 'cp-v1-cache', 'trainval_ans2label.pkl') 174 | label2ans_path = os.path.join(dataroot, 'cp-v1-cache', 'trainval_label2ans.pkl') 175 | elif dataset=='v2': 176 | ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl') 177 | label2ans_path = os.path.join(dataroot, 'cache', 'trainval_label2ans.pkl') 178 | self.ans2label = cPickle.load(open(ans2label_path, 'rb')) 179 | self.label2ans = cPickle.load(open(label2ans_path, 'rb')) 180 | self.num_ans_candidates = len(self.ans2label) 181 | 182 | self.dictionary = dictionary 183 | self.use_hdf5 = use_hdf5 184 | 185 | if use_hdf5: 186 | h5_path = os.path.join(dataroot, '%s36.hdf5'%name) 187 | self.hf = h5py.File(h5_path, 'r') 188 | self.features = self.hf.get('image_features') 189 | 190 | with open("util/%s36_imgid2img.pkl"%name, "rb") as f: 191 | imgid2idx = cPickle.load(f) 192 | else: 193 | imgid2idx = None 194 | 195 | self.entries = _load_dataset(dataroot, name, imgid2idx, dataset=dataset) 196 | if cache_image_features is True: 197 | image_to_fe = {} 198 | for entry in tqdm(self.entries, ncols=100, desc="caching-features"): 199 | img_id = entry["image_id"] 200 | if img_id not in image_to_fe: 201 | if use_hdf5: 202 | fe = np.array(self.features[imgid2idx[img_id]]) 203 | else: 204 | # fe=torch.load('data/rcnn_feature/'+str(img_id)+'.pth')['image_feature'] 205 | fe = np.fromfile("data/trainval_features/" + str(img_id) + ".bin", np.float32) 206 | fe = torch.from_numpy(fe).view(36, 2048) 207 | image_to_fe[img_id]=fe 208 | self.image_to_fe = image_to_fe 209 | if use_hdf5: 210 | self.hf.close() 211 | else: 212 | self.image_to_fe = None 213 | 214 | self.tokenize() 215 | self.tensorize() 216 | 217 | self.v_dim = 2048 218 | 219 | def tokenize(self, max_length=14): 220 | """Tokenizes the questions. 221 | 222 | This will add q_token in each entry of the dataset. 223 | -1 represent nil, and should be treated as padding_idx in embedding 224 | """ 225 | for entry in tqdm(self.entries, ncols=100, desc="tokenize"): 226 | tokens = self.dictionary.tokenize(entry['question'], False) 227 | tokens = tokens[:max_length] 228 | if len(tokens) < max_length: 229 | # Note here we pad in front of the sentence 230 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 231 | padding_mask=[self.dictionary.padding_idx-1] * (max_length - len(tokens)) 232 | tokens_mask = padding_mask + tokens 233 | tokens = padding + tokens 234 | 235 | utils.assert_eq(len(tokens), max_length) 236 | entry['q_token'] = tokens 237 | entry['q_token_mask']=tokens_mask 238 | 239 | def tensorize(self): 240 | for entry in tqdm(self.entries, ncols=100, desc="tensorize"): 241 | question = torch.from_numpy(np.array(entry['q_token'])) 242 | question_mask = torch.from_numpy(np.array(entry['q_token_mask'])) 243 | 244 | entry['q_token'] = question 245 | entry['q_token_mask']=question_mask 246 | 247 | answer = entry['answer'] 248 | labels = np.array(answer['labels']) 249 | scores = np.array(answer['scores'], dtype=np.float32) 250 | if len(labels): 251 | labels = torch.from_numpy(labels) 252 | scores = torch.from_numpy(scores) 253 | entry['answer']['labels'] = labels 254 | entry['answer']['scores'] = scores 255 | else: 256 | entry['answer']['labels'] = None 257 | entry['answer']['scores'] = None 258 | 259 | def __getitem__(self, index): 260 | entry = self.entries[index] 261 | if self.image_to_fe is not None: 262 | features = self.image_to_fe[entry["image_id"]] 263 | elif self.use_hdf5: 264 | features = np.array(self.features[entry['image_idx']]) 265 | features = torch.from_numpy(features).view(36, 2048) 266 | else: 267 | # features = torch.load('data/rcnn_feature/' + str(entry["image_id"]) + '.pth')['image_feature'] 268 | features = np.fromfile("/data/trainval_features/" + str(entry["image_id"]) + ".bin", np.float32) 269 | features = torch.from_numpy(features).view(36, 2048) 270 | 271 | q_id=entry['question_id'] 272 | ques = entry['q_token'] 273 | ques_mask=entry['q_token_mask'] 274 | answer = entry['answer'] 275 | labels = answer['labels'] 276 | scores = answer['scores'] 277 | target = torch.zeros(self.num_ans_candidates) 278 | if labels is not None: 279 | target.scatter_(0, labels, scores) 280 | 281 | if self.name=='train': 282 | train_hint=torch.tensor(self.train_hintscore[str(q_id)]) 283 | type_mask=torch.tensor(self.type_mask[str(q_id)]) 284 | notype_mask=torch.tensor(self.notype_mask[str(q_id)]) 285 | if "bias" in entry: 286 | return features, ques, target,entry["bias"],train_hint,type_mask,notype_mask,ques_mask 287 | 288 | else: 289 | return features, ques,target, 0,train_hint 290 | else: 291 | test_hint=torch.tensor(self.test_hintsocre[str(q_id)]) 292 | if "bias" in entry: 293 | return features, ques, target, entry["bias"],q_id,test_hint 294 | else: 295 | return features, ques, target, 0,q_id,test_hint 296 | 297 | def __len__(self): 298 | return len(self.entries) 299 | 300 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | # import cPickle 4 | import pickle as cPickle 5 | from collections import defaultdict, Counter 6 | from os.path import dirname, join 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | import os 13 | 14 | # from new_dataset import Dictionary, VQAFeatureDataset 15 | from dataset import Dictionary, VQAFeatureDataset 16 | import base_model 17 | from train import train 18 | import utils 19 | 20 | from vqa_debias_loss_functions import * 21 | from tqdm import tqdm 22 | from torch.autograd import Variable 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method") 27 | 28 | # Arguments we added 29 | parser.add_argument( 30 | '--cache_features', default=True, 31 | help="Cache image features in RAM. Makes things much faster, " 32 | "especially if the filesystem is slow, but requires at least 48gb of RAM") 33 | parser.add_argument( 34 | '--dataset', default='cpv2', help="Run on VQA-2.0 instead of VQA-CP 2.0") 35 | parser.add_argument( 36 | '-p', "--entropy_penalty", default=0.36, type=float, 37 | help="Entropy regularizer weight for the learned_mixin model") 38 | parser.add_argument( 39 | '--debias', default="learned_mixin", 40 | choices=["learned_mixin", "reweight", "bias_product", "none"], 41 | help="Kind of ensemble loss to use") 42 | # Arguments from the original model, we leave this default, except we 43 | # set --epochs to 15 since the model maxes out its performance on VQA 2.0 well before then 44 | parser.add_argument('--num_hid', type=int, default=1024) 45 | parser.add_argument('--model', type=str, default='baseline0_newatt') 46 | parser.add_argument('--batch_size', type=int, default=512) 47 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 48 | parser.add_argument('--model_state', type=str, default='logs/exp0/model.pth') 49 | args = parser.parse_args() 50 | return args 51 | 52 | def compute_score_with_logits(logits, labels): 53 | # logits = torch.max(logits, 1)[1].data # argmax 54 | logits = torch.argmax(logits,1) 55 | one_hots = torch.zeros(*labels.size()).cuda() 56 | one_hots.scatter_(1, logits.view(-1, 1), 1) 57 | scores = (one_hots * labels) 58 | return scores 59 | 60 | 61 | def evaluate(model,dataloader,qid2type): 62 | score = 0 63 | upper_bound = 0 64 | score_yesno = 0 65 | score_number = 0 66 | score_other = 0 67 | total_yesno = 0 68 | total_number = 0 69 | total_other = 0 70 | model.train(False) 71 | # import pdb;pdb.set_trace() 72 | 73 | 74 | for v, q, a, b,qids,hintscore in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 75 | v = Variable(v, requires_grad=False).cuda() 76 | q = Variable(q, requires_grad=False).cuda() 77 | pred, _ ,_= model(v, q, None, None,None) 78 | batch_score= compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1) 79 | score += batch_score.sum() 80 | upper_bound += (a.max(1)[0]).sum() 81 | qids = qids.detach().cpu().int().numpy() 82 | for j in range(len(qids)): 83 | qid=qids[j] 84 | typ = qid2type[str(qid)] 85 | if typ == 'yes/no': 86 | score_yesno += batch_score[j] 87 | total_yesno += 1 88 | elif typ == 'other': 89 | score_other += batch_score[j] 90 | total_other += 1 91 | elif typ == 'number': 92 | score_number += batch_score[j] 93 | total_number += 1 94 | else: 95 | print('Hahahahahahahahahahaha') 96 | score = score / len(dataloader.dataset) 97 | upper_bound = upper_bound / len(dataloader.dataset) 98 | score_yesno /= total_yesno 99 | score_other /= total_other 100 | score_number /= total_number 101 | print('\teval overall score: %.2f' % (100 * score)) 102 | print('\teval up_bound score: %.2f' % (100 * upper_bound)) 103 | print('\teval y/n score: %.2f' % (100 * score_yesno)) 104 | print('\teval other score: %.2f' % (100 * score_other)) 105 | print('\teval number score: %.2f' % (100 * score_number)) 106 | 107 | 108 | def main(): 109 | args = parse_args() 110 | dataset = args.dataset 111 | 112 | 113 | with open('util/qid2type_%s.json'%args.dataset,'r') as f: 114 | qid2type=json.load(f) 115 | 116 | if dataset=='cpv1': 117 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl') 118 | elif dataset=='cpv2' or dataset=='v2': 119 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 120 | 121 | print("Building test dataset...") 122 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset, 123 | cache_image_features=args.cache_features) 124 | 125 | # Build the model using the original constructor 126 | constructor = 'build_%s' % args.model 127 | model = getattr(base_model, constructor)(eval_dset, args.num_hid).cuda() 128 | 129 | if args.debias == "bias_product": 130 | model.debias_loss_fn = BiasProduct() 131 | elif args.debias == "none": 132 | model.debias_loss_fn = Plain() 133 | elif args.debias == "reweight": 134 | model.debias_loss_fn = ReweightByInvBias() 135 | elif args.debias == "learned_mixin": 136 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty) 137 | else: 138 | raise RuntimeError(args.mode) 139 | 140 | 141 | model_state = torch.load(args.model_state) 142 | model.load_state_dict(model_state) 143 | 144 | 145 | model = model.cuda() 146 | batch_size = args.batch_size 147 | 148 | torch.manual_seed(args.seed) 149 | torch.cuda.manual_seed(args.seed) 150 | torch.backends.cudnn.benchmark = True 151 | 152 | # The original version uses multiple workers, but that just seems slower on my setup 153 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0) 154 | 155 | 156 | 157 | print("Starting eval...") 158 | 159 | evaluate(model,eval_loader,qid2type) 160 | 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /fc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | from torch.nn.utils.weight_norm import weight_norm 4 | 5 | 6 | class FCNet(nn.Module): 7 | """Simple class for non-linear fully connect network 8 | """ 9 | def __init__(self, dims): 10 | super(FCNet, self).__init__() 11 | 12 | layers = [] 13 | for i in range(len(dims)-2): 14 | in_dim = dims[i] 15 | out_dim = dims[i+1] 16 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 17 | layers.append(nn.ReLU()) 18 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) 19 | layers.append(nn.ReLU()) 20 | 21 | self.main = nn.Sequential(*layers) 22 | 23 | def forward(self, x): 24 | return self.main(x) 25 | 26 | 27 | if __name__ == '__main__': 28 | fc1 = FCNet([10, 20, 10]) 29 | print(fc1) 30 | 31 | print('============') 32 | fc2 = FCNet([10, 20]) 33 | print(fc2) 34 | -------------------------------------------------------------------------------- /language_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | class WordEmbedding(nn.Module): 8 | """Word Embedding 9 | 10 | The ntoken-th dim is used for padding_idx, which agrees *implicitly* 11 | with the definition in Dictionary. 12 | """ 13 | def __init__(self, ntoken, emb_dim, dropout): 14 | super(WordEmbedding, self).__init__() 15 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken) 16 | self.dropout = nn.Dropout(dropout) 17 | self.ntoken = ntoken 18 | self.emb_dim = emb_dim 19 | 20 | def init_embedding(self, np_file): 21 | weight_init = torch.from_numpy(np.load(np_file)) 22 | assert weight_init.shape == (self.ntoken, self.emb_dim) 23 | self.emb.weight.data[:self.ntoken] = weight_init 24 | 25 | def forward(self, x): 26 | emb = self.emb(x) 27 | emb = self.dropout(emb) 28 | return emb 29 | 30 | 31 | class QuestionEmbedding(nn.Module): 32 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'): 33 | """Module for question embedding 34 | """ 35 | super(QuestionEmbedding, self).__init__() 36 | assert rnn_type == 'LSTM' or rnn_type == 'GRU' 37 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU 38 | 39 | self.rnn = rnn_cls( 40 | in_dim, num_hid, nlayers, 41 | bidirectional=bidirect, 42 | dropout=dropout, 43 | batch_first=True) 44 | 45 | self.in_dim = in_dim 46 | self.num_hid = num_hid 47 | self.nlayers = nlayers 48 | self.rnn_type = rnn_type 49 | self.ndirections = 1 + int(bidirect) 50 | 51 | def init_hidden(self, batch): 52 | # just to get the type of tensor 53 | weight = next(self.parameters()).data 54 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid) 55 | if self.rnn_type == 'LSTM': 56 | return (Variable(weight.new(*hid_shape).zero_()), 57 | Variable(weight.new(*hid_shape).zero_())) 58 | else: 59 | return Variable(weight.new(*hid_shape).zero_()) 60 | 61 | def forward(self, x): 62 | # x: [batch, sequence, in_dim] 63 | batch = x.size(0) 64 | hidden = self.init_hidden(batch) 65 | self.rnn.flatten_parameters() 66 | output, hidden = self.rnn(x, hidden) 67 | 68 | if self.ndirections == 1: 69 | return output[:, -1] 70 | 71 | forward_ = output[:, -1, :self.num_hid] 72 | backward = output[:, 0, self.num_hid:] 73 | return torch.cat((forward_, backward), dim=1) 74 | 75 | def forward_all(self, x): 76 | # x: [batch, sequence, in_dim] 77 | batch = x.size(0) 78 | hidden = self.init_hidden(batch) 79 | self.rnn.flatten_parameters() 80 | output, hidden = self.rnn(x, hidden) 81 | return output 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | # import cPickle as pickle 4 | import pickle as cPickle # python3 5 | from collections import defaultdict, Counter 6 | from os.path import dirname, join 7 | import os 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | 14 | from dataset import Dictionary, VQAFeatureDataset 15 | import base_model 16 | from train import train 17 | import utils 18 | import click 19 | 20 | from vqa_debias_loss_functions import * 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method") 25 | 26 | # Arguments we added 27 | parser.add_argument( 28 | '--cache_features', default=True, #True 29 | help="Cache image features in RAM. Makes things much faster, " 30 | "especially if the filesystem is slow, but requires at least 48gb of RAM") 31 | parser.add_argument( 32 | '--dataset', default='cpv2', 33 | choices=["v2", "cpv2", "cpv1"], 34 | help="Run on VQA-2.0 instead of VQA-CP 2.0" 35 | ) 36 | parser.add_argument( 37 | '-p', "--entropy_penalty", default=0.36, type=float, 38 | help="Entropy regularizer weight for the learned_mixin model") 39 | parser.add_argument( 40 | '--mode', default="q_v_debias_inner_contras", #updn 41 | choices=["updn", "q_debias","v_debias","q_v_debias", "v_debias_contras", "v_debias_ori_contras_eucl", "v_debias_ori_contras_cos", "q_v_debias_ori_contras_cos", "q_v_debias_inner_contras"], 42 | help="Kind of ensemble loss to use") 43 | parser.add_argument( 44 | '--debias', default="learned_mixin", # learned_mixin 45 | choices=["learned_mixin", "reweight", "bias_product", "none",'focal'], 46 | help="Kind of ensemble loss to use") 47 | parser.add_argument( 48 | '--topq', type=int,default=1, 49 | choices=[1,2,3], 50 | help="num of words to be masked in questio") 51 | parser.add_argument( 52 | '--keep_qtype', default=True, 53 | help="keep qtype or not") 54 | parser.add_argument( 55 | '--topv', type=int,default=1, 56 | choices=[1,3,5,-1], 57 | help="num of object bbox to be masked in image") 58 | parser.add_argument( 59 | '--top_hint',type=int, default=9, 60 | choices=[9,18,27,36], 61 | help="num of hint") 62 | parser.add_argument( 63 | '--qvp', type=int,default=5, 64 | choices=[0,1,2,3,4,5,6,7,8,9,10], 65 | help="ratio of q_bias and v_bias") 66 | parser.add_argument( 67 | '--eval_each_epoch', default=True, 68 | help="Evaluate every epoch, instead of at the end") 69 | 70 | # Arguments from the original model, we leave this default, except we 71 | # set --epochs to 30 since the model maxes out its performance on VQA 2.0 well before then 72 | parser.add_argument('--margin', type=float, default=0.3, 73 | help="Margin of the original contrastive loss") 74 | parser.add_argument('--contras_loss_weight', type=int, default=2) 75 | parser.add_argument('--epochs', type=int, default=30) 76 | parser.add_argument('--num_hid', type=int, default=1024) 77 | parser.add_argument('--model', type=str, default='baseline0_newatt') 78 | parser.add_argument('--output', type=str, default='logs/q_v_debias_contras') 79 | parser.add_argument('--batch_size', type=int, default=512) 80 | parser.add_argument('--seed', type=int, default=0, help='random seed') 81 | args = parser.parse_args() 82 | return args 83 | 84 | def get_bias(train_dset,eval_dset): 85 | # Compute the bias: 86 | # The bias here is just the expected score for each answer/question type 87 | answer_voc_size = train_dset.num_ans_candidates 88 | 89 | # question_type -> answer -> total score 90 | question_type_to_probs = defaultdict(Counter) 91 | 92 | # question_type -> num_occurances 93 | question_type_to_count = Counter() 94 | for ex in train_dset.entries: 95 | ans = ex["answer"] 96 | q_type = ans["question_type"] 97 | question_type_to_count[q_type] += 1 98 | if ans["labels"] is not None: 99 | for label, score in zip(ans["labels"], ans["scores"]): 100 | question_type_to_probs[q_type][label] += score 101 | question_type_to_prob_array = {} 102 | 103 | for q_type, count in question_type_to_count.items(): 104 | prob_array = np.zeros(answer_voc_size, np.float32) 105 | for label, total_score in question_type_to_probs[q_type].items(): 106 | prob_array[label] += total_score 107 | prob_array /= count 108 | question_type_to_prob_array[q_type] = prob_array 109 | 110 | for ds in [train_dset,eval_dset]: 111 | for ex in ds.entries: 112 | q_type = ex["answer"]["question_type"] 113 | ex["bias"] = question_type_to_prob_array[q_type] 114 | 115 | 116 | def main(): 117 | args = parse_args() 118 | dataset=args.dataset 119 | args.output=os.path.join('logs',args.output) 120 | if not os.path.isdir(args.output): 121 | utils.create_dir(args.output) 122 | else: 123 | if click.confirm('Exp directory already exists in {}. Erase?' 124 | .format(args.output, default=False)): 125 | os.system('rm -r ' + args.output) 126 | utils.create_dir(args.output) 127 | 128 | else: 129 | os._exit(1) 130 | 131 | 132 | 133 | if dataset=='cpv1': 134 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl') 135 | elif dataset=='cpv2' or dataset=='v2': 136 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 137 | 138 | print('margin of Contrasitive loss:', args.margin) 139 | 140 | print("Building train dataset...") 141 | train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset, 142 | cache_image_features=args.cache_features) 143 | 144 | print("Building test dataset...") 145 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset, 146 | cache_image_features=args.cache_features) 147 | 148 | get_bias(train_dset,eval_dset) 149 | 150 | 151 | # Build the model using the original constructor 152 | constructor = 'build_%s' % args.model 153 | model = getattr(base_model, constructor)(train_dset, args.num_hid).cuda() 154 | if dataset=='cpv1': 155 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.npy') 156 | elif dataset=='cpv2' or dataset=='v2': 157 | model.w_emb.init_embedding('data/glove6b_init_300d.npy') 158 | 159 | # Add the loss_fn based our arguments 160 | if args.debias == "bias_product": 161 | model.debias_loss_fn = BiasProduct() 162 | elif args.debias == "none": 163 | model.debias_loss_fn = Plain() 164 | elif args.debias == "reweight": 165 | model.debias_loss_fn = ReweightByInvBias() 166 | elif args.debias == "learned_mixin": 167 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty) 168 | elif args.debias=='focal': 169 | model.debias_loss_fn = Focal() 170 | else: 171 | raise RuntimeError(args.mode) 172 | 173 | if args.mode =='v_debias_ori_contras_cos' or args.mode=='q_v_debias_ori_contras_cos' or args.mode=='q_v_debias_inner_contras': 174 | model.is_contras = True 175 | 176 | 177 | with open('util/qid2type_%s.json'%args.dataset,'r') as f: 178 | qid2type=json.load(f) 179 | 180 | model=model.cuda() 181 | batch_size = args.batch_size 182 | 183 | torch.manual_seed(args.seed) 184 | torch.cuda.manual_seed(args.seed) 185 | torch.backends.cudnn.benchmark = True 186 | 187 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0) 188 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0) 189 | 190 | print("Starting training...") 191 | train(model, train_loader, eval_loader, args,qid2type) 192 | 193 | if __name__ == '__main__': 194 | main() 195 | 196 | 197 | 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /tools/compute_softscore.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import json 7 | import numpy as np 8 | import re 9 | import cPickle 10 | 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | from dataset import Dictionary 13 | import utils 14 | 15 | 16 | contractions = { 17 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 18 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 19 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 20 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 21 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 22 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 23 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 24 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 25 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 26 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 27 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 28 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 29 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 30 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 31 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 32 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 33 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 34 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 35 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 36 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 37 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 38 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 39 | "someonell": "someone'll", "someones": "someone's", "somethingd": 40 | "something'd", "somethingd've": "something'd've", "something'dve": 41 | "something'd've", "somethingll": "something'll", "thats": 42 | "that's", "thered": "there'd", "thered've": "there'd've", 43 | "there'dve": "there'd've", "therere": "there're", "theres": 44 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 45 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 46 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 47 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 48 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 49 | "what's", "whatve": "what've", "whens": "when's", "whered": 50 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 51 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 52 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 53 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 54 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 55 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 56 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 57 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 58 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 59 | "you'll", "youre": "you're", "youve": "you've" 60 | } 61 | 62 | manual_map = { 'none': '0', 63 | 'zero': '0', 64 | 'one': '1', 65 | 'two': '2', 66 | 'three': '3', 67 | 'four': '4', 68 | 'five': '5', 69 | 'six': '6', 70 | 'seven': '7', 71 | 'eight': '8', 72 | 'nine': '9', 73 | 'ten': '10'} 74 | articles = ['a', 'an', 'the'] 75 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 76 | comma_strip = re.compile("(\d)(\,)(\d)") 77 | punct = [';', r"/", '[', ']', '"', '{', '}', 78 | '(', ')', '=', '+', '\\', '_', '-', 79 | '>', '<', '@', '`', ',', '?', '!'] 80 | 81 | 82 | def get_score(occurences): 83 | if occurences == 0: 84 | return 0 85 | elif occurences == 1: 86 | return 0.3 87 | elif occurences == 2: 88 | return 0.6 89 | elif occurences == 3: 90 | return 0.9 91 | else: 92 | return 1 93 | 94 | 95 | def process_punctuation(inText): 96 | outText = inText 97 | for p in punct: 98 | if (p + ' ' in inText or ' ' + p in inText) \ 99 | or (re.search(comma_strip, inText) != None): 100 | outText = outText.replace(p, '') 101 | else: 102 | outText = outText.replace(p, ' ') 103 | outText = period_strip.sub("", outText, re.UNICODE) 104 | return outText 105 | 106 | 107 | def process_digit_article(inText): 108 | outText = [] 109 | tempText = inText.lower().split() 110 | for word in tempText: 111 | word = manual_map.setdefault(word, word) 112 | if word not in articles: 113 | outText.append(word) 114 | else: 115 | pass 116 | for wordId, word in enumerate(outText): 117 | if word in contractions: 118 | outText[wordId] = contractions[word] 119 | outText = ' '.join(outText) 120 | return outText 121 | 122 | 123 | def multiple_replace(text, wordDict): 124 | for key in wordDict: 125 | text = text.replace(key, wordDict[key]) 126 | return text 127 | 128 | 129 | def preprocess_answer(answer): 130 | answer = process_digit_article(process_punctuation(answer)) 131 | answer = answer.replace(',', '') 132 | return answer 133 | 134 | 135 | def filter_answers(answers_dset, min_occurence): 136 | """This will change the answer to preprocessed version 137 | """ 138 | occurence = {} 139 | for ans_entry in answers_dset: 140 | gtruth = ans_entry['multiple_choice_answer'] 141 | gtruth = preprocess_answer(gtruth) 142 | if gtruth not in occurence: 143 | occurence[gtruth] = set() 144 | occurence[gtruth].add(ans_entry['question_id']) 145 | for answer in occurence.keys(): 146 | if len(occurence[answer]) < min_occurence: 147 | occurence.pop(answer) 148 | 149 | print('Num of answers that appear >= %d times: %d' % ( 150 | min_occurence, len(occurence))) 151 | return occurence 152 | 153 | 154 | 155 | 156 | 157 | 158 | def create_ans2label(occurence, name, cache_root): 159 | """Note that this will also create label2ans.pkl at the same time 160 | 161 | occurence: dict {answer -> whatever} 162 | name: prefix of the output file 163 | cache_root: str 164 | """ 165 | ans2label = {} 166 | label2ans = [] 167 | label = 0 168 | for answer in occurence: 169 | label2ans.append(answer) 170 | ans2label[answer] = label 171 | label += 1 172 | 173 | utils.create_dir(cache_root) 174 | 175 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl') 176 | cPickle.dump(ans2label, open(cache_file, 'wb')) 177 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl') 178 | cPickle.dump(label2ans, open(cache_file, 'wb')) 179 | return ans2label 180 | 181 | 182 | def compute_target(answers_dset, ans2label, name, cache_root): 183 | """Augment answers_dset with soft score as label 184 | 185 | ***answers_dset should be preprocessed*** 186 | 187 | Write result into a cache file 188 | """ 189 | target = [] 190 | for ans_entry in answers_dset: 191 | answers = ans_entry['answers'] 192 | answer_count = {} 193 | for answer in answers: 194 | answer_ = answer['answer'] 195 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 196 | 197 | labels = [] 198 | scores = [] 199 | for answer in answer_count: 200 | if answer not in ans2label: 201 | continue 202 | labels.append(ans2label[answer]) 203 | score = get_score(answer_count[answer]) 204 | scores.append(score) 205 | 206 | label_counts = {} 207 | for k, v in answer_count.items(): 208 | if k in ans2label: 209 | label_counts[ans2label[k]] = v 210 | 211 | target.append({ 212 | 'question_id': ans_entry['question_id'], 213 | 'question_type': ans_entry['question_type'], 214 | 'image_id': ans_entry['image_id'], 215 | 'label_counts': label_counts, 216 | 'labels': labels, 217 | 'scores': scores 218 | }) 219 | 220 | print(cache_root) 221 | utils.create_dir(cache_root) 222 | cache_file = os.path.join(cache_root, name+'_target.pkl') 223 | print(cache_file) 224 | with open(cache_file, 'wb') as f: 225 | cPickle.dump(target, f) 226 | return target 227 | 228 | 229 | 230 | def get_answer(qid, answers): 231 | for ans in answers: 232 | if ans['question_id'] == qid: 233 | return ans 234 | 235 | 236 | def get_question(qid, questions): 237 | for question in questions: 238 | if question['question_id'] == qid: 239 | return question 240 | 241 | 242 | def load_cp(): 243 | train_answer_file = "data/vqacp_v2_train_annotations.json" 244 | with open(train_answer_file) as f: 245 | train_answers = json.load(f) # ['annotations'] 246 | 247 | val_answer_file = "data/vqacp_v2_test_annotations.json" 248 | with open(val_answer_file) as f: 249 | val_answers = json.load(f) # ['annotations'] 250 | 251 | occurence = filter_answers(train_answers, 9) 252 | ans2label = create_ans2label(occurence, 'trainval', "data/cp-cache") 253 | compute_target(train_answers, ans2label, 'train', "data/cp-cache") 254 | compute_target(val_answers, ans2label, 'val', "data/cp-cache") 255 | 256 | def load_cp_v1(): 257 | train_answer_file = "data/vqacp_v1_train_annotations.json" 258 | with open(train_answer_file) as f: 259 | train_answers = json.load(f) # ['annotations'] 260 | 261 | val_answer_file = "data/vqacp_v1_test_annotations.json" 262 | with open(val_answer_file) as f: 263 | val_answers = json.load(f) # ['annotations'] 264 | 265 | occurence = filter_answers(train_answers, 9) 266 | ans2label = create_ans2label(occurence, 'trainval', "data/cp-v1-cache") 267 | compute_target(train_answers, ans2label, 'train', "data/cp-v1-cache") 268 | compute_target(val_answers, ans2label, 'val', "data/cp-v1-cache") 269 | 270 | 271 | def load_v2(): 272 | train_answer_file = 'data/v2_mscoco_train2014_annotations.json' 273 | with open(train_answer_file) as f: 274 | train_answers = json.load(f)['annotations'] 275 | 276 | val_answer_file = 'data/v2_mscoco_val2014_annotations.json' 277 | with open(val_answer_file) as f: 278 | val_answers = json.load(f)['annotations'] 279 | 280 | occurence = filter_answers(train_answers, 9) 281 | ans2label = create_ans2label(occurence, 'trainval', "data/cache") 282 | compute_target(train_answers, ans2label, 'train', "data/cache") 283 | compute_target(val_answers, ans2label, 'val', "data/cache") 284 | 285 | 286 | def main(): 287 | parser = argparse.ArgumentParser("Dataset preprocessing") 288 | parser.add_argument("dataset", choices=["cp_v2", "v2",'cp_v1']) 289 | args = parser.parse_args() 290 | if args.dataset == "v2": 291 | load_v2() 292 | elif args.dataset == "cp_v1": 293 | load_cp_v1() 294 | elif args.dataset=='cp_v2': 295 | load_cp() 296 | 297 | 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /tools/create_dictionary.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | from dataset import Dictionary 8 | 9 | 10 | 11 | 12 | def create_dictionary(dataroot): 13 | dictionary = Dictionary() 14 | questions = [] 15 | files = [ 16 | 'v2_OpenEnded_mscoco_train2014_questions.json', 17 | 'v2_OpenEnded_mscoco_val2014_questions.json', 18 | 'v2_OpenEnded_mscoco_test2015_questions.json', 19 | 'v2_OpenEnded_mscoco_test-dev2015_questions.json' 20 | ] 21 | for path in files: 22 | question_path = os.path.join(dataroot, path) 23 | qs = json.load(open(question_path))['questions'] 24 | for q in qs: 25 | dictionary.tokenize(q['question'], True) 26 | dictionary.tokenize('wordmask',True) 27 | return dictionary 28 | 29 | 30 | def create_glove_embedding_init(idx2word, glove_file): 31 | word2emb = {} 32 | with open(glove_file, 'r') as f: 33 | entries = f.readlines() 34 | emb_dim = len(entries[0].split(' ')) - 1 35 | print('embedding dim is %d' % emb_dim) 36 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 37 | 38 | for entry in entries: 39 | vals = entry.split(' ') 40 | word = vals[0] 41 | vals = map(float, vals[1:]) 42 | word2emb[word] = np.array(vals) 43 | for idx, word in enumerate(idx2word): 44 | if word not in word2emb: 45 | continue 46 | weights[idx] = word2emb[word] 47 | return weights, word2emb 48 | 49 | 50 | if __name__ == '__main__': 51 | d = create_dictionary('data') 52 | d.dump_to_file('data/dictionary.pkl') 53 | 54 | d = Dictionary.load_from_file('data/dictionary.pkl') 55 | emb_dim = 300 56 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 57 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 58 | np.save('data/glove6b_init_%dd.npy' % emb_dim, weights) 59 | -------------------------------------------------------------------------------- /tools/create_dictionary_v1.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import numpy as np 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | from dataset import Dictionary 8 | 9 | 10 | def create_dictionary(dataroot): 11 | dictionary = Dictionary() 12 | questions = [] 13 | files = [ 14 | 'OpenEnded_mscoco_train2014_questions.json', 15 | 'OpenEnded_mscoco_val2014_questions.json', 16 | 'OpenEnded_mscoco_test2015_questions.json', 17 | 'OpenEnded_mscoco_test-dev2015_questions.json' 18 | ] 19 | for path in files: 20 | question_path = os.path.join(dataroot, path) 21 | qs = json.load(open(question_path))['questions'] 22 | for q in qs: 23 | dictionary.tokenize(q['question'], True) 24 | return dictionary 25 | 26 | 27 | def create_glove_embedding_init(idx2word, glove_file): 28 | word2emb = {} 29 | with open(glove_file, 'r') as f: 30 | entries = f.readlines() 31 | emb_dim = len(entries[0].split(' ')) - 1 32 | print('embedding dim is %d' % emb_dim) 33 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 34 | 35 | for entry in entries: 36 | vals = entry.split(' ') 37 | word = vals[0] 38 | vals = map(float, vals[1:]) 39 | word2emb[word] = np.array(vals) 40 | for idx, word in enumerate(idx2word): 41 | if word not in word2emb: 42 | continue 43 | weights[idx] = word2emb[word] 44 | return weights, word2emb 45 | 46 | 47 | if __name__ == '__main__': 48 | d = create_dictionary('data') 49 | d.dump_to_file('data/dictionary_v1.pkl') 50 | 51 | d = Dictionary.load_from_file('data/dictionary_v1.pkl') 52 | emb_dim = 300 53 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 54 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 55 | np.save('data/glove6b_init_%dd_v1.npy' % emb_dim, weights) 56 | -------------------------------------------------------------------------------- /tools/download.sh: -------------------------------------------------------------------------------- 1 | ## Script for downloading data 2 | 3 | # GloVe Vectors 4 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip 5 | unzip data/glove.6B.zip -d data/glove 6 | rm data/glove.6B.zip 7 | 8 | # VQA-CP2 9 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json 10 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json 11 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json 12 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json 13 | 14 | # VQA-CP1 15 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json 16 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json 17 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json 18 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json 19 | 20 | # VQA-V2 21 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip 22 | unzip data/v2_Questions_Train_mscoco.zip -d data 23 | rm data/v2_Questions_Train_mscoco.zip 24 | 25 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip 26 | unzip data/v2_Questions_Val_mscoco.zip -d data 27 | rm data/v2_Questions_Val_mscoco.zip 28 | 29 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip 30 | unzip data/v2_Questions_Test_mscoco.zip -d data 31 | rm data/v2_Questions_Test_mscoco.zip 32 | 33 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip 34 | unzip data/v2_Annotations_Train_mscoco.zip -d data 35 | rm data/v2_Annotations_Train_mscoco.zip 36 | 37 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip 38 | unzip data/v2_Annotations_Val_mscoco.zip -d data 39 | rm data/v2_Annotations_Val_mscoco.zip 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /tools/download_v1.sh: -------------------------------------------------------------------------------- 1 | 2 | # VQA-CP1 3 | wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json 4 | wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json 5 | wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json 6 | wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json 7 | -------------------------------------------------------------------------------- /tools/process.sh: -------------------------------------------------------------------------------- 1 | # Process data 2 | 3 | python tools/create_dictionary.py 4 | python tools/create_dictionary_v1.py 5 | python tools/compute_softscore.py v2 6 | python tools/compute_softscore.py cp_v1 7 | python tools/compute_softscore.py cp_v2 8 | 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import time 5 | from os.path import join 6 | 7 | import torch 8 | import torch.nn as nn 9 | import utils 10 | from torch.autograd import Variable 11 | import numpy as np 12 | from tqdm import tqdm 13 | import random 14 | import copy 15 | 16 | def compute_score_with_logits(logits, labels): 17 | logits = torch.argmax(logits, 1) 18 | one_hots = torch.zeros(*labels.size()).cuda() 19 | one_hots.scatter_(1, logits.view(-1, 1), 1) 20 | scores = (one_hots * labels) 21 | return scores 22 | 23 | def train(model, train_loader, eval_loader,args,qid2type): 24 | 25 | dataset=args.dataset 26 | num_epochs=args.epochs 27 | mode=args.mode 28 | run_eval=args.eval_each_epoch 29 | output=args.output 30 | optim = torch.optim.Adamax(model.parameters()) 31 | logger = utils.Logger(os.path.join(output, 'log.txt')) 32 | total_step = 0 33 | best_eval_score = 0 34 | 35 | if mode=='q_debias': 36 | topq=args.topq 37 | keep_qtype=args.keep_qtype 38 | elif mode=='v_debias': 39 | topv=args.topv 40 | top_hint=args.top_hint 41 | elif mode=='v_debias_contras' or mode=='v_debias_ori_contras_cos' or mode == 'v_debias_ori_contras_eucl': 42 | topv=args.topv 43 | top_hint=args.top_hint 44 | elif mode=='q_v_debias' or mode=='q_v_debias_ori_contras_cos' or mode=='q_v_debias_inner_contras': 45 | topv=args.topv 46 | top_hint=args.top_hint 47 | topq=args.topq 48 | keep_qtype=args.keep_qtype 49 | qvp=args.qvp 50 | 51 | 52 | for epoch in range(num_epochs): 53 | total_loss = 0 54 | train_score = 0 55 | 56 | t = time.time() 57 | for i, (v, q, a, b, hintscore,type_mask,notype_mask,q_mask) in tqdm(enumerate(train_loader), ncols=100, 58 | desc="Epoch %d" % (epoch + 1), total=len(train_loader)): 59 | 60 | total_step += 1 61 | 62 | 63 | ######################################### 64 | v = Variable(v).cuda().requires_grad_() 65 | q = Variable(q).cuda() 66 | q_mask=Variable(q_mask).cuda() 67 | a = Variable(a).cuda() 68 | b = Variable(b).cuda() 69 | hintscore = Variable(hintscore).cuda() 70 | type_mask=Variable(type_mask).float().cuda() 71 | notype_mask=Variable(notype_mask).float().cuda() 72 | ######################################### 73 | 74 | if mode=='updn': 75 | pred, loss,_ = model(v, q, a, b, None) 76 | if (loss != loss).any(): 77 | raise ValueError("NaN loss") 78 | loss.backward() 79 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 80 | optim.step() 81 | optim.zero_grad() 82 | 83 | total_loss += loss.item() * q.size(0) 84 | batch_score = compute_score_with_logits(pred, a.data).sum() 85 | train_score += batch_score 86 | 87 | elif mode=='q_debias': 88 | if keep_qtype==True: 89 | sen_mask=type_mask 90 | else: 91 | sen_mask=notype_mask 92 | ## first train 93 | pred, loss,word_emb = model(v, q, a, b, None) 94 | 95 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0] 96 | 97 | if (loss != loss).any(): 98 | raise ValueError("NaN loss") 99 | loss.backward() 100 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 101 | optim.step() 102 | optim.zero_grad() 103 | 104 | total_loss += loss.item() * q.size(0) 105 | batch_score = compute_score_with_logits(pred, a.data).sum() 106 | train_score += batch_score 107 | 108 | ## second train 109 | 110 | word_grad_cam = word_grad.sum(2) 111 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000) 112 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask) 113 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask 114 | 115 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq] 116 | 117 | q2 = copy.deepcopy(q_mask) 118 | 119 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1] 120 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0] 121 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1] 122 | if dataset=='cpv1': 123 | m3=m1*18330 124 | else: 125 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0] 126 | q2 = q2 * m2.long() + m3.long() 127 | 128 | pred, _, _ = model(v, q2, None, b, None) 129 | 130 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 131 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 132 | false_ans.scatter_(1, pred_ind, 0) 133 | a2 = a * false_ans 134 | q3 = copy.deepcopy(q) 135 | if dataset=='cpv1': 136 | q3.scatter_(1, w_ind, 18330) 137 | else: 138 | q3.scatter_(1, w_ind, 18455) 139 | 140 | ## third train 141 | 142 | pred, loss, _ = model(v, q3, a2, b, None) 143 | 144 | if (loss != loss).any(): 145 | raise ValueError("NaN loss") 146 | loss.backward() 147 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 148 | optim.step() 149 | optim.zero_grad() 150 | 151 | total_loss += loss.item() * q.size(0) 152 | 153 | elif mode=='v_debias': 154 | ## first train 155 | pred, loss, _ = model(v, q, a, b, None) 156 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 157 | 158 | if (loss != loss).any(): 159 | raise ValueError("NaN loss") 160 | loss.backward() 161 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 162 | optim.step() 163 | optim.zero_grad() 164 | 165 | total_loss += loss.item() * q.size(0) 166 | batch_score = compute_score_with_logits(pred, a.data).sum() 167 | train_score += batch_score 168 | 169 | ##second train 170 | v_mask = torch.zeros(v.shape[0], 36).cuda() 171 | visual_grad_cam = visual_grad.sum(2) 172 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 173 | v_ind = hint_ind[:, :top_hint] 174 | v_grad = visual_grad_cam.gather(1, v_ind) 175 | 176 | if topv==-1: 177 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True) 178 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1) 179 | v_grad_sum=torch.cumsum(v_grad_score,dim=1) 180 | v_grad_mask=(v_grad_sum<=0.65).long() 181 | v_grad_mask[:,0] = 1 182 | v_mask_ind=v_grad_mask*v_ind 183 | for x in range(a.shape[0]): 184 | num=len(torch.nonzero(v_grad_mask[x])) 185 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1) 186 | else: 187 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 188 | v_star = v_ind.gather(1, v_grad_ind) 189 | v_mask.scatter_(1, v_star, 1) 190 | 191 | 192 | pred, _, _ = model(v, q, None, b, v_mask) # P = VQA(I+,Q) 193 | 194 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 195 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 196 | false_ans.scatter_(1, pred_ind, 0) 197 | a2 = a * false_ans 198 | 199 | v_mask = 1 - v_mask 200 | 201 | pred, loss, _ = model(v, q, a2, b, v_mask) # P = VQA(I-,Q) 202 | 203 | if (loss != loss).any(): 204 | raise ValueError("NaN loss") 205 | loss.backward() 206 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 207 | optim.step() 208 | optim.zero_grad() 209 | 210 | total_loss += loss.item() * q.size(0) 211 | 212 | elif mode=='v_debias_contras': 213 | ## first train 214 | pred, loss, _, feature_ori = model(v, q, a, b, None) 215 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 216 | 217 | if (loss != loss).any(): 218 | raise ValueError("NaN loss") 219 | loss.backward(retain_graph=True) 220 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 221 | optim.step() 222 | optim.zero_grad() 223 | 224 | total_loss += loss.item() * q.size(0) 225 | batch_score = compute_score_with_logits(pred, a.data).sum() 226 | train_score += batch_score 227 | 228 | ##second train 229 | v_mask = torch.zeros(v.shape[0], 36).cuda() 230 | visual_grad_cam = visual_grad.sum(2) 231 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 232 | v_ind = hint_ind[:, :top_hint] 233 | v_grad = visual_grad_cam.gather(1, v_ind) 234 | 235 | if topv==-1: 236 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True) 237 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1) 238 | v_grad_sum=torch.cumsum(v_grad_score,dim=1) 239 | v_grad_mask=(v_grad_sum<=0.65).long() 240 | v_grad_mask[:,0] = 1 241 | v_mask_ind=v_grad_mask*v_ind 242 | for x in range(a.shape[0]): 243 | num=len(torch.nonzero(v_grad_mask[x])) 244 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1) 245 | else: 246 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 247 | v_star = v_ind.gather(1, v_grad_ind) 248 | v_mask.scatter_(1, v_star, 1) 249 | 250 | 251 | pred, _, _, feature_pos = model(v, q, None, b, v_mask) # P = VQA(I+,Q) 252 | 253 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 254 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 255 | false_ans.scatter_(1, pred_ind, 0) 256 | a2 = a * false_ans 257 | 258 | v_mask = 1 - v_mask 259 | 260 | pred, loss, _, feature_neg = model(v, q, a2, b, v_mask) # P = VQA(I-,Q) 261 | 262 | #Compute the contras loss 263 | # inner product 264 | # pos = (feature_ori * feature_pos).sum(1) 265 | # neg = (feature_ori * feature_neg).sum(1) 266 | # cosine_similarity 267 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1) 268 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1) 269 | 270 | logit = torch.stack((pos, neg), 1) # [b, 2] 271 | softmax_logit = nn.functional.softmax(logit, 1) #[b, 2] 272 | 273 | contras_loss = - torch.log(softmax_logit[:, 0]) 274 | contras_loss = contras_loss.mean() 275 | 276 | loss += contras_loss 277 | 278 | if (loss != loss).any(): 279 | raise ValueError("NaN loss") 280 | loss.backward() 281 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 282 | optim.step() 283 | optim.zero_grad() 284 | 285 | total_loss += loss.item() * q.size(0) 286 | 287 | 288 | elif mode=='v_debias_ori_contras_cos' or mode=='v_debias_ori_contras_eucl' or mode == 'q_v_debias_ori_contras_cos': 289 | if not mode == 'q_v_debias_ori_contras_cos': 290 | ## first train 291 | pred, loss, _, feature_ori = model(v, q, a, b, None) 292 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 293 | 294 | if (loss != loss).any(): 295 | raise ValueError("NaN loss") 296 | loss.backward(retain_graph=True) 297 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 298 | optim.step() 299 | optim.zero_grad() 300 | 301 | total_loss += loss.item() * q.size(0) 302 | batch_score = compute_score_with_logits(pred, a.data).sum() 303 | train_score += batch_score 304 | 305 | ##second train 306 | v_mask = torch.zeros(v.shape[0], 36).cuda() 307 | visual_grad_cam = visual_grad.sum(2) 308 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 309 | v_ind = hint_ind[:, :top_hint] 310 | v_grad = visual_grad_cam.gather(1, v_ind) 311 | 312 | if topv==-1: 313 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True) 314 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1) 315 | v_grad_sum=torch.cumsum(v_grad_score,dim=1) 316 | v_grad_mask=(v_grad_sum<=0.65).long() 317 | v_grad_mask[:,0] = 1 318 | v_mask_ind=v_grad_mask*v_ind 319 | for x in range(a.shape[0]): 320 | num=len(torch.nonzero(v_grad_mask[x])) 321 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1) 322 | else: 323 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 324 | v_star = v_ind.gather(1, v_grad_ind) 325 | v_mask.scatter_(1, v_star, 1) 326 | 327 | 328 | pred, _, _, feature_pos = model(v, q, None, b, v_mask) # P = VQA(I+,Q) 329 | 330 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 331 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 332 | false_ans.scatter_(1, pred_ind, 0) 333 | a2 = a * false_ans 334 | 335 | v_mask = 1 - v_mask 336 | 337 | pred, loss, _, feature_neg = model(v, q, a2, b, v_mask) # P = VQA(I-,Q) 338 | 339 | #Compute the contras loss 340 | # inner product 341 | # pos = (feature_ori * feature_pos).sum(1) 342 | # neg = (feature_ori * feature_neg).sum(1) 343 | 344 | # cosine_similarity 345 | if mode=='v_debias_ori_contras_cos': 346 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1) 347 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1) 348 | pos_dis = 1 - pos 349 | neg_dis = 1 - neg 350 | contras_loss_pos = pos_dis 351 | margin = args.margin 352 | contras_loss_neg = torch.clamp(margin - neg_dis, min=0.0) 353 | # euclidean_distance 354 | elif mode=='v_debias_ori_contras_eucl': 355 | pos_dis = nn.functional.pairwise_distance(feature_ori, feature_pos, p=2) 356 | neg_dis = nn.functional.pairwise_distance(feature_ori, feature_neg, p=2) 357 | contras_loss_pos = torch.pow(pos_dis, 2) 358 | margin = args.margin 359 | contras_loss_neg = torch.pow(torch.clamp(margin - neg_dis, min=0.0), 2) 360 | 361 | contras_loss = contras_loss_pos + contras_loss_neg 362 | contras_loss = contras_loss.mean() 363 | 364 | loss += contras_loss 365 | 366 | if (loss != loss).any(): 367 | raise ValueError("NaN loss") 368 | loss.backward() 369 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 370 | optim.step() 371 | optim.zero_grad() 372 | 373 | total_loss += loss.item() * q.size(0) 374 | 375 | elif mode=='q_v_debias_ori_contras_cos': 376 | random_num = random.randint(1, 10) 377 | if keep_qtype == True: 378 | sen_mask = type_mask 379 | else: 380 | sen_mask = notype_mask 381 | if random_num<=qvp: 382 | ## first train 383 | pred, loss, word_emb, feature_ori = model(v, q, a, b, None) 384 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0] 385 | 386 | if (loss != loss).any(): 387 | raise ValueError("NaN loss") 388 | loss.backward(retain_graph=True) 389 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 390 | optim.step() 391 | optim.zero_grad() 392 | 393 | total_loss += loss.item() * q.size(0) 394 | batch_score = compute_score_with_logits(pred, a.data).sum() 395 | train_score += batch_score 396 | 397 | ## second train 398 | 399 | word_grad_cam = word_grad.sum(2) 400 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000) 401 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask) 402 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask 403 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq] 404 | 405 | q2 = copy.deepcopy(q_mask) 406 | 407 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1] 408 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0] 409 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1] 410 | if dataset=='cpv1': 411 | m3=m1*18330 412 | else: 413 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0] 414 | q2 = q2 * m2.long() + m3.long() 415 | 416 | pred, _, _, feature_pos = model(v, q2, None, b, None) 417 | 418 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 419 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 420 | false_ans.scatter_(1, pred_ind, 0) 421 | a2 = a * false_ans 422 | q3 = copy.deepcopy(q) 423 | if dataset=='cpv1': 424 | q3.scatter_(1, w_ind, 18330) 425 | else: 426 | q3.scatter_(1, w_ind, 18455) 427 | 428 | ## third train 429 | 430 | pred, loss, _, feature_neg = model(v, q3, a2, b, None) 431 | 432 | # cosine_similarity 433 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1) 434 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1) 435 | pos_dis = 1 - pos 436 | neg_dis = 1 - neg 437 | contras_loss_pos = pos_dis 438 | margin = args.margin 439 | contras_loss_neg = torch.clamp(margin - neg_dis, min=0.0) 440 | 441 | 442 | contras_loss = contras_loss_pos + contras_loss_neg 443 | contras_loss = contras_loss.mean() 444 | 445 | loss += args.contras_loss_weight * contras_loss # with CSS 446 | # loss = contras_loss # only_contras_loss 447 | 448 | if (loss != loss).any(): 449 | raise ValueError("NaN loss") 450 | loss.backward() 451 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 452 | optim.step() 453 | optim.zero_grad() 454 | 455 | total_loss += loss.item() * q.size(0) 456 | 457 | 458 | else: 459 | ## first train 460 | pred, loss, _, feature_ori = model(v, q, a, b, None) 461 | visual_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 462 | 463 | if (loss != loss).any(): 464 | raise ValueError("NaN loss") 465 | loss.backward(retain_graph=True) 466 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 467 | optim.step() 468 | optim.zero_grad() 469 | 470 | total_loss += loss.item() * q.size(0) 471 | batch_score = compute_score_with_logits(pred, a.data).sum() 472 | train_score += batch_score 473 | 474 | ##second train 475 | v_mask = torch.zeros(v.shape[0], 36).cuda() 476 | visual_grad_cam = visual_grad.sum(2) 477 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 478 | v_ind = hint_ind[:, :top_hint] 479 | v_grad = visual_grad_cam.gather(1, v_ind) 480 | 481 | if topv == -1: 482 | v_grad_score, v_grad_ind = v_grad.sort(1, descending=True) 483 | v_grad_score = nn.functional.softmax(v_grad_score * 10, dim=1) 484 | v_grad_sum = torch.cumsum(v_grad_score, dim=1) 485 | v_grad_mask = (v_grad_sum <= 0.65).long() 486 | v_grad_mask[:,0] = 1 487 | v_mask_ind = v_grad_mask * v_ind 488 | for x in range(a.shape[0]): 489 | num = len(torch.nonzero(v_grad_mask[x])) 490 | v_mask[x].scatter_(0, v_mask_ind[x,:num], 1) 491 | else: 492 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 493 | v_star = v_ind.gather(1, v_grad_ind) 494 | v_mask.scatter_(1, v_star, 1) 495 | 496 | pred, _, _, feature_pos = model(v, q, None, b, v_mask) 497 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 498 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 499 | false_ans.scatter_(1, pred_ind, 0) 500 | a2 = a * false_ans 501 | 502 | v_mask = 1 - v_mask 503 | 504 | pred, loss, _, feature_neg = model(v, q, a2, b, v_mask) 505 | 506 | # cosine_similarity 507 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1) 508 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1) 509 | pos_dis = 1 - pos 510 | neg_dis = 1 - neg 511 | contras_loss_pos = pos_dis 512 | margin = args.margin 513 | contras_loss_neg = torch.clamp(margin - neg_dis, min=0.0) 514 | 515 | 516 | contras_loss = contras_loss_pos + contras_loss_neg 517 | contras_loss = contras_loss.mean() 518 | 519 | loss += args.contras_loss_weight * contras_loss # with CSS 520 | # loss = contras_loss # only_contras_loss 521 | 522 | if (loss != loss).any(): 523 | raise ValueError("NaN loss") 524 | loss.backward() 525 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 526 | optim.step() 527 | optim.zero_grad() 528 | 529 | total_loss += loss.item() * q.size(0) 530 | 531 | 532 | elif mode=='q_v_debias': 533 | random_num = random.randint(1, 10) 534 | if keep_qtype == True: 535 | sen_mask = type_mask 536 | else: 537 | sen_mask = notype_mask 538 | if random_num<=qvp: 539 | ## first train 540 | pred, loss, word_emb = model(v, q, a, b, None) 541 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0] 542 | 543 | if (loss != loss).any(): 544 | raise ValueError("NaN loss") 545 | loss.backward() 546 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 547 | optim.step() 548 | optim.zero_grad() 549 | 550 | total_loss += loss.item() * q.size(0) 551 | batch_score = compute_score_with_logits(pred, a.data).sum() 552 | train_score += batch_score 553 | 554 | ## second train 555 | 556 | word_grad_cam = word_grad.sum(2) 557 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000) 558 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask) 559 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask 560 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq] 561 | 562 | q2 = copy.deepcopy(q_mask) 563 | 564 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1] 565 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0] 566 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1] 567 | if dataset=='cpv1': 568 | m3=m1*18330 569 | else: 570 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0] 571 | q2 = q2 * m2.long() + m3.long() 572 | 573 | pred, _, _ = model(v, q2, None, b, None) 574 | 575 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 576 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 577 | false_ans.scatter_(1, pred_ind, 0) 578 | a2 = a * false_ans 579 | q3 = copy.deepcopy(q) 580 | if dataset=='cpv1': 581 | q3.scatter_(1, w_ind, 18330) 582 | else: 583 | q3.scatter_(1, w_ind, 18455) 584 | 585 | ## third train 586 | 587 | pred, loss, _ = model(v, q3, a2, b, None) 588 | 589 | if (loss != loss).any(): 590 | raise ValueError("NaN loss") 591 | loss.backward() 592 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 593 | optim.step() 594 | optim.zero_grad() 595 | 596 | total_loss += loss.item() * q.size(0) 597 | 598 | 599 | else: 600 | ## first train 601 | pred, loss, _ = model(v, q, a, b, None) 602 | visual_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 603 | 604 | if (loss != loss).any(): 605 | raise ValueError("NaN loss") 606 | loss.backward() 607 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 608 | optim.step() 609 | optim.zero_grad() 610 | 611 | total_loss += loss.item() * q.size(0) 612 | batch_score = compute_score_with_logits(pred, a.data).sum() 613 | train_score += batch_score 614 | 615 | ##second train 616 | v_mask = torch.zeros(v.shape[0], 36).cuda() 617 | visual_grad_cam = visual_grad.sum(2) 618 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 619 | v_ind = hint_ind[:, :top_hint] 620 | v_grad = visual_grad_cam.gather(1, v_ind) 621 | 622 | if topv == -1: 623 | v_grad_score, v_grad_ind = v_grad.sort(1, descending=True) 624 | v_grad_score = nn.functional.softmax(v_grad_score * 10, dim=1) 625 | v_grad_sum = torch.cumsum(v_grad_score, dim=1) 626 | v_grad_mask = (v_grad_sum <= 0.65).long() 627 | v_grad_mask[:,0] = 1 628 | v_mask_ind = v_grad_mask * v_ind 629 | for x in range(a.shape[0]): 630 | num = len(torch.nonzero(v_grad_mask[x])) 631 | v_mask[x].scatter_(0, v_mask_ind[x,:num], 1) 632 | else: 633 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 634 | v_star = v_ind.gather(1, v_grad_ind) 635 | v_mask.scatter_(1, v_star, 1) 636 | 637 | pred, _, _ = model(v, q, None, b, v_mask) 638 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5] 639 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda() 640 | false_ans.scatter_(1, pred_ind, 0) 641 | a2 = a * false_ans 642 | 643 | v_mask = 1 - v_mask 644 | 645 | pred, loss, _ = model(v, q, a2, b, v_mask) 646 | 647 | if (loss != loss).any(): 648 | raise ValueError("NaN loss") 649 | loss.backward() 650 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 651 | optim.step() 652 | optim.zero_grad() 653 | 654 | total_loss += loss.item() * q.size(0) 655 | 656 | elif mode=='q_v_debias_inner_contras': 657 | random_num = random.randint(1, 10) 658 | if keep_qtype == True: 659 | sen_mask = type_mask 660 | else: 661 | sen_mask = notype_mask 662 | if random_num<=qvp: 663 | ## first train 664 | pred, loss, word_emb, feature_ori = model(v, q, a, b, None) 665 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0] 666 | 667 | if (loss != loss).any(): 668 | raise ValueError("NaN loss") 669 | loss.backward(retain_graph=True) 670 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 671 | optim.step() 672 | optim.zero_grad() 673 | 674 | total_loss += loss.item() * q.size(0) 675 | batch_score = compute_score_with_logits(pred, a.data).sum() 676 | train_score += batch_score 677 | 678 | ## second train 679 | 680 | word_grad_cam = word_grad.sum(2) 681 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000) 682 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask) 683 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask 684 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq] 685 | 686 | q2 = copy.deepcopy(q_mask) 687 | 688 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1] 689 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0] 690 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1] 691 | if dataset=='cpv1': 692 | m3=m1*18330 693 | else: 694 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0] 695 | q2 = q2 * m2.long() + m3.long() 696 | 697 | pred_pos, _, _, feature_pos = model(v, q2, None, b, None) # P = VQA(I,Q+, ) 698 | pred_ind = torch.argsort(pred_pos, 1, descending=True)[:, :5] 699 | false_ans = torch.ones(pred_pos.shape[0], pred_pos.shape[1]).cuda() 700 | false_ans.scatter_(1, pred_ind, 0) 701 | a2 = a * false_ans 702 | q3 = copy.deepcopy(q) 703 | if dataset=='cpv1': 704 | q3.scatter_(1, w_ind, 18330) 705 | else: 706 | q3.scatter_(1, w_ind, 18455) 707 | 708 | ## third train 709 | 710 | pred_neg, loss_neg, _, feature_neg = model(v, q3, a2, b, None) # P = VQA(I,Q-, ) 711 | 712 | 713 | #Compute the contras loss 714 | # inner product 715 | # pos = (feature_ori * feature_pos).sum(1) 716 | # neg = (feature_ori * feature_neg).sum(1) 717 | # cosine_similarity 718 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1) 719 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1) 720 | 721 | logit = torch.stack((pos, neg), 1) # [b, 2] 722 | softmax_logit = nn.functional.softmax(logit, 1) #[b, 2] 723 | 724 | contras_loss = - torch.log(softmax_logit[:, 0]) 725 | # contras_loss += torch.log(softmax_logit[:, 1]) # add contras_neg 726 | contras_loss = contras_loss.mean() 727 | 728 | 729 | loss = loss_neg + args.contras_loss_weight * contras_loss 730 | 731 | if (loss != loss).any(): 732 | raise ValueError("NaN loss") 733 | loss.backward() 734 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 735 | optim.step() 736 | optim.zero_grad() 737 | 738 | total_loss += loss.item() * q.size(0) 739 | 740 | 741 | else: 742 | ## first train 743 | pred, loss, _, feature_ori = model(v, q, a, b, None) 744 | 745 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0] 746 | 747 | if (loss != loss).any(): 748 | raise ValueError("NaN loss") 749 | loss.backward(retain_graph=True) 750 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 751 | optim.step() 752 | optim.zero_grad() 753 | 754 | total_loss += loss.item() * q.size(0) 755 | batch_score = compute_score_with_logits(pred, a.data).sum() 756 | train_score += batch_score 757 | 758 | ##second train 759 | v_mask = torch.zeros(v.shape[0], 36).cuda() 760 | visual_grad_cam = visual_grad.sum(2) 761 | hint_sort, hint_ind = hintscore.sort(1, descending=True) 762 | v_ind = hint_ind[:, :top_hint] 763 | v_grad = visual_grad_cam.gather(1, v_ind) 764 | 765 | if topv==-1: 766 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True) 767 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1) 768 | v_grad_sum=torch.cumsum(v_grad_score,dim=1) 769 | v_grad_mask=(v_grad_sum<=0.65).long() 770 | v_grad_mask[:,0] = 1 771 | v_mask_ind=v_grad_mask*v_ind 772 | for x in range(a.shape[0]): 773 | num=len(torch.nonzero(v_grad_mask[x])) 774 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1) 775 | else: 776 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv] 777 | v_star = v_ind.gather(1, v_grad_ind) 778 | v_mask.scatter_(1, v_star, 1) 779 | 780 | pred_pos, _, _, feature_pos = model(v, q, None, b, v_mask) # P = VQA(I+,Q) 781 | v_mask_pos = v_mask 782 | 783 | pred_ind = torch.argsort(pred_pos, 1, descending=True)[:, :5] 784 | false_ans = torch.ones(pred_pos.shape[0], pred_pos.shape[1]).cuda() 785 | false_ans.scatter_(1, pred_ind, 0) 786 | a2 = a * false_ans 787 | 788 | v_mask_neg = 1 - v_mask_pos 789 | 790 | pred_neg, loss_neg, _, feature_neg = model(v, q, a2, b, v_mask_neg) # P = VQA(I-,Q) 791 | 792 | #Compute the contras loss 793 | # inner product 794 | # pos = (feature_ori * feature_pos).sum(1) 795 | # neg = (feature_ori * feature_neg).sum(1) 796 | # cosine_similarity 797 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1) 798 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1) 799 | 800 | logit = torch.stack((pos, neg), 1) # [b, 2] 801 | softmax_logit = nn.functional.softmax(logit, 1) #[b, 2] 802 | 803 | contras_loss = - torch.log(softmax_logit[:, 0]) 804 | # contras_loss += torch.log(softmax_logit[:, 1]) # add contras_neg 805 | contras_loss = contras_loss.mean() 806 | 807 | loss = loss_neg + args.contras_loss_weight * contras_loss 808 | 809 | if (loss != loss).any(): 810 | raise ValueError("NaN loss") 811 | loss.backward() 812 | nn.utils.clip_grad_norm_(model.parameters(), 0.25) 813 | optim.step() 814 | optim.zero_grad() 815 | 816 | total_loss += loss.item() * q.size(0) 817 | 818 | 819 | if mode=='updn': 820 | total_loss /= len(train_loader.dataset) 821 | else: 822 | total_loss /= len(train_loader.dataset) * 2 823 | train_score = 100 * train_score / len(train_loader.dataset) 824 | 825 | if run_eval: 826 | model.train(False) 827 | results = evaluate(model, eval_loader, qid2type) 828 | results["epoch"] = epoch + 1 829 | results["step"] = total_step 830 | results["train_loss"] = total_loss 831 | results["train_score"] = train_score 832 | 833 | model.train(True) 834 | 835 | eval_score = results["score"] 836 | bound = results["upper_bound"] 837 | yn = results['score_yesno'] 838 | other = results['score_other'] 839 | num = results['score_number'] 840 | 841 | logger.write('epoch %d, time: %.2f' % (epoch, time.time() - t)) 842 | logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score)) 843 | 844 | if run_eval: 845 | logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound)) 846 | logger.write('\tyn score: %.2f other score: %.2f num score: %.2f' % (100 * yn, 100 * other, 100 * num)) 847 | 848 | if eval_score > best_eval_score: 849 | model_path = os.path.join(output, 'model.pth') 850 | torch.save(model.state_dict(), model_path) 851 | best_eval_score = eval_score 852 | 853 | 854 | def evaluate(model, dataloader, qid2type): 855 | score = 0 856 | upper_bound = 0 857 | score_yesno = 0 858 | score_number = 0 859 | score_other = 0 860 | total_yesno = 0 861 | total_number = 0 862 | total_other = 0 863 | 864 | for v, q, a, b, qids, _ in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"): 865 | v = Variable(v, requires_grad=False).cuda() 866 | q = Variable(q, requires_grad=False).cuda() 867 | pred, _,_ = model(v, q, None, None, None) 868 | batch_score = compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1) 869 | score += batch_score.sum() 870 | upper_bound += (a.max(1)[0]).sum() 871 | qids = qids.detach().cpu().int().numpy() 872 | for j in range(len(qids)): 873 | qid = qids[j] 874 | typ = qid2type[str(qid)] 875 | if typ == 'yes/no': 876 | score_yesno += batch_score[j] 877 | total_yesno += 1 878 | elif typ == 'other': 879 | score_other += batch_score[j] 880 | total_other += 1 881 | elif typ == 'number': 882 | score_number += batch_score[j] 883 | total_number += 1 884 | else: 885 | print('Hahahahahahahahahahaha') 886 | 887 | 888 | score = score / len(dataloader.dataset) 889 | upper_bound = upper_bound / len(dataloader.dataset) 890 | score_yesno /= total_yesno 891 | score_other /= total_other 892 | score_number /= total_number 893 | 894 | results = dict( 895 | score=score, 896 | upper_bound=upper_bound, 897 | score_yesno=score_yesno, 898 | score_other=score_other, 899 | score_number=score_number, 900 | ) 901 | return results 902 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import errno 4 | import os 5 | import numpy as np 6 | # from PIL import Image 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | EPS = 1e-7 12 | 13 | 14 | def assert_eq(real, expected): 15 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 16 | 17 | 18 | def assert_array_eq(real, expected): 19 | assert (np.abs(real-expected) < EPS).all(), \ 20 | '%s (true) vs %s (expected)' % (real, expected) 21 | 22 | 23 | def load_folder(folder, suffix): 24 | imgs = [] 25 | for f in sorted(os.listdir(folder)): 26 | if f.endswith(suffix): 27 | imgs.append(os.path.join(folder, f)) 28 | return imgs 29 | 30 | 31 | # def load_imageid(folder): 32 | # images = load_folder(folder, 'jpg') 33 | # img_ids = set() 34 | # for img in images: 35 | # img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1]) 36 | # img_ids.add(img_id) 37 | # return img_ids 38 | 39 | 40 | # def pil_loader(path): 41 | # with open(path, 'rb') as f: 42 | # with Image.open(f) as img: 43 | # return img.convert('RGB') 44 | 45 | 46 | def weights_init(m): 47 | """custom weights initialization.""" 48 | cname = m.__class__ 49 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d: 50 | m.weight.data.normal_(0.0, 0.02) 51 | elif cname == nn.BatchNorm2d: 52 | m.weight.data.normal_(1.0, 0.02) 53 | m.bias.data.fill_(0) 54 | else: 55 | print('%s is not initialized.' % cname) 56 | 57 | 58 | def init_net(net, net_file): 59 | if net_file: 60 | net.load_state_dict(torch.load(net_file)) 61 | else: 62 | net.apply(weights_init) 63 | 64 | 65 | def create_dir(path): 66 | if not os.path.exists(path): 67 | try: 68 | os.makedirs(path) 69 | except OSError as exc: 70 | if exc.errno != errno.EEXIST: 71 | raise 72 | 73 | 74 | class Logger(object): 75 | def __init__(self, output_name): 76 | dirname = os.path.dirname(output_name) 77 | if not os.path.exists(dirname): 78 | os.mkdir(dirname) 79 | 80 | self.log_file = open(output_name, 'w') 81 | self.infos = {} 82 | 83 | def append(self, key, val): 84 | vals = self.infos.setdefault(key, []) 85 | vals.append(val) 86 | 87 | def log(self, extra_msg=''): 88 | msgs = [extra_msg] 89 | for key, vals in self.infos.iteritems(): 90 | msgs.append('%s %.6f' % (key, np.mean(vals))) 91 | msg = '\n'.join(msgs) 92 | self.log_file.write(msg + '\n') 93 | self.log_file.flush() 94 | self.infos = {} 95 | return msg 96 | 97 | def write(self, msg): 98 | self.log_file.write(msg + '\n') 99 | self.log_file.flush() 100 | print(msg) 101 | -------------------------------------------------------------------------------- /vqa_debias_loss_functions.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict, Counter 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import torch 7 | import inspect 8 | 9 | 10 | def convert_sigmoid_logits_to_binary_logprobs(logits): 11 | """computes log(sigmoid(logits)), log(1-sigmoid(logits))""" 12 | log_prob = -F.softplus(-logits) 13 | log_one_minus_prob = -logits + log_prob 14 | return log_prob, log_one_minus_prob 15 | 16 | 17 | def elementwise_logsumexp(a, b): 18 | """computes log(exp(x) + exp(b))""" 19 | return torch.max(a, b) + torch.log1p(torch.exp(-torch.abs(a - b))) 20 | 21 | 22 | def renormalize_binary_logits(a, b): 23 | """Normalize so exp(a) + exp(b) == 1""" 24 | norm = elementwise_logsumexp(a, b) 25 | return a - norm, b - norm 26 | 27 | 28 | class DebiasLossFn(nn.Module): 29 | """General API for our loss functions""" 30 | 31 | def forward(self, hidden, logits, bias, labels): 32 | """ 33 | :param hidden: [batch, n_hidden] hidden features from the last layer in the model 34 | :param logits: [batch, n_answers_options] sigmoid logits for each answer option 35 | :param bias: [batch, n_answers_options] 36 | bias probabilities for each answer option between 0 and 1 37 | :param labels: [batch, n_answers_options] 38 | scores for each answer option, between 0 and 1 39 | :return: Scalar loss 40 | """ 41 | raise NotImplementedError() 42 | 43 | def to_json(self): 44 | """Get a json representation of this loss function. 45 | 46 | We construct this by looking up the __init__ args 47 | """ 48 | cls = self.__class__ 49 | init = cls.__init__ 50 | if init is object.__init__: 51 | return [] # No init args 52 | 53 | init_signature = inspect.getargspec(init) 54 | if init_signature.varargs is not None: 55 | raise NotImplementedError("varags not supported") 56 | if init_signature.keywords is not None: 57 | raise NotImplementedError("keywords not supported") 58 | args = [x for x in init_signature.args if x != "self"] 59 | out = OrderedDict() 60 | out["name"] = cls.__name__ 61 | for key in args: 62 | out[key] = getattr(self, key) 63 | return out 64 | 65 | 66 | class Plain(DebiasLossFn): 67 | def forward(self, hidden, logits, bias, labels): 68 | loss = F.binary_cross_entropy_with_logits(logits, labels) 69 | 70 | loss *= labels.size(1) 71 | return loss 72 | 73 | 74 | class Focal(DebiasLossFn): 75 | def forward(self, hidden, logits, bias, labels): 76 | # import pdb;pdb.set_trace() 77 | focal_logits=torch.log(F.softmax(logits,dim=1)+1e-5) * ((1-F.softmax(bias,dim=1))*(1-F.softmax(bias,dim=1))) 78 | loss=F.binary_cross_entropy_with_logits(focal_logits,labels) 79 | loss*=labels.size(1) 80 | return loss 81 | 82 | class ReweightByInvBias(DebiasLossFn): 83 | def forward(self, hidden, logits, bias, labels): 84 | # Manually compute the binary cross entropy since the old version of torch always aggregates 85 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 86 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob) 87 | weights = (1 - bias) 88 | loss *= weights # Apply the weights 89 | return loss.sum() / weights.sum() 90 | 91 | 92 | class BiasProduct(DebiasLossFn): 93 | def __init__(self, smooth=True, smooth_init=-1, constant_smooth=0.0): 94 | """ 95 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 96 | :param smooth_init: How to initialize `a` 97 | :param constant_smooth: Constant to add to the bias to smooth it 98 | """ 99 | super(BiasProduct, self).__init__() 100 | self.constant_smooth = constant_smooth 101 | self.smooth_init = smooth_init 102 | self.smooth = smooth 103 | if smooth: 104 | self.smooth_param = torch.nn.Parameter( 105 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 106 | else: 107 | self.smooth_param = None 108 | 109 | def forward(self, hidden, logits, bias, labels): 110 | smooth = self.constant_smooth 111 | if self.smooth: 112 | smooth += F.sigmoid(self.smooth_param) 113 | 114 | # Convert the bias into log-space, with a factor for both the 115 | # binary outputs for each answer option 116 | bias_lp = torch.log(bias + smooth) 117 | bias_l_inv = torch.log1p(-bias + smooth) 118 | 119 | # Convert the the logits into log-space with the same format 120 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 121 | # import pdb;pdb.set_trace() 122 | 123 | # Add the bias 124 | log_prob += bias_lp 125 | log_one_minus_prob += bias_l_inv 126 | 127 | # Re-normalize the factors in logspace 128 | log_prob, log_one_minus_prob = renormalize_binary_logits(log_prob, log_one_minus_prob) 129 | 130 | # Compute the binary cross entropy 131 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 132 | return loss 133 | 134 | 135 | class LearnedMixin(DebiasLossFn): 136 | def __init__(self, w, smooth=True, smooth_init=-1, constant_smooth=0.0): 137 | """ 138 | :param w: Weight of the entropy penalty 139 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it 140 | :param smooth_init: How to initialize `a` 141 | :param constant_smooth: Constant to add to the bias to smooth it 142 | """ 143 | super(LearnedMixin, self).__init__() 144 | self.w = w 145 | # self.w=0 146 | self.smooth_init = smooth_init 147 | self.constant_smooth = constant_smooth 148 | self.bias_lin = torch.nn.Linear(1024, 1) 149 | self.smooth = smooth 150 | if self.smooth: 151 | self.smooth_param = torch.nn.Parameter( 152 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32))) 153 | else: 154 | self.smooth_param = None 155 | 156 | def forward(self, hidden, logits, bias, labels): 157 | factor = self.bias_lin.forward(hidden) # [batch, 1] 158 | factor = F.softplus(factor) 159 | 160 | bias = torch.stack([bias, 1 - bias], 2) # [batch, n_answers, 2] 161 | 162 | # Smooth 163 | bias += self.constant_smooth 164 | if self.smooth: 165 | soften_factor = F.sigmoid(self.smooth_param) 166 | bias = bias + soften_factor.unsqueeze(1) 167 | 168 | bias = torch.log(bias) # Convert to logspace 169 | 170 | # Scale by the factor 171 | # [batch, n_answers, 2] * [batch, 1, 1] -> [batch, n_answers, 2] 172 | bias = bias * factor.unsqueeze(1) 173 | 174 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits) 175 | log_probs = torch.stack([log_prob, log_one_minus_prob], 2) 176 | 177 | # Add the bias in 178 | logits = bias + log_probs 179 | 180 | # Renormalize to get log probabilities 181 | log_prob, log_one_minus_prob = renormalize_binary_logits(logits[:, :, 0], logits[:, :, 1]) 182 | 183 | # Compute loss 184 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0) 185 | 186 | # Re-normalized version of the bias 187 | bias_norm = elementwise_logsumexp(bias[:, :, 0], bias[:, :, 1]) 188 | bias_logprob = bias - bias_norm.unsqueeze(2) 189 | 190 | # Compute and add the entropy penalty 191 | entropy = -(torch.exp(bias_logprob) * bias_logprob).sum(2).mean() 192 | return loss + self.w * entropy --------------------------------------------------------------------------------