├── .gitignore ├── .gitmodules ├── README.md ├── config.py ├── dan.py ├── data.py ├── logs └── .dummy ├── model.py ├── preprocess-images.py ├── preprocess-vocab.py ├── train.py ├── utils.py ├── val_acc.png └── view-log.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | vqa 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "resnet"] 2 | path = resnet 3 | url = https://github.com/Cyanogenoid/pytorch-resnet 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Dual Attention Networks for Visual Question Answering 2 | 3 | This is a PyTorch implementation of [Dual Attention Networks for Multimodal Reasoning and Matching](https://arxiv.org/pdf/1611.00471.pdf). I forked the code from [Cyanogenoid](https://github.com/Cyanogenoid)'s [pytorch-vqa](https://github.com/Cyanogenoid/pytorch-vqa) and replaced the model with my implementation of Dual Attention Networks because doing all the data preprocessing and loading stuff is kinda nasty. Please see [pytorch-vqa](https://github.com/Cyanogenoid/pytorch-vqa) on how the data was preprocessed and extracted. 4 | 5 | Differences between paper and this model 6 | - Learning rate decay: the original paper halved the learning after 30 epochs and trained for another 30 epochs. we used the forked code optimization and halved learning rate after 50k iterations. 7 | - Answer scoring: the original paper used only a single layer to score the answers with the memory vector. Our implementation uses a 2 layer network. 8 | - Pretrained word embeddings: the original paper used 512 as word embedding dimension. For the below graph, we used 300 and load pretrained Glove vectors. 9 | 10 | Our implementation reaches around 61% validation accuracy after running 20 epochs. 11 | ![Learning graph](val_acc.png) 12 | 13 | ### Requirements 14 | 15 | Python version 3 16 | - h5py 17 | - torch 18 | - torchvision 19 | - tqdm 20 | - torchtext 21 | 22 | Plotting 23 | - numpy 24 | - matplotlib 25 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # paths 2 | qa_path = 'vqa' # directory containing the question and annotation jsons 3 | train_path = 'mscoco/train2014' # directory of training images 4 | val_path = 'mscoco/val2014' # directory of validation images 5 | test_path = 'mscoco/test2015' # directory of test images 6 | preprocessed_path = './resnet/resnet-14x14.h5' # path where preprocessed features are saved to and loaded from 7 | vocabulary_path = 'vocab.json' # path where the used vocabularies for question and answers are saved to 8 | 9 | task = 'OpenEnded' 10 | dataset = 'mscoco' 11 | 12 | # preprocess config 13 | preprocess_batch_size = 64 14 | image_size = 448 # scale shorter end of image to this size and centre crop 15 | output_size = image_size // 32 # size of the feature maps after processing through a network 16 | output_features = 2048 # number of feature maps thereof 17 | central_fraction = 0.875 # only take this much of the centre when scaling and centre cropping 18 | 19 | # model config 20 | pretrained = True 21 | embedding_dim = 300 22 | hidden_size = 512 23 | max_answers = 3000 24 | 25 | # training config 26 | epochs = 50 27 | batch_size = 128 28 | initial_lr = 1e-3 # default Adam lr 29 | lr_halflife = 50000 # in iterations 30 | data_workers = 8 31 | 32 | -------------------------------------------------------------------------------- /dan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torchtext.vocab as vocab 6 | 7 | from model import Classifier 8 | 9 | class TextEncoder(nn.Module): 10 | def __init__(self, num_embeddings, embedding_dim, hidden_size): 11 | super(TextEncoder, self).__init__() 12 | self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 13 | self.bilstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, 14 | batch_first=False, 15 | bidirectional=True, 16 | dropout=0.5) 17 | 18 | self.dropout = nn.Dropout(p=0.5) 19 | 20 | def forward(self, x): 21 | embed_output = self.embed(x) 22 | bilstm_output, _ = self.bilstm(self.dropout(embed_output)) 23 | return bilstm_output 24 | 25 | def load_pretrained(self, dictionary): 26 | print("Loading pretrained weights...") 27 | # Load pretrained vectors for embedding layer 28 | glove = vocab.GloVe(name='6B', dim=self.embed.embedding_dim) 29 | 30 | # Build weight matrix here 31 | pretrained_weight = self.embed.weight.data 32 | for word, idx in dictionary.items(): 33 | if word.lower() in glove.stoi: 34 | vector = glove.vectors[ glove.stoi[word.lower()] ] 35 | pretrained_weight[ idx ] = vector 36 | 37 | self.embed.weight = nn.Parameter(pretrained_weight) 38 | 39 | class rDAN(nn.Module): 40 | def __init__(self, num_embeddings, embedding_dim, hidden_size, answer_size, k=2): 41 | super(rDAN, self).__init__() 42 | 43 | # Build Text Encoder 44 | self.textencoder = TextEncoder(num_embeddings=num_embeddings, 45 | embedding_dim=embedding_dim, 46 | hidden_size=hidden_size) 47 | 48 | memory_size = 2 * hidden_size # bidirectional 49 | 50 | # Visual Attention 51 | self.Wv = nn.Linear(in_features=2048, out_features=hidden_size) 52 | self.Wvm = nn.Linear(in_features=memory_size, out_features=hidden_size) 53 | self.Wvh = nn.Linear(in_features=hidden_size, out_features=1) 54 | self.P = nn.Linear(in_features=2048, out_features=memory_size) 55 | 56 | # Textual Attention 57 | self.Wu = nn.Linear(in_features=2*hidden_size, out_features=hidden_size) 58 | self.Wum = nn.Linear(in_features=memory_size, out_features=hidden_size) 59 | self.Wuh = nn.Linear(in_features=hidden_size, out_features=1) 60 | 61 | self.Wans = nn.Linear(in_features=memory_size, out_features=answer_size) 62 | 63 | # Scoring Network 64 | self.classifier = Classifier(memory_size, hidden_size, answer_size, 0.5) 65 | 66 | # Dropout 67 | self.dropout = nn.Dropout(p=0.5) 68 | 69 | # Activations 70 | self.tanh = nn.Tanh() 71 | self.softmax = nn.Softmax(0) # Softmax over first dimension 72 | 73 | # Loops 74 | self.k = k 75 | 76 | def forward(self, visual, text): 77 | 78 | batch_size = visual.shape[0] 79 | 80 | # Prepare Visual Features 81 | visual = visual.view(batch_size, 2048, -1) 82 | vns = visual.permute(2,0,1) # (nregion, batch_size, dim) 83 | 84 | # Prepare Textual Features 85 | text = text.permute(1,0) 86 | uts = self.textencoder.forward(text) # (seq_len, batch_size, dim) 87 | 88 | # Initialize Memory 89 | u = uts.mean(0) 90 | v = self.tanh( self.P( vns.mean(0) )) 91 | memory = v * u 92 | 93 | # K indicates the number of hops 94 | for k in range(self.k): 95 | # Compute Visual Attention 96 | hv = self.tanh(self.Wv(self.dropout(vns))) * self.tanh(self.Wvm(self.dropout(memory))) 97 | # attention weights for every region 98 | alphaV = self.softmax(self.Wvh(self.dropout(hv))) #(seq_len, batch_size, memory_size) 99 | # Sum over regions 100 | v = self.tanh(self.P(alphaV * vns)).sum(0) 101 | 102 | # Text 103 | # (seq_len, batch_size, dim) * (batch_size, dim) 104 | hu = self.tanh(self.Wu(self.dropout(uts))) * self.tanh(self.Wum(self.dropout(memory))) 105 | # attention weights for text features 106 | alphaU = self.softmax(self.Wuh(self.dropout(hu))) # (seq_len, batch_size, memory_size) 107 | # Sum over sequence 108 | u = (alphaU * uts).sum(0) # Sum over sequence 109 | 110 | # Build Memory 111 | memory = memory + u * v 112 | 113 | # We compute scores using a classifier 114 | scores = self.classifier(memory) 115 | 116 | return scores 117 | 118 | if __name__ == "__main__": 119 | pass 120 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path 4 | import re 5 | 6 | from PIL import Image 7 | import h5py 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | 12 | import config 13 | import utils 14 | 15 | 16 | def get_loader(train=False, val=False, test=False): 17 | """ Returns a data loader for the desired split """ 18 | assert train + val + test == 1, 'need to set exactly one of {train, val, test} to True' 19 | split = VQA( 20 | utils.path_for(train=train, val=val, test=test, question=True), 21 | utils.path_for(train=train, val=val, test=test, answer=True), 22 | config.preprocessed_path, 23 | answerable_only=train, 24 | ) 25 | loader = torch.utils.data.DataLoader( 26 | split, 27 | batch_size=config.batch_size, 28 | shuffle=train, # only shuffle the data in training 29 | pin_memory=True, 30 | num_workers=config.data_workers, 31 | collate_fn=collate_fn, 32 | ) 33 | return loader 34 | 35 | 36 | def collate_fn(batch): 37 | # put question lengths in descending order so that we can use packed sequences later 38 | batch.sort(key=lambda x: x[-1], reverse=True) 39 | return data.dataloader.default_collate(batch) 40 | 41 | 42 | class VQA(data.Dataset): 43 | """ VQA dataset, open-ended """ 44 | def __init__(self, questions_path, answers_path, image_features_path, answerable_only=False): 45 | super(VQA, self).__init__() 46 | with open(questions_path, 'r') as fd: 47 | questions_json = json.load(fd) 48 | with open(answers_path, 'r') as fd: 49 | answers_json = json.load(fd) 50 | with open(config.vocabulary_path, 'r') as fd: 51 | vocab_json = json.load(fd) 52 | self._check_integrity(questions_json, answers_json) 53 | 54 | # vocab 55 | self.vocab = vocab_json 56 | self.token_to_index = self.vocab['question'] 57 | self.answer_to_index = self.vocab['answer'] 58 | 59 | # q and a 60 | self.questions = list(prepare_questions(questions_json)) 61 | self.answers = list(prepare_answers(answers_json)) 62 | self.questions = [self._encode_question(q) for q in self.questions] 63 | self.answers = [self._encode_answers(a) for a in self.answers] 64 | 65 | # v 66 | self.image_features_path = image_features_path 67 | self.coco_id_to_index = self._create_coco_id_to_index() 68 | self.coco_ids = [q['image_id'] for q in questions_json['questions']] 69 | 70 | # only use questions that have at least one answer? 71 | self.answerable_only = answerable_only 72 | if self.answerable_only: 73 | self.answerable = self._find_answerable() 74 | 75 | @property 76 | def max_question_length(self): 77 | if not hasattr(self, '_max_length'): 78 | self._max_length = max(map(len, self.questions)) 79 | return self._max_length 80 | 81 | @property 82 | def num_tokens(self): 83 | return len(self.token_to_index) + 1 # add 1 for token at index 0 84 | 85 | def _create_coco_id_to_index(self): 86 | """ Create a mapping from a COCO image id into the corresponding index into the h5 file """ 87 | with h5py.File(self.image_features_path, 'r') as features_file: 88 | coco_ids = features_file['ids'][()] 89 | coco_id_to_index = {id: i for i, id in enumerate(coco_ids)} 90 | return coco_id_to_index 91 | 92 | def _check_integrity(self, questions, answers): 93 | """ Verify that we are using the correct data """ 94 | qa_pairs = list(zip(questions['questions'], answers['annotations'])) 95 | assert all(q['question_id'] == a['question_id'] for q, a in qa_pairs), 'Questions not aligned with answers' 96 | assert all(q['image_id'] == a['image_id'] for q, a in qa_pairs), 'Image id of question and answer don\'t match' 97 | assert questions['data_type'] == answers['data_type'], 'Mismatched data types' 98 | assert questions['data_subtype'] == answers['data_subtype'], 'Mismatched data subtypes' 99 | 100 | def _find_answerable(self): 101 | """ Create a list of indices into questions that will have at least one answer that is in the vocab """ 102 | answerable = [] 103 | for i, answers in enumerate(self.answers): 104 | answer_has_index = len(answers.nonzero()) > 0 105 | # store the indices of anything that is answerable 106 | if answer_has_index: 107 | answerable.append(i) 108 | return answerable 109 | 110 | def _encode_question(self, question): 111 | """ Turn a question into a vector of indices and a question length """ 112 | vec = torch.zeros(self.max_question_length).long() 113 | for i, token in enumerate(question): 114 | index = self.token_to_index.get(token, 0) 115 | vec[i] = index 116 | return vec, len(question) 117 | 118 | def _encode_answers(self, answers): 119 | """ Turn an answer into a vector """ 120 | # answer vec will be a vector of answer counts to determine which answers will contribute to the loss. 121 | # this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up 122 | # to get the loss that is weighted by how many humans gave that answer 123 | answer_vec = torch.zeros(len(self.answer_to_index)) 124 | for answer in answers: 125 | index = self.answer_to_index.get(answer) 126 | if index is not None: 127 | answer_vec[index] += 1 128 | return answer_vec 129 | 130 | def _load_image(self, image_id): 131 | """ Load an image """ 132 | if not hasattr(self, 'features_file'): 133 | # Loading the h5 file has to be done here and not in __init__ because when the DataLoader 134 | # forks for multiple works, every child would use the same file object and fail 135 | # Having multiple readers using different file objects is fine though, so we just init in here. 136 | self.features_file = h5py.File(self.image_features_path, 'r') 137 | index = self.coco_id_to_index[image_id] 138 | dataset = self.features_file['features'] 139 | img = dataset[index].astype('float32') 140 | return torch.from_numpy(img) 141 | 142 | def __getitem__(self, item): 143 | if self.answerable_only: 144 | # change of indices to only address answerable questions 145 | item = self.answerable[item] 146 | 147 | q, q_length = self.questions[item] 148 | a = self.answers[item] 149 | image_id = self.coco_ids[item] 150 | v = self._load_image(image_id) 151 | # since batches are re-ordered for PackedSequence's, the original question order is lost 152 | # we return `item` so that the order of (v, q, a) triples can be restored if desired 153 | # without shuffling in the dataloader, these will be in the order that they appear in the q and a json's. 154 | return v, q, a, item, q_length 155 | 156 | def __len__(self): 157 | if self.answerable_only: 158 | return len(self.answerable) 159 | else: 160 | return len(self.questions) 161 | 162 | 163 | # this is used for normalizing questions 164 | _special_chars = re.compile('[^a-z0-9 ]*') 165 | 166 | # these try to emulate the original normalization scheme for answers 167 | _period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)') 168 | _comma_strip = re.compile(r'(\d)(,)(\d)') 169 | _punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!') 170 | _punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars))) 171 | _punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars)) 172 | 173 | 174 | def prepare_questions(questions_json): 175 | """ Tokenize and normalize questions from a given question json in the usual VQA format. """ 176 | questions = [q['question'] for q in questions_json['questions']] 177 | for question in questions: 178 | question = question.lower()[:-1] 179 | yield question.split(' ') 180 | 181 | 182 | def prepare_answers(answers_json): 183 | """ Normalize answers from a given answer json in the usual VQA format. """ 184 | answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in answers_json['annotations']] 185 | # The only normalization that is applied to both machine generated answers as well as 186 | # ground truth answers is replacing most punctuation with space (see [0] and [1]). 187 | # Since potential machine generated answers are just taken from most common answers, applying the other 188 | # normalizations is not needed, assuming that the human answers are already normalized. 189 | # [0]: http://visualqa.org/evaluation.html 190 | # [1]: https://github.com/VT-vision-lab/VQA/blob/3849b1eae04a0ffd83f56ad6f70ebd0767e09e0f/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L96 191 | 192 | def process_punctuation(s): 193 | # the original is somewhat broken, so things that look odd here might just be to mimic that behaviour 194 | # this version should be faster since we use re instead of repeated operations on str's 195 | if _punctuation.search(s) is None: 196 | return s 197 | s = _punctuation_with_a_space.sub('', s) 198 | if re.search(_comma_strip, s) is not None: 199 | s = s.replace(',', '') 200 | s = _punctuation.sub(' ', s) 201 | s = _period_strip.sub('', s) 202 | return s.strip() 203 | 204 | for answer_list in answers: 205 | yield list(map(process_punctuation, answer_list)) 206 | 207 | 208 | class CocoImages(data.Dataset): 209 | """ Dataset for MSCOCO images located in a folder on the filesystem """ 210 | def __init__(self, path, transform=None): 211 | super(CocoImages, self).__init__() 212 | self.path = path 213 | self.id_to_filename = self._find_images() 214 | self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order 215 | print('found {} images in {}'.format(len(self), self.path)) 216 | self.transform = transform 217 | 218 | def _find_images(self): 219 | id_to_filename = {} 220 | for filename in os.listdir(self.path): 221 | if not filename.endswith('.jpg'): 222 | continue 223 | id_and_extension = filename.split('_')[-1] 224 | id = int(id_and_extension.split('.')[0]) 225 | id_to_filename[id] = filename 226 | return id_to_filename 227 | 228 | def __getitem__(self, item): 229 | id = self.sorted_ids[item] 230 | path = os.path.join(self.path, self.id_to_filename[id]) 231 | img = Image.open(path).convert('RGB') 232 | 233 | if self.transform is not None: 234 | img = self.transform(img) 235 | return id, img 236 | 237 | def __len__(self): 238 | return len(self.sorted_ids) 239 | 240 | 241 | class Composite(data.Dataset): 242 | """ Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """ 243 | def __init__(self, *datasets): 244 | self.datasets = datasets 245 | 246 | def __getitem__(self, item): 247 | current = self.datasets[0] 248 | for d in self.datasets: 249 | if item < len(d): 250 | return d[item] 251 | item -= len(d) 252 | else: 253 | raise IndexError('Index too large for composite dataset') 254 | 255 | def __len__(self): 256 | return sum(map(len, self.datasets)) 257 | -------------------------------------------------------------------------------- /logs/.dummy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzuhsial/pytorch-vqa-dan/4d42acecf47798b84daa13a1733700103a418586/logs/.dummy -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | 7 | import config 8 | 9 | 10 | class Net(nn.Module): 11 | """ Re-implementation of ``Show, Ask, Attend, and Answer: A Strong Baseline For Visual Question Answering'' [0] 12 | 13 | [0]: https://arxiv.org/abs/1704.03162 14 | """ 15 | 16 | def __init__(self, embedding_tokens): 17 | super(Net, self).__init__() 18 | question_features = 1024 19 | vision_features = config.output_features 20 | glimpses = 2 21 | 22 | self.text = TextProcessor( 23 | embedding_tokens=embedding_tokens, 24 | embedding_features=300, 25 | lstm_features=question_features, 26 | drop=0.5, 27 | ) 28 | self.attention = Attention( 29 | v_features=vision_features, 30 | q_features=question_features, 31 | mid_features=512, 32 | glimpses=2, 33 | drop=0.5, 34 | ) 35 | self.classifier = Classifier( 36 | in_features=glimpses * vision_features + question_features, 37 | mid_features=1024, 38 | out_features=config.max_answers, 39 | drop=0.5, 40 | ) 41 | 42 | for m in self.modules(): 43 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 44 | init.xavier_uniform(m.weight) 45 | if m.bias is not None: 46 | m.bias.data.zero_() 47 | 48 | def forward(self, v, q, q_len): 49 | q = self.text(q, list(q_len.data)) 50 | 51 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 52 | a = self.attention(v, q) 53 | v = apply_attention(v, a) 54 | 55 | combined = torch.cat([v, q], dim=1) 56 | answer = self.classifier(combined) 57 | return answer 58 | 59 | 60 | class Classifier(nn.Sequential): 61 | def __init__(self, in_features, mid_features, out_features, drop=0.0): 62 | super(Classifier, self).__init__() 63 | self.add_module('drop1', nn.Dropout(drop)) 64 | self.add_module('lin1', nn.Linear(in_features, mid_features)) 65 | self.add_module('relu', nn.ReLU()) 66 | self.add_module('drop2', nn.Dropout(drop)) 67 | self.add_module('lin2', nn.Linear(mid_features, out_features)) 68 | 69 | 70 | class TextProcessor(nn.Module): 71 | def __init__(self, embedding_tokens, embedding_features, lstm_features, drop=0.0): 72 | super(TextProcessor, self).__init__() 73 | self.embedding = nn.Embedding(embedding_tokens, embedding_features, padding_idx=0) 74 | self.drop = nn.Dropout(drop) 75 | self.tanh = nn.Tanh() 76 | self.lstm = nn.LSTM(input_size=embedding_features, 77 | hidden_size=lstm_features, 78 | num_layers=1) 79 | self.features = lstm_features 80 | 81 | self._init_lstm(self.lstm.weight_ih_l0) 82 | self._init_lstm(self.lstm.weight_hh_l0) 83 | self.lstm.bias_ih_l0.data.zero_() 84 | self.lstm.bias_hh_l0.data.zero_() 85 | 86 | init.xavier_uniform(self.embedding.weight) 87 | 88 | def _init_lstm(self, weight): 89 | for w in weight.chunk(4, 0): 90 | init.xavier_uniform(w) 91 | 92 | def forward(self, q, q_len): 93 | embedded = self.embedding(q) 94 | tanhed = self.tanh(self.drop(embedded)) 95 | packed = pack_padded_sequence(tanhed, q_len, batch_first=True) 96 | _, (_, c) = self.lstm(packed) 97 | return c.squeeze(0) 98 | 99 | 100 | class Attention(nn.Module): 101 | def __init__(self, v_features, q_features, mid_features, glimpses, drop=0.0): 102 | super(Attention, self).__init__() 103 | self.v_conv = nn.Conv2d(v_features, mid_features, 1, bias=False) # let self.lin take care of bias 104 | self.q_lin = nn.Linear(q_features, mid_features) 105 | self.x_conv = nn.Conv2d(mid_features, glimpses, 1) 106 | 107 | self.drop = nn.Dropout(drop) 108 | self.relu = nn.ReLU(inplace=True) 109 | 110 | def forward(self, v, q): 111 | v = self.v_conv(self.drop(v)) 112 | q = self.q_lin(self.drop(q)) 113 | q = tile_2d_over_nd(q, v) 114 | x = self.relu(v + q) 115 | x = self.x_conv(self.drop(x)) 116 | return x 117 | 118 | 119 | def apply_attention(input, attention): 120 | """ Apply any number of attention maps over the input. 121 | The attention map has to have the same size in all dimensions except dim=1. 122 | """ 123 | n, c = input.size()[:2] 124 | glimpses = attention.size(1) 125 | 126 | # flatten the spatial dims into the third dim, since we don't need to care about how they are arranged 127 | input = input.view(n, c, -1) 128 | attention = attention.view(n, glimpses, -1) 129 | s = input.size(2) 130 | 131 | # apply a softmax to each attention map separately 132 | # since softmax only takes 2d inputs, we have to collapse the first two dimensions together 133 | # so that each glimpse is normalized separately 134 | attention = attention.view(n * glimpses, -1) 135 | attention = F.softmax(attention) 136 | 137 | # apply the weighting by creating a new dim to tile both tensors over 138 | target_size = [n, glimpses, c, s] 139 | input = input.view(n, 1, c, s).expand(*target_size) 140 | attention = attention.view(n, glimpses, 1, s).expand(*target_size) 141 | weighted = input * attention 142 | # sum over only the spatial dimension 143 | weighted_mean = weighted.sum(dim=3) 144 | # the shape at this point is (n, glimpses, c, 1) 145 | return weighted_mean.view(n, -1) 146 | 147 | 148 | def tile_2d_over_nd(feature_vector, feature_map): 149 | """ Repeat the same feature vector over all spatial positions of a given feature map. 150 | The feature vector should have the same batch size and number of features as the feature map. 151 | """ 152 | n, c = feature_vector.size() 153 | spatial_size = feature_map.dim() - 2 154 | tiled = feature_vector.view(n, c, *([1] * spatial_size)).expand_as(feature_map) 155 | return tiled 156 | -------------------------------------------------------------------------------- /preprocess-images.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.backends.cudnn as cudnn 5 | import torch.utils.data 6 | import torchvision.models as models 7 | from tqdm import tqdm 8 | 9 | import config 10 | import data 11 | import utils 12 | from resnet import resnet as caffe_resnet 13 | 14 | 15 | class Net(nn.Module): 16 | def __init__(self): 17 | super(Net, self).__init__() 18 | self.model = caffe_resnet.resnet152(pretrained=True) 19 | 20 | def save_output(module, input, output): 21 | self.buffer = output 22 | self.model.layer4.register_forward_hook(save_output) 23 | 24 | def forward(self, x): 25 | self.model(x) 26 | return self.buffer 27 | 28 | 29 | def create_coco_loader(*paths): 30 | transform = utils.get_transform(config.image_size, config.central_fraction) 31 | datasets = [data.CocoImages(path, transform=transform) for path in paths] 32 | dataset = data.Composite(*datasets) 33 | data_loader = torch.utils.data.DataLoader( 34 | dataset, 35 | batch_size=config.preprocess_batch_size, 36 | num_workers=config.data_workers, 37 | shuffle=False, 38 | pin_memory=True, 39 | ) 40 | return data_loader 41 | 42 | 43 | def main(): 44 | cudnn.benchmark = True 45 | 46 | net = Net().cuda() 47 | net.eval() 48 | 49 | loader = create_coco_loader(config.train_path, config.val_path) 50 | features_shape = ( 51 | len(loader.dataset), 52 | config.output_features, 53 | config.output_size, 54 | config.output_size 55 | ) 56 | 57 | with h5py.File(config.preprocessed_path, libver='latest') as fd: 58 | features = fd.create_dataset('features', shape=features_shape, dtype='float16') 59 | coco_ids = fd.create_dataset('ids', shape=(len(loader.dataset),), dtype='int32') 60 | 61 | i = j = 0 62 | for ids, imgs in tqdm(loader): 63 | imgs = Variable(imgs.cuda(async=True), volatile=True) 64 | out = net(imgs) 65 | 66 | j = i + imgs.size(0) 67 | features[i:j, :, :] = out.data.cpu().numpy().astype('float16') 68 | coco_ids[i:j] = ids.numpy().astype('int32') 69 | i = j 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /preprocess-vocab.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import Counter 3 | import itertools 4 | 5 | import config 6 | import data 7 | import utils 8 | 9 | 10 | def extract_vocab(iterable, top_k=None, start=0): 11 | """ Turns an iterable of list of tokens into a vocabulary. 12 | These tokens could be single answers or word tokens in questions. 13 | """ 14 | all_tokens = itertools.chain.from_iterable(iterable) 15 | counter = Counter(all_tokens) 16 | if top_k: 17 | most_common = counter.most_common(top_k) 18 | most_common = (t for t, c in most_common) 19 | else: 20 | most_common = counter.keys() 21 | # descending in count, then lexicographical order 22 | tokens = sorted(most_common, key=lambda x: (counter[x], x), reverse=True) 23 | vocab = {t: i for i, t in enumerate(tokens, start=start)} 24 | return vocab 25 | 26 | 27 | def main(): 28 | questions = utils.path_for(train=True, question=True) 29 | answers = utils.path_for(train=True, answer=True) 30 | 31 | with open(questions, 'r') as fd: 32 | questions = json.load(fd) 33 | with open(answers, 'r') as fd: 34 | answers = json.load(fd) 35 | 36 | questions = data.prepare_questions(questions) 37 | answers = data.prepare_answers(answers) 38 | 39 | question_vocab = extract_vocab(questions, start=1) 40 | answer_vocab = extract_vocab(answers, top_k=config.max_answers) 41 | 42 | vocabs = { 43 | 'question': question_vocab, 44 | 'answer': answer_vocab, 45 | } 46 | with open(config.vocabulary_path, 'w') as fd: 47 | json.dump(vocabs, fd) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path 3 | import math 4 | import json 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | import torch.backends.cudnn as cudnn 11 | from tqdm import tqdm 12 | 13 | import config 14 | import data 15 | #import model 16 | from dan import TextEncoder, rDAN 17 | import utils 18 | 19 | 20 | def update_learning_rate(optimizer, iteration): 21 | lr = config.initial_lr * 0.5**(float(iteration) / config.lr_halflife) 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | 25 | total_iterations = 0 26 | 27 | def run(net, loader, optimizer, tracker, train=False, prefix='', epoch=0): 28 | """ Run an epoch over the given loader """ 29 | if train: 30 | net.train() 31 | tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99} 32 | else: 33 | net.eval() 34 | tracker_class, tracker_params = tracker.MeanMonitor, {} 35 | answ = [] 36 | idxs = [] 37 | accs = [] 38 | 39 | tq = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) 40 | loss_tracker = tracker.track('{}_loss'.format(prefix), tracker_class(**tracker_params)) 41 | acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) 42 | 43 | log_softmax = nn.LogSoftmax().cuda() 44 | for v, q, a, idx, q_len in tq: 45 | var_params = { 46 | 'volatile': not train, 47 | 'requires_grad': False, 48 | } 49 | v = Variable(v.cuda(async=True), **var_params) 50 | q = Variable(q.cuda(async=True), **var_params) 51 | a = Variable(a.cuda(async=True), **var_params) 52 | q_len = Variable(q_len.cuda(async=False), **var_params) 53 | 54 | out = net(v, q) 55 | nll = -log_softmax(out) 56 | loss = (nll * a / 10).sum(dim=1).mean() 57 | acc = utils.batch_accuracy(out.data, a.data).cpu() 58 | 59 | if train: 60 | global total_iterations 61 | update_learning_rate(optimizer, total_iterations) 62 | 63 | optimizer.zero_grad() 64 | loss.backward() 65 | optimizer.step() 66 | 67 | total_iterations += 1 68 | else: 69 | # store information about evaluation of this minibatch 70 | _, answer = out.data.cpu().max(dim=1) 71 | answ.append(answer.view(-1)) 72 | accs.append(acc.view(-1)) 73 | idxs.append(idx.view(-1).clone()) 74 | 75 | loss_tracker.append(loss.data[0]) 76 | acc_tracker.append(acc.mean()) 77 | fmt = '{:.4f}'.format 78 | tq.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) 79 | if not os.path.exists('model_'+str(config.run_number)): 80 | os.mkdir('model_'+str(config.run_number)) 81 | torch.save(net.state_dict(),'model_' + str(config.run_number)+'/model_path.' + str(config.run_number) + '_' +str(epoch)+ '.pkl') 82 | 83 | if not train: 84 | answ = list(torch.cat(answ, dim=0)) 85 | accs = list(torch.cat(accs, dim=0)) 86 | idxs = list(torch.cat(idxs, dim=0)) 87 | return answ, accs, idxs 88 | 89 | 90 | def main(): 91 | if len(sys.argv) > 1: 92 | name = ' '.join(sys.argv[1:]) 93 | else: 94 | from datetime import datetime 95 | name = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 96 | target_name = os.path.join('logs', '{}.pth'.format(name)) 97 | print('will save to {}'.format(target_name)) 98 | 99 | cudnn.benchmark = True 100 | 101 | train_loader = data.get_loader(train=True) 102 | val_loader = data.get_loader(val=True) 103 | 104 | # Build Model 105 | vocab_size = train_loader.dataset.num_tokens 106 | model = rDAN(num_embeddings=vocab_size, 107 | embedding_dim=config.embedding_dim, 108 | hidden_size=config.hidden_size, 109 | answer_size=config.max_answers) 110 | if config.pretrained: 111 | model.textencoder.load_pretrained(train_loader.dataset.vocab['question']) 112 | net = nn.DataParallel(model).cuda() 113 | 114 | optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad]) 115 | 116 | tracker = utils.Tracker() 117 | config_as_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} 118 | 119 | for i in range(config.epochs): 120 | _ = run(net, train_loader, optimizer, tracker, train=True, prefix='train', epoch=i) 121 | r = run(net, val_loader, optimizer, tracker, train=False, prefix='val', epoch=i) 122 | 123 | results = { 124 | 'name': name, 125 | 'tracker': tracker.to_dict(), 126 | 'config': config_as_dict, 127 | 'weights': net.state_dict(), 128 | 'eval': { 129 | 'answers': r[0], 130 | 'accuracies': r[1], 131 | 'idx': r[2], 132 | }, 133 | 'vocab': train_loader.dataset.vocab, 134 | } 135 | torch.save(results, target_name) 136 | 137 | 138 | if __name__ == '__main__': 139 | main() 140 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms as transforms 7 | 8 | import config 9 | 10 | 11 | def batch_accuracy(predicted, true): 12 | """ Compute the accuracies for a batch of predictions and answers """ 13 | _, predicted_index = predicted.max(dim=1, keepdim=True) 14 | agreeing = true.gather(dim=1, index=predicted_index) 15 | ''' 16 | Acc needs to be averaged over all 10 choose 9 subsets of human answers. 17 | While we could just use a loop, surely this can be done more efficiently (and indeed, it can). 18 | There are two cases for the 1 chosen answer to be discarded: 19 | (1) the discarded answer is not the predicted answer => acc stays the same 20 | (2) the discarded answer is the predicted answer => we have to subtract 1 from the number of agreeing answers 21 | 22 | There are (10 - num_agreeing_answers) of case 1 and num_agreeing_answers of case 2, thus 23 | acc = ((10 - agreeing) * min( agreeing / 3, 1) 24 | + agreeing * min((agreeing - 1) / 3, 1)) / 10 25 | 26 | Let's do some more simplification: 27 | if num_agreeing_answers == 0: 28 | acc = 0 since the case 1 min term becomes 0 and case 2 weighting term is 0 29 | if num_agreeing_answers >= 4: 30 | acc = 1 since the min term in both cases is always 1 31 | The only cases left are for 1, 2, and 3 agreeing answers. 32 | In all of those cases, (agreeing - 1) / 3 < agreeing / 3 <= 1, so we can get rid of all the mins. 33 | By moving num_agreeing_answers from both cases outside the sum we get: 34 | acc = agreeing * ((10 - agreeing) + (agreeing - 1)) / 3 / 10 35 | which we can simplify to: 36 | acc = agreeing * 0.3 37 | Finally, we can combine all cases together with: 38 | min(agreeing * 0.3, 1) 39 | ''' 40 | return (agreeing * 0.3).clamp(max=1) 41 | 42 | 43 | def path_for(train=False, val=False, test=False, question=False, answer=False): 44 | assert train + val + test == 1 45 | assert question + answer == 1 46 | assert not (test and answer), 'loading answers from test split not supported' # if you want to eval on test, you need to implement loading of a VQA Dataset without given answers yourself 47 | if train: 48 | split = 'train2014' 49 | elif val: 50 | split = 'val2014' 51 | else: 52 | split = 'test2015' 53 | if question: 54 | fmt = '{0}_{1}_{2}_questions.json' 55 | else: 56 | fmt = '{1}_{2}_annotations.json' 57 | s = fmt.format(config.task, config.dataset, split) 58 | return os.path.join(config.qa_path, s) 59 | 60 | 61 | class Tracker: 62 | """ Keep track of results over time, while having access to monitors to display information about them. """ 63 | def __init__(self): 64 | self.data = {} 65 | 66 | def track(self, name, *monitors): 67 | """ Track a set of results with given monitors under some name (e.g. 'val_acc'). 68 | When appending to the returned list storage, use the monitors to retrieve useful information. 69 | """ 70 | l = Tracker.ListStorage(monitors) 71 | self.data.setdefault(name, []).append(l) 72 | return l 73 | 74 | def to_dict(self): 75 | # turn list storages into regular lists 76 | return {k: list(map(list, v)) for k, v in self.data.items()} 77 | 78 | 79 | class ListStorage: 80 | """ Storage of data points that updates the given monitors """ 81 | def __init__(self, monitors=[]): 82 | self.data = [] 83 | self.monitors = monitors 84 | for monitor in self.monitors: 85 | setattr(self, monitor.name, monitor) 86 | 87 | def append(self, item): 88 | for monitor in self.monitors: 89 | monitor.update(item) 90 | self.data.append(item) 91 | 92 | def __iter__(self): 93 | return iter(self.data) 94 | 95 | class MeanMonitor: 96 | """ Take the mean over the given values """ 97 | name = 'mean' 98 | 99 | def __init__(self): 100 | self.n = 0 101 | self.total = 0 102 | 103 | def update(self, value): 104 | self.total += value 105 | self.n += 1 106 | 107 | @property 108 | def value(self): 109 | return self.total / self.n 110 | 111 | class MovingMeanMonitor: 112 | """ Take an exponentially moving mean over the given values """ 113 | name = 'mean' 114 | 115 | def __init__(self, momentum=0.9): 116 | self.momentum = momentum 117 | self.first = True 118 | self.value = None 119 | 120 | def update(self, value): 121 | if self.first: 122 | self.value = value 123 | self.first = False 124 | else: 125 | m = self.momentum 126 | self.value = m * self.value + (1 - m) * value 127 | 128 | 129 | def get_transform(target_size, central_fraction=1.0): 130 | return transforms.Compose([ 131 | transforms.Scale(int(target_size / central_fraction)), 132 | transforms.CenterCrop(target_size), 133 | transforms.ToTensor(), 134 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 135 | std=[0.229, 0.224, 0.225]), 136 | ]) 137 | -------------------------------------------------------------------------------- /val_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzuhsial/pytorch-vqa-dan/4d42acecf47798b84daa13a1733700103a418586/val_acc.png -------------------------------------------------------------------------------- /view-log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import matplotlib; matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def plot_acc(acc, figname): 8 | plt.figure() 9 | plt.xlabel("Epochs") 10 | plt.ylabel("Accuracy") 11 | plt.axhline(y=0.643, color='r', linestyle='-',label='Paper test acc') 12 | plt.plot(acc,label='Our val acc') 13 | plt.legend() 14 | plt.savefig(figname) 15 | 16 | def main(): 17 | path = sys.argv[1] 18 | results = torch.load(path,map_location={'cuda:0': 'cpu'}) 19 | 20 | val_acc = torch.FloatTensor(results['tracker']['val_acc']) 21 | val_acc = val_acc.mean(dim=1).numpy() 22 | 23 | 24 | plot_acc(val_acc, 'val_acc.png') 25 | 26 | if __name__ == '__main__': 27 | main() 28 | --------------------------------------------------------------------------------