├── README.md ├── cfgs ├── __pycache__ │ ├── base_cfgs.cpython-36.pyc │ └── path_cfgs.cpython-36.pyc ├── base_cfgs.py ├── large_model.yml ├── mca_model.yml ├── mfb_model.yml ├── mfh_model.yml ├── path_cfgs.py └── small_model.yml ├── core ├── __pycache__ │ └── exec.cpython-36.pyc ├── data │ ├── __pycache__ │ │ ├── ans_punct.cpython-36.pyc │ │ ├── data_utils.cpython-36.pyc │ │ └── load_data.cpython-36.pyc │ ├── ans_punct.py │ ├── answer_dict.json │ ├── data_utils.py │ └── load_data.py ├── exec.py └── model │ ├── __pycache__ │ ├── mca.cpython-36.pyc │ ├── mfb.cpython-36.pyc │ ├── net.cpython-36.pyc │ ├── net_mfb.cpython-36.pyc │ ├── net_utils.cpython-36.pyc │ ├── optim.cpython-36.pyc │ ├── tv2d_layer_2.cpython-36.pyc │ ├── tv2d_layer_batch.cpython-36.pyc │ └── tv2d_numba.cpython-36.pyc │ ├── basis_functions.py │ ├── continuous_softmax.py │ ├── continuous_sparsemax.py │ ├── mca.py │ ├── mfb.py │ ├── net.py │ ├── net_mfb.py │ ├── net_utils.py │ ├── optim.py │ ├── tv2d_layer_2.py │ └── tv2d_numba.py ├── extract_features.py ├── run.py └── utils ├── __pycache__ ├── vqa.cpython-36.pyc └── vqaEval.cpython-36.pyc ├── proc_ansdict.py ├── vqa.py └── vqaEval.py /README.md: -------------------------------------------------------------------------------- 1 | # Sparse and Continuous Attention Mechanisms - experiments on VQA with continuous attention 2 | PyTorch implementation of the Deep Modular Co-Attention Networks (MCAN) with continuous attention. Follow this procedure to replicate the results reported in our paper [Sparse and Continuous Attention Mechanisms](https://arxiv.org/abs/2006.07214) [1]. Note: we added the files `basis_functions.py`, `continuous_softmax.py` and `continuous_sparsemax.py` and changed the `net.py` to work with continuous attention. 3 | 4 | ## Requirements 5 | 6 | We recommend to follow the procedure in the official [MCAN](https://github.com/MILVLG/mcan-vqa) repository in what concerns software and hardware requirements. We also use the same setup - see there how to organize the `datasets` folders. The only difference is that we don't use bottom-up features; instead you can download the images from [VQA-v2](https://visualqa.org/download.html) and place them in `./train2014`, `./val2014` and `./test2015`. Then, you can run: 7 | ```features 8 | python3 extract_features.py 9 | ``` 10 | to extract 14x14 grid features generated by a ResNet pretained on ImageNet. You should repeat the procedure for both `./train2014`, `./val2014` and `./test2015`. The extracted features will be saved in `./features/train`, `./features/val` and `./features/test`, respectively. 11 | 12 | You will also need to run: 13 | 14 | ```entmax 15 | pip install entmax 16 | ``` 17 | to install the entmax package. 18 | 19 | ## Training 20 | 21 | To train the models in the paper, run this command: 22 | 23 | ```train 24 | python3 run.py --RUN='train' --MAX_EPOCH=13 --M='mca' --gen_func='softmax' --SPLIT='train' --attention= --VERSION= 25 | ``` 26 | with ```={'discrete', 'cont-softmax', 'cont-sparsemax'}``` to train the model with discrete, 2D continuous softmax or 2D continuous sparsemax attention. Note that you should include ```gen_func='softmax'``` for both discrete softmax and continuous softmax or sparsemax models (in the results reported in the paper we used the mean and variance according to discrete softmax attention probabilities to obtain the attention density parameters for continuous attention). This will load all the default hyperparameters. You can assign a name for you model by doing ```='name'```. You can add ```--SEED=87415123``` to reproduce the results reported in the paper. 27 | 28 | ## Evaluation 29 | 30 | The evaluations of both the VQA 2.0 *test-dev* and *test-std* splits are run as follows: 31 | 32 | ```eval 33 | python3 run.py --RUN='test' --CKPT_V= --CKPT_E=13 --M='mca' --gen_func='softmax' --attention= 34 | 35 | ``` 36 | and the result file is stored in ```results/result_test/result_run_<'PATH+random number' or 'VERSION+EPOCH'>.json```. The obtained result json file can be uploaded to [Eval AI](https://evalai.cloudcv.org/web/challenges/challenge-page/163/overview) to evaluate the scores on *test-dev* and *test-std* splits. 37 | 38 | ## Results 39 | 40 | Following this steps you should be able to reprocude the results in the paper. The performance of the 3 models (discrete attention baseline, 2D continuous softmax attention and 2D continuos sparsemax attention) on *test-dev* split is reported as follows: 41 | 42 | _Model_ | Overall | Yes/No | Number | Other 43 | :-: | :-: | :-: | :-: | :-: 44 | _Discrete attention_ | 65.83 | 83.40 | 43.59 | 55.91 | 45 | _2D continuous softmax_ | **65.96**| 83.40 | 44.80 | 55.88 | 46 | _2D continuous sparsemax_ | 65.79 | 83.10 | 44.12 | 55.95 | 47 | 48 | On *test-std* split we report: 49 | 50 | _Model_ | Overall | Yes/No | Number | Other 51 | :-: | :-: | :-: | :-: | :-: 52 | _Discrete attention_ | 66.13 | 83.47 | 42.99 | 56.33 | 53 | _2D continuous softmax_ | **66.27**| 83.79 | 44.33 | 56.04 | 54 | _2D continuous sparsemax_ | 66.10 | 83.38 | 43.91 | 56.14 | 55 | 56 | ## References 57 | 58 | [1] André F. T. Martins, António Farinhas, Marcos Treviso, Vlad Niculae, Pedro M. Q. Aguiar, and Mário A. T. Figueiredo. [Sparse and Continuous Attention Mechanisms](https://arxiv.org/abs/2006.07214). NeurIPS 2020. 59 | -------------------------------------------------------------------------------- /cfgs/__pycache__/base_cfgs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/cfgs/__pycache__/base_cfgs.cpython-36.pyc -------------------------------------------------------------------------------- /cfgs/__pycache__/path_cfgs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/cfgs/__pycache__/path_cfgs.cpython-36.pyc -------------------------------------------------------------------------------- /cfgs/base_cfgs.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | from cfgs.path_cfgs import PATH 8 | # from path_cfgs import PATH 9 | 10 | import os, torch, random 11 | import numpy as np 12 | from types import MethodType 13 | 14 | 15 | class Cfgs(PATH): 16 | def __init__(self): 17 | super(Cfgs, self).__init__() 18 | 19 | # Set Devices 20 | # If use multi-gpu training, set e.g.'0, 1, 2' instead 21 | self.GPU = '0' 22 | 23 | # Set RNG For CPU And GPUs 24 | self.SEED = random.randint(0, 99999999) 25 | 26 | # ------------------------- 27 | # ---- Version Control ---- 28 | # ------------------------- 29 | 30 | # Define a specific name to start new training 31 | # self.VERSION = 'Anonymous_' + str(self.SEED) 32 | self.VERSION = str(self.SEED) 33 | 34 | # Resume training 35 | self.RESUME = False 36 | 37 | # Used in Resume training and testing 38 | self.CKPT_VERSION = self.VERSION 39 | self.CKPT_EPOCH = 0 40 | 41 | # Absolutely checkpoint path, 'CKPT_VERSION' and 'CKPT_EPOCH' will be overridden 42 | self.CKPT_PATH = None 43 | 44 | # Print loss every step 45 | self.VERBOSE = True 46 | 47 | 48 | # ------------------------------ 49 | # ---- Data Provider Params ---- 50 | # ------------------------------ 51 | 52 | # {'train', 'val', 'test'} 53 | self.RUN_MODE = 'train' 54 | 55 | # Set True to evaluate offline 56 | self.EVAL_EVERY_EPOCH = True 57 | 58 | # Set True to save the prediction vector (Ensemble) 59 | self.TEST_SAVE_PRED = False 60 | 61 | # Pre-load the features into memory to increase the I/O speed 62 | self.PRELOAD = False 63 | 64 | # Define the 'train' 'val' 'test' data split 65 | # (EVAL_EVERY_EPOCH triggered when set {'train': 'train'}) 66 | self.SPLIT = { 67 | 'train': '', 68 | 'val': 'val', 69 | 'test': 'test', 70 | } 71 | 72 | # A external method to set train split 73 | self.TRAIN_SPLIT = 'train+val+vg' 74 | 75 | # Set True to use pretrained word embedding 76 | # (GloVe: spaCy https://spacy.io/) 77 | self.USE_GLOVE = True 78 | 79 | # Word embedding matrix size 80 | # (token size x WORD_EMBED_SIZE) 81 | self.WORD_EMBED_SIZE = 300 82 | 83 | # Max length of question sentences 84 | self.MAX_TOKEN = 14 85 | 86 | # Filter the answer by occurrence 87 | # self.ANS_FREQ = 8 88 | 89 | # Max length of extracted faster-rcnn 2048D features 90 | # (bottom-up and Top-down: https://github.com/peteanderson80/bottom-up-attention) 91 | self.IMG_FEAT_PAD_SIZE = 100 92 | 93 | # Faster-rcnn 2048D features 94 | self.IMG_FEAT_SIZE = 2048 95 | 96 | # Default training batch size: 64 97 | self.BATCH_SIZE = 64 98 | 99 | # Multi-thread I/O 100 | self.NUM_WORKERS = 8 101 | 102 | # Use pin memory 103 | # (Warning: pin memory can accelerate GPU loading but may 104 | # increase the CPU memory usage when NUM_WORKS is large) 105 | self.PIN_MEM = True 106 | 107 | # Large model can not training with batch size 64 108 | # Gradient accumulate can split batch to reduce gpu memory usage 109 | # (Warning: BATCH_SIZE should be divided by GRAD_ACCU_STEPS) 110 | self.GRAD_ACCU_STEPS = 1 111 | 112 | # Set 'external': use external shuffle method to implement training shuffle 113 | # Set 'internal': use pytorch dataloader default shuffle method 114 | self.SHUFFLE_MODE = 'external' 115 | 116 | 117 | # ------------------------ 118 | # ---- Network Params ---- 119 | # ------------------------ 120 | 121 | # Model deeps 122 | # (Encoder and Decoder will be same deeps) 123 | self.LAYER = 6 124 | 125 | # Model hidden size 126 | # (512 as default, bigger will be a sharp increase of gpu memory usage) 127 | self.HIDDEN_SIZE = 512 128 | 129 | # Multi-head number in MCA layers 130 | # (Warning: HIDDEN_SIZE should be divided by MULTI_HEAD) 131 | self.MULTI_HEAD = 8 132 | 133 | # Dropout rate for all dropout layers 134 | # (dropout can prevent overfitting: [Dropout: a simple way to prevent neural networks from overfitting]) 135 | self.DROPOUT_R = 0.1 136 | 137 | # MLP size in flatten layers 138 | self.FLAT_MLP_SIZE = 512 139 | 140 | # Flatten the last hidden to vector with {n} attention glimpses 141 | self.FLAT_GLIMPSES = 1 142 | self.FLAT_OUT_SIZE = 1024 143 | 144 | 145 | # -------------------------- 146 | # ---- Optimizer Params ---- 147 | # -------------------------- 148 | 149 | # The base learning rate 150 | self.LR_BASE = 0.0001 151 | 152 | # Learning rate decay ratio 153 | self.LR_DECAY_R = 0.2 154 | 155 | # Learning rate decay at {x, y, z...} epoch 156 | self.LR_DECAY_LIST = [10, 12] 157 | 158 | # Max training epoch 159 | self.MAX_EPOCH = 13 160 | 161 | # Gradient clip 162 | # (default: -1 means not using) 163 | self.GRAD_NORM_CLIP = -1 164 | 165 | # Adam optimizer betas and eps 166 | self.OPT_BETAS = (0.9, 0.98) 167 | self.OPT_EPS = 1e-9 168 | 169 | 170 | def parse_to_dict(self, args): 171 | args_dict = {} 172 | for arg in dir(args): 173 | if not arg.startswith('_') and not isinstance(getattr(args, arg), MethodType): 174 | if getattr(args, arg) is not None: 175 | args_dict[arg] = getattr(args, arg) 176 | 177 | return args_dict 178 | 179 | 180 | def add_args(self, args_dict): 181 | for arg in args_dict: 182 | setattr(self, arg, args_dict[arg]) 183 | 184 | 185 | def proc(self): 186 | assert self.RUN_MODE in ['train', 'val', 'test'] 187 | 188 | # ------------ Devices setup 189 | os.environ['CUDA_VISIBLE_DEVICES'] = self.GPU 190 | self.N_GPU = len(self.GPU.split(',')) 191 | self.DEVICES = [_ for _ in range(self.N_GPU)] 192 | torch.set_num_threads(2) 193 | 194 | 195 | # ------------ Seed setup 196 | # fix pytorch seed 197 | torch.manual_seed(self.SEED) 198 | if self.N_GPU < 2: 199 | torch.cuda.manual_seed(self.SEED) 200 | else: 201 | torch.cuda.manual_seed_all(self.SEED) 202 | torch.backends.cudnn.deterministic = True 203 | 204 | # fix numpy seed 205 | np.random.seed(self.SEED) 206 | 207 | # fix random seed 208 | random.seed(self.SEED) 209 | 210 | if self.CKPT_PATH is not None: 211 | print('Warning: you are now using CKPT_PATH args, ' 212 | 'CKPT_VERSION and CKPT_EPOCH will not work') 213 | self.CKPT_VERSION = self.CKPT_PATH.split('/')[-1] + '_' + str(random.randint(0, 99999999)) 214 | 215 | 216 | # ------------ Split setup 217 | self.SPLIT['train'] = self.TRAIN_SPLIT 218 | if 'val' in self.SPLIT['train'].split('+') or self.RUN_MODE not in ['train']: 219 | self.EVAL_EVERY_EPOCH = False 220 | 221 | if self.RUN_MODE not in ['test']: 222 | self.TEST_SAVE_PRED = False 223 | 224 | 225 | # ------------ Gradient accumulate setup 226 | assert self.BATCH_SIZE % self.GRAD_ACCU_STEPS == 0 227 | self.SUB_BATCH_SIZE = int(self.BATCH_SIZE / self.GRAD_ACCU_STEPS) 228 | 229 | # Use a small eval batch will reduce gpu memory usage 230 | self.EVAL_BATCH_SIZE = int(self.SUB_BATCH_SIZE / 2) 231 | 232 | 233 | # ------------ Networks setup 234 | # FeedForwardNet size in every MCA layer 235 | self.FF_SIZE = int(self.HIDDEN_SIZE * 4) 236 | 237 | # A pipe line hidden size in attention compute 238 | assert self.HIDDEN_SIZE % self.MULTI_HEAD == 0 239 | self.HIDDEN_SIZE_HEAD = int(self.HIDDEN_SIZE / self.MULTI_HEAD) 240 | 241 | 242 | def __str__(self): 243 | for attr in dir(self): 244 | if not attr.startswith('__') and not isinstance(getattr(self, attr), MethodType): 245 | print('{ %-17s }->' % attr, getattr(self, attr)) 246 | 247 | return '' 248 | 249 | # 250 | # 251 | # if __name__ == '__main__': 252 | # __C = Cfgs() 253 | # __C.proc() 254 | 255 | 256 | 257 | 258 | 259 | -------------------------------------------------------------------------------- /cfgs/large_model.yml: -------------------------------------------------------------------------------- 1 | LAYER: 6 2 | HIDDEN_SIZE: 1024 3 | MULTI_HEAD: 8 4 | DROPOUT_R: 0.1 5 | FLAT_MLP_SIZE: 512 6 | FLAT_GLIMPSES: 1 7 | FLAT_OUT_SIZE: 2048 8 | LR_BASE: 0.00005 9 | LR_DECAY_R: 0.2 10 | GRAD_ACCU_STEPS: 2 11 | CKPT_VERSION: 'large' 12 | CKPT_EPOCH: 13 -------------------------------------------------------------------------------- /cfgs/mca_model.yml: -------------------------------------------------------------------------------- 1 | LAYER: 6 2 | HIDDEN_SIZE: 512 3 | MULTI_HEAD: 8 4 | DROPOUT_R: 0.1 5 | FLAT_MLP_SIZE: 512 6 | FLAT_GLIMPSES: 1 7 | FLAT_OUT_SIZE: 1024 8 | LR_BASE: 0.0001 9 | LR_DECAY_R: 0.2 10 | GRAD_ACCU_STEPS: 1 11 | CKPT_VERSION: 'small' 12 | CKPT_EPOCH: 13 -------------------------------------------------------------------------------- /cfgs/mfb_model.yml: -------------------------------------------------------------------------------- 1 | MFB_K: 5 2 | MFB_O: 1000 3 | LSTM_OUT_SIZE: 1024 4 | DROPOUT_R: 0.1 5 | I_GLIMPSES: 2 6 | Q_GLIMPSES: 2 7 | HIGH_ORDER: False # True for MFH, False for MFB 8 | HIDDEN_SIZE: 512 -------------------------------------------------------------------------------- /cfgs/mfh_model.yml: -------------------------------------------------------------------------------- 1 | MFB_K: 5 2 | MFB_O: 1000 3 | LSTM_OUT_SIZE: 1024 4 | DROPOUT_R: 0.1 5 | I_GLIMPSES: 2 6 | Q_GLIMPSES: 2 7 | HIGH_ORDER: True # True for MFH, False for MFB 8 | HIDDEN_SIZE: 512 -------------------------------------------------------------------------------- /cfgs/path_cfgs.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | import os 8 | 9 | class PATH: 10 | def __init__(self): 11 | 12 | # vqav2 dataset root path 13 | self.DATASET_PATH = './datasets/vqa/' 14 | 15 | # bottom up features root path 16 | self.FEATURE_PATH = './datasets/coco_extract/' 17 | 18 | self.init_path() 19 | 20 | 21 | def init_path(self): 22 | 23 | self.IMG_FEAT_PATH = { 24 | 'train': self.FEATURE_PATH + 'train/', 25 | 'val': self.FEATURE_PATH + 'val/', 26 | 'test': self.FEATURE_PATH + 'test/', 27 | } 28 | 29 | self.QUESTION_PATH = { 30 | 'train': self.DATASET_PATH + 'v2_OpenEnded_mscoco_train2014_questions.json', 31 | 'val': self.DATASET_PATH + 'v2_OpenEnded_mscoco_val2014_questions.json', 32 | 'test': self.DATASET_PATH + 'v2_OpenEnded_mscoco_test2015_questions.json', 33 | 'vg': self.DATASET_PATH + 'VG_questions.json', 34 | } 35 | 36 | self.ANSWER_PATH = { 37 | 'train': self.DATASET_PATH + 'v2_mscoco_train2014_annotations.json', 38 | 'val': self.DATASET_PATH + 'v2_mscoco_val2014_annotations.json', 39 | 'vg': self.DATASET_PATH + 'VG_annotations.json', 40 | } 41 | 42 | self.RESULT_PATH = './results/result_test/' 43 | self.PRED_PATH = './results/pred/' 44 | self.CACHE_PATH = './results/cache/' 45 | self.LOG_PATH = './results/log/' 46 | self.CKPTS_PATH = './ckpts/' 47 | 48 | if 'result_test' not in os.listdir('./results'): 49 | os.mkdir('./results/result_test') 50 | 51 | if 'pred' not in os.listdir('./results'): 52 | os.mkdir('./results/pred') 53 | 54 | if 'cache' not in os.listdir('./results'): 55 | os.mkdir('./results/cache') 56 | 57 | if 'log' not in os.listdir('./results'): 58 | os.mkdir('./results/log') 59 | 60 | if 'ckpts' not in os.listdir('./'): 61 | os.mkdir('./ckpts') 62 | 63 | 64 | def check_path(self): 65 | print('Checking dataset ...') 66 | 67 | for mode in self.IMG_FEAT_PATH: 68 | if not os.path.exists(self.IMG_FEAT_PATH[mode]): 69 | print(self.IMG_FEAT_PATH[mode] + 'NOT EXIST') 70 | exit(-1) 71 | 72 | for mode in self.QUESTION_PATH: 73 | if not os.path.exists(self.QUESTION_PATH[mode]): 74 | print(self.QUESTION_PATH[mode] + 'NOT EXIST') 75 | exit(-1) 76 | 77 | for mode in self.ANSWER_PATH: 78 | if not os.path.exists(self.ANSWER_PATH[mode]): 79 | print(self.ANSWER_PATH[mode] + 'NOT EXIST') 80 | exit(-1) 81 | 82 | print('Finished') 83 | print('') 84 | 85 | -------------------------------------------------------------------------------- /cfgs/small_model.yml: -------------------------------------------------------------------------------- 1 | LAYER: 6 2 | HIDDEN_SIZE: 512 3 | MULTI_HEAD: 8 4 | DROPOUT_R: 0.1 5 | FLAT_MLP_SIZE: 512 6 | FLAT_GLIMPSES: 1 7 | FLAT_OUT_SIZE: 1024 8 | LR_BASE: 0.0001 9 | LR_DECAY_R: 0.2 10 | GRAD_ACCU_STEPS: 1 11 | CKPT_VERSION: 'small' 12 | CKPT_EPOCH: 13 -------------------------------------------------------------------------------- /core/__pycache__/exec.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/__pycache__/exec.cpython-36.pyc -------------------------------------------------------------------------------- /core/data/__pycache__/ans_punct.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/data/__pycache__/ans_punct.cpython-36.pyc -------------------------------------------------------------------------------- /core/data/__pycache__/data_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/data/__pycache__/data_utils.cpython-36.pyc -------------------------------------------------------------------------------- /core/data/__pycache__/load_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/data/__pycache__/load_data.cpython-36.pyc -------------------------------------------------------------------------------- /core/data/ans_punct.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # based on VQA Evaluation Code 6 | # -------------------------------------------------------- 7 | 8 | import re 9 | 10 | contractions = { 11 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 12 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 13 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 14 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 15 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 16 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 17 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 18 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 19 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 20 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 21 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 22 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 23 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 24 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 25 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 26 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 27 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 28 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 29 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 30 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 31 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 32 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 33 | "someonell": "someone'll", "someones": "someone's", "somethingd": 34 | "something'd", "somethingd've": "something'd've", "something'dve": 35 | "something'd've", "somethingll": "something'll", "thats": 36 | "that's", "thered": "there'd", "thered've": "there'd've", 37 | "there'dve": "there'd've", "therere": "there're", "theres": 38 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 39 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 40 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 41 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 42 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 43 | "what's", "whatve": "what've", "whens": "when's", "whered": 44 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 45 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 46 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 47 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 48 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 49 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 50 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 51 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 52 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 53 | "you'll", "youre": "you're", "youve": "you've" 54 | } 55 | 56 | manual_map = { 'none': '0', 57 | 'zero': '0', 58 | 'one': '1', 59 | 'two': '2', 60 | 'three': '3', 61 | 'four': '4', 62 | 'five': '5', 63 | 'six': '6', 64 | 'seven': '7', 65 | 'eight': '8', 66 | 'nine': '9', 67 | 'ten': '10'} 68 | articles = ['a', 'an', 'the'] 69 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 70 | comma_strip = re.compile("(\d)(\,)(\d)") 71 | punct = [';', r"/", '[', ']', '"', '{', '}', 72 | '(', ')', '=', '+', '\\', '_', '-', 73 | '>', '<', '@', '`', ',', '?', '!'] 74 | 75 | def process_punctuation(inText): 76 | outText = inText 77 | for p in punct: 78 | if (p + ' ' in inText or ' ' + p in inText) \ 79 | or (re.search(comma_strip, inText) != None): 80 | outText = outText.replace(p, '') 81 | else: 82 | outText = outText.replace(p, ' ') 83 | outText = period_strip.sub("", outText, re.UNICODE) 84 | return outText 85 | 86 | 87 | def process_digit_article(inText): 88 | outText = [] 89 | tempText = inText.lower().split() 90 | for word in tempText: 91 | word = manual_map.setdefault(word, word) 92 | if word not in articles: 93 | outText.append(word) 94 | else: 95 | pass 96 | for wordId, word in enumerate(outText): 97 | if word in contractions: 98 | outText[wordId] = contractions[word] 99 | outText = ' '.join(outText) 100 | return outText 101 | 102 | 103 | def prep_ans(answer): 104 | answer = process_digit_article(process_punctuation(answer)) 105 | answer = answer.replace(',', '') 106 | return answer 107 | -------------------------------------------------------------------------------- /core/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | from core.data.ans_punct import prep_ans 8 | import numpy as np 9 | import en_vectors_web_lg, random, re, json 10 | 11 | 12 | def shuffle_list(ans_list): 13 | random.shuffle(ans_list) 14 | 15 | 16 | # ------------------------------ 17 | # ---- Initialization Utils ---- 18 | # ------------------------------ 19 | 20 | def img_feat_path_load(path_list): 21 | iid_to_path = {} 22 | 23 | for ix, path in enumerate(path_list): 24 | iid = str(int(path.split('/')[-1].split('_')[-1].split('.')[0])) 25 | iid_to_path[iid] = path 26 | 27 | return iid_to_path 28 | 29 | 30 | def img_feat_load(path_list): 31 | iid_to_feat = {} 32 | 33 | for ix, path in enumerate(path_list): 34 | iid = str(int(path.split('/')[-1].split('_')[-1].split('.')[0])) 35 | a = np.load(path) 36 | #img_feat_x = img_feat['x'].transpose((1, 0)) 37 | img_feat_x = img_feat = [a[k] for k in a][0] 38 | iid_to_feat[iid] = img_feat_x 39 | print('\rPre-Loading: [{} | {}] '.format(ix, path_list.__len__()), end=' ') 40 | 41 | return iid_to_feat 42 | 43 | 44 | def ques_load(ques_list): 45 | qid_to_ques = {} 46 | 47 | for ques in ques_list: 48 | qid = str(ques['question_id']) 49 | qid_to_ques[qid] = ques 50 | 51 | return qid_to_ques 52 | 53 | 54 | def tokenize(stat_ques_list, use_glove): 55 | token_to_ix = { 56 | 'PAD': 0, 57 | 'UNK': 1, 58 | } 59 | 60 | spacy_tool = None 61 | pretrained_emb = [] 62 | if use_glove: 63 | spacy_tool = en_vectors_web_lg.load() 64 | pretrained_emb.append(spacy_tool('PAD').vector) 65 | pretrained_emb.append(spacy_tool('UNK').vector) 66 | 67 | for ques in stat_ques_list: 68 | words = re.sub( 69 | r"([.,'!?\"()*#:;])", 70 | '', 71 | ques['question'].lower() 72 | ).replace('-', ' ').replace('/', ' ').split() 73 | 74 | for word in words: 75 | if word not in token_to_ix: 76 | token_to_ix[word] = len(token_to_ix) 77 | if use_glove: 78 | pretrained_emb.append(spacy_tool(word).vector) 79 | 80 | pretrained_emb = np.array(pretrained_emb) 81 | 82 | return token_to_ix, pretrained_emb 83 | 84 | 85 | # def ans_stat(stat_ans_list, ans_freq): 86 | # ans_to_ix = {} 87 | # ix_to_ans = {} 88 | # ans_freq_dict = {} 89 | # 90 | # for ans in stat_ans_list: 91 | # ans_proc = prep_ans(ans['multiple_choice_answer']) 92 | # if ans_proc not in ans_freq_dict: 93 | # ans_freq_dict[ans_proc] = 1 94 | # else: 95 | # ans_freq_dict[ans_proc] += 1 96 | # 97 | # ans_freq_filter = ans_freq_dict.copy() 98 | # for ans in ans_freq_dict: 99 | # if ans_freq_dict[ans] <= ans_freq: 100 | # ans_freq_filter.pop(ans) 101 | # 102 | # for ans in ans_freq_filter: 103 | # ix_to_ans[ans_to_ix.__len__()] = ans 104 | # ans_to_ix[ans] = ans_to_ix.__len__() 105 | # 106 | # return ans_to_ix, ix_to_ans 107 | 108 | 109 | def ans_stat(json_file): 110 | ans_to_ix, ix_to_ans = json.load(open(json_file, 'r')) 111 | 112 | return ans_to_ix, ix_to_ans 113 | 114 | 115 | # ------------------------------------ 116 | # ---- Real-Time Processing Utils ---- 117 | # ------------------------------------ 118 | 119 | def proc_img_feat(img_feat, img_feat_pad_size): 120 | if img_feat.shape[0] > img_feat_pad_size: 121 | img_feat = img_feat[:img_feat_pad_size] 122 | 123 | img_feat = np.pad( 124 | img_feat, 125 | ((0, img_feat_pad_size - img_feat.shape[0]), (0, 0)), 126 | mode='constant', 127 | constant_values=0 128 | ) 129 | 130 | return img_feat 131 | 132 | 133 | def proc_ques(ques, token_to_ix, max_token): 134 | ques_ix = np.zeros(max_token, np.int64) 135 | 136 | words = re.sub( 137 | r"([.,'!?\"()*#:;])", 138 | '', 139 | ques['question'].lower() 140 | ).replace('-', ' ').replace('/', ' ').split() 141 | 142 | for ix, word in enumerate(words): 143 | if word in token_to_ix: 144 | ques_ix[ix] = token_to_ix[word] 145 | else: 146 | ques_ix[ix] = token_to_ix['UNK'] 147 | 148 | if ix + 1 == max_token: 149 | break 150 | 151 | return ques_ix 152 | 153 | 154 | def get_score(occur): 155 | if occur == 0: 156 | return .0 157 | elif occur == 1: 158 | return .3 159 | elif occur == 2: 160 | return .6 161 | elif occur == 3: 162 | return .9 163 | else: 164 | return 1. 165 | 166 | 167 | def proc_ans(ans, ans_to_ix): 168 | ans_score = np.zeros(ans_to_ix.__len__(), np.float32) 169 | ans_prob_dict = {} 170 | 171 | for ans_ in ans['answers']: 172 | ans_proc = prep_ans(ans_['answer']) 173 | if ans_proc not in ans_prob_dict: 174 | ans_prob_dict[ans_proc] = 1 175 | else: 176 | ans_prob_dict[ans_proc] += 1 177 | 178 | for ans_ in ans_prob_dict: 179 | if ans_ in ans_to_ix: 180 | ans_score[ans_to_ix[ans_]] = get_score(ans_prob_dict[ans_]) 181 | 182 | return ans_score 183 | 184 | -------------------------------------------------------------------------------- /core/data/load_data.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | from core.data.data_utils import img_feat_path_load, img_feat_load, ques_load, tokenize, ans_stat 8 | from core.data.data_utils import proc_img_feat, proc_ques, proc_ans 9 | 10 | import numpy as np 11 | import glob, json, torch, time 12 | import torch.utils.data as Data 13 | import torch 14 | 15 | class DataSet(Data.Dataset): 16 | def __init__(self, __C): 17 | self.__C = __C 18 | 19 | 20 | # -------------------------- 21 | # ---- Raw data loading ---- 22 | # -------------------------- 23 | 24 | # Loading all image paths 25 | # if self.__C.PRELOAD: 26 | self.img_feat_path_list = [] 27 | split_list = __C.SPLIT[__C.RUN_MODE].split('+') 28 | for split in split_list: 29 | if split in ['train', 'val', 'test']: 30 | self.img_feat_path_list += glob.glob(__C.IMG_FEAT_PATH[split]+'*.npz' ) 31 | 32 | # if __C.EVAL_EVERY_EPOCH and __C.RUN_MODE in ['train']: 33 | # self.img_feat_path_list += glob.glob(__C.IMG_FEAT_PATH['val'] + '*.npz') 34 | 35 | # else: 36 | # self.img_feat_path_list = \ 37 | # glob.glob(__C.IMG_FEAT_PATH['train'] + '*.npz') + \ 38 | # glob.glob(__C.IMG_FEAT_PATH['val'] + '*.npz') + \ 39 | # glob.glob(__C.IMG_FEAT_PATH['test'] + '*.npz') 40 | 41 | # Loading question word list 42 | self.stat_ques_list = \ 43 | json.load(open(__C.QUESTION_PATH['train'], 'r'))['questions'] + \ 44 | json.load(open(__C.QUESTION_PATH['val'], 'r'))['questions'] + \ 45 | json.load(open(__C.QUESTION_PATH['test'], 'r'))['questions'] + \ 46 | json.load(open(__C.QUESTION_PATH['vg'], 'r'))['questions'] 47 | 48 | # Loading answer word list 49 | # self.stat_ans_list = \ 50 | # json.load(open(__C.ANSWER_PATH['train'], 'r'))['annotations'] + \ 51 | # json.load(open(__C.ANSWER_PATH['val'], 'r'))['annotations'] 52 | 53 | # Loading question and answer list 54 | self.ques_list = [] 55 | self.ans_list = [] 56 | 57 | split_list = __C.SPLIT[__C.RUN_MODE].split('+') 58 | for split in split_list: 59 | self.ques_list += json.load(open(__C.QUESTION_PATH[split], 'r'))['questions'] 60 | if __C.RUN_MODE in ['train']: 61 | self.ans_list += json.load(open(__C.ANSWER_PATH[split], 'r'))['annotations'] 62 | 63 | # Define run data size 64 | if __C.RUN_MODE in ['train']: 65 | self.data_size = self.ans_list.__len__() 66 | else: 67 | self.data_size = self.ques_list.__len__() 68 | 69 | print('== Dataset size:', self.data_size) 70 | 71 | 72 | # ------------------------ 73 | # ---- Data statistic ---- 74 | # ------------------------ 75 | 76 | # {image id} -> {image feature absolutely path} 77 | if self.__C.PRELOAD: 78 | print('==== Pre-Loading features ...') 79 | time_start = time.time() 80 | self.iid_to_img_feat = img_feat_load(self.img_feat_path_list) 81 | time_end = time.time() 82 | print('==== Finished in {}s'.format(int(time_end-time_start))) 83 | else: 84 | self.iid_to_img_feat_path = img_feat_path_load(self.img_feat_path_list) 85 | 86 | # {question id} -> {question} 87 | self.qid_to_ques = ques_load(self.ques_list) 88 | 89 | # Tokenize 90 | self.token_to_ix, self.pretrained_emb = tokenize(self.stat_ques_list, __C.USE_GLOVE) 91 | self.token_size = self.token_to_ix.__len__() 92 | print('== Question token vocab size:', self.token_size) 93 | 94 | # Answers statistic 95 | # Make answer dict during training does not guarantee 96 | # the same order of {ans_to_ix}, so we published our 97 | # answer dict to ensure that our pre-trained model 98 | # can be adapted on each machine. 99 | 100 | # Thanks to Licheng Yu (https://github.com/lichengunc) 101 | # for finding this bug and providing the solutions. 102 | 103 | # self.ans_to_ix, self.ix_to_ans = ans_stat(self.stat_ans_list, __C.ANS_FREQ) 104 | self.ans_to_ix, self.ix_to_ans = ans_stat('core/data/answer_dict.json') 105 | self.ans_size = self.ans_to_ix.__len__() 106 | print('== Answer vocab size (occurr more than {} times):'.format(8), self.ans_size) 107 | print('Finished!') 108 | print('') 109 | 110 | 111 | def __getitem__(self, idx): 112 | 113 | # For code safety 114 | img_feat_iter = np.zeros(1) 115 | ques_ix_iter = np.zeros(1) 116 | ans_iter = np.zeros(1) 117 | 118 | # Process ['train'] and ['val', 'test'] respectively 119 | if self.__C.RUN_MODE in ['train']: 120 | # Load the run data from list 121 | ans = self.ans_list[idx] 122 | ques = self.qid_to_ques[str(ans['question_id'])] 123 | 124 | # Process image feature from (.npz) file 125 | if self.__C.PRELOAD: 126 | img_feat_x = self.iid_to_img_feat[str(ans['image_id'])] 127 | else: 128 | a = np.load(self.iid_to_img_feat_path[str(ans['image_id'])]) 129 | #img_feat_x = img_feat['x'].transpose((1, 0)) 130 | img_feat = [a[k] for k in a][0] 131 | img_feat_iter = img_feat#proc_img_feat(img_feat_x, self.__C.IMG_FEAT_PAD_SIZE) 132 | # Process question 133 | ques_ix_iter = proc_ques(ques, self.token_to_ix, self.__C.MAX_TOKEN) 134 | 135 | # Process answer 136 | ans_iter = proc_ans(ans, self.ans_to_ix) 137 | 138 | else: 139 | # Load the run data from list 140 | ques = self.ques_list[idx] 141 | 142 | # # Process image feature from (.npz) file 143 | # img_feat = np.load(self.iid_to_img_feat_path[str(ques['image_id'])]) 144 | # img_feat_x = img_feat['x'].transpose((1, 0)) 145 | # Process image feature from (.npz) file 146 | if self.__C.PRELOAD: 147 | img_feat_x = self.iid_to_img_feat[str(ques['image_id'])] 148 | else: 149 | a = np.load(self.iid_to_img_feat_path[str(ques['image_id'])]) 150 | img_feat = [a[k] for k in a][0] 151 | #img_feat_x = img_feat['x'].transpose((1, 0)) 152 | img_feat_iter = img_feat#proc_img_feat(img_feat_x, self.__C.IMG_FEAT_PAD_SIZE) 153 | 154 | # Process question 155 | ques_ix_iter = proc_ques(ques, self.token_to_ix, self.__C.MAX_TOKEN) 156 | 157 | 158 | return torch.from_numpy(img_feat_iter), \ 159 | torch.from_numpy(ques_ix_iter), \ 160 | torch.from_numpy(ans_iter) 161 | 162 | 163 | def __len__(self): 164 | return self.data_size 165 | 166 | 167 | -------------------------------------------------------------------------------- /core/exec.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | from core.data.load_data import DataSet 8 | from core.model.net import Net 9 | from core.model.net_mfb import Net_mfb 10 | from core.model.optim import get_optim, adjust_lr 11 | from core.data.data_utils import shuffle_list 12 | from utils.vqa import VQA 13 | from utils.vqaEval import VQAEval 14 | 15 | import os, json, torch, datetime, pickle, copy, shutil, time 16 | import numpy as np 17 | import torch.nn as nn 18 | import torch.utils.data as Data 19 | 20 | from entmax import sparsemax 21 | from functools import partial 22 | 23 | 24 | 25 | class Execution: 26 | def __init__(self, __C): 27 | self.__C = __C 28 | 29 | print('Loading training set ........') 30 | self.dataset = DataSet(__C) 31 | 32 | self.dataset_eval = None 33 | if __C.EVAL_EVERY_EPOCH: 34 | __C_eval = copy.deepcopy(__C) 35 | setattr(__C_eval, 'RUN_MODE', 'val') 36 | 37 | print('Loading validation set for per-epoch evaluation ........') 38 | self.dataset_eval = DataSet(__C_eval) 39 | 40 | 41 | def train(self, dataset, dataset_eval=None): 42 | 43 | # Obtain needed information 44 | data_size = dataset.data_size 45 | token_size = dataset.token_size 46 | ans_size = dataset.ans_size 47 | pretrained_emb = dataset.pretrained_emb 48 | patience=0 49 | best_acc=0 50 | 51 | gen_funcs = {"softmax": torch.softmax,"sparsemax": partial(sparsemax, k=512),"tvmax": "tvmax"} 52 | 53 | gen_func = gen_funcs[self.__C.gen_func] 54 | 55 | # Define the MCAN model 56 | if self.__C.MODEL=='mca': 57 | net = Net(self.__C,pretrained_emb,token_size,ans_size,gen_func=gen_func) 58 | else: 59 | net = Net_mfb(self.__C,pretrained_emb,token_size,ans_size,gen_func=gen_func) 60 | 61 | net.cuda() 62 | net.train() 63 | 64 | # Define the multi-gpu training if needed 65 | if self.__C.N_GPU > 1: 66 | net = nn.DataParallel(net, device_ids=self.__C.DEVICES) 67 | 68 | # Define the binary cross entropy loss 69 | # loss_fn = torch.nn.BCELoss(size_average=False).cuda() 70 | if self.__C.MODEL=='mca': 71 | loss_fn = torch.nn.BCELoss(reduction='sum').cuda() 72 | else: 73 | loss_fn = torch.nn.KLDivLoss(reduction='sum').cuda() 74 | # Load checkpoint if resume training 75 | if self.__C.RESUME: 76 | print(' ========== Resume training') 77 | 78 | if self.__C.CKPT_PATH is not None: 79 | print('Warning: you are now using CKPT_PATH args, ' 80 | 'CKPT_VERSION and CKPT_EPOCH will not work') 81 | 82 | path = self.__C.CKPT_PATH 83 | else: 84 | path = self.__C.CKPTS_PATH + 'ckpt_' + self.__C.CKPT_VERSION + '/epoch' + str(self.__C.CKPT_EPOCH) + '.pkl' 85 | 86 | # Load the network parameters 87 | print('Loading ckpt {}'.format(path)) 88 | ckpt = torch.load(path) 89 | print('Finish!') 90 | net.load_state_dict(ckpt['state_dict']) 91 | 92 | # Load the optimizer paramters 93 | optim = get_optim(self.__C, net, data_size, ckpt['lr_base']) 94 | optim._step = int(data_size / self.__C.BATCH_SIZE * self.__C.CKPT_EPOCH) 95 | optim.optimizer.load_state_dict(ckpt['optimizer']) 96 | 97 | start_epoch = self.__C.CKPT_EPOCH 98 | 99 | else: 100 | if ('ckpt_' + self.__C.VERSION) in os.listdir(self.__C.CKPTS_PATH): 101 | shutil.rmtree(self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION) 102 | 103 | os.mkdir(self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION) 104 | 105 | optim = get_optim(self.__C, net, data_size) 106 | start_epoch = 0 107 | 108 | loss_sum = 0 109 | named_params = list(net.named_parameters()) 110 | grad_norm = np.zeros(len(named_params)) 111 | 112 | # Define multi-thread dataloader 113 | if self.__C.SHUFFLE_MODE in ['external']: 114 | dataloader = Data.DataLoader(dataset,batch_size=self.__C.BATCH_SIZE,shuffle=False,num_workers=self.__C.NUM_WORKERS,pin_memory=self.__C.PIN_MEM,drop_last=True 115 | ) 116 | else: 117 | dataloader = Data.DataLoader(dataset,batch_size=self.__C.BATCH_SIZE,shuffle=True,num_workers=self.__C.NUM_WORKERS,pin_memory=self.__C.PIN_MEM,drop_last=True 118 | ) 119 | 120 | # Training script 121 | for epoch in range(start_epoch, self.__C.MAX_EPOCH): 122 | 123 | # Save log information 124 | logfile = open( 125 | self.__C.LOG_PATH + 126 | 'log_run_' + self.__C.VERSION + '.txt', 127 | 'a+') 128 | logfile.write( 129 | 'nowTime: ' + 130 | datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + 131 | '\n') 132 | logfile.close() 133 | 134 | # Learning Rate Decay 135 | if epoch in self.__C.LR_DECAY_LIST: 136 | #if patience==1: 137 | patience=0 138 | adjust_lr(optim, self.__C.LR_DECAY_R) 139 | 140 | # Externally shuffle 141 | if self.__C.SHUFFLE_MODE == 'external': 142 | shuffle_list(dataset.ans_list) 143 | 144 | time_start = time.time() 145 | # Iteration 146 | for step, (img_feat_iter,ques_ix_iter,ans_iter) in enumerate(dataloader): 147 | 148 | optim.zero_grad() 149 | 150 | img_feat_iter = img_feat_iter.cuda() 151 | ques_ix_iter = ques_ix_iter.cuda() 152 | ans_iter = ans_iter.cuda() 153 | 154 | for accu_step in range(self.__C.GRAD_ACCU_STEPS): 155 | 156 | sub_img_feat_iter = img_feat_iter[accu_step * self.__C.SUB_BATCH_SIZE:(accu_step + 1) * self.__C.SUB_BATCH_SIZE] 157 | sub_ques_ix_iter = ques_ix_iter[accu_step * self.__C.SUB_BATCH_SIZE:(accu_step + 1) * self.__C.SUB_BATCH_SIZE] 158 | sub_ans_iter = ans_iter[accu_step * self.__C.SUB_BATCH_SIZE:(accu_step + 1) * self.__C.SUB_BATCH_SIZE] 159 | 160 | 161 | pred = net(sub_img_feat_iter,sub_ques_ix_iter) 162 | if self.__C.MODEL=='mca': 163 | loss = loss_fn(pred, sub_ans_iter) 164 | else: 165 | pred = torch.log_softmax(pred, dim=-1) 166 | loss = loss_fn(pred, sub_ans_iter) 167 | 168 | loss /= self.__C.GRAD_ACCU_STEPS 169 | loss.backward() 170 | loss_sum += loss.cpu().data.numpy() * self.__C.GRAD_ACCU_STEPS 171 | 172 | if self.__C.VERBOSE: 173 | if dataset_eval is not None: 174 | mode_str = self.__C.SPLIT['train'] + '->' + self.__C.SPLIT['val'] 175 | else: 176 | mode_str = self.__C.SPLIT['train'] + '->' + self.__C.SPLIT['test'] 177 | 178 | print("\r[version %s][epoch %2d][step %4d/%4d][%s] loss: %.4f, lr: %.2e" % ( 179 | self.__C.VERSION,epoch + 1,step,int(data_size / self.__C.BATCH_SIZE),mode_str, 180 | loss.cpu().data.numpy() / self.__C.SUB_BATCH_SIZE,optim._rate), end=' ') 181 | 182 | # Gradient norm clipping 183 | if self.__C.GRAD_NORM_CLIP > 0: 184 | nn.utils.clip_grad_norm_(net.parameters(),self.__C.GRAD_NORM_CLIP) 185 | 186 | # Save the gradient information 187 | for name in range(len(named_params)): 188 | norm_v = torch.norm(named_params[name][1].grad).cpu().data.numpy() \ 189 | if named_params[name][1].grad is not None else 0 190 | grad_norm[name] += norm_v * self.__C.GRAD_ACCU_STEPS 191 | # print('Param %-3s Name %-80s Grad_Norm %-20s'% 192 | # (str(grad_wt), 193 | # params[grad_wt][0], 194 | # str(norm_v))) 195 | 196 | optim.step() 197 | 198 | time_end = time.time() 199 | print('Finished in {}s'.format(int(time_end-time_start))) 200 | 201 | # print('') 202 | epoch_finish = epoch + 1 203 | 204 | # Save checkpoint 205 | state = {'state_dict': net.state_dict(),'optimizer': optim.optimizer.state_dict(),'lr_base': optim.lr_base} 206 | torch.save(state,self.__C.CKPTS_PATH +'ckpt_' + self.__C.VERSION +'/epoch' + str(epoch_finish) +'.pkl') 207 | 208 | # Logging 209 | logfile = open(self.__C.LOG_PATH +'log_run_' + self.__C.VERSION + '.txt','a+') 210 | logfile.write('epoch = ' + str(epoch_finish) +' loss = ' + str(loss_sum / data_size) +'\n' +'lr = ' + str(optim._rate) +'\n\n') 211 | logfile.close() 212 | 213 | # Eval after every epoch 214 | if dataset_eval is not None: 215 | acc = self.eval(dataset_eval,state_dict=net.state_dict(),valid=True) 216 | if acc>=best_acc: 217 | best_acc=acc 218 | patience=0 219 | else: 220 | patience+=1 221 | # if self.__C.VERBOSE: 222 | # logfile = open( 223 | # self.__C.LOG_PATH + 224 | # 'log_run_' + self.__C.VERSION + '.txt', 225 | # 'a+' 226 | # ) 227 | # for name in range(len(named_params)): 228 | # logfile.write( 229 | # 'Param %-3s Name %-80s Grad_Norm %-25s\n' % ( 230 | # str(name), 231 | # named_params[name][0], 232 | # str(grad_norm[name] / data_size * self.__C.BATCH_SIZE) 233 | # ) 234 | # ) 235 | # logfile.write('\n') 236 | # logfile.close() 237 | 238 | loss_sum = 0 239 | grad_norm = np.zeros(len(named_params)) 240 | 241 | 242 | # Evaluation 243 | def eval(self, dataset, state_dict=None, valid=False): 244 | 245 | # Load parameters 246 | if self.__C.CKPT_PATH is not None: 247 | print('Warning: you are now using CKPT_PATH args, ' 248 | 'CKPT_VERSION and CKPT_EPOCH will not work') 249 | 250 | path = self.__C.CKPT_PATH 251 | else: 252 | path = self.__C.CKPTS_PATH + \ 253 | 'ckpt_' + self.__C.CKPT_VERSION + \ 254 | '/epoch' + str(self.__C.CKPT_EPOCH) + '.pkl' 255 | 256 | val_ckpt_flag = False 257 | if state_dict is None: 258 | val_ckpt_flag = True 259 | print('Loading ckpt {}'.format(path)) 260 | state_dict = torch.load(path)['state_dict'] 261 | print('Finish!') 262 | 263 | # Store the prediction list 264 | qid_list = [ques['question_id'] for ques in dataset.ques_list] 265 | ans_ix_list = [] 266 | pred_list = [] 267 | 268 | data_size = dataset.data_size 269 | token_size = dataset.token_size 270 | ans_size = dataset.ans_size 271 | pretrained_emb = dataset.pretrained_emb 272 | 273 | gen_funcs = {"softmax": torch.softmax,"sparsemax": partial(sparsemax, k=512),"tvmax": "tvmax"} 274 | 275 | gen_func = gen_funcs[self.__C.gen_func] 276 | 277 | if self.__C.MODEL=='mca': 278 | net = Net(self.__C,pretrained_emb,token_size,ans_size,gen_func=gen_func) 279 | else: 280 | net = Net_mfb(self.__C,pretrained_emb,token_size,ans_size,gen_func=gen_func) 281 | 282 | net.cuda() 283 | net.eval() 284 | 285 | if self.__C.N_GPU > 1: 286 | net = nn.DataParallel(net, device_ids=self.__C.DEVICES) 287 | 288 | net.load_state_dict(state_dict) 289 | 290 | dataloader = Data.DataLoader( 291 | dataset, 292 | batch_size=self.__C.EVAL_BATCH_SIZE, 293 | shuffle=False, 294 | num_workers=self.__C.NUM_WORKERS, 295 | pin_memory=True 296 | ) 297 | 298 | for step, ( 299 | img_feat_iter, 300 | ques_ix_iter, 301 | ans_iter 302 | ) in enumerate(dataloader): 303 | print("\rEvaluation: [step %4d/%4d]" % ( 304 | step, 305 | int(data_size / self.__C.EVAL_BATCH_SIZE), 306 | ), end=' ') 307 | 308 | img_feat_iter = img_feat_iter.cuda() 309 | ques_ix_iter = ques_ix_iter.cuda() 310 | 311 | pred = net( 312 | img_feat_iter, 313 | ques_ix_iter 314 | ) 315 | pred_np = pred.cpu().data.numpy() 316 | pred_argmax = np.argmax(pred_np, axis=1) 317 | 318 | # Save the answer index 319 | if pred_argmax.shape[0] != self.__C.EVAL_BATCH_SIZE: 320 | pred_argmax = np.pad( 321 | pred_argmax, 322 | (0, self.__C.EVAL_BATCH_SIZE - pred_argmax.shape[0]), 323 | mode='constant', 324 | constant_values=-1 325 | ) 326 | 327 | ans_ix_list.append(pred_argmax) 328 | 329 | # Save the whole prediction vector 330 | if self.__C.TEST_SAVE_PRED: 331 | if pred_np.shape[0] != self.__C.EVAL_BATCH_SIZE: 332 | pred_np = np.pad( 333 | pred_np, 334 | ((0, self.__C.EVAL_BATCH_SIZE - pred_np.shape[0]), (0, 0)), 335 | mode='constant', 336 | constant_values=-1 337 | ) 338 | 339 | pred_list.append(pred_np) 340 | 341 | print('') 342 | ans_ix_list = np.array(ans_ix_list).reshape(-1) 343 | 344 | result = [{ 345 | 'answer': dataset.ix_to_ans[str(ans_ix_list[qix])], # ix_to_ans(load with json) keys are type of string 346 | 'question_id': int(qid_list[qix]) 347 | }for qix in range(qid_list.__len__())] 348 | 349 | # Write the results to result file 350 | if valid: 351 | if val_ckpt_flag: 352 | result_eval_file = \ 353 | self.__C.CACHE_PATH + \ 354 | 'result_run_' + self.__C.CKPT_VERSION + \ 355 | '.json' 356 | else: 357 | result_eval_file = \ 358 | self.__C.CACHE_PATH + \ 359 | 'result_run_' + self.__C.VERSION + \ 360 | '.json' 361 | 362 | else: 363 | if self.__C.CKPT_PATH is not None: 364 | result_eval_file = \ 365 | self.__C.RESULT_PATH + \ 366 | 'result_run_' + self.__C.CKPT_VERSION + \ 367 | '.json' 368 | else: 369 | result_eval_file = \ 370 | self.__C.RESULT_PATH + \ 371 | 'result_run_' + self.__C.CKPT_VERSION + \ 372 | '_epoch' + str(self.__C.CKPT_EPOCH) + \ 373 | '.json' 374 | 375 | print('Save the result to file: {}'.format(result_eval_file)) 376 | 377 | json.dump(result, open(result_eval_file, 'w')) 378 | 379 | # Save the whole prediction vector 380 | if self.__C.TEST_SAVE_PRED: 381 | 382 | if self.__C.CKPT_PATH is not None: 383 | ensemble_file = \ 384 | self.__C.PRED_PATH + \ 385 | 'result_run_' + self.__C.CKPT_VERSION + \ 386 | '.json' 387 | else: 388 | ensemble_file = \ 389 | self.__C.PRED_PATH + \ 390 | 'result_run_' + self.__C.CKPT_VERSION + \ 391 | '_epoch' + str(self.__C.CKPT_EPOCH) + \ 392 | '.json' 393 | 394 | print('Save the prediction vector to file: {}'.format(ensemble_file)) 395 | 396 | pred_list = np.array(pred_list).reshape(-1, ans_size) 397 | result_pred = [{ 398 | 'pred': pred_list[qix], 399 | 'question_id': int(qid_list[qix]) 400 | }for qix in range(qid_list.__len__())] 401 | 402 | pickle.dump(result_pred, open(ensemble_file, 'wb+'), protocol=-1) 403 | 404 | 405 | # Run validation script 406 | if valid: 407 | # create vqa object and vqaRes object 408 | ques_file_path = self.__C.QUESTION_PATH['val'] 409 | ans_file_path = self.__C.ANSWER_PATH['val'] 410 | 411 | vqa = VQA(ans_file_path, ques_file_path) 412 | vqaRes = vqa.loadRes(result_eval_file, ques_file_path) 413 | 414 | # create vqaEval object by taking vqa and vqaRes 415 | vqaEval = VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 416 | 417 | # evaluate results 418 | """ 419 | If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function 420 | By default it uses all the question ids in annotation file 421 | """ 422 | vqaEval.evaluate() 423 | 424 | # print accuracies 425 | print("\n") 426 | print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 427 | # print("Per Question Type Accuracy is the following:") 428 | # for quesType in vqaEval.accuracy['perQuestionType']: 429 | # print("%s : %.02f" % (quesType, vqaEval.accuracy['perQuestionType'][quesType])) 430 | # print("\n") 431 | print("Per Answer Type Accuracy is the following:") 432 | for ansType in vqaEval.accuracy['perAnswerType']: 433 | print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 434 | print("\n") 435 | 436 | if val_ckpt_flag: 437 | print('Write to log file: {}'.format( 438 | self.__C.LOG_PATH + 439 | 'log_run_' + self.__C.CKPT_VERSION + '.txt', 440 | 'a+') 441 | ) 442 | 443 | logfile = open( 444 | self.__C.LOG_PATH + 445 | 'log_run_' + self.__C.CKPT_VERSION + '.txt', 446 | 'a+' 447 | ) 448 | 449 | else: 450 | print('Write to log file: {}'.format( 451 | self.__C.LOG_PATH + 452 | 'log_run_' + self.__C.VERSION + '.txt', 453 | 'a+') 454 | ) 455 | 456 | logfile = open( 457 | self.__C.LOG_PATH + 458 | 'log_run_' + self.__C.VERSION + '.txt', 459 | 'a+' 460 | ) 461 | 462 | logfile.write("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 463 | for ansType in vqaEval.accuracy['perAnswerType']: 464 | logfile.write("%s : %.02f " % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 465 | logfile.write("\n\n") 466 | logfile.close() 467 | return vqaEval.accuracy['overall'] 468 | 469 | 470 | def run(self, run_mode): 471 | if run_mode == 'train': 472 | self.empty_log(self.__C.VERSION) 473 | self.train(self.dataset, self.dataset_eval) 474 | 475 | elif run_mode == 'val': 476 | self.eval(self.dataset, valid=True) 477 | 478 | elif run_mode == 'test': 479 | self.eval(self.dataset) 480 | 481 | else: 482 | exit(-1) 483 | 484 | 485 | def empty_log(self, version): 486 | print('Initializing log file ........') 487 | if (os.path.exists(self.__C.LOG_PATH + 'log_run_' + version + '.txt')): 488 | os.remove(self.__C.LOG_PATH + 'log_run_' + version + '.txt') 489 | print('Finished!') 490 | print('') 491 | 492 | 493 | 494 | 495 | -------------------------------------------------------------------------------- /core/model/__pycache__/mca.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/mca.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/mfb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/mfb.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/net.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/net_mfb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/net_mfb.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/net_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/net_utils.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/optim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/optim.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/tv2d_layer_2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/tv2d_layer_2.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/tv2d_layer_batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/tv2d_layer_batch.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/__pycache__/tv2d_numba.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/core/model/__pycache__/tv2d_numba.cpython-36.pyc -------------------------------------------------------------------------------- /core/model/basis_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | class BasisFunctions(object): 6 | def __init__(self): 7 | pass 8 | 9 | def __len__(self): 10 | """Number of basis functions.""" 11 | pass 12 | 13 | def evaluate(self, t): 14 | pass 15 | 16 | def integrate_t2_times_psi(self, a, b): 17 | """Compute integral int_a^b (t**2) * psi(t).""" 18 | pass 19 | 20 | def integrate_t_times_psi(self, a, b): 21 | """Compute integral int_a^b t * psi(t).""" 22 | pass 23 | 24 | def integrate_psi(self, a, b): 25 | """Compute integral int_a^b psi(t).""" 26 | pass 27 | 28 | 29 | class GaussianBasisFunctions(BasisFunctions): 30 | """ 31 | Function phi(t)=Gaussian(t;Mu,Sigma) 32 | Mu and Sigma obtained from the data (probability density function) 33 | self.mu = mu_j 34 | self.sigma = sigma_j 35 | """ 36 | def __init__(self, mu, sigma): 37 | self.mu = mu.unsqueeze(0) #torch.Size([1, N, 2, 1]) 38 | self.sigma = sigma.unsqueeze(0) #torch.Size([1, N, 2, 2]) 39 | 40 | 41 | def __repr__(self): 42 | return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})" 43 | 44 | def __len__(self): 45 | """Number of basis functions.""" 46 | #self.mu=[1,N,2,1] 47 | return self.mu.size(1) 48 | 49 | def _phi(self, t, sigma): 50 | sigma_inv= 1/2. * (sigma.inverse()+ torch.transpose(sigma.inverse(),-1,-2)) #to avoid numerical problems 51 | return 1. / (2. * math.pi * ((sigma.det().unsqueeze(2).unsqueeze(3))**(1./2.)) )* torch.exp(-.5 * torch.transpose(t,-1,-2) @ sigma_inv @ t) 52 | 53 | def _integrate_product_of_gaussians(self, Mu, Sigma): 54 | sigma = self.sigma + Sigma #torch.Size([batch, N, 2, 2]) 55 | return self._phi(Mu - self.mu, sigma) 56 | 57 | def evaluate(self, t): 58 | return self._phi((t-self.mu), self.sigma) 59 | 60 | def integrate_t2_times_psi_gaussian(self, Mu, Sigma): 61 | """Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t). 62 | """ 63 | 64 | S_tilde = self._integrate_product_of_gaussians(Mu, Sigma) 65 | sigma_tilde = ((1/2. * (Sigma.inverse() + torch.transpose(Sigma.inverse(),-1,-2))) + (1/2. * (self.sigma.inverse() + torch.transpose(self.sigma.inverse(),-1,-2)))) 66 | sigma_tilde=(1/2. * (sigma_tilde.inverse() + torch.transpose(sigma_tilde.inverse(),-1,-2))) 67 | mu_tilde= sigma_tilde @ ((1/2. * (Sigma.inverse() + torch.transpose(Sigma.inverse(),-1,-2))) @ Mu + (1/2. * (self.sigma.inverse() + torch.transpose(self.sigma.inverse(),-1,-2))) @ self.mu) 68 | 69 | return S_tilde * (sigma_tilde + mu_tilde @ torch.transpose(mu_tilde,-2,-1)) 70 | 71 | def integrate_t_times_psi_gaussian(self, Mu, Sigma): 72 | """Compute integral int N(t; Mu, Sigma) * t * psi(t). 73 | """ 74 | S_tilde = self._integrate_product_of_gaussians(Mu, Sigma) 75 | sigma_tilde = ((1/2. * (Sigma.inverse() + torch.transpose(Sigma.inverse(),-1,-2))) + (1/2. * (self.sigma.inverse() + torch.transpose(self.sigma.inverse(),-1,-2)))) 76 | sigma_tilde=(1/2. * (sigma_tilde.inverse() + torch.transpose(sigma_tilde.inverse(),-1,-2))) 77 | mu_tilde= sigma_tilde @ ((1/2. * (Sigma.inverse() + torch.transpose(Sigma.inverse(),-1,-2))) @ Mu + (1/2. * (self.sigma.inverse() + torch.transpose(self.sigma.inverse(),-1,-2))) @ self.mu) 78 | 79 | return S_tilde * mu_tilde 80 | 81 | def integrate_psi_gaussian(self, Mu, Sigma): 82 | """Compute integral int N(t; Mu, Sigma) * psi(t).""" 83 | return self._integrate_product_of_gaussians(Mu, Sigma) 84 | 85 | 86 | # adding this functions for 2D continuous sparsemax 87 | 88 | def sqrtm(self, M): 89 | # M is a 2x2 positive define matrix 90 | # M([batch, N, 2, 2]) 91 | device=M.device 92 | dtype=M.dtype 93 | 94 | s=torch.sqrt(M[:,0,0,0]*M[:,0,1,1]-M[:,0,0,1]*M[:,0,0,1]) 95 | t=torch.sqrt(M[:,0,0,0]+M[:,0,1,1]+2.*s) 96 | identity = torch.eye(2,dtype=dtype, device=device).unsqueeze(0) 97 | batch_identity = identity.repeat(M.size(0), 1, 1).unsqueeze(1) 98 | 99 | return (1./t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)*(M+s.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)*batch_identity)) 100 | 101 | def inv(self, M): 102 | # to avoid numerical problems 103 | return (1/2. * (M.inverse()+ torch.transpose(M.inverse(),-1,-2))) 104 | 105 | def get_radius_parameters(self, theta, mu_tilde, Sigma_tilde): 106 | device=theta.device 107 | dtype=theta.dtype 108 | 109 | inv_Sigma_tilde = self.inv(Sigma_tilde) 110 | a = torch.tensor([[math.cos(theta)], [math.sin(theta)]], dtype=dtype, device=device) 111 | sigma_sq = 1./(torch.transpose(a,-1,-2) @ inv_Sigma_tilde @ a) # [batch, N, 1, 1] 112 | r0 = sigma_sq * torch.transpose(a,-1,-2) @ inv_Sigma_tilde @ mu_tilde # [batch, N, 1, 1] 113 | P = inv_Sigma_tilde - (sigma_sq * inv_Sigma_tilde @ a @ torch.transpose(a,-1,-2) @ inv_Sigma_tilde) # [batch, N, 2, 2] 114 | s_tilde = torch.sqrt(sigma_sq / (2. * math.pi * Sigma_tilde.det().unsqueeze(-1).unsqueeze(-1))) * torch.exp(-.5 * torch.transpose(mu_tilde,-1,-2) @ P @ mu_tilde) # [batch, N, 1, 1] 115 | sigma=torch.sqrt(sigma_sq) # [batch, N, 1, 1] 116 | 117 | return sigma, r0, s_tilde 118 | 119 | def norm_uni_gaussian(self, t): 120 | # normalized univariate gaussian (mu=0, sigma_sq=1) 121 | return ((1./ (math.sqrt(2.*math.pi))) * torch.exp(-(t ** 2) / 2.)) # [batch, N, 1, 1] 122 | 123 | def integrate_through_radius(self, theta, mu_tilde, Sigma_tilde): 124 | # returns f_theta*s_tilde 125 | # for the forward pass 126 | sigma, r0, s_tilde = self.get_radius_parameters(theta, mu_tilde, Sigma_tilde) 127 | 128 | f_theta = (self.norm_uni_gaussian((1-r0)/sigma)*(2*sigma**3 + sigma*(r0**2 +r0)) + 129 | self.norm_uni_gaussian(-r0/sigma)*(-2.*sigma**3 - sigma*(r0**2 - 1)) - 130 | (torch.erf((1-r0)/(math.sqrt(2.)*sigma)) - torch.erf(-r0/(math.sqrt(2.)*sigma))) * (r0**3 + (3*sigma**2 - 1)*r0)/2 131 | ) 132 | 133 | return (f_theta * s_tilde) 134 | 135 | def integrate_through_radius_t_N(self, theta, mu_tilde, Sigma_tilde, Mu, Sigma): 136 | device=theta.device 137 | dtype=theta.dtype 138 | 139 | a = torch.tensor([[math.cos(theta)], [math.sin(theta)]], dtype=dtype, device=device) 140 | sigma, r0, s_tilde = self.get_radius_parameters(theta, mu_tilde, Sigma_tilde) 141 | lbd = -torch.sqrt(1 / (math.pi * torch.sqrt(Sigma.det()))).unsqueeze(-1).unsqueeze(-1) # [batch, 1, 1, 1] 142 | 143 | const=torch.sqrt(-2. * lbd) * self.sqrtm(Sigma) @ a * sigma 144 | g_theta= ( (((const*r0) + (Mu*sigma)) * self.norm_uni_gaussian(-r0/sigma)) + 145 | - ((const*(1+r0)+(Mu*sigma)) * self.norm_uni_gaussian((1-r0)/sigma)) + 146 | .5 * (((torch.sqrt(-2. * lbd) * self.sqrtm(Sigma) @ a)*(sigma**2 + r0**2) + Mu*r0) * (torch.erf((1-r0)/(math.sqrt(2)*sigma))-torch.erf(-r0/(math.sqrt(2.)*sigma)))) 147 | ) 148 | 149 | return (g_theta * s_tilde) # [batch, N, 2, 1] 150 | 151 | def integrate_through_radius_ttT_N(self, theta, mu_tilde, Sigma_tilde, Mu, Sigma): 152 | device=theta.device 153 | dtype=theta.dtype 154 | 155 | a = torch.tensor([[math.cos(theta)], [math.sin(theta)]], dtype=dtype, device=device) 156 | sigma, r0, s_tilde = self.get_radius_parameters(theta, mu_tilde, Sigma_tilde) 157 | lbd = -torch.sqrt(1 / (math.pi * torch.sqrt(Sigma.det()))).unsqueeze(-1).unsqueeze(-1) # [batch, 1, 1, 1] 158 | sqrtm_Sigma=self.sqrtm(Sigma) 159 | 160 | A= (-2.*lbd) * sqrtm_Sigma @ a @ torch.transpose(a,-1,-2) @ torch.transpose(sqrtm_Sigma, -1, -2) 161 | B = torch.sqrt(-2.*lbd) * ((sqrtm_Sigma @ a @ torch.transpose(Mu,-1,-2)) + (Mu @ torch.transpose(a, -1, -2) @ torch.transpose(sqrtm_Sigma, -1, -2)) ) 162 | C= Mu @ torch.transpose(Mu, -1, -2) 163 | A_l = (sigma ** 3) * A 164 | B_l = (sigma ** 2) * (3*r0*A + B) 165 | C_l = sigma * ((3*(r0**2)*A) + (2*r0*B) + C) 166 | D_l = ((r0**3)*A) + ((r0**2)*B) + (r0*C) 167 | 168 | m_theta = ( ( ((2+(-r0/sigma)**2)*A_l - (r0/sigma)*B_l + C_l) * (self.norm_uni_gaussian(-r0/sigma))) - 169 | ( ( (2+((1-r0)/sigma)**2)*A_l + ((1-r0)/sigma)*B_l + C_l) * (self.norm_uni_gaussian((1-r0)/sigma))) + 170 | (.5 * (B_l + D_l) * (torch.erf((1-r0)/(math.sqrt(2)*sigma))-torch.erf(-r0/(math.sqrt(2.)*sigma)))) 171 | ) 172 | 173 | return (m_theta * s_tilde) # [batch, N, 2, 2] 174 | 175 | def integrate_through_radius_N(self, theta, mu_tilde, Sigma_tilde): 176 | sigma, r0, s_tilde = self.get_radius_parameters(theta, mu_tilde, Sigma_tilde) 177 | 178 | h_theta=(sigma * (self.norm_uni_gaussian(-r0/sigma) - self.norm_uni_gaussian((1-r0)/sigma)) + 179 | (r0/2)*(torch.erf((1-r0)/(math.sqrt(2)*sigma)) - torch.erf(-r0/(math.sqrt(2.)*sigma))) 180 | ) 181 | 182 | return (h_theta * s_tilde) # [batch, N, 1, 1] 183 | 184 | def integrate_psi(self, Mu, Sigma): 185 | # returns the result for the forward pass 186 | # simple sum with n_points for the numerical integration 187 | lbd = -torch.sqrt(1. / (math.pi * torch.sqrt(Sigma.det()))).unsqueeze(-1).unsqueeze(-1) # [batch, 1, 1, 1] 188 | mu_tilde= (1./ torch.sqrt(-2.*lbd) * self.inv(self.sqrtm(Sigma))) @ (self.mu-Mu) # [batch, N, 2, 1] 189 | Sigma_tilde= 1. / (-2.*lbd) * self.inv(self.sqrtm(Sigma))@ self.sigma @ self.inv(self.sqrtm(Sigma)) # [batch, N, 2, 2] 190 | 191 | n_points=100 # integrate with 100 points 192 | values=torch.zeros(mu_tilde.size(0),mu_tilde.size(1),n_points,1).cuda() 193 | theta=torch.linspace(0,2.*math.pi,n_points).cuda() 194 | 195 | for i in range(n_points): 196 | values[:,:,i]= (-lbd * self.integrate_through_radius(theta[i], mu_tilde, Sigma_tilde)).squeeze(-1) 197 | 198 | result=(2*math.pi * torch.mean(values, dim=2)).unsqueeze(-1) 199 | 200 | return result # [batch, N, 1, 1] 201 | 202 | def integrate_t_times_psi(self, Mu, Sigma): 203 | # returns the result of the first integral for the backward pass 204 | # simple sum with n_points for the numerical integration 205 | lbd = -torch.sqrt(1 / (math.pi * torch.sqrt(Sigma.det()))).unsqueeze(-1).unsqueeze(-1) # [batch, 1, 1, 1] 206 | mu_tilde= (1./ torch.sqrt(-2*lbd) * self.inv(self.sqrtm(Sigma))) @ (self.mu-Mu) # [batch, N, 2, 1] 207 | Sigma_tilde= 1. / (-2*lbd) * self.inv(self.sqrtm(Sigma))@ self.sigma @ self.inv(self.sqrtm(Sigma)) # [batch, N, 2, 2] 208 | 209 | n_points=100 # integrate with 100 points 210 | values=torch.zeros(mu_tilde.size(0),mu_tilde.size(1),n_points,2).cuda() 211 | theta=torch.linspace(0,2*math.pi,n_points).cuda() 212 | 213 | for i in range(n_points): 214 | values[:,:,i]= (self.integrate_through_radius_t_N(theta[i], mu_tilde, Sigma_tilde, Mu, Sigma)).reshape([mu_tilde.size(0),mu_tilde.size(1),2]) 215 | 216 | result=(2*math.pi * torch.mean(values, dim=2)).reshape([mu_tilde.size(0),mu_tilde.size(1),2,1]) 217 | 218 | return result # [batch, N, 2, 1] 219 | 220 | def integrate_t2_times_psi(self, Mu, Sigma): 221 | # returns the result of the third integral for the backward pass 222 | # simple sum with n_points for the numerical integration 223 | lbd = -torch.sqrt(1 / (math.pi * torch.sqrt(Sigma.det()))).unsqueeze(-1).unsqueeze(-1) # [batch, 1, 1, 1] 224 | mu_tilde= (1./ torch.sqrt(-2*lbd) * self.inv(self.sqrtm(Sigma))) @ (self.mu-Mu) # [batch, N, 2, 1] 225 | Sigma_tilde= 1. / (-2*lbd) * self.inv(self.sqrtm(Sigma))@ self.sigma @ self.inv(self.sqrtm(Sigma)) # [batch, N, 2, 2] 226 | 227 | n_points=100 # integrate with 100 points 228 | values=torch.zeros(mu_tilde.size(0),mu_tilde.size(1),n_points,4).cuda() 229 | theta=torch.linspace(0,2*math.pi,n_points).cuda() 230 | 231 | for i in range(n_points): 232 | values[:,:,i]= (self.integrate_through_radius_ttT_N(theta[i], mu_tilde, Sigma_tilde, Mu, Sigma)).reshape([mu_tilde.size(0),mu_tilde.size(1),4]) 233 | 234 | result=(2*math.pi * torch.mean(values, dim=2)).reshape([mu_tilde.size(0),mu_tilde.size(1),2,2]) 235 | 236 | return result # [batch, N, 2, 2] 237 | 238 | def integrate_normal(self, Mu, Sigma): 239 | lbd = -torch.sqrt(1 / (math.pi * torch.sqrt(Sigma.det()))).unsqueeze(-1).unsqueeze(-1) # [batch, 1, 1, 1] 240 | mu_tilde= (1./ torch.sqrt(-2*lbd) * self.inv(self.sqrtm(Sigma))) @ (self.mu-Mu) # [batch, N, 2, 1] 241 | Sigma_tilde= 1. / (-2*lbd) * self.inv(self.sqrtm(Sigma))@ self.sigma @ self.inv(self.sqrtm(Sigma)) # [batch, N, 2, 2] 242 | 243 | n_points=100 # integrate with 100 points 244 | values=torch.zeros(mu_tilde.size(0),mu_tilde.size(1),n_points,1).cuda() 245 | theta=torch.linspace(0,2*math.pi,n_points).cuda() 246 | 247 | for i in range(n_points): 248 | values[:,:,i]= (self.integrate_through_radius_N(theta[i], mu_tilde, Sigma_tilde)).squeeze(-1) 249 | 250 | result=(2*math.pi * torch.mean(values, dim=2)).unsqueeze(-1) 251 | 252 | return result # [batch, N, 1, 1] 253 | 254 | def area_ellipse(self, Mu, Sigma): 255 | lbd = -torch.sqrt(1 / (math.pi * torch.sqrt(Sigma.det()))).unsqueeze(-1).unsqueeze(-1) # [batch, 1, 1, 1] 256 | op= self.inv(Sigma) / (-2.*lbd) 257 | area = math.pi / (torch.sqrt(op.det())) 258 | return area # [batch,1] 259 | 260 | def aux(self, Mu, Sigma): 261 | aux=(Mu@torch.transpose(Mu,-1,-2)) + (Sigma/(self.area_ellipse(Mu, Sigma).unsqueeze(-1).unsqueeze(-1))) 262 | return aux # [batch,1,2,2] 263 | -------------------------------------------------------------------------------- /core/model/continuous_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # torch.autograd.set_detect_anomaly(True) 5 | 6 | 7 | class ContinuousSoftmaxFunction(torch.autograd.Function): 8 | 9 | @classmethod 10 | def _expectation_phi_psi(cls, ctx, Mu, Sigma): 11 | """Compute expectation of phi(t) * psi(t).T under N(mu, sigma_sq).""" 12 | num_basis = [len(basis_functions) for basis_functions in ctx.psi] 13 | total_basis = sum(num_basis) 14 | V = torch.zeros((Mu.shape[0], 6, total_basis), dtype=ctx.dtype, device=ctx.device) 15 | offsets = torch.cumsum(torch.IntTensor(num_basis).to(ctx.device), dim=0) 16 | start = 0 17 | for j, basis_functions in enumerate(ctx.psi): 18 | V[:, 0, start:offsets[j]]=basis_functions.integrate_t_times_psi_gaussian(Mu,Sigma).squeeze(-1)[:,:,0] 19 | V[:, 1, start:offsets[j]]=basis_functions.integrate_t_times_psi_gaussian(Mu,Sigma).squeeze(-1)[:,:,1] 20 | V[:, 2, start:offsets[j]]=basis_functions.integrate_t2_times_psi_gaussian(Mu,Sigma)[:,:,0,0] 21 | V[:, 3, start:offsets[j]]=basis_functions.integrate_t2_times_psi_gaussian(Mu,Sigma)[:,:,0,1] 22 | V[:, 4, start:offsets[j]]=basis_functions.integrate_t2_times_psi_gaussian(Mu,Sigma)[:,:,1,0] 23 | V[:, 5, start:offsets[j]]=basis_functions.integrate_t2_times_psi_gaussian(Mu,Sigma)[:,:,1,1] 24 | start = offsets[j] 25 | return V # [batch,6,N] 26 | 27 | 28 | @classmethod 29 | def _expectation_psi(cls, ctx, Mu, Sigma): 30 | """Compute expectation of psi under N(mu, sigma_sq).""" 31 | num_basis = [len(basis_functions) for basis_functions in ctx.psi] 32 | total_basis = sum(num_basis) 33 | r = torch.zeros(Mu.shape[0], total_basis, dtype=ctx.dtype, device=ctx.device) 34 | offsets = torch.cumsum(torch.IntTensor(num_basis).to(ctx.device), dim=0) 35 | start = 0 36 | for j, basis_functions in enumerate(ctx.psi): 37 | r[:, start:offsets[j]] = basis_functions.integrate_psi_gaussian(Mu, Sigma).squeeze(-2).squeeze(-1) 38 | start = offsets[j] 39 | return r # [batch,N] 40 | 41 | @classmethod 42 | def _expectation_phi(cls, ctx, Mu, Sigma): 43 | v = torch.zeros(Mu.shape[0], 6, dtype=ctx.dtype, device=ctx.device) 44 | v[:, 0:2]=Mu.squeeze(1).squeeze(-1) 45 | v[:, 2:6]=((Mu @ torch.transpose(Mu,-1,-2)) + Sigma).view(-1,4) 46 | return v # [batch,6] 47 | 48 | 49 | @classmethod 50 | def forward(cls, ctx, theta, psi): 51 | # We assume a Gaussian 52 | # We have: 53 | # Mu:[batch,1,2,1] and Sigma:[batch,1,2,2] 54 | #theta=[(Sigma)^-1 @ Mu, -0.5*(Sigma)^-1] 55 | #theta: batch x 6 56 | #phi(t)=[t,tt^t] 57 | #p(t)= Gaussian(t; Mu, Sigma) 58 | 59 | ctx.dtype = theta.dtype 60 | ctx.device = theta.device 61 | ctx.psi = psi 62 | 63 | Sigma=(-2*theta[:,2:6].view(-1,2,2)) 64 | Sigma=(1/2. * (Sigma.inverse() + torch.transpose(Sigma.inverse(),-1,-2))).unsqueeze(1) # torch.Size([batch, 1, 2, 2]) 65 | Mu=(Sigma @ (theta[:,0:2].view(-1,2,1)).unsqueeze(1)) # torch.Size([batch, 1, 2, 1]) 66 | 67 | r=cls._expectation_psi(ctx, Mu, Sigma) 68 | ctx.save_for_backward(Mu, Sigma, r) 69 | return r # [batch, N] 70 | 71 | @classmethod 72 | def backward(cls, ctx, grad_output): 73 | Mu, Sigma, r = ctx.saved_tensors 74 | J = cls._expectation_phi_psi(ctx, Mu, Sigma) # batch,6,N 75 | e_phi = cls._expectation_phi(ctx, Mu, Sigma) # batch,6 76 | e_psi = cls._expectation_psi(ctx, Mu, Sigma) # batch,N 77 | J -= torch.bmm(e_phi.unsqueeze(2), e_psi.unsqueeze(1)) 78 | grad_input = torch.matmul(J, grad_output.unsqueeze(2)).squeeze(2) 79 | return grad_input, None 80 | 81 | class ContinuousSoftmax(nn.Module): 82 | def __init__(self, psi=None): 83 | super(ContinuousSoftmax, self).__init__() 84 | self.psi = psi 85 | 86 | def forward(self, theta): 87 | return ContinuousSoftmaxFunction.apply(theta, self.psi) 88 | -------------------------------------------------------------------------------- /core/model/continuous_sparsemax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # torch.autograd.set_detect_anomaly(True) 5 | 6 | 7 | class ContinuousSparsemaxFunction(torch.autograd.Function): 8 | 9 | @classmethod 10 | def _expectation_phi_psi(cls, ctx, Mu, Sigma): 11 | """Compute expectation of phi(t) * psi(t).T under N(mu, sigma_sq).""" 12 | num_basis = [len(basis_functions) for basis_functions in ctx.psi] 13 | total_basis = sum(num_basis) 14 | V = torch.zeros((Mu.shape[0], 6, total_basis), dtype=ctx.dtype, device=ctx.device) 15 | offsets = torch.cumsum(torch.IntTensor(num_basis).to(ctx.device), dim=0) 16 | start = 0 17 | for j, basis_functions in enumerate(ctx.psi): 18 | integral_t_times_psi=(basis_functions.integrate_t_times_psi(Mu,Sigma).squeeze(-1)).to(ctx.device) 19 | integral_t2_times_psi=basis_functions.integrate_t2_times_psi(Mu,Sigma).to(ctx.device) 20 | 21 | V[:, 0, start:offsets[j]]=integral_t_times_psi[:,:,0] 22 | V[:, 1, start:offsets[j]]=integral_t_times_psi[:,:,1] 23 | V[:, 2, start:offsets[j]]=integral_t2_times_psi[:,:,0,0] 24 | V[:, 3, start:offsets[j]]=integral_t2_times_psi[:,:,0,1] 25 | V[:, 4, start:offsets[j]]=integral_t2_times_psi[:,:,1,0] 26 | V[:, 5, start:offsets[j]]=integral_t2_times_psi[:,:,1,1] 27 | start = offsets[j] 28 | return V # [batch,6,N] 29 | 30 | 31 | @classmethod 32 | def _expectation_psi(cls, ctx, Mu, Sigma): 33 | """Compute expectation of psi under N(mu, sigma_sq).""" 34 | num_basis = [len(basis_functions) for basis_functions in ctx.psi] 35 | total_basis = sum(num_basis) 36 | r = torch.zeros(Mu.shape[0], total_basis, dtype=ctx.dtype, device=ctx.device) 37 | offsets = torch.cumsum(torch.IntTensor(num_basis).to(ctx.device), dim=0) 38 | start = 0 39 | for j, basis_functions in enumerate(ctx.psi): 40 | r[:, start:offsets[j]] = basis_functions.integrate_psi(Mu, Sigma).squeeze(-2).squeeze(-1) 41 | start = offsets[j] 42 | return r # [batch,N] 43 | 44 | 45 | @classmethod 46 | def _expectation_phi(cls, ctx, Mu, Sigma): 47 | 48 | num_basis = [len(basis_functions) for basis_functions in ctx.psi] 49 | total_basis = sum(num_basis) 50 | v = torch.zeros((Mu.shape[0], 6, total_basis), dtype=ctx.dtype, device=ctx.device) 51 | offsets = torch.cumsum(torch.IntTensor(num_basis).to(ctx.device), dim=0) 52 | start = 0 53 | 54 | for j, basis_functions in enumerate(ctx.psi): 55 | integral_normal=basis_functions.integrate_normal(Mu, Sigma).to(ctx.device) # [batch, N, 1, 1] 56 | aux=(basis_functions.aux(Mu, Sigma)).to(ctx.device) 57 | 58 | v[:, 0, start:offsets[j]]=Mu.squeeze(-1)[:,:,0] * integral_normal.squeeze(-1).squeeze(-1) 59 | v[:, 1, start:offsets[j]]=Mu.squeeze(-1)[:,:,1] * integral_normal.squeeze(-1).squeeze(-1) 60 | v[:, 2, start:offsets[j]]=(aux * integral_normal)[:,:,0,0] 61 | v[:, 3, start:offsets[j]]=(aux * integral_normal)[:,:,0,1] 62 | v[:, 4, start:offsets[j]]=(aux * integral_normal)[:,:,1,0] 63 | v[:, 5, start:offsets[j]]=(aux * integral_normal)[:,:,1,1] 64 | start = offsets[j] 65 | return v # [batch,6,N] 66 | 67 | 68 | @classmethod 69 | def forward(cls, ctx, theta, psi): 70 | # We assume a Gaussian 71 | # We have: 72 | # Mu:[batch,1,2,1] and Sigma:[batch,1,2,2] 73 | #theta=[(Sigma)^-1 @ Mu, -0.5*(Sigma)^-1] 74 | #theta: batch x 6 75 | #phi(t)=[t,tt^t] 76 | #p(t)= Gaussian(t; Mu, Sigma) 77 | 78 | ctx.dtype = theta.dtype 79 | ctx.device = theta.device 80 | ctx.psi = psi 81 | 82 | Sigma=(-2*theta[:,2:6].view(-1,2,2)) 83 | Sigma=(1/2. * (Sigma.inverse() + torch.transpose(Sigma.inverse(),-1,-2))).unsqueeze(1) # torch.Size([batch, 1, 2, 2]) 84 | Mu=(Sigma @ (theta[:,0:2].view(-1,2,1)).unsqueeze(1)) # torch.Size([batch, 1, 2, 1]) 85 | 86 | r=cls._expectation_psi(ctx, Mu, Sigma) 87 | ctx.save_for_backward(Mu, Sigma, r) 88 | return r # [batch, N] 89 | 90 | @classmethod 91 | def backward(cls, ctx, grad_output): 92 | Mu, Sigma, r = ctx.saved_tensors 93 | J = cls._expectation_phi_psi(ctx, Mu, Sigma) - cls._expectation_phi(ctx, Mu, Sigma) # batch,6,N 94 | grad_input = torch.matmul(J, grad_output.unsqueeze(2)).squeeze(2) 95 | return grad_input, None 96 | 97 | class ContinuousSparsemax(nn.Module): 98 | def __init__(self, psi=None): 99 | super(ContinuousSparsemax, self).__init__() 100 | self.psi = psi 101 | 102 | def forward(self, theta): 103 | return ContinuousSparsemaxFunction.apply(theta, self.psi) 104 | -------------------------------------------------------------------------------- /core/model/mca.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | from core.model.net_utils import FC, MLP, LayerNorm 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch, math 12 | 13 | from functools import partial 14 | from core.model.tv2d_layer_2 import TV2DFunction 15 | from entmax import sparsemax 16 | 17 | 18 | # ------------------------------ 19 | # ---- Multi-Head Attention ---- 20 | # ------------------------------ 21 | 22 | class MHAtt(nn.Module): 23 | def __init__(self, __C, gen_func=torch.softmax): 24 | super(MHAtt, self).__init__() 25 | self.__C = __C 26 | 27 | self.linear_v = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 28 | self.linear_k = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 29 | self.linear_q = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 30 | self.linear_merge = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 31 | 32 | self.dropout = nn.Dropout(__C.DROPOUT_R) 33 | 34 | if str(gen_func)=='tvmax': 35 | self.gen_func='tvmax' 36 | self.sparsemax = partial(sparsemax, k=512) 37 | self.tvmax = TV2DFunction.apply 38 | else: 39 | self.gen_func = gen_func 40 | 41 | 42 | 43 | def forward(self, v, k, q, mask): 44 | n_batches = q.size(0) 45 | 46 | v = self.linear_v(v).view(n_batches,-1,self.__C.MULTI_HEAD,self.__C.HIDDEN_SIZE_HEAD).transpose(1, 2) 47 | 48 | k = self.linear_k(k).view(n_batches,-1,self.__C.MULTI_HEAD,self.__C.HIDDEN_SIZE_HEAD).transpose(1, 2) 49 | 50 | q = self.linear_q(q).view(n_batches,-1,self.__C.MULTI_HEAD,self.__C.HIDDEN_SIZE_HEAD).transpose(1, 2) 51 | 52 | atted = self.att(v, k, q, mask) 53 | atted = atted.transpose(1, 2).contiguous().view(n_batches,-1,self.__C.HIDDEN_SIZE) 54 | 55 | atted = self.linear_merge(atted) 56 | 57 | return atted 58 | 59 | def att(self, value, key, query, mask): 60 | 61 | d_k = query.size(-1) 62 | 63 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 64 | if mask is not None: 65 | scores = scores.masked_fill(mask, -1e9) 66 | aux=0 67 | 68 | if str(self.gen_func)=='tvmax': 69 | att_map = self.sparsemax(scores, dim=-1) 70 | else: 71 | att_map = self.gen_func(scores, dim=-1) 72 | 73 | att_map = self.dropout(att_map) 74 | return torch.matmul(att_map, value) 75 | 76 | 77 | # --------------------------- 78 | # ---- Feed Forward Nets ---- 79 | # --------------------------- 80 | 81 | class FFN(nn.Module): 82 | def __init__(self, __C): 83 | super(FFN, self).__init__() 84 | 85 | self.mlp = MLP( 86 | in_size=__C.HIDDEN_SIZE, 87 | mid_size=__C.FF_SIZE, 88 | out_size=__C.HIDDEN_SIZE, 89 | dropout_r=__C.DROPOUT_R, 90 | use_relu=True) 91 | 92 | def forward(self, x): 93 | return self.mlp(x) 94 | 95 | 96 | # ------------------------ 97 | # ---- Self Attention ---- 98 | # ------------------------ 99 | 100 | class SA(nn.Module): 101 | def __init__(self, __C,gen_func=torch.softmax): 102 | super(SA, self).__init__() 103 | 104 | self.mhatt = MHAtt(__C,gen_func=gen_func) 105 | self.ffn = FFN(__C) 106 | 107 | self.dropout1 = nn.Dropout(__C.DROPOUT_R) 108 | self.norm1 = LayerNorm(__C.HIDDEN_SIZE) 109 | 110 | self.dropout2 = nn.Dropout(__C.DROPOUT_R) 111 | self.norm2 = LayerNorm(__C.HIDDEN_SIZE) 112 | 113 | def forward(self, x, x_mask): 114 | x = self.norm1(x + self.dropout1(self.mhatt(x, x, x, x_mask))) 115 | 116 | x = self.norm2(x + self.dropout2(self.ffn(x))) 117 | 118 | return x 119 | 120 | 121 | # ------------------------------- 122 | # ---- Self Guided Attention ---- 123 | # ------------------------------- 124 | 125 | class SGA(nn.Module): 126 | def __init__(self, __C,gen_func=torch.softmax): 127 | super(SGA, self).__init__() 128 | 129 | self.mhatt1 = MHAtt(__C,gen_func=gen_func) 130 | self.mhatt2 = MHAtt(__C) 131 | self.ffn = FFN(__C) 132 | 133 | self.dropout1 = nn.Dropout(__C.DROPOUT_R) 134 | self.norm1 = LayerNorm(__C.HIDDEN_SIZE) 135 | 136 | self.dropout2 = nn.Dropout(__C.DROPOUT_R) 137 | self.norm2 = LayerNorm(__C.HIDDEN_SIZE) 138 | 139 | self.dropout3 = nn.Dropout(__C.DROPOUT_R) 140 | self.norm3 = LayerNorm(__C.HIDDEN_SIZE) 141 | 142 | def forward(self, x, y, x_mask, y_mask): 143 | y = self.norm1(y + self.dropout1(self.mhatt1(y, y, y, y_mask))) 144 | y = self.norm2(y + self.dropout2(self.mhatt2(x, x, y, x_mask))) 145 | y = self.norm3(y + self.dropout3(self.ffn(y))) 146 | 147 | return y 148 | 149 | 150 | # ------------------------------------------------ 151 | # ---- MAC Layers Cascaded by Encoder-Decoder ---- 152 | # ------------------------------------------------ 153 | 154 | class MCA_ED(nn.Module): 155 | def __init__(self, __C, gen_func=torch.softmax): 156 | super(MCA_ED, self).__init__() 157 | 158 | self.enc_list = nn.ModuleList([SA(__C, gen_func=torch.softmax) for _ in range(__C.LAYER)]) 159 | self.dec_list = nn.ModuleList([SGA(__C, gen_func=gen_func) for _ in range(__C.LAYER)]) 160 | 161 | def forward(self, x, y, x_mask, y_mask): 162 | # Get hidden vector 163 | 164 | for enc in self.enc_list: 165 | x = enc(x, x_mask) 166 | 167 | 168 | for dec in self.dec_list: 169 | y = dec(x, y, x_mask, y_mask) 170 | 171 | return x, y 172 | -------------------------------------------------------------------------------- /core/model/mfb.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Pengbing Gao https://github.com/nbgao 5 | # -------------------------------------------------------- 6 | 7 | from core.model.net_utils import MLP 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from functools import partial 12 | from core.model.tv2d_layer_2 import TV2DFunction 13 | from entmax import sparsemax 14 | 15 | # ------------------------------------------------------------- 16 | # ---- Multi-Model Hign-order Bilinear Pooling Co-Attention---- 17 | # ------------------------------------------------------------- 18 | 19 | 20 | class MFB(nn.Module): 21 | def __init__(self, __C, img_feat_size, ques_feat_size, is_first): 22 | super(MFB, self).__init__() 23 | self.__C = __C 24 | self.is_first = is_first 25 | self.proj_i = nn.Linear(img_feat_size, __C.MFB_K * __C.MFB_O) 26 | self.proj_q = nn.Linear(ques_feat_size, __C.MFB_K * __C.MFB_O) 27 | self.dropout = nn.Dropout(__C.DROPOUT_R) 28 | self.pool = nn.AvgPool1d(__C.MFB_K, stride=__C.MFB_K) 29 | 30 | def forward(self, img_feat, ques_feat, exp_in=1): 31 | ''' 32 | img_feat.size() -> (N, C, img_feat_size) C = 1 or 100 33 | ques_feat.size() -> (N, 1, ques_feat_size) 34 | z.size() -> (N, C, MFB_O) 35 | exp_out.size() -> (N, C, K*O) 36 | ''' 37 | batch_size = img_feat.shape[0] 38 | img_feat = self.proj_i(img_feat) # (N, C, K*O) 39 | ques_feat = self.proj_q(ques_feat) # (N, 1, K*O) 40 | 41 | exp_out = img_feat * ques_feat # (N, C, K*O) 42 | exp_out = self.dropout(exp_out) if self.is_first else self.dropout(exp_out * exp_in) # (N, C, K*O) 43 | z = self.pool(exp_out) * self.__C.MFB_K # (N, C, O) 44 | z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z)) 45 | z = F.normalize(z.view(batch_size, -1)) # (N, C*O) 46 | z = z.view(batch_size, -1, self.__C.MFB_O) # (N, C, O) 47 | return z, exp_out 48 | 49 | 50 | class QAtt(nn.Module): 51 | def __init__(self, __C): 52 | super(QAtt, self).__init__() 53 | self.__C = __C 54 | self.mlp = MLP( 55 | in_size=__C.LSTM_OUT_SIZE, 56 | mid_size=__C.HIDDEN_SIZE, 57 | out_size=__C.Q_GLIMPSES, 58 | dropout_r=__C.DROPOUT_R, 59 | use_relu=True 60 | ) 61 | 62 | def forward(self, ques_feat): 63 | ''' 64 | ques_feat.size() -> (N, T, LSTM_OUT_SIZE) 65 | qatt_feat.size() -> (N, LSTM_OUT_SIZE * Q_GLIMPSES) 66 | ''' 67 | qatt_maps = self.mlp(ques_feat) # (N, T, Q_GLIMPSES) 68 | qatt_maps = F.softmax(qatt_maps, dim=1) # (N, T, Q_GLIMPSES) 69 | 70 | qatt_feat_list = [] 71 | for i in range(self.__C.Q_GLIMPSES): 72 | mask = qatt_maps[:, :, i:i + 1] # (N, T, 1) 73 | mask = mask * ques_feat # (N, T, LSTM_OUT_SIZE) 74 | mask = torch.sum(mask, dim=1) # (N, LSTM_OUT_SIZE) 75 | qatt_feat_list.append(mask) 76 | qatt_feat = torch.cat(qatt_feat_list, dim=1) # (N, LSTM_OUT_SIZE*Q_GLIMPSES) 77 | 78 | return qatt_feat 79 | 80 | 81 | class IAtt(nn.Module): 82 | def __init__(self, __C, img_feat_size, ques_att_feat_size, gen_func): 83 | super(IAtt, self).__init__() 84 | self.__C = __C 85 | self.dropout = nn.Dropout(__C.DROPOUT_R) 86 | self.mfb = MFB(__C, img_feat_size, ques_att_feat_size, True) 87 | self.mlp = MLP( 88 | in_size=__C.MFB_O, 89 | mid_size=__C.HIDDEN_SIZE, 90 | out_size=__C.I_GLIMPSES, 91 | dropout_r=__C.DROPOUT_R, 92 | use_relu=True) 93 | 94 | if str(gen_func)=='tvmax': 95 | self.gen_func='tvmax' 96 | self.sparsemax = partial(sparsemax, k=512) 97 | self.tvmax = TV2DFunction.apply 98 | else: 99 | self.gen_func = gen_func 100 | 101 | def forward(self, img_feat, ques_att_feat): 102 | ''' 103 | img_feats.size() -> (N, C, FRCN_FEAT_SIZE) 104 | ques_att_feat.size() -> (N, LSTM_OUT_SIZE * Q_GLIMPSES) 105 | iatt_feat.size() -> (N, MFB_O * I_GLIMPSES) 106 | ''' 107 | ques_att_feat = ques_att_feat.unsqueeze(1) # (N, 1, LSTM_OUT_SIZE * Q_GLIMPSES) 108 | img_feat = self.dropout(img_feat) 109 | z, _ = self.mfb(img_feat, ques_att_feat) # (N, C, O) 110 | 111 | iatt_maps = self.mlp(z) # (N, C, I_GLIMPSES) 112 | 113 | if str(self.gen_func)=='tvmax': 114 | iatt_maps = iatt_maps.transpose(1,2) 115 | 116 | for i in range(iatt_maps.size(0)): 117 | for j in range(iatt_maps.size(1)): 118 | iatt_maps[i,j] = self.tvmax(iatt_maps[i,j].view(14,14)).view(14*14) 119 | iatt_maps = iatt_maps.transpose(1,2) 120 | 121 | iatt_maps = self.sparsemax(iatt_maps,dim=1) 122 | else: 123 | iatt_maps = self.gen_func(iatt_maps, dim=1) # (N, C, I_GLIMPSES) 124 | 125 | iatt_feat_list = [] 126 | for i in range(self.__C.I_GLIMPSES): 127 | mask = iatt_maps[:, :, i:i + 1] # (N, C, 1) 128 | mask = mask * img_feat # (N, C, FRCN_FEAT_SIZE) 129 | mask = torch.sum(mask, dim=1) # (N, FRCN_FEAT_SIZE) 130 | iatt_feat_list.append(mask) 131 | iatt_feat = torch.cat(iatt_feat_list, dim=1) # (N, FRCN_FEAT_SIZE*I_GLIMPSES) 132 | 133 | return iatt_feat 134 | 135 | 136 | class CoAtt(nn.Module): 137 | def __init__(self, __C, gen_func=torch.softmax): 138 | super(CoAtt, self).__init__() 139 | self.__C = __C 140 | 141 | img_feat_size = __C.HIDDEN_SIZE 142 | img_att_feat_size = img_feat_size * __C.I_GLIMPSES 143 | ques_att_feat_size = __C.LSTM_OUT_SIZE * __C.Q_GLIMPSES 144 | 145 | self.q_att = QAtt(__C) 146 | self.i_att = IAtt(__C, img_feat_size, ques_att_feat_size, gen_func) 147 | 148 | if self.__C.HIGH_ORDER: # MFH 149 | self.mfh1 = MFB(__C, img_att_feat_size, ques_att_feat_size, True) 150 | self.mfh2 = MFB(__C, img_att_feat_size, ques_att_feat_size, False) 151 | else: # MFB 152 | self.mfb = MFB(__C, img_att_feat_size, ques_att_feat_size, True) 153 | 154 | def forward(self, img_feat, ques_feat): 155 | ''' 156 | img_feat.size() -> (N, C, FRCN_FEAT_SIZE) 157 | ques_feat.size() -> (N, T, LSTM_OUT_SIZE) 158 | z.size() -> MFH:(N, 2*O) / MFB:(N, O) 159 | ''' 160 | ques_feat = self.q_att(ques_feat) # (N, LSTM_OUT_SIZE*Q_GLIMPSES) 161 | fuse_feat = self.i_att(img_feat, ques_feat) # (N, FRCN_FEAT_SIZE*I_GLIMPSES) 162 | 163 | if self.__C.HIGH_ORDER: # MFH 164 | z1, exp1 = self.mfh1(fuse_feat.unsqueeze(1), ques_feat.unsqueeze(1)) # z1:(N, 1, O) exp1:(N, C, K*O) 165 | z2, _ = self.mfh2(fuse_feat.unsqueeze(1), ques_feat.unsqueeze(1), exp1) # z2:(N, 1, O) _:(N, C, K*O) 166 | z = torch.cat((z1.squeeze(1), z2.squeeze(1)), 1) # (N, 2*O) 167 | else: # MFB 168 | z, _ = self.mfb(fuse_feat.unsqueeze(1), ques_feat.unsqueeze(1)) # z:(N, 1, O) _:(N, C, K*O) 169 | z = z.squeeze(1) # (N, O) 170 | 171 | return z 172 | -------------------------------------------------------------------------------- /core/model/net.py: -------------------------------------------------------------------------------- 1 | from core.model.net_utils import FC, MLP, LayerNorm 2 | from core.model.mca import MCA_ED 3 | 4 | import numpy as np 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch 9 | from core.model.tv2d_layer_2 import TV2DFunction 10 | from entmax import sparsemax 11 | from functools import partial 12 | from torch import Tensor 13 | 14 | from core.model.basis_functions import GaussianBasisFunctions 15 | from core.model.continuous_sparsemax import ContinuousSparsemax 16 | from core.model.continuous_softmax import ContinuousSoftmax 17 | import math 18 | 19 | 20 | # -------------------------------------------------------------- 21 | # ---- Flatten the sequence (image in continuous attention) ---- 22 | # -------------------------------------------------------------- 23 | 24 | class AttFlat(nn.Module): 25 | def __init__(self, __C, gen_func=torch.softmax): 26 | super(AttFlat, self).__init__() 27 | self.__C = __C 28 | 29 | self.attention=__C.attention 30 | self.gen_func=gen_func 31 | 32 | if str(gen_func)=='tvmax': 33 | self.sparsemax = partial(sparsemax, k=512) 34 | self.tvmax = TV2DFunction.apply 35 | 36 | self.mlp = MLP( 37 | in_size=__C.HIDDEN_SIZE, 38 | mid_size=__C.FLAT_MLP_SIZE, 39 | out_size=__C.FLAT_GLIMPSES, 40 | dropout_r=__C.DROPOUT_R, 41 | use_relu=True) 42 | 43 | self.linear_merge = nn.Linear(__C.HIDDEN_SIZE * __C.FLAT_GLIMPSES,__C.FLAT_OUT_SIZE) 44 | 45 | if (self.attention=='cont-sparsemax'): 46 | self.transform = ContinuousSparsemax(psi=None) # use basis functions in 'psi' to define continuous sparsemax 47 | else: 48 | self.transform = ContinuousSoftmax(psi=None) # use basis functions in 'psi' to define continuous softmax 49 | 50 | device='cuda' 51 | 52 | # compute F and G offline for one length = 14*14 = 196 53 | self.Gs = [None] 54 | self.psi = [None] 55 | max_seq_len=14*14 # 196 grid features 56 | attn_num_basis=100 # 100 basis functions 57 | nb_waves=attn_num_basis 58 | self.psi.append([]) 59 | self.add_gaussian_basis_functions(self.psi[1],nb_waves,device=device) 60 | 61 | 62 | # stack basis functions 63 | padding=True 64 | length=max_seq_len 65 | if padding: 66 | shift=1/float(2*math.sqrt(length)) 67 | positions_x = torch.linspace(-0.5+shift, 1.5-shift, int(2*math.sqrt(length))) 68 | positions_x, positions_y=torch.meshgrid(positions_x,positions_x) 69 | positions_x=positions_x.flatten() 70 | positions_y=positions_y.flatten() 71 | else: 72 | shift = 1 / float(2*math.sqrt(length)) 73 | positions_x = torch.linspace(shift, 1-shift, int(math.sqrt(length))) 74 | positions_x, positions_y=torch.meshgrid(positions_x,positions_x) 75 | positions_x=positions_x.flatten() 76 | positions_y=positions_y.flatten() 77 | 78 | positions=torch.zeros(len(positions_x),2,1).to(device) 79 | for position in range(1,len(positions_x)+1): 80 | positions[position-1]=torch.tensor([[positions_x[position-1]],[positions_y[position-1]]]) 81 | 82 | F = torch.zeros(nb_waves, positions.size(0)).unsqueeze(2).unsqueeze(3).to(device) # torch.Size([N, 196, 1, 1]) 83 | # print(positions.size()) # torch.Size([196, 2, 1]) 84 | basis_functions = self.psi[1][0] 85 | # print(basis_functions.evaluate(positions[0]).size()) # torch.Size([N, 1, 1]) 86 | 87 | for i in range(0,positions.size(0)): 88 | F[:,i]=basis_functions.evaluate(positions[i])[:] 89 | 90 | penalty = .01 # Ridge penalty 91 | I = torch.eye(nb_waves).to(device) 92 | F=F.squeeze(-2).squeeze(-1) # torch.Size([N, 196]) 93 | G = F.t().matmul((F.matmul(F.t()) + penalty * I).inverse()) # torch.Size([196, N]) 94 | if padding: 95 | G = G[length:-length, :] 96 | G=torch.cat([G[7:21,:],G[35:49,:],G[63:77,:],G[91:105,:],G[119:133,:],G[147:161,:],G[175:189,:],G[203:217,:],G[231:245,:],G[259:273,:],G[287:301,:],G[315:329,:],G[343:357,:],G[371:385,:]]) 97 | 98 | self.Gs.append(G.to(device)) 99 | 100 | def add_gaussian_basis_functions(self, psi, nb_basis, device): 101 | 102 | steps=int(math.sqrt(nb_basis)) 103 | 104 | mu_x=torch.linspace(0,1,steps) 105 | mu_y=torch.linspace(0,1,steps) 106 | mux,muy=torch.meshgrid(mu_x,mu_y) 107 | mux=mux.flatten() 108 | muy=muy.flatten() 109 | 110 | mus=[] 111 | for mu in range(1,nb_basis+1): 112 | mus.append([[mux[mu-1]],[muy[mu-1]]]) 113 | mus=torch.tensor(mus).to(device) 114 | 115 | sigmas=[] 116 | for sigma in range(1,nb_basis+1): 117 | sigmas.append([[0.001,0.],[0.,0.001]]) # it is possible to change this matrix 118 | sigmas=torch.tensor(sigmas).to(device) # in continuous softmax we have sigmas=torch.DoubleTensor(sigmas).to(device) 119 | 120 | assert mus.size(0) == nb_basis 121 | psi.append(GaussianBasisFunctions(mu=mus, sigma=sigmas)) 122 | 123 | def value_function(self, values, mask=None): 124 | # Approximate B * F = values via multivariate regression. 125 | # Use a ridge penalty. The solution is B = values * G 126 | # x:(batch,L,D) 127 | G = self.Gs[1] 128 | B = torch.transpose(values,-1,-2) @ G 129 | return B 130 | 131 | def forward(self, x, x_mask): 132 | att = self.mlp(x) 133 | att = att.masked_fill(x_mask.squeeze(1).squeeze(1).unsqueeze(2),-1e9) 134 | 135 | if str(self.gen_func)=='tvmax': 136 | att = att.squeeze(-1).view(-1,14,14) 137 | for i in range(att.size(0)): 138 | att[i] = self.tvmax(att[i]) 139 | att = self.sparsemax(att.view(-1,14*14)).unsqueeze(-1) 140 | 141 | else: 142 | att = self.gen_func(att.squeeze(-1), dim=-1).unsqueeze(-1) 143 | 144 | # compute distribution parameters 145 | max_seq_len=196 146 | length=max_seq_len 147 | 148 | positions_x = torch.linspace(0., 1., int(math.sqrt(length))) 149 | positions_x, positions_y=torch.meshgrid(positions_x,positions_x) 150 | positions_x=positions_x.flatten() 151 | positions_y=positions_y.flatten() 152 | positions=torch.zeros(len(positions_x),2,1).to(x.device) 153 | for position in range(1,len(positions_x)+1): 154 | positions[position-1]=torch.tensor([[positions_x[position-1]],[positions_y[position-1]]]) 155 | 156 | # positions: (196, 2, 1) 157 | # positions.unsqueeze(0): (1, 196, 2, 1) 158 | # att.unsqueeze(-1): (batch, 196, 1, 1) 159 | Mu= torch.sum(positions.unsqueeze(0) @ att.unsqueeze(-1), 1) # (batch, 2, 1) 160 | Sigma=torch.sum(((positions @ torch.transpose(positions,-1,-2)).unsqueeze(0) * att.unsqueeze(-1)),1) - (Mu @ torch.transpose(Mu,-1,-2)) # (batch, 2, 2) 161 | Sigma=Sigma + (torch.tensor([[1.,0.],[0.,1.]])*1e-6).to(x.device) # to avoid problems with small values 162 | 163 | 164 | if (self.attention=='cont-sparsemax'): 165 | Sigma=9.*math.pi*torch.sqrt(Sigma.det().unsqueeze(-1).unsqueeze(-1))*Sigma 166 | 167 | # get `mu` and `sigma` as the canonical parameters `theta` 168 | theta1 = ((1/2. * (Sigma.inverse() + torch.transpose(Sigma.inverse(),-1,-2))) @ Mu).flatten(1) 169 | theta2 = (-1. / 2. * (1/2. * (Sigma.inverse() + torch.transpose(Sigma.inverse(),-1,-2)))).flatten(1) 170 | theta = torch.zeros(x.size(0), 6, device=x.device ) #torch.Size([batch, 6]) 171 | theta[:,0:2]=theta1 172 | theta[:,2:6]=theta2 173 | 174 | # map to a probability density over basis functions 175 | self.transform.psi = self.psi[1] 176 | r = self.transform(theta) # batch x nb_basis 177 | 178 | # compute B using a multivariate regression 179 | # batch x D x N 180 | B = self.value_function(x, mask=None) 181 | 182 | # (bs, nb_basis) -> (bs, 1, nb_basis) 183 | r = r.unsqueeze(1) # batch x 1 x nb_basis 184 | 185 | # (bs, hdim, nb_basis) * (bs, nb_basis, 1) -> (bs, hdim, 1) 186 | # get the context vector 187 | # batch x values_size x 1 188 | context = torch.matmul(B, r.transpose(-1, -2)) 189 | context = context.transpose(-1, -2) # batch x 1 x values_size 190 | 191 | att_list = [] 192 | for i in range(self.__C.FLAT_GLIMPSES): 193 | att_list.append(torch.sum(att[:, :, i: i + 1] * x, dim=1)) 194 | 195 | x_atted = torch.cat(att_list, dim=1) # don't need this for continuous attention 196 | 197 | x_atted=context.squeeze(1) # for continuous softmax/sparsemax 198 | 199 | x_atted = self.linear_merge(x_atted) # linear_merge is used to compute Wx 200 | return x_atted 201 | 202 | 203 | 204 | 205 | 206 | # ---------------------------------------------------------------- 207 | # ---- Flatten the sequence (question and discrete attention) ---- 208 | # ---------------------------------------------------------------- 209 | # this is also used to flatten the image features with discrete attention 210 | class AttFlatText(nn.Module): 211 | def __init__(self, __C, gen_func=torch.softmax): 212 | super(AttFlatText, self).__init__() 213 | self.__C = __C 214 | 215 | self.gen_func=gen_func 216 | 217 | if str(gen_func)=='tvmax': 218 | self.sparsemax = partial(sparsemax, k=512) 219 | self.tvmax = TV2DFunction.apply 220 | 221 | self.mlp = MLP( 222 | in_size=__C.HIDDEN_SIZE, 223 | mid_size=__C.FLAT_MLP_SIZE, 224 | out_size=__C.FLAT_GLIMPSES, 225 | dropout_r=__C.DROPOUT_R, 226 | use_relu=True) 227 | 228 | self.linear_merge = nn.Linear(__C.HIDDEN_SIZE * __C.FLAT_GLIMPSES,__C.FLAT_OUT_SIZE) 229 | 230 | def forward(self, x, x_mask): 231 | att = self.mlp(x) 232 | att = att.masked_fill(x_mask.squeeze(1).squeeze(1).unsqueeze(2),-1e9) 233 | 234 | if str(self.gen_func)=='tvmax': 235 | att = att.squeeze(-1).view(-1,14,14) 236 | for i in range(att.size(0)): 237 | att[i] = self.tvmax(att[i]) 238 | att = self.sparsemax(att.view(-1,14*14)).unsqueeze(-1) 239 | 240 | else: 241 | att = self.gen_func(att.squeeze(-1), dim=-1).unsqueeze(-1) 242 | att_list = [] 243 | for i in range(self.__C.FLAT_GLIMPSES): 244 | att_list.append(torch.sum(att[:, :, i: i + 1] * x, dim=1)) 245 | 246 | x_atted = torch.cat(att_list, dim=1) 247 | x_atted = self.linear_merge(x_atted) 248 | return x_atted 249 | 250 | 251 | # ------------------------- 252 | # ---- Main MCAN Model ---- 253 | # ------------------------- 254 | 255 | class Net(nn.Module): 256 | def __init__(self, __C, pretrained_emb, token_size, answer_size, gen_func=torch.softmax): 257 | super(Net, self).__init__() 258 | 259 | self.embedding = nn.Embedding(num_embeddings=token_size, embedding_dim=__C.WORD_EMBED_SIZE) 260 | 261 | # Loading the GloVe embedding weights 262 | if __C.USE_GLOVE: 263 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) 264 | 265 | self.attention=__C.attention #added this 266 | 267 | 268 | #if __C.USE_IMG_POS_EMBEDDINGS: 269 | # self.img_pos_x_embeddings = nn.Embedding(num_embeddings=14, embedding_dim=int(__C.HIDDEN_SIZE/2)) 270 | # torch.nn.init.xavier_uniform_(self.img_pos_x_embeddings.weight) 271 | # self.img_pos_y_embeddings = nn.Embedding(num_embeddings=14, embedding_dim=int(__C.HIDDEN_SIZE/2)) 272 | # torch.nn.init.xavier_uniform_(self.img_pos_y_embeddings.weight) 273 | # self.use_img_pos_embeddings = __C.USE_IMG_POS_EMBEDDINGS 274 | 275 | self.lstm = nn.LSTM( 276 | input_size=__C.WORD_EMBED_SIZE, 277 | hidden_size=__C.HIDDEN_SIZE, 278 | num_layers=1, 279 | batch_first=True) 280 | 281 | self.img_feat_linear = nn.Linear( 282 | __C.IMG_FEAT_SIZE, 283 | __C.HIDDEN_SIZE) 284 | 285 | self.gen_func=gen_func 286 | self.backbone = MCA_ED(__C, gen_func) 287 | 288 | if (self.attention=='discrete'): 289 | self.attflat_img = AttFlatText(__C, self.gen_func) 290 | else: # use continuous attention 291 | self.attflat_img = AttFlat(__C, self.gen_func) 292 | 293 | self.attflat_lang = AttFlatText(__C) 294 | 295 | self.proj_norm = LayerNorm(__C.FLAT_OUT_SIZE) 296 | self.proj = nn.Linear(__C.FLAT_OUT_SIZE, answer_size) 297 | 298 | 299 | 300 | 301 | def forward(self, img_feat, ques_ix): 302 | 303 | # Make mask 304 | lang_feat_mask = self.make_mask(ques_ix.unsqueeze(2)) 305 | img_feat_mask = self.make_mask(img_feat) 306 | 307 | # Pre-process Language Feature 308 | lang_feat = self.embedding(ques_ix) 309 | lang_feat, _ = self.lstm(lang_feat) 310 | 311 | # Pre-process Image Feature 312 | img_feat = self.img_feat_linear(img_feat) 313 | 314 | #if self.use_img_pos_embeddings: 315 | # for i in range(img_feat.size(0)): 316 | # pos = torch.LongTensor(np.mgrid[0:14,0:14]).cuda() 317 | # img_feat[i]+=torch.cat([self.img_pos_x_embeddings(pos[0].view(-1)), self.img_pos_y_embeddings(pos[1].view(-1))],1) 318 | 319 | # Backbone Framework 320 | lang_feat, img_feat = self.backbone( 321 | lang_feat, 322 | img_feat, 323 | lang_feat_mask, 324 | img_feat_mask) 325 | 326 | lang_feat = self.attflat_lang( 327 | lang_feat, 328 | lang_feat_mask) 329 | 330 | img_feat = self.attflat_img( 331 | img_feat, 332 | img_feat_mask) 333 | 334 | proj_feat = lang_feat + img_feat 335 | proj_feat = self.proj_norm(proj_feat) 336 | proj_feat = torch.sigmoid(self.proj(proj_feat)) 337 | 338 | return proj_feat 339 | 340 | 341 | # Masking 342 | def make_mask(self, feature): 343 | return (torch.sum(torch.abs(feature),dim=-1) == 0).unsqueeze(1).unsqueeze(2) 344 | -------------------------------------------------------------------------------- /core/model/net_mfb.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # OpenVQA 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Pengbing Gao https://github.com/nbgao 5 | # -------------------------------------------------------- 6 | 7 | from core.model.mfb import CoAtt 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | # ------------------------------------------------------- 13 | # ---- Main MFB/MFH model with Co-Attention Learning ---- 14 | # ------------------------------------------------------- 15 | 16 | 17 | class Net_mfb(nn.Module): 18 | def __init__(self, __C, pretrained_emb, token_size, answer_size, gen_func=torch.softmax): 19 | super(Net_mfb, self).__init__() 20 | self.__C = __C 21 | 22 | self.embedding = nn.Embedding(num_embeddings=token_size,embedding_dim=__C.WORD_EMBED_SIZE) 23 | 24 | self.img_feat_linear = nn.Linear(__C.IMG_FEAT_SIZE,__C.HIDDEN_SIZE) 25 | 26 | 27 | # Loading the GloVe embedding weights 28 | if __C.USE_GLOVE: 29 | self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) 30 | 31 | self.lstm = nn.LSTM( 32 | input_size=__C.WORD_EMBED_SIZE, 33 | hidden_size=__C.LSTM_OUT_SIZE, 34 | num_layers=1, 35 | batch_first=True) 36 | 37 | self.gen_func=gen_func 38 | 39 | self.dropout = nn.Dropout(__C.DROPOUT_R) 40 | self.dropout_lstm = nn.Dropout(__C.DROPOUT_R) 41 | self.backbone = CoAtt(__C, gen_func) 42 | 43 | if __C.HIGH_ORDER: # MFH 44 | self.proj = nn.Linear(2*__C.MFB_O, answer_size) 45 | else: # MFB 46 | self.proj = nn.Linear(__C.MFB_O, answer_size) 47 | 48 | def forward(self, img_feat, ques_ix): 49 | 50 | # Make mask 51 | lang_feat_mask = self.make_mask(ques_ix.unsqueeze(2)) 52 | img_feat_mask = self.make_mask(img_feat) 53 | 54 | # Pre-process Image Feature 55 | img_feat = self.img_feat_linear(img_feat) 56 | 57 | # Pre-process Language Feature 58 | ques_feat = self.embedding(ques_ix) # (N, T, WORD_EMBED_SIZE) 59 | ques_feat = self.dropout(ques_feat) 60 | ques_feat, _ = self.lstm(ques_feat) # (N, T, LSTM_OUT_SIZE) 61 | ques_feat = self.dropout_lstm(ques_feat) 62 | 63 | z = self.backbone(img_feat, ques_feat) # MFH:(N, 2*O) / MFB:(N, O) 64 | proj_feat = self.proj(z) # (N, answer_size) 65 | 66 | return proj_feat 67 | 68 | # Masking 69 | def make_mask(self, feature): 70 | return (torch.sum(torch.abs(feature),dim=-1) == 0).unsqueeze(1).unsqueeze(2) 71 | -------------------------------------------------------------------------------- /core/model/net_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | import torch.nn as nn 8 | import torch 9 | 10 | 11 | class FC(nn.Module): 12 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True): 13 | super(FC, self).__init__() 14 | self.dropout_r = dropout_r 15 | self.use_relu = use_relu 16 | 17 | self.linear = nn.Linear(in_size, out_size) 18 | 19 | if use_relu: 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | if dropout_r > 0: 23 | self.dropout = nn.Dropout(dropout_r) 24 | 25 | def forward(self, x): 26 | x = self.linear(x) 27 | 28 | if self.use_relu: 29 | x = self.relu(x) 30 | 31 | if self.dropout_r > 0: 32 | x = self.dropout(x) 33 | 34 | return x 35 | 36 | 37 | class MLP(nn.Module): 38 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True): 39 | super(MLP, self).__init__() 40 | 41 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu) 42 | self.linear = nn.Linear(mid_size, out_size) 43 | 44 | def forward(self, x): 45 | return self.linear(self.fc(x)) 46 | 47 | 48 | class LayerNorm(nn.Module): 49 | def __init__(self, size, eps=1e-6): 50 | super(LayerNorm, self).__init__() 51 | self.eps = eps 52 | 53 | self.a_2 = nn.Parameter(torch.ones(size)) 54 | self.b_2 = nn.Parameter(torch.zeros(size)) 55 | 56 | def forward(self, x): 57 | mean = x.mean(-1, keepdim=True) 58 | std = x.std(-1, keepdim=True) 59 | 60 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 61 | 62 | -------------------------------------------------------------------------------- /core/model/optim.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import torch.optim as Optim 9 | 10 | 11 | class WarmupOptimizer(object): 12 | def __init__(self, lr_base, optimizer, data_size, batch_size, model='mca'): 13 | self.optimizer = optimizer 14 | self._step = 0 15 | self.lr_base = lr_base 16 | self._rate = 0 17 | self.data_size = data_size 18 | self.batch_size = batch_size 19 | self.model = model 20 | 21 | 22 | def step(self): 23 | self._step += 1 24 | 25 | rate = self.rate() 26 | for p in self.optimizer.param_groups: 27 | p['lr'] = rate 28 | if self.model=='mca': 29 | self._rate = rate 30 | else: 31 | self._rate = self.lr_base 32 | 33 | self.optimizer.step() 34 | 35 | 36 | def zero_grad(self): 37 | self.optimizer.zero_grad() 38 | 39 | 40 | def rate(self, step=None): 41 | if step is None: 42 | step = self._step 43 | 44 | if step <= int(self.data_size / self.batch_size * 1): 45 | r = self.lr_base * 1/4. 46 | elif step <= int(self.data_size / self.batch_size * 2): 47 | r = self.lr_base * 2/4. 48 | elif step <= int(self.data_size / self.batch_size * 3): 49 | r = self.lr_base * 3/4. 50 | else: 51 | r = self.lr_base 52 | 53 | return r 54 | 55 | 56 | def get_optim(__C, model, data_size, lr_base=None): 57 | if lr_base is None: 58 | lr_base = __C.LR_BASE 59 | 60 | return WarmupOptimizer( 61 | lr_base, 62 | Optim.Adam( 63 | filter(lambda p: p.requires_grad, model.parameters()), 64 | lr=0, 65 | betas=__C.OPT_BETAS, 66 | eps=__C.OPT_EPS 67 | ), 68 | data_size, 69 | __C.BATCH_SIZE, __C.MODEL 70 | ) 71 | 72 | 73 | def adjust_lr(optim, decay_r): 74 | optim.lr_base *= decay_r 75 | -------------------------------------------------------------------------------- /core/model/tv2d_layer_2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import label 3 | import torch 4 | from torch.autograd import Function 5 | from torch.nn import Module 6 | 7 | from core.model.tv2d_numba import prox_tv2d 8 | 9 | from time import perf_counter 10 | from numba import jit 11 | import sys 12 | 13 | @jit(nopython=True) 14 | def isin(x, l): 15 | for i in l: 16 | if x==i: 17 | return True 18 | return False 19 | 20 | @jit(nopython=True) 21 | def back(Y, dX, dY): 22 | neigbhours=list([(1,1)]) 23 | del neigbhours[-1] 24 | group=[(0,0)] 25 | del group[-1] 26 | n=0 27 | idx_grouped = [(200,200)for x in range(196)] 28 | count=0 29 | value=0 30 | s=0 31 | while True: 32 | if len(neigbhours)!=0: 33 | while len(neigbhours)!=0: 34 | if Y[neigbhours[0][0],neigbhours[0][1]] == value: 35 | a = neigbhours[0][0] 36 | b = neigbhours[0][1] 37 | del neigbhours[0] 38 | count+=1 39 | s+=dY[a,b] 40 | group.append((a,b)) 41 | idx_grouped[n]=(a,b) 42 | n+=1 43 | if b<13 and isin((a,b+1), idx_grouped)==False and isin((a,b+1), neigbhours)==False: 44 | neigbhours.append((a,b+1)) 45 | if a<13 and isin((a+1,b), idx_grouped)==False and isin((a+1,b), neigbhours)==False: 46 | neigbhours.append((a+1,b)) 47 | if b>0 and isin((a,b-1), idx_grouped)==False and isin((a,b-1), neigbhours)==False: 48 | neigbhours.append((a,b-1)) 49 | if a>0 and isin((a-1,b), idx_grouped)==False and isin((a-1,b), neigbhours)==False: 50 | neigbhours.append((a-1,b)) 51 | else: 52 | del neigbhours[0] 53 | else: 54 | if len(group)>0: 55 | o=s/count 56 | count=0 57 | for x in group: 58 | dX[x[0],x[1]]=o 59 | group=[(0,0)] 60 | del group[0] 61 | 62 | if n>=196: 63 | break 64 | B=False 65 | for i in range(14): 66 | for j in range(14): 67 | if isin((i,j), idx_grouped)==False: 68 | value = Y[i,j] 69 | s = dY[i,j] 70 | count+=1 71 | group.append((i, j)) 72 | idx_grouped[n] = (i, j) 73 | n+=1 74 | if j<13 and isin((i,j+1), idx_grouped)==False and isin((i,j+1), neigbhours)==False: 75 | neigbhours.append((i,j+1)) 76 | if i<13 and isin((i+1,j), idx_grouped)==False and isin((i+1,j), neigbhours)==False: 77 | neigbhours.append((i+1,j)) 78 | if j>0 and isin((i,j-1), idx_grouped)==False and isin((i,j-1), neigbhours)==False: 79 | neigbhours.append((i,j-1)) 80 | if i>0 and isin((i-1,j), idx_grouped)==False and isin((i-1,j), neigbhours)==False: 81 | neigbhours.append((i-1,j)) 82 | B=True 83 | break 84 | if B: 85 | break 86 | return dX 87 | 88 | class TV2DFunction(Function): 89 | 90 | @staticmethod 91 | def forward(ctx, X, alpha=0.01, max_iter=35, tol=1e-2): 92 | torch.set_num_threads(8) 93 | ctx.digits_tol = int(-np.log10(tol)) // 2 94 | 95 | X_np = X.detach().cpu().numpy() 96 | n_rows, n_cols = X_np.shape 97 | Y_np = prox_tv2d(X_np.ravel(), 98 | step_size=alpha / 2, 99 | n_rows=n_rows, 100 | n_cols=n_cols, 101 | max_iter=max_iter, 102 | tol=tol) 103 | 104 | 105 | Y_np = Y_np.reshape(n_rows, n_cols) 106 | Y = torch.from_numpy(Y_np) # double-precision 107 | Y = torch.as_tensor(Y, dtype=X.dtype, device=X.device) 108 | ctx.save_for_backward(Y.detach()) # TODO figure out why detach everywhere 109 | 110 | return Y 111 | 112 | @staticmethod 113 | def backward(ctx, dY): 114 | #with torch.autograd.profiler.profile(use_cuda=True) as prof) 115 | torch.set_num_threads(8) 116 | Y, = ctx.saved_tensors 117 | """ 118 | tic = perf_counter() 119 | dY_np = dY.cpu().numpy() 120 | dX_np = np.zeros((8,8)) 121 | 122 | Y_np_round = Y.cpu().numpy().round(ctx.digits_tol) 123 | # TODO speed me up. Maybe with scikit-image label? 124 | uniq, inv = np.unique(Y_np_round, return_inverse=True) 125 | 126 | inv = inv.reshape((8,8)) 127 | 128 | for j in range(len(uniq)): 129 | objs, n_objs = label(inv == j) 130 | for k in range(1, n_objs + 1): 131 | obj = objs == k 132 | obj_mean = (obj * dY_np).sum() / obj.sum() 133 | dX_np += obj_mean * obj 134 | #tac=perf_counter() 135 | #print(torch.as_tensor(dX_np, dtype=dY.dtype, device=dY.device)) 136 | #print('vlad', tac-tic) 137 | #tic=perf_counter() 138 | """ 139 | Y_np = np.array(Y.cpu()).round(ctx.digits_tol) 140 | dY_np = np.array(dY.cpu()) 141 | dX = np.zeros((14,14)) 142 | dX = back(Y_np, dX, dY_np) 143 | dX = torch.as_tensor(dX, dtype=dY.dtype, device=dY.device) 144 | #tac=perf_counter() 145 | #print(dX) 146 | #print('pedro', tac-tic) 147 | 148 | return dX, None 149 | 150 | 151 | _tv2d = TV2DFunction.apply 152 | 153 | 154 | class TV2D(Module): 155 | 156 | def __init__(self, alpha=1, max_iter=1000, tol=1e-12): 157 | """2D Total Variation layer 158 | 159 | Computes argmax_P 0.5 ||X - P||^2 + alpha * tv_penalty(P) 160 | 161 | where tv_penalty(P) = sum_j^N sum_i=1^M | P[i, j] - P[i - 1, j] | 162 | + sum_i^M sum_j=1^N | P[i, j] - P[i, j - 1] | 163 | 164 | using Douglas-Rachford splitting, and a direct O(n log n) algorithm for 165 | each row and column subproblem. 166 | 167 | Parameters: 168 | 169 | alpha: float, 170 | the strength of the fused lasso regularization 171 | 172 | max_iter: int, 173 | the number of Douglas-Rachford outer iterations 174 | 175 | tol: int, 176 | fixed-point stopping criteria for Douglas-Rachford. 177 | """ 178 | 179 | self.alpha = alpha 180 | self.max_iter = max_iter 181 | self.tol = tol 182 | 183 | def forward(self, X): 184 | return _tv2d(X, self.alpha, self.max_iter, self.tol) 185 | 186 | 187 | if __name__ == '__main__': 188 | sys.settrace 189 | 190 | X = torch.randn(14, 14, requires_grad=True) 191 | 192 | Y = _tv2d(X) 193 | #tic = perf_counter() 194 | Y[1, 2].backward() 195 | #tac = perf_counter() 196 | #print(tac-tic) 197 | #print(X.grad) 198 | """ 199 | print("Gradient check") 200 | from torch.autograd import gradcheck 201 | for _ in range(20): 202 | X = torch.randn(6, 6, dtype=torch.double, requires_grad=True) 203 | test = gradcheck(_tv2d, X) 204 | print(test) 205 | """ 206 | -------------------------------------------------------------------------------- /core/model/tv2d_numba.py: -------------------------------------------------------------------------------- 1 | # most of this file from Fabian Pedregosa 2 | # https://github.com/openopt/copt/blob/master/copt/tv_prox.py 3 | 4 | import numpy as np 5 | import warnings 6 | try: 7 | from numba import njit 8 | except ImportError: 9 | from functools import wraps 10 | 11 | def njit(*args, **kw): 12 | if len(args) == 1 and len(kw)== 0 and hasattr(args[0], '__call__'): 13 | func = args[0] 14 | @wraps(func) 15 | def inner_function(*args, **kwargs): 16 | return func(*args, **kwargs) 17 | return inner_function 18 | else: 19 | def inner_function(function): 20 | @wraps(function) 21 | def wrapper(*args, **kwargs): 22 | return function(*args, **kwargs) 23 | return wrapper 24 | return inner_function 25 | 26 | 27 | @njit 28 | def _prox_tv1d(step_size, input, output): 29 | """low level function call, no checks are performed""" 30 | width = input.size + 1 31 | index_low = np.zeros(width, dtype=np.int32) 32 | slope_low = np.zeros(width, dtype=input.dtype) 33 | index_up = np.zeros(width, dtype=np.int32) 34 | slope_up = np.zeros(width, dtype=input.dtype) 35 | index = np.zeros(width, dtype=np.int32) 36 | z = np.zeros(width, dtype=input.dtype) 37 | y_low = np.empty(width, dtype=input.dtype) 38 | y_up = np.empty(width, dtype=input.dtype) 39 | s_low, c_low, s_up, c_up, c = 0, 0, 0, 0, 0 40 | y_low[0] = y_up[0] = 0 41 | y_low[1] = input[0] - step_size 42 | y_up[1] = input[0] + step_size 43 | incr = 1 44 | 45 | for i in range(2, width): 46 | y_low[i] = y_low[i-1] + input[(i - 1) * incr] 47 | y_up[i] = y_up[i-1] + input[(i - 1) * incr] 48 | 49 | y_low[width-1] += step_size 50 | y_up[width-1] -= step_size 51 | slope_low[0] = np.inf 52 | slope_up[0] = -np.inf 53 | z[0] = y_low[0] 54 | 55 | for i in range(1, width): 56 | c_low += 1 57 | c_up += 1 58 | index_low[c_low] = index_up[c_up] = i 59 | slope_low[c_low] = y_low[i]-y_low[i-1] 60 | while (c_low > s_low+1) and (slope_low[max(s_low, c_low-1)] <= slope_low[c_low]): 61 | c_low -= 1 62 | index_low[c_low] = i 63 | if c_low > s_low+1: 64 | slope_low[c_low] = (y_low[i]-y_low[index_low[c_low-1]]) / (i-index_low[c_low-1]) 65 | else: 66 | slope_low[c_low] = (y_low[i]-z[c]) / (i-index[c]) 67 | 68 | slope_up[c_up] = y_up[i]-y_up[i-1] 69 | while (c_up > s_up+1) and (slope_up[max(c_up-1, s_up)] >= slope_up[c_up]): 70 | c_up -= 1 71 | index_up[c_up] = i 72 | if c_up > s_up + 1: 73 | slope_up[c_up] = (y_up[i]-y_up[index_up[c_up-1]]) / (i-index_up[c_up-1]) 74 | else: 75 | slope_up[c_up] = (y_up[i]-z[c]) / (i-index[c]) 76 | 77 | while (c_low == s_low+1) and (c_up > s_up+1) and (slope_low[c_low] >= slope_up[s_up+1]): 78 | c += 1 79 | s_up += 1 80 | index[c] = index_up[s_up] 81 | z[c] = y_up[index[c]] 82 | index_low[s_low] = index[c] 83 | slope_low[c_low] = (y_low[i]-z[c]) / (i-index[c]) 84 | while (c_up == s_up+1) and (c_low>s_low+1) and (slope_up[c_up]<=slope_low[s_low+1]): 85 | c += 1 86 | s_low += 1 87 | index[c] = index_low[s_low] 88 | z[c] = y_low[index[c]] 89 | index_up[s_up] = index[c] 90 | slope_up[c_up] = (y_up[i]-z[c]) / (i-index[c]) 91 | 92 | for i in range(1, c_low - s_low + 1): 93 | index[c+i] = index_low[s_low+i] 94 | z[c+i] = y_low[index[c+i]] 95 | c = c + c_low-s_low 96 | j, i = 0, 1 97 | while i <= c: 98 | a = (z[i]-z[i-1]) / (index[i]-index[i-1]) 99 | while j < index[i]: 100 | output[j * incr] = a 101 | output[j * incr] = a 102 | j += 1 103 | i += 1 104 | return 105 | 106 | 107 | @njit 108 | def prox_tv1d_cols(stepsize, a, n_rows, n_cols): 109 | """apply prox_tv1d along columns of the matri a 110 | """ 111 | A = a.reshape((n_rows, n_cols)) 112 | out = np.empty_like(A) 113 | for i in range(n_cols): 114 | _prox_tv1d(stepsize, A[:, i], out[:, i]) 115 | return out.ravel() 116 | 117 | 118 | @njit 119 | def prox_tv1d_rows(stepsize, a, n_rows, n_cols): 120 | """apply prox_tv1d along rows of the matri a 121 | """ 122 | A = a.reshape((n_rows, n_cols)) 123 | out = np.empty_like(A) 124 | for i in range(n_rows): 125 | _prox_tv1d(stepsize, A[i, :], out[i, :]) 126 | return out.ravel() 127 | 128 | 129 | def c_prox_tv2d(step_size, x, n_rows, n_cols, max_iter, tol): 130 | """ 131 | Douglas-Rachford to minimize a 2-dimensional total variation. 132 | Reference: https://arxiv.org/abs/1411.0589 133 | """ 134 | n_features = n_rows * n_cols 135 | p = np.zeros(n_features) 136 | q = np.zeros(n_features) 137 | 138 | for it in range(max_iter): 139 | y = x + p 140 | y = prox_tv1d_cols(step_size, y, n_rows, n_cols) 141 | p += x - y 142 | x = y + q 143 | x = prox_tv1d_rows(step_size, x, n_rows, n_cols) 144 | q += y - x 145 | 146 | # check convergence 147 | accuracy = np.max(np.abs(y - x)) 148 | if accuracy < tol: 149 | break 150 | else: 151 | warnings.warn("prox_tv2d did not converged to desired accuracy\n" + 152 | "Accuracy reached: %s" % accuracy) 153 | return x 154 | 155 | 156 | def prox_tv2d(w, step_size, n_rows, n_cols, max_iter=500, tol=1e-2): 157 | """ 158 | Computes the proximal operator of the 2-dimensional total variation operator. 159 | This solves a problem of the form 160 | argmin_x TV(x) + (1/(2 stepsize)) ||x - w||^2 161 | where TV(x) is the two-dimensional total variation. It does so using the 162 | Douglas-Rachford algorithm [Barbero and Sra, 2014]. 163 | Parameters 164 | ---------- 165 | w: array 166 | vector of coefficients 167 | step_size: float 168 | step size (often denoted gamma) in proximal objective function 169 | max_iter: int 170 | tol: float 171 | References 172 | ---------- 173 | Condat, Laurent. "A direct algorithm for 1D total variation denoising." 174 | IEEE Signal Processing Letters (2013) 175 | Barbero, Alvaro, and Suvrit Sra. "Modular proximal optimization for 176 | multidimensional total-variation regularization." arXiv preprint 177 | arXiv:1411.0589 (2014). 178 | """ 179 | 180 | x = w.copy().astype(np.float64) 181 | return c_prox_tv2d(step_size, x, n_rows, n_cols, max_iter, tol) 182 | 183 | 184 | def tv2d_numba(X, max_iter=1000, tol=1e-2): 185 | n_rows, n_cols = X.shape 186 | x = X.ravel() 187 | p = prox_tv2d(x, step_size=.5, n_rows=n_rows, n_cols=n_cols, 188 | max_iter=max_iter, tol=tol) 189 | return p.reshape(n_rows, n_cols) 190 | 191 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torch.nn import init 8 | import numpy as np 9 | from torch.autograd import gradcheck, Variable 10 | import pickle 11 | import os 12 | from torchvision import transforms 13 | from PIL import Image 14 | 15 | class AttentiveCNN(nn.Module): 16 | def __init__(self): 17 | super(AttentiveCNN, self).__init__() 18 | 19 | # ResNet-152 backend 20 | resnet = models.resnet152(pretrained=True) 21 | modules = list(resnet.children())[:-2] # delete the last fc layer and avg pool. 22 | resnet_conv = nn.Sequential(*modules) # last conv feature 23 | 24 | self.resnet_conv = resnet_conv 25 | 26 | def forward(self, images): 27 | ''' 28 | Input: images 29 | Output: V=[v_1, ..., v_n], v_g 30 | ''' 31 | # Last conv layer feature map 32 | A = self.resnet_conv(images) 33 | # V = [ v_1, v_2, ..., v_49 ] 34 | V = A.view(A.size(0), A.size(1), -1).transpose(1,2) 35 | 36 | return V 37 | 38 | transform = transforms.Compose([ 39 | transforms.Resize((448, 448)), 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.485, 0.456, 0.406), 42 | (0.229, 0.224, 0.225))]) 43 | 44 | model=AttentiveCNN().cuda() 45 | model.eval() 46 | 47 | i=0 48 | paths=[] 49 | for image_path in os.listdir('./val2014'): 50 | i+=1 51 | image = Image.open(os.path.join('./val2014', image_path)).convert('RGB') 52 | image=transform(image) 53 | if len(paths)==0: 54 | images = image.unsqueeze(0) 55 | else: 56 | images = torch.cat([images,image.unsqueeze(0)],0) 57 | paths.append(image_path) 58 | if i%1000==0: 59 | print(i) 60 | if images.size(0)==10: 61 | v = model(images.cuda()) 62 | for j in range(v.size(0)): 63 | np.savez('./features/val/'+paths[j].replace('.jpg',''),v[j].clone().detach().cpu().numpy()) 64 | paths=[] 65 | del(images) 66 | del(v) 67 | v = model(images.cuda()) 68 | for j in range(v.size(0)): 69 | np.savez('./features/val/'+paths[j].replace('.jpg',''),v[j].clone().detach().cpu().numpy()) 70 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | from cfgs.base_cfgs import Cfgs 8 | from core.exec import Execution 9 | import argparse, yaml 10 | 11 | 12 | def parse_args(): 13 | ''' 14 | Parse input arguments 15 | ''' 16 | parser = argparse.ArgumentParser(description='MCAN Args') 17 | 18 | parser.add_argument('--M', dest='MODEL', 19 | default='mca', 20 | type=str, required=True) 21 | 22 | parser.add_argument('--RUN', dest='RUN_MODE', 23 | choices=['train', 'val', 'test'], 24 | help='{train, val, test}', 25 | type=str, required=True) 26 | 27 | parser.add_argument('--MODEL', dest='MODEL_size', 28 | choices=['small', 'large'], 29 | help='{small, large}', 30 | default='small', type=str) 31 | 32 | parser.add_argument('--SPLIT', dest='TRAIN_SPLIT', 33 | choices=['train', 'train+val', 'train+val+vg'], 34 | help="set training split, " 35 | "eg.'train', 'train+val+vg'" 36 | "set 'train' can trigger the " 37 | "eval after every epoch", 38 | type=str) 39 | 40 | parser.add_argument('--EVAL_EE', dest='EVAL_EVERY_EPOCH', 41 | help='set True to evaluate the ' 42 | 'val split when an epoch finished' 43 | "(only work when train with " 44 | "'train' split)", 45 | type=bool) 46 | 47 | parser.add_argument('--SAVE_PRED', dest='TEST_SAVE_PRED', 48 | help='set True to save the ' 49 | 'prediction vectors' 50 | '(only work in testing)', 51 | type=bool) 52 | 53 | parser.add_argument('--BS', dest='BATCH_SIZE', 54 | help='batch size during training', 55 | type=int) 56 | 57 | parser.add_argument('--MAX_EPOCH', dest='MAX_EPOCH', 58 | help='max training epoch', 59 | type=int) 60 | 61 | parser.add_argument('--PRELOAD', dest='PRELOAD', 62 | help='pre-load the features into memory' 63 | 'to increase the I/O speed', 64 | type=bool) 65 | 66 | parser.add_argument('--GPU', dest='GPU', 67 | help="gpu select, eg.'0, 1, 2'", 68 | type=str) 69 | 70 | parser.add_argument('--SEED', dest='SEED', 71 | help='fix random seed', 72 | type=int) 73 | 74 | parser.add_argument('--VERSION', dest='VERSION', 75 | help='version control', 76 | type=str) 77 | 78 | parser.add_argument('--RESUME', dest='RESUME', 79 | help='resume training', 80 | type=bool) 81 | 82 | parser.add_argument('--CKPT_V', dest='CKPT_VERSION', 83 | help='checkpoint version', 84 | type=str) 85 | 86 | parser.add_argument('--CKPT_E', dest='CKPT_EPOCH', 87 | help='checkpoint epoch', 88 | type=int) 89 | 90 | parser.add_argument('--CKPT_PATH', dest='CKPT_PATH', 91 | help='load checkpoint path, we ' 92 | 'recommend that you use ' 93 | 'CKPT_VERSION and CKPT_EPOCH ' 94 | 'instead', 95 | type=str) 96 | 97 | parser.add_argument('--ACCU', dest='GRAD_ACCU_STEPS', 98 | help='reduce gpu memory usage', 99 | type=int) 100 | 101 | parser.add_argument('--NW', dest='NUM_WORKERS', 102 | help='multithreaded loading', 103 | type=int) 104 | 105 | parser.add_argument('--PINM', dest='PIN_MEM', 106 | help='use pin memory', 107 | type=bool) 108 | 109 | parser.add_argument('--VERB', dest='VERBOSE', 110 | help='verbose print', 111 | type=bool) 112 | 113 | parser.add_argument('--DATA_PATH', dest='DATASET_PATH', 114 | help='vqav2 dataset root path', 115 | type=str) 116 | 117 | parser.add_argument('--FEAT_PATH', dest='FEATURE_PATH', 118 | help='bottom up features root path', 119 | type=str) 120 | 121 | parser.add_argument('--POS_EMB', dest='USE_IMG_POS_EMBEDDINGS', 122 | help='verbose print', 123 | type=bool) 124 | 125 | parser.add_argument('--gen_func', default='softmax') 126 | 127 | parser.add_argument('--attention', default='discrete') 128 | 129 | parser.add_argument('--MFB_O', default=1000, type=int) 130 | 131 | parser.add_argument('--MFB_K', default=5, type=int) 132 | 133 | parser.add_argument('--I_GLIMPSES', default=2, type=int) 134 | 135 | parser.add_argument('--Q_GLIMPSES', default=2, type=int) 136 | 137 | parser.add_argument('--LSTM_OUT_SIZE', default=1024, type=int) 138 | args = parser.parse_args() 139 | return args 140 | 141 | 142 | if __name__ == '__main__': 143 | __C = Cfgs() 144 | 145 | args = parse_args() 146 | args_dict = __C.parse_to_dict(args) 147 | 148 | cfg_file = "cfgs/{}_model.yml".format(args.MODEL) 149 | print(cfg_file) 150 | with open(cfg_file, 'r') as f: 151 | yaml_dict = yaml.load(f) 152 | 153 | args_dict = {**yaml_dict, **args_dict} 154 | __C.add_args(args_dict) 155 | __C.proc() 156 | 157 | print('Hyper Parameters:') 158 | print(__C) 159 | 160 | __C.check_path() 161 | 162 | execution = Execution(__C) 163 | execution.run(__C.RUN_MODE) 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /utils/__pycache__/vqa.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/utils/__pycache__/vqa.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/vqaEval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/mcan-vqa-continuous-attention/58c57041a7bd2691da05888828eb691920342c36/utils/__pycache__/vqaEval.cpython-36.pyc -------------------------------------------------------------------------------- /utils/proc_ansdict.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # mcan-vqa (Deep Modular Co-Attention Networks) 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yuhao Cui https://github.com/cuiyuhao1996 5 | # -------------------------------------------------------- 6 | 7 | import sys 8 | sys.path.append('../') 9 | from core.data.ans_punct import prep_ans 10 | import json 11 | 12 | DATASET_PATH = '../datasets/vqa/' 13 | 14 | ANSWER_PATH = { 15 | 'train': DATASET_PATH + 'v2_mscoco_train2014_annotations.json', 16 | 'val': DATASET_PATH + 'v2_mscoco_val2014_annotations.json', 17 | 'vg': DATASET_PATH + 'VG_annotations.json', 18 | } 19 | 20 | # Loading answer word list 21 | stat_ans_list = \ 22 | json.load(open(ANSWER_PATH['train'], 'r'))['annotations'] + \ 23 | json.load(open(ANSWER_PATH['val'], 'r'))['annotations'] 24 | 25 | 26 | def ans_stat(stat_ans_list): 27 | ans_to_ix = {} 28 | ix_to_ans = {} 29 | ans_freq_dict = {} 30 | 31 | for ans in stat_ans_list: 32 | ans_proc = prep_ans(ans['multiple_choice_answer']) 33 | if ans_proc not in ans_freq_dict: 34 | ans_freq_dict[ans_proc] = 1 35 | else: 36 | ans_freq_dict[ans_proc] += 1 37 | 38 | ans_freq_filter = ans_freq_dict.copy() 39 | for ans in ans_freq_dict: 40 | if ans_freq_dict[ans] <= 8: 41 | ans_freq_filter.pop(ans) 42 | 43 | for ans in ans_freq_filter: 44 | ix_to_ans[ans_to_ix.__len__()] = ans 45 | ans_to_ix[ans] = ans_to_ix.__len__() 46 | 47 | return ans_to_ix, ix_to_ans 48 | 49 | ans_to_ix, ix_to_ans = ans_stat(stat_ans_list) 50 | # print(ans_to_ix.__len__()) 51 | json.dump([ans_to_ix, ix_to_ans], open('../core/data/answer_dict.json', 'w')) 52 | -------------------------------------------------------------------------------- /utils/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | __version__ = '0.9' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | 24 | class VQA: 25 | def __init__(self, annotation_file=None, question_file=None): 26 | """ 27 | Constructor of VQA helper class for reading and visualizing questions and answers. 28 | :param annotation_file (str): location of VQA annotation file 29 | :return: 30 | """ 31 | # load dataset 32 | self.dataset = {} 33 | self.questions = {} 34 | self.qa = {} 35 | self.qqa = {} 36 | self.imgToQA = {} 37 | if not annotation_file == None and not question_file == None: 38 | print('loading VQA annotations and questions into memory...') 39 | time_t = datetime.datetime.utcnow() 40 | dataset = json.load(open(annotation_file, 'r')) 41 | questions = json.load(open(question_file, 'r')) 42 | print(datetime.datetime.utcnow() - time_t) 43 | self.dataset = dataset 44 | self.questions = questions 45 | self.createIndex() 46 | 47 | def createIndex(self): 48 | # create index 49 | print('creating index...') 50 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} 51 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']} 52 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} 53 | for ann in self.dataset['annotations']: 54 | imgToQA[ann['image_id']] += [ann] 55 | qa[ann['question_id']] = ann 56 | for ques in self.questions['questions']: 57 | qqa[ques['question_id']] = ques 58 | print('index created!') 59 | 60 | # create class members 61 | self.qa = qa 62 | self.qqa = qqa 63 | self.imgToQA = imgToQA 64 | 65 | def info(self): 66 | """ 67 | Print information about the VQA annotation file. 68 | :return: 69 | """ 70 | for key, value in self.dataset['info'].items(): 71 | print('%s: %s' % (key, value)) 72 | 73 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 74 | """ 75 | Get question ids that satisfy given filter conditions. default skips that filter 76 | :param imgIds (int array) : get question ids for given imgs 77 | quesTypes (str array) : get question ids for given question types 78 | ansTypes (str array) : get question ids for given answer types 79 | :return: ids (int array) : integer array of question ids 80 | """ 81 | imgIds = imgIds if type(imgIds) == list else [imgIds] 82 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 83 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 84 | 85 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 86 | anns = self.dataset['annotations'] 87 | else: 88 | if not len(imgIds) == 0: 89 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], []) 90 | else: 91 | anns = self.dataset['annotations'] 92 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 93 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 94 | ids = [ann['question_id'] for ann in anns] 95 | return ids 96 | 97 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 98 | """ 99 | Get image ids that satisfy given filter conditions. default skips that filter 100 | :param quesIds (int array) : get image ids for given question ids 101 | quesTypes (str array) : get image ids for given question types 102 | ansTypes (str array) : get image ids for given answer types 103 | :return: ids (int array) : integer array of image ids 104 | """ 105 | quesIds = quesIds if type(quesIds) == list else [quesIds] 106 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 107 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 108 | 109 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 110 | anns = self.dataset['annotations'] 111 | else: 112 | if not len(quesIds) == 0: 113 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], []) 114 | else: 115 | anns = self.dataset['annotations'] 116 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 117 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 118 | ids = [ann['image_id'] for ann in anns] 119 | return ids 120 | 121 | def loadQA(self, ids=[]): 122 | """ 123 | Load questions and answers with the specified question ids. 124 | :param ids (int array) : integer ids specifying question ids 125 | :return: qa (object array) : loaded qa objects 126 | """ 127 | if type(ids) == list: 128 | return [self.qa[id] for id in ids] 129 | elif type(ids) == int: 130 | return [self.qa[ids]] 131 | 132 | def showQA(self, anns): 133 | """ 134 | Display the specified annotations. 135 | :param anns (array of object): annotations to display 136 | :return: None 137 | """ 138 | if len(anns) == 0: 139 | return 0 140 | for ann in anns: 141 | quesId = ann['question_id'] 142 | print("Question: %s" % (self.qqa[quesId]['question'])) 143 | for ans in ann['answers']: 144 | print("Answer %d: %s" % (ans['answer_id'], ans['answer'])) 145 | 146 | def loadRes(self, resFile, quesFile): 147 | """ 148 | Load result file and return a result object. 149 | :param resFile (str) : file name of result file 150 | :return: res (obj) : result api object 151 | """ 152 | res = VQA() 153 | res.questions = json.load(open(quesFile)) 154 | res.dataset['info'] = copy.deepcopy(self.questions['info']) 155 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) 156 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) 157 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype']) 158 | res.dataset['license'] = copy.deepcopy(self.questions['license']) 159 | 160 | print('Loading and preparing results... ') 161 | time_t = datetime.datetime.utcnow() 162 | anns = json.load(open(resFile)) 163 | assert type(anns) == list, 'results is not an array of objects' 164 | annsQuesIds = [ann['question_id'] for ann in anns] 165 | assert set(annsQuesIds) == set(self.getQuesIds()), \ 166 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' 167 | for ann in anns: 168 | quesId = ann['question_id'] 169 | if res.dataset['task_type'] == 'Multiple Choice': 170 | assert ann['answer'] in self.qqa[quesId][ 171 | 'multiple_choices'], 'predicted answer is not one of the multiple choices' 172 | qaAnn = self.qa[quesId] 173 | ann['image_id'] = qaAnn['image_id'] 174 | ann['question_type'] = qaAnn['question_type'] 175 | ann['answer_type'] = qaAnn['answer_type'] 176 | print('DONE (t=%0.2fs)' % ((datetime.datetime.utcnow() - time_t).total_seconds())) 177 | 178 | res.dataset['annotations'] = anns 179 | res.createIndex() 180 | return res 181 | -------------------------------------------------------------------------------- /utils/vqaEval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | __author__='aagrawal' 4 | 5 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 6 | # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). 7 | import sys 8 | import re 9 | 10 | class VQAEval: 11 | def __init__(self, vqa, vqaRes, n=2): 12 | self.n = n 13 | self.accuracy = {} 14 | self.evalQA = {} 15 | self.evalQuesType = {} 16 | self.evalAnsType = {} 17 | self.vqa = vqa 18 | self.vqaRes = vqaRes 19 | self.params = {'question_id': vqa.getQuesIds()} 20 | self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", 21 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", 22 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", 23 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", 24 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 25 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 26 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 27 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", 28 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", 29 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", 30 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", 31 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", 32 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", 33 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", 34 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", 35 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", 36 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", 37 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 38 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 39 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 40 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", 41 | "youll": "you'll", "youre": "you're", "youve": "you've"} 42 | self.manualMap = { 'none': '0', 43 | 'zero': '0', 44 | 'one': '1', 45 | 'two': '2', 46 | 'three': '3', 47 | 'four': '4', 48 | 'five': '5', 49 | 'six': '6', 50 | 'seven': '7', 51 | 'eight': '8', 52 | 'nine': '9', 53 | 'ten': '10' 54 | } 55 | self.articles = ['a', 56 | 'an', 57 | 'the' 58 | ] 59 | 60 | 61 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 62 | self.commaStrip = re.compile("(\d)(,)(\d)") 63 | self.punct = [';', r"/", '[', ']', '"', '{', '}', 64 | '(', ')', '=', '+', '\\', '_', '-', 65 | '>', '<', '@', '`', ',', '?', '!'] 66 | 67 | 68 | def evaluate(self, quesIds=None): 69 | if quesIds == None: 70 | quesIds = [quesId for quesId in self.params['question_id']] 71 | gts = {} 72 | res = {} 73 | for quesId in quesIds: 74 | gts[quesId] = self.vqa.qa[quesId] 75 | res[quesId] = self.vqaRes.qa[quesId] 76 | 77 | # ================================================= 78 | # Compute accuracy 79 | # ================================================= 80 | accQA = [] 81 | accQuesType = {} 82 | accAnsType = {} 83 | print ("computing accuracy") 84 | step = 0 85 | for quesId in quesIds: 86 | resAns = res[quesId]['answer'] 87 | resAns = resAns.replace('\n', ' ') 88 | resAns = resAns.replace('\t', ' ') 89 | resAns = resAns.strip() 90 | resAns = self.processPunctuation(resAns) 91 | resAns = self.processDigitArticle(resAns) 92 | gtAcc = [] 93 | gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] 94 | if len(set(gtAnswers)) > 1: 95 | for ansDic in gts[quesId]['answers']: 96 | ansDic['answer'] = self.processPunctuation(ansDic['answer']) 97 | for gtAnsDatum in gts[quesId]['answers']: 98 | otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum] 99 | matchingAns = [item for item in otherGTAns if item['answer']==resAns] 100 | acc = min(1, float(len(matchingAns))/3) 101 | gtAcc.append(acc) 102 | quesType = gts[quesId]['question_type'] 103 | ansType = gts[quesId]['answer_type'] 104 | avgGTAcc = float(sum(gtAcc))/len(gtAcc) 105 | accQA.append(avgGTAcc) 106 | if quesType not in accQuesType: 107 | accQuesType[quesType] = [] 108 | accQuesType[quesType].append(avgGTAcc) 109 | if ansType not in accAnsType: 110 | accAnsType[ansType] = [] 111 | accAnsType[ansType].append(avgGTAcc) 112 | self.setEvalQA(quesId, avgGTAcc) 113 | self.setEvalQuesType(quesId, quesType, avgGTAcc) 114 | self.setEvalAnsType(quesId, ansType, avgGTAcc) 115 | if step%100 == 0: 116 | self.updateProgress(step/float(len(quesIds))) 117 | step = step + 1 118 | 119 | self.setAccuracy(accQA, accQuesType, accAnsType) 120 | print ("Done computing accuracy") 121 | 122 | def processPunctuation(self, inText): 123 | outText = inText 124 | for p in self.punct: 125 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): 126 | outText = outText.replace(p, '') 127 | else: 128 | outText = outText.replace(p, ' ') 129 | outText = self.periodStrip.sub("", 130 | outText, 131 | re.UNICODE) 132 | return outText 133 | 134 | def processDigitArticle(self, inText): 135 | outText = [] 136 | tempText = inText.lower().split() 137 | for word in tempText: 138 | word = self.manualMap.setdefault(word, word) 139 | if word not in self.articles: 140 | outText.append(word) 141 | else: 142 | pass 143 | for wordId, word in enumerate(outText): 144 | if word in self.contractions: 145 | outText[wordId] = self.contractions[word] 146 | outText = ' '.join(outText) 147 | return outText 148 | 149 | def setAccuracy(self, accQA, accQuesType, accAnsType): 150 | self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n) 151 | self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType} 152 | self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType} 153 | 154 | def setEvalQA(self, quesId, acc): 155 | self.evalQA[quesId] = round(100*acc, self.n) 156 | 157 | def setEvalQuesType(self, quesId, quesType, acc): 158 | if quesType not in self.evalQuesType: 159 | self.evalQuesType[quesType] = {} 160 | self.evalQuesType[quesType][quesId] = round(100*acc, self.n) 161 | 162 | def setEvalAnsType(self, quesId, ansType, acc): 163 | if ansType not in self.evalAnsType: 164 | self.evalAnsType[ansType] = {} 165 | self.evalAnsType[ansType][quesId] = round(100*acc, self.n) 166 | 167 | def updateProgress(self, progress): 168 | barLength = 20 169 | status = "" 170 | if isinstance(progress, int): 171 | progress = float(progress) 172 | if not isinstance(progress, float): 173 | progress = 0 174 | status = "error: progress var must be float\r\n" 175 | if progress < 0: 176 | progress = 0 177 | status = "Halt...\r\n" 178 | if progress >= 1: 179 | progress = 1 180 | status = "Done...\r\n" 181 | block = int(round(barLength*progress)) 182 | text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status) 183 | sys.stdout.write(text) 184 | sys.stdout.flush() 185 | 186 | --------------------------------------------------------------------------------