├── README.md
├── attention.py
├── base_model.py
├── classifier.py
├── dataset.py
├── eval.py
├── fc.py
├── language_model.py
├── main.py
├── tools
├── compute_softscore.py
├── create_dictionary.py
├── create_dictionary_v1.py
├── download.sh
├── download_v1.sh
└── process.sh
├── train.py
├── utils.py
└── vqa_debias_loss_functions.py
/README.md:
--------------------------------------------------------------------------------
1 | # Learning to Contrast the Counterfactual Samples for Robust Visual Question Answering
2 | The source code for our paper [Learning to Contrast the Counterfactual Samples for Robust Visual Question Answering](https://www.aclweb.org/anthology/2020.emnlp-main.265.pdf) published in EMNLP 2020. This repo contains code modified from [CSS-VQA](https://github.com/yanxinzju/CSS-VQA), Many thanks for their efforts.
3 |
4 | ### Prerequisites
5 |
6 | Make sure you are on a machine with a NVIDIA GPU and Python 2.7 with about 100 GB disk space.
7 | h5py==2.10.0
8 | pytorch==1.1.0
9 | Click==7.0
10 | numpy==1.16.5
11 | tqdm==4.35.0
12 |
13 | ### Data Setup
14 | All data preprocess and set up please refer to [bottom-up-attention-vqa](https://github.com/hengyuan-hu/bottom-up-attention-vqa)
15 |
16 | 1. Please run the script to download the data.
17 |
18 | ```
19 | bash tools/download.sh
20 | ```
21 |
22 | 2. Please click the link [HERE](https://drive.google.com/drive/folders/13e-b76otJukupbjfC-n1s05L202PaFKQ?usp=sharing) to download the rest of the data, which is kindly shared by [CSS-VQA](https://github.com/yanxinzju/CSS-VQA).
23 |
24 |
25 |
26 | ### Training
27 |
28 | All the args for running our code is preset in the main.py.
29 |
30 | Run
31 | ```
32 | CUDA_VISIBLE_DEVICES=0 python main.py
33 | ```
34 | to train a model
35 |
36 | ### Testing
37 | Run
38 | ```
39 | CUDA_VISIBLE_DEVICES=0 python eval.py --dataset [] --debias [] --model_state []
40 | ```
41 | to eval a model
42 |
43 | ## Citation
44 |
45 | If you find this paper helps your research, please kindly consider citing our paper in your publications.
46 |
47 | ```BibTeX
48 | @inproceedings{liang2020learning,
49 | title={Learning to Contrast the Counterfactual Samples for Robust Visual Question Answering},
50 | author={Liang, Zujie and Jiang, Weitao and Hu, Haifeng and Zhu, Jiaying},
51 | booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)},
52 | year={2020}
53 | }
54 | ```
55 |
56 |
--------------------------------------------------------------------------------
/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.weight_norm import weight_norm
4 | from fc import FCNet
5 |
6 |
7 | class Attention(nn.Module):
8 | def __init__(self, v_dim, q_dim, num_hid):
9 | super(Attention, self).__init__()
10 | self.nonlinear = FCNet([v_dim + q_dim, num_hid])
11 | self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None)
12 |
13 | def forward(self, v, q):
14 | """
15 | v: [batch, k, vdim]
16 | q: [batch, qdim]
17 | """
18 | logits = self.logits(v, q)
19 | w = nn.functional.softmax(logits, 1)
20 | return w
21 |
22 | def logits(self, v, q):
23 | num_objs = v.size(1)
24 | q = q.unsqueeze(1).repeat(1, num_objs, 1)
25 | vq = torch.cat((v, q), 2)
26 | joint_repr = self.nonlinear(vq)
27 | logits = self.linear(joint_repr)
28 | return logits
29 |
30 |
31 | class NewAttention(nn.Module):
32 | def __init__(self, v_dim, q_dim, num_hid, dropout=0.2):
33 | super(NewAttention, self).__init__()
34 |
35 | self.v_proj = FCNet([v_dim, num_hid])
36 | self.q_proj = FCNet([q_dim, num_hid])
37 | self.dropout = nn.Dropout(dropout)
38 | self.linear = weight_norm(nn.Linear(q_dim, 1), dim=None)
39 |
40 | def forward(self, v, q):
41 | """
42 | v: [batch, k, vdim]
43 | q: [batch, qdim]
44 | """
45 | logits = self.logits(v, q)
46 | # w = nn.functional.softmax(logits, 1)
47 | # return w
48 | return logits
49 |
50 | def logits(self, v, q):
51 | batch, k, _ = v.size()
52 | v_proj = self.v_proj(v) # [batch, k, qdim]
53 | q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1)
54 | joint_repr = v_proj * q_proj
55 | joint_repr = self.dropout(joint_repr)
56 | logits = self.linear(joint_repr)
57 | return logits
58 |
--------------------------------------------------------------------------------
/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from attention import Attention, NewAttention
4 | from language_model import WordEmbedding, QuestionEmbedding
5 | from classifier import SimpleClassifier
6 | from fc import FCNet
7 | import numpy as np
8 |
9 | def mask_softmax(x,mask):
10 | mask=mask.unsqueeze(2).float()
11 | x2=torch.exp(x-torch.max(x))
12 | x3=x2*mask
13 | epsilon=1e-5
14 | x3_sum=torch.sum(x3,dim=1,keepdim=True)+epsilon
15 | x4=x3/x3_sum.expand_as(x3)
16 | return x4
17 |
18 |
19 | class BaseModel(nn.Module):
20 | def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier):
21 | super(BaseModel, self).__init__()
22 | self.w_emb = w_emb
23 | self.q_emb = q_emb
24 | self.v_att = v_att
25 | self.q_net = q_net
26 | self.v_net = v_net
27 | self.classifier = classifier
28 | self.debias_loss_fn = None
29 | # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2))
30 | self.bias_lin = torch.nn.Linear(1024, 1)
31 | self.is_contras = None
32 |
33 | def forward(self, v, q, labels, bias,v_mask):
34 | """Forward
35 |
36 | v: [batch, num_objs, obj_dim]
37 | b: [batch, num_objs, b_dim]
38 | q: [batch_size, seq_length]
39 |
40 | return: logits, not probs
41 | """
42 | w_emb = self.w_emb(q)
43 | q_emb = self.q_emb(w_emb) # [batch, q_dim]
44 |
45 |
46 | att = self.v_att(v, q_emb)
47 | if v_mask is None:
48 | att = nn.functional.softmax(att, 1)
49 | else:
50 | att= mask_softmax(att,v_mask)
51 |
52 | v_emb = (att * v).sum(1) # [batch, v_dim]
53 |
54 | q_repr = self.q_net(q_emb)
55 | v_repr = self.v_net(v_emb)
56 | joint_repr = q_repr * v_repr
57 |
58 | logits = self.classifier(joint_repr)
59 |
60 | if labels is not None:
61 | loss = self.debias_loss_fn(joint_repr, logits, bias, labels)
62 |
63 | else:
64 | loss = None
65 | if (self.is_contras is True) and (self.training is True):
66 | return logits, loss, w_emb, joint_repr
67 | else:
68 | return logits, loss, w_emb
69 |
70 | def build_baseline0(dataset, num_hid):
71 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0)
72 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0)
73 | v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid)
74 | q_net = FCNet([num_hid, num_hid])
75 | v_net = FCNet([dataset.v_dim, num_hid])
76 | classifier = SimpleClassifier(
77 | num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5)
78 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier)
79 |
80 |
81 | def build_baseline0_newatt(dataset, num_hid):
82 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0)
83 | q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0)
84 | v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid)
85 | q_net = FCNet([q_emb.num_hid, num_hid])
86 | v_net = FCNet([dataset.v_dim, num_hid])
87 | classifier = SimpleClassifier(
88 | num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5)
89 | return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier)
--------------------------------------------------------------------------------
/classifier.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn.utils.weight_norm import weight_norm
3 |
4 |
5 | class SimpleClassifier(nn.Module):
6 | def __init__(self, in_dim, hid_dim, out_dim, dropout):
7 | super(SimpleClassifier, self).__init__()
8 | layers = [
9 | weight_norm(nn.Linear(in_dim, hid_dim), dim=None),
10 | nn.ReLU(),
11 | nn.Dropout(dropout, inplace=True),
12 | weight_norm(nn.Linear(hid_dim, out_dim), dim=None)
13 | ]
14 | self.main = nn.Sequential(*layers)
15 |
16 | def forward(self, x):
17 | logits = self.main(x)
18 | return logits
19 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import unicode_literals
3 |
4 | import os
5 | import json
6 | # import cPickle
7 | import pickle as cPickle #python3
8 | from collections import Counter
9 |
10 | import numpy as np
11 | import utils
12 | import h5py
13 | import torch
14 | from torch.utils.data import Dataset
15 | from tqdm import tqdm
16 | from random import choice
17 |
18 | class Dictionary(object):
19 | def __init__(self, word2idx=None, idx2word=None):
20 | if word2idx is None:
21 | word2idx = {}
22 | if idx2word is None:
23 | idx2word = []
24 | self.word2idx = word2idx
25 | self.idx2word = idx2word
26 |
27 | @property
28 | def ntoken(self):
29 | return len(self.word2idx)
30 |
31 | @property
32 | def padding_idx(self):
33 | return len(self.word2idx)
34 |
35 | def tokenize(self, sentence, add_word):
36 | sentence = sentence.lower()
37 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s').replace('-',
38 | ' ').replace('.','').replace('"', '').replace('n\'t', ' not').replace('$', ' dollar ')
39 | words = sentence.split()
40 | tokens = []
41 | if add_word:
42 | for w in words:
43 | tokens.append(self.add_word(w))
44 | else:
45 | for w in words:
46 | if w in self.word2idx:
47 | tokens.append(self.word2idx[w])
48 | else:
49 | tokens.append(len(self.word2idx))
50 | return tokens
51 |
52 | def dump_to_file(self, path):
53 | cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb'))
54 | print('dictionary dumped to %s' % path)
55 |
56 | @classmethod
57 | def load_from_file(cls, path):
58 | print('loading dictionary from %s' % path)
59 | word2idx, idx2word = cPickle.load(open(path, 'rb'))
60 | d = cls(word2idx, idx2word)
61 | return d
62 |
63 | def add_word(self, word):
64 | if word not in self.word2idx:
65 | self.idx2word.append(word)
66 | self.word2idx[word] = len(self.idx2word) - 1
67 | return self.word2idx[word]
68 |
69 | def __len__(self):
70 | return len(self.idx2word)
71 |
72 |
73 | def _create_entry(img_idx, question, answer):
74 | answer.pop('image_id')
75 | answer.pop('question_id')
76 | entry = {
77 | 'question_id' : question['question_id'],
78 | 'image_id' : question['image_id'],
79 | 'image_idx' : img_idx,
80 | 'question' : question['question'],
81 | 'answer' : answer
82 | }
83 | return entry
84 |
85 |
86 | def _load_dataset(dataroot, name, img_id2val, dataset):
87 | """Load entries
88 |
89 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features
90 | dataroot: root path of dataset
91 | name: 'train', 'val'
92 | """
93 | if dataset=='cpv2':
94 | answer_path = os.path.join(dataroot, 'cp-cache', '%s_target.pkl' % name)
95 | name = "train" if name == "train" else "test"
96 | question_path = os.path.join(dataroot, 'vqacp_v2_%s_questions.json' % name)
97 | with open(question_path) as f:
98 | questions = json.load(f)
99 | elif dataset=='cpv1':
100 | answer_path = os.path.join(dataroot, 'cp-v1-cache', '%s_target.pkl' % name)
101 | name = "train" if name == "train" else "test"
102 | question_path = os.path.join(dataroot, 'vqacp_v1_%s_questions.json' % name)
103 | with open(question_path) as f:
104 | questions = json.load(f)
105 | elif dataset=='v2':
106 | answer_path = os.path.join(dataroot, 'cache', '%s_target.pkl' % name)
107 | question_path = os.path.join(dataroot, 'v2_OpenEnded_mscoco_%s2014_questions.json' % name)
108 | with open(question_path) as f:
109 | questions = json.load(f)["questions"]
110 |
111 | with open(answer_path, 'rb') as f:
112 | answers = cPickle.load(f)
113 |
114 | questions.sort(key=lambda x: x['question_id'])
115 | answers.sort(key=lambda x: x['question_id'])
116 |
117 | utils.assert_eq(len(questions), len(answers))
118 | entries = []
119 | for question, answer in zip(questions, answers):
120 | if answer["labels"] is None:
121 | raise ValueError()
122 | utils.assert_eq(question['question_id'], answer['question_id'])
123 | utils.assert_eq(question['image_id'], answer['image_id'])
124 | img_id = question['image_id']
125 | img_idx = None
126 | if img_id2val:
127 | img_idx = img_id2val[img_id]
128 |
129 | entries.append(_create_entry(img_idx, question, answer))
130 | return entries
131 |
132 |
133 | class VQAFeatureDataset(Dataset):
134 | def __init__(self, name, dictionary, dataroot='data', dataset='cpv2',
135 | use_hdf5=False, cache_image_features=False):
136 | super(VQAFeatureDataset, self).__init__()
137 | self.name=name
138 | if dataset=='cpv2':
139 | with open('data/train_cpv2_hintscore.json', 'r') as f:
140 | self.train_hintscore = json.load(f)
141 | with open('data/test_cpv2_hintscore.json', 'r') as f:
142 | self.test_hintsocre = json.load(f)
143 | with open('util/cpv2_type_mask.json', 'r') as f:
144 | self.type_mask = json.load(f)
145 | with open('util/cpv2_notype_mask.json', 'r') as f:
146 | self.notype_mask = json.load(f)
147 |
148 | elif dataset=='cpv1':
149 | with open('data/train_cpv1_hintscore.json', 'r') as f:
150 | self.train_hintscore = json.load(f)
151 | with open('data/test_cpv1_hintscore.json', 'r') as f:
152 | self.test_hintsocre = json.load(f)
153 | with open('util/cpv1_type_mask.json', 'r') as f:
154 | self.type_mask = json.load(f)
155 | with open('util/cpv1_notype_mask.json', 'r') as f:
156 | self.notype_mask = json.load(f)
157 | elif dataset=='v2':
158 | with open('data/train_v2_hintscore.json', 'r') as f:
159 | self.train_hintscore = json.load(f)
160 | with open('data/test_v2_hintscore.json', 'r') as f:
161 | self.test_hintsocre = json.load(f)
162 | with open('util/v2_type_mask.json', 'r') as f:
163 | self.type_mask = json.load(f)
164 | with open('util/v2_notype_mask.json', 'r') as f:
165 | self.notype_mask = json.load(f)
166 |
167 | assert name in ['train', 'val']
168 |
169 | if dataset=='cpv2':
170 | ans2label_path = os.path.join(dataroot, 'cp-cache', 'trainval_ans2label.pkl')
171 | label2ans_path = os.path.join(dataroot, 'cp-cache', 'trainval_label2ans.pkl')
172 | elif dataset=='cpv1':
173 | ans2label_path = os.path.join(dataroot, 'cp-v1-cache', 'trainval_ans2label.pkl')
174 | label2ans_path = os.path.join(dataroot, 'cp-v1-cache', 'trainval_label2ans.pkl')
175 | elif dataset=='v2':
176 | ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl')
177 | label2ans_path = os.path.join(dataroot, 'cache', 'trainval_label2ans.pkl')
178 | self.ans2label = cPickle.load(open(ans2label_path, 'rb'))
179 | self.label2ans = cPickle.load(open(label2ans_path, 'rb'))
180 | self.num_ans_candidates = len(self.ans2label)
181 |
182 | self.dictionary = dictionary
183 | self.use_hdf5 = use_hdf5
184 |
185 | if use_hdf5:
186 | h5_path = os.path.join(dataroot, '%s36.hdf5'%name)
187 | self.hf = h5py.File(h5_path, 'r')
188 | self.features = self.hf.get('image_features')
189 |
190 | with open("util/%s36_imgid2img.pkl"%name, "rb") as f:
191 | imgid2idx = cPickle.load(f)
192 | else:
193 | imgid2idx = None
194 |
195 | self.entries = _load_dataset(dataroot, name, imgid2idx, dataset=dataset)
196 | if cache_image_features is True:
197 | image_to_fe = {}
198 | for entry in tqdm(self.entries, ncols=100, desc="caching-features"):
199 | img_id = entry["image_id"]
200 | if img_id not in image_to_fe:
201 | if use_hdf5:
202 | fe = np.array(self.features[imgid2idx[img_id]])
203 | else:
204 | # fe=torch.load('data/rcnn_feature/'+str(img_id)+'.pth')['image_feature']
205 | fe = np.fromfile("data/trainval_features/" + str(img_id) + ".bin", np.float32)
206 | fe = torch.from_numpy(fe).view(36, 2048)
207 | image_to_fe[img_id]=fe
208 | self.image_to_fe = image_to_fe
209 | if use_hdf5:
210 | self.hf.close()
211 | else:
212 | self.image_to_fe = None
213 |
214 | self.tokenize()
215 | self.tensorize()
216 |
217 | self.v_dim = 2048
218 |
219 | def tokenize(self, max_length=14):
220 | """Tokenizes the questions.
221 |
222 | This will add q_token in each entry of the dataset.
223 | -1 represent nil, and should be treated as padding_idx in embedding
224 | """
225 | for entry in tqdm(self.entries, ncols=100, desc="tokenize"):
226 | tokens = self.dictionary.tokenize(entry['question'], False)
227 | tokens = tokens[:max_length]
228 | if len(tokens) < max_length:
229 | # Note here we pad in front of the sentence
230 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens))
231 | padding_mask=[self.dictionary.padding_idx-1] * (max_length - len(tokens))
232 | tokens_mask = padding_mask + tokens
233 | tokens = padding + tokens
234 |
235 | utils.assert_eq(len(tokens), max_length)
236 | entry['q_token'] = tokens
237 | entry['q_token_mask']=tokens_mask
238 |
239 | def tensorize(self):
240 | for entry in tqdm(self.entries, ncols=100, desc="tensorize"):
241 | question = torch.from_numpy(np.array(entry['q_token']))
242 | question_mask = torch.from_numpy(np.array(entry['q_token_mask']))
243 |
244 | entry['q_token'] = question
245 | entry['q_token_mask']=question_mask
246 |
247 | answer = entry['answer']
248 | labels = np.array(answer['labels'])
249 | scores = np.array(answer['scores'], dtype=np.float32)
250 | if len(labels):
251 | labels = torch.from_numpy(labels)
252 | scores = torch.from_numpy(scores)
253 | entry['answer']['labels'] = labels
254 | entry['answer']['scores'] = scores
255 | else:
256 | entry['answer']['labels'] = None
257 | entry['answer']['scores'] = None
258 |
259 | def __getitem__(self, index):
260 | entry = self.entries[index]
261 | if self.image_to_fe is not None:
262 | features = self.image_to_fe[entry["image_id"]]
263 | elif self.use_hdf5:
264 | features = np.array(self.features[entry['image_idx']])
265 | features = torch.from_numpy(features).view(36, 2048)
266 | else:
267 | # features = torch.load('data/rcnn_feature/' + str(entry["image_id"]) + '.pth')['image_feature']
268 | features = np.fromfile("/data/trainval_features/" + str(entry["image_id"]) + ".bin", np.float32)
269 | features = torch.from_numpy(features).view(36, 2048)
270 |
271 | q_id=entry['question_id']
272 | ques = entry['q_token']
273 | ques_mask=entry['q_token_mask']
274 | answer = entry['answer']
275 | labels = answer['labels']
276 | scores = answer['scores']
277 | target = torch.zeros(self.num_ans_candidates)
278 | if labels is not None:
279 | target.scatter_(0, labels, scores)
280 |
281 | if self.name=='train':
282 | train_hint=torch.tensor(self.train_hintscore[str(q_id)])
283 | type_mask=torch.tensor(self.type_mask[str(q_id)])
284 | notype_mask=torch.tensor(self.notype_mask[str(q_id)])
285 | if "bias" in entry:
286 | return features, ques, target,entry["bias"],train_hint,type_mask,notype_mask,ques_mask
287 |
288 | else:
289 | return features, ques,target, 0,train_hint
290 | else:
291 | test_hint=torch.tensor(self.test_hintsocre[str(q_id)])
292 | if "bias" in entry:
293 | return features, ques, target, entry["bias"],q_id,test_hint
294 | else:
295 | return features, ques, target, 0,q_id,test_hint
296 |
297 | def __len__(self):
298 | return len(self.entries)
299 |
300 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | # import cPickle
4 | import pickle as cPickle
5 | from collections import defaultdict, Counter
6 | from os.path import dirname, join
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | import numpy as np
12 | import os
13 |
14 | # from new_dataset import Dictionary, VQAFeatureDataset
15 | from dataset import Dictionary, VQAFeatureDataset
16 | import base_model
17 | from train import train
18 | import utils
19 |
20 | from vqa_debias_loss_functions import *
21 | from tqdm import tqdm
22 | from torch.autograd import Variable
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method")
27 |
28 | # Arguments we added
29 | parser.add_argument(
30 | '--cache_features', default=True,
31 | help="Cache image features in RAM. Makes things much faster, "
32 | "especially if the filesystem is slow, but requires at least 48gb of RAM")
33 | parser.add_argument(
34 | '--dataset', default='cpv2', help="Run on VQA-2.0 instead of VQA-CP 2.0")
35 | parser.add_argument(
36 | '-p', "--entropy_penalty", default=0.36, type=float,
37 | help="Entropy regularizer weight for the learned_mixin model")
38 | parser.add_argument(
39 | '--debias', default="learned_mixin",
40 | choices=["learned_mixin", "reweight", "bias_product", "none"],
41 | help="Kind of ensemble loss to use")
42 | # Arguments from the original model, we leave this default, except we
43 | # set --epochs to 15 since the model maxes out its performance on VQA 2.0 well before then
44 | parser.add_argument('--num_hid', type=int, default=1024)
45 | parser.add_argument('--model', type=str, default='baseline0_newatt')
46 | parser.add_argument('--batch_size', type=int, default=512)
47 | parser.add_argument('--seed', type=int, default=1111, help='random seed')
48 | parser.add_argument('--model_state', type=str, default='logs/exp0/model.pth')
49 | args = parser.parse_args()
50 | return args
51 |
52 | def compute_score_with_logits(logits, labels):
53 | # logits = torch.max(logits, 1)[1].data # argmax
54 | logits = torch.argmax(logits,1)
55 | one_hots = torch.zeros(*labels.size()).cuda()
56 | one_hots.scatter_(1, logits.view(-1, 1), 1)
57 | scores = (one_hots * labels)
58 | return scores
59 |
60 |
61 | def evaluate(model,dataloader,qid2type):
62 | score = 0
63 | upper_bound = 0
64 | score_yesno = 0
65 | score_number = 0
66 | score_other = 0
67 | total_yesno = 0
68 | total_number = 0
69 | total_other = 0
70 | model.train(False)
71 | # import pdb;pdb.set_trace()
72 |
73 |
74 | for v, q, a, b,qids,hintscore in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"):
75 | v = Variable(v, requires_grad=False).cuda()
76 | q = Variable(q, requires_grad=False).cuda()
77 | pred, _ ,_= model(v, q, None, None,None)
78 | batch_score= compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1)
79 | score += batch_score.sum()
80 | upper_bound += (a.max(1)[0]).sum()
81 | qids = qids.detach().cpu().int().numpy()
82 | for j in range(len(qids)):
83 | qid=qids[j]
84 | typ = qid2type[str(qid)]
85 | if typ == 'yes/no':
86 | score_yesno += batch_score[j]
87 | total_yesno += 1
88 | elif typ == 'other':
89 | score_other += batch_score[j]
90 | total_other += 1
91 | elif typ == 'number':
92 | score_number += batch_score[j]
93 | total_number += 1
94 | else:
95 | print('Hahahahahahahahahahaha')
96 | score = score / len(dataloader.dataset)
97 | upper_bound = upper_bound / len(dataloader.dataset)
98 | score_yesno /= total_yesno
99 | score_other /= total_other
100 | score_number /= total_number
101 | print('\teval overall score: %.2f' % (100 * score))
102 | print('\teval up_bound score: %.2f' % (100 * upper_bound))
103 | print('\teval y/n score: %.2f' % (100 * score_yesno))
104 | print('\teval other score: %.2f' % (100 * score_other))
105 | print('\teval number score: %.2f' % (100 * score_number))
106 |
107 |
108 | def main():
109 | args = parse_args()
110 | dataset = args.dataset
111 |
112 |
113 | with open('util/qid2type_%s.json'%args.dataset,'r') as f:
114 | qid2type=json.load(f)
115 |
116 | if dataset=='cpv1':
117 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl')
118 | elif dataset=='cpv2' or dataset=='v2':
119 | dictionary = Dictionary.load_from_file('data/dictionary.pkl')
120 |
121 | print("Building test dataset...")
122 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset,
123 | cache_image_features=args.cache_features)
124 |
125 | # Build the model using the original constructor
126 | constructor = 'build_%s' % args.model
127 | model = getattr(base_model, constructor)(eval_dset, args.num_hid).cuda()
128 |
129 | if args.debias == "bias_product":
130 | model.debias_loss_fn = BiasProduct()
131 | elif args.debias == "none":
132 | model.debias_loss_fn = Plain()
133 | elif args.debias == "reweight":
134 | model.debias_loss_fn = ReweightByInvBias()
135 | elif args.debias == "learned_mixin":
136 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty)
137 | else:
138 | raise RuntimeError(args.mode)
139 |
140 |
141 | model_state = torch.load(args.model_state)
142 | model.load_state_dict(model_state)
143 |
144 |
145 | model = model.cuda()
146 | batch_size = args.batch_size
147 |
148 | torch.manual_seed(args.seed)
149 | torch.cuda.manual_seed(args.seed)
150 | torch.backends.cudnn.benchmark = True
151 |
152 | # The original version uses multiple workers, but that just seems slower on my setup
153 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0)
154 |
155 |
156 |
157 | print("Starting eval...")
158 |
159 | evaluate(model,eval_loader,qid2type)
160 |
161 |
162 |
163 | if __name__ == '__main__':
164 | main()
165 |
--------------------------------------------------------------------------------
/fc.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch.nn as nn
3 | from torch.nn.utils.weight_norm import weight_norm
4 |
5 |
6 | class FCNet(nn.Module):
7 | """Simple class for non-linear fully connect network
8 | """
9 | def __init__(self, dims):
10 | super(FCNet, self).__init__()
11 |
12 | layers = []
13 | for i in range(len(dims)-2):
14 | in_dim = dims[i]
15 | out_dim = dims[i+1]
16 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
17 | layers.append(nn.ReLU())
18 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
19 | layers.append(nn.ReLU())
20 |
21 | self.main = nn.Sequential(*layers)
22 |
23 | def forward(self, x):
24 | return self.main(x)
25 |
26 |
27 | if __name__ == '__main__':
28 | fc1 = FCNet([10, 20, 10])
29 | print(fc1)
30 |
31 | print('============')
32 | fc2 = FCNet([10, 20])
33 | print(fc2)
34 |
--------------------------------------------------------------------------------
/language_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import numpy as np
5 |
6 |
7 | class WordEmbedding(nn.Module):
8 | """Word Embedding
9 |
10 | The ntoken-th dim is used for padding_idx, which agrees *implicitly*
11 | with the definition in Dictionary.
12 | """
13 | def __init__(self, ntoken, emb_dim, dropout):
14 | super(WordEmbedding, self).__init__()
15 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken)
16 | self.dropout = nn.Dropout(dropout)
17 | self.ntoken = ntoken
18 | self.emb_dim = emb_dim
19 |
20 | def init_embedding(self, np_file):
21 | weight_init = torch.from_numpy(np.load(np_file))
22 | assert weight_init.shape == (self.ntoken, self.emb_dim)
23 | self.emb.weight.data[:self.ntoken] = weight_init
24 |
25 | def forward(self, x):
26 | emb = self.emb(x)
27 | emb = self.dropout(emb)
28 | return emb
29 |
30 |
31 | class QuestionEmbedding(nn.Module):
32 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'):
33 | """Module for question embedding
34 | """
35 | super(QuestionEmbedding, self).__init__()
36 | assert rnn_type == 'LSTM' or rnn_type == 'GRU'
37 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU
38 |
39 | self.rnn = rnn_cls(
40 | in_dim, num_hid, nlayers,
41 | bidirectional=bidirect,
42 | dropout=dropout,
43 | batch_first=True)
44 |
45 | self.in_dim = in_dim
46 | self.num_hid = num_hid
47 | self.nlayers = nlayers
48 | self.rnn_type = rnn_type
49 | self.ndirections = 1 + int(bidirect)
50 |
51 | def init_hidden(self, batch):
52 | # just to get the type of tensor
53 | weight = next(self.parameters()).data
54 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid)
55 | if self.rnn_type == 'LSTM':
56 | return (Variable(weight.new(*hid_shape).zero_()),
57 | Variable(weight.new(*hid_shape).zero_()))
58 | else:
59 | return Variable(weight.new(*hid_shape).zero_())
60 |
61 | def forward(self, x):
62 | # x: [batch, sequence, in_dim]
63 | batch = x.size(0)
64 | hidden = self.init_hidden(batch)
65 | self.rnn.flatten_parameters()
66 | output, hidden = self.rnn(x, hidden)
67 |
68 | if self.ndirections == 1:
69 | return output[:, -1]
70 |
71 | forward_ = output[:, -1, :self.num_hid]
72 | backward = output[:, 0, self.num_hid:]
73 | return torch.cat((forward_, backward), dim=1)
74 |
75 | def forward_all(self, x):
76 | # x: [batch, sequence, in_dim]
77 | batch = x.size(0)
78 | hidden = self.init_hidden(batch)
79 | self.rnn.flatten_parameters()
80 | output, hidden = self.rnn(x, hidden)
81 | return output
82 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | # import cPickle as pickle
4 | import pickle as cPickle # python3
5 | from collections import defaultdict, Counter
6 | from os.path import dirname, join
7 | import os
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils.data import DataLoader
12 | import numpy as np
13 |
14 | from dataset import Dictionary, VQAFeatureDataset
15 | import base_model
16 | from train import train
17 | import utils
18 | import click
19 |
20 | from vqa_debias_loss_functions import *
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method")
25 |
26 | # Arguments we added
27 | parser.add_argument(
28 | '--cache_features', default=True, #True
29 | help="Cache image features in RAM. Makes things much faster, "
30 | "especially if the filesystem is slow, but requires at least 48gb of RAM")
31 | parser.add_argument(
32 | '--dataset', default='cpv2',
33 | choices=["v2", "cpv2", "cpv1"],
34 | help="Run on VQA-2.0 instead of VQA-CP 2.0"
35 | )
36 | parser.add_argument(
37 | '-p', "--entropy_penalty", default=0.36, type=float,
38 | help="Entropy regularizer weight for the learned_mixin model")
39 | parser.add_argument(
40 | '--mode', default="q_v_debias_inner_contras", #updn
41 | choices=["updn", "q_debias","v_debias","q_v_debias", "v_debias_contras", "v_debias_ori_contras_eucl", "v_debias_ori_contras_cos", "q_v_debias_ori_contras_cos", "q_v_debias_inner_contras"],
42 | help="Kind of ensemble loss to use")
43 | parser.add_argument(
44 | '--debias', default="learned_mixin", # learned_mixin
45 | choices=["learned_mixin", "reweight", "bias_product", "none",'focal'],
46 | help="Kind of ensemble loss to use")
47 | parser.add_argument(
48 | '--topq', type=int,default=1,
49 | choices=[1,2,3],
50 | help="num of words to be masked in questio")
51 | parser.add_argument(
52 | '--keep_qtype', default=True,
53 | help="keep qtype or not")
54 | parser.add_argument(
55 | '--topv', type=int,default=1,
56 | choices=[1,3,5,-1],
57 | help="num of object bbox to be masked in image")
58 | parser.add_argument(
59 | '--top_hint',type=int, default=9,
60 | choices=[9,18,27,36],
61 | help="num of hint")
62 | parser.add_argument(
63 | '--qvp', type=int,default=5,
64 | choices=[0,1,2,3,4,5,6,7,8,9,10],
65 | help="ratio of q_bias and v_bias")
66 | parser.add_argument(
67 | '--eval_each_epoch', default=True,
68 | help="Evaluate every epoch, instead of at the end")
69 |
70 | # Arguments from the original model, we leave this default, except we
71 | # set --epochs to 30 since the model maxes out its performance on VQA 2.0 well before then
72 | parser.add_argument('--margin', type=float, default=0.3,
73 | help="Margin of the original contrastive loss")
74 | parser.add_argument('--contras_loss_weight', type=int, default=2)
75 | parser.add_argument('--epochs', type=int, default=30)
76 | parser.add_argument('--num_hid', type=int, default=1024)
77 | parser.add_argument('--model', type=str, default='baseline0_newatt')
78 | parser.add_argument('--output', type=str, default='logs/q_v_debias_contras')
79 | parser.add_argument('--batch_size', type=int, default=512)
80 | parser.add_argument('--seed', type=int, default=0, help='random seed')
81 | args = parser.parse_args()
82 | return args
83 |
84 | def get_bias(train_dset,eval_dset):
85 | # Compute the bias:
86 | # The bias here is just the expected score for each answer/question type
87 | answer_voc_size = train_dset.num_ans_candidates
88 |
89 | # question_type -> answer -> total score
90 | question_type_to_probs = defaultdict(Counter)
91 |
92 | # question_type -> num_occurances
93 | question_type_to_count = Counter()
94 | for ex in train_dset.entries:
95 | ans = ex["answer"]
96 | q_type = ans["question_type"]
97 | question_type_to_count[q_type] += 1
98 | if ans["labels"] is not None:
99 | for label, score in zip(ans["labels"], ans["scores"]):
100 | question_type_to_probs[q_type][label] += score
101 | question_type_to_prob_array = {}
102 |
103 | for q_type, count in question_type_to_count.items():
104 | prob_array = np.zeros(answer_voc_size, np.float32)
105 | for label, total_score in question_type_to_probs[q_type].items():
106 | prob_array[label] += total_score
107 | prob_array /= count
108 | question_type_to_prob_array[q_type] = prob_array
109 |
110 | for ds in [train_dset,eval_dset]:
111 | for ex in ds.entries:
112 | q_type = ex["answer"]["question_type"]
113 | ex["bias"] = question_type_to_prob_array[q_type]
114 |
115 |
116 | def main():
117 | args = parse_args()
118 | dataset=args.dataset
119 | args.output=os.path.join('logs',args.output)
120 | if not os.path.isdir(args.output):
121 | utils.create_dir(args.output)
122 | else:
123 | if click.confirm('Exp directory already exists in {}. Erase?'
124 | .format(args.output, default=False)):
125 | os.system('rm -r ' + args.output)
126 | utils.create_dir(args.output)
127 |
128 | else:
129 | os._exit(1)
130 |
131 |
132 |
133 | if dataset=='cpv1':
134 | dictionary = Dictionary.load_from_file('data/dictionary_v1.pkl')
135 | elif dataset=='cpv2' or dataset=='v2':
136 | dictionary = Dictionary.load_from_file('data/dictionary.pkl')
137 |
138 | print('margin of Contrasitive loss:', args.margin)
139 |
140 | print("Building train dataset...")
141 | train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset,
142 | cache_image_features=args.cache_features)
143 |
144 | print("Building test dataset...")
145 | eval_dset = VQAFeatureDataset('val', dictionary, dataset=dataset,
146 | cache_image_features=args.cache_features)
147 |
148 | get_bias(train_dset,eval_dset)
149 |
150 |
151 | # Build the model using the original constructor
152 | constructor = 'build_%s' % args.model
153 | model = getattr(base_model, constructor)(train_dset, args.num_hid).cuda()
154 | if dataset=='cpv1':
155 | model.w_emb.init_embedding('data/glove6b_init_300d_v1.npy')
156 | elif dataset=='cpv2' or dataset=='v2':
157 | model.w_emb.init_embedding('data/glove6b_init_300d.npy')
158 |
159 | # Add the loss_fn based our arguments
160 | if args.debias == "bias_product":
161 | model.debias_loss_fn = BiasProduct()
162 | elif args.debias == "none":
163 | model.debias_loss_fn = Plain()
164 | elif args.debias == "reweight":
165 | model.debias_loss_fn = ReweightByInvBias()
166 | elif args.debias == "learned_mixin":
167 | model.debias_loss_fn = LearnedMixin(args.entropy_penalty)
168 | elif args.debias=='focal':
169 | model.debias_loss_fn = Focal()
170 | else:
171 | raise RuntimeError(args.mode)
172 |
173 | if args.mode =='v_debias_ori_contras_cos' or args.mode=='q_v_debias_ori_contras_cos' or args.mode=='q_v_debias_inner_contras':
174 | model.is_contras = True
175 |
176 |
177 | with open('util/qid2type_%s.json'%args.dataset,'r') as f:
178 | qid2type=json.load(f)
179 |
180 | model=model.cuda()
181 | batch_size = args.batch_size
182 |
183 | torch.manual_seed(args.seed)
184 | torch.cuda.manual_seed(args.seed)
185 | torch.backends.cudnn.benchmark = True
186 |
187 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=0)
188 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=0)
189 |
190 | print("Starting training...")
191 | train(model, train_loader, eval_loader, args,qid2type)
192 |
193 | if __name__ == '__main__':
194 | main()
195 |
196 |
197 |
198 |
199 |
200 |
201 |
--------------------------------------------------------------------------------
/tools/compute_softscore.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import sys
6 | import json
7 | import numpy as np
8 | import re
9 | import cPickle
10 |
11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12 | from dataset import Dictionary
13 | import utils
14 |
15 |
16 | contractions = {
17 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
18 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
19 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
20 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
21 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
22 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
23 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
24 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
25 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
26 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
27 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've":
28 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
29 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
30 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
31 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
32 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
33 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
34 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
35 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
36 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
37 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
38 | "someoned've": "someone'd've", "someone'dve": "someone'd've",
39 | "someonell": "someone'll", "someones": "someone's", "somethingd":
40 | "something'd", "somethingd've": "something'd've", "something'dve":
41 | "something'd've", "somethingll": "something'll", "thats":
42 | "that's", "thered": "there'd", "thered've": "there'd've",
43 | "there'dve": "there'd've", "therere": "there're", "theres":
44 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
45 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
46 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
47 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
48 | "weren't", "whatll": "what'll", "whatre": "what're", "whats":
49 | "what's", "whatve": "what've", "whens": "when's", "whered":
50 | "where'd", "wheres": "where's", "whereve": "where've", "whod":
51 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
52 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
53 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
54 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
55 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
56 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
57 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
58 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
59 | "you'll", "youre": "you're", "youve": "you've"
60 | }
61 |
62 | manual_map = { 'none': '0',
63 | 'zero': '0',
64 | 'one': '1',
65 | 'two': '2',
66 | 'three': '3',
67 | 'four': '4',
68 | 'five': '5',
69 | 'six': '6',
70 | 'seven': '7',
71 | 'eight': '8',
72 | 'nine': '9',
73 | 'ten': '10'}
74 | articles = ['a', 'an', 'the']
75 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
76 | comma_strip = re.compile("(\d)(\,)(\d)")
77 | punct = [';', r"/", '[', ']', '"', '{', '}',
78 | '(', ')', '=', '+', '\\', '_', '-',
79 | '>', '<', '@', '`', ',', '?', '!']
80 |
81 |
82 | def get_score(occurences):
83 | if occurences == 0:
84 | return 0
85 | elif occurences == 1:
86 | return 0.3
87 | elif occurences == 2:
88 | return 0.6
89 | elif occurences == 3:
90 | return 0.9
91 | else:
92 | return 1
93 |
94 |
95 | def process_punctuation(inText):
96 | outText = inText
97 | for p in punct:
98 | if (p + ' ' in inText or ' ' + p in inText) \
99 | or (re.search(comma_strip, inText) != None):
100 | outText = outText.replace(p, '')
101 | else:
102 | outText = outText.replace(p, ' ')
103 | outText = period_strip.sub("", outText, re.UNICODE)
104 | return outText
105 |
106 |
107 | def process_digit_article(inText):
108 | outText = []
109 | tempText = inText.lower().split()
110 | for word in tempText:
111 | word = manual_map.setdefault(word, word)
112 | if word not in articles:
113 | outText.append(word)
114 | else:
115 | pass
116 | for wordId, word in enumerate(outText):
117 | if word in contractions:
118 | outText[wordId] = contractions[word]
119 | outText = ' '.join(outText)
120 | return outText
121 |
122 |
123 | def multiple_replace(text, wordDict):
124 | for key in wordDict:
125 | text = text.replace(key, wordDict[key])
126 | return text
127 |
128 |
129 | def preprocess_answer(answer):
130 | answer = process_digit_article(process_punctuation(answer))
131 | answer = answer.replace(',', '')
132 | return answer
133 |
134 |
135 | def filter_answers(answers_dset, min_occurence):
136 | """This will change the answer to preprocessed version
137 | """
138 | occurence = {}
139 | for ans_entry in answers_dset:
140 | gtruth = ans_entry['multiple_choice_answer']
141 | gtruth = preprocess_answer(gtruth)
142 | if gtruth not in occurence:
143 | occurence[gtruth] = set()
144 | occurence[gtruth].add(ans_entry['question_id'])
145 | for answer in occurence.keys():
146 | if len(occurence[answer]) < min_occurence:
147 | occurence.pop(answer)
148 |
149 | print('Num of answers that appear >= %d times: %d' % (
150 | min_occurence, len(occurence)))
151 | return occurence
152 |
153 |
154 |
155 |
156 |
157 |
158 | def create_ans2label(occurence, name, cache_root):
159 | """Note that this will also create label2ans.pkl at the same time
160 |
161 | occurence: dict {answer -> whatever}
162 | name: prefix of the output file
163 | cache_root: str
164 | """
165 | ans2label = {}
166 | label2ans = []
167 | label = 0
168 | for answer in occurence:
169 | label2ans.append(answer)
170 | ans2label[answer] = label
171 | label += 1
172 |
173 | utils.create_dir(cache_root)
174 |
175 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl')
176 | cPickle.dump(ans2label, open(cache_file, 'wb'))
177 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl')
178 | cPickle.dump(label2ans, open(cache_file, 'wb'))
179 | return ans2label
180 |
181 |
182 | def compute_target(answers_dset, ans2label, name, cache_root):
183 | """Augment answers_dset with soft score as label
184 |
185 | ***answers_dset should be preprocessed***
186 |
187 | Write result into a cache file
188 | """
189 | target = []
190 | for ans_entry in answers_dset:
191 | answers = ans_entry['answers']
192 | answer_count = {}
193 | for answer in answers:
194 | answer_ = answer['answer']
195 | answer_count[answer_] = answer_count.get(answer_, 0) + 1
196 |
197 | labels = []
198 | scores = []
199 | for answer in answer_count:
200 | if answer not in ans2label:
201 | continue
202 | labels.append(ans2label[answer])
203 | score = get_score(answer_count[answer])
204 | scores.append(score)
205 |
206 | label_counts = {}
207 | for k, v in answer_count.items():
208 | if k in ans2label:
209 | label_counts[ans2label[k]] = v
210 |
211 | target.append({
212 | 'question_id': ans_entry['question_id'],
213 | 'question_type': ans_entry['question_type'],
214 | 'image_id': ans_entry['image_id'],
215 | 'label_counts': label_counts,
216 | 'labels': labels,
217 | 'scores': scores
218 | })
219 |
220 | print(cache_root)
221 | utils.create_dir(cache_root)
222 | cache_file = os.path.join(cache_root, name+'_target.pkl')
223 | print(cache_file)
224 | with open(cache_file, 'wb') as f:
225 | cPickle.dump(target, f)
226 | return target
227 |
228 |
229 |
230 | def get_answer(qid, answers):
231 | for ans in answers:
232 | if ans['question_id'] == qid:
233 | return ans
234 |
235 |
236 | def get_question(qid, questions):
237 | for question in questions:
238 | if question['question_id'] == qid:
239 | return question
240 |
241 |
242 | def load_cp():
243 | train_answer_file = "data/vqacp_v2_train_annotations.json"
244 | with open(train_answer_file) as f:
245 | train_answers = json.load(f) # ['annotations']
246 |
247 | val_answer_file = "data/vqacp_v2_test_annotations.json"
248 | with open(val_answer_file) as f:
249 | val_answers = json.load(f) # ['annotations']
250 |
251 | occurence = filter_answers(train_answers, 9)
252 | ans2label = create_ans2label(occurence, 'trainval', "data/cp-cache")
253 | compute_target(train_answers, ans2label, 'train', "data/cp-cache")
254 | compute_target(val_answers, ans2label, 'val', "data/cp-cache")
255 |
256 | def load_cp_v1():
257 | train_answer_file = "data/vqacp_v1_train_annotations.json"
258 | with open(train_answer_file) as f:
259 | train_answers = json.load(f) # ['annotations']
260 |
261 | val_answer_file = "data/vqacp_v1_test_annotations.json"
262 | with open(val_answer_file) as f:
263 | val_answers = json.load(f) # ['annotations']
264 |
265 | occurence = filter_answers(train_answers, 9)
266 | ans2label = create_ans2label(occurence, 'trainval', "data/cp-v1-cache")
267 | compute_target(train_answers, ans2label, 'train', "data/cp-v1-cache")
268 | compute_target(val_answers, ans2label, 'val', "data/cp-v1-cache")
269 |
270 |
271 | def load_v2():
272 | train_answer_file = 'data/v2_mscoco_train2014_annotations.json'
273 | with open(train_answer_file) as f:
274 | train_answers = json.load(f)['annotations']
275 |
276 | val_answer_file = 'data/v2_mscoco_val2014_annotations.json'
277 | with open(val_answer_file) as f:
278 | val_answers = json.load(f)['annotations']
279 |
280 | occurence = filter_answers(train_answers, 9)
281 | ans2label = create_ans2label(occurence, 'trainval', "data/cache")
282 | compute_target(train_answers, ans2label, 'train', "data/cache")
283 | compute_target(val_answers, ans2label, 'val', "data/cache")
284 |
285 |
286 | def main():
287 | parser = argparse.ArgumentParser("Dataset preprocessing")
288 | parser.add_argument("dataset", choices=["cp_v2", "v2",'cp_v1'])
289 | args = parser.parse_args()
290 | if args.dataset == "v2":
291 | load_v2()
292 | elif args.dataset == "cp_v1":
293 | load_cp_v1()
294 | elif args.dataset=='cp_v2':
295 | load_cp()
296 |
297 |
298 |
299 | if __name__ == '__main__':
300 | main()
301 |
--------------------------------------------------------------------------------
/tools/create_dictionary.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import json
5 | import numpy as np
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 | from dataset import Dictionary
8 |
9 |
10 |
11 |
12 | def create_dictionary(dataroot):
13 | dictionary = Dictionary()
14 | questions = []
15 | files = [
16 | 'v2_OpenEnded_mscoco_train2014_questions.json',
17 | 'v2_OpenEnded_mscoco_val2014_questions.json',
18 | 'v2_OpenEnded_mscoco_test2015_questions.json',
19 | 'v2_OpenEnded_mscoco_test-dev2015_questions.json'
20 | ]
21 | for path in files:
22 | question_path = os.path.join(dataroot, path)
23 | qs = json.load(open(question_path))['questions']
24 | for q in qs:
25 | dictionary.tokenize(q['question'], True)
26 | dictionary.tokenize('wordmask',True)
27 | return dictionary
28 |
29 |
30 | def create_glove_embedding_init(idx2word, glove_file):
31 | word2emb = {}
32 | with open(glove_file, 'r') as f:
33 | entries = f.readlines()
34 | emb_dim = len(entries[0].split(' ')) - 1
35 | print('embedding dim is %d' % emb_dim)
36 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32)
37 |
38 | for entry in entries:
39 | vals = entry.split(' ')
40 | word = vals[0]
41 | vals = map(float, vals[1:])
42 | word2emb[word] = np.array(vals)
43 | for idx, word in enumerate(idx2word):
44 | if word not in word2emb:
45 | continue
46 | weights[idx] = word2emb[word]
47 | return weights, word2emb
48 |
49 |
50 | if __name__ == '__main__':
51 | d = create_dictionary('data')
52 | d.dump_to_file('data/dictionary.pkl')
53 |
54 | d = Dictionary.load_from_file('data/dictionary.pkl')
55 | emb_dim = 300
56 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim
57 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file)
58 | np.save('data/glove6b_init_%dd.npy' % emb_dim, weights)
59 |
--------------------------------------------------------------------------------
/tools/create_dictionary_v1.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import json
5 | import numpy as np
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 | from dataset import Dictionary
8 |
9 |
10 | def create_dictionary(dataroot):
11 | dictionary = Dictionary()
12 | questions = []
13 | files = [
14 | 'OpenEnded_mscoco_train2014_questions.json',
15 | 'OpenEnded_mscoco_val2014_questions.json',
16 | 'OpenEnded_mscoco_test2015_questions.json',
17 | 'OpenEnded_mscoco_test-dev2015_questions.json'
18 | ]
19 | for path in files:
20 | question_path = os.path.join(dataroot, path)
21 | qs = json.load(open(question_path))['questions']
22 | for q in qs:
23 | dictionary.tokenize(q['question'], True)
24 | return dictionary
25 |
26 |
27 | def create_glove_embedding_init(idx2word, glove_file):
28 | word2emb = {}
29 | with open(glove_file, 'r') as f:
30 | entries = f.readlines()
31 | emb_dim = len(entries[0].split(' ')) - 1
32 | print('embedding dim is %d' % emb_dim)
33 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32)
34 |
35 | for entry in entries:
36 | vals = entry.split(' ')
37 | word = vals[0]
38 | vals = map(float, vals[1:])
39 | word2emb[word] = np.array(vals)
40 | for idx, word in enumerate(idx2word):
41 | if word not in word2emb:
42 | continue
43 | weights[idx] = word2emb[word]
44 | return weights, word2emb
45 |
46 |
47 | if __name__ == '__main__':
48 | d = create_dictionary('data')
49 | d.dump_to_file('data/dictionary_v1.pkl')
50 |
51 | d = Dictionary.load_from_file('data/dictionary_v1.pkl')
52 | emb_dim = 300
53 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim
54 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file)
55 | np.save('data/glove6b_init_%dd_v1.npy' % emb_dim, weights)
56 |
--------------------------------------------------------------------------------
/tools/download.sh:
--------------------------------------------------------------------------------
1 | ## Script for downloading data
2 |
3 | # GloVe Vectors
4 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip
5 | unzip data/glove.6B.zip -d data/glove
6 | rm data/glove.6B.zip
7 |
8 | # VQA-CP2
9 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_annotations.json
10 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_annotations.json
11 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_train_questions.json
12 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v2_test_questions.json
13 |
14 | # VQA-CP1
15 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json
16 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json
17 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json
18 | wget -P data https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json
19 |
20 | # VQA-V2
21 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip
22 | unzip data/v2_Questions_Train_mscoco.zip -d data
23 | rm data/v2_Questions_Train_mscoco.zip
24 |
25 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip
26 | unzip data/v2_Questions_Val_mscoco.zip -d data
27 | rm data/v2_Questions_Val_mscoco.zip
28 |
29 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip
30 | unzip data/v2_Questions_Test_mscoco.zip -d data
31 | rm data/v2_Questions_Test_mscoco.zip
32 |
33 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip
34 | unzip data/v2_Annotations_Train_mscoco.zip -d data
35 | rm data/v2_Annotations_Train_mscoco.zip
36 |
37 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip
38 | unzip data/v2_Annotations_Val_mscoco.zip -d data
39 | rm data/v2_Annotations_Val_mscoco.zip
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/tools/download_v1.sh:
--------------------------------------------------------------------------------
1 |
2 | # VQA-CP1
3 | wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_annotations.json
4 | wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_annotations.json
5 | wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_train_questions.json
6 | wget https://computing.ece.vt.edu/~aish/vqacp/vqacp_v1_test_questions.json
7 |
--------------------------------------------------------------------------------
/tools/process.sh:
--------------------------------------------------------------------------------
1 | # Process data
2 |
3 | python tools/create_dictionary.py
4 | python tools/create_dictionary_v1.py
5 | python tools/compute_softscore.py v2
6 | python tools/compute_softscore.py cp_v1
7 | python tools/compute_softscore.py cp_v2
8 |
9 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 | import time
5 | from os.path import join
6 |
7 | import torch
8 | import torch.nn as nn
9 | import utils
10 | from torch.autograd import Variable
11 | import numpy as np
12 | from tqdm import tqdm
13 | import random
14 | import copy
15 |
16 | def compute_score_with_logits(logits, labels):
17 | logits = torch.argmax(logits, 1)
18 | one_hots = torch.zeros(*labels.size()).cuda()
19 | one_hots.scatter_(1, logits.view(-1, 1), 1)
20 | scores = (one_hots * labels)
21 | return scores
22 |
23 | def train(model, train_loader, eval_loader,args,qid2type):
24 |
25 | dataset=args.dataset
26 | num_epochs=args.epochs
27 | mode=args.mode
28 | run_eval=args.eval_each_epoch
29 | output=args.output
30 | optim = torch.optim.Adamax(model.parameters())
31 | logger = utils.Logger(os.path.join(output, 'log.txt'))
32 | total_step = 0
33 | best_eval_score = 0
34 |
35 | if mode=='q_debias':
36 | topq=args.topq
37 | keep_qtype=args.keep_qtype
38 | elif mode=='v_debias':
39 | topv=args.topv
40 | top_hint=args.top_hint
41 | elif mode=='v_debias_contras' or mode=='v_debias_ori_contras_cos' or mode == 'v_debias_ori_contras_eucl':
42 | topv=args.topv
43 | top_hint=args.top_hint
44 | elif mode=='q_v_debias' or mode=='q_v_debias_ori_contras_cos' or mode=='q_v_debias_inner_contras':
45 | topv=args.topv
46 | top_hint=args.top_hint
47 | topq=args.topq
48 | keep_qtype=args.keep_qtype
49 | qvp=args.qvp
50 |
51 |
52 | for epoch in range(num_epochs):
53 | total_loss = 0
54 | train_score = 0
55 |
56 | t = time.time()
57 | for i, (v, q, a, b, hintscore,type_mask,notype_mask,q_mask) in tqdm(enumerate(train_loader), ncols=100,
58 | desc="Epoch %d" % (epoch + 1), total=len(train_loader)):
59 |
60 | total_step += 1
61 |
62 |
63 | #########################################
64 | v = Variable(v).cuda().requires_grad_()
65 | q = Variable(q).cuda()
66 | q_mask=Variable(q_mask).cuda()
67 | a = Variable(a).cuda()
68 | b = Variable(b).cuda()
69 | hintscore = Variable(hintscore).cuda()
70 | type_mask=Variable(type_mask).float().cuda()
71 | notype_mask=Variable(notype_mask).float().cuda()
72 | #########################################
73 |
74 | if mode=='updn':
75 | pred, loss,_ = model(v, q, a, b, None)
76 | if (loss != loss).any():
77 | raise ValueError("NaN loss")
78 | loss.backward()
79 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
80 | optim.step()
81 | optim.zero_grad()
82 |
83 | total_loss += loss.item() * q.size(0)
84 | batch_score = compute_score_with_logits(pred, a.data).sum()
85 | train_score += batch_score
86 |
87 | elif mode=='q_debias':
88 | if keep_qtype==True:
89 | sen_mask=type_mask
90 | else:
91 | sen_mask=notype_mask
92 | ## first train
93 | pred, loss,word_emb = model(v, q, a, b, None)
94 |
95 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0]
96 |
97 | if (loss != loss).any():
98 | raise ValueError("NaN loss")
99 | loss.backward()
100 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
101 | optim.step()
102 | optim.zero_grad()
103 |
104 | total_loss += loss.item() * q.size(0)
105 | batch_score = compute_score_with_logits(pred, a.data).sum()
106 | train_score += batch_score
107 |
108 | ## second train
109 |
110 | word_grad_cam = word_grad.sum(2)
111 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000)
112 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask)
113 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask
114 |
115 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq]
116 |
117 | q2 = copy.deepcopy(q_mask)
118 |
119 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1]
120 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0]
121 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1]
122 | if dataset=='cpv1':
123 | m3=m1*18330
124 | else:
125 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0]
126 | q2 = q2 * m2.long() + m3.long()
127 |
128 | pred, _, _ = model(v, q2, None, b, None)
129 |
130 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5]
131 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda()
132 | false_ans.scatter_(1, pred_ind, 0)
133 | a2 = a * false_ans
134 | q3 = copy.deepcopy(q)
135 | if dataset=='cpv1':
136 | q3.scatter_(1, w_ind, 18330)
137 | else:
138 | q3.scatter_(1, w_ind, 18455)
139 |
140 | ## third train
141 |
142 | pred, loss, _ = model(v, q3, a2, b, None)
143 |
144 | if (loss != loss).any():
145 | raise ValueError("NaN loss")
146 | loss.backward()
147 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
148 | optim.step()
149 | optim.zero_grad()
150 |
151 | total_loss += loss.item() * q.size(0)
152 |
153 | elif mode=='v_debias':
154 | ## first train
155 | pred, loss, _ = model(v, q, a, b, None)
156 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0]
157 |
158 | if (loss != loss).any():
159 | raise ValueError("NaN loss")
160 | loss.backward()
161 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
162 | optim.step()
163 | optim.zero_grad()
164 |
165 | total_loss += loss.item() * q.size(0)
166 | batch_score = compute_score_with_logits(pred, a.data).sum()
167 | train_score += batch_score
168 |
169 | ##second train
170 | v_mask = torch.zeros(v.shape[0], 36).cuda()
171 | visual_grad_cam = visual_grad.sum(2)
172 | hint_sort, hint_ind = hintscore.sort(1, descending=True)
173 | v_ind = hint_ind[:, :top_hint]
174 | v_grad = visual_grad_cam.gather(1, v_ind)
175 |
176 | if topv==-1:
177 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True)
178 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1)
179 | v_grad_sum=torch.cumsum(v_grad_score,dim=1)
180 | v_grad_mask=(v_grad_sum<=0.65).long()
181 | v_grad_mask[:,0] = 1
182 | v_mask_ind=v_grad_mask*v_ind
183 | for x in range(a.shape[0]):
184 | num=len(torch.nonzero(v_grad_mask[x]))
185 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1)
186 | else:
187 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv]
188 | v_star = v_ind.gather(1, v_grad_ind)
189 | v_mask.scatter_(1, v_star, 1)
190 |
191 |
192 | pred, _, _ = model(v, q, None, b, v_mask) # P = VQA(I+,Q)
193 |
194 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5]
195 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda()
196 | false_ans.scatter_(1, pred_ind, 0)
197 | a2 = a * false_ans
198 |
199 | v_mask = 1 - v_mask
200 |
201 | pred, loss, _ = model(v, q, a2, b, v_mask) # P = VQA(I-,Q)
202 |
203 | if (loss != loss).any():
204 | raise ValueError("NaN loss")
205 | loss.backward()
206 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
207 | optim.step()
208 | optim.zero_grad()
209 |
210 | total_loss += loss.item() * q.size(0)
211 |
212 | elif mode=='v_debias_contras':
213 | ## first train
214 | pred, loss, _, feature_ori = model(v, q, a, b, None)
215 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0]
216 |
217 | if (loss != loss).any():
218 | raise ValueError("NaN loss")
219 | loss.backward(retain_graph=True)
220 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
221 | optim.step()
222 | optim.zero_grad()
223 |
224 | total_loss += loss.item() * q.size(0)
225 | batch_score = compute_score_with_logits(pred, a.data).sum()
226 | train_score += batch_score
227 |
228 | ##second train
229 | v_mask = torch.zeros(v.shape[0], 36).cuda()
230 | visual_grad_cam = visual_grad.sum(2)
231 | hint_sort, hint_ind = hintscore.sort(1, descending=True)
232 | v_ind = hint_ind[:, :top_hint]
233 | v_grad = visual_grad_cam.gather(1, v_ind)
234 |
235 | if topv==-1:
236 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True)
237 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1)
238 | v_grad_sum=torch.cumsum(v_grad_score,dim=1)
239 | v_grad_mask=(v_grad_sum<=0.65).long()
240 | v_grad_mask[:,0] = 1
241 | v_mask_ind=v_grad_mask*v_ind
242 | for x in range(a.shape[0]):
243 | num=len(torch.nonzero(v_grad_mask[x]))
244 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1)
245 | else:
246 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv]
247 | v_star = v_ind.gather(1, v_grad_ind)
248 | v_mask.scatter_(1, v_star, 1)
249 |
250 |
251 | pred, _, _, feature_pos = model(v, q, None, b, v_mask) # P = VQA(I+,Q)
252 |
253 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5]
254 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda()
255 | false_ans.scatter_(1, pred_ind, 0)
256 | a2 = a * false_ans
257 |
258 | v_mask = 1 - v_mask
259 |
260 | pred, loss, _, feature_neg = model(v, q, a2, b, v_mask) # P = VQA(I-,Q)
261 |
262 | #Compute the contras loss
263 | # inner product
264 | # pos = (feature_ori * feature_pos).sum(1)
265 | # neg = (feature_ori * feature_neg).sum(1)
266 | # cosine_similarity
267 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1)
268 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1)
269 |
270 | logit = torch.stack((pos, neg), 1) # [b, 2]
271 | softmax_logit = nn.functional.softmax(logit, 1) #[b, 2]
272 |
273 | contras_loss = - torch.log(softmax_logit[:, 0])
274 | contras_loss = contras_loss.mean()
275 |
276 | loss += contras_loss
277 |
278 | if (loss != loss).any():
279 | raise ValueError("NaN loss")
280 | loss.backward()
281 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
282 | optim.step()
283 | optim.zero_grad()
284 |
285 | total_loss += loss.item() * q.size(0)
286 |
287 |
288 | elif mode=='v_debias_ori_contras_cos' or mode=='v_debias_ori_contras_eucl' or mode == 'q_v_debias_ori_contras_cos':
289 | if not mode == 'q_v_debias_ori_contras_cos':
290 | ## first train
291 | pred, loss, _, feature_ori = model(v, q, a, b, None)
292 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0]
293 |
294 | if (loss != loss).any():
295 | raise ValueError("NaN loss")
296 | loss.backward(retain_graph=True)
297 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
298 | optim.step()
299 | optim.zero_grad()
300 |
301 | total_loss += loss.item() * q.size(0)
302 | batch_score = compute_score_with_logits(pred, a.data).sum()
303 | train_score += batch_score
304 |
305 | ##second train
306 | v_mask = torch.zeros(v.shape[0], 36).cuda()
307 | visual_grad_cam = visual_grad.sum(2)
308 | hint_sort, hint_ind = hintscore.sort(1, descending=True)
309 | v_ind = hint_ind[:, :top_hint]
310 | v_grad = visual_grad_cam.gather(1, v_ind)
311 |
312 | if topv==-1:
313 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True)
314 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1)
315 | v_grad_sum=torch.cumsum(v_grad_score,dim=1)
316 | v_grad_mask=(v_grad_sum<=0.65).long()
317 | v_grad_mask[:,0] = 1
318 | v_mask_ind=v_grad_mask*v_ind
319 | for x in range(a.shape[0]):
320 | num=len(torch.nonzero(v_grad_mask[x]))
321 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1)
322 | else:
323 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv]
324 | v_star = v_ind.gather(1, v_grad_ind)
325 | v_mask.scatter_(1, v_star, 1)
326 |
327 |
328 | pred, _, _, feature_pos = model(v, q, None, b, v_mask) # P = VQA(I+,Q)
329 |
330 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5]
331 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda()
332 | false_ans.scatter_(1, pred_ind, 0)
333 | a2 = a * false_ans
334 |
335 | v_mask = 1 - v_mask
336 |
337 | pred, loss, _, feature_neg = model(v, q, a2, b, v_mask) # P = VQA(I-,Q)
338 |
339 | #Compute the contras loss
340 | # inner product
341 | # pos = (feature_ori * feature_pos).sum(1)
342 | # neg = (feature_ori * feature_neg).sum(1)
343 |
344 | # cosine_similarity
345 | if mode=='v_debias_ori_contras_cos':
346 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1)
347 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1)
348 | pos_dis = 1 - pos
349 | neg_dis = 1 - neg
350 | contras_loss_pos = pos_dis
351 | margin = args.margin
352 | contras_loss_neg = torch.clamp(margin - neg_dis, min=0.0)
353 | # euclidean_distance
354 | elif mode=='v_debias_ori_contras_eucl':
355 | pos_dis = nn.functional.pairwise_distance(feature_ori, feature_pos, p=2)
356 | neg_dis = nn.functional.pairwise_distance(feature_ori, feature_neg, p=2)
357 | contras_loss_pos = torch.pow(pos_dis, 2)
358 | margin = args.margin
359 | contras_loss_neg = torch.pow(torch.clamp(margin - neg_dis, min=0.0), 2)
360 |
361 | contras_loss = contras_loss_pos + contras_loss_neg
362 | contras_loss = contras_loss.mean()
363 |
364 | loss += contras_loss
365 |
366 | if (loss != loss).any():
367 | raise ValueError("NaN loss")
368 | loss.backward()
369 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
370 | optim.step()
371 | optim.zero_grad()
372 |
373 | total_loss += loss.item() * q.size(0)
374 |
375 | elif mode=='q_v_debias_ori_contras_cos':
376 | random_num = random.randint(1, 10)
377 | if keep_qtype == True:
378 | sen_mask = type_mask
379 | else:
380 | sen_mask = notype_mask
381 | if random_num<=qvp:
382 | ## first train
383 | pred, loss, word_emb, feature_ori = model(v, q, a, b, None)
384 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0]
385 |
386 | if (loss != loss).any():
387 | raise ValueError("NaN loss")
388 | loss.backward(retain_graph=True)
389 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
390 | optim.step()
391 | optim.zero_grad()
392 |
393 | total_loss += loss.item() * q.size(0)
394 | batch_score = compute_score_with_logits(pred, a.data).sum()
395 | train_score += batch_score
396 |
397 | ## second train
398 |
399 | word_grad_cam = word_grad.sum(2)
400 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000)
401 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask)
402 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask
403 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq]
404 |
405 | q2 = copy.deepcopy(q_mask)
406 |
407 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1]
408 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0]
409 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1]
410 | if dataset=='cpv1':
411 | m3=m1*18330
412 | else:
413 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0]
414 | q2 = q2 * m2.long() + m3.long()
415 |
416 | pred, _, _, feature_pos = model(v, q2, None, b, None)
417 |
418 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5]
419 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda()
420 | false_ans.scatter_(1, pred_ind, 0)
421 | a2 = a * false_ans
422 | q3 = copy.deepcopy(q)
423 | if dataset=='cpv1':
424 | q3.scatter_(1, w_ind, 18330)
425 | else:
426 | q3.scatter_(1, w_ind, 18455)
427 |
428 | ## third train
429 |
430 | pred, loss, _, feature_neg = model(v, q3, a2, b, None)
431 |
432 | # cosine_similarity
433 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1)
434 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1)
435 | pos_dis = 1 - pos
436 | neg_dis = 1 - neg
437 | contras_loss_pos = pos_dis
438 | margin = args.margin
439 | contras_loss_neg = torch.clamp(margin - neg_dis, min=0.0)
440 |
441 |
442 | contras_loss = contras_loss_pos + contras_loss_neg
443 | contras_loss = contras_loss.mean()
444 |
445 | loss += args.contras_loss_weight * contras_loss # with CSS
446 | # loss = contras_loss # only_contras_loss
447 |
448 | if (loss != loss).any():
449 | raise ValueError("NaN loss")
450 | loss.backward()
451 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
452 | optim.step()
453 | optim.zero_grad()
454 |
455 | total_loss += loss.item() * q.size(0)
456 |
457 |
458 | else:
459 | ## first train
460 | pred, loss, _, feature_ori = model(v, q, a, b, None)
461 | visual_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0]
462 |
463 | if (loss != loss).any():
464 | raise ValueError("NaN loss")
465 | loss.backward(retain_graph=True)
466 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
467 | optim.step()
468 | optim.zero_grad()
469 |
470 | total_loss += loss.item() * q.size(0)
471 | batch_score = compute_score_with_logits(pred, a.data).sum()
472 | train_score += batch_score
473 |
474 | ##second train
475 | v_mask = torch.zeros(v.shape[0], 36).cuda()
476 | visual_grad_cam = visual_grad.sum(2)
477 | hint_sort, hint_ind = hintscore.sort(1, descending=True)
478 | v_ind = hint_ind[:, :top_hint]
479 | v_grad = visual_grad_cam.gather(1, v_ind)
480 |
481 | if topv == -1:
482 | v_grad_score, v_grad_ind = v_grad.sort(1, descending=True)
483 | v_grad_score = nn.functional.softmax(v_grad_score * 10, dim=1)
484 | v_grad_sum = torch.cumsum(v_grad_score, dim=1)
485 | v_grad_mask = (v_grad_sum <= 0.65).long()
486 | v_grad_mask[:,0] = 1
487 | v_mask_ind = v_grad_mask * v_ind
488 | for x in range(a.shape[0]):
489 | num = len(torch.nonzero(v_grad_mask[x]))
490 | v_mask[x].scatter_(0, v_mask_ind[x,:num], 1)
491 | else:
492 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv]
493 | v_star = v_ind.gather(1, v_grad_ind)
494 | v_mask.scatter_(1, v_star, 1)
495 |
496 | pred, _, _, feature_pos = model(v, q, None, b, v_mask)
497 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5]
498 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda()
499 | false_ans.scatter_(1, pred_ind, 0)
500 | a2 = a * false_ans
501 |
502 | v_mask = 1 - v_mask
503 |
504 | pred, loss, _, feature_neg = model(v, q, a2, b, v_mask)
505 |
506 | # cosine_similarity
507 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1)
508 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1)
509 | pos_dis = 1 - pos
510 | neg_dis = 1 - neg
511 | contras_loss_pos = pos_dis
512 | margin = args.margin
513 | contras_loss_neg = torch.clamp(margin - neg_dis, min=0.0)
514 |
515 |
516 | contras_loss = contras_loss_pos + contras_loss_neg
517 | contras_loss = contras_loss.mean()
518 |
519 | loss += args.contras_loss_weight * contras_loss # with CSS
520 | # loss = contras_loss # only_contras_loss
521 |
522 | if (loss != loss).any():
523 | raise ValueError("NaN loss")
524 | loss.backward()
525 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
526 | optim.step()
527 | optim.zero_grad()
528 |
529 | total_loss += loss.item() * q.size(0)
530 |
531 |
532 | elif mode=='q_v_debias':
533 | random_num = random.randint(1, 10)
534 | if keep_qtype == True:
535 | sen_mask = type_mask
536 | else:
537 | sen_mask = notype_mask
538 | if random_num<=qvp:
539 | ## first train
540 | pred, loss, word_emb = model(v, q, a, b, None)
541 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0]
542 |
543 | if (loss != loss).any():
544 | raise ValueError("NaN loss")
545 | loss.backward()
546 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
547 | optim.step()
548 | optim.zero_grad()
549 |
550 | total_loss += loss.item() * q.size(0)
551 | batch_score = compute_score_with_logits(pred, a.data).sum()
552 | train_score += batch_score
553 |
554 | ## second train
555 |
556 | word_grad_cam = word_grad.sum(2)
557 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000)
558 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask)
559 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask
560 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq]
561 |
562 | q2 = copy.deepcopy(q_mask)
563 |
564 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1]
565 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0]
566 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1]
567 | if dataset=='cpv1':
568 | m3=m1*18330
569 | else:
570 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0]
571 | q2 = q2 * m2.long() + m3.long()
572 |
573 | pred, _, _ = model(v, q2, None, b, None)
574 |
575 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5]
576 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda()
577 | false_ans.scatter_(1, pred_ind, 0)
578 | a2 = a * false_ans
579 | q3 = copy.deepcopy(q)
580 | if dataset=='cpv1':
581 | q3.scatter_(1, w_ind, 18330)
582 | else:
583 | q3.scatter_(1, w_ind, 18455)
584 |
585 | ## third train
586 |
587 | pred, loss, _ = model(v, q3, a2, b, None)
588 |
589 | if (loss != loss).any():
590 | raise ValueError("NaN loss")
591 | loss.backward()
592 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
593 | optim.step()
594 | optim.zero_grad()
595 |
596 | total_loss += loss.item() * q.size(0)
597 |
598 |
599 | else:
600 | ## first train
601 | pred, loss, _ = model(v, q, a, b, None)
602 | visual_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0]
603 |
604 | if (loss != loss).any():
605 | raise ValueError("NaN loss")
606 | loss.backward()
607 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
608 | optim.step()
609 | optim.zero_grad()
610 |
611 | total_loss += loss.item() * q.size(0)
612 | batch_score = compute_score_with_logits(pred, a.data).sum()
613 | train_score += batch_score
614 |
615 | ##second train
616 | v_mask = torch.zeros(v.shape[0], 36).cuda()
617 | visual_grad_cam = visual_grad.sum(2)
618 | hint_sort, hint_ind = hintscore.sort(1, descending=True)
619 | v_ind = hint_ind[:, :top_hint]
620 | v_grad = visual_grad_cam.gather(1, v_ind)
621 |
622 | if topv == -1:
623 | v_grad_score, v_grad_ind = v_grad.sort(1, descending=True)
624 | v_grad_score = nn.functional.softmax(v_grad_score * 10, dim=1)
625 | v_grad_sum = torch.cumsum(v_grad_score, dim=1)
626 | v_grad_mask = (v_grad_sum <= 0.65).long()
627 | v_grad_mask[:,0] = 1
628 | v_mask_ind = v_grad_mask * v_ind
629 | for x in range(a.shape[0]):
630 | num = len(torch.nonzero(v_grad_mask[x]))
631 | v_mask[x].scatter_(0, v_mask_ind[x,:num], 1)
632 | else:
633 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv]
634 | v_star = v_ind.gather(1, v_grad_ind)
635 | v_mask.scatter_(1, v_star, 1)
636 |
637 | pred, _, _ = model(v, q, None, b, v_mask)
638 | pred_ind = torch.argsort(pred, 1, descending=True)[:, :5]
639 | false_ans = torch.ones(pred.shape[0], pred.shape[1]).cuda()
640 | false_ans.scatter_(1, pred_ind, 0)
641 | a2 = a * false_ans
642 |
643 | v_mask = 1 - v_mask
644 |
645 | pred, loss, _ = model(v, q, a2, b, v_mask)
646 |
647 | if (loss != loss).any():
648 | raise ValueError("NaN loss")
649 | loss.backward()
650 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
651 | optim.step()
652 | optim.zero_grad()
653 |
654 | total_loss += loss.item() * q.size(0)
655 |
656 | elif mode=='q_v_debias_inner_contras':
657 | random_num = random.randint(1, 10)
658 | if keep_qtype == True:
659 | sen_mask = type_mask
660 | else:
661 | sen_mask = notype_mask
662 | if random_num<=qvp:
663 | ## first train
664 | pred, loss, word_emb, feature_ori = model(v, q, a, b, None)
665 | word_grad = torch.autograd.grad((pred * (a > 0).float()).sum(), word_emb, create_graph=True)[0]
666 |
667 | if (loss != loss).any():
668 | raise ValueError("NaN loss")
669 | loss.backward(retain_graph=True)
670 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
671 | optim.step()
672 | optim.zero_grad()
673 |
674 | total_loss += loss.item() * q.size(0)
675 | batch_score = compute_score_with_logits(pred, a.data).sum()
676 | train_score += batch_score
677 |
678 | ## second train
679 |
680 | word_grad_cam = word_grad.sum(2)
681 | # word_grad_cam_sigmoid = torch.sigmoid(word_grad_cam * 1000)
682 | word_grad_cam_sigmoid = torch.exp(word_grad_cam * sen_mask)
683 | word_grad_cam_sigmoid = word_grad_cam_sigmoid * sen_mask
684 | w_ind = word_grad_cam_sigmoid.sort(1, descending=True)[1][:, :topq]
685 |
686 | q2 = copy.deepcopy(q_mask)
687 |
688 | m1 = copy.deepcopy(sen_mask) ##[0,0,0...0,1,1,1,1]
689 | m1.scatter_(1, w_ind, 0) ##[0,0,0...0,0,1,1,0]
690 | m2 = 1 - m1 ##[1,1,1...1,1,0,0,1]
691 | if dataset=='cpv1':
692 | m3=m1*18330
693 | else:
694 | m3 = m1 * 18455 ##[0,0,0...0,0,18455,18455,0]
695 | q2 = q2 * m2.long() + m3.long()
696 |
697 | pred_pos, _, _, feature_pos = model(v, q2, None, b, None) # P = VQA(I,Q+, )
698 | pred_ind = torch.argsort(pred_pos, 1, descending=True)[:, :5]
699 | false_ans = torch.ones(pred_pos.shape[0], pred_pos.shape[1]).cuda()
700 | false_ans.scatter_(1, pred_ind, 0)
701 | a2 = a * false_ans
702 | q3 = copy.deepcopy(q)
703 | if dataset=='cpv1':
704 | q3.scatter_(1, w_ind, 18330)
705 | else:
706 | q3.scatter_(1, w_ind, 18455)
707 |
708 | ## third train
709 |
710 | pred_neg, loss_neg, _, feature_neg = model(v, q3, a2, b, None) # P = VQA(I,Q-, )
711 |
712 |
713 | #Compute the contras loss
714 | # inner product
715 | # pos = (feature_ori * feature_pos).sum(1)
716 | # neg = (feature_ori * feature_neg).sum(1)
717 | # cosine_similarity
718 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1)
719 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1)
720 |
721 | logit = torch.stack((pos, neg), 1) # [b, 2]
722 | softmax_logit = nn.functional.softmax(logit, 1) #[b, 2]
723 |
724 | contras_loss = - torch.log(softmax_logit[:, 0])
725 | # contras_loss += torch.log(softmax_logit[:, 1]) # add contras_neg
726 | contras_loss = contras_loss.mean()
727 |
728 |
729 | loss = loss_neg + args.contras_loss_weight * contras_loss
730 |
731 | if (loss != loss).any():
732 | raise ValueError("NaN loss")
733 | loss.backward()
734 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
735 | optim.step()
736 | optim.zero_grad()
737 |
738 | total_loss += loss.item() * q.size(0)
739 |
740 |
741 | else:
742 | ## first train
743 | pred, loss, _, feature_ori = model(v, q, a, b, None)
744 |
745 | visual_grad=torch.autograd.grad((pred * (a > 0).float()).sum(), v, create_graph=True)[0]
746 |
747 | if (loss != loss).any():
748 | raise ValueError("NaN loss")
749 | loss.backward(retain_graph=True)
750 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
751 | optim.step()
752 | optim.zero_grad()
753 |
754 | total_loss += loss.item() * q.size(0)
755 | batch_score = compute_score_with_logits(pred, a.data).sum()
756 | train_score += batch_score
757 |
758 | ##second train
759 | v_mask = torch.zeros(v.shape[0], 36).cuda()
760 | visual_grad_cam = visual_grad.sum(2)
761 | hint_sort, hint_ind = hintscore.sort(1, descending=True)
762 | v_ind = hint_ind[:, :top_hint]
763 | v_grad = visual_grad_cam.gather(1, v_ind)
764 |
765 | if topv==-1:
766 | v_grad_score,v_grad_ind=v_grad.sort(1,descending=True)
767 | v_grad_score=nn.functional.softmax(v_grad_score*10,dim=1)
768 | v_grad_sum=torch.cumsum(v_grad_score,dim=1)
769 | v_grad_mask=(v_grad_sum<=0.65).long()
770 | v_grad_mask[:,0] = 1
771 | v_mask_ind=v_grad_mask*v_ind
772 | for x in range(a.shape[0]):
773 | num=len(torch.nonzero(v_grad_mask[x]))
774 | v_mask[x].scatter_(0,v_mask_ind[x,:num],1)
775 | else:
776 | v_grad_ind = v_grad.sort(1, descending=True)[1][:, :topv]
777 | v_star = v_ind.gather(1, v_grad_ind)
778 | v_mask.scatter_(1, v_star, 1)
779 |
780 | pred_pos, _, _, feature_pos = model(v, q, None, b, v_mask) # P = VQA(I+,Q)
781 | v_mask_pos = v_mask
782 |
783 | pred_ind = torch.argsort(pred_pos, 1, descending=True)[:, :5]
784 | false_ans = torch.ones(pred_pos.shape[0], pred_pos.shape[1]).cuda()
785 | false_ans.scatter_(1, pred_ind, 0)
786 | a2 = a * false_ans
787 |
788 | v_mask_neg = 1 - v_mask_pos
789 |
790 | pred_neg, loss_neg, _, feature_neg = model(v, q, a2, b, v_mask_neg) # P = VQA(I-,Q)
791 |
792 | #Compute the contras loss
793 | # inner product
794 | # pos = (feature_ori * feature_pos).sum(1)
795 | # neg = (feature_ori * feature_neg).sum(1)
796 | # cosine_similarity
797 | pos = torch.cosine_similarity(feature_ori, feature_pos, dim=1)
798 | neg = torch.cosine_similarity(feature_ori, feature_neg, dim=1)
799 |
800 | logit = torch.stack((pos, neg), 1) # [b, 2]
801 | softmax_logit = nn.functional.softmax(logit, 1) #[b, 2]
802 |
803 | contras_loss = - torch.log(softmax_logit[:, 0])
804 | # contras_loss += torch.log(softmax_logit[:, 1]) # add contras_neg
805 | contras_loss = contras_loss.mean()
806 |
807 | loss = loss_neg + args.contras_loss_weight * contras_loss
808 |
809 | if (loss != loss).any():
810 | raise ValueError("NaN loss")
811 | loss.backward()
812 | nn.utils.clip_grad_norm_(model.parameters(), 0.25)
813 | optim.step()
814 | optim.zero_grad()
815 |
816 | total_loss += loss.item() * q.size(0)
817 |
818 |
819 | if mode=='updn':
820 | total_loss /= len(train_loader.dataset)
821 | else:
822 | total_loss /= len(train_loader.dataset) * 2
823 | train_score = 100 * train_score / len(train_loader.dataset)
824 |
825 | if run_eval:
826 | model.train(False)
827 | results = evaluate(model, eval_loader, qid2type)
828 | results["epoch"] = epoch + 1
829 | results["step"] = total_step
830 | results["train_loss"] = total_loss
831 | results["train_score"] = train_score
832 |
833 | model.train(True)
834 |
835 | eval_score = results["score"]
836 | bound = results["upper_bound"]
837 | yn = results['score_yesno']
838 | other = results['score_other']
839 | num = results['score_number']
840 |
841 | logger.write('epoch %d, time: %.2f' % (epoch, time.time() - t))
842 | logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score))
843 |
844 | if run_eval:
845 | logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound))
846 | logger.write('\tyn score: %.2f other score: %.2f num score: %.2f' % (100 * yn, 100 * other, 100 * num))
847 |
848 | if eval_score > best_eval_score:
849 | model_path = os.path.join(output, 'model.pth')
850 | torch.save(model.state_dict(), model_path)
851 | best_eval_score = eval_score
852 |
853 |
854 | def evaluate(model, dataloader, qid2type):
855 | score = 0
856 | upper_bound = 0
857 | score_yesno = 0
858 | score_number = 0
859 | score_other = 0
860 | total_yesno = 0
861 | total_number = 0
862 | total_other = 0
863 |
864 | for v, q, a, b, qids, _ in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"):
865 | v = Variable(v, requires_grad=False).cuda()
866 | q = Variable(q, requires_grad=False).cuda()
867 | pred, _,_ = model(v, q, None, None, None)
868 | batch_score = compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1)
869 | score += batch_score.sum()
870 | upper_bound += (a.max(1)[0]).sum()
871 | qids = qids.detach().cpu().int().numpy()
872 | for j in range(len(qids)):
873 | qid = qids[j]
874 | typ = qid2type[str(qid)]
875 | if typ == 'yes/no':
876 | score_yesno += batch_score[j]
877 | total_yesno += 1
878 | elif typ == 'other':
879 | score_other += batch_score[j]
880 | total_other += 1
881 | elif typ == 'number':
882 | score_number += batch_score[j]
883 | total_number += 1
884 | else:
885 | print('Hahahahahahahahahahaha')
886 |
887 |
888 | score = score / len(dataloader.dataset)
889 | upper_bound = upper_bound / len(dataloader.dataset)
890 | score_yesno /= total_yesno
891 | score_other /= total_other
892 | score_number /= total_number
893 |
894 | results = dict(
895 | score=score,
896 | upper_bound=upper_bound,
897 | score_yesno=score_yesno,
898 | score_other=score_other,
899 | score_number=score_number,
900 | )
901 | return results
902 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import errno
4 | import os
5 | import numpy as np
6 | # from PIL import Image
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | EPS = 1e-7
12 |
13 |
14 | def assert_eq(real, expected):
15 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected)
16 |
17 |
18 | def assert_array_eq(real, expected):
19 | assert (np.abs(real-expected) < EPS).all(), \
20 | '%s (true) vs %s (expected)' % (real, expected)
21 |
22 |
23 | def load_folder(folder, suffix):
24 | imgs = []
25 | for f in sorted(os.listdir(folder)):
26 | if f.endswith(suffix):
27 | imgs.append(os.path.join(folder, f))
28 | return imgs
29 |
30 |
31 | # def load_imageid(folder):
32 | # images = load_folder(folder, 'jpg')
33 | # img_ids = set()
34 | # for img in images:
35 | # img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1])
36 | # img_ids.add(img_id)
37 | # return img_ids
38 |
39 |
40 | # def pil_loader(path):
41 | # with open(path, 'rb') as f:
42 | # with Image.open(f) as img:
43 | # return img.convert('RGB')
44 |
45 |
46 | def weights_init(m):
47 | """custom weights initialization."""
48 | cname = m.__class__
49 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d:
50 | m.weight.data.normal_(0.0, 0.02)
51 | elif cname == nn.BatchNorm2d:
52 | m.weight.data.normal_(1.0, 0.02)
53 | m.bias.data.fill_(0)
54 | else:
55 | print('%s is not initialized.' % cname)
56 |
57 |
58 | def init_net(net, net_file):
59 | if net_file:
60 | net.load_state_dict(torch.load(net_file))
61 | else:
62 | net.apply(weights_init)
63 |
64 |
65 | def create_dir(path):
66 | if not os.path.exists(path):
67 | try:
68 | os.makedirs(path)
69 | except OSError as exc:
70 | if exc.errno != errno.EEXIST:
71 | raise
72 |
73 |
74 | class Logger(object):
75 | def __init__(self, output_name):
76 | dirname = os.path.dirname(output_name)
77 | if not os.path.exists(dirname):
78 | os.mkdir(dirname)
79 |
80 | self.log_file = open(output_name, 'w')
81 | self.infos = {}
82 |
83 | def append(self, key, val):
84 | vals = self.infos.setdefault(key, [])
85 | vals.append(val)
86 |
87 | def log(self, extra_msg=''):
88 | msgs = [extra_msg]
89 | for key, vals in self.infos.iteritems():
90 | msgs.append('%s %.6f' % (key, np.mean(vals)))
91 | msg = '\n'.join(msgs)
92 | self.log_file.write(msg + '\n')
93 | self.log_file.flush()
94 | self.infos = {}
95 | return msg
96 |
97 | def write(self, msg):
98 | self.log_file.write(msg + '\n')
99 | self.log_file.flush()
100 | print(msg)
101 |
--------------------------------------------------------------------------------
/vqa_debias_loss_functions.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict, defaultdict, Counter
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | import numpy as np
6 | import torch
7 | import inspect
8 |
9 |
10 | def convert_sigmoid_logits_to_binary_logprobs(logits):
11 | """computes log(sigmoid(logits)), log(1-sigmoid(logits))"""
12 | log_prob = -F.softplus(-logits)
13 | log_one_minus_prob = -logits + log_prob
14 | return log_prob, log_one_minus_prob
15 |
16 |
17 | def elementwise_logsumexp(a, b):
18 | """computes log(exp(x) + exp(b))"""
19 | return torch.max(a, b) + torch.log1p(torch.exp(-torch.abs(a - b)))
20 |
21 |
22 | def renormalize_binary_logits(a, b):
23 | """Normalize so exp(a) + exp(b) == 1"""
24 | norm = elementwise_logsumexp(a, b)
25 | return a - norm, b - norm
26 |
27 |
28 | class DebiasLossFn(nn.Module):
29 | """General API for our loss functions"""
30 |
31 | def forward(self, hidden, logits, bias, labels):
32 | """
33 | :param hidden: [batch, n_hidden] hidden features from the last layer in the model
34 | :param logits: [batch, n_answers_options] sigmoid logits for each answer option
35 | :param bias: [batch, n_answers_options]
36 | bias probabilities for each answer option between 0 and 1
37 | :param labels: [batch, n_answers_options]
38 | scores for each answer option, between 0 and 1
39 | :return: Scalar loss
40 | """
41 | raise NotImplementedError()
42 |
43 | def to_json(self):
44 | """Get a json representation of this loss function.
45 |
46 | We construct this by looking up the __init__ args
47 | """
48 | cls = self.__class__
49 | init = cls.__init__
50 | if init is object.__init__:
51 | return [] # No init args
52 |
53 | init_signature = inspect.getargspec(init)
54 | if init_signature.varargs is not None:
55 | raise NotImplementedError("varags not supported")
56 | if init_signature.keywords is not None:
57 | raise NotImplementedError("keywords not supported")
58 | args = [x for x in init_signature.args if x != "self"]
59 | out = OrderedDict()
60 | out["name"] = cls.__name__
61 | for key in args:
62 | out[key] = getattr(self, key)
63 | return out
64 |
65 |
66 | class Plain(DebiasLossFn):
67 | def forward(self, hidden, logits, bias, labels):
68 | loss = F.binary_cross_entropy_with_logits(logits, labels)
69 |
70 | loss *= labels.size(1)
71 | return loss
72 |
73 |
74 | class Focal(DebiasLossFn):
75 | def forward(self, hidden, logits, bias, labels):
76 | # import pdb;pdb.set_trace()
77 | focal_logits=torch.log(F.softmax(logits,dim=1)+1e-5) * ((1-F.softmax(bias,dim=1))*(1-F.softmax(bias,dim=1)))
78 | loss=F.binary_cross_entropy_with_logits(focal_logits,labels)
79 | loss*=labels.size(1)
80 | return loss
81 |
82 | class ReweightByInvBias(DebiasLossFn):
83 | def forward(self, hidden, logits, bias, labels):
84 | # Manually compute the binary cross entropy since the old version of torch always aggregates
85 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits)
86 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob)
87 | weights = (1 - bias)
88 | loss *= weights # Apply the weights
89 | return loss.sum() / weights.sum()
90 |
91 |
92 | class BiasProduct(DebiasLossFn):
93 | def __init__(self, smooth=True, smooth_init=-1, constant_smooth=0.0):
94 | """
95 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it
96 | :param smooth_init: How to initialize `a`
97 | :param constant_smooth: Constant to add to the bias to smooth it
98 | """
99 | super(BiasProduct, self).__init__()
100 | self.constant_smooth = constant_smooth
101 | self.smooth_init = smooth_init
102 | self.smooth = smooth
103 | if smooth:
104 | self.smooth_param = torch.nn.Parameter(
105 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32)))
106 | else:
107 | self.smooth_param = None
108 |
109 | def forward(self, hidden, logits, bias, labels):
110 | smooth = self.constant_smooth
111 | if self.smooth:
112 | smooth += F.sigmoid(self.smooth_param)
113 |
114 | # Convert the bias into log-space, with a factor for both the
115 | # binary outputs for each answer option
116 | bias_lp = torch.log(bias + smooth)
117 | bias_l_inv = torch.log1p(-bias + smooth)
118 |
119 | # Convert the the logits into log-space with the same format
120 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits)
121 | # import pdb;pdb.set_trace()
122 |
123 | # Add the bias
124 | log_prob += bias_lp
125 | log_one_minus_prob += bias_l_inv
126 |
127 | # Re-normalize the factors in logspace
128 | log_prob, log_one_minus_prob = renormalize_binary_logits(log_prob, log_one_minus_prob)
129 |
130 | # Compute the binary cross entropy
131 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0)
132 | return loss
133 |
134 |
135 | class LearnedMixin(DebiasLossFn):
136 | def __init__(self, w, smooth=True, smooth_init=-1, constant_smooth=0.0):
137 | """
138 | :param w: Weight of the entropy penalty
139 | :param smooth: Add a learned sigmoid(a) factor to the bias to smooth it
140 | :param smooth_init: How to initialize `a`
141 | :param constant_smooth: Constant to add to the bias to smooth it
142 | """
143 | super(LearnedMixin, self).__init__()
144 | self.w = w
145 | # self.w=0
146 | self.smooth_init = smooth_init
147 | self.constant_smooth = constant_smooth
148 | self.bias_lin = torch.nn.Linear(1024, 1)
149 | self.smooth = smooth
150 | if self.smooth:
151 | self.smooth_param = torch.nn.Parameter(
152 | torch.from_numpy(np.full((1,), smooth_init, dtype=np.float32)))
153 | else:
154 | self.smooth_param = None
155 |
156 | def forward(self, hidden, logits, bias, labels):
157 | factor = self.bias_lin.forward(hidden) # [batch, 1]
158 | factor = F.softplus(factor)
159 |
160 | bias = torch.stack([bias, 1 - bias], 2) # [batch, n_answers, 2]
161 |
162 | # Smooth
163 | bias += self.constant_smooth
164 | if self.smooth:
165 | soften_factor = F.sigmoid(self.smooth_param)
166 | bias = bias + soften_factor.unsqueeze(1)
167 |
168 | bias = torch.log(bias) # Convert to logspace
169 |
170 | # Scale by the factor
171 | # [batch, n_answers, 2] * [batch, 1, 1] -> [batch, n_answers, 2]
172 | bias = bias * factor.unsqueeze(1)
173 |
174 | log_prob, log_one_minus_prob = convert_sigmoid_logits_to_binary_logprobs(logits)
175 | log_probs = torch.stack([log_prob, log_one_minus_prob], 2)
176 |
177 | # Add the bias in
178 | logits = bias + log_probs
179 |
180 | # Renormalize to get log probabilities
181 | log_prob, log_one_minus_prob = renormalize_binary_logits(logits[:, :, 0], logits[:, :, 1])
182 |
183 | # Compute loss
184 | loss = -(log_prob * labels + (1 - labels) * log_one_minus_prob).sum(1).mean(0)
185 |
186 | # Re-normalized version of the bias
187 | bias_norm = elementwise_logsumexp(bias[:, :, 0], bias[:, :, 1])
188 | bias_logprob = bias - bias_norm.unsqueeze(2)
189 |
190 | # Compute and add the entropy penalty
191 | entropy = -(torch.exp(bias_logprob) * bias_logprob).sum(2).mean()
192 | return loss + self.w * entropy
--------------------------------------------------------------------------------