├── LICENSE ├── README.md ├── attention.py ├── base_model.py ├── bc.py ├── classifier.py ├── counting.py ├── data └── flickr30k │ ├── test_ids.pkl │ ├── train_ids.pkl │ └── val_ids.pkl ├── dataset.py ├── evaluate.py ├── fc.py ├── language_model.py ├── main.py ├── misc └── ban_overview.png ├── test.py ├── tools ├── adaptive_detection_features_converter.py ├── compute_softscore.py ├── create_dictionary.py ├── create_embedding.py ├── detection_features_converter.py ├── detection_features_converter_target.py ├── download.sh ├── download_data.sh ├── download_flickr.sh ├── grad_check.py ├── process.sh └── process_flickr.sh ├── train.py ├── train_flickr.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jin-Hwa Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bilinear Attention Networks 2 | 3 | ⚠️ **Regrettably, I cannot perform maintenance due to the loss of the materials. I'm archiving this repository for reference** 4 | 5 | This repository is the implementation of [Bilinear Attention Networks](http://arxiv.org/abs/1805.07932) for the visual question answering and Flickr30k Entities tasks. 6 | 7 | For the visual question answering task, our single model achieved **70.35** and an ensemble of 15 models achieved **71.84** (Test-standard, VQA 2.0). 8 | For the Flickr30k Entities task, our single model achieved **69.88 / 84.39 / 86.40** for Recall@1, 5, and 10, respectively (slightly better than the original paper). 9 | For the detail, please refer to our [technical report](http://arxiv.org/abs/1805.07932). 10 | 11 | This repository is based on and inspired by @hengyuan-hu's [work](https://github.com/hengyuan-hu/bottom-up-attention-vqa). We sincerely thank for their sharing of the codes. 12 | 13 | ![Overview of bilinear attention networks](misc/ban_overview.png) 14 | 15 | ### Updates 16 | 17 | * Bilinear attention networks using `torch.einsum`, backward-compatible. ([12 Mar 2019](https://github.com/jnhwkim/ban-vqa/issues/15#issuecomment-471864594)) 18 | * Now compatible with PyTorch v1.0.1. ([12 Mar 2019](https://github.com/jnhwkim/ban-vqa/pull/22)) 19 | 20 | ### Prerequisites 21 | 22 | You may need a machine with 4 GPUs, 64GB memory, and PyTorch v1.0.1 for Python 3. 23 | 24 | 1. Install [PyTorch](http://pytorch.org/) with CUDA and Python 3.6. 25 | 2. Install [h5py](http://docs.h5py.org/en/latest/build.html). 26 | 27 | *WARNING: do not use PyTorch v1.0.0 due to [a bug](https://github.com/pytorch/pytorch/issues/15602) which induces underperformance.* 28 | 29 | ## VQA 30 | ### Preprocessing 31 | 32 | Our implementation uses the pretrained features from [bottom-up-attention](https://github.com/peteanderson80/bottom-up-attention), the adaptive 10-100 features per image. In addition to this, the GloVe vectors. For the simplicity, the below script helps you to avoid a hassle. 33 | 34 | All data should be downloaded to a `data/` directory in the root directory of this repository. 35 | 36 | The easiest way to download the data is to run the provided script `tools/download.sh` from the repository root. If the script does not work, it should be easy to examine the script and modify the steps outlined in it according to your needs. Then run `tools/process.sh` from the repository root to process the data to the correct format. 37 | 38 | For now, you should manually download for the below options (used in our best single model). 39 | 40 | We use a part of Visual Genome dataset for data augmentation. The [image meta data](https://visualgenome.org/static/data/dataset/image_data.json.zip) and the [question answers](https://visualgenome.org/static/data/dataset/question_answers.json.zip) of Version 1.2 are needed to be placed in `data/`. 41 | 42 | We use MS COCO captions to extract semantically connected words for the extended word embeddings along with the questions of VQA 2.0 and Visual Genome. You can download in [here](http://images.cocodataset.org/annotations/annotations_trainval2017.zip). Since the contribution of these captions is minor, you can skip the processing of MS COCO captions by removing `cap` elements in the `target` option in this [line](https://github.com/jnhwkim/ban-vqa/blob/master/dataset.py#L393). 43 | 44 | Counting module ([Zhang et al., 2018](https://openreview.net/forum?id=B12Js_yRb)) is integrated in this repository as `counting.py` for your convenience. The source repository can be found in @Cyanogenoid's [vqa-counting](https://github.com/Cyanogenoid/vqa-counting). 45 | 46 | ### Training 47 | 48 | ``` 49 | $ python3 main.py --use_both True --use_vg True 50 | ``` 51 | to start training (the options for the train/val splits and Visual Genome to train, respectively). The training and validation scores will be printed every epoch, and the best model will be saved under the directory "saved_models". The default hyperparameters should give you the best result of single model, which is around **70.04** for test-dev split. 52 | 53 | ### Validation 54 | 55 | If you trained a model with the training split using 56 | ``` 57 | $ python3 main.py 58 | ``` 59 | then you can run `evaluate.py` with appropriate options to evaluate its score for the validation split. 60 | 61 | ### Pretrained model 62 | 63 | We provide the pretrained model reported as the best single model in the paper (70.04 for test-dev, 70.35 for test-standard). 64 | 65 | Please download the [link](https://drive.google.com/uc?export=download&id=1OGYxF5WY4uYc_6UobDjhrJIHkl2UGNct) and move to `saved_models/ban/model_epoch12.pth` (you may encounter a redirection page to confirm). The training log is found in [here](https://drive.google.com/uc?export=download&id=1sEa5bTMOFv_Xjo_A0xeNw379_Sljg9R_). 66 | 67 | ``` 68 | $ python3 test.py --label mytest 69 | ``` 70 | 71 | The result json file will be found in the directory `results/`. 72 | 73 | ### Without Visual Genome augmentation 74 | 75 | Without the Visual Genome augmentation, we get **69.50** (average of 8 models with the standard deviation of **0.096**) for the test-dev split. We use the 8-glimpse model, the learning rate is starting with **0.001** (please see this change for the better results), 13 epochs, and the batch size of 256. 76 | 77 | ## Flickr30k Entities 78 | ### Preprocessing 79 | You have to manually download [Annotation and Sentence](https://github.com/BryanPlummer/flickr30k_entities/blob/master/annotations.zip) files to `data/flickr30k/Flickr30kEntities.tar.gz`. Then run the provided script `tools/download_flickr.sh` and `tools/process_flickr.sh` from the root of this repository, similarly to the case of VQA. Note that the image features of Flickr30k were generated using [bottom-up-attention pretrained model](https://github.com/peteanderson80/bottom-up-attention.git). 80 | 81 | ### Training 82 | 83 | ``` 84 | $ python3 main.py --task flickr --out saved_models/flickr 85 | ``` 86 | to start training. `--gamma` option does not applied. The default hyperparameters should give you approximately **69.6** for Recall@1 for the test split. 87 | 88 | 89 | ### Validation 90 | Please download the [link](https://drive.google.com/uc?export=download&id=1xiVVRPsbabipyHes25iE0uj2YkdKWv3K) and move to `saved_models/flickr/model_epoch5.pth` (you may encounter a redirection page to confirm). 91 | 92 | ``` 93 | $ python3 evaluate.py --task flickr --input saved_models/flickr --epoch 5 94 | ``` 95 | to evaluate the scores for the test split. 96 | 97 | 98 | 99 | 100 | ### Troubleshooting 101 | 102 | Please check [troubleshooting wiki](https://github.com/jnhwkim/ban-vqa/wiki/Troubleshooting) and [previous issue history](https://github.com/jnhwkim/ban-vqa/issues?utf8=✓&q=is%3Aissue). 103 | 104 | ### Citation 105 | 106 | If you use this code as part of any published research, we'd really appreciate it if you could cite the following paper: 107 | 108 | ``` 109 | @inproceedings{Kim2018, 110 | author = {Kim, Jin-Hwa and Jun, Jaehyun and Zhang, Byoung-Tak}, 111 | booktitle = {Advances in Neural Information Processing Systems 31}, 112 | title = {{Bilinear Attention Networks}}, 113 | pages = {1571--1581}, 114 | year = {2018} 115 | } 116 | ``` 117 | 118 | ### License 119 | 120 | MIT License 121 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bilinear Attention Networks 3 | Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang 4 | https://arxiv.org/abs/1805.07932 5 | 6 | This code is written by Jin-Hwa Kim. 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.utils.weight_norm import weight_norm 11 | from fc import FCNet 12 | from bc import BCNet 13 | 14 | 15 | class BiAttention(nn.Module): 16 | def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2,.5]): 17 | super(BiAttention, self).__init__() 18 | 19 | self.glimpse = glimpse 20 | self.logits = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3), \ 21 | name='h_mat', dim=None) 22 | 23 | def forward(self, v, q, v_mask=True): 24 | """ 25 | v: [batch, k, vdim] 26 | q: [batch, qdim] 27 | """ 28 | p, logits = self.forward_all(v, q, v_mask) 29 | return p, logits 30 | 31 | def forward_all(self, v, q, v_mask=True, logit=False, mask_with=-float('inf')): 32 | v_num = v.size(1) 33 | q_num = q.size(1) 34 | logits = self.logits(v,q) # b x g x v x q 35 | 36 | if v_mask: 37 | mask = (0 == v.abs().sum(2)).unsqueeze(1).unsqueeze(3).expand(logits.size()) 38 | logits.data.masked_fill_(mask.data, mask_with) 39 | 40 | if not logit: 41 | p = nn.functional.softmax(logits.view(-1, self.glimpse, v_num * q_num), 2) 42 | return p.view(-1, self.glimpse, v_num, q_num), logits 43 | 44 | return logits 45 | -------------------------------------------------------------------------------- /base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bilinear Attention Networks 3 | Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang 4 | https://arxiv.org/abs/1805.07932 5 | 6 | This code is written by Jin-Hwa Kim. 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.utils.weight_norm import weight_norm 12 | import utils 13 | from attention import BiAttention 14 | from language_model import WordEmbedding, QuestionEmbedding 15 | from classifier import SimpleClassifier 16 | from fc import FCNet 17 | from bc import BCNet 18 | from counting import Counter 19 | 20 | 21 | class BanModel(nn.Module): 22 | def __init__(self, dataset, w_emb, q_emb, v_att, b_net, q_prj, c_prj, classifier, counter, op, glimpse): 23 | super(BanModel, self).__init__() 24 | self.dataset = dataset 25 | self.op = op 26 | self.glimpse = glimpse 27 | self.w_emb = w_emb 28 | self.q_emb = q_emb 29 | self.v_att = v_att 30 | self.b_net = nn.ModuleList(b_net) 31 | self.q_prj = nn.ModuleList(q_prj) 32 | self.c_prj = nn.ModuleList(c_prj) 33 | self.classifier = classifier 34 | self.counter = counter 35 | self.drop = nn.Dropout(.5) 36 | self.tanh = nn.Tanh() 37 | 38 | def forward(self, v, b, q, labels): 39 | """Forward 40 | 41 | v: [batch, num_objs, obj_dim] 42 | b: [batch, num_objs, b_dim] 43 | q: [batch_size, seq_length] 44 | 45 | return: logits, not probs 46 | """ 47 | w_emb = self.w_emb(q) 48 | q_emb = self.q_emb.forward_all(w_emb) # [batch, q_len, q_dim] 49 | boxes = b[:,:,:4].transpose(1,2) 50 | 51 | b_emb = [0] * self.glimpse 52 | att, logits = self.v_att.forward_all(v, q_emb) # b x g x v x q 53 | 54 | for g in range(self.glimpse): 55 | b_emb[g] = self.b_net[g].forward_with_weights(v, q_emb, att[:,g,:,:]) # b x l x h 56 | 57 | atten, _ = logits[:,g,:,:].max(2) 58 | embed = self.counter(boxes, atten) 59 | 60 | q_emb = self.q_prj[g](b_emb[g].unsqueeze(1)) + q_emb 61 | q_emb = q_emb + self.c_prj[g](embed).unsqueeze(1) 62 | 63 | logits = self.classifier(q_emb.sum(1)) 64 | 65 | return logits, att 66 | 67 | class BanModel_flickr(nn.Module): 68 | def __init__(self, w_emb, q_emb, v_att, op, glimpse): 69 | super(BanModel_flickr, self).__init__() 70 | self.op = op 71 | self.glimpse = glimpse 72 | self.w_emb = w_emb 73 | self.q_emb = q_emb 74 | self.v_att = v_att 75 | self.alpha = torch.Tensor([1.]*(glimpse)) 76 | 77 | # features, spatials, sentence, e_pos, target 78 | def forward(self, v, b, q, e, labels): 79 | """Forward 80 | 81 | v: [batch, num_objs, obj_dim] 82 | b: [batch, num_objs, b_dim] 83 | q: [batch, seq_length] 84 | e: [batch, num_entities] 85 | 86 | return: logits, not probs 87 | """ 88 | assert q.size(1) > e.data.max(), 'len(q)=%d > e_pos.max()=%d' % (q.size(1), e.data.max()) 89 | MINUS_INFINITE = -99 90 | if 's' in self.op: 91 | v = torch.cat([v, b], 2) 92 | w_emb = self.w_emb(q) 93 | q_emb = self.q_emb.forward_all(w_emb) # [batch, q_len, q_dim] 94 | # entity positions 95 | q_emb = utils.batched_index_select(q_emb, 1, e) 96 | 97 | att = self.v_att.forward_all(v, q_emb, True, True, MINUS_INFINITE) # b x g x v x q 98 | mask = (e == 0).unsqueeze(1).unsqueeze(2).expand(att.size()) 99 | mask[:, :, :, 0].data.fill_(0) # at least one entity per sentence 100 | att.data.masked_fill_(mask.data, MINUS_INFINITE) 101 | 102 | return None, att 103 | 104 | 105 | def build_ban(dataset, num_hid, op='', gamma=4, task='vqa'): 106 | w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, .0, op) 107 | q_emb = QuestionEmbedding(300 if 'c' not in op else 600, num_hid, 1, False, .0) 108 | v_att = BiAttention(dataset.v_dim, num_hid, num_hid, gamma) 109 | if task == 'vqa': 110 | b_net = [] 111 | q_prj = [] 112 | c_prj = [] 113 | objects = 10 # minimum number of boxes 114 | for i in range(gamma): 115 | b_net.append(BCNet(dataset.v_dim, num_hid, num_hid, None, k=1)) 116 | q_prj.append(FCNet([num_hid, num_hid], '', .2)) 117 | c_prj.append(FCNet([objects + 1, num_hid], 'ReLU', .0)) 118 | classifier = SimpleClassifier( 119 | num_hid, num_hid * 2, dataset.num_ans_candidates, .5) 120 | counter = Counter(objects) 121 | return BanModel(dataset, w_emb, q_emb, v_att, b_net, q_prj, c_prj, classifier, counter, op, gamma) 122 | elif task == 'flickr': 123 | return BanModel_flickr(w_emb, q_emb, v_att, op, gamma) 124 | -------------------------------------------------------------------------------- /bc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bilinear Attention Networks 3 | Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang 4 | https://arxiv.org/abs/1805.07932 5 | 6 | This code is written by Jin-Hwa Kim. 7 | """ 8 | from __future__ import print_function 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.utils.weight_norm import weight_norm 13 | from fc import FCNet 14 | 15 | 16 | class BCNet(nn.Module): 17 | """Simple class for non-linear bilinear connect network 18 | """ 19 | def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=[.2,.5], k=3): 20 | super(BCNet, self).__init__() 21 | 22 | self.c = 32 23 | self.k = k 24 | self.v_dim = v_dim; self.q_dim = q_dim 25 | self.h_dim = h_dim; self.h_out = h_out 26 | 27 | self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0]) 28 | self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0]) 29 | self.dropout = nn.Dropout(dropout[1]) # attention 30 | if 1 < k: 31 | self.p_net = nn.AvgPool1d(self.k, stride=self.k) 32 | 33 | if None == h_out: 34 | pass 35 | elif h_out <= self.c: 36 | self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) 37 | self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) 38 | else: 39 | self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) 40 | 41 | def forward(self, v, q): 42 | if None == self.h_out: 43 | v_ = self.v_net(v) 44 | q_ = self.q_net(q) 45 | logits = torch.einsum('bvk,bqk->bvqk', (v_, q_)) 46 | return logits 47 | 48 | # low-rank bilinear pooling using einsum 49 | elif self.h_out <= self.c: 50 | v_ = self.dropout(self.v_net(v)) 51 | q_ = self.q_net(q) 52 | logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias 53 | return logits # b x h_out x v x q 54 | 55 | # batch outer product, linear projection 56 | # memory efficient but slow computation 57 | else: 58 | v_ = self.dropout(self.v_net(v)).transpose(1,2).unsqueeze(3) 59 | q_ = self.q_net(q).transpose(1,2).unsqueeze(2) 60 | d_ = torch.matmul(v_, q_) # b x h_dim x v x q 61 | logits = self.h_net(d_.transpose(1,2).transpose(2,3)) # b x v x q x h_out 62 | return logits.transpose(2,3).transpose(1,2) # b x h_out x v x q 63 | 64 | def forward_with_weights(self, v, q, w): 65 | v_ = self.v_net(v) # b x v x d 66 | q_ = self.q_net(q) # b x q x d 67 | logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_)) 68 | if 1 < self.k: 69 | logits = logits.unsqueeze(1) # b x 1 x d 70 | logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling 71 | return logits 72 | 73 | 74 | if __name__=='__main__': 75 | net = BCNet(1024,1024,1024,1024).cuda() 76 | x = torch.Tensor(512,36,1024).cuda() 77 | y = torch.Tensor(512,14,1024).cuda() 78 | out = net.forward(x,y) 79 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | from torch.nn.utils.weight_norm import weight_norm 9 | 10 | 11 | class SimpleClassifier(nn.Module): 12 | def __init__(self, in_dim, hid_dim, out_dim, dropout): 13 | super(SimpleClassifier, self).__init__() 14 | layers = [ 15 | weight_norm(nn.Linear(in_dim, hid_dim), dim=None), 16 | nn.ReLU(), 17 | nn.Dropout(dropout, inplace=True), 18 | weight_norm(nn.Linear(hid_dim, out_dim), dim=None) 19 | ] 20 | self.main = nn.Sequential(*layers) 21 | 22 | def forward(self, x): 23 | logits = self.main(x) 24 | return logits 25 | -------------------------------------------------------------------------------- /counting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Learning to Count Objects in Natural Images for Visual Question Answering 3 | Yan Zhang, Jonathon Hare, Adam Prügel-Bennett 4 | ICLR 2018 5 | 6 | This code is from Yan Zhang's repository. 7 | https://github.com/Cyanogenoid/vqa-counting/blob/master/vqa-v2/counting.py 8 | MIT License 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class Counter(nn.Module): 15 | """ Counting module as proposed in [1]. 16 | Count the number of objects from a set of bounding boxes and a set of scores for each bounding box. 17 | This produces (self.objects + 1) number of count features. 18 | 19 | [1]: Yan Zhang, Jonathon Hare, Adam Prügel-Bennett: Learning to Count Objects in Natural Images for Visual Question Answering. 20 | https://openreview.net/forum?id=B12Js_yRb 21 | """ 22 | def __init__(self, objects, already_sigmoided=False): 23 | super().__init__() 24 | self.objects = objects 25 | self.already_sigmoided = already_sigmoided 26 | self.f = nn.ModuleList([PiecewiseLin(16) for _ in range(16)]) 27 | 28 | def forward(self, boxes, attention): 29 | """ Forward propagation of attention weights and bounding boxes to produce count features. 30 | `boxes` has to be a tensor of shape (n, 4, m) with the 4 channels containing the x and y coordinates of the top left corner and the x and y coordinates of the bottom right corner in this order. 31 | `attention` has to be a tensor of shape (n, m). Each value should be in [0, 1] if already_sigmoided is set to True, but there are no restrictions if already_sigmoided is set to False. This value should be close to 1 if the corresponding boundign box is relevant and close to 0 if it is not. 32 | n is the batch size, m is the number of bounding boxes per image. 33 | """ 34 | # only care about the highest scoring object proposals 35 | # the ones with low score will have a low impact on the count anyway 36 | boxes, attention = self.filter_most_important(self.objects, boxes, attention) 37 | # normalise the attention weights to be in [0, 1] 38 | if not self.already_sigmoided: 39 | attention = torch.sigmoid(attention) 40 | 41 | relevancy = self.outer_product(attention) 42 | distance = 1 - self.iou(boxes, boxes) 43 | 44 | # intra-object dedup 45 | score = self.f[0](relevancy) * self.f[1](distance) 46 | 47 | # inter-object dedup 48 | dedup_score = self.f[3](relevancy) * self.f[4](distance) 49 | dedup_per_entry, dedup_per_row = self.deduplicate(dedup_score, attention) 50 | score = score / dedup_per_entry 51 | 52 | # aggregate the score 53 | # can skip putting this on the diagonal since we're just summing over it anyway 54 | correction = self.f[0](attention * attention) / dedup_per_row 55 | score = score.sum(dim=2).sum(dim=1, keepdim=True) + correction.sum(dim=1, keepdim=True) 56 | score = (score + 1e-20).sqrt() 57 | one_hot = self.to_one_hot(score) 58 | 59 | att_conf = (self.f[5](attention) - 0.5).abs() 60 | dist_conf = (self.f[6](distance) - 0.5).abs() 61 | conf = self.f[7](att_conf.mean(dim=1, keepdim=True) + dist_conf.mean(dim=2).mean(dim=1, keepdim=True)) 62 | 63 | return one_hot * conf 64 | 65 | def deduplicate(self, dedup_score, att): 66 | # using outer-diffs 67 | att_diff = self.outer_diff(att) 68 | score_diff = self.outer_diff(dedup_score) 69 | sim = self.f[2](1 - score_diff).prod(dim=1) * self.f[2](1 - att_diff) 70 | # similarity for each row 71 | row_sims = sim.sum(dim=2) 72 | # similarity for each entry 73 | all_sims = self.outer_product(row_sims) 74 | return all_sims, row_sims 75 | 76 | def to_one_hot(self, scores): 77 | """ Turn a bunch of non-negative scalar values into a one-hot encoding. 78 | E.g. with self.objects = 3, 0 -> [1 0 0 0], 2.75 -> [0 0 0.25 0.75]. 79 | """ 80 | # sanity check, I don't think this ever does anything (it certainly shouldn't) 81 | scores = scores.clamp(min=0, max=self.objects) 82 | # compute only on the support 83 | i = scores.long().data 84 | f = scores.frac() 85 | # target_l is the one-hot if the score is rounded down 86 | # target_r is the one-hot if the score is rounded up 87 | target_l = scores.data.new(i.size(0), self.objects + 1).fill_(0) 88 | target_r = scores.data.new(i.size(0), self.objects + 1).fill_(0) 89 | 90 | target_l.scatter_(dim=1, index=i.clamp(max=self.objects), value=1) 91 | target_r.scatter_(dim=1, index=(i + 1).clamp(max=self.objects), value=1) 92 | # interpolate between these with the fractional part of the score 93 | return (1 - f) * target_l + f * target_r 94 | 95 | def filter_most_important(self, n, boxes, attention): 96 | """ Only keep top-n object proposals, scored by attention weight """ 97 | attention, idx = attention.topk(n, dim=1, sorted=False) 98 | idx = idx.unsqueeze(dim=1).expand(boxes.size(0), boxes.size(1), idx.size(1)) 99 | boxes = boxes.gather(2, idx) 100 | return boxes, attention 101 | 102 | def outer(self, x): 103 | size = tuple(x.size()) + (x.size()[-1],) 104 | a = x.unsqueeze(dim=-1).expand(*size) 105 | b = x.unsqueeze(dim=-2).expand(*size) 106 | return a, b 107 | 108 | def outer_product(self, x): 109 | # Y_ij = x_i * x_j 110 | a, b = self.outer(x) 111 | return a * b 112 | 113 | def outer_diff(self, x): 114 | # like outer products, except taking the absolute difference instead 115 | # Y_ij = | x_i - x_j | 116 | a, b = self.outer(x) 117 | return (a - b).abs() 118 | 119 | def iou(self, a, b): 120 | # this is just the usual way to IoU from bounding boxes 121 | inter = self.intersection(a, b) 122 | area_a = self.area(a).unsqueeze(2).expand_as(inter) 123 | area_b = self.area(b).unsqueeze(1).expand_as(inter) 124 | return inter / (area_a + area_b - inter + 1e-12) 125 | 126 | def area(self, box): 127 | x = (box[:, 2, :] - box[:, 0, :]).clamp(min=0) 128 | y = (box[:, 3, :] - box[:, 1, :]).clamp(min=0) 129 | return x * y 130 | 131 | def intersection(self, a, b): 132 | size = (a.size(0), 2, a.size(2), b.size(2)) 133 | min_point = torch.max( 134 | a[:, :2, :].unsqueeze(dim=3).expand(*size), 135 | b[:, :2, :].unsqueeze(dim=2).expand(*size), 136 | ) 137 | max_point = torch.min( 138 | a[:, 2:, :].unsqueeze(dim=3).expand(*size), 139 | b[:, 2:, :].unsqueeze(dim=2).expand(*size), 140 | ) 141 | inter = (max_point - min_point).clamp(min=0) 142 | area = inter[:, 0, :, :] * inter[:, 1, :, :] 143 | return area 144 | 145 | 146 | class PiecewiseLin(nn.Module): 147 | def __init__(self, n): 148 | super().__init__() 149 | self.n = n 150 | self.weight = nn.Parameter(torch.ones(n + 1)) 151 | # the first weight here is always 0 with a 0 gradient 152 | self.weight.data[0] = 0 153 | 154 | def forward(self, x): 155 | # all weights are positive -> function is monotonically increasing 156 | w = self.weight.abs() 157 | # make weights sum to one -> f(1) = 1 158 | w = w / w.sum() 159 | w = w.view([self.n + 1] + [1] * x.dim()) 160 | # keep cumulative sum for O(1) time complexity 161 | csum = w.cumsum(dim=0) 162 | csum = csum.expand((self.n + 1,) + tuple(x.size())) 163 | w = w.expand_as(csum) 164 | 165 | # figure out which part of the function the input lies on 166 | y = self.n * x.unsqueeze(0) 167 | idx = y.long().data 168 | f = y.frac() 169 | 170 | # contribution of the linear parts left of the input 171 | x = csum.gather(0, idx.clamp(max=self.n)) 172 | # contribution within the linear segment the input falls into 173 | x = x + f * w.gather(0, (idx + 1).clamp(max=self.n)) 174 | return x.squeeze(0) 175 | -------------------------------------------------------------------------------- /data/flickr30k/test_ids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jnhwkim/ban-vqa/ef4cc36d63f6b31344301ac1bd91afd807d132e5/data/flickr30k/test_ids.pkl -------------------------------------------------------------------------------- /data/flickr30k/train_ids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jnhwkim/ban-vqa/ef4cc36d63f6b31344301ac1bd91afd807d132e5/data/flickr30k/train_ids.pkl -------------------------------------------------------------------------------- /data/flickr30k/val_ids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jnhwkim/ban-vqa/ef4cc36d63f6b31344301ac1bd91afd807d132e5/data/flickr30k/val_ids.pkl -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import json 8 | import _pickle as cPickle 9 | import numpy as np 10 | import utils 11 | import warnings 12 | with warnings.catch_warnings(): 13 | warnings.filterwarnings("ignore",category=FutureWarning) 14 | import h5py 15 | from xml.etree.ElementTree import parse 16 | import torch 17 | from torch.utils.data import Dataset 18 | import tools.compute_softscore 19 | import itertools 20 | import re 21 | 22 | COUNTING_ONLY = False 23 | 24 | # Following Trott et al. (ICLR 2018) 25 | # Interpretable Counting for Visual Question Answering 26 | def is_howmany(q, a, label2ans): 27 | if 'how many' in q.lower() or \ 28 | ('number of' in q.lower() and 'number of the' not in q.lower()) or \ 29 | 'amount of' in q.lower() or \ 30 | 'count of' in q.lower(): 31 | if a is None or answer_filter(a, label2ans): 32 | return True 33 | else: 34 | return False 35 | else: 36 | return False 37 | 38 | 39 | def answer_filter(answers, label2ans, max_num=10): 40 | for ans in answers['labels']: 41 | if label2ans[ans].isdigit() and max_num >= int(label2ans[ans]): 42 | return True 43 | return False 44 | 45 | 46 | class Dictionary(object): 47 | def __init__(self, word2idx=None, idx2word=None): 48 | if word2idx is None: 49 | word2idx = {} 50 | if idx2word is None: 51 | idx2word = [] 52 | self.word2idx = word2idx 53 | self.idx2word = idx2word 54 | 55 | @property 56 | def ntoken(self): 57 | return len(self.word2idx) 58 | 59 | @property 60 | def padding_idx(self): 61 | return len(self.word2idx) 62 | 63 | def tokenize(self, sentence, add_word): 64 | sentence = sentence.lower() 65 | sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s') 66 | words = sentence.split() 67 | tokens = [] 68 | if add_word: 69 | for w in words: 70 | tokens.append(self.add_word(w)) 71 | else: 72 | for w in words: 73 | # the least frequent word (`bebe`) as UNK for Visual Genome dataset 74 | tokens.append(self.word2idx.get(w, self.padding_idx-1)) 75 | return tokens 76 | 77 | def dump_to_file(self, path): 78 | cPickle.dump([self.word2idx, self.idx2word], open(path, 'wb')) 79 | print('dictionary dumped to %s' % path) 80 | 81 | @classmethod 82 | def load_from_file(cls, path): 83 | print('loading dictionary from %s' % path) 84 | word2idx, idx2word = cPickle.load(open(path, 'rb')) 85 | d = cls(word2idx, idx2word) 86 | return d 87 | 88 | def add_word(self, word): 89 | if word not in self.word2idx: 90 | self.idx2word.append(word) 91 | self.word2idx[word] = len(self.idx2word) - 1 92 | return self.word2idx[word] 93 | 94 | def __len__(self): 95 | return len(self.idx2word) 96 | 97 | 98 | def _create_entry(img, question, answer): 99 | if None!=answer: 100 | answer.pop('image_id') 101 | answer.pop('question_id') 102 | entry = { 103 | 'question_id' : question['question_id'], 104 | 'image_id' : question['image_id'], 105 | 'image' : img, 106 | 'question' : question['question'], 107 | 'answer' : answer} 108 | return entry 109 | 110 | 111 | def _load_dataset(dataroot, name, img_id2val, label2ans): 112 | """Load entries 113 | 114 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features 115 | dataroot: root path of dataset 116 | name: 'train', 'val', 'test-dev2015', test2015' 117 | """ 118 | question_path = os.path.join( 119 | dataroot, 'v2_OpenEnded_mscoco_%s_questions.json' % \ 120 | (name + '2014' if 'test'!=name[:4] else name)) 121 | questions = sorted(json.load(open(question_path))['questions'], 122 | key=lambda x: x['question_id']) 123 | if 'test'!=name[:4]: # train, val 124 | answer_path = os.path.join(dataroot, 'cache', '%s_target.pkl' % name) 125 | answers = cPickle.load(open(answer_path, 'rb')) 126 | answers = sorted(answers, key=lambda x: x['question_id']) 127 | 128 | utils.assert_eq(len(questions), len(answers)) 129 | entries = [] 130 | for question, answer in zip(questions, answers): 131 | utils.assert_eq(question['question_id'], answer['question_id']) 132 | utils.assert_eq(question['image_id'], answer['image_id']) 133 | img_id = question['image_id'] 134 | if not COUNTING_ONLY or is_howmany(question['question'], answer, label2ans): 135 | entries.append(_create_entry(img_id2val[img_id], question, answer)) 136 | else: # test2015 137 | entries = [] 138 | for question in questions: 139 | img_id = question['image_id'] 140 | if not COUNTING_ONLY or is_howmany(question['question'], None, None): 141 | entries.append(_create_entry(img_id2val[img_id], question, None)) 142 | 143 | return entries 144 | 145 | 146 | def _load_visualgenome(dataroot, name, img_id2val, label2ans, adaptive=True): 147 | """Load entries 148 | 149 | img_id2val: dict {img_id -> val} val can be used to retrieve image or features 150 | dataroot: root path of dataset 151 | name: 'train', 'val' 152 | """ 153 | question_path = os.path.join(dataroot, 'question_answers.json') 154 | image_data_path = os.path.join(dataroot, 'image_data.json') 155 | ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl') 156 | cache_path = os.path.join(dataroot, 'cache', 'vg_%s%s_target.pkl' % (name, '_adaptive' if adaptive else '')) 157 | 158 | if os.path.isfile(cache_path): 159 | entries = cPickle.load(open(cache_path, 'rb')) 160 | else: 161 | entries = [] 162 | ans2label = cPickle.load(open(ans2label_path, 'rb')) 163 | vgq = json.load(open(question_path, 'r')) 164 | _vgv = json.load(open(image_data_path, 'r')) #108,077 165 | vgv = {} 166 | for _v in _vgv: 167 | if None != _v['coco_id']: 168 | vgv[_v['image_id']] = _v['coco_id'] 169 | counts = [0, 0, 0, 0] # used image, used question, total question, out-of-split 170 | for vg in vgq: 171 | coco_id = vgv.get(vg['id'], None) 172 | if None != coco_id: 173 | counts[0] += 1 174 | img_idx = img_id2val.get(coco_id, None) 175 | if None == img_idx: 176 | counts[3] += 1 177 | for q in vg['qas']: 178 | counts[2] += 1 179 | _answer = tools.compute_softscore.preprocess_answer(q['answer']) 180 | label = ans2label.get(_answer, None) 181 | if None != label and None != img_idx: 182 | counts[1] += 1 183 | answer = { 184 | 'labels': [label], 185 | 'scores': [1.]} 186 | entry = { 187 | 'question_id' : q['id'], 188 | 'image_id' : coco_id, 189 | 'image' : img_idx, 190 | 'question' : q['question'], 191 | 'answer' : answer} 192 | if not COUNTING_ONLY or is_howmany(q['question'], answer, label2ans): 193 | entries.append(entry) 194 | 195 | print('Loading VisualGenome %s' % name) 196 | print('\tUsed COCO images: %d/%d (%.4f)' % \ 197 | (counts[0], len(_vgv), counts[0]/len(_vgv))) 198 | print('\tOut-of-split COCO images: %d/%d (%.4f)' % \ 199 | (counts[3], counts[0], counts[3]/counts[0])) 200 | print('\tUsed VG questions: %d/%d (%.4f)' % \ 201 | (counts[1], counts[2], counts[1]/counts[2])) 202 | with open(cache_path, 'wb') as f: 203 | cPickle.dump(entries, open(cache_path, 'wb')) 204 | 205 | return entries 206 | 207 | 208 | def _find_coco_id(vgv, vgv_id): 209 | for v in vgv: 210 | if v['id']==vgv_id: 211 | return v['coco_id'] 212 | return None 213 | 214 | 215 | def _load_flickr30k(dataroot, img_id2idx, bbox, pos_boxes): 216 | """Load entries 217 | 218 | img_id2idx: dict {img_id -> val} val can be used to retrieve image or features 219 | dataroot: root path of dataset 220 | name: 'train', 'val', 'test-dev2015', test2015' 221 | """ 222 | pattern_phrase = r'\[(.*?)\]' 223 | pattern_no = r'\/EN\#(\d+)' 224 | 225 | missing_entity_count = dict() 226 | multibox_entity_count = 0 227 | 228 | entries = [] 229 | for image_id, idx in img_id2idx.items(): 230 | 231 | phrase_file = os.path.join(dataroot, 'Flickr30kEntities/Sentences/%d.txt' % image_id) 232 | anno_file = os.path.join(dataroot, 'Flickr30kEntities/Annotations/%d.xml' % image_id) 233 | 234 | with open(phrase_file, 'r', encoding='utf-8') as f: 235 | sents = [x.strip() for x in f] 236 | 237 | # Parse Annotation 238 | root = parse(anno_file).getroot() 239 | obj_elems = root.findall('./object') 240 | pos_box = pos_boxes[idx] 241 | bboxes = bbox[pos_box[0]:pos_box[1]] 242 | target_bboxes = {} 243 | 244 | for elem in obj_elems: 245 | if elem.find('bndbox') == None or len(elem.find('bndbox')) == 0: 246 | continue 247 | left = int(elem.findtext('./bndbox/xmin')) 248 | top = int(elem.findtext('./bndbox/ymin')) 249 | right = int(elem.findtext('./bndbox/xmax')) 250 | bottom = int(elem.findtext('./bndbox/ymax')) 251 | assert 0 < left and 0 < top 252 | 253 | for name in elem.findall('name'): 254 | entity_id = int(name.text) 255 | assert 0 < entity_id 256 | if not entity_id in target_bboxes: 257 | target_bboxes[entity_id] = [] 258 | else: 259 | multibox_entity_count += 1 260 | target_bboxes[entity_id].append([left, top, right, bottom]) 261 | 262 | # Parse Sentence 263 | for sent_id, sent in enumerate(sents): 264 | sentence = utils.remove_annotations(sent) 265 | entities = re.findall(pattern_phrase, sent) 266 | entity_indices = [] 267 | target_indices = [] 268 | entity_ids = [] 269 | entity_types = [] 270 | 271 | for entity_i, entity in enumerate(entities): 272 | info, phrase = entity.split(' ', 1) 273 | entity_id = int(re.findall(pattern_no, info)[0]) 274 | entity_type = info.split('/')[2:] 275 | 276 | entity_idx = utils.find_sublist(sentence.split(' '), phrase.split(' ')) 277 | assert 0 <= entity_idx 278 | 279 | if not entity_id in target_bboxes: 280 | if entity_id >= 0: 281 | missing_entity_count[entity_type[0]] = missing_entity_count.get(entity_type[0], 0) + 1 282 | continue 283 | 284 | assert 0 < entity_id 285 | 286 | entity_ids.append(entity_id) 287 | entity_types.append(entity_type) 288 | 289 | target_idx = utils.get_match_index(target_bboxes[entity_id], bboxes) 290 | entity_indices.append(entity_idx) 291 | target_indices.append(target_idx) 292 | 293 | if 0 == len(entity_ids): 294 | continue 295 | 296 | entries.append( 297 | _create_flickr_entry(idx, sentence, entity_indices, target_indices, entity_ids, entity_types)) 298 | 299 | if 0 < len(missing_entity_count.keys()): 300 | print('missing_entity_count=') 301 | print(missing_entity_count) 302 | print('multibox_entity_count=%d' % multibox_entity_count) 303 | 304 | return entries 305 | 306 | 307 | # idx, sentence, entity_indices, target_indices, entity_ids, entity_types 308 | def _create_flickr_entry(img, sentence, entity_indices, target_indices, entity_ids, entity_types): 309 | type_map = {'people':0,'clothing':1,'bodyparts':2,'animals':3,'vehicles':4,'instruments':5,'scene':6,'other':7} 310 | MAX_TYPE_NUM = 3 311 | for i, entity_type in enumerate(entity_types): 312 | assert MAX_TYPE_NUM >= len(entity_type) 313 | entity_types[i] = list(type_map[x] for x in entity_type) 314 | entity_types[i] += [-1] * (MAX_TYPE_NUM-len(entity_type)) 315 | entry = { 316 | 'image' : img, 317 | 'sentence' : sentence, 318 | 'entity_indices' : entity_indices, 319 | 'target_indices' : target_indices, 320 | 'entity_ids' : entity_ids, 321 | 'entity_types' : entity_types, 322 | 'entity_num' : len(entity_ids)} 323 | return entry 324 | 325 | 326 | 327 | class VQAFeatureDataset(Dataset): 328 | def __init__(self, name, dictionary, dataroot='data', adaptive=False): 329 | super(VQAFeatureDataset, self).__init__() 330 | assert name in ['train', 'val', 'test-dev2015', 'test2015'] 331 | 332 | ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl') 333 | label2ans_path = os.path.join(dataroot, 'cache', 'trainval_label2ans.pkl') 334 | self.ans2label = cPickle.load(open(ans2label_path, 'rb')) 335 | self.label2ans = cPickle.load(open(label2ans_path, 'rb')) 336 | self.num_ans_candidates = len(self.ans2label) 337 | 338 | self.dictionary = dictionary 339 | self.adaptive = adaptive 340 | 341 | self.img_id2idx = cPickle.load( 342 | open(os.path.join(dataroot, '%s%s_imgid2idx.pkl' % (name, '' if self.adaptive else '36')), 'rb')) 343 | 344 | h5_path = os.path.join(dataroot, '%s%s.hdf5' % (name, '' if self.adaptive else '36')) 345 | 346 | print('loading features from h5 file') 347 | with h5py.File(h5_path, 'r') as hf: 348 | self.features = np.array(hf.get('image_features')) 349 | self.spatials = np.array(hf.get('spatial_features')) 350 | if self.adaptive: 351 | self.pos_boxes = np.array(hf.get('pos_boxes')) 352 | 353 | self.entries = _load_dataset(dataroot, name, self.img_id2idx, self.label2ans) 354 | self.tokenize() 355 | self.tensorize() 356 | self.v_dim = self.features.size(1 if self.adaptive else 2) 357 | self.s_dim = self.spatials.size(1 if self.adaptive else 2) 358 | 359 | def tokenize(self, max_length=14): 360 | """Tokenizes the questions. 361 | 362 | This will add q_token in each entry of the dataset. 363 | -1 represent nil, and should be treated as padding_idx in embedding 364 | """ 365 | for entry in self.entries: 366 | tokens = self.dictionary.tokenize(entry['question'], False) 367 | tokens = tokens[:max_length] 368 | if len(tokens) < max_length: 369 | # Note here we pad in front of the sentence 370 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 371 | tokens = tokens + padding 372 | utils.assert_eq(len(tokens), max_length) 373 | entry['q_token'] = tokens 374 | 375 | def tensorize(self): 376 | self.features = torch.from_numpy(self.features) 377 | self.spatials = torch.from_numpy(self.spatials) 378 | 379 | for entry in self.entries: 380 | question = torch.from_numpy(np.array(entry['q_token'])) 381 | entry['q_token'] = question 382 | 383 | answer = entry['answer'] 384 | if None!=answer: 385 | labels = np.array(answer['labels']) 386 | scores = np.array(answer['scores'], dtype=np.float32) 387 | if len(labels): 388 | labels = torch.from_numpy(labels) 389 | scores = torch.from_numpy(scores) 390 | entry['answer']['labels'] = labels 391 | entry['answer']['scores'] = scores 392 | else: 393 | entry['answer']['labels'] = None 394 | entry['answer']['scores'] = None 395 | 396 | def __getitem__(self, index): 397 | entry = self.entries[index] 398 | if not self.adaptive: 399 | features = self.features[entry['image']] 400 | spatials = self.spatials[entry['image']] 401 | else: 402 | features = self.features[self.pos_boxes[entry['image']][0]:self.pos_boxes[entry['image']][1],:] 403 | spatials = self.spatials[self.pos_boxes[entry['image']][0]:self.pos_boxes[entry['image']][1],:] 404 | 405 | question = entry['q_token'] 406 | question_id = entry['question_id'] 407 | answer = entry['answer'] 408 | if None!=answer: 409 | labels = answer['labels'] 410 | scores = answer['scores'] 411 | target = torch.zeros(self.num_ans_candidates) 412 | if labels is not None: 413 | target.scatter_(0, labels, scores) 414 | return features, spatials, question, target 415 | else: 416 | return features, spatials, question, question_id 417 | 418 | def __len__(self): 419 | return len(self.entries) 420 | 421 | 422 | class VisualGenomeFeatureDataset(Dataset): 423 | def __init__(self, name, features, spatials, dictionary, dataroot='data', adaptive=False, pos_boxes=None): 424 | super(VisualGenomeFeatureDataset, self).__init__() 425 | # do not use test split images! 426 | assert name in ['train', 'val'] 427 | 428 | ans2label_path = os.path.join(dataroot, 'cache', 'trainval_ans2label.pkl') 429 | label2ans_path = os.path.join(dataroot, 'cache', 'trainval_label2ans.pkl') 430 | self.ans2label = cPickle.load(open(ans2label_path, 'rb')) 431 | self.label2ans = cPickle.load(open(label2ans_path, 'rb')) 432 | self.num_ans_candidates = len(self.ans2label) 433 | 434 | self.dictionary = dictionary 435 | self.adaptive = adaptive 436 | 437 | self.img_id2idx = cPickle.load( 438 | open(os.path.join(dataroot, '%s%s_imgid2idx.pkl' % (name, '' if self.adaptive else '36')), 'rb')) 439 | 440 | self.features = features 441 | self.spatials = spatials 442 | if self.adaptive: 443 | self.pos_boxes = pos_boxes 444 | 445 | self.entries = _load_visualgenome(dataroot, name, self.img_id2idx, self.label2ans) 446 | self.tokenize() 447 | self.tensorize() 448 | self.v_dim = self.features.size(1 if self.adaptive else 2) 449 | self.s_dim = self.spatials.size(1 if self.adaptive else 2) 450 | 451 | def tokenize(self, max_length=14): 452 | """Tokenizes the questions. 453 | 454 | This will add q_token in each entry of the dataset. 455 | -1 represent nil, and should be treated as padding_idx in embedding 456 | """ 457 | for entry in self.entries: 458 | tokens = self.dictionary.tokenize(entry['question'], False) 459 | tokens = tokens[:max_length] 460 | if len(tokens) < max_length: 461 | # Note here we pad in front of the sentence 462 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 463 | tokens = tokens + padding 464 | utils.assert_eq(len(tokens), max_length) 465 | entry['q_token'] = tokens 466 | 467 | def tensorize(self): 468 | for entry in self.entries: 469 | question = torch.from_numpy(np.array(entry['q_token'])) 470 | entry['q_token'] = question 471 | 472 | answer = entry['answer'] 473 | labels = np.array(answer['labels']) 474 | scores = np.array(answer['scores'], dtype=np.float32) 475 | if len(labels): 476 | labels = torch.from_numpy(labels) 477 | scores = torch.from_numpy(scores) 478 | entry['answer']['labels'] = labels 479 | entry['answer']['scores'] = scores 480 | else: 481 | entry['answer']['labels'] = None 482 | entry['answer']['scores'] = None 483 | 484 | def __getitem__(self, index): 485 | entry = self.entries[index] 486 | if not self.adaptive: 487 | features = self.features[entry['image']] 488 | spatials = self.spatials[entry['image']] 489 | else: 490 | features = self.features[self.pos_boxes[entry['image']][0]:self.pos_boxes[entry['image']][1],:] 491 | spatials = self.spatials[self.pos_boxes[entry['image']][0]:self.pos_boxes[entry['image']][1],:] 492 | 493 | question = entry['q_token'] 494 | question_id = entry['question_id'] 495 | answer = entry['answer'] 496 | labels = answer['labels'] 497 | scores = answer['scores'] 498 | target = torch.zeros(self.num_ans_candidates) 499 | if labels is not None: 500 | target.scatter_(0, labels, scores) 501 | return features, spatials, question, target 502 | 503 | def __len__(self): 504 | return len(self.entries) 505 | 506 | 507 | class Flickr30kFeatureDataset(Dataset): 508 | def __init__(self, name, dictionary, dataroot='data/flickr30k/'): 509 | super(Flickr30kFeatureDataset, self).__init__() 510 | 511 | self.num_ans_candidates = 100 512 | 513 | self.dictionary = dictionary 514 | 515 | self.img_id2idx = cPickle.load( 516 | open(os.path.join(dataroot, '%s_imgid2idx.pkl' % name), 'rb')) 517 | 518 | h5_path = os.path.join(dataroot, '%s.hdf5' % name) 519 | 520 | print('loading features from h5 file') 521 | with h5py.File(h5_path, 'r') as hf: 522 | self.features = np.array(hf.get('image_features')) 523 | self.spatials = np.array(hf.get('spatial_features')) 524 | self.bbox = np.array(hf.get('image_bb')) 525 | self.pos_boxes = np.array(hf.get('pos_boxes')) 526 | 527 | self.entries = _load_flickr30k(dataroot, self.img_id2idx, self.bbox, self.pos_boxes) 528 | self.tokenize() 529 | self.tensorize(self.num_ans_candidates) 530 | self.v_dim = self.features.size(1) 531 | self.s_dim = self.spatials.size(1) 532 | 533 | def tokenize(self, max_length=82): 534 | """Tokenizes the questions. 535 | 536 | This will add q_token in each entry of the dataset. 537 | -1 represent nil, and should be treated as padding_idx in embedding 538 | """ 539 | for entry in self.entries: 540 | tokens = self.dictionary.tokenize(entry['sentence'], False) 541 | tokens = tokens[:max_length] 542 | if len(tokens) < max_length: 543 | # Note here we pad in front of the sentence 544 | padding = [self.dictionary.padding_idx] * (max_length - len(tokens)) 545 | tokens = tokens + padding 546 | utils.assert_eq(len(tokens), max_length) 547 | entry['p_token'] = tokens 548 | 549 | def tensorize(self, max_box=100, max_entities=16, max_length=82): 550 | self.features = torch.from_numpy(self.features) 551 | self.spatials = torch.from_numpy(self.spatials) 552 | 553 | for entry in self.entries: 554 | phrase = torch.from_numpy(np.array(entry['p_token'])) 555 | entry['p_token'] = phrase 556 | 557 | assert len(entry['target_indices']) == entry['entity_num'] 558 | assert len(entry['entity_indices']) == entry['entity_num'] 559 | 560 | target_tensors = [] 561 | for i in range(entry['entity_num']): 562 | target_tensor = torch.zeros(1, max_box) 563 | if len(entry['target_indices'][i]) > 0: 564 | target_idx = torch.from_numpy(np.array(entry['target_indices'][i])) 565 | target_tensor = torch.zeros(max_box).scatter_(0, target_idx, 1).unsqueeze(0) 566 | target_tensors.append(target_tensor) 567 | assert len(target_tensors) <= max_entities, '> %d entities!' % max_entities 568 | for i in range(max_entities - len(target_tensors)): 569 | target_tensor = torch.zeros(1, max_box) 570 | target_tensors.append(target_tensor) 571 | entry['entity_ids'].append(0) 572 | # padding entity_indices with non-overlapping indices 573 | entry['entity_indices'] += [x for x in range(max_length) if x not in entry['entity_indices']] 574 | entry['entity_indices'] = entry['entity_indices'][:max_entities] 575 | entry['target'] = torch.cat(target_tensors, 0) 576 | # entity positions in (e) tensor 577 | entry['e_pos'] = torch.LongTensor(entry['entity_indices']) 578 | entry['e_num'] = torch.LongTensor([entry['entity_num']]) 579 | entry['entity_ids'] = torch.LongTensor(entry['entity_ids']) 580 | entry['entity_types'] = torch.LongTensor(entry['entity_types']) 581 | 582 | def __getitem__(self, index): 583 | entry = self.entries[index] 584 | features = self.features[self.pos_boxes[entry['image']][0]:self.pos_boxes[entry['image']][1], :] 585 | spatials = self.spatials[self.pos_boxes[entry['image']][0]:self.pos_boxes[entry['image']][1], :] 586 | 587 | sentence = entry['p_token'] 588 | e_pos = entry['e_pos'] 589 | e_num = entry['e_num'] 590 | target = entry['target'] 591 | entity_ids = entry['entity_ids'] 592 | entity_types = entry['entity_types'] 593 | 594 | return features, spatials, sentence, e_pos, e_num, target, entity_ids, entity_types 595 | 596 | def __len__(self): 597 | return len(self.entries) 598 | 599 | 600 | 601 | 602 | def tfidf_from_questions(names, dictionary, dataroot='data', target=['vqa', 'vg', 'cap', 'flickr']): 603 | inds = [[], []] # rows, cols for uncoalesce sparse matrix 604 | df = dict() 605 | N = len(dictionary) 606 | 607 | def populate(inds, df, text): 608 | tokens = dictionary.tokenize(text, True) 609 | for t in tokens: 610 | df[t] = df.get(t, 0) + 1 611 | combin = list(itertools.combinations(tokens, 2)) 612 | for c in combin: 613 | if c[0] < N: 614 | inds[0].append(c[0]); inds[1].append(c[1]) 615 | if c[1] < N: 616 | inds[0].append(c[1]); inds[1].append(c[0]) 617 | 618 | if 'vqa' in target: # VQA 2.0 619 | for name in names: 620 | assert name in ['train', 'val', 'test-dev2015', 'test2015'] 621 | question_path = os.path.join( 622 | dataroot, 'v2_OpenEnded_mscoco_%s_questions.json' % \ 623 | (name + '2014' if 'test'!=name[:4] else name)) 624 | questions = json.load(open(question_path))['questions'] 625 | 626 | for question in questions: 627 | populate(inds, df, question['question']) 628 | 629 | if 'vg' in target: # Visual Genome 630 | question_path = os.path.join(dataroot, 'question_answers.json') 631 | vgq = json.load(open(question_path, 'r')) 632 | for vg in vgq: 633 | for q in vg['qas']: 634 | populate(inds, df, q['question']) 635 | 636 | if 'cap' in target: # MSCOCO Caption 637 | for split in ['train2017', 'val2017']: 638 | captions = json.load(open('data/annotations/captions_%s.json' % split, 'r')) 639 | for caps in captions['annotations']: 640 | populate(inds, df, caps['caption']) 641 | 642 | # TF-IDF 643 | vals = [1] * len(inds[1]) 644 | for idx, col in enumerate(inds[1]): 645 | assert df[col] >= 1, 'document frequency should be greater than zero!' 646 | vals[col] /= df[col] 647 | 648 | # Make stochastic matrix 649 | def normalize(inds, vals): 650 | z = dict() 651 | for row, val in zip(inds[0], vals): 652 | z[row] = z.get(row, 0) + val 653 | for idx, row in enumerate(inds[0]): 654 | vals[idx] /= z[row] 655 | return vals 656 | 657 | vals = normalize(inds, vals) 658 | 659 | tfidf = torch.sparse.FloatTensor(torch.LongTensor(inds), torch.FloatTensor(vals)) 660 | tfidf = tfidf.coalesce() 661 | 662 | # Latent word embeddings 663 | emb_dim = 300 664 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 665 | weights, word2emb = utils.create_glove_embedding_init(dictionary.idx2word[N:], glove_file) 666 | print('tf-idf stochastic matrix (%d x %d) is generated.' % (tfidf.size(0), tfidf.size(1))) 667 | 668 | return tfidf, weights 669 | 670 | 671 | if __name__=='__main__': 672 | dictionary = Dictionary.load_from_file('data/flickr30k/dictionary.pkl') 673 | tfidf, weights = tfidf_from_questions(['train', 'val', 'test2015'], dictionary) 674 | 675 | if __name__=='__main2__': 676 | from torch.utils.data import DataLoader 677 | 678 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 679 | train_dset = VQAFeatureDataset('val', dictionary, adaptive=True) 680 | # name = 'train' 681 | # eval_dset = VQAFeatureDataset(name, dictionary) 682 | # vg_dset = VisualGenomeFeatureDataset(name, eval_dset.features, eval_dset.spatials, dictionary) 683 | 684 | # train_loader = DataLoader(vg_dset, 10, shuffle=True, num_workers=1) 685 | 686 | loader = DataLoader(train_dset, 10, shuffle=True, num_workers=1, collate_fn=utils.trim_collate) 687 | for i, (v, b, q, a) in enumerate(loader): 688 | print(v.size()) 689 | 690 | # VisualGenome Train 691 | # Used COCO images: 51487/108077 (0.4764) 692 | # Out-of-split COCO images: 17464/51487 (0.3392) 693 | # Used VG questions: 325311/726932 (0.4475) 694 | 695 | # VisualGenome Val 696 | # Used COCO images: 51487/108077 (0.4764) 697 | # Out-of-split COCO images: 34023/51487 (0.6608) 698 | # Used VG questions: 166409/726932 (0.2289) 699 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | import numpy as np 11 | 12 | from dataset import Dictionary, VQAFeatureDataset, Flickr30kFeatureDataset 13 | import base_model 14 | import utils 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--task', type=str, default='vqa', help='vqa or flickr') 20 | parser.add_argument('--num_hid', type=int, default=1280) 21 | parser.add_argument('--model', type=str, default='ban') 22 | parser.add_argument('--op', type=str, default='c') 23 | parser.add_argument('--gamma', type=int, default=8) 24 | parser.add_argument('--input', type=str, default='saved_models/ban') 25 | parser.add_argument('--epoch', type=int, default=12) 26 | parser.add_argument('--batch_size', type=int, default=256) 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | if __name__ == '__main__': 32 | print('Evaluate a given model optimized by training split using validation split.') 33 | args = parse_args() 34 | 35 | torch.backends.cudnn.benchmark = True 36 | 37 | if args.task == 'vqa': 38 | from train import evaluate 39 | dict_path = 'data/dictionary.pkl' 40 | dictionary = Dictionary.load_from_file(dict_path) 41 | eval_dset = VQAFeatureDataset('val', dictionary, adaptive=True) 42 | 43 | elif args.task == 'flickr': 44 | from train_flickr import evaluate 45 | dict_path = 'data/flickr30k/dictionary.pkl' 46 | dictionary = Dictionary.load_from_file(dict_path) 47 | eval_dset = Flickr30kFeatureDataset('test', dictionary) 48 | args.op = '' 49 | args.gamma = 1 50 | 51 | n_device = torch.cuda.device_count() 52 | batch_size = args.batch_size * n_device 53 | 54 | constructor = 'build_%s' % args.model 55 | model = getattr(base_model, constructor)(eval_dset, args.num_hid, args.op, args.gamma, args.task).cuda() 56 | model_data = torch.load(args.input+'/model'+('_epoch%d' % args.epoch if 0 < args.epoch else '')+'.pth') 57 | 58 | model = nn.DataParallel(model).cuda() 59 | model.load_state_dict(model_data.get('model_state', model_data)) 60 | 61 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=True, num_workers=1, collate_fn=utils.trim_collate) 62 | model.train(False) 63 | 64 | eval_score, bound, entropy = evaluate(model, eval_loader) 65 | if args.task == 'vqa': 66 | print('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound)) 67 | elif args.task == 'flickr': 68 | print('\teval score: %.2f/%.2f/%.2f (%.2f)' % ( 69 | 100 * eval_score[0], 100 * eval_score[1], 100 * eval_score[2], 100 * bound)) 70 | 71 | 72 | -------------------------------------------------------------------------------- /fc.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | import torch.nn as nn 7 | from torch.nn.utils.weight_norm import weight_norm 8 | 9 | 10 | class FCNet(nn.Module): 11 | """Simple class for non-linear fully connect network 12 | """ 13 | def __init__(self, dims, act='ReLU', dropout=0): 14 | super(FCNet, self).__init__() 15 | 16 | layers = [] 17 | for i in range(len(dims)-2): 18 | in_dim = dims[i] 19 | out_dim = dims[i+1] 20 | if 0 < dropout: 21 | layers.append(nn.Dropout(dropout)) 22 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 23 | if ''!=act: 24 | layers.append(getattr(nn, act)()) 25 | if 0 < dropout: 26 | layers.append(nn.Dropout(dropout)) 27 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) 28 | if ''!=act: 29 | layers.append(getattr(nn, act)()) 30 | 31 | self.main = nn.Sequential(*layers) 32 | 33 | def forward(self, x): 34 | return self.main(x) 35 | 36 | 37 | if __name__ == '__main__': 38 | fc1 = FCNet([10, 20, 10]) 39 | print(fc1) 40 | 41 | print('============') 42 | fc2 = FCNet([10, 20]) 43 | print(fc2) 44 | -------------------------------------------------------------------------------- /language_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | 10 | class WordEmbedding(nn.Module): 11 | """Word Embedding 12 | 13 | The ntoken-th dim is used for padding_idx, which agrees *implicitly* 14 | with the definition in Dictionary. 15 | """ 16 | def __init__(self, ntoken, emb_dim, dropout, op=''): 17 | super(WordEmbedding, self).__init__() 18 | self.op = op 19 | self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken) 20 | if 'c' in op: 21 | self.emb_ = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken) 22 | self.emb_.weight.requires_grad = False # fixed 23 | self.dropout = nn.Dropout(dropout) 24 | self.ntoken = ntoken 25 | self.emb_dim = emb_dim 26 | 27 | def init_embedding(self, np_file, tfidf=None, tfidf_weights=None): 28 | weight_init = torch.from_numpy(np.load(np_file)) 29 | assert weight_init.shape == (self.ntoken, self.emb_dim) 30 | self.emb.weight.data[:self.ntoken] = weight_init 31 | if tfidf is not None: 32 | if 0 < tfidf_weights.size: 33 | weight_init = torch.cat([weight_init, torch.from_numpy(tfidf_weights)], 0) 34 | weight_init = tfidf.matmul(weight_init) # (N x N') x (N', F) 35 | if 'c' in self.op: 36 | self.emb_.weight.requires_grad = True 37 | if 'c' in self.op: 38 | self.emb_.weight.data[:self.ntoken] = weight_init.clone() 39 | 40 | def forward(self, x): 41 | emb = self.emb(x) 42 | if 'c' in self.op: 43 | emb = torch.cat((emb, self.emb_(x)), 2) 44 | emb = self.dropout(emb) 45 | return emb 46 | 47 | 48 | class QuestionEmbedding(nn.Module): 49 | def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'): 50 | """Module for question embedding 51 | """ 52 | super(QuestionEmbedding, self).__init__() 53 | assert rnn_type == 'LSTM' or rnn_type == 'GRU' 54 | rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU if rnn_type == 'GRU' else None 55 | 56 | self.rnn = rnn_cls( 57 | in_dim, num_hid, nlayers, 58 | bidirectional=bidirect, 59 | dropout=dropout, 60 | batch_first=True) 61 | 62 | self.in_dim = in_dim 63 | self.num_hid = num_hid 64 | self.nlayers = nlayers 65 | self.rnn_type = rnn_type 66 | self.ndirections = 1 + int(bidirect) 67 | 68 | def init_hidden(self, batch): 69 | # just to get the type of tensor 70 | weight = next(self.parameters()).data 71 | hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid) 72 | if self.rnn_type == 'LSTM': 73 | return (weight.new(*hid_shape).zero_(), 74 | weight.new(*hid_shape).zero_()) 75 | else: 76 | return weight.new(*hid_shape).zero_() 77 | 78 | def forward(self, x): 79 | # x: [batch, sequence, in_dim] 80 | batch = x.size(0) 81 | hidden = self.init_hidden(batch) 82 | output, hidden = self.rnn(x, hidden) 83 | 84 | if self.ndirections == 1: 85 | return output[:, -1] 86 | 87 | forward_ = output[:, -1, :self.num_hid] 88 | backward = output[:, 0, self.num_hid:] 89 | return torch.cat((forward_, backward), dim=1) 90 | 91 | def forward_all(self, x): 92 | # x: [batch, sequence, in_dim] 93 | batch = x.size(0) 94 | hidden = self.init_hidden(batch) 95 | output, hidden = self.rnn(x, hidden) 96 | return output 97 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | import os 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader, ConcatDataset 10 | import numpy as np 11 | 12 | from dataset import Dictionary, VQAFeatureDataset, VisualGenomeFeatureDataset, Flickr30kFeatureDataset 13 | import base_model 14 | import utils 15 | from utils import trim_collate 16 | from dataset import tfidf_from_questions 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--task', type=str, default='vqa', help='vqa or flickr') 22 | parser.add_argument('--epochs', type=int, default=13) 23 | parser.add_argument('--num_hid', type=int, default=1280) 24 | parser.add_argument('--model', type=str, default='ban') 25 | parser.add_argument('--op', type=str, default='c') 26 | parser.add_argument('--gamma', type=int, default=8, help='glimpse') 27 | parser.add_argument('--use_both', action='store_true', help='use both train/val datasets to train?') 28 | parser.add_argument('--use_vg', action='store_true', help='use visual genome dataset to train?') 29 | parser.add_argument('--tfidf', action='store_false', help='tfidf word embedding?') 30 | parser.add_argument('--input', type=str, default=None) 31 | parser.add_argument('--output', type=str, default='saved_models/ban') 32 | parser.add_argument('--batch_size', type=int, default=256) 33 | parser.add_argument('--seed', type=int, default=1204, help='random seed') 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | if __name__ == '__main__': 39 | args = parse_args() 40 | 41 | torch.manual_seed(args.seed) 42 | torch.cuda.manual_seed_all(args.seed) 43 | torch.backends.cudnn.benchmark = True 44 | 45 | if args.task == 'vqa': 46 | from train import train 47 | dict_path = 'data/dictionary.pkl' 48 | dictionary = Dictionary.load_from_file(dict_path) 49 | train_dset = VQAFeatureDataset('train', dictionary, adaptive=True) 50 | val_dset = VQAFeatureDataset('val', dictionary, adaptive=True) 51 | w_emb_path = 'data/glove6b_init_300d.npy' 52 | 53 | elif args.task == 'flickr': 54 | from train_flickr import train 55 | dict_path = 'data/flickr30k/dictionary.pkl' 56 | dictionary = Dictionary.load_from_file(dict_path) 57 | train_dset = Flickr30kFeatureDataset('train', dictionary) 58 | val_dset = Flickr30kFeatureDataset('val', dictionary) 59 | w_emb_path = 'data/flickr30k/glove6b_init_300d.npy' 60 | args.op = '' 61 | args.gamma = 1 62 | args.tfidf = False 63 | 64 | utils.create_dir(args.output) 65 | logger = utils.Logger(os.path.join(args.output, 'args.txt')) 66 | logger.write(args.__repr__()) 67 | 68 | batch_size = args.batch_size 69 | 70 | constructor = 'build_%s' % args.model 71 | model = getattr(base_model, constructor)(train_dset, args.num_hid, args.op, args.gamma, args.task).cuda() 72 | 73 | 74 | tfidf = None 75 | weights = None 76 | 77 | if args.tfidf: 78 | dict = Dictionary.load_from_file(dict_path) 79 | tfidf, weights = tfidf_from_questions(['train', 'val', 'test2015'], dict) 80 | 81 | model.w_emb.init_embedding(w_emb_path, tfidf, weights) 82 | 83 | model = nn.DataParallel(model).cuda() 84 | 85 | optim = None 86 | epoch = 0 87 | 88 | # load snapshot 89 | if args.input is not None: 90 | print('loading %s' % args.input) 91 | model_data = torch.load(args.input) 92 | model.load_state_dict(model_data.get('model_state', model_data)) 93 | optim = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters())) 94 | optim.load_state_dict(model_data.get('optimizer_state', model_data)) 95 | epoch = model_data['epoch'] + 1 96 | 97 | if args.task == 'vqa': 98 | if args.use_both: # use train & val splits to optimize 99 | if args.use_vg: # use a portion of Visual Genome dataset 100 | vg_dsets = [ 101 | VisualGenomeFeatureDataset('train', \ 102 | train_dset.features, train_dset.spatials, dictionary, adaptive=True, pos_boxes=train_dset.pos_boxes), 103 | VisualGenomeFeatureDataset('val', \ 104 | val_dset.features, val_dset.spatials, dictionary, adaptive=True, pos_boxes=val_dset.pos_boxes)] 105 | trainval_dset = ConcatDataset([train_dset, val_dset]+vg_dsets) 106 | else: 107 | trainval_dset = ConcatDataset([train_dset, val_dset]) 108 | train_loader = DataLoader(trainval_dset, batch_size, shuffle=True, num_workers=1, collate_fn=utils.trim_collate) 109 | eval_loader = None 110 | else: 111 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=1, collate_fn=utils.trim_collate) 112 | eval_loader = DataLoader(val_dset, batch_size, shuffle=False, num_workers=1, collate_fn=utils.trim_collate) 113 | 114 | elif args.task == 'flickr': 115 | train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=1, collate_fn=utils.trim_collate) 116 | eval_loader = DataLoader(val_dset, batch_size, shuffle=False, num_workers=1, collate_fn=utils.trim_collate) 117 | 118 | train(model, train_loader, eval_loader, args.epochs, args.output, optim, epoch) 119 | -------------------------------------------------------------------------------- /misc/ban_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jnhwkim/ban-vqa/ef4cc36d63f6b31344301ac1bd91afd807d132e5/misc/ban_overview.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | import argparse 6 | import json 7 | import progressbar 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | 13 | from dataset import Dictionary, VQAFeatureDataset 14 | import base_model 15 | import utils 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--num_hid', type=int, default=1280) 21 | parser.add_argument('--model', type=str, default='ban') 22 | parser.add_argument('--op', type=str, default='c') 23 | parser.add_argument('--label', type=str, default='') 24 | parser.add_argument('--gamma', type=int, default=8) 25 | parser.add_argument('--split', type=str, default='test2015') 26 | parser.add_argument('--input', type=str, default='saved_models/ban') 27 | parser.add_argument('--output', type=str, default='results') 28 | parser.add_argument('--batch_size', type=int, default=256) 29 | parser.add_argument('--debug', action='store_true') 30 | parser.add_argument('--logits', action='store_true') 31 | parser.add_argument('--index', type=int, default=0) 32 | parser.add_argument('--epoch', type=int, default=12) 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def get_question(q, dataloader): 38 | str = [] 39 | dictionary = dataloader.dataset.dictionary 40 | for i in range(q.size(0)): 41 | str.append(dictionary.idx2word[q[i]] if q[i] < len(dictionary.idx2word) else '_') 42 | return ' '.join(str) 43 | 44 | 45 | def get_answer(p, dataloader): 46 | _m, idx = p.max(0) 47 | return dataloader.dataset.label2ans[idx.item()] 48 | 49 | 50 | @torch.no_grad() 51 | def get_logits(model, dataloader): 52 | N = len(dataloader.dataset) 53 | M = dataloader.dataset.num_ans_candidates 54 | pred = torch.FloatTensor(N, M).zero_() 55 | qIds = torch.IntTensor(N).zero_() 56 | idx = 0 57 | bar = progressbar.ProgressBar(max_value=N) 58 | for v, b, q, i in iter(dataloader): 59 | bar.update(idx) 60 | batch_size = v.size(0) 61 | v = v.cuda() 62 | b = b.cuda() 63 | q = q.cuda() 64 | logits, att = model(v, b, q, None) 65 | pred[idx:idx+batch_size,:].copy_(logits.data) 66 | qIds[idx:idx+batch_size].copy_(i) 67 | idx += batch_size 68 | if args.debug: 69 | print(get_question(q.data[0], dataloader)) 70 | print(get_answer(logits.data[0], dataloader)) 71 | bar.update(idx) 72 | return pred, qIds 73 | 74 | 75 | def make_json(logits, qIds, dataloader): 76 | utils.assert_eq(logits.size(0), len(qIds)) 77 | results = [] 78 | for i in range(logits.size(0)): 79 | result = {} 80 | result['question_id'] = qIds[i].item() 81 | result['answer'] = get_answer(logits[i], dataloader) 82 | results.append(result) 83 | return results 84 | 85 | if __name__ == '__main__': 86 | args = parse_args() 87 | 88 | torch.backends.cudnn.benchmark = True 89 | 90 | dictionary = Dictionary.load_from_file('data/dictionary.pkl') 91 | eval_dset = VQAFeatureDataset(args.split, dictionary, adaptive=True) 92 | 93 | n_device = torch.cuda.device_count() 94 | batch_size = args.batch_size * n_device 95 | 96 | constructor = 'build_%s' % args.model 97 | model = getattr(base_model, constructor)(eval_dset, args.num_hid, args.op, args.gamma).cuda() 98 | eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=1, collate_fn=utils.trim_collate) 99 | 100 | def process(args, model, eval_loader): 101 | model_path = args.input+'/model%s.pth' % \ 102 | ('' if 0 > args.epoch else '_epoch%d' % args.epoch) 103 | 104 | print('loading %s' % model_path) 105 | model_data = torch.load(model_path) 106 | 107 | model = nn.DataParallel(model).cuda() 108 | model.load_state_dict(model_data.get('model_state', model_data)) 109 | 110 | model.train(False) 111 | 112 | logits, qIds = get_logits(model, eval_loader) 113 | results = make_json(logits, qIds, eval_loader) 114 | model_label = '%s%s%d_%s' % (args.model, args.op, args.num_hid, args.label) 115 | 116 | if args.logits: 117 | utils.create_dir('logits/'+model_label) 118 | torch.save(logits, 'logits/'+model_label+'/logits%d.pth' % args.index) 119 | 120 | utils.create_dir(args.output) 121 | if 0 <= args.epoch: 122 | model_label += '_epoch%d' % args.epoch 123 | 124 | with open(args.output+'/%s_%s.json' \ 125 | % (args.split, model_label), 'w') as f: 126 | json.dump(results, f) 127 | 128 | process(args, model, eval_loader) 129 | -------------------------------------------------------------------------------- /tools/adaptive_detection_features_converter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | 5 | Reads in a tsv file with pre-trained bottom up attention features 6 | of the adaptive number of boxes and stores it in HDF5 format. 7 | Also store {image_id: feature_idx} as a pickle file. 8 | 9 | Hierarchy of HDF5 file: 10 | 11 | { 'image_features': num_boxes x 2048 12 | 'image_bb': num_boxes x 4 13 | 'spatial_features': num_boxes x 6 14 | 'pos_boxes': num_images x 2 } 15 | """ 16 | from __future__ import print_function 17 | 18 | import os 19 | import argparse 20 | import sys 21 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 22 | 23 | import base64 24 | import csv 25 | import h5py 26 | import _pickle as cPickle 27 | import numpy as np 28 | import utils 29 | 30 | csv.field_size_limit(sys.maxsize) 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--task', type=str, default='vqa', help='vqa or flickr') 36 | args = parser.parse_args() 37 | return args 38 | 39 | def extract(split, infiles, task='vqa'): 40 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features'] 41 | if task == 'vqa': 42 | data_file = { 43 | 'train': 'data/train.hdf5', 44 | 'val': 'data/val.hdf5', 45 | 'test': 'data/test2015.hdf5'} 46 | indices_file = { 47 | 'train': 'data/train_imgid2idx.pkl', 48 | 'val': 'data/val_imgid2idx.pkl', 49 | 'test': 'data/test2015_imgid2idx.pkl'} 50 | ids_file = { 51 | 'train': 'data/train_ids.pkl', 52 | 'val': 'data/val_ids.pkl', 53 | 'test': 'data/test2015_ids.pkl'} 54 | path_imgs = { 55 | 'train': 'data/train2014', 56 | 'val': 'data/val2014', 57 | 'test': 'data/test2015'} 58 | known_num_boxes = { 59 | 'train': 2643089, 60 | 'val': 1281164, 61 | 'test': 2566887,} 62 | 63 | elif task == 'flickr': 64 | data_file = { 65 | 'train': 'data/flickr30k/train.hdf5', 66 | 'val': 'data/flickr30k/val.hdf5', 67 | 'test': 'data/flickr30k/test.hdf5'} 68 | indices_file = { 69 | 'train': 'data/flickr30k/train_imgid2idx.pkl', 70 | 'val': 'data/flickr30k/val_imgid2idx.pkl', 71 | 'test': 'data/flickr30k/test_imgid2idx.pkl'} 72 | ids_file = { 73 | 'train': 'data/flickr30k/train_ids.pkl', 74 | 'val': 'data/flickr30k/val_ids.pkl', 75 | 'test': 'data/flickr30k/test_ids.pkl'} 76 | path_imgs = { 77 | 'train': 'data/flickr30k/flickr30k_images', 78 | 'val': 'data/flickr30k/flickr30k_images', 79 | 'test': 'data/flickr30k/flickr30k_images'} 80 | known_num_boxes = { 81 | 'train': 903500, 82 | 'val': 30722, 83 | 'test': 30648,} 84 | 85 | feature_length = 2048 86 | min_fixed_boxes = 10 87 | max_fixed_boxes = 100 88 | 89 | if os.path.exists(ids_file[split]): 90 | imgids = cPickle.load(open(ids_file[split], 'rb')) 91 | else: 92 | imgids = utils.load_imageid(path_imgs[split]) 93 | cPickle.dump(imgids, open(ids_file[split], 'wb')) 94 | 95 | h = h5py.File(data_file[split], 'w') 96 | 97 | if known_num_boxes[split] is None: 98 | num_boxes = 0 99 | for infile in infiles: 100 | print("reading tsv...%s" % infile) 101 | with open(infile, "r+") as tsv_in_file: 102 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) 103 | for item in reader: 104 | item['num_boxes'] = int(item['num_boxes']) 105 | image_id = int(item['image_id']) 106 | if image_id in imgids: 107 | num_boxes += item['num_boxes'] 108 | else: 109 | num_boxes = known_num_boxes[split] 110 | 111 | print('num_boxes=%d' % num_boxes) 112 | 113 | img_features = h.create_dataset( 114 | 'image_features', (num_boxes, feature_length), 'f') 115 | img_bb = h.create_dataset( 116 | 'image_bb', (num_boxes, 4), 'f') 117 | spatial_img_features = h.create_dataset( 118 | 'spatial_features', (num_boxes, 6), 'f') 119 | pos_boxes = h.create_dataset( 120 | 'pos_boxes', (len(imgids), 2), dtype='int32') 121 | 122 | counter = 0 123 | num_boxes = 0 124 | indices = {} 125 | 126 | for infile in infiles: 127 | unknown_ids = [] 128 | print("reading tsv...%s" % infile) 129 | with open(infile, "r+") as tsv_in_file: 130 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) 131 | for item in reader: 132 | item['num_boxes'] = int(item['num_boxes']) 133 | item['boxes'] = bytes(item['boxes'], 'utf') 134 | item['features'] = bytes(item['features'], 'utf') 135 | image_id = int(item['image_id']) 136 | image_w = float(item['image_w']) 137 | image_h = float(item['image_h']) 138 | bboxes = np.frombuffer( 139 | base64.decodestring(item['boxes']), 140 | dtype=np.float32).reshape((item['num_boxes'], -1)) 141 | 142 | box_width = bboxes[:, 2] - bboxes[:, 0] 143 | box_height = bboxes[:, 3] - bboxes[:, 1] 144 | scaled_width = box_width / image_w 145 | scaled_height = box_height / image_h 146 | scaled_x = bboxes[:, 0] / image_w 147 | scaled_y = bboxes[:, 1] / image_h 148 | 149 | box_width = box_width[..., np.newaxis] 150 | box_height = box_height[..., np.newaxis] 151 | scaled_width = scaled_width[..., np.newaxis] 152 | scaled_height = scaled_height[..., np.newaxis] 153 | scaled_x = scaled_x[..., np.newaxis] 154 | scaled_y = scaled_y[..., np.newaxis] 155 | 156 | spatial_features = np.concatenate( 157 | (scaled_x, 158 | scaled_y, 159 | scaled_x + scaled_width, 160 | scaled_y + scaled_height, 161 | scaled_width, 162 | scaled_height), 163 | axis=1) 164 | 165 | if image_id in imgids: 166 | imgids.remove(image_id) 167 | indices[image_id] = counter 168 | pos_boxes[counter,:] = np.array([num_boxes, num_boxes + item['num_boxes']]) 169 | img_bb[num_boxes:num_boxes+item['num_boxes'], :] = bboxes 170 | img_features[num_boxes:num_boxes+item['num_boxes'], :] = np.frombuffer( 171 | base64.decodestring(item['features']), 172 | dtype=np.float32).reshape((item['num_boxes'], -1)) 173 | spatial_img_features[num_boxes:num_boxes+item['num_boxes'], :] = spatial_features 174 | counter += 1 175 | num_boxes += item['num_boxes'] 176 | else: 177 | unknown_ids.append(image_id) 178 | 179 | print('%d unknown_ids...' % len(unknown_ids)) 180 | print('%d image_ids left...' % len(imgids)) 181 | 182 | if len(imgids) != 0: 183 | print('Warning: %s_image_ids is not empty' % split) 184 | 185 | cPickle.dump(indices, open(indices_file[split], 'wb')) 186 | h.close() 187 | print("done!") 188 | 189 | if __name__ == '__main__': 190 | args = parse_args() 191 | 192 | if args.task == 'vqa': 193 | infiles = ['data/trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv', 194 | 'data/trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', 195 | 'data/trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1', 196 | 'data/trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv'] 197 | extract('train', infiles, args.task) 198 | extract('val', infiles, args.task) 199 | infiles = ['data/test2015/test2015_resnet101_faster_rcnn_genome.tsv'] 200 | extract('test', infiles, args.task) 201 | elif args.task == 'flickr': 202 | infiles = ['data/flickr30k/train_flickr30k_resnet101_faster_rcnn_genome.tsv.1', 203 | 'data/flickr30k/train_flickr30k_resnet101_faster_rcnn_genome.tsv.2'] 204 | extract('train', infiles, args.task) 205 | infiles = ['data/flickr30k/val_flickr30k_resnet101_faster_rcnn_genome.tsv.3'] 206 | extract('val', infiles, args.task) 207 | infiles = ['data/flickr30k/test_flickr30k_resnet101_faster_rcnn_genome.tsv.3'] 208 | extract('test', infiles, args.task) 209 | -------------------------------------------------------------------------------- /tools/compute_softscore.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is slightly modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import sys 8 | import json 9 | import numpy as np 10 | import re 11 | import _pickle as cPickle 12 | 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14 | # from dataset import Dictionary 15 | import utils 16 | 17 | 18 | contractions = { 19 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 20 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 21 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 22 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 23 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 24 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 25 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 26 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 27 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 28 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 29 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 30 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 31 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 32 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 33 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 34 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 35 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 36 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 37 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 38 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 39 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 40 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 41 | "someonell": "someone'll", "someones": "someone's", "somethingd": 42 | "something'd", "somethingd've": "something'd've", "something'dve": 43 | "something'd've", "somethingll": "something'll", "thats": 44 | "that's", "thered": "there'd", "thered've": "there'd've", 45 | "there'dve": "there'd've", "therere": "there're", "theres": 46 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 47 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 48 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 49 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 50 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 51 | "what's", "whatve": "what've", "whens": "when's", "whered": 52 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 53 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 54 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 55 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 56 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 57 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 58 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 59 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 60 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 61 | "you'll", "youre": "you're", "youve": "you've" 62 | } 63 | 64 | manual_map = { 'none': '0', 65 | 'zero': '0', 66 | 'one': '1', 67 | 'two': '2', 68 | 'three': '3', 69 | 'four': '4', 70 | 'five': '5', 71 | 'six': '6', 72 | 'seven': '7', 73 | 'eight': '8', 74 | 'nine': '9', 75 | 'ten': '10'} 76 | articles = ['a', 'an', 'the'] 77 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 78 | comma_strip = re.compile("(\d)(\,)(\d)") 79 | punct = [';', r"/", '[', ']', '"', '{', '}', 80 | '(', ')', '=', '+', '\\', '_', '-', 81 | '>', '<', '@', '`', ',', '?', '!'] 82 | 83 | 84 | # Notice that VQA score is the average of 10 choose 9 candidate answers cases 85 | # See http://visualqa.org/evaluation.html 86 | def get_score(occurences): 87 | if occurences == 0: 88 | return .0 89 | elif occurences == 1: 90 | return .3 91 | elif occurences == 2: 92 | return .6 93 | elif occurences == 3: 94 | return .9 95 | else: 96 | return 1. 97 | 98 | 99 | def process_punctuation(inText): 100 | outText = inText 101 | for p in punct: 102 | if (p + ' ' in inText or ' ' + p in inText) \ 103 | or (re.search(comma_strip, inText) != None): 104 | outText = outText.replace(p, '') 105 | else: 106 | outText = outText.replace(p, ' ') 107 | outText = period_strip.sub("", outText, re.UNICODE) 108 | return outText 109 | 110 | 111 | def process_digit_article(inText): 112 | outText = [] 113 | tempText = inText.lower().split() 114 | for word in tempText: 115 | word = manual_map.setdefault(word, word) 116 | if word not in articles: 117 | outText.append(word) 118 | else: 119 | pass 120 | for wordId, word in enumerate(outText): 121 | if word in contractions: 122 | outText[wordId] = contractions[word] 123 | outText = ' '.join(outText) 124 | return outText 125 | 126 | 127 | def multiple_replace(text, wordDict): 128 | for key in wordDict: 129 | text = text.replace(key, wordDict[key]) 130 | return text 131 | 132 | 133 | def preprocess_answer(answer): 134 | answer = process_digit_article(process_punctuation(answer)) 135 | answer = answer.replace(',', '') 136 | return answer 137 | 138 | 139 | def filter_answers(answers_dset, min_occurence): 140 | """This will change the answer to preprocessed version 141 | """ 142 | occurence = {} 143 | 144 | for ans_entry in answers_dset: 145 | answers = ans_entry['answers'] 146 | gtruth = ans_entry['multiple_choice_answer'] 147 | gtruth = preprocess_answer(gtruth) 148 | if gtruth not in occurence: 149 | occurence[gtruth] = set() 150 | occurence[gtruth].add(ans_entry['question_id']) 151 | for answer in list(occurence): 152 | if len(occurence[answer]) < min_occurence: 153 | occurence.pop(answer) 154 | 155 | print('Num of answers that appear >= %d times: %d' % ( 156 | min_occurence, len(occurence))) 157 | return occurence 158 | 159 | 160 | def create_ans2label(occurence, name, cache_root='data/cache'): 161 | """Note that this will also create label2ans.pkl at the same time 162 | 163 | occurence: dict {answer -> whatever} 164 | name: prefix of the output file 165 | cache_root: str 166 | """ 167 | ans2label = {} 168 | label2ans = [] 169 | label = 0 170 | for answer in occurence: 171 | label2ans.append(answer) 172 | ans2label[answer] = label 173 | label += 1 174 | 175 | utils.create_dir(cache_root) 176 | 177 | cache_file = os.path.join(cache_root, name+'_ans2label.pkl') 178 | cPickle.dump(ans2label, open(cache_file, 'wb')) 179 | cache_file = os.path.join(cache_root, name+'_label2ans.pkl') 180 | cPickle.dump(label2ans, open(cache_file, 'wb')) 181 | return ans2label 182 | 183 | 184 | def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): 185 | """Augment answers_dset with soft score as label 186 | 187 | ***answers_dset should be preprocessed*** 188 | 189 | Write result into a cache file 190 | """ 191 | target = [] 192 | for ans_entry in answers_dset: 193 | answers = ans_entry['answers'] 194 | answer_count = {} 195 | for answer in answers: 196 | answer_ = answer['answer'] 197 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 198 | 199 | labels = [] 200 | scores = [] 201 | for answer in answer_count: 202 | if answer not in ans2label: 203 | continue 204 | labels.append(ans2label[answer]) 205 | score = get_score(answer_count[answer]) 206 | scores.append(score) 207 | 208 | target.append({ 209 | 'question_id': ans_entry['question_id'], 210 | 'image_id': ans_entry['image_id'], 211 | 'labels': labels, 212 | 'scores': scores 213 | }) 214 | 215 | utils.create_dir(cache_root) 216 | cache_file = os.path.join(cache_root, name+'_target.pkl') 217 | cPickle.dump(target, open(cache_file, 'wb')) 218 | return target 219 | 220 | 221 | def get_answer(qid, answers): 222 | for ans in answers: 223 | if ans['question_id'] == qid: 224 | return ans 225 | 226 | 227 | def get_question(qid, questions): 228 | for question in questions: 229 | if question['question_id'] == qid: 230 | return question 231 | 232 | 233 | if __name__ == '__main__': 234 | train_answer_file = 'data/v2_mscoco_train2014_annotations.json' 235 | train_answers = json.load(open(train_answer_file))['annotations'] 236 | 237 | val_answer_file = 'data/v2_mscoco_val2014_annotations.json' 238 | val_answers = json.load(open(val_answer_file))['annotations'] 239 | 240 | train_question_file = 'data/v2_OpenEnded_mscoco_train2014_questions.json' 241 | train_questions = json.load(open(train_question_file))['questions'] 242 | 243 | val_question_file = 'data/v2_OpenEnded_mscoco_val2014_questions.json' 244 | val_questions = json.load(open(val_question_file))['questions'] 245 | 246 | answers = train_answers + val_answers 247 | occurence = filter_answers(answers, 9) 248 | 249 | cache_path = 'data/cache/trainval_ans2label.pkl' 250 | if os.path.isfile(cache_path): 251 | print('found %s' % cache_path) 252 | ans2label = cPickle.load(open(cache_path, 'rb')) 253 | else: 254 | ans2label = create_ans2label(occurence, 'trainval') 255 | compute_target(train_answers, ans2label, 'train') 256 | compute_target(val_answers, ans2label, 'val') 257 | -------------------------------------------------------------------------------- /tools/create_dictionary.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import argparse 8 | import sys 9 | import json 10 | import _pickle as cPickle 11 | import numpy as np 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from dataset import Dictionary 14 | from utils import get_sent_data 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--task', type=str, default='vqa', help='vqa or flickr') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | 25 | def create_dictionary(dataroot, task='vqa'): 26 | dictionary = Dictionary() 27 | if task == 'vqa': 28 | files = [ 29 | 'v2_OpenEnded_mscoco_train2014_questions.json', 30 | 'v2_OpenEnded_mscoco_val2014_questions.json', 31 | 'v2_OpenEnded_mscoco_test2015_questions.json', 32 | 'v2_OpenEnded_mscoco_test-dev2015_questions.json' 33 | ] 34 | for path in files: 35 | question_path = os.path.join(dataroot, path) 36 | qs = json.load(open(question_path))['questions'] 37 | for q in qs: 38 | dictionary.tokenize(q['question'], True) 39 | 40 | elif task == 'flickr': 41 | files = [ 42 | 'train_ids.pkl', 43 | 'val_ids.pkl', 44 | 'test_ids.pkl', 45 | ] 46 | sentence_dir = os.path.join(dataroot, 'Flickr30kEntities/Sentences') 47 | 48 | for path in files: 49 | ids_file = os.path.join(dataroot, path) 50 | 51 | with open(ids_file, 'rb') as f: 52 | imgids = cPickle.load(f) 53 | 54 | for image_id in imgids: 55 | question_path = os.path.join(sentence_dir, '%d.txt' % image_id) 56 | phrases = get_sent_data(question_path) 57 | for phrase in phrases: 58 | dictionary.tokenize(phrase, True) 59 | return dictionary 60 | 61 | 62 | def create_glove_embedding_init(idx2word, glove_file): 63 | word2emb = {} 64 | with open(glove_file, 'r') as f: 65 | entries = f.readlines() 66 | emb_dim = len(entries[0].split(' ')) - 1 67 | print('embedding dim is %d' % emb_dim) 68 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 69 | 70 | for entry in entries: 71 | vals = entry.split(' ') 72 | word = vals[0] 73 | vals = list(map(float, vals[1:])) 74 | word2emb[word] = np.array(vals) 75 | for idx, word in enumerate(idx2word): 76 | if word not in word2emb: 77 | continue 78 | weights[idx] = word2emb[word] 79 | return weights, word2emb 80 | 81 | 82 | if __name__ == '__main__': 83 | args = parse_args() 84 | dataroot = 'data' if args.task == 'vqa' else 'data/flickr30k' 85 | 86 | dictionary_path = os.path.join(dataroot, 'dictionary.pkl') 87 | 88 | d = create_dictionary(dataroot, args.task) 89 | d.dump_to_file(dictionary_path) 90 | 91 | d = Dictionary.load_from_file(dictionary_path) 92 | emb_dim = 300 93 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 94 | weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file) 95 | np.save(os.path.join(dataroot, 'glove6b_init_%dd.npy' % emb_dim), weights) 96 | -------------------------------------------------------------------------------- /tools/create_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | import os 7 | import sys 8 | import json 9 | import functools 10 | import operator 11 | import numpy as np 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from dataset import Dictionary 14 | import _pickle as cPickle 15 | 16 | 17 | def create_glove_embedding_init(idx2word, glove_file): 18 | word2emb = {} 19 | with open(glove_file, 'r') as f: 20 | entries = f.readlines() 21 | emb_dim = len(entries[0].split(' ')) - 1 22 | print('embedding dim is %d' % emb_dim) 23 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 24 | 25 | for entry in entries: 26 | vals = entry.split(' ') 27 | word = vals[0] 28 | vals = list(map(float, vals[1:])) 29 | word2emb[word] = np.array(vals) 30 | count = 0 31 | for idx, word in enumerate(idx2word): 32 | if word not in word2emb: 33 | updates = 0 34 | for w in word.split(' '): 35 | if w not in word2emb: 36 | continue 37 | weights[idx] += word2emb[w] 38 | updates += 1 39 | if updates == 0: 40 | count+= 1 41 | continue 42 | weights[idx] = word2emb[word] 43 | return weights, word2emb 44 | 45 | 46 | if __name__ == '__main__': 47 | emb_dims = [50, 100, 200, 300] 48 | weights = [0] * len(emb_dims) 49 | label2ans = cPickle.load(open('data/cache/trainval_label2ans.pkl', 'rb')) 50 | 51 | for idx, emb_dim in enumerate(emb_dims): # available embedding sizes 52 | glove_file = 'data/glove/glove.6B.%dd.txt' % emb_dim 53 | weights[idx], word2emb = create_glove_embedding_init(label2ans, glove_file) 54 | np.save('data/glove6b_emb_%dd.npy' % functools.reduce(operator.add, emb_dims), np.hstack(weights)) 55 | -------------------------------------------------------------------------------- /tools/detection_features_converter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | 5 | Reads in a tsv file with pre-trained bottom up attention features and 6 | stores it in HDF5 format. Also store {image_id: feature_idx} 7 | as a pickle file. 8 | 9 | Hierarchy of HDF5 file: 10 | 11 | { 'image_features': num_images x num_boxes x 2048 array of features 12 | 'image_bb': num_images x num_boxes x 4 array of bounding boxes } 13 | """ 14 | from __future__ import print_function 15 | 16 | import os 17 | import sys 18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 19 | 20 | import base64 21 | import csv 22 | import h5py 23 | import _pickle as cPickle 24 | import numpy as np 25 | import utils 26 | 27 | 28 | csv.field_size_limit(sys.maxsize) 29 | 30 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features'] 31 | infile = 'data/trainval_36/trainval_resnet101_faster_rcnn_genome_36.tsv' 32 | train_data_file = 'data/train36.hdf5' 33 | val_data_file = 'data/val36.hdf5' 34 | train_indices_file = 'data/train36_imgid2idx.pkl' 35 | val_indices_file = 'data/val36_imgid2idx.pkl' 36 | train_ids_file = 'data/train_ids.pkl' 37 | val_ids_file = 'data/val_ids.pkl' 38 | 39 | feature_length = 2048 40 | num_fixed_boxes = 36 41 | 42 | 43 | if __name__ == '__main__': 44 | h_train = h5py.File(train_data_file, "w") 45 | h_val = h5py.File(val_data_file, "w") 46 | 47 | if os.path.exists(train_ids_file) and os.path.exists(val_ids_file): 48 | train_imgids = cPickle.load(open(train_ids_file, 'rb')) 49 | val_imgids = cPickle.load(open(val_ids_file, 'rb')) 50 | else: 51 | train_imgids = utils.load_imageid('data/train2014') 52 | val_imgids = utils.load_imageid('data/val2014') 53 | cPickle.dump(train_imgids, open(train_ids_file, 'wb')) 54 | cPickle.dump(val_imgids, open(val_ids_file, 'wb')) 55 | 56 | train_indices = {} 57 | val_indices = {} 58 | 59 | train_img_features = h_train.create_dataset( 60 | 'image_features', (len(train_imgids), num_fixed_boxes, feature_length), 'f') 61 | train_img_bb = h_train.create_dataset( 62 | 'image_bb', (len(train_imgids), num_fixed_boxes, 4), 'f') 63 | train_spatial_img_features = h_train.create_dataset( 64 | 'spatial_features', (len(train_imgids), num_fixed_boxes, 6), 'f') 65 | 66 | val_img_bb = h_val.create_dataset( 67 | 'image_bb', (len(val_imgids), num_fixed_boxes, 4), 'f') 68 | val_img_features = h_val.create_dataset( 69 | 'image_features', (len(val_imgids), num_fixed_boxes, feature_length), 'f') 70 | val_spatial_img_features = h_val.create_dataset( 71 | 'spatial_features', (len(val_imgids), num_fixed_boxes, 6), 'f') 72 | 73 | train_counter = 0 74 | val_counter = 0 75 | 76 | print("reading tsv...") 77 | with open(infile, "r+") as tsv_in_file: 78 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) 79 | for item in reader: 80 | item['num_boxes'] = int(item['num_boxes']) 81 | item['boxes'] = bytes(item['boxes'], 'utf') 82 | item['features'] = bytes(item['features'], 'utf') 83 | image_id = int(item['image_id']) 84 | image_w = float(item['image_w']) 85 | image_h = float(item['image_h']) 86 | bboxes = np.frombuffer( 87 | base64.decodestring(item['boxes']), 88 | dtype=np.float32).reshape((item['num_boxes'], -1)) 89 | 90 | box_width = bboxes[:, 2] - bboxes[:, 0] 91 | box_height = bboxes[:, 3] - bboxes[:, 1] 92 | scaled_width = box_width / image_w 93 | scaled_height = box_height / image_h 94 | scaled_x = bboxes[:, 0] / image_w 95 | scaled_y = bboxes[:, 1] / image_h 96 | 97 | box_width = box_width[..., np.newaxis] 98 | box_height = box_height[..., np.newaxis] 99 | scaled_width = scaled_width[..., np.newaxis] 100 | scaled_height = scaled_height[..., np.newaxis] 101 | scaled_x = scaled_x[..., np.newaxis] 102 | scaled_y = scaled_y[..., np.newaxis] 103 | 104 | spatial_features = np.concatenate( 105 | (scaled_x, 106 | scaled_y, 107 | scaled_x + scaled_width, 108 | scaled_y + scaled_height, 109 | scaled_width, 110 | scaled_height), 111 | axis=1) 112 | 113 | if image_id in train_imgids: 114 | train_imgids.remove(image_id) 115 | train_indices[image_id] = train_counter 116 | train_img_bb[train_counter, :, :] = bboxes 117 | train_img_features[train_counter, :, :] = np.frombuffer( 118 | base64.decodestring(item['features']), 119 | dtype=np.float32).reshape((item['num_boxes'], -1)) 120 | train_spatial_img_features[train_counter, :, :] = spatial_features 121 | train_counter += 1 122 | elif image_id in val_imgids: 123 | val_imgids.remove(image_id) 124 | val_indices[image_id] = val_counter 125 | val_img_bb[val_counter, :, :] = bboxes 126 | val_img_features[val_counter, :, :] = np.frombuffer( 127 | base64.decodestring(item['features']), 128 | dtype=np.float32).reshape((item['num_boxes'], -1)) 129 | val_spatial_img_features[val_counter, :, :] = spatial_features 130 | val_counter += 1 131 | else: 132 | assert False, 'Unknown image id: %d' % image_id 133 | 134 | if len(train_imgids) != 0: 135 | print('Warning: train_image_ids is not empty') 136 | 137 | if len(val_imgids) != 0: 138 | print('Warning: val_image_ids is not empty') 139 | 140 | cPickle.dump(train_indices, open(train_indices_file, 'wb')) 141 | cPickle.dump(val_indices, open(val_indices_file, 'wb')) 142 | h_train.close() 143 | h_val.close() 144 | print("done!") 145 | -------------------------------------------------------------------------------- /tools/detection_features_converter_target.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | 5 | Reads in a tsv file with pre-trained bottom up attention features and 6 | stores it in HDF5 format. Also store {image_id: feature_idx} 7 | as a pickle file. 8 | 9 | Hierarchy of HDF5 file: 10 | 11 | { 'image_features': num_images x num_boxes x 2048 array of features 12 | 'image_bb': num_images x num_boxes x 4 array of bounding boxes } 13 | """ 14 | from __future__ import print_function 15 | 16 | import os 17 | import sys 18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 19 | 20 | import base64 21 | import csv 22 | import h5py 23 | import _pickle as cPickle 24 | import numpy as np 25 | import utils 26 | 27 | target = 'test2015' 28 | 29 | csv.field_size_limit(sys.maxsize) 30 | 31 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features'] 32 | infile = 'data/%s_36/%s_resnet101_faster_rcnn_genome_36.tsv' % (target, target) 33 | data_file = 'data/%s36.hdf5' % target 34 | indices_file = 'data/%s36_imgid2idx.pkl' % target 35 | ids_file = 'data/%s_ids.pkl' % target 36 | 37 | feature_length = 2048 38 | num_fixed_boxes = 36 39 | 40 | 41 | if __name__ == '__main__': 42 | h = h5py.File(data_file, "w") 43 | 44 | if os.path.exists(ids_file): 45 | imgids = cPickle.load(open(ids_file, 'rb')) 46 | else: 47 | imgids = utils.load_imageid('data/%s' % target) 48 | cPickle.dump(imgids, open(ids_file, 'wb')) 49 | 50 | indices = {} 51 | 52 | img_bb = h.create_dataset( 53 | 'image_bb', (len(imgids), num_fixed_boxes, 4), 'f') 54 | img_features = h.create_dataset( 55 | 'image_features', (len(imgids), num_fixed_boxes, feature_length), 'f') 56 | spatial_img_features = h.create_dataset( 57 | 'spatial_features', (len(imgids), num_fixed_boxes, 6), 'f') 58 | 59 | counter = 0 60 | 61 | print("reading tsv...") 62 | with open(infile, "r+") as tsv_in_file: 63 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) 64 | for item in reader: 65 | item['num_boxes'] = int(item['num_boxes']) 66 | item['boxes'] = bytes(item['boxes'], 'utf') 67 | item['features'] = bytes(item['features'], 'utf') 68 | image_id = int(item['image_id']) 69 | image_w = float(item['image_w']) 70 | image_h = float(item['image_h']) 71 | bboxes = np.frombuffer( 72 | base64.decodestring(item['boxes']), 73 | dtype=np.float32).reshape((item['num_boxes'], -1)) 74 | 75 | box_width = bboxes[:, 2] - bboxes[:, 0] 76 | box_height = bboxes[:, 3] - bboxes[:, 1] 77 | scaled_width = box_width / image_w 78 | scaled_height = box_height / image_h 79 | scaled_x = bboxes[:, 0] / image_w 80 | scaled_y = bboxes[:, 1] / image_h 81 | 82 | box_width = box_width[..., np.newaxis] 83 | box_height = box_height[..., np.newaxis] 84 | scaled_width = scaled_width[..., np.newaxis] 85 | scaled_height = scaled_height[..., np.newaxis] 86 | scaled_x = scaled_x[..., np.newaxis] 87 | scaled_y = scaled_y[..., np.newaxis] 88 | 89 | spatial_features = np.concatenate( 90 | (scaled_x, 91 | scaled_y, 92 | scaled_x + scaled_width, 93 | scaled_y + scaled_height, 94 | scaled_width, 95 | scaled_height), 96 | axis=1) 97 | 98 | if image_id in imgids: 99 | imgids.remove(image_id) 100 | indices[image_id] = counter 101 | img_bb[counter, :, :] = bboxes 102 | img_features[counter, :, :] = np.frombuffer( 103 | base64.decodestring(item['features']), 104 | dtype=np.float32).reshape((item['num_boxes'], -1)) 105 | spatial_img_features[counter, :, :] = spatial_features 106 | counter += 1 107 | else: 108 | assert False, 'Unknown image id: %d' % image_id 109 | 110 | if len(imgids) != 0: 111 | print('Warning: image_ids is not empty') 112 | 113 | cPickle.dump(indices, open(indices_file, 'wb')) 114 | h.close() 115 | print("done!") 116 | -------------------------------------------------------------------------------- /tools/download.sh: -------------------------------------------------------------------------------- 1 | ## This code is modified from Hengyuan Hu's repository. 2 | ## https://github.com/hengyuan-hu/bottom-up-attention-vqa 3 | 4 | ## Script for downloading data 5 | 6 | # GloVe Vectors 7 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip 8 | unzip data/glove.6B.zip -d data/glove 9 | rm data/glove.6B.zip 10 | 11 | # Questions 12 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip 13 | unzip data/v2_Questions_Train_mscoco.zip -d data 14 | rm data/v2_Questions_Train_mscoco.zip 15 | 16 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip 17 | unzip data/v2_Questions_Val_mscoco.zip -d data 18 | rm data/v2_Questions_Val_mscoco.zip 19 | 20 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip 21 | unzip data/v2_Questions_Test_mscoco.zip -d data 22 | rm data/v2_Questions_Test_mscoco.zip 23 | 24 | # Annotations 25 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip 26 | unzip data/v2_Annotations_Train_mscoco.zip -d data 27 | rm data/v2_Annotations_Train_mscoco.zip 28 | 29 | wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip 30 | unzip data/v2_Annotations_Val_mscoco.zip -d data 31 | rm data/v2_Annotations_Val_mscoco.zip 32 | 33 | # Image Features 34 | wget -P data https://imagecaption.blob.core.windows.net/imagecaption/trainval.zip 35 | wget -P data https://imagecaption.blob.core.windows.net/imagecaption/test2014.zip 36 | wget -P data https://imagecaption.blob.core.windows.net/imagecaption/test2015.zip 37 | unzip data/trainval.zip -d data 38 | unzip data/test2014.zip -d data 39 | unzip data/test2015.zip -d data 40 | rm data/trainval.zip 41 | rm data/test2014.zip 42 | rm data/test2015.zip 43 | 44 | # Download Pickle caches for the pretrained model from 45 | # https://drive.google.com/file/d/1m5pL9gOkcnLZ_NuANmnDFIcil3NQVmZc/view?usp=sharing 46 | # and extract pkl files under data/cache/. 47 | mkdir -p data/cache 48 | wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1m5pL9gOkcnLZ_NuANmnDFIcil3NQVmZc' -O data/cache/cache.pkl.tgz 49 | tar xvf data/cache/cache.pkl.tgz -C data/cache/ 50 | rm data/cache/cache.pkl.tgz 51 | 52 | -------------------------------------------------------------------------------- /tools/download_data.sh: -------------------------------------------------------------------------------- 1 | ## This code is from Hengyuan Hu's repository. 2 | ## https://github.com/hengyuan-hu/bottom-up-attention-vqa 3 | 4 | ## Script for downloading data 5 | 6 | # VQA Input Images 7 | 8 | wget -P data http://msvocds.blob.core.windows.net/coco2014/train2014.zip 9 | unzip data/train2014.zip -d data/ 10 | rm data/train2014.zip 11 | 12 | wget -P data http://msvocds.blob.core.windows.net/coco2014/val2014.zip 13 | unzip data/val2014.zip -d data/ 14 | rm data/val2014.zip 15 | 16 | wget -P data http://msvocds.blob.core.windows.net/coco2015/test2015.zip 17 | unzip data/test2015.zip -d data/ 18 | rm data/test2015.zip 19 | 20 | -------------------------------------------------------------------------------- /tools/download_flickr.sh: -------------------------------------------------------------------------------- 1 | ## This code is modified from Hengyuan Hu's repository. 2 | ## https://github.com/hengyuan-hu/bottom-up-attention-vqa 3 | 4 | ## Script for downloading data 5 | 6 | 7 | # Download Flickr30k Entities data mannually into data/flickr30k 8 | # https://github.com/BryanPlummer/flickr30k_entities/blob/master/annotations.zip 9 | # Unpack Flickr30kEntities.tar.gz 10 | tar xvf data/flickr30k/Flickr30kEntities.tar.gz -C data/flickr30k 11 | rm data/flickr30k/Flickr30kEntities.tar.gz 12 | 13 | 14 | # Download Flickr30k images and captions data 15 | # https://drive.google.com/file/d/0B_PL6p-5reUAZEM4MmRQQ2VVSlk/view?usp=sharing 16 | mkdir -p data/flickr30k 17 | 18 | wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=0B_PL6p-5reUAZEM4MmRQQ2VVSlk' -O tmp.html 19 | wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(cat tmp.html | sed -En 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=0B_PL6p-5reUAZEM4MmRQQ2VVSlk" -O data/flickr30k/flickr30k_images.tar.gz 20 | rm -rf /tmp/cookies.txt tmp.html 21 | 22 | tar xvf data/flickr30k/flickr30k_images.tar.gz -C data/flickr30k 23 | rm data/flickr30k/flickr30k_images.tar.gz 24 | 25 | 26 | # Flickr30k Image Features 27 | # Image feature Generated using https://github.com/peteanderson80/bottom-up-attention.git 28 | # Please download via https://drive.google.com/file/d/1BmcxeY1kXzMZv54d4wMtl7HGc8Cs9zgO/view?usp=sharing 29 | # The below direct download is blocked. 30 | # 31 | # wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=11OD_qq7ITBarJwWZfi0bWIRw3HPEaHwE' -O tmp.html 32 | # wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(cat tmp.html | sed -En 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=11OD_qq7ITBarJwWZfi0bWIRw3HPEaHwE" -O data/flickr30k/flickr30k_features.zip 33 | # rm -rf /tmp/cookies.txt tmp.html 34 | 35 | unzip data/flickr30k/flickr30k_features.zip -d data/flickr30k 36 | rm data/flickr30k/flickr30k_features.zip 37 | 38 | 39 | # GloVe Vectors 40 | wget -P data http://nlp.stanford.edu/data/glove.6B.zip 41 | unzip data/glove.6B.zip -d data/glove 42 | rm data/glove.6B.zip 43 | -------------------------------------------------------------------------------- /tools/grad_check.py: -------------------------------------------------------------------------------- 1 | ## This code is from Hengyuan Hu's repository. 2 | ## https://github.com/hengyuan-hu/bottom-up-attention-vqa 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as f 6 | from torch.autograd import Variable 7 | 8 | q = torch.Tensor([[1,2,3]]) # 1x3 9 | v = torch.Tensor([[[2,1,3],[3,2,1],[1,2,3]]]) #1x3x3 10 | 11 | q = Variable(q, requires_grad=True) 12 | v = Variable(v, requires_grad=True) 13 | 14 | q_ = q.unsqueeze(1).repeat(1, 3, 1) 15 | qv = (q_ * v).sum(2) 16 | a = f.softmax(qv, 1) 17 | 18 | v_ = (a.unsqueeze(2).expand(1,3,3) * v).sum(1) 19 | output = (q * v_).sum(1) 20 | 21 | output.backward() 22 | print('autograd:') 23 | print(q.grad.data) 24 | 25 | # autograd 26 | # q.grad = [1.0136 1.9155 3.0709] 27 | 28 | s = torch.Tensor(3).fill_(0) 29 | for i in range(3): 30 | dp = v.data[0][0].clone().fill_(0) 31 | p_i = a.data[0][i] 32 | for j in range(3): 33 | p_j = a.data[0][j] 34 | if i==j: 35 | dp = dp + p_i * (1 - p_i) * v.data[0][i] 36 | else: 37 | dp = dp - p_i * (p_j) * v.data[0][j] 38 | s = s + dp * (v.data[0][i] * q.data[0]).sum() 39 | 40 | dq = v_.data + s.unsqueeze(0) 41 | print('for-loop:') 42 | print(dq) 43 | 44 | # for-loop 45 | # dq = [1.0136 1.9155 3.0709] 46 | 47 | p = a.data[0] 48 | pp = torch.ger(p, p) 49 | diag_p = torch.diag(p) 50 | 51 | V = v.data[0] # V^T 52 | T = torch.matmul((diag_p - pp), V) 53 | qV = torch.matmul(V, q.data[0]) 54 | s = torch.matmul(qV, T) 55 | 56 | dq = v_.data + s 57 | print('linear algebra:') # (q^T V) (diag(p)-pp^T) V^T 58 | print(dq) 59 | 60 | T = torch.matmul(V.t(), diag_p - pp) 61 | U = torch.matmul(T, V) 62 | s = torch.matmul(U, q.data[0]) 63 | 64 | dq = v_.data + s 65 | print('linear algebra:') # V (diag(p)-pp^T) V^T q 66 | print(dq) 67 | 68 | # linear algebra 69 | # dq = [1.0136 1.9155 3.0709] 70 | -------------------------------------------------------------------------------- /tools/process.sh: -------------------------------------------------------------------------------- 1 | ## This code is modified from Hengyuan Hu's repository. 2 | ## https://github.com/hengyuan-hu/bottom-up-attention-vqa 3 | 4 | ## Process data 5 | ## Notice that 10-100 adaptive bottom-up attention features are used. 6 | 7 | python3 tools/create_dictionary.py 8 | python3 tools/compute_softscore.py 9 | python3 tools/adaptive_detection_features_converter.py 10 | -------------------------------------------------------------------------------- /tools/process_flickr.sh: -------------------------------------------------------------------------------- 1 | ## This code is modified from Hengyuan Hu's repository. 2 | ## https://github.com/hengyuan-hu/bottom-up-attention-vqa 3 | 4 | ## Process data 5 | ## Notice that 10-100 adaptive bottom-up attention features are used. 6 | 7 | python3 tools/create_dictionary.py --task flickr 8 | python3 tools/adaptive_detection_features_converter.py --task flickr -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | import os 6 | import time 7 | import itertools 8 | import torch 9 | import torch.nn as nn 10 | import utils 11 | import torch.optim.lr_scheduler as lr_scheduler 12 | 13 | 14 | def instance_bce_with_logits(logits, labels, reduction='mean'): 15 | assert logits.dim() == 2 16 | 17 | loss = nn.functional.binary_cross_entropy_with_logits(logits, labels, reduction=reduction) 18 | if reduction == 'mean': 19 | loss *= labels.size(1) 20 | return loss 21 | 22 | 23 | def compute_score_with_logits(logits, labels): 24 | logits = torch.max(logits, 1)[1].data # argmax 25 | one_hots = torch.zeros(*labels.size()).cuda() 26 | one_hots.scatter_(1, logits.view(-1, 1), 1) 27 | scores = (one_hots * labels) 28 | return scores 29 | 30 | 31 | def train(model, train_loader, eval_loader, num_epochs, output, opt=None, s_epoch=0): 32 | lr_default = 1e-3 if eval_loader is not None else 7e-4 33 | lr_decay_step = 2 34 | lr_decay_rate = .25 35 | lr_decay_epochs = range(10,20,lr_decay_step) if eval_loader is not None else range(10,20,lr_decay_step) 36 | gradual_warmup_steps = [0.5 * lr_default, 1.0 * lr_default, 1.5 * lr_default, 2.0 * lr_default] 37 | saving_epoch = 3 38 | grad_clip = .25 39 | 40 | utils.create_dir(output) 41 | optim = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_default) \ 42 | if opt is None else opt 43 | logger = utils.Logger(os.path.join(output, 'log.txt')) 44 | best_eval_score = 0 45 | 46 | utils.print_model(model, logger) 47 | logger.write('optim: adamax lr=%.4f, decay_step=%d, decay_rate=%.2f, grad_clip=%.2f' % \ 48 | (lr_default, lr_decay_step, lr_decay_rate, grad_clip)) 49 | 50 | for epoch in range(s_epoch, num_epochs): 51 | total_loss = 0 52 | train_score = 0 53 | total_norm = 0 54 | count_norm = 0 55 | t = time.time() 56 | N = len(train_loader.dataset) 57 | if epoch < len(gradual_warmup_steps): 58 | optim.param_groups[0]['lr'] = gradual_warmup_steps[epoch] 59 | logger.write('gradual warmup lr: %.4f' % optim.param_groups[0]['lr']) 60 | elif epoch in lr_decay_epochs: 61 | optim.param_groups[0]['lr'] *= lr_decay_rate 62 | logger.write('decreased lr: %.4f' % optim.param_groups[0]['lr']) 63 | else: 64 | logger.write('lr: %.4f' % optim.param_groups[0]['lr']) 65 | 66 | for i, (v, b, q, a) in enumerate(train_loader): 67 | v = v.cuda() 68 | b = b.cuda() 69 | q = q.cuda() 70 | a = a.cuda() 71 | 72 | pred, att = model(v, b, q, a) 73 | loss = instance_bce_with_logits(pred, a) 74 | loss.backward() 75 | total_norm += nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 76 | count_norm += 1 77 | optim.step() 78 | optim.zero_grad() 79 | 80 | batch_score = compute_score_with_logits(pred, a.data).sum() 81 | total_loss += loss.item() * v.size(0) 82 | train_score += batch_score.item() 83 | 84 | total_loss /= N 85 | train_score = 100 * train_score / N 86 | if None != eval_loader: 87 | model.train(False) 88 | eval_score, bound, entropy = evaluate(model, eval_loader) 89 | model.train(True) 90 | 91 | logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t)) 92 | logger.write('\ttrain_loss: %.2f, norm: %.4f, score: %.2f' % (total_loss, total_norm/count_norm, train_score)) 93 | if eval_loader is not None: 94 | logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound)) 95 | 96 | if eval_loader is not None and entropy is not None: 97 | info = '' 98 | for i in range(entropy.size(0)): 99 | info = info + ' %.2f' % entropy[i] 100 | logger.write('\tentropy: ' + info) 101 | 102 | if (eval_loader is not None and eval_score > best_eval_score) or (eval_loader is None and epoch >= saving_epoch): 103 | model_path = os.path.join(output, 'model_epoch%d.pth' % epoch) 104 | utils.save_model(model_path, model, epoch, optim) 105 | if eval_loader is not None: 106 | best_eval_score = eval_score 107 | 108 | 109 | @torch.no_grad() 110 | def evaluate(model, dataloader): 111 | score = 0 112 | upper_bound = 0 113 | num_data = 0 114 | entropy = None 115 | if hasattr(model.module, 'glimpse'): 116 | entropy = torch.Tensor(model.module.glimpse).zero_().cuda() 117 | for i, (v, b, q, a) in enumerate(dataloader): 118 | v = v.cuda() 119 | b = b.cuda() 120 | q = q.cuda() 121 | pred, att = model(v, b, q, None) 122 | batch_score = compute_score_with_logits(pred, a.cuda()).sum() 123 | score += batch_score.item() 124 | upper_bound += (a.max(1)[0]).sum().item() 125 | num_data += pred.size(0) 126 | if att is not None and 0 < model.module.glimpse: 127 | entropy += calc_entropy(att.data)[:model.module.glimpse] 128 | 129 | score = score / len(dataloader.dataset) 130 | upper_bound = upper_bound / len(dataloader.dataset) 131 | 132 | if entropy is not None: 133 | entropy = entropy / len(dataloader.dataset) 134 | 135 | return score, upper_bound, entropy 136 | 137 | def calc_entropy(att): # size(att) = [b x g x v x q] 138 | sizes = att.size() 139 | eps = 1e-8 140 | p = att.view(-1, sizes[1], sizes[2] * sizes[3]) 141 | return (-p * (p+eps).log()).sum(2).sum(0) # g -------------------------------------------------------------------------------- /train_flickr.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | 6 | import os 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | import utils 11 | from train import instance_bce_with_logits, compute_score_with_logits 12 | 13 | 14 | def compute_recall_with_logits(logits, labels): 15 | logits = torch.sort(logits, 1, descending=True)[1].data 16 | scores = [0]*3 17 | for i,r in enumerate([1,5,10]): 18 | one_hots = torch.zeros(*labels.size()).cuda() 19 | one_hots.scatter_(1, logits[:,:r], 1) 20 | scores[i] = ((one_hots * labels).sum(1)>=1).float().sum() 21 | return scores 22 | 23 | 24 | def train(model, train_loader, eval_loader, num_epochs, output, opt=None, s_epoch=0): 25 | lr_default = 1e-3 if eval_loader is not None else 7e-4 26 | lr_decay_step = 2 27 | lr_decay_rate = .25 28 | lr_decay_epochs = range(10,20,lr_decay_step) if eval_loader is not None else range(10,20,lr_decay_step) 29 | gradual_warmup_steps = [0.5 * lr_default, 1.0 * lr_default, 1.5 * lr_default, 2.0 * lr_default] 30 | saving_epoch = 3 31 | grad_clip = .25 32 | 33 | utils.create_dir(output) 34 | optim = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_default) \ 35 | if opt is None else opt 36 | logger = utils.Logger(os.path.join(output, 'log.txt')) 37 | best_eval_score = 0 38 | 39 | utils.print_model(model, logger) 40 | logger.write('optim: adamax lr=%.4f, decay_step=%d, decay_rate=%.2f, grad_clip=%.2f' % \ 41 | (lr_default, lr_decay_step, lr_decay_rate, grad_clip)) 42 | 43 | for epoch in range(s_epoch, num_epochs): 44 | total_loss = 0 45 | train_score = 0 46 | total_norm = 0 47 | count_norm = 0 48 | t = time.time() 49 | N = 0 50 | 51 | if epoch < len(gradual_warmup_steps): 52 | optim.param_groups[0]['lr'] = gradual_warmup_steps[epoch] 53 | logger.write('gradual warmup lr: %.4f' % optim.param_groups[0]['lr']) 54 | elif epoch in lr_decay_epochs: 55 | optim.param_groups[0]['lr'] *= lr_decay_rate 56 | logger.write('decreased lr: %.4f' % optim.param_groups[0]['lr']) 57 | else: 58 | logger.write('lr: %.4f' % optim.param_groups[0]['lr']) 59 | 60 | for i, (v, b, p, e, n, a, idx, types) in enumerate(train_loader): 61 | v = v.cuda() 62 | b = b.cuda() 63 | p = p.cuda() 64 | e = e.cuda() 65 | a = a.cuda() 66 | 67 | _, logits = model(v, b, p, e, a) 68 | n_obj = logits.size(2) 69 | logits.squeeze_() 70 | 71 | merged_logit = torch.cat(tuple(logits[j, :, :n[j][0]] for j in range(n.size(0))), -1).permute(1, 0) 72 | merged_a = torch.cat(tuple(a[j, :n[j][0], :n_obj] for j in range(n.size(0))), 0) 73 | 74 | loss = instance_bce_with_logits(merged_logit, merged_a, 'sum') / v.size(0) 75 | N += n.sum().float() 76 | 77 | batch_score = compute_score_with_logits(merged_logit, merged_a.data).sum() 78 | 79 | loss.backward() 80 | total_norm += nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 81 | count_norm += 1 82 | optim.step() 83 | optim.zero_grad() 84 | total_loss += loss.item() * v.size(0) 85 | train_score += batch_score.item() 86 | 87 | total_loss /= N 88 | train_score = 100 * train_score / N 89 | if None != eval_loader: 90 | model.train(False) 91 | eval_score, bound, entropy = evaluate(model, eval_loader) 92 | model.train(True) 93 | 94 | logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t)) 95 | logger.write('\ttrain_loss: %.2f, norm: %.4f, score: %.2f' % (total_loss, total_norm/count_norm, train_score)) 96 | if eval_loader is not None: 97 | logger.write('\teval score: %.2f/%.2f/%.2f (%.2f)' % ( 98 | 100 * eval_score[0], 100 * eval_score[1], 100 * eval_score[2], 100 * bound)) 99 | eval_score = eval_score[0] 100 | 101 | if eval_loader is not None and entropy is not None: 102 | info = '' 103 | for i in range(entropy.size(0)): 104 | info = info + ' %.2f' % entropy[i] 105 | logger.write('\tentropy: ' + info) 106 | 107 | if (eval_loader is not None and eval_score > best_eval_score) or (eval_loader is None and epoch >= saving_epoch): 108 | model_path = os.path.join(output, 'model_epoch%d.pth' % epoch) 109 | utils.save_model(model_path, model, epoch, optim) 110 | if eval_loader is not None: 111 | best_eval_score = eval_score 112 | 113 | 114 | @torch.no_grad() 115 | def evaluate(model, dataloader): 116 | upper_bound = 0 117 | entropy = None 118 | score = [0] * 3 119 | N = 0 120 | for i, (v, b, p, e, n, a, idx, types) in enumerate(dataloader): 121 | v = v.cuda() 122 | b = b.cuda() 123 | p = p.cuda() 124 | e = e.cuda() 125 | a = a.cuda() 126 | _, logits = model(v, b, p, e, None) 127 | n_obj = logits.size(2) 128 | logits.squeeze_() 129 | 130 | merged_logits = torch.cat(tuple(logits[j, :, :n[j][0]] for j in range(n.size(0))), -1).permute(1, 0) 131 | merged_a = torch.cat(tuple(a[j, :n[j][0], :n_obj] for j in range(n.size(0))), 0) 132 | 133 | recall = compute_recall_with_logits(merged_logits, merged_a.data) 134 | for r_idx, r in enumerate(recall): 135 | score[r_idx] += r 136 | N += n.sum().float() 137 | upper_bound += merged_a.max(-1, False)[0].sum().item() 138 | 139 | for i in range(3): 140 | score[i] = score[i] / N 141 | upper_bound = upper_bound / N 142 | 143 | return score, upper_bound, entropy 144 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is extended from Hengyuan Hu's repository. 3 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 4 | """ 5 | from __future__ import print_function 6 | 7 | import errno 8 | import os 9 | import re 10 | import collections 11 | import numpy as np 12 | import operator 13 | import functools 14 | from PIL import Image 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch._six import string_classes 19 | from torch.utils.data.dataloader import default_collate 20 | 21 | 22 | EPS = 1e-7 23 | 24 | 25 | def assert_eq(real, expected): 26 | assert real == expected, '%s (true) vs %s (expected)' % (real, expected) 27 | 28 | 29 | def assert_array_eq(real, expected): 30 | assert (np.abs(real-expected) < EPS).all(), \ 31 | '%s (true) vs %s (expected)' % (real, expected) 32 | 33 | def assert_tensor_eq(real, expected, eps=EPS): 34 | assert (torch.abs(real-expected) < eps).all(), \ 35 | '%s (true) vs %s (expected)' % (real, expected) 36 | 37 | 38 | def load_folder(folder, suffix): 39 | imgs = [] 40 | for f in sorted(os.listdir(folder)): 41 | if f.endswith(suffix): 42 | imgs.append(os.path.join(folder, f)) 43 | return imgs 44 | 45 | 46 | def load_imageid(folder): 47 | images = load_folder(folder, 'jpg') 48 | img_ids = set() 49 | for img in images: 50 | img_id = int(img.split('/')[-1].split('.')[0].split('_')[-1]) 51 | img_ids.add(img_id) 52 | return img_ids 53 | 54 | 55 | def pil_loader(path): 56 | with open(path, 'rb') as f: 57 | with Image.open(f) as img: 58 | return img.convert('RGB') 59 | 60 | 61 | def weights_init(m): 62 | """custom weights initialization.""" 63 | cname = m.__class__ 64 | if cname == nn.Linear or cname == nn.Conv2d or cname == nn.ConvTranspose2d: 65 | m.weight.data.normal_(0.0, 0.02) 66 | elif cname == nn.BatchNorm2d: 67 | m.weight.data.normal_(1.0, 0.02) 68 | m.bias.data.fill_(0) 69 | else: 70 | print('%s is not initialized.' % cname) 71 | 72 | 73 | def init_net(net, net_file): 74 | if net_file: 75 | net.load_state_dict(torch.load(net_file)) 76 | else: 77 | net.apply(weights_init) 78 | 79 | 80 | def create_dir(path): 81 | if not os.path.exists(path): 82 | try: 83 | os.makedirs(path) 84 | except OSError as exc: 85 | if exc.errno != errno.EEXIST: 86 | raise 87 | 88 | 89 | def print_model(model, logger): 90 | print(model) 91 | nParams = 0 92 | for w in model.parameters(): 93 | nParams += functools.reduce(operator.mul, w.size(), 1) 94 | if logger: 95 | logger.write('nParams=\t'+str(nParams)) 96 | 97 | 98 | def save_model(path, model, epoch, optimizer=None): 99 | model_dict = { 100 | 'epoch': epoch, 101 | 'model_state': model.state_dict() 102 | } 103 | if optimizer is not None: 104 | model_dict['optimizer_state'] = optimizer.state_dict() 105 | 106 | torch.save(model_dict, path) 107 | 108 | 109 | # Select the indices given by `lengths` in the second dimension 110 | # As a result, # of dimensions is shrinked by one 111 | # @param pad(Tensor) 112 | # @param len(list[int]) 113 | def rho_select(pad, lengths): 114 | # Index of the last output for each sequence. 115 | idx_ = (lengths-1).view(-1,1).expand(pad.size(0), pad.size(2)).unsqueeze(1) 116 | extracted = pad.gather(1, idx_).squeeze(1) 117 | return extracted 118 | 119 | 120 | def trim_collate(batch): 121 | "Puts each data field into a tensor with outer dimension batch size" 122 | _use_shared_memory = True 123 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 124 | elem_type = type(batch[0]) 125 | if torch.is_tensor(batch[0]): 126 | out = None 127 | if 1 < batch[0].dim(): # image features 128 | max_num_boxes = max([x.size(0) for x in batch]) 129 | if _use_shared_memory: 130 | # If we're in a background process, concatenate directly into a 131 | # shared memory tensor to avoid an extra copy 132 | numel = len(batch) * max_num_boxes * batch[0].size(-1) 133 | storage = batch[0].storage()._new_shared(numel) 134 | out = batch[0].new(storage) 135 | # warning: F.pad returns Variable! 136 | return torch.stack([F.pad(x, (0,0,0,max_num_boxes-x.size(0))).data for x in batch], 0, out=out) 137 | else: 138 | if _use_shared_memory: 139 | # If we're in a background process, concatenate directly into a 140 | # shared memory tensor to avoid an extra copy 141 | numel = sum([x.numel() for x in batch]) 142 | storage = batch[0].storage()._new_shared(numel) 143 | out = batch[0].new(storage) 144 | return torch.stack(batch, 0, out=out) 145 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 146 | and elem_type.__name__ != 'string_': 147 | elem = batch[0] 148 | if elem_type.__name__ == 'ndarray': 149 | # array of string classes and object 150 | if re.search('[SaUO]', elem.dtype.str) is not None: 151 | raise TypeError(error_msg.format(elem.dtype)) 152 | 153 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 154 | if elem.shape == (): # scalars 155 | py_type = float if elem.dtype.name.startswith('float') else int 156 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 157 | elif isinstance(batch[0], int): 158 | return torch.LongTensor(batch) 159 | elif isinstance(batch[0], float): 160 | return torch.DoubleTensor(batch) 161 | elif isinstance(batch[0], string_classes): 162 | return batch 163 | elif isinstance(batch[0], collections.Mapping): 164 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 165 | elif isinstance(batch[0], collections.Sequence): 166 | transposed = zip(*batch) 167 | return [trim_collate(samples) for samples in transposed] 168 | 169 | raise TypeError((error_msg.format(type(batch[0])))) 170 | 171 | 172 | class Logger(object): 173 | def __init__(self, output_name): 174 | dirname = os.path.dirname(output_name) 175 | if not os.path.exists(dirname): 176 | os.mkdir(dirname) 177 | 178 | self.log_file = open(output_name, 'w') 179 | self.infos = {} 180 | 181 | def append(self, key, val): 182 | vals = self.infos.setdefault(key, []) 183 | vals.append(val) 184 | 185 | def log(self, extra_msg=''): 186 | msgs = [extra_msg] 187 | for key, vals in self.infos.iteritems(): 188 | msgs.append('%s %.6f' % (key, np.mean(vals))) 189 | msg = '\n'.join(msgs) 190 | self.log_file.write(msg + '\n') 191 | self.log_file.flush() 192 | self.infos = {} 193 | return msg 194 | 195 | def write(self, msg): 196 | self.log_file.write(msg + '\n') 197 | self.log_file.flush() 198 | print(msg) 199 | 200 | 201 | def create_glove_embedding_init(idx2word, glove_file): 202 | word2emb = {} 203 | with open(glove_file, 'r', encoding='utf-8') as f: 204 | entries = f.readlines() 205 | emb_dim = len(entries[0].split(' ')) - 1 206 | print('embedding dim is %d' % emb_dim) 207 | weights = np.zeros((len(idx2word), emb_dim), dtype=np.float32) 208 | 209 | for entry in entries: 210 | vals = entry.split(' ') 211 | word = vals[0] 212 | vals = list(map(float, vals[1:])) 213 | word2emb[word] = np.array(vals) 214 | for idx, word in enumerate(idx2word): 215 | if word not in word2emb: 216 | continue 217 | weights[idx] = word2emb[word] 218 | return weights, word2emb 219 | 220 | # Remove Flickr30K Entity annotations in a string 221 | def remove_annotations(s): 222 | return re.sub(r'\[[^ ]+ ','',s).replace(']', '') 223 | 224 | def get_sent_data(file_path): 225 | phrases = [] 226 | 227 | with open(file_path, 'r', encoding='utf-8') as f: 228 | for sent in f: 229 | str = remove_annotations(sent.strip()) 230 | phrases.append(str) 231 | 232 | return phrases 233 | 234 | 235 | # Find position of a given sublist 236 | # return the index of the last token 237 | def find_sublist(arr, sub): 238 | sublen = len(sub) 239 | first = sub[0] 240 | indx = -1 241 | while True: 242 | try: 243 | indx = arr.index(first, indx + 1) 244 | except ValueError: 245 | break 246 | if sub == arr[indx: indx + sublen]: 247 | return indx + sublen - 1 248 | return -1 249 | 250 | 251 | def calculate_iou(obj1, obj2): 252 | area1 = calculate_area(obj1) 253 | area2 = calculate_area(obj2) 254 | intersection = get_intersection(obj1, obj2) 255 | area_int = calculate_area(intersection) 256 | return area_int / (area1 + area2 - area_int) 257 | 258 | def calculate_area(obj): 259 | return (obj[2] - obj[0]) * (obj[3] - obj[1]) 260 | 261 | def get_intersection(obj1, obj2): 262 | left = obj1[0] if obj1[0] > obj2[0] else obj2[0] 263 | top = obj1[1] if obj1[1] > obj2[1] else obj2[1] 264 | right = obj1[2] if obj1[2] < obj2[2] else obj2[2] 265 | bottom = obj1[3] if obj1[3] < obj2[3] else obj2[3] 266 | if left > right or top > bottom: 267 | return [0, 0, 0, 0] 268 | return [left, top, right, bottom] 269 | 270 | 271 | def get_match_index(src_bboxes, dst_bboxes): 272 | indices = set() 273 | for src_bbox in src_bboxes: 274 | for i, dst_bbox in enumerate(dst_bboxes): 275 | iou = calculate_iou(src_bbox, dst_bbox) 276 | if iou >= 0.5: 277 | indices.add(i) 278 | return list(indices) 279 | 280 | # Batched index_select 281 | def batched_index_select(t, dim, inds): 282 | dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2)) 283 | out = t.gather(dim, dummy) # b x e x f 284 | return out 285 | --------------------------------------------------------------------------------