├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── counting.py ├── data.py ├── data ├── co_occour_count.npy └── download.sh ├── eval-acc.py ├── gen_tree_net.py ├── logs └── .dummy ├── model.py ├── preprocess-features.py ├── preprocess-vocab.py ├── q_type_module.py ├── train.py ├── tree_def.py ├── tree_feature.py ├── tree_lstm.py ├── tree_utils.py ├── utils.py └── view-log.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.pth 4 | val_acc.png 5 | results.json 6 | vocab.json 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yan Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VCTree-Visual-Question-Answering 2 | Code for the VQA part of CVPR 2019 oral paper: "[Learning to Compose Dynamic Tree Structures for Visual Contexts][0]", as to the Scene Graph Generation part of this paper, please refer to [KaihuaTang/VCTree-Scene-Graph-Generation][7] 3 | 4 | UGLY CODE WARNING! UGLY CODE WARNING! UGLY CODE WARNING! 5 | 6 | The code is directly modified from the project [Cyanogenoid/vqa-counting][1]. We mainly modified the model.py, train.py, config.py and add several files about our VCTree model, such as all tree_*.py, gen_tree_net.py. Before we got our final model, we tried lots of different tree structures, hence you may found some strange code such as config.gen_tree_mode and the corresponding choices in tree_feature.py. Just ignore them. (I'm too lazy to purge the code, sorry about that) 7 | 8 | ## Dependencies 9 | This code was confirmed to run with the following environment: 10 | 11 | - Python 3.6 12 | - torch 0.4 13 | - torchvision 0.2 14 | - h5py 2.7 15 | - tqdm 4.19 16 | 17 | # Prepare data 18 | Please follow [Instruction][4] to prepare data. 19 | 20 | - In the `data` directory, execute `./download.sh` to download VQA v2 [questions, answers, and bottom-up features][4]. 21 | - For experimenting, using 36 fixed proposals is faster, at the expense of a bit of accuracy. Uncomment the relevant lines in `download.sh` and change the paths in `config.py` accordingly. Don't forget to set `output_size` in there to 36 to actually get the speed-up. 22 | - Prepare the data by running 23 | ``` 24 | python preprocess-images.py 25 | python preprocess-vocab.py 26 | ``` 27 | This creates an `h5py` database (95 GiB) containing the object proposal features and a vocabulary for questions and answers at the locations specified in `config.py`. 28 | - Download the pretrained object correlation score models 29 | - The proposed VCTree requires pretrained model to generate object correlation score f(xi, xj) as we mentioned in the Section3.1, such a pretrained model can be downloaded from [vgrel-19 (for 10-100 bounding box features)](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21620273&authkey=AKFuFsQ90tQO4q0), [vgrel-29 (for 36 bounbing box feautures)](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21620229&authkey=APSqYLYmGyfl3Mg). Since there are two different types of bottom-up-top-down features, we also have 2 corresponding pretrained object correlation score models. The object correlation score model is trained based on the faster-RCNN model (fixed) from [bottom-up-top-down model][2] and the code from [zjuchenlong/faster-rcnn.pytorch][3] (./vqa/feature_extractor/bottom_up_origin.py, to be more specific) 30 | - Put corresponding models under ./data depending on your config.output_size 31 | 32 | # Train your model 33 | Note that the proposed hybird learning strategy needs to manually iteratively change the config.use_rl = False or True and use -resume to load the model from previous stage (which is quite stupid). So you can just first start with config.use_rl = False 34 | 35 | The rest instruction is similar to original project [Cyanogenoid/vqa-counting][1] 36 | 37 | - Train the model in `model.py` with: 38 | ``` 39 | python train.py [optional-name] 40 | ``` 41 | This will alternate between one epoch of training on the train split and one epoch of validation on the validation split while printing the current training progress to stdout and saving logs in the `logs` directory. 42 | The logs contain the name of the model, training statistics, contents of `config.py`, model weights, evaluation information (per-question answer and accuracy), and question and answer vocabularies. 43 | - To view training progression of a model that is currently or has finished training. 44 | ``` 45 | python view-log.py 46 | ``` 47 | 48 | - To evaluate accuracy (VQA accuracy and balanced pair accuracy; see paper for details) in various categories, you can run 49 | ``` 50 | python eval-acc.py [ ...] 51 | ``` 52 | If you pass in multiple paths as arguments, this gives you standard deviations as well. 53 | To customise what categories are shown, you can modify the "accept conditions" for categories in `eval-acc.py`. 54 | 55 | 56 | # Sometime You Need To Know 57 | - Currently, the default setting is what I used to train my model reported in [Learning to Compose Dynamic Tree Structures for Visual Contexts][0]. However, since the model takes lots of epoches (about 80-100) to converge. It may takes a long time, so I didn't try too many settings for hyperparameters. After the CVPR deadline, I found that using larger size of hidden dimension at some places may further improve the performance a little bit. 58 | - The current training strategy of our VQA model is follow the paper [Learning to Count Objects in Natural Images for Visual Question Answering][5], i.e., (simple Linear + optim.Adam + continues decay at each batch + large number of epoches). However, we found that using an alternative Strategy (WeightNorm Linear + optim.Adamax + lr warm-up) will only take no more than 15 epoches to converge. So you can try this learning strategy if you want. You may check my [another project][6] about the reimplementations of some recent(2018) state-of-the-art VQA models using this strategy. 59 | 60 | # If this paper/project inspires your work, pls cite our work: 61 | ``` 62 | @inproceedings{tang2018learning, 63 | title={Learning to Compose Dynamic Tree Structures for Visual Contexts}, 64 | author={Tang, Kaihua and Zhang, Hanwang and Wu, Baoyuan and Luo, Wenhan and Liu, Wei}, 65 | booktitle= "Conference on Computer Vision and Pattern Recognition", 66 | year={2019} 67 | } 68 | ``` 69 | 70 | [0]: https://arxiv.org/abs/1812.01880 71 | [1]: https://github.com/Cyanogenoid/vqa-counting 72 | [2]: https://github.com/peteanderson80/bottom-up-attention 73 | [3]: https://github.com/zjuchenlong/faster-rcnn.pytorch 74 | [4]: https://github.com/Cyanogenoid/vqa-counting/tree/master/vqa-v2 75 | [5]: https://openreview.net/forum?id=B12Js_yRb 76 | [6]: https://github.com/KaihuaTang/VQA2.0-Recent-Approachs-2018.pytorch 77 | [7]: https://github.com/KaihuaTang/VCTree-Scene-Graph-Generation 78 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # paths 2 | qa_path = 'data' # directory containing the question and annotation jsons 3 | bottom_up_trainval_path = 'data/trainval' # directory containing the .tsv file(s) with bottom up features 4 | bottom_up_test_path = 'data/test2015' # directory containing the .tsv file(s) with bottom up features 5 | preprocessed_trainval_path = 'genome-trainval.h5' # path where preprocessed features from the trainval split are saved to and loaded from 6 | #preprocessed_trainval_path = '/media/tangkaihua/Disk1/vqa/genome-trainval36.h5' # path where preprocessed features from the trainval split are saved to and loaded from 7 | preprocessed_test_path = 'genome-test.h5' # path where preprocessed features from the test split are saved to and loaded from 8 | #preprocessed_test_path = '/media/tangkaihua/Disk1/vqa/genome-test.h5' # path where preprocessed features from the test split are saved to and loaded from 9 | vocabulary_path = 'vocab.json' # path where the used vocabularies for question and answers are saved to 10 | 11 | task = 'OpenEnded' 12 | dataset = 'mscoco' 13 | 14 | test_split = 'test2015' # either 'test-dev2015' or 'test2015' 15 | 16 | # preprocess config 17 | output_size = 100 # max number of object proposals per image 18 | output_features = 2048 # number of features in each object proposal 19 | 20 | # training config 21 | epochs = 200 22 | batch_size = 256 23 | initial_lr = 1.5e-3 24 | lr_halflife = 50000 # in iterations 25 | data_workers = 4 26 | max_answers = 3000 27 | 28 | # model config 29 | # the method we used to generate trees 30 | #gen_tree_mode = "overlap_tree" 31 | gen_tree_mode = "arbitrary_trees_transfer" 32 | # tree lstm hidden pool out mode, 33 | #poolout_mode = "sigmoid" 34 | poolout_mode = "softmax" 35 | use_rl = True 36 | log_softmax = False -------------------------------------------------------------------------------- /counting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from utils import PiecewiseLin 6 | 7 | 8 | class Counter(nn.Module): 9 | """ Counting module as proposed in [1]. 10 | Count the number of objects from a set of bounding boxes and a set of scores for each bounding box. 11 | This produces (self.objects + 1) number of count features. 12 | 13 | [1]: Yan Zhang, Jonathon Hare, Adam Prügel-Bennett: Learning to Count Objects in Natural Images for Visual Question Answering. 14 | https://openreview.net/forum?id=B12Js_yRb 15 | """ 16 | def __init__(self, objects, already_sigmoided=False): 17 | super().__init__() 18 | self.objects = objects 19 | self.already_sigmoided = already_sigmoided 20 | self.f = nn.ModuleList([PiecewiseLin(16) for _ in range(16)]) 21 | 22 | def forward(self, boxes, attention): 23 | """ Forward propagation of attention weights and bounding boxes to produce count features. 24 | `boxes` has to be a tensor of shape (n, 4, m) with the 4 channels containing the x and y coordinates of the top left corner and the x and y coordinates of the bottom right corner in this order. 25 | `attention` has to be a tensor of shape (n, m). Each value should be in [0, 1] if already_sigmoided is set to True, but there are no restrictions if already_sigmoided is set to False. This value should be close to 1 if the corresponding boundign box is relevant and close to 0 if it is not. 26 | n is the batch size, m is the number of bounding boxes per image. 27 | """ 28 | # only care about the highest scoring object proposals 29 | # the ones with low score will have a low impact on the count anyway 30 | boxes, attention = self.filter_most_important(self.objects, boxes, attention) 31 | # normalise the attention weights to be in [0, 1] 32 | if not self.already_sigmoided: 33 | attention = F.sigmoid(attention) 34 | 35 | relevancy = self.outer_product(attention) 36 | distance = 1 - self.iou(boxes, boxes) 37 | 38 | # intra-object dedup 39 | score = self.f[0](relevancy) * self.f[1](distance) 40 | 41 | # inter-object dedup 42 | dedup_score = self.f[3](relevancy) * self.f[4](distance) 43 | dedup_per_entry, dedup_per_row = self.deduplicate(dedup_score, attention) 44 | score = score / dedup_per_entry 45 | 46 | # aggregate the score 47 | # can skip putting this on the diagonal since we're just summing over it anyway 48 | correction = self.f[0](attention * attention) / dedup_per_row 49 | score = score.sum(dim=2).sum(dim=1, keepdim=True) + correction.sum(dim=1, keepdim=True) 50 | score = (score + 1e-20).sqrt() 51 | one_hot = self.to_one_hot(score) 52 | 53 | att_conf = (self.f[5](attention) - 0.5).abs() 54 | dist_conf = (self.f[6](distance) - 0.5).abs() 55 | conf = self.f[7](att_conf.mean(dim=1, keepdim=True) + dist_conf.mean(dim=2).mean(dim=1, keepdim=True)) 56 | 57 | return one_hot * conf 58 | 59 | def deduplicate(self, dedup_score, att): 60 | # using outer-diffs 61 | att_diff = self.outer_diff(att) 62 | score_diff = self.outer_diff(dedup_score) 63 | sim = self.f[2](1 - score_diff).prod(dim=1) * self.f[2](1 - att_diff) 64 | # similarity for each row 65 | row_sims = sim.sum(dim=2) 66 | # similarity for each entry 67 | all_sims = self.outer_product(row_sims) 68 | return all_sims, row_sims 69 | 70 | def to_one_hot(self, scores): 71 | """ Turn a bunch of non-negative scalar values into a one-hot encoding. 72 | E.g. with self.objects = 3, 0 -> [1 0 0 0], 2.75 -> [0 0 0.25 0.75]. 73 | """ 74 | # sanity check, I don't think this ever does anything (it certainly shouldn't) 75 | scores = scores.clamp(min=0, max=self.objects) 76 | # compute only on the support 77 | i = scores.long().data 78 | f = scores.frac() 79 | # target_l is the one-hot if the score is rounded down 80 | # target_r is the one-hot if the score is rounded up 81 | target_l = scores.data.new(i.size(0), self.objects + 1).fill_(0) 82 | target_r = scores.data.new(i.size(0), self.objects + 1).fill_(0) 83 | 84 | target_l.scatter_(dim=1, index=i.clamp(max=self.objects), value=1) 85 | target_r.scatter_(dim=1, index=(i + 1).clamp(max=self.objects), value=1) 86 | # interpolate between these with the fractional part of the score 87 | return (1 - f) * Variable(target_l) + f * Variable(target_r) 88 | 89 | def filter_most_important(self, n, boxes, attention): 90 | """ Only keep top-n object proposals, scored by attention weight """ 91 | attention, idx = attention.topk(n, dim=1, sorted=False) 92 | idx = idx.unsqueeze(dim=1).expand(boxes.size(0), boxes.size(1), idx.size(1)) 93 | boxes = boxes.gather(2, idx) 94 | return boxes, attention 95 | 96 | def outer(self, x): 97 | size = tuple(x.size()) + (x.size()[-1],) 98 | a = x.unsqueeze(dim=-1).expand(*size) 99 | b = x.unsqueeze(dim=-2).expand(*size) 100 | return a, b 101 | 102 | def outer_product(self, x): 103 | # Y_ij = x_i * x_j 104 | a, b = self.outer(x) 105 | return a * b 106 | 107 | def outer_diff(self, x): 108 | # like outer products, except taking the absolute difference instead 109 | # Y_ij = | x_i - x_j | 110 | a, b = self.outer(x) 111 | return (a - b).abs() 112 | 113 | def iou(self, a, b): 114 | # this is just the usual way to IoU from bounding boxes 115 | inter = self.intersection(a, b) 116 | area_a = self.area(a).unsqueeze(2).expand_as(inter) 117 | area_b = self.area(b).unsqueeze(1).expand_as(inter) 118 | return inter / (area_a + area_b - inter + 1e-12) 119 | 120 | def area(self, box): 121 | x = (box[:, 2, :] - box[:, 0, :]).clamp(min=0) 122 | y = (box[:, 3, :] - box[:, 1, :]).clamp(min=0) 123 | return x * y 124 | 125 | def intersection(self, a, b): 126 | size = (a.size(0), 2, a.size(2), b.size(2)) 127 | min_point = torch.max( 128 | a[:, :2, :].unsqueeze(dim=3).expand(*size), 129 | b[:, :2, :].unsqueeze(dim=2).expand(*size), 130 | ) 131 | max_point = torch.min( 132 | a[:, 2:, :].unsqueeze(dim=3).expand(*size), 133 | b[:, 2:, :].unsqueeze(dim=2).expand(*size), 134 | ) 135 | inter = (max_point - min_point).clamp(min=0) 136 | area = inter[:, 0, :, :] * inter[:, 1, :, :] 137 | return area 138 | 139 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path 4 | import re 5 | 6 | from PIL import Image 7 | import h5py 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | import numpy as np 12 | 13 | import config 14 | import utils 15 | 16 | 17 | preloaded_vocab = None 18 | question_type_dict = {'what animal is': 0, 'what type of': 1, 'which': 2, 'is he': 3, 'how many': 4, 'are the': 5, 'how many people are': 6, 'what room is': 7, 'why': 8, 'has': 9, 'where is the': 10, 'is this person': 11, 'is that a': 12, 'is the man': 13, 'what is the': 14, 'is there': 15, 'what is': 16, 'what color are the': 17, 'what is the color of the': 18, 'is this': 19, 'was': 20, 'can you': 21, 'are': 22, 'do you': 23, 'could': 24, 'none of the above': 25, 'what color is the': 26, 'what is the person': 27, 'are these': 28, 'what color': 29, 'how': 30, 'what sport is': 31, 'is the': 32, 'what color is': 33, 'are they': 34, 'what does the': 35, 'what is on the': 36, 'what': 37, 'what kind of': 38, 'is the woman': 39, 'what are the': 40, 'what is the woman': 41, 'do': 42, 'what is in the': 43, 'who is': 44, 'are there any': 45, 'what brand': 46, 'where are the': 47, 'why is the': 48, 'what is this': 49, 'what is the name': 50, 'does the': 51, 'how many people are in': 52, 'what are': 53, 'is': 54, 'is it': 55, 'is the person': 56, 'what time': 57, 'is there a': 58, 'what number is': 59, 'is this a': 60, 'does this': 61, 'what is the man': 62, 'is this an': 63, 'are there': 64} 19 | 20 | 21 | 22 | def get_loader(train=False, val=False, test=False): 23 | """ Returns a data loader for the desired split """ 24 | split = VQA( 25 | utils.path_for(train=train, val=val, test=test, question=True), 26 | utils.path_for(train=train, val=val, test=test, answer=True), 27 | config.preprocessed_trainval_path if not test else config.preprocessed_test_path, 28 | answerable_only=train, 29 | dummy_answers=test, 30 | ) 31 | loader = torch.utils.data.DataLoader( 32 | split, 33 | batch_size=config.batch_size, 34 | shuffle=train, # only shuffle the data in training 35 | pin_memory=True, 36 | num_workers=config.data_workers, 37 | collate_fn=collate_fn, 38 | ) 39 | return loader 40 | 41 | 42 | def collate_fn(batch): 43 | # put question lengths in descending order so that we can use packed sequences later 44 | batch.sort(key=lambda x: x[-1], reverse=True) 45 | return data.dataloader.default_collate(batch) 46 | 47 | 48 | class VQA(data.Dataset): 49 | """ VQA dataset, open-ended """ 50 | def __init__(self, questions_path, answers_path, image_features_path, answerable_only=False, dummy_answers=False): 51 | super(VQA, self).__init__() 52 | with open(questions_path, 'r') as fd: 53 | questions_json = json.load(fd) 54 | with open(answers_path, 'r') as fd: 55 | answers_json = json.load(fd) 56 | if preloaded_vocab: 57 | vocab_json = preloaded_vocab 58 | else: 59 | with open(config.vocabulary_path, 'r') as fd: 60 | vocab_json = json.load(fd) 61 | 62 | self.question_ids = [q['question_id'] for q in questions_json['questions']] 63 | 64 | # vocab 65 | self.vocab = vocab_json 66 | self.token_to_index = self.vocab['question'] 67 | self.answer_to_index = self.vocab['answer'] 68 | 69 | # q and a 70 | self.questions = list(prepare_questions(questions_json)) 71 | self.answers = list(prepare_answers(answers_json)) 72 | self.questions_type = list(prepare_questions_type(answers_json)) 73 | self.questions = [self._encode_question(q) for q in self.questions] 74 | self.answers = [self._encode_answers(a) for a in self.answers] 75 | 76 | # v 77 | self.image_features_path = image_features_path 78 | self.coco_id_to_index = self._create_coco_id_to_index() 79 | self.coco_ids = [q['image_id'] for q in questions_json['questions']] 80 | 81 | self.dummy_answers= dummy_answers 82 | 83 | # only use questions that have at least one answer? 84 | self.answerable_only = answerable_only 85 | if self.answerable_only: 86 | self.answerable = self._find_answerable(not self.answerable_only) 87 | 88 | @property 89 | def max_question_length(self): 90 | if not hasattr(self, '_max_length'): 91 | self._max_length = max(map(len, self.questions)) 92 | return self._max_length 93 | 94 | @property 95 | def num_tokens(self): 96 | return len(self.token_to_index) + 1 # add 1 for token at index 0 97 | 98 | def _create_coco_id_to_index(self): 99 | """ Create a mapping from a COCO image id into the corresponding index into the h5 file """ 100 | with h5py.File(self.image_features_path, 'r') as features_file: 101 | coco_ids = features_file['ids'][()] 102 | coco_id_to_index = {id: i for i, id in enumerate(coco_ids)} 103 | return coco_id_to_index 104 | 105 | def _check_integrity(self, questions, answers): 106 | """ Verify that we are using the correct data """ 107 | qa_pairs = list(zip(questions['questions'], answers['annotations'])) 108 | assert all(q['question_id'] == a['question_id'] for q, a in qa_pairs), 'Questions not aligned with answers' 109 | assert all(q['image_id'] == a['image_id'] for q, a in qa_pairs), 'Image id of question and answer don\'t match' 110 | assert questions['data_type'] == answers['data_type'], 'Mismatched data types' 111 | assert questions['data_subtype'] == answers['data_subtype'], 'Mismatched data subtypes' 112 | 113 | def _find_answerable(self, count=False): 114 | """ Create a list of indices into questions that will have at least one answer that is in the vocab """ 115 | answerable = [] 116 | if count: 117 | number_indices = torch.LongTensor([self.answer_to_index[str(i)] for i in range(0, 8)]) 118 | for i, answers in enumerate(self.answers): 119 | # store the indices of anything that is answerable 120 | if count: 121 | answers = answers[number_indices] 122 | answer_has_index = len(answers.nonzero()) > 0 123 | if answer_has_index: 124 | answerable.append(i) 125 | return answerable 126 | 127 | def _encode_question(self, question): 128 | """ Turn a question into a vector of indices and a question length """ 129 | vec = torch.zeros(self.max_question_length).long() 130 | for i, token in enumerate(question): 131 | index = self.token_to_index.get(token, 0) 132 | vec[i] = index 133 | return vec, len(question) 134 | 135 | def _encode_answers(self, answers): 136 | """ Turn an answer into a vector """ 137 | # answer vec will be a vector of answer counts to determine which answers will contribute to the loss. 138 | # this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up 139 | # to get the loss that is weighted by how many humans gave that answer 140 | answer_vec = torch.zeros(len(self.answer_to_index)) 141 | for answer in answers: 142 | index = self.answer_to_index.get(answer) 143 | if index is not None: 144 | answer_vec[index] += 1 145 | return answer_vec 146 | 147 | def _load_image(self, image_id): 148 | """ Load an image """ 149 | if not hasattr(self, 'features_file'): 150 | # Loading the h5 file has to be done here and not in __init__ because when the DataLoader 151 | # forks for multiple works, every child would use the same file object and fail 152 | # Having multiple readers using different file objects is fine though, so we just init in here. 153 | self.features_file = h5py.File(self.image_features_path, 'r') 154 | index = self.coco_id_to_index[image_id] 155 | img = self.features_file['features'][index] 156 | boxes = self.features_file['boxes'][index] 157 | widths = self.features_file['widths'][index] 158 | heights = self.features_file['heights'][index] 159 | boxes[0] = boxes[0] / widths 160 | boxes[2] = boxes[2] / widths 161 | boxes[1] = boxes[1] / heights 162 | boxes[3] = boxes[3] / heights 163 | return torch.from_numpy(img).unsqueeze(1), torch.from_numpy(boxes) 164 | 165 | def __getitem__(self, item): 166 | if self.answerable_only: 167 | item = self.answerable[item] 168 | q, q_length = self.questions[item] 169 | q_type = self.questions_type[item] 170 | if not self.dummy_answers: 171 | a = self.answers[item] 172 | else: 173 | # just return a dummy answer, it's not going to be used anyway 174 | a = 0 175 | image_id = self.coco_ids[item] 176 | v, b = self._load_image(image_id) 177 | # since batches are re-ordered for PackedSequence's, the original question order is lost 178 | # we return `item` so that the order of (v, q, a) triples can be restored if desired 179 | # without shuffling in the dataloader, these will be in the order that they appear in the q and a json's. 180 | return v, q, a, b, q_type, item, q_length 181 | 182 | def __len__(self): 183 | if self.answerable_only: 184 | return len(self.answerable) 185 | else: 186 | return len(self.questions) 187 | 188 | 189 | # this is used for normalizing questions 190 | _special_chars = re.compile('[^a-z0-9 ]*') 191 | 192 | # these try to emulate the original normalization scheme for answers 193 | _period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)') 194 | _comma_strip = re.compile(r'(\d)(,)(\d)') 195 | _punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!') 196 | _punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars))) 197 | _punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars)) 198 | 199 | 200 | def prepare_questions(questions_json): 201 | """ Tokenize and normalize questions from a given question json in the usual VQA format. """ 202 | questions = [q['question'] for q in questions_json['questions']] 203 | for question in questions: 204 | question = question.lower()[:-1] 205 | question = _special_chars.sub('', question) 206 | yield question.split(' ') 207 | 208 | def prepare_questions_type(answers_json): 209 | """ Get 65 Question Type for further analysis """ 210 | type_que = [a['question_type'] for a in answers_json['annotations']] 211 | type_ind = [question_type_dict[tp] for tp in type_que] 212 | return type_ind 213 | 214 | def prepare_answers(answers_json): 215 | """ Normalize answers from a given answer json in the usual VQA format. """ 216 | answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in answers_json['annotations']] 217 | # The only normalization that is applied to both machine generated answers as well as 218 | # ground truth answers is replacing most punctuation with space (see [0] and [1]). 219 | # Since potential machine generated answers are just taken from most common answers, applying the other 220 | # normalizations is not needed, assuming that the human answers are already normalized. 221 | # [0]: http://visualqa.org/evaluation.html 222 | # [1]: https://github.com/VT-vision-lab/VQA/blob/3849b1eae04a0ffd83f56ad6f70ebd0767e09e0f/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L96 223 | 224 | def process_punctuation(s): 225 | # the original is somewhat broken, so things that look odd here might just be to mimic that behaviour 226 | # this version should be faster since we use re instead of repeated operations on str's 227 | if _punctuation.search(s) is None: 228 | return s 229 | s = _punctuation_with_a_space.sub('', s) 230 | if re.search(_comma_strip, s) is not None: 231 | s = s.replace(',', '') 232 | s = _punctuation.sub(' ', s) 233 | s = _period_strip.sub('', s) 234 | return s.strip() 235 | 236 | for answer_list in answers: 237 | yield list(map(process_punctuation, answer_list)) 238 | 239 | 240 | class CocoImages(data.Dataset): 241 | """ Dataset for MSCOCO images located in a folder on the filesystem """ 242 | def __init__(self, path, transform=None): 243 | super(CocoImages, self).__init__() 244 | self.path = path 245 | self.id_to_filename = self._find_images() 246 | self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order 247 | print('found {} images in {}'.format(len(self), self.path)) 248 | self.transform = transform 249 | 250 | def _find_images(self): 251 | id_to_filename = {} 252 | for filename in os.listdir(self.path): 253 | if not filename.endswith('.jpg'): 254 | continue 255 | id_and_extension = filename.split('_')[-1] 256 | id = int(id_and_extension.split('.')[0]) 257 | id_to_filename[id] = filename 258 | return id_to_filename 259 | 260 | def __getitem__(self, item): 261 | id = self.sorted_ids[item] 262 | path = os.path.join(self.path, self.id_to_filename[id]) 263 | img = Image.open(path).convert('RGB') 264 | 265 | if self.transform is not None: 266 | img = self.transform(img) 267 | return id, img 268 | 269 | def __len__(self): 270 | return len(self.sorted_ids) 271 | -------------------------------------------------------------------------------- /data/co_occour_count.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Visual-Question-Answering/b6b0a8bdb01d45d36de3bded91db42544ad6a593/data/co_occour_count.npy -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # questions 4 | wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Train_mscoco.zip http://visualqa.org/data/mscoco/vqa/v2_Questions_Val_mscoco.zip http://visualqa.org/data/mscoco/vqa/v2_Questions_Test_mscoco.zip 5 | 6 | # answers 7 | wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Train_mscoco.zip http://visualqa.org/data/mscoco/vqa/v2_Annotations_Val_mscoco.zip 8 | 9 | # balanced pairs 10 | wget http://visualqa.org/data/mscoco/vqa/v2_Complementary_Pairs_Train_mscoco.zip http://visualqa.org/data/mscoco/vqa/v2_Complementary_Pairs_Val_mscoco.zip 11 | 12 | # bottom up features (https://github.com/peteanderson80/bottom-up-attention) 13 | #wget https://imagecaption.blob.core.windows.net/imagecaption/trainval.zip https://imagecaption.blob.core.windows.net/imagecaption/test2015.zip 14 | ## alternative bottom-up features: 36 fixed proposals per image instead of 10--100 adaptive proposals per image. 15 | #wget https://imagecaption.blob.core.windows.net/imagecaption/trainval_36.zip https://imagecaption.blob.core.windows.net/imagecaption/test2015_36.zip 16 | 17 | unzip "*.zip" 18 | -------------------------------------------------------------------------------- /eval-acc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os.path 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import utils 10 | import config 11 | 12 | 13 | q_path = utils.path_for(val=True, question=True) 14 | with open(q_path, 'r') as fd: 15 | q_json = json.load(fd) 16 | a_path = utils.path_for(val=True, answer=True) 17 | with open(a_path, 'r') as fd: 18 | a_json = json.load(fd) 19 | with open(os.path.join(config.qa_path, 'v2_mscoco_val2014_complementary_pairs.json')) as fd: 20 | pairs = json.load(fd) 21 | 22 | question_list = q_json['questions'] 23 | question_ids = [q['question_id'] for q in question_list] 24 | questions = [q['question'] for q in question_list] 25 | answer_list = a_json['annotations'] 26 | categories = [a['answer_type'] for a in answer_list] # {'yes/no', 'other', 'number'} 27 | accept_condition = { 28 | 'yes/no': (lambda x: id_to_cat[x] == 'yes/no'), 29 | 'other': (lambda x: id_to_cat[x] == 'other'), 30 | 'number': (lambda x: id_to_cat[x] == 'number'), 31 | 'count': (lambda x: id_to_question[x].lower().startswith('how many')), 32 | 'all': (lambda x: True), 33 | } 34 | 35 | statistics = defaultdict(list) 36 | for path in sys.argv[1:]: 37 | log = torch.load(path) 38 | ans = log['eval'] 39 | d = [(acc, ans) for (acc, ans, _) in sorted(zip(ans['accuracies'], ans['answers'], ans['idx']), key=lambda x: x[-1])] 40 | accs = map(lambda x: x[0], d) 41 | id_to_cat = dict(zip(question_ids, categories)) 42 | id_to_acc = dict(zip(question_ids, accs)) 43 | id_to_question = dict(zip(question_ids, questions)) 44 | 45 | for name, f in accept_condition.items(): 46 | for on_pairs in [False, True]: 47 | acc = [] 48 | if on_pairs: 49 | for a, b in pairs: 50 | if not (f(a) and f(b)): 51 | continue 52 | if id_to_acc[a] == id_to_acc[b] == 1: 53 | acc.append(1) 54 | else: 55 | acc.append(0) 56 | else: 57 | for x in question_ids: 58 | if not f(x): 59 | continue 60 | acc.append(id_to_acc[x]) 61 | acc = np.mean(acc) 62 | statistics[name, 'pair' if on_pairs else 'single'].append(acc) 63 | 64 | for (name, pairness), accs in statistics.items(): 65 | mean = np.mean(accs) 66 | std = np.std(accs, ddof=1) 67 | print('{} ({})\t: {:.2f}% +- {}'.format(name, pairness, 100 * mean, 100 * std)) 68 | -------------------------------------------------------------------------------- /gen_tree_net.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | class GenTreeModule(nn.Module): 10 | """ 11 | Calculate Scores to Generate Trees 12 | """ 13 | def __init__(self): 14 | super().__init__() 15 | # score fc: get 151 class distribution from bbox feature 16 | self.score_fc = nn.Linear(2048, 151) 17 | init.xavier_uniform_(self.score_fc.weight) 18 | self.score_fc.bias.data.zero_() 19 | # context for calculating gen-tree score 20 | self.context = LinearizedContext() 21 | 22 | def forward(self, visual_feat, bbox): 23 | """ 24 | visual_feat: [batch, 2048, num_box] 25 | bbox: [batch, 4, num_box] (x1,y1,x2,y2) 26 | """ 27 | batch_size, feat_size, box_num = visual_feat.shape 28 | visual_feat = torch.transpose(visual_feat, 1, 2).contiguous() # [batch, num_box, feat_size] 29 | assert(visual_feat.shape[2] == feat_size) 30 | visual_feat = visual_feat.view(-1, feat_size) # [batch * num_box, feat_size] 31 | # prepare obj distribution 32 | obj_predict = self.score_fc(visual_feat) 33 | #obj_distrib = F.softmax(obj_predict, dim=1)[:, 1:] 34 | assert(obj_predict.shape[1] == 151) # [batch * num_box, 150] 35 | # prepare bbox feature 36 | bbox_trans = torch.transpose(bbox, 1, 2).contiguous() 37 | assert(bbox_trans.shape[2] == 4) 38 | bbox_feat = bbox_trans.view(-1, 4) 39 | bbox_embed = get_box_info(bbox_feat) # [batch * num_box, 8] 40 | #print('bbox_embed', bbox_embed) 41 | # prepare overlap feature 42 | overlab_embed = get_overlap_info(bbox_trans) 43 | #print('overlab_embed: ', overlab_embed) 44 | 45 | return self.context(visual_feat, obj_predict, bbox_embed, overlab_embed, batch_size, box_num) 46 | 47 | def get_label(self, visual_feat): 48 | """ 49 | visual_feat: [batch, 2048, num_box] 50 | output[0]: object distribution, [batch_size, box_num, 151] 51 | output[1]: object label, [batch_size, box_num] 52 | """ 53 | batch_size, feat_size, box_num = visual_feat.shape 54 | visual_feat = torch.transpose(visual_feat, 1, 2).contiguous() # [batch, num_box, feat_size] 55 | assert(visual_feat.shape[2] == feat_size) 56 | visual_feat = visual_feat.view(-1, feat_size) # [batch * num_box, feat_size] 57 | obj_predict = self.score_fc(visual_feat) 58 | obj_distrib = F.softmax(obj_predict, dim=1) 59 | obj_label = obj_distrib.max(1)[1] 60 | return obj_distrib.view(batch_size, box_num, -1), obj_label.view(batch_size, box_num) 61 | 62 | 63 | class LinearizedContext(nn.Module): 64 | """ 65 | The name is meaningless, we just need to maintain the same structure to load transferred model. 66 | """ 67 | def __init__(self): 68 | super().__init__() 69 | self.num_classes = 151 70 | self.embed_dim = 200 71 | self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim) 72 | #self.virtual_node_embed = nn.Embedding(1, self.embed_dim) 73 | 74 | self.co_occour = np.load('data/co_occour_count.npy') 75 | self.co_occour = self.co_occour / self.co_occour.sum() 76 | 77 | self.rl_input_size = 256 78 | self.rl_hidden_size = 256 79 | self.feat_preprocess_net = RLFeatPreprocessNet(feat_size=2048, embed_size=self.embed_dim, bbox_size=8, overlap_size=6, output_size=self.rl_input_size) 80 | 81 | self.rl_sub = nn.Linear(self.rl_input_size, self.rl_hidden_size) 82 | self.rl_obj = nn.Linear(self.rl_input_size, self.rl_hidden_size) 83 | self.rl_scores = nn.Linear(self.rl_hidden_size * 3 + 3, 1) # (left child score, right child score) 84 | 85 | init.xavier_uniform_(self.rl_sub.weight) 86 | init.xavier_uniform_(self.rl_obj.weight) 87 | init.xavier_uniform_(self.rl_scores.weight) 88 | 89 | self.rl_sub.bias.data.zero_() 90 | self.rl_obj.bias.data.zero_() 91 | self.rl_scores.bias.data.zero_() 92 | 93 | def forward(self, visual_feat, obj_predict, bbox_embed, overlap_embed, batch_size, box_num): 94 | """ 95 | total = batch_size * box_num 96 | visual_feat: [total, 2048] 97 | obj_predict: [total, 151] 98 | bbox_embed: [total, 8] 99 | overlap_embed: [total, 6] 100 | """ 101 | # object label embed and prediction 102 | num_class = 150 103 | obj_embed = F.softmax(obj_predict, dim=1) @ self.obj_embed.weight 104 | obj_distrib = F.softmax(obj_predict, dim=1)[:,1:].view(batch_size, box_num, num_class) 105 | # co_occour 106 | cooccour_matrix = Variable(torch.from_numpy(self.co_occour).float().cuda()) 107 | class_scores = cooccour_matrix.sum(1).view(-1) 108 | 109 | # preprocessed features 110 | prepro_feat = self.feat_preprocess_net(visual_feat, obj_embed, bbox_embed, overlap_embed) 111 | rl_sub_feat = self.rl_sub(prepro_feat) 112 | rl_obj_feat = self.rl_obj(prepro_feat) 113 | rl_sub_feat = F.relu(rl_sub_feat).view(batch_size, box_num, -1) 114 | rl_obj_feat = F.relu(rl_obj_feat).view(batch_size, box_num, -1) 115 | 116 | # score matrix generation 117 | hidden_size = self.rl_hidden_size 118 | tree_matrix = Variable(torch.FloatTensor(batch_size, box_num * box_num).zero_().cuda()) 119 | for i in range(batch_size): 120 | sliced_sub_feat = rl_sub_feat[i].view(1, box_num, hidden_size).expand(box_num, box_num, hidden_size) 121 | sliced_obj_feat = rl_obj_feat[i].view(box_num, 1, hidden_size).expand(box_num, box_num, hidden_size) 122 | sliced_sub_dist = obj_distrib[i].view(1, box_num, num_class).expand(box_num, box_num, num_class).contiguous().view(-1, num_class) 123 | sliced_obj_dist = obj_distrib[i].view(box_num, 1, num_class).expand(box_num, box_num, num_class).contiguous().view(-1, num_class) 124 | sliced_dot_dist = sliced_sub_dist.view(-1, num_class, 1) @ sliced_obj_dist.view(-1, 1, num_class) # [num_pair, 150, 150] 125 | sliced_dot_score = sliced_dot_dist * cooccour_matrix # [num_pair, 150, 150] 126 | 127 | sliced_pair_score = sliced_dot_score.view(box_num * box_num, num_class * num_class).sum(1).view(box_num, box_num, 1) 128 | sliced_sub_score = (sliced_sub_dist * class_scores).sum(1).view(box_num, box_num, 1) 129 | sliced_obj_score = (sliced_obj_dist * class_scores).sum(1).view(box_num, box_num, 1) 130 | sliced_pair_feat = torch.cat((sliced_sub_feat * sliced_obj_feat, sliced_sub_feat, sliced_obj_feat, sliced_pair_score, sliced_sub_score, sliced_obj_score), 2) 131 | 132 | sliced_pair_output = self.rl_scores(sliced_pair_feat.view(-1, hidden_size * 3 + 3)) 133 | sliced_pair_gates = F.sigmoid(sliced_pair_output).view(-1,1) # (relation prob) 134 | sliced_rel_scores = (sliced_pair_score.view(-1,1) * sliced_pair_gates).view(-1) 135 | 136 | tree_matrix[i, :] = sliced_rel_scores 137 | 138 | return tree_matrix.view(batch_size, box_num, box_num) 139 | 140 | class RLFeatPreprocessNet(nn.Module): 141 | """ 142 | Preprocess Features 143 | 1. visual feature 144 | 2. label prediction embed feature 145 | 3. box embed 146 | 4. overlap embed 147 | """ 148 | def __init__(self, feat_size, embed_size, bbox_size, overlap_size, output_size): 149 | super(RLFeatPreprocessNet, self).__init__() 150 | self.feature_size = feat_size 151 | self.embed_size = embed_size 152 | self.box_info_size = bbox_size 153 | self.overlap_info_size = overlap_size 154 | self.output_size = output_size 155 | 156 | # linear layers 157 | self.resize_feat = nn.Linear(self.feature_size, int(output_size / 4)) 158 | self.resize_embed = nn.Linear(self.embed_size, int(output_size / 4)) 159 | self.resize_box = nn.Linear(self.box_info_size, int(output_size / 4)) 160 | self.resize_overlap = nn.Linear(self.overlap_info_size, int(output_size / 4)) 161 | 162 | # init 163 | self.resize_feat.weight.data.normal_(0, 0.001) 164 | self.resize_embed.weight.data.normal_(0, 0.01) 165 | self.resize_box.weight.data.normal_(0, 1) 166 | self.resize_overlap.weight.data.normal_(0, 1) 167 | self.resize_feat.bias.data.zero_() 168 | self.resize_embed.bias.data.zero_() 169 | self.resize_box.bias.data.zero_() 170 | self.resize_overlap.bias.data.zero_() 171 | 172 | def forward(self, obj_feat, obj_embed, box_info, overlap_info): 173 | resized_obj = self.resize_feat(obj_feat) 174 | resized_embed = self.resize_embed(obj_embed) 175 | resized_box = self.resize_box(box_info) 176 | resized_overlap = self.resize_overlap(overlap_info) 177 | 178 | output_feat = torch.cat((resized_obj, resized_embed, resized_box, resized_overlap), 1) 179 | return output_feat 180 | 181 | def get_box_info(boxes): 182 | """ 183 | input: [batch_size, (x1,y1,x2,y2)] 184 | output: [batch_size, (x1,y1,x2,y2,cx,cy,w,h)] 185 | """ 186 | return torch.cat((boxes, center_size(boxes)), 1) 187 | 188 | def center_size(boxes): 189 | """ Convert prior_boxes to (cx, cy, w, h) 190 | representation for comparison to center-size form ground truth data. 191 | Args: 192 | boxes: (tensor) point_form boxes 193 | Return: 194 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 195 | """ 196 | wh = boxes[:, 2:] - boxes[:, :2] 197 | return torch.cat((boxes[:, :2] + 0.5 * wh, wh), 1) 198 | 199 | def get_overlap_info(bbox): 200 | """ 201 | input: 202 | box_priors: [batch_size, number_obj, 4] 203 | output: [number_object, 6] 204 | number of overlapped obj (self not included) 205 | sum of all intersection area (self not included) 206 | sum of IoU (Intersection over Union) 207 | average of all intersection area (self not included) 208 | average of IoU (Intersection over Union) 209 | roi area 210 | """ 211 | batch_size, num_obj, bsize = bbox.shape 212 | # generate input feat 213 | overlap_info = Variable(torch.FloatTensor(batch_size, num_obj, 6).zero_().cuda()) # each obj has how many overlaped objects 214 | reverse_eye = Variable(1.0 - torch.eye(num_obj).float().cuda()) # removed diagonal elements 215 | for i in range(batch_size): 216 | sliced_bbox = bbox[i].view(num_obj, bsize) 217 | sliced_intersection = bbox_intersections(sliced_bbox, sliced_bbox) 218 | sliced_overlap = bbox_overlaps(sliced_bbox, sliced_bbox, sliced_intersection) 219 | sliced_area = bbox_area(sliced_bbox) 220 | # removed diagonal elements 221 | sliced_intersection = sliced_intersection * reverse_eye 222 | sliced_overlap = sliced_overlap * reverse_eye 223 | # assign value 224 | overlap_info[i, :, 0] = (sliced_intersection > 0.0).float().sum(1) 225 | overlap_info[i, :, 1] = sliced_intersection.sum(1) 226 | overlap_info[i, :, 2] = sliced_overlap.sum(1) 227 | overlap_info[i, :, 3] = overlap_info[i, :, 1] / (overlap_info[i, :, 0] + 1e-9) 228 | overlap_info[i, :, 4] = overlap_info[i, :, 2] / (overlap_info[i, :, 0] + 1e-9) 229 | overlap_info[i, :, 5] = sliced_area 230 | 231 | return overlap_info.view(batch_size * num_obj, 6) 232 | 233 | def bbox_area(bbox): 234 | """ 235 | bbox: (K, 4) ndarray of float 236 | area: (k) 237 | """ 238 | K = bbox.size(0) 239 | bbox_area = ((bbox[:,2] - bbox[:,0]) * (bbox[:,3] - bbox[:,1])).view(K) 240 | return bbox_area 241 | 242 | def bbox_intersections(box_a, box_b): 243 | """ 244 | Args: 245 | box_a: (tensor) bounding boxes, Shape: [A,4]. 246 | box_b: (tensor) bounding boxes, Shape: [B,4]. 247 | Return: 248 | (tensor) intersection area, Shape: [A,B]. 249 | """ 250 | A = box_a.size(0) 251 | B = box_b.size(0) 252 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 253 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 254 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 255 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 256 | inter = torch.clamp((max_xy - min_xy), min=0) 257 | return inter[:, :, 0] * inter[:, :, 1] 258 | 259 | def bbox_overlaps(box_a, box_b, inter=None): 260 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 261 | is simply the intersection over union of two boxes. Here we operate on 262 | ground truth boxes and default boxes. 263 | E.g.: 264 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 265 | Args: 266 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 267 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 268 | Return: 269 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 270 | """ 271 | if inter is None: 272 | inter = bbox_intersections(box_a, box_b) 273 | area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 274 | area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 275 | union = area_a + area_b - inter 276 | return inter / union # [A,B] -------------------------------------------------------------------------------- /logs/.dummy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaihuaTang/VCTree-Visual-Question-Answering/b6b0a8bdb01d45d36de3bded91db42544ad6a593/logs/.dummy -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.autograd import Variable 6 | from torch.nn.utils import weight_norm 7 | from torch.nn.utils.rnn import pack_padded_sequence 8 | 9 | import config 10 | import counting 11 | import tree_feature 12 | 13 | 14 | class Net(nn.Module): 15 | """ Based on ``Show, Ask, Attend, and Answer: A Strong Baseline For Visual Question Answering'' [0] 16 | 17 | [0]: https://arxiv.org/abs/1704.03162 18 | """ 19 | 20 | def __init__(self, embedding_tokens): 21 | super(Net, self).__init__() 22 | question_features = 1024 23 | vision_features = config.output_features 24 | glimpses = 2 25 | self.num_models = 2 26 | objects = 10 27 | tree_hidden_size = 1024 28 | 29 | self.text = TextProcessor( 30 | embedding_tokens=embedding_tokens, 31 | embedding_features=300, 32 | lstm_features=question_features, 33 | drop=0.5, 34 | ) 35 | self.attention = Attention( 36 | v_features=vision_features, 37 | q_features=question_features, 38 | mid_features=512, 39 | glimpses=glimpses, 40 | drop=0.5, 41 | ) 42 | 43 | self.tree_attention = Attention( 44 | v_features=tree_hidden_size, 45 | q_features=question_features, 46 | mid_features=512, 47 | glimpses=glimpses, 48 | drop=0.5, 49 | ) 50 | 51 | self.model_attention = ModelAttention( 52 | q_features=question_features, 53 | mid_features=1024, 54 | q_type_num=65, 55 | num_models=self.num_models, 56 | drop=0.5, 57 | ) 58 | self.classifier = Classifier( 59 | in_features=(glimpses * vision_features, question_features), 60 | mid_features=1024, 61 | num_module=self.num_models, 62 | out_features=config.max_answers, 63 | tree_features=tree_hidden_size * glimpses, 64 | count_features=objects + 1, 65 | drop=0.5, 66 | ) 67 | self.counter = counting.Counter(objects) 68 | self.tree_lstm = tree_feature.TreeFeature(objects, vision_features, tree_hidden_size) 69 | 70 | for m in self.modules(): 71 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 72 | init.xavier_uniform_(m.weight) 73 | if m.bias is not None: 74 | m.bias.data.zero_() 75 | 76 | def forward(self, v_origin, b, q, q_len, q_type): 77 | """ 78 | input V: [batch, 2048, 1, 36] 79 | input B: [batch, 4, 36] 80 | input Q: [batch, 23] 81 | input Q len: [batch] 82 | input Q type: [batch] 83 | """ 84 | # question embedding 85 | q = self.text(q, list(q_len.data)) # [batch, 1024] 86 | # normalized visual feature 87 | v_norm = v_origin / (v_origin.norm(p=2, dim=1, keepdim=True) + 1e-12).expand_as(v_origin) # [batch, 2048, 1, 36] 88 | # attention 89 | a = self.attention(v_norm, q) # [batch, num_glimpse, 1, 36] 90 | v = apply_attention(v_norm, a) # [batch, 4096] 91 | # model attention 92 | model_att = self.model_attention(q, q_type) # [batch, num_model] 93 | assert(model_att.shape[1] == self.num_models * 1024) 94 | 95 | # this is where the counting component is used 96 | # pick out the first attention map 97 | a1 = a[:, 0, :, :].contiguous().view(a.size(0), -1) # [batch, 36] 98 | #a2 = a[:, 1, :, :].contiguous().view(a.size(0), -1) # [batch, 36] 99 | # give it and the bounding boxes to the component 100 | tree_feat, rl_loss, entropy_loss = self.tree_lstm(b, a1, v_norm, v_origin, q_type) # [batch, 512, 1, 10] 101 | #print('tree_feat: ', tree_feat.shape) 102 | tree_att = self.tree_attention(tree_feat, q) 103 | if config.poolout_mode == "softmax": 104 | att_t_f = apply_attention(tree_feat, tree_att) 105 | elif config.poolout_mode == "sigmoid": 106 | att_t_f = apply_attention(tree_feat, tree_att, use_softmax=False) 107 | else: 108 | print('Error') 109 | 110 | #count = self.counter(b, a2) # [batch, 11] 111 | 112 | answer = self.classifier(v, q, att_t_f, model_att) # [batch, 3000] 113 | return answer, rl_loss, entropy_loss 114 | 115 | class Fusion(nn.Module): 116 | """ Crazy multi-modal fusion: negative squared difference minus relu'd sum 117 | """ 118 | def __init__(self): 119 | super().__init__() 120 | 121 | def forward(self, x, y): 122 | # found through grad student descent ;) 123 | return - (x - y)**2 + F.relu(x + y) 124 | 125 | class Classifier(nn.Sequential): 126 | def __init__(self, in_features, mid_features, num_module, tree_features, count_features, out_features, drop=0.0): 127 | super(Classifier, self).__init__() 128 | self.drop = nn.Dropout(drop) 129 | self.relu = nn.ReLU() 130 | self.fusion = Fusion() 131 | self.lin11 = nn.Linear(in_features[0], mid_features) 132 | self.lin12 = nn.Linear(in_features[1], mid_features) 133 | self.lin_t1 = nn.Linear(tree_features, mid_features) 134 | self.lin_t2 = nn.Linear(in_features[1], mid_features) 135 | self.lin_c = nn.Linear(count_features, mid_features) 136 | self.lin2 = nn.Linear(mid_features * num_module, out_features) 137 | 138 | self.bn1 = nn.BatchNorm1d(mid_features) 139 | self.bn2 = nn.BatchNorm1d(mid_features) 140 | self.bn3 = nn.BatchNorm1d(mid_features) 141 | 142 | def forward(self, x, y, t, model_att): 143 | x = self.fusion(self.lin11(self.drop(x)), self.lin12(self.drop(y))) 144 | t = self.fusion(self.lin_t1(self.drop(t)), self.lin_t2(self.drop(y))) 145 | #c = self.relu(self.lin_c(c)) 146 | #out = self.bn1(x) * model_att[:,0].view(-1, 1) + self.bn2(t) * model_att[:,1].view(-1, 1) + self.bn3(c) * model_att[:,2].view(-1, 1) 147 | out = torch.cat((self.bn1(x), self.bn2(t)), 1) * model_att 148 | out = self.lin2(self.drop(out)) 149 | return out 150 | 151 | 152 | class TextProcessor(nn.Module): 153 | def __init__(self, embedding_tokens, embedding_features, lstm_features, drop=0.0): 154 | super(TextProcessor, self).__init__() 155 | self.embedding = nn.Embedding(embedding_tokens, embedding_features, padding_idx=0) 156 | self.drop = nn.Dropout(drop) 157 | self.tanh = nn.Tanh() 158 | self.lstm = nn.GRU(input_size=embedding_features, 159 | hidden_size=lstm_features, 160 | num_layers=1) 161 | self.features = lstm_features 162 | 163 | self._init_lstm(self.lstm.weight_ih_l0) 164 | self._init_lstm(self.lstm.weight_hh_l0) 165 | self.lstm.bias_ih_l0.data.zero_() 166 | self.lstm.bias_hh_l0.data.zero_() 167 | 168 | init.xavier_uniform_(self.embedding.weight) 169 | 170 | def _init_lstm(self, weight): 171 | for w in weight.chunk(3, 0): 172 | init.xavier_uniform_(w) 173 | 174 | def forward(self, q, q_len): 175 | embedded = self.embedding(q) 176 | tanhed = self.tanh(self.drop(embedded)) 177 | packed = pack_padded_sequence(tanhed, q_len, batch_first=True) 178 | _, h = self.lstm(packed) 179 | return h.squeeze(0) 180 | 181 | 182 | class Attention(nn.Module): 183 | def __init__(self, v_features, q_features, mid_features, glimpses, drop=0.0): 184 | super(Attention, self).__init__() 185 | self.v_conv = nn.Conv2d(v_features, mid_features, 1, bias=False) # let self.lin take care of bias 186 | self.q_lin = nn.Linear(q_features, mid_features) 187 | self.x_conv = nn.Conv2d(mid_features, glimpses, 1) 188 | 189 | self.drop = nn.Dropout(drop) 190 | self.relu = nn.ReLU(inplace=True) 191 | self.fusion = Fusion() 192 | 193 | def forward(self, v, q): 194 | #q_in = q 195 | v = self.v_conv(self.drop(v)) 196 | q = self.q_lin(self.drop(q)) 197 | q = tile_2d_over_nd(q, v) 198 | x = self.fusion(v, q) 199 | x = self.x_conv(self.drop(x)) 200 | return x 201 | 202 | class ModelAttention(nn.Module): 203 | def __init__(self, q_features, mid_features, q_type_num, num_models, drop=0.0): 204 | super(ModelAttention, self).__init__() 205 | self.q_lin_1 = nn.Linear(q_features, 256) 206 | 207 | self.q_type_1 = nn.Embedding(q_type_num, 256) 208 | self.q_type_2 = nn.Linear(256, 256) 209 | 210 | self.lin_fuse = nn.Linear(256, num_models*mid_features) 211 | self.bn1 = nn.BatchNorm1d(256) 212 | #self.bn2 = nn.BatchNorm1d(256) 213 | 214 | self.drop = nn.Dropout(drop) 215 | self.relu = nn.ReLU() 216 | self.fusion = Fusion() 217 | 218 | init.xavier_uniform_(self.q_lin_1.weight) 219 | #init.xavier_uniform_(self.q_lin_2.weight) 220 | init.xavier_uniform_(self.q_type_1.weight) 221 | init.xavier_uniform_(self.q_type_2.weight) 222 | 223 | self.q_lin_1.bias.data.zero_() 224 | #self.q_lin_2.bias.data.zero_() 225 | self.q_type_2.bias.data.zero_() 226 | 227 | 228 | def forward(self, q, q_type): 229 | q = self.q_lin_1(self.drop(q)) # [batch, 256] 230 | 231 | q_t = self.q_type_1(q_type) 232 | q_t = self.q_type_2(self.drop(q_t)) 233 | 234 | fused_q = self.fusion(q, q_t) 235 | fused_q = self.lin_fuse(self.drop(self.bn1(fused_q))) 236 | 237 | att = F.sigmoid(fused_q) 238 | return att 239 | 240 | def apply_attention(input, attention, use_softmax=True): 241 | """ Apply any number of attention maps over the input. 242 | The attention map has to have the same size in all dimensions except dim=1. 243 | """ 244 | n, c = input.size()[:2] 245 | glimpses = attention.size(1) 246 | 247 | # flatten the spatial dims into the third dim, since we don't need to care about how they are arranged 248 | input = input.view(n, c, -1) 249 | attention = attention.view(n, glimpses, -1) 250 | s = input.size(2) 251 | 252 | # apply a softmax to each attention map separately 253 | # since softmax only takes 2d inputs, we have to collapse the first two dimensions together 254 | # so that each glimpse is normalized separately 255 | attention = attention.view(n * glimpses, -1) 256 | if use_softmax: 257 | attention = F.softmax(attention, dim=1) 258 | else: 259 | attention = F.sigmoid(attention) 260 | 261 | # apply the weighting by creating a new dim to tile both tensors over 262 | target_size = [n, glimpses, c, s] 263 | input = input.view(n, 1, c, s).expand(*target_size) 264 | attention = attention.view(n, glimpses, 1, s).expand(*target_size) 265 | weighted = input * attention 266 | # sum over only the spatial dimension 267 | weighted_mean = weighted.sum(dim=3, keepdim=True) 268 | # the shape at this point is (n, glimpses, c, 1) 269 | return weighted_mean.view(n, -1) 270 | 271 | 272 | def tile_2d_over_nd(feature_vector, feature_map): 273 | """ Repeat the same feature vector over all spatial positions of a given feature map. 274 | The feature vector should have the same batch size and number of features as the feature map. 275 | """ 276 | n, c = feature_vector.size() 277 | spatial_sizes = feature_map.size()[2:] 278 | tiled = feature_vector.view(n, c, *([1] * len(spatial_sizes))).expand(n, c, *spatial_sizes) 279 | return tiled 280 | -------------------------------------------------------------------------------- /preprocess-features.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import base64 4 | import os 5 | import csv 6 | import itertools 7 | 8 | csv.field_size_limit(sys.maxsize) 9 | 10 | import h5py 11 | import torch.utils.data 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | import config 16 | import data 17 | import utils 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--test', action='store_true') 23 | args = parser.parse_args() 24 | 25 | FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] 26 | 27 | features_shape = ( 28 | 82783 + 40504 if not args.test else 81434, # number of images in trainval or in test 29 | config.output_features, 30 | config.output_size, 31 | ) 32 | boxes_shape = ( 33 | features_shape[0], 34 | 4, 35 | config.output_size, 36 | ) 37 | 38 | if not args.test: 39 | path = config.preprocessed_trainval_path 40 | else: 41 | path = config.preprocessed_test_path 42 | with h5py.File(path, libver='latest') as fd: 43 | features = fd.create_dataset('features', shape=features_shape, dtype='float32') 44 | boxes = fd.create_dataset('boxes', shape=boxes_shape, dtype='float32') 45 | coco_ids = fd.create_dataset('ids', shape=(features_shape[0],), dtype='int32') 46 | widths = fd.create_dataset('widths', shape=(features_shape[0],), dtype='int32') 47 | heights = fd.create_dataset('heights', shape=(features_shape[0],), dtype='int32') 48 | 49 | readers = [] 50 | if not args.test: 51 | path = config.bottom_up_trainval_path 52 | else: 53 | path = config.bottom_up_test_path 54 | for filename in os.listdir(path): 55 | if not '.tsv' in filename: 56 | continue 57 | full_filename = os.path.join(path, filename) 58 | fd = open(full_filename, 'r') 59 | reader = csv.DictReader(fd, delimiter='\t', fieldnames=FIELDNAMES) 60 | readers.append(reader) 61 | 62 | reader = itertools.chain.from_iterable(readers) 63 | for i, item in enumerate(tqdm(reader, total=features_shape[0])): 64 | coco_ids[i] = int(item['image_id']) 65 | widths[i] = int(item['image_w']) 66 | heights[i] = int(item['image_h']) 67 | 68 | buf = base64.decodestring(item['features'].encode('utf8')) 69 | array = np.frombuffer(buf, dtype='float32') 70 | array = array.reshape((-1, config.output_features)).transpose() 71 | features[i, :, :array.shape[1]] = array 72 | 73 | buf = base64.decodestring(item['boxes'].encode('utf8')) 74 | array = np.frombuffer(buf, dtype='float32') 75 | array = array.reshape((-1, 4)).transpose() 76 | boxes[i, :, :array.shape[1]] = array 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /preprocess-vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from collections import Counter 4 | import itertools 5 | 6 | import config 7 | import data 8 | import utils 9 | 10 | 11 | def extract_vocab(iterable, top_k=None, start=0): 12 | """ Turns an iterable of list of tokens into a vocabulary. 13 | These tokens could be single answers or word tokens in questions. 14 | """ 15 | all_tokens = itertools.chain.from_iterable(iterable) 16 | counter = Counter(all_tokens) 17 | if top_k: 18 | most_common = counter.most_common(top_k) 19 | most_common = (t for t, c in most_common) 20 | else: 21 | most_common = counter.keys() 22 | # descending in count, then lexicographical order 23 | tokens = sorted(most_common, key=lambda x: (counter[x], x), reverse=True) 24 | vocab = {t: i for i, t in enumerate(tokens, start=start)} 25 | return vocab 26 | 27 | 28 | def main(): 29 | questions = utils.path_for(train=True, question=True) 30 | answers = utils.path_for(train=True, answer=True) 31 | 32 | with open(questions, 'r') as fd: 33 | questions = json.load(fd) 34 | with open(answers, 'r') as fd: 35 | answers = json.load(fd) 36 | 37 | questions = list(data.prepare_questions(questions)) 38 | answers = list(data.prepare_answers(answers)) 39 | 40 | question_vocab = extract_vocab(questions, start=1) 41 | answer_vocab = extract_vocab(answers, top_k=config.max_answers) 42 | 43 | vocabs = { 44 | 'question': question_vocab, 45 | 'answer': answer_vocab, 46 | } 47 | with open(config.vocabulary_path, 'w') as fd: 48 | json.dump(vocabs, fd) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /q_type_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import itertools 5 | from torch.autograd import Variable 6 | import tree_def, tree_def 7 | 8 | class AccumulatingModule(nn.Module): 9 | """ 10 | Accumulating Module 11 | According to different detailed question types (65 types in total), each question type will have an individual [num_ot, num_ot] matrix. 12 | The parameters of [num_ot, num_ot] matrix are updated by statistical accumulating (not backpropagation). 13 | 14 | num_qt: number of question types -> 65 15 | num_ot: number of object types -> 151 16 | """ 17 | def __init__(self, num_qt, num_ot): 18 | super().__init__() 19 | self.num_qt = num_qt 20 | self.num_ot = num_ot 21 | self.pair_num = 90 22 | self.score_matrix = nn.Parameter(torch.zeros(num_qt, self.pair_num, num_ot, num_ot).float().fill_(1e-12)) 23 | 24 | def batch_update_matrix(self, obj_label, qus_type, attention): 25 | """ 26 | obj_label: [batch_size, box_num] 27 | attention: [batch_size, box_num] 28 | qus_type: [batch_size] 29 | """ 30 | obj_label = obj_label.data 31 | qus_type = qus_type.data 32 | attention = attention.detach().data 33 | 34 | batch_size, box_num = obj_label.shape 35 | ol1 = obj_label.view(batch_size, 1, box_num).expand(batch_size, box_num, box_num).contiguous() 36 | ol2 = obj_label.view(batch_size, box_num, 1).expand(batch_size, box_num, box_num).contiguous() 37 | eye = torch.eye(box_num).cuda().long().view(1, box_num, box_num).expand(batch_size, box_num, box_num).contiguous() 38 | ol1 = ol1.view(-1)[torch.nonzero((1-eye).view(-1))].view(-1) 39 | ol2 = ol2.view(-1)[torch.nonzero((1-eye).view(-1))].view(-1) 40 | assert ol1.shape[0] == batch_size * box_num * (box_num - 1) 41 | assert ol2.shape[0] == batch_size * box_num * (box_num - 1) 42 | qt = qus_type.view(batch_size, 1).expand(batch_size, box_num * (box_num - 1)).contiguous().view(-1) 43 | ra = torch.range(0, self.pair_num-1).cuda().long().view(1, self.pair_num).expand(batch_size, self.pair_num).contiguous().view(-1) 44 | # score 45 | at1 = attention.view(batch_size, 1, box_num).expand(batch_size, box_num, box_num).contiguous() 46 | at2 = attention.view(batch_size, box_num, 1).expand(batch_size, box_num, box_num).contiguous() 47 | at1 = at1.view(-1)[torch.nonzero((1-eye).view(-1))].view(-1) 48 | at2 = at2.view(-1)[torch.nonzero((1-eye).view(-1))].view(-1) 49 | at = at1 * at2 50 | # update 51 | self.score_matrix.data[qt, ra, ol1, ol2] += at.data 52 | 53 | 54 | def update_matrix(self, obj_label, qus_type, attention): 55 | """ 56 | obj_label: [box_num] 57 | attention: [box_num] 58 | qus_type: 1 59 | """ 60 | box_num = obj_label.shape[0] 61 | for i in range(box_num): 62 | for j in range(box_num): 63 | if i != j: 64 | # make sure it is sysmetrical metrix 65 | self.score_matrix[int(qus_type), int(obj_label[i]), int(obj_label[j])] = self.score_matrix[int(qus_type), int(obj_label[i]), int(obj_label[j])] + float(attention[i] * attention[j]) 66 | self.score_matrix[int(qus_type), int(obj_label[j]), int(obj_label[i])] = self.score_matrix[int(qus_type), int(obj_label[j]), int(obj_label[i])] + float(attention[i] * attention[j]) 67 | 68 | def get_matrix(self, obj_label, qus_type): 69 | """ 70 | obj_label: [box_num] 71 | qus_type: 1 72 | """ 73 | box_num = obj_label.shape[0] 74 | sliced_matrix = self.score_matrix[int(qus_type)].sum(0).data 75 | normed_matrix = sliced_matrix / sliced_matrix.max() 76 | ol1 = obj_label.view(box_num, 1).expand(box_num, box_num) 77 | ol2 = obj_label.view(1, box_num).expand(box_num, box_num) 78 | output = normed_matrix[ol1, ol2].clone() 79 | return output + output.transpose(0, 1) 80 | 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path 3 | import argparse 4 | import math 5 | import json 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torch.optim.lr_scheduler as lr_scheduler 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | from tqdm import tqdm 15 | 16 | import config 17 | import data 18 | import model 19 | import utils 20 | 21 | from tensorboardX import SummaryWriter 22 | exp_setting = input("What's new in this experiment......") 23 | print('Experiment Setting: ', exp_setting) 24 | writer = SummaryWriter('runs/'+exp_setting) 25 | 26 | def run(net, loader, optimizer, scheduler, tracker, train=False, has_answers=True, prefix='', epoch=0): 27 | """ Run an epoch over the given loader """ 28 | assert not (train and not has_answers) 29 | if train: 30 | net.train() 31 | tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99} 32 | else: 33 | net.eval() 34 | tracker_class, tracker_params = tracker.MeanMonitor, {} 35 | answ = [] 36 | idxs = [] 37 | accs = [] 38 | 39 | loader = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) 40 | loss_tracker = tracker.track('{}_loss'.format(prefix), tracker_class(**tracker_params)) 41 | acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) 42 | batch_count = 0 43 | batch_max = len(loader) 44 | for v, q, a, b, q_type, idx, q_len in loader: 45 | var_params = { 46 | 'volatile': not train, 47 | 'requires_grad': False, 48 | } 49 | v = Variable(v.cuda(async=True), **var_params) 50 | q = Variable(q.cuda(async=True), **var_params) 51 | a = Variable(a.cuda(async=True), **var_params) 52 | b = Variable(b.cuda(async=True), **var_params) 53 | q_len = Variable(q_len.cuda(async=True), **var_params) 54 | q_type = Variable(q_type.cuda(async=True), **var_params) 55 | 56 | if config.use_rl and train: 57 | net.eval() 58 | out, _, _ = net(v, b, q, q_len, q_type) 59 | acc = utils.batch_accuracy(out.data, a.data).cpu() 60 | baseline = [] 61 | for i in range(acc.shape[0]): 62 | baseline.append(float(acc[i])) 63 | #float(acc.mean()) 64 | net.train() 65 | utils.fix_batchnorm(net) 66 | out, rl_ls, _ = net(v, b, q, q_len, q_type) 67 | acc = utils.batch_accuracy(out.data, a.data).cpu() 68 | current = [] 69 | for i in range(acc.shape[0]): 70 | current.append(float(acc[i])) 71 | #float(acc.mean()) 72 | #print(baseline - current) 73 | rl_loss = [] 74 | assert len(rl_ls) == len(baseline) 75 | for i in range(len(rl_ls)): 76 | rl_loss.append((baseline[i] - current[i]) * rl_ls[i]) 77 | #(baseline - current) * sum(rl_ls) / len(rl_ls) 78 | #entropy_loss = sum(entropy_ls) / len(entropy_ls) * 1e-4 79 | loss = sum(rl_loss) #+ entropy_loss 80 | else: 81 | out, _, _ = net(v, b, q, q_len, q_type) 82 | if has_answers: 83 | nll = -F.log_softmax(out, dim=1) 84 | loss = (nll * a / 10).sum(dim=1).mean() 85 | acc = utils.batch_accuracy(out.data, a.data).cpu() 86 | 87 | if train: 88 | scheduler.step() 89 | optimizer.zero_grad() 90 | loss.backward() 91 | optimizer.step() 92 | else: 93 | # store information about evaluation of this minibatch 94 | _, answer = out.data.cpu().max(dim=1) 95 | answ.append(answer.view(-1)) 96 | if has_answers: 97 | accs.append(acc.view(-1)) 98 | idxs.append(idx.view(-1).clone()) 99 | 100 | if has_answers: 101 | loss_tracker.append(loss.item()) 102 | acc_tracker.append(acc.mean()) 103 | fmt = '{:.4f}'.format 104 | loader.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) 105 | 106 | if train: 107 | writer.add_scalar('train/loss', loss.item(), epoch * batch_max + batch_count) 108 | writer.add_scalar('train/accu', acc.mean(), epoch * batch_max + batch_count) 109 | #writer.export_scalars_to_json("./log_board.json") 110 | else: 111 | writer.add_scalar('val/loss', loss.item(), epoch * batch_max + batch_count) 112 | writer.add_scalar('val/accu', acc.mean(), epoch * batch_max + batch_count) 113 | #writer.export_scalars_to_json("./log_board.json") 114 | batch_count += 1 115 | 116 | if not train: 117 | answ = list(torch.cat(answ, dim=0)) 118 | if has_answers: 119 | accs = list(torch.cat(accs, dim=0)) 120 | else: 121 | accs = [] 122 | idxs = list(torch.cat(idxs, dim=0)) 123 | return answ, accs, idxs 124 | 125 | 126 | def main(): 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('name', nargs='*') 129 | parser.add_argument('--eval', dest='eval_only', action='store_true') 130 | parser.add_argument('--test', action='store_true') 131 | parser.add_argument('--resume', nargs='*') 132 | args = parser.parse_args() 133 | 134 | if args.test: 135 | args.eval_only = True 136 | src = open('model.py').read() 137 | if args.name: 138 | name = ' '.join(args.name) 139 | else: 140 | from datetime import datetime 141 | name = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 142 | target_name = os.path.join('logs', '{}'.format(name)) 143 | writer.add_text('Log Name:', name) 144 | if not args.test: 145 | # target_name won't be used in test mode 146 | print('will save to {}'.format(target_name)) 147 | if args.resume: 148 | logs = torch.load(' '.join(args.resume)) 149 | # hacky way to tell the VQA classes that they should use the vocab without passing more params around 150 | #data.preloaded_vocab = logs['vocab'] 151 | 152 | cudnn.benchmark = True 153 | 154 | if not args.eval_only: 155 | train_loader = data.get_loader(train=True) 156 | if not args.test: 157 | val_loader = data.get_loader(val=True) 158 | else: 159 | val_loader = data.get_loader(test=True) 160 | 161 | net = model.Net(val_loader.dataset.num_tokens).cuda() 162 | # restore transfer learning 163 | # 'data/vgrel-29.tar' for 36 164 | # 'data/vgrel-19.tar' for 10-100 165 | if config.output_size == 36: 166 | print("load data/vgrel-29(transfer36).tar") 167 | ckpt = torch.load('data/vgrel-29(transfer36).tar') 168 | else: 169 | print("load data/vgrel-19(transfer110).tar") 170 | ckpt = torch.load('data/vgrel-19(transfer110).tar') 171 | 172 | utils.optimistic_restore(net.tree_lstm.gen_tree_net, ckpt['state_dict']) 173 | 174 | if config.use_rl: 175 | for p in net.parameters(): 176 | p.requires_grad = False 177 | for p in net.tree_lstm.gen_tree_net.parameters(): 178 | p.requires_grad = True 179 | 180 | optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad], lr=config.initial_lr) 181 | scheduler = lr_scheduler.ExponentialLR(optimizer, 0.5**(1 / config.lr_halflife)) 182 | start_epoch = 0 183 | if args.resume: 184 | net.load_state_dict(logs['weights']) 185 | #optimizer.load_state_dict(logs['optimizer']) 186 | start_epoch = int(logs['epoch']) + 1 187 | 188 | tracker = utils.Tracker() 189 | config_as_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} 190 | print(config_as_dict) 191 | best_accuracy = -1 192 | 193 | for i in range(start_epoch, config.epochs): 194 | if not args.eval_only: 195 | run(net, train_loader, optimizer, scheduler, tracker, train=True, prefix='train', epoch=i) 196 | if i % 1 != 0 or (i > 0 and i <20): 197 | r = [[-1], [-1], [-1]] 198 | else: 199 | r = run(net, val_loader, optimizer, scheduler, tracker, train=False, prefix='val', epoch=i, has_answers=not args.test) 200 | 201 | if not args.test: 202 | results = { 203 | 'name': name, 204 | 'tracker': tracker.to_dict(), 205 | 'config': config_as_dict, 206 | 'weights': net.state_dict(), 207 | 'optimizer': optimizer.state_dict(), 208 | 'epoch': i, 209 | 'eval': { 210 | 'answers': r[0], 211 | 'accuracies': r[1], 212 | 'idx': r[2], 213 | }, 214 | 'vocab': val_loader.dataset.vocab, 215 | 'src': src, 216 | 'setting': exp_setting, 217 | } 218 | current_ac = sum(r[1]) / len(r[1]) 219 | if current_ac > best_accuracy: 220 | best_accuracy = current_ac 221 | print('update best model, current: ', current_ac) 222 | torch.save(results, target_name + '_best.pth') 223 | if i % 1 == 0: 224 | torch.save(results, target_name + '_' + str(i) + '.pth') 225 | 226 | else: 227 | # in test mode, save a results file in the format accepted by the submission server 228 | answer_index_to_string = {a: s for s, a in val_loader.dataset.answer_to_index.items()} 229 | results = [] 230 | for answer, index in zip(r[0], r[2]): 231 | answer = answer_index_to_string[answer.item()] 232 | qid = val_loader.dataset.question_ids[index] 233 | entry = { 234 | 'question_id': qid, 235 | 'answer': answer, 236 | } 237 | results.append(entry) 238 | with open('results.json', 'w') as fd: 239 | json.dump(results, fd) 240 | 241 | if args.eval_only: 242 | break 243 | 244 | 245 | if __name__ == '__main__': 246 | main() 247 | writer.close() 248 | writer.export_scalars_to_json("./log_board.json") 249 | -------------------------------------------------------------------------------- /tree_def.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import errno 4 | import os 5 | import numpy as np 6 | from PIL import Image 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class BasicBiTree(object): 12 | def __init__(self, idx, is_root=False): 13 | self.index = int(idx) 14 | self.is_root = is_root 15 | self.left_child = None 16 | self.right_child = None 17 | self.parent = None 18 | self.num_child = 0 19 | 20 | def set_root(self): 21 | self.is_root = True 22 | 23 | def add_left_child(self, child): 24 | if self.left_child is not None: 25 | print('Left child already exist') 26 | return 27 | child.parent = self 28 | self.num_child += 1 29 | self.left_child = child 30 | 31 | def add_right_child(self, child): 32 | if self.right_child is not None: 33 | print('Right child already exist') 34 | return 35 | child.parent = self 36 | self.num_child += 1 37 | self.right_child = child 38 | 39 | def get_total_child(self): 40 | sum = 0 41 | sum += self.num_child 42 | if self.left_child is not None: 43 | sum += self.left_child.get_total_child() 44 | if self.right_child is not None: 45 | sum += self.right_child.get_total_child() 46 | return sum 47 | 48 | def depth(self): 49 | if hasattr(self, '_depth'): 50 | return self._depth 51 | if self.parent is None: 52 | count = 1 53 | else: 54 | count = self.parent.depth() + 1 55 | self._depth = count 56 | return self._depth 57 | 58 | def max_depth(self): 59 | if hasattr(self, '_max_depth'): 60 | return self._max_depth 61 | count = 0 62 | if self.left_child is not None: 63 | left_depth = self.left_child.max_depth() 64 | if left_depth > count: 65 | count = left_depth 66 | if self.right_child is not None: 67 | right_depth = self.right_child.max_depth() 68 | if right_depth > count: 69 | count = right_depth 70 | count += 1 71 | self._max_depth = count 72 | return self._max_depth 73 | 74 | class ArbitraryTree(object): 75 | def __init__(self, idx, im_idx=-1, is_root=False): 76 | self.index = int(idx) 77 | self.is_root = is_root 78 | self.children = [] 79 | self.im_idx = int(im_idx) # which image it comes from 80 | self.parent = None 81 | 82 | def generate_bi_tree(self): 83 | # generate a BiTree node, parent/child relationship are not inherited 84 | return BiTree(self.index, im_idx=self.im_idx, is_root=self.is_root) 85 | 86 | def add_child(self, child): 87 | child.parent = self 88 | self.children.append(child) 89 | 90 | def print(self): 91 | print('====================') 92 | print('is root: ', self.is_root) 93 | print('index: ', self.index) 94 | print('num of child: ', len(self.children)) 95 | for node in self.children: 96 | node.print() 97 | 98 | def find_node_by_index(self, index, result_node): 99 | if self.index == index: 100 | result_node = self 101 | elif len(self.children) > 0: 102 | for i in range(len(self.children)): 103 | result_node = self.children[i].find_node_by_index(index, result_node) 104 | 105 | return result_node 106 | 107 | def search_best_insert(self, matrix_score, insert_node, best_score, best_depend_node, best_insert_node): 108 | # virtual node will not be considerred 109 | if self.is_root: 110 | pass 111 | elif float(matrix_score[self.index, insert_node.index]) > float(best_score): 112 | best_score = matrix_score[self.index, insert_node.index] 113 | best_depend_node = self 114 | best_insert_node = insert_node 115 | 116 | # iteratively search child 117 | for i in range(self.get_child_num()): 118 | best_score, best_depend_node, best_insert_node = \ 119 | self.children[i].search_best_insert(matrix_score, insert_node, best_score, best_depend_node, best_insert_node) 120 | 121 | return best_score, best_depend_node, best_insert_node 122 | 123 | def get_child_num(self): 124 | return len(self.children) 125 | 126 | def get_total_child(self): 127 | sum = 0 128 | num_current_child = self.get_child_num() 129 | sum += num_current_child 130 | for i in range(num_current_child): 131 | sum += self.children[i].get_total_child() 132 | return sum 133 | 134 | # only support binary tree 135 | class BiTree(BasicBiTree): 136 | def __init__(self, idx, im_idx, is_root=False, node_score=0.0, center_x=0.0): 137 | super(BiTree, self).__init__(idx, is_root) 138 | self.node_score = float(node_score) 139 | self.center_x = float(center_x) 140 | self.im_idx = int(im_idx) # which image it comes from 141 | -------------------------------------------------------------------------------- /tree_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import tree_def, tree_lstm, tree_utils, gen_tree_net, q_type_module 6 | from utils import PiecewiseLin 7 | 8 | import config 9 | 10 | class TreeFeature(nn.Module): 11 | def __init__(self, objects, visual_dim, hidden_dim): 12 | super().__init__() 13 | """ overlap_tree | arbitrary_trees_attention | arbitrary_trees_transfer | accumulate_module_trees """ 14 | """ sigmoid | softmax """ 15 | self.gen_tree_mode = config.gen_tree_mode 16 | self.poolout_mode = config.poolout_mode 17 | if config.use_rl: 18 | self.dropout = 0.0 19 | else: 20 | self.dropout = 0.5 21 | self.objects = objects 22 | self.visual_dim = visual_dim 23 | self.hidden_dim = hidden_dim 24 | 25 | self.tree_lstm = tree_lstm.BidirectionalTreeLSTM(visual_dim, hidden_dim) 26 | 27 | self.gen_tree_net = gen_tree_net.GenTreeModule() 28 | 29 | if self.gen_tree_mode == "accumulate_module_trees": 30 | self.accumulate_module = q_type_module.AccumulatingModule(65, 151) 31 | #self.f = nn.ModuleList([PiecewiseLin(16) for _ in range(2)]) 32 | 33 | def forward(self, boxes, attention_orig, visual_feature, v_origin, que_type): 34 | # only care about the highest scoring object proposals 35 | # the ones with low score will have a low impact on the count anyway 36 | boxes, attention_orig, visual_feature, v_origin = self.filter_most_important(self.objects, boxes, attention_orig, self.resize_visual_feature(visual_feature), self.resize_visual_feature(v_origin)) 37 | 38 | if self.gen_tree_mode == "overlap_tree": 39 | #only use box info to generate overlap tree 40 | forest = tree_utils.generate_tree(torch.transpose(boxes, 1, 2), "overlap_tree") 41 | elif self.gen_tree_mode == "arbitrary_trees_transfer": 42 | #use transfered tree-parser network to generate score matrix 43 | attention = F.sigmoid(attention_orig) # [batch_size, num_obj] 44 | relevancy = self.outer_product(attention) 45 | scores = self.gen_tree_net(v_origin, boxes) 46 | forest, rl_loss, entropy_loss = tree_utils.generate_tree((scores * relevancy, self.training), "arbitrary_trees") 47 | elif self.gen_tree_mode == "accumulate_module_trees": 48 | #use accumulate module to generate score matrix 49 | attention = F.sigmoid(attention_orig) # [batch_size, num_obj] 50 | #relevancy = self.outer_product(attention) # [batch_size, 10, 10] 51 | bbox_sim = self.iou(boxes, boxes) 52 | _, obj_label = self.gen_tree_net.get_label(v_origin) #[batch_size, 10, 151], [batch_size, 10] 53 | #print('obj_dist: ', obj_dist.shape) 54 | #print('obj_label: ', obj_label.data.cpu().numpy()) 55 | packed_input = (self.accumulate_module, attention, obj_label, que_type, bbox_sim, self.training) 56 | forest = tree_utils.generate_tree(packed_input, "accumulate_module_trees") 57 | else: 58 | print('Error: Please select a proper gen-tree method') 59 | 60 | if self.training: 61 | visual_hidden = self.tree_lstm(forest, torch.transpose(visual_feature, 1, 2), self.objects, self.dropout) 62 | else: 63 | visual_hidden = self.tree_lstm(forest, torch.transpose(visual_feature, 1, 2), self.objects, 0.0) # [batch_size, num_obj, hidden_size] 64 | 65 | del forest 66 | 67 | batch_size, num_obj, hidden_size = visual_hidden.shape 68 | return torch.transpose(visual_hidden, 1, 2).contiguous().view(batch_size, hidden_size, 1, num_obj), rl_loss, entropy_loss 69 | 70 | def filter_most_important(self, n, boxes, attention, visual_feature, v_origin): 71 | """ Only keep top-n object proposals, scored by attention weight """ 72 | attention, idx = attention.topk(n, dim=1, sorted=False) 73 | idx_box = idx.unsqueeze(dim=1).expand(boxes.size(0), boxes.size(1), idx.size(1)) 74 | boxes = boxes.gather(2, idx_box) 75 | idx_feat = idx.unsqueeze(dim=1).expand(visual_feature.size(0), visual_feature.size(1), idx.size(1)) 76 | visual_feature = visual_feature.gather(2, idx_feat) 77 | v_origin = v_origin.gather(2, idx_feat) 78 | return boxes, attention, visual_feature, v_origin 79 | 80 | def resize_visual_feature(self, visual_feature): 81 | batch_size, feature_dim, _, num_obj = visual_feature.shape 82 | return visual_feature.view(batch_size, feature_dim, num_obj) 83 | 84 | def outer(self, x): 85 | size = tuple(x.size()) + (x.size()[-1],) 86 | a = x.unsqueeze(dim=-1).expand(*size) 87 | b = x.unsqueeze(dim=-2).expand(*size) 88 | return a, b 89 | 90 | def outer_product(self, x): 91 | # Y_ij = x_i * x_j 92 | a, b = self.outer(x) 93 | return a * b 94 | 95 | def outer_diff(self, x): 96 | # like outer products, except taking the absolute difference instead 97 | # Y_ij = | x_i - x_j | 98 | a, b = self.outer(x) 99 | return (a - b).abs() 100 | 101 | def iou(self, a, b): 102 | # this is just the usual way to IoU from bounding boxes 103 | inter = self.intersection(a, b) 104 | area_a = self.area(a).unsqueeze(2).expand_as(inter) 105 | area_b = self.area(b).unsqueeze(1).expand_as(inter) 106 | return inter / (area_a + area_b - inter + 1e-12) 107 | 108 | def area(self, box): 109 | x = (box[:, 2, :] - box[:, 0, :]).clamp(min=0) 110 | y = (box[:, 3, :] - box[:, 1, :]).clamp(min=0) 111 | return x * y 112 | 113 | def intersection(self, a, b): 114 | size = (a.size(0), 2, a.size(2), b.size(2)) 115 | min_point = torch.max( 116 | a[:, :2, :].unsqueeze(dim=3).expand(*size), 117 | b[:, :2, :].unsqueeze(dim=2).expand(*size), 118 | ) 119 | max_point = torch.min( 120 | a[:, 2:, :].unsqueeze(dim=3).expand(*size), 121 | b[:, 2:, :].unsqueeze(dim=2).expand(*size), 122 | ) 123 | inter = (max_point - min_point).clamp(min=0) 124 | area = inter[:, 0, :, :] * inter[:, 1, :, :] 125 | return area -------------------------------------------------------------------------------- /tree_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import PackedSequence 6 | import numpy as np 7 | 8 | import tree_utils 9 | from tree_utils import block_orthogonal 10 | 11 | 12 | class MultiLayer_BTreeLSTM(nn.Module): 13 | """ 14 | Multilayer Bidirectional Tree LSTM 15 | Each layer contains one forward lstm(leaves to root) and one backward lstm(root to leaves) 16 | """ 17 | def __init__(self, in_dim, out_dim, num_layer): 18 | super(MultiLayer_BTreeLSTM, self).__init__() 19 | self.num_layer = num_layer 20 | layers = [] 21 | layers.append(BidirectionalTreeLSTM(in_dim, out_dim)) 22 | for i in range(num_layer - 1): 23 | layers.append(BidirectionalTreeLSTM(out_dim, out_dim)) 24 | self.multi_layer_lstm = nn.ModuleList(layers) 25 | 26 | def forward(self, forest, features, num_obj, dropout=0.0): 27 | for i in range(self.num_layer): 28 | features = self.multi_layer_lstm[i](forest, features, num_obj, dropout) 29 | return features 30 | 31 | 32 | class BidirectionalTreeLSTM(nn.Module): 33 | """ 34 | Bidirectional Tree LSTM 35 | Contains one forward lstm(leaves to root) and one backward lstm(root to leaves) 36 | Dropout mask will be generated one time for all trees in the forest, to make sure the consistancy 37 | """ 38 | def __init__(self, in_dim, out_dim): 39 | super(BidirectionalTreeLSTM, self).__init__() 40 | self.out_dim = out_dim 41 | self.treeLSTM_foreward = OneDirectionalTreeLSTM(in_dim, int(out_dim / 2), 'foreward') 42 | self.treeLSTM_backward = OneDirectionalTreeLSTM(in_dim, int(out_dim / 2), 'backward') 43 | 44 | def forward(self, forest, features, num_obj, dropout=0.0): 45 | foreward_output = self.treeLSTM_foreward(forest, features, num_obj, dropout) 46 | backward_output = self.treeLSTM_backward(forest, features, num_obj, dropout) 47 | 48 | final_output = torch.cat((foreward_output, backward_output), 2) 49 | 50 | return final_output 51 | 52 | class RootCentricTreeLSTM(nn.Module): 53 | """ 54 | From leaves node to root node 55 | """ 56 | def __init__(self, in_dim, out_dim): 57 | super(RootCentricTreeLSTM, self).__init__() 58 | self.out_dim = out_dim 59 | self.treeLSTM = BiTreeLSTM_Foreward(in_dim, out_dim) 60 | 61 | def forward(self, forest, features, num_obj, dropout=0.0): 62 | # calc dropout mask, same for all 63 | if dropout > 0.0: 64 | dropout_mask = get_dropout_mask(dropout, self.out_dim) 65 | else: 66 | dropout_mask = None 67 | 68 | # tree lstm input 69 | final_output = None 70 | lstm_io = tree_utils.TreeLSTM_IO(num_obj, dropout_mask) 71 | 72 | # run tree lstm forward (leaves to root) 73 | for idx in range(len(forest)): 74 | _, sliced_h = self.treeLSTM(forest[idx], features[idx], lstm_io, idx) 75 | sliced_output = sliced_h.view(1, self.out_dim) 76 | if final_output is None: 77 | final_output = sliced_output 78 | else: 79 | final_output = torch.cat((final_output, sliced_output), 0) 80 | # Reset hidden 81 | lstm_io.reset() 82 | 83 | return final_output 84 | 85 | class OneDirectionalTreeLSTM(nn.Module): 86 | """ 87 | One Way Tree LSTM 88 | direction = foreward | backward 89 | """ 90 | def __init__(self, in_dim, out_dim, direction): 91 | super(OneDirectionalTreeLSTM, self).__init__() 92 | self.out_dim = out_dim 93 | self.direction = direction 94 | if direction == 'foreward': 95 | self.treeLSTM = BiTreeLSTM_Foreward(in_dim, out_dim) 96 | elif direction == 'backward': 97 | self.treeLSTM = BiTreeLSTM_Backward(in_dim, out_dim) 98 | else: 99 | print('Error Tree LSTM Direction') 100 | 101 | def forward(self, forest, features, num_obj, dropout=0.0): 102 | # calc dropout mask, same for all 103 | if dropout > 0.0: 104 | dropout_mask = get_dropout_mask(dropout, self.out_dim) 105 | else: 106 | dropout_mask = None 107 | 108 | # tree lstm input 109 | final_output = None 110 | lstm_io = tree_utils.TreeLSTM_IO(num_obj, dropout_mask) 111 | # run tree lstm forward (leaves to root) 112 | for idx in range(len(forest)): 113 | if self.direction == 'foreward': 114 | self.treeLSTM(forest[idx], features[idx], lstm_io, idx) 115 | elif self.direction == 'backward': 116 | root_c = torch.FloatTensor(self.out_dim).cuda().fill_(0.0) 117 | root_h = torch.FloatTensor(self.out_dim).cuda().fill_(0.0) 118 | self.treeLSTM(forest[idx], features[idx], lstm_io, idx, root_c, root_h) 119 | else: 120 | print('Error Tree LSTM Direction') 121 | sliced_output = torch.index_select(lstm_io.hidden, 0, lstm_io.order.long()).view(1, num_obj, self.out_dim) 122 | if final_output is None: 123 | final_output = sliced_output 124 | else: 125 | final_output = torch.cat((final_output, sliced_output), 0) 126 | # Reset hidden 127 | lstm_io.reset() 128 | 129 | return final_output 130 | 131 | 132 | class BiTreeLSTM_Foreward(nn.Module): 133 | """ 134 | From leaves to root 135 | """ 136 | def __init__(self, feat_dim, h_dim): 137 | super(BiTreeLSTM_Foreward, self).__init__() 138 | self.feat_dim = feat_dim 139 | self.h_dim = h_dim 140 | 141 | self.ioffux = nn.Linear(self.feat_dim, 5 * self.h_dim) 142 | self.ioffuh_left = nn.Linear(self.h_dim, 5 * self.h_dim) 143 | self.ioffuh_right = nn.Linear(self.h_dim, 5 * self.h_dim) 144 | #self.px = nn.Linear(self.feat_dim, self.h_dim) 145 | 146 | # init parameter 147 | #block_orthogonal(self.px.weight.data, [self.h_dim, self.feat_dim]) 148 | block_orthogonal(self.ioffux.weight.data, [self.h_dim, self.feat_dim]) 149 | block_orthogonal(self.ioffuh_left.weight.data, [self.h_dim, self.h_dim]) 150 | block_orthogonal(self.ioffuh_right.weight.data, [self.h_dim, self.h_dim]) 151 | 152 | #self.px.bias.data.fill_(0.0) 153 | self.ioffux.bias.data.fill_(0.0) 154 | self.ioffuh_left.bias.data.fill_(0.0) 155 | self.ioffuh_right.bias.data.fill_(0.0) 156 | # Initialize forget gate biases to 1.0 as per An Empirical 157 | # Exploration of Recurrent Network Architectures, (Jozefowicz, 2015). 158 | self.ioffuh_left.bias.data[2 * self.h_dim:4 * self.h_dim].fill_(0.5) 159 | self.ioffuh_right.bias.data[2 * self.h_dim:4 * self.h_dim].fill_(0.5) 160 | 161 | 162 | def node_forward(self, feat_inp, left_c, right_c, left_h, right_h, dropout_mask, has_left, has_right): 163 | #projected_x = self.px(feat_inp) 164 | if has_left and has_right: 165 | ioffu = self.ioffux(feat_inp) + self.ioffuh_left(left_h) + self.ioffuh_right(right_h) 166 | elif has_left and (not has_right): 167 | ioffu = self.ioffux(feat_inp) + self.ioffuh_left(left_h) 168 | elif has_right and (not has_left): 169 | ioffu = self.ioffux(feat_inp) + self.ioffuh_right(right_h) 170 | else: 171 | ioffu = self.ioffux(feat_inp) 172 | 173 | i, o, f_l, f_r, u = torch.split(ioffu, ioffu.size(1) // 5, dim=1) 174 | i, o, f_l, f_r, u = F.sigmoid(i), F.sigmoid(o), F.sigmoid(f_l), F.sigmoid(f_r), F.tanh(u) #, F.sigmoid(r) 175 | 176 | c = torch.mul(i, u) + torch.mul(f_l, left_c) + torch.mul(f_r, right_c) 177 | h = torch.mul(o, F.tanh(c)) 178 | #h_final = torch.mul(r, h) + torch.mul((1 - r), projected_x) 179 | # Only do dropout if the dropout prob is > 0.0 and we are in training mode. 180 | if dropout_mask is not None and self.training: 181 | h = torch.mul(h, dropout_mask) 182 | return c, h 183 | 184 | def forward(self, tree, features, treelstm_io, batch_idx): 185 | """ 186 | tree: The root for a tree 187 | features: [num_obj, featuresize] 188 | treelstm_io.hidden: init as None, cat until it covers all objects as [num_obj, hidden_size] 189 | treelstm_io.order: init as 0 for all [num_obj], update for recovering original order 190 | """ 191 | # recursively search child 192 | if tree.left_child is not None: 193 | has_left = True 194 | left_c, left_h = self.forward(tree.left_child, features, treelstm_io, batch_idx) 195 | else: 196 | has_left = False 197 | left_c = torch.FloatTensor(self.h_dim).cuda().fill_(0.0) 198 | left_h = torch.FloatTensor(self.h_dim).cuda().fill_(0.0) 199 | 200 | if tree.right_child is not None: 201 | has_right = True 202 | right_c, right_h = self.forward(tree.right_child, features, treelstm_io, batch_idx) 203 | else: 204 | has_right = False 205 | right_c = torch.FloatTensor(self.h_dim).cuda().fill_(0.0) 206 | right_h = torch.FloatTensor(self.h_dim).cuda().fill_(0.0) 207 | 208 | # calc 209 | next_feature = features[tree.index].view(1, -1) 210 | 211 | c, h = self.node_forward(next_feature, left_c, right_c, left_h, right_h, treelstm_io.dropout_mask, has_left, has_right) 212 | 213 | # record hidden state 214 | if treelstm_io.hidden is None: 215 | treelstm_io.hidden = h.view(1, -1) 216 | else: 217 | treelstm_io.hidden = torch.cat((treelstm_io.hidden, h.view(1, -1)), 0) 218 | 219 | treelstm_io.order[tree.index] = treelstm_io.order_count 220 | treelstm_io.order_count += 1 221 | 222 | return c, h 223 | 224 | 225 | class BiTreeLSTM_Backward(nn.Module): 226 | """ 227 | from root to leaves 228 | """ 229 | def __init__(self, feat_dim, h_dim): 230 | super(BiTreeLSTM_Backward, self).__init__() 231 | self.feat_dim = feat_dim 232 | self.h_dim = h_dim 233 | 234 | self.iofux = nn.Linear(self.feat_dim, 4 * self.h_dim) 235 | self.iofuh = nn.Linear(self.h_dim, 4 * self.h_dim) 236 | #self.px = nn.Linear(self.feat_dim, self.h_dim) 237 | 238 | # init parameter 239 | #block_orthogonal(self.px.weight.data, [self.h_dim, self.feat_dim]) 240 | block_orthogonal(self.iofux.weight.data, [self.h_dim, self.feat_dim]) 241 | block_orthogonal(self.iofuh.weight.data, [self.h_dim, self.h_dim]) 242 | 243 | #self.px.bias.data.fill_(0.0) 244 | self.iofux.bias.data.fill_(0.0) 245 | self.iofuh.bias.data.fill_(0.0) 246 | # Initialize forget gate biases to 1.0 as per An Empirical 247 | # Exploration of Recurrent Network Architectures, (Jozefowicz, 2015). 248 | self.iofuh.bias.data[2 * self.h_dim:3 * self.h_dim].fill_(1.0) 249 | 250 | def node_backward(self, feat_inp, root_c, root_h, dropout_mask): 251 | 252 | #projected_x = self.px(feat_inp) 253 | iofu = self.iofux(feat_inp) + self.iofuh(root_h) 254 | i, o, f, u = torch.split(iofu, iofu.size(1) // 4, dim=1) 255 | i, o, f, u = F.sigmoid(i), F.sigmoid(o), F.sigmoid(f), F.tanh(u) #, F.sigmoid(r) 256 | 257 | c = torch.mul(i, u) + torch.mul(f, root_c) 258 | h = torch.mul(o, F.tanh(c)) 259 | #h_final = torch.mul(r, h) + torch.mul((1 - r), projected_x) 260 | # Only do dropout if the dropout prob is > 0.0 and we are in training mode. 261 | if dropout_mask is not None and self.training: 262 | h = torch.mul(h, dropout_mask) 263 | return c, h 264 | 265 | def forward(self, tree, features, treelstm_io, batch_idx, root_c, root_h): 266 | """ 267 | tree: The root for a tree 268 | features: [num_obj, featuresize] 269 | treelstm_io.hidden: init as None, cat until it covers all objects as [num_obj, hidden_size] 270 | treelstm_io.order: init as 0 for all [num_obj], update for recovering original order 271 | """ 272 | next_features = features[tree.index].view(1, -1) 273 | 274 | c, h = self.node_backward(next_features, root_c, root_h, treelstm_io.dropout_mask) 275 | 276 | # record hidden state 277 | if treelstm_io.hidden is None: 278 | treelstm_io.hidden = h.view(1, -1) 279 | else: 280 | treelstm_io.hidden = torch.cat((treelstm_io.hidden, h.view(1, -1)), 0) 281 | 282 | treelstm_io.order[tree.index] = treelstm_io.order_count 283 | treelstm_io.order_count += 1 284 | 285 | # recursively update from root to leaves 286 | if tree.left_child is not None: 287 | self.forward(tree.left_child, features, treelstm_io, batch_idx, c, h) 288 | if tree.right_child is not None: 289 | self.forward(tree.right_child, features, treelstm_io, batch_idx, c, h) 290 | 291 | return 292 | 293 | def get_dropout_mask(dropout_probability, h_dim): 294 | """ 295 | Computes and returns an element-wise dropout mask for a given tensor, where 296 | each element in the mask is dropped out with probability dropout_probability. 297 | Note that the mask is NOT applied to the tensor - the tensor is passed to retain 298 | the correct CUDA tensor type for the mask. 299 | 300 | Parameters 301 | ---------- 302 | dropout_probability : float, required. 303 | Probability of dropping a dimension of the input. 304 | tensor_for_masking : torch.Variable, required. 305 | 306 | 307 | Returns 308 | ------- 309 | A torch.FloatTensor consisting of the binary mask scaled by 1/ (1 - dropout_probability). 310 | This scaling ensures expected values and variances of the output of applying this mask 311 | and the original tensor are the same. 312 | """ 313 | binary_mask = Variable(torch.FloatTensor(h_dim).cuda().fill_(0.0)) 314 | binary_mask.data.copy_(torch.rand(h_dim) > dropout_probability) 315 | # Scale mask by 1/keep_prob to preserve output statistics. 316 | dropout_mask = binary_mask.float().div(1.0 - dropout_probability) 317 | return dropout_mask 318 | -------------------------------------------------------------------------------- /tree_utils.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import itertools 9 | from torch.autograd import Variable 10 | import tree_def, tree_utils, tree_def 11 | import config 12 | 13 | def generate_tree(gen_tree_input, tree_type="overlap_tree"): 14 | """ 15 | """ 16 | if tree_type == "overlap_tree": 17 | # gen_tree_input => bbox_shape : [batch, num_objs, b_dim] (x1,y1,x2,y2) 18 | return generate_overlap_tree(gen_tree_input) 19 | elif tree_type == "arbitrary_trees": 20 | # gen_tree_input => scores : [batch_size, 10, 10] 21 | return generate_arbitrary_trees(gen_tree_input) 22 | else: 23 | print("Invalid Tree Type") 24 | return None 25 | 26 | 27 | def generate_arbitrary_trees(inputpack): 28 | """ Generate arbiraty trees according to the bbox scores """ 29 | scores, is_training = inputpack 30 | trees = [] 31 | batch_size, num_obj, _ = scores.size() 32 | sym_score = scores + scores.transpose(1,2) 33 | rl_loss = [] 34 | entropy_loss = [] 35 | for i in range(batch_size): 36 | slice_container, index_list = return_tree_contrainer(num_obj, i) 37 | slice_root_score = (sym_score[i].sum(1) - sym_score[i].diag()) / 100.0 38 | slice_bitree = ArTree_to_BiTree(return_tree(sym_score[i], slice_root_score, slice_container, index_list, rl_loss, entropy_loss, is_training)) 39 | trees.append(slice_bitree) 40 | 41 | #trees = [create_single_tree(sym_score[i], i, num_obj) for i in range(batch_size)] 42 | return trees, rl_loss, entropy_loss 43 | 44 | #def create_single_tree(single_sym_score, i, num_obj): 45 | # slice_container = [tree_def.ArbitraryTree(index, im_idx=i) for index in range(num_obj)] 46 | # index_list = [index for index in range(num_obj)] 47 | # single_root_score = single_sym_score.sum(1) - single_sym_score.diag() 48 | # slice_bitree = ArTree_to_BiTree(return_tree(single_sym_score, single_root_score, slice_container, index_list)) 49 | # return slice_bitree 50 | 51 | 52 | def return_tree(matrix_score, root_score, node_containter, remain_list, gen_tree_loss_per_batch, entropy_loss, is_training): 53 | """ Generate An Arbitrary Tree by Scores """ 54 | virtual_root = tree_def.ArbitraryTree(-1, im_idx=-1) 55 | virtual_root.is_root = True 56 | 57 | start_idx = int(root_score.argmax()) 58 | start_node = node_containter[start_idx] 59 | virtual_root.add_child(start_node) 60 | assert(start_node.index == start_idx) 61 | select_list = [] 62 | selected_node = [] 63 | select_list.append(start_idx) 64 | selected_node.append(start_node) 65 | remain_list.remove(start_idx) 66 | node_containter.remove(start_node) 67 | 68 | not_sampled = True 69 | 70 | while(len(node_containter) > 0): 71 | wid = len(remain_list) 72 | 73 | select_index_var = Variable(torch.LongTensor(select_list).cuda()) 74 | remain_index_var = Variable(torch.LongTensor(remain_list).cuda()) 75 | select_score_map = torch.index_select( torch.index_select(matrix_score, 0, select_index_var), 1, remain_index_var ).contiguous().view(-1) 76 | 77 | #select_score_map = matrix_score[select_list][:,remain_list].contiguous().view(-1) 78 | if config.use_rl and is_training and not_sampled: 79 | dist = F.softmax(select_score_map, 0) 80 | greedy_id = select_score_map.max(0)[1] 81 | best_id = torch.multinomial(dist, 1)[0] 82 | if int(greedy_id) != int(best_id): 83 | not_sampled = False 84 | if config.log_softmax: 85 | prob = dist[best_id] + 1e-20 86 | else: 87 | prob = select_score_map[best_id] + 1e-20 88 | gen_tree_loss_per_batch.append(prob.log()) 89 | #neg_entropy = dist * (dist + 1e-20).log() 90 | #entropy_loss.append(neg_entropy.sum()) 91 | else: 92 | _, best_id = select_score_map.max(0) 93 | #_, best_id = select_score_map.max(0) 94 | depend_id = int(best_id) // wid 95 | insert_id = int(best_id) % wid 96 | 97 | best_depend_node = selected_node[depend_id] 98 | best_insert_node = node_containter[insert_id] 99 | best_depend_node.add_child(best_insert_node) 100 | 101 | selected_node.append(best_insert_node) 102 | select_list.append(best_insert_node.index) 103 | node_containter.remove(best_insert_node) 104 | remain_list.remove(best_insert_node.index) 105 | if not_sampled: 106 | gen_tree_loss_per_batch.append(Variable(torch.FloatTensor([0]).zero_().cuda())) 107 | return virtual_root 108 | 109 | def return_tree_contrainer(num_nodes, batch_id): 110 | """ Return number of tree nodes """ 111 | container = [] 112 | index_list= [] 113 | for i in range(num_nodes): 114 | container.append(tree_def.ArbitraryTree(i, im_idx=batch_id)) 115 | index_list.append(i) 116 | return container, index_list 117 | 118 | def ArTree_to_BiTree(arTree): 119 | root_node = arTree.generate_bi_tree() 120 | arNode_to_biNode(arTree, root_node) 121 | assert(root_node.index == -1) 122 | assert(root_node.right_child is None) 123 | assert(root_node.left_child is not None) 124 | return root_node.left_child 125 | 126 | def arNode_to_biNode(arNode, biNode): 127 | if arNode.get_child_num() >= 1: 128 | new_bi_node = arNode.children[0].generate_bi_tree() 129 | biNode.add_left_child(new_bi_node) 130 | arNode_to_biNode(arNode.children[0], biNode.left_child) 131 | 132 | if arNode.get_child_num() > 1: 133 | current_bi_node = biNode.left_child 134 | for i in range(arNode.get_child_num() - 1): 135 | new_bi_node = arNode.children[i+1].generate_bi_tree() 136 | current_bi_node.add_right_child(new_bi_node) 137 | current_bi_node = current_bi_node.right_child 138 | arNode_to_biNode(arNode.children[i+1], current_bi_node) 139 | 140 | def generate_overlap_tree(bbox_shape): 141 | """ 142 | bbox_shape : [batch, num_objs, b_dim] (x1,y1,x2,y2,w,h) 143 | Method: 144 | Iteratively generate a tree: 145 | 1) Select a node with the largest score: num_overlap + /lambda * node_area (/lambda = 1e-10) 146 | 2) According to the bbox center, separate the rest node into left/right subtree 147 | """ 148 | batch_size, num_objs, _ = bbox_shape.size() 149 | # calculate information required by generation procedure 150 | bbox_overlap = overlap(bbox_shape) 151 | bbox_area = (bbox_shape[:, :, 2] - bbox_shape[:, :, 0]) * (bbox_shape[:, :, 3] - bbox_shape[:, :, 1]) 152 | bbox_center = (bbox_shape[:, :, 0] + bbox_shape[:, :, 2]) / 2.0 153 | 154 | forest = [] 155 | for i in range(batch_size): 156 | node_container = [] 157 | score_list = [] 158 | # Overlap Matrix -> Binary Matrix -> Num of Overlaped Objects 159 | overlap_slice = (bbox_overlap[i].view(num_objs, num_objs) > 0).sum(1) 160 | for j in range(num_objs): 161 | # Node score = num_overlap + /lambda * node_area (/lambda = 1e-10) 162 | node_score = float(overlap_slice[j]) + float(bbox_area[i,j]) * 1e-10 163 | node_container.append(tree_def.BiTree(j, im_idx=i, node_score=node_score, center_x=float(bbox_center[i,j]))) 164 | score_list.append(node_score) 165 | 166 | root = return_best_node(node_container, score_list) 167 | root.set_root() 168 | iterative_gen_tree(node_container, score_list, root) 169 | forest.append(root) 170 | return forest 171 | 172 | 173 | def iterative_gen_tree(node_container, score_list, root): 174 | """ 175 | Iterativly generate a tree 176 | (1) Select a root 177 | (2) Separete the rest nodes into two parts 178 | (3) Running step one for each parts 179 | """ 180 | if len(node_container) == 0: 181 | return 182 | left_container, left_score, right_container, right_score = seperate_container_by_root(node_container, score_list, root) 183 | left_root = return_best_node(left_container, left_score) 184 | right_root = return_best_node(right_container, right_score) 185 | if left_root is not None: 186 | root.add_left_child(left_root) 187 | iterative_gen_tree(left_container, left_score, left_root) 188 | if right_root is not None: 189 | root.add_right_child(right_root) 190 | iterative_gen_tree(right_container, right_score, right_root) 191 | return 192 | 193 | def return_best_node(node_container, score_list): 194 | """ 195 | Given a list of nodes 196 | (1)Find the node with the largest score 197 | (2)Remove the selected node 198 | """ 199 | if len(node_container) == 0: 200 | return None 201 | scoreList = torch.FloatTensor(score_list) 202 | ind = int(scoreList.max(0)[1]) 203 | best_node = node_container[ind] 204 | return best_node 205 | 206 | def seperate_container_by_root(node_container, score_list, root): 207 | """ 208 | Given a list of nodes 209 | (1) Seperate the container in to two by root node 210 | (2) Return left/right container 211 | """ 212 | left_container = [] 213 | left_score = [] 214 | right_container = [] 215 | right_score = [] 216 | for i in range(len(node_container)): 217 | if node_container[i].index == root.index: 218 | continue 219 | elif node_container[i].center_x < root.center_x: 220 | left_container.append(node_container[i]) 221 | left_score.append(score_list[i]) 222 | else: 223 | right_container.append(node_container[i]) 224 | right_score.append(score_list[i]) 225 | return left_container, left_score, right_container, right_score 226 | 227 | 228 | def overlap(bbox_shape): 229 | """ 230 | bbox_shape : [batch, num_objs, b_dim] (x1,y1,x2,y2,w,h) 231 | """ 232 | batch_size, num_objs, _ = bbox_shape.size() 233 | 234 | min_max_xy = torch.min(bbox_shape[:, :, 2:4].unsqueeze(2).expand(batch_size, num_objs, num_objs, 2), 235 | bbox_shape[:, :, 2:4].unsqueeze(1).expand(batch_size, num_objs, num_objs, 2)) 236 | max_min_xy = torch.max(bbox_shape[:, :, :2].unsqueeze(2).expand(batch_size, num_objs, num_objs, 2), 237 | bbox_shape[:, :, :2].unsqueeze(1).expand(batch_size, num_objs, num_objs, 2)) 238 | inter = torch.clamp((min_max_xy - max_min_xy), min=0) 239 | return inter[:, :, :, 0] * inter[:, :, :, 1] 240 | 241 | class TreeLSTM_IO(object): 242 | def __init__(self, num_obj, dropout_mask): 243 | self.num_obj = num_obj 244 | self.hidden = None # Float tensor [num_obj, self.out_dim] 245 | self.order = Variable(torch.LongTensor(num_obj).zero_().cuda()) # Long tensor [num_obj] 246 | self.order_count = 0 # int 247 | self.dropout_mask = dropout_mask 248 | 249 | def reset(self): 250 | self.hidden = None # Float tensor [num_obj, self.out_dim] 251 | self.order = Variable(torch.LongTensor(self.num_obj).zero_().cuda()) # Long tensor [num_obj] 252 | self.order_count = 0 # int 253 | 254 | def block_orthogonal(tensor, split_sizes, gain=1.0): 255 | """ 256 | An initializer which allows initializing model parameters in "blocks". This is helpful 257 | in the case of recurrent models which use multiple gates applied to linear projections, 258 | which can be computed efficiently if they are concatenated together. However, they are 259 | separate parameters which should be initialized independently. 260 | Parameters 261 | ---------- 262 | tensor : ``torch.Tensor``, required. 263 | A tensor to initialize. 264 | split_sizes : List[int], required. 265 | A list of length ``tensor.ndim()`` specifying the size of the 266 | blocks along that particular dimension. E.g. ``[10, 20]`` would 267 | result in the tensor being split into chunks of size 10 along the 268 | first dimension and 20 along the second. 269 | gain : float, optional (default = 1.0) 270 | The gain (scaling) applied to the orthogonal initialization. 271 | """ 272 | sizes = list(tensor.size()) 273 | if any([a % b != 0 for a, b in zip(sizes, split_sizes)]): 274 | raise ValueError("tensor dimensions must be divisible by their respective " 275 | "split_sizes. Found size: {} and split_sizes: {}".format(sizes, split_sizes)) 276 | indexes = [list(range(0, max_size, split)) 277 | for max_size, split in zip(sizes, split_sizes)] 278 | # Iterate over all possible blocks within the tensor. 279 | for block_start_indices in itertools.product(*indexes): 280 | # A list of tuples containing the index to start at for this block 281 | # and the appropriate step size (i.e split_size[i] for dimension i). 282 | index_and_step_tuples = zip(block_start_indices, split_sizes) 283 | # This is a tuple of slices corresponding to: 284 | # tensor[index: index + step_size, ...]. This is 285 | # required because we could have an arbitrary number 286 | # of dimensions. The actual slices we need are the 287 | # start_index: start_index + step for each dimension in the tensor. 288 | block_slice = tuple([slice(start_index, start_index + step) 289 | for start_index, step in index_and_step_tuples]) 290 | 291 | # let's not initialize empty things to 0s because THAT SOUNDS REALLY BAD 292 | assert len(block_slice) == 2 293 | sizes = [x.stop - x.start for x in block_slice] 294 | tensor_copy = tensor.new(max(sizes), max(sizes)) 295 | torch.nn.init.orthogonal_(tensor_copy, gain=gain) 296 | tensor[block_slice] = tensor_copy[0:sizes[0], 0:sizes[1]].data 297 | 298 | 299 | def print_tree(tree): 300 | if tree is None: 301 | return 302 | if(tree.left_child is not None): 303 | print_node(tree.left_child) 304 | if(tree.right_child is not None): 305 | print_node(tree.right_child) 306 | 307 | print_tree(tree.left_child) 308 | print_tree(tree.right_child) 309 | 310 | return 311 | 312 | 313 | def print_node(tree): 314 | print(' depth: ', tree.depth(), end="") 315 | print(' score: ', tree.node_score, end="") 316 | print(' child: ', tree.get_total_child()) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.transforms as transforms 8 | from torch.autograd import Variable 9 | 10 | import config 11 | 12 | 13 | def fix_batchnorm(model): 14 | if isinstance(model, list): 15 | for m in model: 16 | fix_batchnorm(m) 17 | else: 18 | for m in model.modules(): 19 | if isinstance(m, nn.BatchNorm1d): 20 | #print('Fix BatchNorm1d') 21 | m.eval() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | #print('Fix BatchNorm2d') 24 | m.eval() 25 | elif isinstance(m, nn.BatchNorm3d): 26 | #print('Fix BatchNorm3d') 27 | m.eval() 28 | elif isinstance(m, nn.Dropout): 29 | #print('Fix Dropout') 30 | m.eval() 31 | elif isinstance(m, nn.AlphaDropout): 32 | #print('Fix AlphaDropout') 33 | m.eval() 34 | 35 | 36 | def optimistic_restore(network, state_dict): 37 | mismatch = False 38 | own_state = network.state_dict() 39 | for name, param in state_dict.items(): 40 | if name not in own_state: 41 | #print("Unexpected key {} in state_dict with size {}".format(name, param.size())) 42 | mismatch = True 43 | elif param.size() == own_state[name].size(): 44 | own_state[name].copy_(param) 45 | else: 46 | print("Network has {} with size {}, ckpt has {}".format(name, 47 | own_state[name].size(), 48 | param.size())) 49 | mismatch = True 50 | 51 | missing = set(own_state.keys()) - set(state_dict.keys()) 52 | if len(missing) > 0: 53 | print("We couldn't find {}".format(','.join(missing))) 54 | mismatch = True 55 | return not mismatch 56 | 57 | class PiecewiseLin(nn.Module): 58 | def __init__(self, n): 59 | super().__init__() 60 | self.n = n 61 | self.weight = nn.Parameter(torch.ones(n + 1)) 62 | # the first weight here is always 0 with a 0 gradient 63 | self.weight.data[0] = 0 64 | 65 | def forward(self, x): 66 | # all weights are positive -> function is monotonically increasing 67 | w = self.weight.abs() 68 | # make weights sum to one -> f(1) = 1 69 | w = w / w.sum() 70 | w = w.view([self.n + 1] + [1] * x.dim()) 71 | # keep cumulative sum for O(1) time complexity 72 | csum = w.cumsum(dim=0) 73 | csum = csum.expand((self.n + 1,) + tuple(x.size())) 74 | w = w.expand_as(csum) 75 | 76 | # figure out which part of the function the input lies on 77 | y = self.n * x.unsqueeze(0) 78 | idx = Variable(y.long().data) 79 | f = y.frac() 80 | 81 | # contribution of the linear parts left of the input 82 | x = csum.gather(0, idx.clamp(max=self.n)) 83 | # contribution within the linear segment the input falls into 84 | x = x + f * w.gather(0, (idx + 1).clamp(max=self.n)) 85 | return x.squeeze(0) 86 | 87 | 88 | def batch_accuracy(predicted, true): 89 | """ Compute the accuracies for a batch of predictions and answers """ 90 | _, predicted_index = predicted.max(dim=1, keepdim=True) 91 | agreeing = true.gather(dim=1, index=predicted_index) 92 | ''' 93 | Acc needs to be averaged over all 10 choose 9 subsets of human answers. 94 | While we could just use a loop, surely this can be done more efficiently (and indeed, it can). 95 | There are two cases for the 1 chosen answer to be discarded: 96 | (1) the discarded answer is not the predicted answer => acc stays the same 97 | (2) the discarded answer is the predicted answer => we have to subtract 1 from the number of agreeing answers 98 | 99 | There are (10 - num_agreeing_answers) of case 1 and num_agreeing_answers of case 2, thus 100 | acc = ((10 - agreeing) * min( agreeing / 3, 1) 101 | + agreeing * min((agreeing - 1) / 3, 1)) / 10 102 | 103 | Let's do some more simplification: 104 | if num_agreeing_answers == 0: 105 | acc = 0 since the case 1 min term becomes 0 and case 2 weighting term is 0 106 | if num_agreeing_answers >= 4: 107 | acc = 1 since the min term in both cases is always 1 108 | The only cases left are for 1, 2, and 3 agreeing answers. 109 | In all of those cases, (agreeing - 1) / 3 < agreeing / 3 <= 1, so we can get rid of all the mins. 110 | By moving num_agreeing_answers from both cases outside the sum we get: 111 | acc = agreeing * ((10 - agreeing) + (agreeing - 1)) / 3 / 10 112 | which we can simplify to: 113 | acc = agreeing * 0.3 114 | Finally, we can combine all cases together with: 115 | min(agreeing * 0.3, 1) 116 | ''' 117 | return (agreeing * 0.3).clamp(max=1) 118 | 119 | 120 | def path_for(train=False, val=False, test=False, question=False, answer=False): 121 | assert train + val + test == 1 122 | assert question + answer == 1 123 | 124 | if train: 125 | split = 'train2014' 126 | elif val: 127 | split = 'val2014' 128 | else: 129 | split = config.test_split 130 | 131 | if question: 132 | fmt = 'v2_{0}_{1}_{2}_questions.json' 133 | else: 134 | if test: 135 | # just load validation data in the test=answer=True case, will be ignored anyway 136 | split = 'val2014' 137 | fmt = 'v2_{1}_{2}_annotations.json' 138 | s = fmt.format(config.task, config.dataset, split) 139 | return os.path.join(config.qa_path, s) 140 | 141 | 142 | class Tracker: 143 | """ Keep track of results over time, while having access to monitors to display information about them. """ 144 | def __init__(self): 145 | self.data = {} 146 | 147 | def track(self, name, *monitors): 148 | """ Track a set of results with given monitors under some name (e.g. 'val_acc'). 149 | When appending to the returned list storage, use the monitors to retrieve useful information. 150 | """ 151 | l = Tracker.ListStorage(monitors) 152 | self.data.setdefault(name, []).append(l) 153 | return l 154 | 155 | def to_dict(self): 156 | # turn list storages into regular lists 157 | return {k: list(map(list, v)) for k, v in self.data.items()} 158 | 159 | 160 | class ListStorage: 161 | """ Storage of data points that updates the given monitors """ 162 | def __init__(self, monitors=[]): 163 | self.data = [] 164 | self.monitors = monitors 165 | for monitor in self.monitors: 166 | setattr(self, monitor.name, monitor) 167 | 168 | def append(self, item): 169 | for monitor in self.monitors: 170 | monitor.update(item) 171 | self.data.append(item) 172 | 173 | def __iter__(self): 174 | return iter(self.data) 175 | 176 | class MeanMonitor: 177 | """ Take the mean over the given values """ 178 | name = 'mean' 179 | 180 | def __init__(self): 181 | self.n = 0 182 | self.total = 0 183 | 184 | def update(self, value): 185 | self.total += value 186 | self.n += 1 187 | 188 | @property 189 | def value(self): 190 | return self.total / self.n 191 | 192 | class MovingMeanMonitor: 193 | """ Take an exponentially moving mean over the given values """ 194 | name = 'mean' 195 | 196 | def __init__(self, momentum=0.9): 197 | self.momentum = momentum 198 | self.first = True 199 | self.value = None 200 | 201 | def update(self, value): 202 | if self.first: 203 | self.value = value 204 | self.first = False 205 | else: 206 | m = self.momentum 207 | self.value = m * self.value + (1 - m) * value 208 | -------------------------------------------------------------------------------- /view-log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import matplotlib; matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def main(): 8 | path = sys.argv[1] 9 | results = torch.load(path) 10 | 11 | val_acc = torch.FloatTensor(results['tracker']['val_acc']) 12 | val_acc = val_acc.mean(dim=1).numpy() 13 | for i, v in enumerate(val_acc): 14 | print(i, v) 15 | 16 | plt.figure() 17 | plt.plot(val_acc) 18 | plt.savefig('val_acc.png') 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | --------------------------------------------------------------------------------