├── plot ├── __init__.py ├── plotter.py └── visualize_att.py ├── .gitignore ├── config ├── dataset_creation │ ├── dme_v1.yaml │ ├── sts2017_v1.yaml │ ├── sts2017_v5.yaml │ ├── insegcat_v4.yaml │ ├── sts2017_v4.yaml │ ├── insegcat_v1.yaml │ └── insegcat_v3.yaml ├── insegcat │ ├── config_nomask.yaml │ ├── config_regionintext.yaml │ ├── config_cropregion.yaml │ ├── config_ours.yaml │ └── config_drawregion.yaml ├── sts2017 │ ├── config_cropregion.yaml │ ├── config_nomask.yaml │ ├── config_ours.yaml │ ├── config_drawregion.yaml │ └── config_regionintext.yaml └── dme │ ├── config_nomask.yaml │ ├── config_cropregion.yaml │ ├── config_ours.yaml │ ├── config_drawregion.yaml │ └── config_regionintext.yaml ├── requirements.txt ├── misc ├── printer.py ├── image_processing.py ├── git.py ├── compute_answer_weights.py ├── dirs.py └── io.py ├── core ├── models │ ├── components │ │ ├── utils.py │ │ ├── image.py │ │ ├── classification.py │ │ ├── fusion.py │ │ ├── text.py │ │ └── attention.py │ ├── model_factory.py │ └── models.py ├── train_vault │ ├── optimizers.py │ ├── criteria.py │ ├── comet.py │ ├── logbook.py │ ├── train_utils.py │ └── looper.py └── datasets │ ├── loaders_factory.py │ ├── aux.py │ ├── visual.py │ ├── nlp.py │ └── vqa.py ├── create_dataset.py ├── LICENSE ├── testing ├── visualize_attention_maps.py ├── plot_metrics_per_class.py ├── plot_predictions.py └── test_dataset.py ├── metrics └── metrics.py ├── train.py ├── README.md ├── inference.py ├── plot_metrics.py └── dataset_factory └── qa_factory.py /plot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | 4 | .vscode/ 5 | 6 | dataset_factory/coco 7 | 8 | miccai2023/ 9 | 10 | config/cholec 11 | 12 | data 13 | -------------------------------------------------------------------------------- /config/dataset_creation/dme_v1.yaml: -------------------------------------------------------------------------------- 1 | dataset: DME 2 | path_orig: /home/sergio814/Documents/PhD/code/data/dme_dataset_8_balanced 3 | path_output: /home/sergio814/Documents/PhD/code/data/Tools/DME_v1 4 | 5 | 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | comet_ml 2 | matplotlib 3 | pandas 4 | urllib3~>1.26.17 5 | requests-toolbelt==0.10.1 6 | scikit-learn 7 | torch~>2.2.0 8 | torchvision==0.14.1 9 | tqdm=~>4.66.3 10 | wandb==0.12.10 11 | -------------------------------------------------------------------------------- /misc/printer.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Functions to print stuff in the console 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | def print_section(section_name): 9 | print(40*"~") 10 | print(section_name) 11 | print(40*"~") 12 | 13 | def print_line(): 14 | print(40*'-') 15 | 16 | def print_event(text): 17 | print('-> Now doing:', text, '...') -------------------------------------------------------------------------------- /core/models/components/utils.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Utilities for VQA components 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | def expand_like_2D(to_expand, reference): 9 | # expands input with dims [B, K] to dimensions of reference which are [B, K, M, M] 10 | expanded = to_expand.unsqueeze_(2).unsqueeze_(3).expand_as(reference) 11 | return expanded -------------------------------------------------------------------------------- /misc/image_processing.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Image processing functions 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | 13 | 14 | def resize_and_save(path_in, path_out, size = 448): 15 | """Normalization function 16 | """ 17 | im = Image.open(path_in) 18 | 19 | im_resized = im.resize((size, size), Image.ANTIALIAS) 20 | 21 | im_resized.save(path_out) -------------------------------------------------------------------------------- /core/models/model_factory.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Model factory 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import torch.nn as nn 9 | from . import models 10 | 11 | 12 | def get_vqa_model(config, vocab_words, vocab_answers): 13 | # function to provide a vqa model 14 | 15 | model = getattr(models, config['model'])(config, vocab_words, vocab_answers) 16 | 17 | if config['data_parallel'] and config['cuda']: 18 | model = nn.DataParallel(model).cuda() 19 | 20 | return model -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Dataset creation for localized questions, depending on chosen config file 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import os 9 | import json 10 | 11 | from misc import io 12 | from dataset_factory import regions 13 | 14 | # get args and config 15 | args = io.get_config_file_name() 16 | config = io.read_config(args.path_config) 17 | 18 | # invoke constructor depending on class name given in the config file 19 | dataset_obj = getattr(regions, config['dataset'])(config) 20 | 21 | dataset_obj.print_details() 22 | dataset_obj.build() 23 | dataset_obj.created_successfully() -------------------------------------------------------------------------------- /misc/git.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # GIT-related functions and classes 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import os 9 | import subprocess 10 | 11 | def get_commit_hash(): 12 | """Function to get the commit hash. 13 | 14 | Returns 15 | ------- 16 | str 17 | Commit hash of this version of the code 18 | """ 19 | old_path = os.getcwd() 20 | os.chdir(os.path.dirname(os.path.abspath(__file__))) # enter path of train file which is repo folder 21 | try: 22 | h = subprocess.check_output(["git", "log", "--pretty=format:%H", "-n", "1"]).decode() 23 | except: 24 | h = "UnknownHash" 25 | os.chdir(old_path) 26 | return h -------------------------------------------------------------------------------- /config/dataset_creation/sts2017_v1.yaml: -------------------------------------------------------------------------------- 1 | dataset: STS2017 2 | path_data: /home/sergio814/Documents/PhD/code/data/Tools/STS2017 3 | path_output: /home/sergio814/Documents/PhD/code/data/Tools/STS2017_v1 4 | 5 | # options for image preparation (pre-processing) 6 | overwrite_img: False 7 | resize: True 8 | size: 448 9 | 10 | overwrite_qa: True # whether or not QA file should be overwritten 11 | num_regions: 10 # how many regions to generate for each class of each image. Half of them are generated to have answer No and half with answer Yes. Should be an even number 12 | min_regions: 4 # minimum number of regions to be generated when segmentation region is too big or too small. Should be an even number 13 | 14 | threshold: 1 15 | threshold_as_percentage: False 16 | min_window_side: 150 17 | max_window_side: 512 18 | proportion_deviation: 0.2 # deviation around one for the windows 19 | window_offset: 10 # leave 10 pixels as border (i.e. sample the random regions excluding the borders) 20 | -------------------------------------------------------------------------------- /config/dataset_creation/sts2017_v5.yaml: -------------------------------------------------------------------------------- 1 | dataset: STS2017 2 | path_data: /home/sergio814/Documents/PhD/code/data/Tools/STS2017 3 | path_output: /home/sergio814/Documents/PhD/code/data/Tools/STS2017_v5 4 | 5 | # options for image preparation (pre-processing) 6 | overwrite_img: False 7 | resize: True 8 | size: 448 9 | 10 | overwrite_qa: True # whether or not QA file should be overwritten 11 | num_regions: 10 # how many regions to generate for each class of each image. Half of them are generated to have answer No and half with answer Yes. Should be an even number 12 | min_regions: 4 # minimum number of regions to be generated when segmentation region is too big or too small. Should be an even number 13 | 14 | threshold: 1 15 | threshold_as_percentage: True 16 | min_window_side: 150 17 | max_window_side: 512 18 | proportion_deviation: 0.2 # deviation around one for the windows 19 | window_offset: 10 # leave 10 pixels as border (i.e. sample the random regions excluding the borders) 20 | -------------------------------------------------------------------------------- /config/dataset_creation/insegcat_v4.yaml: -------------------------------------------------------------------------------- 1 | dataset: Insegcat 2 | path_data: /home/sergio814/Documents/PhD/code/data/Tools/insegcat-2 3 | path_output: /home/sergio814/Documents/PhD/code/data/Tools/INSEGCAT_v4 4 | 5 | # options for image preparation (pre-processing) 6 | overwrite_img: False 7 | resize: True 8 | size: 448 9 | 10 | overwrite_qa: True # whether or not QA file should be overwritten 11 | num_regions: 10 # how many regions to generate for each class of each image. Half of them are generated to have answer No and half with answer Yes. Should be an even number 12 | min_regions: 4 # minimum number of regions to be generated when segmentation region is too big or too small. Should be an even number 13 | 14 | threshold: 2 15 | threshold_as_percentage: True 16 | min_window_side: 100 17 | max_window_side: 260 18 | proportion_deviation: 0.2 # deviation around one for the windows 19 | window_offset: 10 # leave 10 pixels as border (i.e. sample the random regions excluding the borders) 20 | -------------------------------------------------------------------------------- /config/dataset_creation/sts2017_v4.yaml: -------------------------------------------------------------------------------- 1 | dataset: STS2017 2 | path_data: /home/sergio814/Documents/PhD/code/data/Tools/STS2017 3 | path_output: /home/sergio814/Documents/PhD/code/data/Tools/STS2017_v4 4 | 5 | # options for image preparation (pre-processing) 6 | overwrite_img: False 7 | resize: True 8 | size: 448 9 | 10 | overwrite_qa: True # whether or not QA file should be overwritten 11 | num_regions: 10 # how many regions to generate for each class of each image. Half of them are generated to have answer No and half with answer Yes. Should be an even number 12 | min_regions: 4 # minimum number of regions to be generated when segmentation region is too big or too small. Should be an even number 13 | 14 | threshold: 10 15 | threshold_as_percentage: False 16 | min_window_side: 150 17 | max_window_side: 512 18 | proportion_deviation: 0.2 # deviation around one for the windows 19 | window_offset: 10 # leave 10 pixels as border (i.e. sample the random regions excluding the borders) 20 | -------------------------------------------------------------------------------- /config/dataset_creation/insegcat_v1.yaml: -------------------------------------------------------------------------------- 1 | dataset: Insegcat 2 | path_data: /home/sergio814/Documents/PhD/code/data/Tools/insegcat-2 3 | path_output: /home/sergio814/Documents/PhD/code/data/Tools/INSEGCAT_v1 4 | 5 | # options for image preparation (pre-processing) 6 | overwrite_img: False 7 | resize: True 8 | size: 448 9 | 10 | overwrite_qa: True # whether or not QA file should be overwritten 11 | num_regions: 10 # how many regions to generate for each class of each image. Half of them are generated to have answer No and half with answer Yes. Should be an even number 12 | min_regions: 4 # minimum number of regions to be generated when segmentation region is too big or too small. Should be an even number 13 | 14 | threshold: 1 15 | threshold_as_percentage: False 16 | min_window_side: 100 17 | max_window_side: 260 18 | proportion_deviation: 0.2 # deviation around one for the windows 19 | window_offset: 10 # leave 10 pixels as border (i.e. sample the random regions excluding the borders) 20 | -------------------------------------------------------------------------------- /config/dataset_creation/insegcat_v3.yaml: -------------------------------------------------------------------------------- 1 | dataset: Insegcat 2 | path_data: /home/sergio814/Documents/PhD/code/data/Tools/insegcat-2 3 | path_output: /home/sergio814/Documents/PhD/code/data/Tools/INSEGCAT_v3 4 | 5 | # options for image preparation (pre-processing) 6 | overwrite_img: False 7 | resize: True 8 | size: 448 9 | 10 | overwrite_qa: True # whether or not QA file should be overwritten 11 | num_regions: 10 # how many regions to generate for each class of each image. Half of them are generated to have answer No and half with answer Yes. Should be an even number 12 | min_regions: 4 # minimum number of regions to be generated when segmentation region is too big or too small. Should be an even number 13 | 14 | threshold: 10 15 | threshold_as_percentage: False 16 | min_window_side: 100 17 | max_window_side: 260 18 | proportion_deviation: 0.2 # deviation around one for the windows 19 | window_offset: 10 # leave 10 pixels as border (i.e. sample the random regions excluding the borders) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Sergio Tascon Morales 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /core/train_vault/optimizers.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Optimizers 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import torch 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | 11 | def get_optimizer(config, model, add_scheduler=False): 12 | 13 | if 'adam' in config['optimizer']: 14 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), config['learning_rate']) 15 | elif 'adadelta' in config['optimizer']: 16 | optimizer = torch.optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), config['learning_rate']) 17 | elif 'rmsprop' in config['optimizer']: 18 | optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), config['learning_rate']) 19 | elif 'sgd' in config['optimizer']: 20 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), config['learning_rate']) 21 | 22 | if add_scheduler: 23 | scheduler = ReduceLROnPlateau(optimizer, 'min') 24 | return optimizer, scheduler 25 | else: 26 | return optimizer -------------------------------------------------------------------------------- /core/train_vault/criteria.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Loss functions file 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from torch import nn 9 | import numpy as np 10 | import torch 11 | import os 12 | 13 | def get_criterion(config, device, ignore_index = None, weights = None): 14 | # function to return a criterion. By default I set reduction to 'sum' so that batch averages are not performed because I want the average across the whole dataset 15 | if config['loss'] == 'crossentropy': 16 | if weights is not None: 17 | crit = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean', weight=weights).to(device) 18 | else: 19 | crit = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum').to(device) 20 | elif config['loss'] == 'bce': 21 | if weights is not None: 22 | crit = nn.BCEWithLogitsLoss(reduction='mean').to(device) 23 | else: 24 | crit = nn.BCEWithLogitsLoss(reduction='sum').to(device) 25 | else: 26 | raise ValueError("Unknown loss function.") 27 | 28 | return crit -------------------------------------------------------------------------------- /core/datasets/loaders_factory.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Script to provide dataloaders 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import numpy as np 9 | from torch.utils.data import DataLoader 10 | import collections 11 | import torch 12 | from . import visual, vqa, aux 13 | 14 | 15 | def get_vqa_loader(subset, config, shuffle=False, draw_regions=False): 16 | 17 | # create visual dataset for images 18 | dataset_visual = visual.get_visual_dataset(subset, config) 19 | 20 | # create vqa dataset for questions and answers 21 | dataset_vqa = vqa.get_vqa_dataset(subset, config, dataset_visual, draw_regions=draw_regions) 22 | 23 | dataloader = DataLoader( dataset_vqa, 24 | batch_size = config['batch_size'], 25 | shuffle=shuffle, 26 | num_workers=config['num_workers'], 27 | pin_memory=config['pin_memory'], 28 | collate_fn=aux.collater 29 | ) 30 | if subset == 'train': 31 | return dataloader, dataset_vqa.map_index_word, dataset_vqa.map_index_answer, dataset_vqa.index_unknown_answer 32 | else: 33 | return dataloader -------------------------------------------------------------------------------- /misc/compute_answer_weights.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Script to compute the weights for the answers 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from os.path import join as jp 9 | import os 10 | import pickle 11 | from collections import Counter 12 | import torch 13 | 14 | path_base = '/home/sergio814/Documents/PhD/code/data/Tools/DME_v2/' 15 | path_processed = jp(path_base, 'processed') 16 | path_output = jp(path_base, 'answer_weights') 17 | os.makedirs(path_output, exist_ok=True) 18 | path_output_file = jp(path_output, 'w.pt') 19 | 20 | # read train QA pairs using pickle 21 | path_trainset = jp(path_processed, 'trainset.pickle') 22 | with open(path_trainset, 'rb') as f: 23 | trainset = pickle.load(f) 24 | 25 | answers = [e['answer_index'] for e in trainset] 26 | 27 | countings = Counter(answers).most_common() 28 | countings_dict = {e[0]:e[1] for e in countings} 29 | weights = torch.zeros(len(countings_dict)) 30 | for i in range(weights.shape[0]): 31 | weights[i] = countings_dict[i] 32 | 33 | # normalize weights as suggested in https://discuss.pytorch.org/t/weights-in-weighted-loss-nn-crossentropyloss/69514 34 | weights = 1 - weights/weights.sum() 35 | 36 | # save weights to target file 37 | torch.save(weights, path_output_file) -------------------------------------------------------------------------------- /core/models/components/image.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Image embedding script 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from torch import nn 9 | from torchvision import models 10 | from torchvision.models.resnet import ResNet152_Weights 11 | 12 | def get_visual_feature_extractor(config): 13 | if 'resnet' in config['visual_extractor']: 14 | model = ResNetExtractor(config['imagenet_weights']) 15 | else: 16 | raise ValueError("Unknown model for visual feature extraction") 17 | return model 18 | 19 | class ResNetExtractor(nn.Module): 20 | def __init__(self, imagenet): 21 | super().__init__() 22 | self.pre_trained = imagenet 23 | if self.pre_trained: 24 | self.net_base = models.resnet152(weights=ResNet152_Weights.DEFAULT) 25 | else: 26 | self.net_base = models.resnet152(weights=ResNet152_Weights.NONE) 27 | modules = list(self.net_base.children())[:-2] # ignore avgpool layer and classifier 28 | self.extractor = nn.Sequential(*modules) 29 | # freeze weights 30 | for p in self.extractor.parameters(): 31 | p.requires_grad = False 32 | 33 | def forward(self, x): 34 | x = self.extractor(x) # [B, 2048, 14, 14] if input is [B, 3, 448, 448] 35 | return x -------------------------------------------------------------------------------- /core/datasets/aux.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Auxiliary functions for dataset and dataloader creation 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import torch 9 | import collections 10 | import numpy as np 11 | 12 | 13 | def collater(batch): 14 | # function to collate several samples of a batch 15 | if torch.is_tensor(batch[0]): 16 | return torch.stack(batch, 0) 17 | elif type(batch[0]).__module__ == np.__name__ and type(batch[0]).__name__ == 'ndarray': 18 | return torch.stack([torch.from_numpy(sample) for sample in batch], 0) 19 | elif isinstance(batch[0], int): 20 | return torch.LongTensor(batch) 21 | elif isinstance(batch[0], float): 22 | return torch.tensor(batch) # * use DoubleTensor? 23 | elif isinstance(batch[0], dict): 24 | res = dict.fromkeys(batch[0].keys()) 25 | for k in res.keys(): 26 | res[k] = [s[k] for s in batch] 27 | return {k:collater(v) for k,v in res.items()} 28 | elif isinstance(batch[0], collections.Iterable): 29 | return torch.tensor(batch, dtype=torch.int) # ! integers because only for all answers it gets here and the indices are integers 30 | else: 31 | raise ValueError("Unknown type of samples in the batch. Add condition to collater function") 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /core/models/components/classification.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Output classifier definitions 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from torch import nn 9 | 10 | def get_classfier(input_size, config): 11 | hidden_size = config['classifier_hidden_size'] 12 | if config['num_answers'] == 2: 13 | output_size = 1 # if binary classification, output has one neuron and BCEWithLogitsLoss can be used 14 | else: 15 | output_size = config['num_answers'] 16 | dropout_percentage = config['classifier_dropout'] 17 | 18 | # create MLP 19 | classifier = Classifier(input_size, hidden_size, output_size, drop=dropout_percentage) 20 | return classifier 21 | 22 | class Classifier(nn.Module): 23 | def __init__(self, input_size, hidden_size, output_classes, drop=0.0): 24 | super().__init__() 25 | self.input_size = input_size 26 | self.hidden_size = hidden_size 27 | self.num_classes = output_classes 28 | 29 | self.drop1 = nn.Dropout(drop) 30 | self.fc1 = nn.Linear(self.input_size, self.hidden_size) 31 | self.relu = nn.ReLU() 32 | self.drop2 = nn.Dropout(drop) 33 | self.fc2 = nn.Linear(self.hidden_size, self.num_classes) 34 | 35 | def forward(self, x): 36 | x = self.fc1(self.drop1(x)) 37 | x = self.relu(x) 38 | x = self.fc2(self.drop2(x)) 39 | 40 | return x -------------------------------------------------------------------------------- /core/train_vault/comet.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Comet ML functions 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from comet_ml import Experiment, ExistingExperiment 9 | from misc.git import get_commit_hash 10 | 11 | def get_new_experiment(config, path_config): 12 | if not config['comet_ml']: 13 | return None 14 | config_file_name = path_config.split("/")[-1].split(".")[0] 15 | comet_exp = Experiment() # Requires to have api key in config file in home folder as suggested in the documentation 16 | comet_exp.add_tags([config['model'], config_file_name]) 17 | comet_exp.log_parameters(config) 18 | comet_exp.log_asset(path_config, file_name=path_config.split("/")[-1].split(".")[0]) # log yaml file to comet ml 19 | comet_exp.log_other('commit_hash', get_commit_hash()) # log commit hash 20 | return comet_exp 21 | 22 | def get_existing_experiment(config): 23 | if not config['comet_ml']: 24 | return None 25 | else: 26 | return ExistingExperiment(previous_experiment=config['experiment_key']) 27 | 28 | def log_metrics(exp, metrics, epoch, to_log='all'): 29 | if exp is None: 30 | return 31 | if to_log == 'all': 32 | exp.log_metrics(metrics, epoch=epoch) 33 | else: # if only some of the metrics in the dictionary should be logged 34 | to_be_logged = {v: metrics[k] for k, v in to_log.items()} 35 | exp.log_metrics(to_be_logged, epoch=epoch) -------------------------------------------------------------------------------- /core/models/components/fusion.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Multimodal fusion mechanisms 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from torch import nn 9 | import torch 10 | 11 | def get_fuser(fusion_method, size_dim_1_a, size_dim1_b): 12 | 13 | if 'cat' in fusion_method: 14 | fuser = ConcatenationFusion() 15 | fused_size = size_dim_1_a + size_dim1_b 16 | elif 'mul' in fusion_method: 17 | fuser = HadamardFusion() 18 | fused_size = size_dim_1_a 19 | elif 'sum' in fusion_method: 20 | fuser = AdditionFusion() 21 | fused_size = size_dim_1_a 22 | else: 23 | raise ValueError("Unsupported fusion method") 24 | 25 | return fuser, fused_size 26 | 27 | class ConcatenationFusion(nn.Module): 28 | # concatenates two [B, L] tensors 29 | def __init__(self): 30 | super().__init__() 31 | self.fuser = torch.cat 32 | def forward(self, x_1, x_2): 33 | return self.fuser((x_1, x_2), dim=1) # [B, 2L] 34 | 35 | class HadamardFusion(nn.Module): 36 | # Element-wise multiplication 37 | def __init__(self): 38 | super().__init__() 39 | self.fuser = torch.mul 40 | def forward(self, x_1, x_2): 41 | return self.fuser(x_1, x_2) # [B, L] 42 | 43 | class AdditionFusion(nn.Module): 44 | # Addition of two tensors 45 | def __init__(self): 46 | super().__init__() 47 | self.fuser = torch.add 48 | def forward(self, x_1, x_2): 49 | return self.fuser(x_1, x_2) # [B, L] -------------------------------------------------------------------------------- /core/models/components/text.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Text embedding script 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from torch import nn 9 | 10 | def get_text_feature_extractor(config, vocab_words): 11 | word_embedding_size = config['word_embedding_size'] 12 | num_layers_LSTM = config['num_layers_LSTM'] 13 | question_feature_size = config['question_feature_size'] 14 | 15 | # instanciate the text encoder 16 | embedder = LSTMEncoder(word_embedding_size, num_layers_LSTM, question_feature_size, vocab_words) 17 | 18 | return embedder 19 | 20 | 21 | class LSTMEncoder(nn.Module): 22 | 23 | def __init__(self, word_embedding_size, num_layers_LSTM, lstm_features, vocab_words): 24 | super().__init__() 25 | self.vocab_words = vocab_words 26 | self.word_embedding_size = word_embedding_size 27 | self.num_layers_LSTM = num_layers_LSTM 28 | self.lstm_features = lstm_features 29 | 30 | # create word embedding 31 | self.embedding = nn.Embedding(num_embeddings=len(self.vocab_words)+1, embedding_dim=self.word_embedding_size, padding_idx=0) 32 | 33 | # create sequence encoder 34 | self.rnn = nn.LSTM(input_size=self.word_embedding_size, hidden_size=self.lstm_features, num_layers=self.num_layers_LSTM) 35 | 36 | def forward(self, question_vector): 37 | # question vector should be [B, max_question_length] 38 | x = self.embedding(question_vector) # [B, max_question_length, word_embedding_size] 39 | x = x.transpose(0,1) # put sequence dimension first, batch dim second [max_question_length, B, word_embedding_size] 40 | self.rnn.flatten_parameters() # * attempt to remove warning about non contiguous weights. 41 | output, (hn, cn) = self.rnn(x) # output is [max_question_length, B, lstm_features], hn and cn are [1, B, lstm_features] 42 | return cn.squeeze(0) -------------------------------------------------------------------------------- /misc/dirs.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Functions to handle folders and files 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import os 9 | from os.path import join as jp 10 | import ntpath 11 | import shutil 12 | 13 | def create_folder(path): 14 | # creation of folders 15 | if not os.path.exists(path): 16 | try: 17 | os.mkdir(path) # try to create folder 18 | except: 19 | os.makedirs(path, exist_ok=True) # create full path 20 | 21 | def is_empty(path): 22 | # check if a folder is empty 23 | return len(os.listdir(path)) < 1 24 | 25 | def clean_folder(path): 26 | if is_empty(path): # already empty 27 | return 28 | else: 29 | files_and_folders = os.listdir(path) 30 | for elem in files_and_folders: 31 | if os.path.isdir(jp(path,elem)): 32 | shutil.rmtree(jp(path, elem)) 33 | else: 34 | os.remove(jp(path,elem)) 35 | 36 | def list_folders(path): 37 | # lists folders only in path 38 | return [k for k in os.listdir(path) if os.path.isdir(jp(path, k))] 39 | 40 | def list_files(path): 41 | return [k for k in os.listdir(path) if not os.path.isdir(jp(path, k))] 42 | 43 | def create_folders_within_folder(parent_path, folder_name_list): 44 | # create folders and return paths 45 | paths = [] 46 | for folder_name in folder_name_list: 47 | folder_path = jp(parent_path, folder_name) 48 | create_folder(folder_path) 49 | paths.append(folder_path) 50 | return paths 51 | 52 | def get_filename(path): 53 | head, tail = ntpath.split(path) 54 | return tail or ntpath.basename(head) 55 | 56 | def get_filename_without_extension(path): 57 | filename = get_filename(path) 58 | return os.path.splitext(filename)[0] 59 | 60 | def get_filename_with_extension(path): 61 | filename = get_filename(path) 62 | return filename 63 | 64 | def remove_whole_folder(path): 65 | shutil.rmtree(path) -------------------------------------------------------------------------------- /core/train_vault/logbook.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # ... 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import os 9 | import json 10 | from os.path import join as jp 11 | 12 | class Logbook(object): 13 | def __init__(self, init_info = None): 14 | if init_info is not None: 15 | self.book = init_info 16 | else: 17 | self.book = {'general': {}, 'train': {}, 'val': {}} # by default it will have some general info, some training info and some val info 18 | 19 | def log_metric(self, stage, metric_name, value, epoch): 20 | if stage != 'train' and stage != 'val': 21 | raise ValueError("stage must be either train or val") 22 | # check if metric was already created previously. If not, create it and log 23 | if metric_name not in self.book[stage]: 24 | self.book[stage][metric_name] = {} 25 | self.book[stage][metric_name][epoch] = value 26 | else: # if metric already exists, check if it was already reported for the epoch index, then report 27 | # check if metric was already logged for the given epoch (just as sanity check) 28 | if epoch in self.book[stage][metric_name]: 29 | print("Warning: Entry already exists for given epoch index " + str(epoch)) 30 | self.book[stage][metric_name][epoch] = value 31 | 32 | def log_metrics(self, stage, metrics, epoch): 33 | # receives dictionary 34 | for k,v in metrics.items(): 35 | self.log_metric(stage, k, v, epoch) 36 | 37 | def log_general_info(self, key, value): 38 | # log anything as general info to the book 39 | self.book['general'][key] = value 40 | 41 | def save_logbook(self, path): 42 | # path is just the path to the folder where the json file should be stored, not the full path to the file 43 | with open(jp(path, 'logbook.json'), 'w') as f: 44 | json.dump(self.book, f) 45 | 46 | def load_logbook(self, path): 47 | if not os.path.exists(jp(path, 'logbook.json')): 48 | raise Exception("File logbook.json does not exists at " + path) 49 | else: 50 | with open(jp(path, 'logbook.json')) as f: 51 | self.book = json.load(f) -------------------------------------------------------------------------------- /testing/visualize_attention_maps.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Visualization of attention maps 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import sys 9 | sys.path.append('/home/sergio814/Documents/PhD/code/locvqa/') 10 | 11 | import comet_ml 12 | import torch 13 | from misc import io 14 | from os.path import join as jp 15 | from metrics import metrics 16 | from tqdm import tqdm 17 | from plot import visualize_att 18 | from misc import printer, dirs 19 | from core.datasets import loaders_factory 20 | from core.models import model_factory 21 | from core.train_vault import criteria, optimizers, train_utils, looper 22 | 23 | args = io.get_config_file_name() 24 | 25 | def main(): 26 | # read config file 27 | config = io.read_config(args.path_config) 28 | 29 | config['train_from'] = 'best' # set this parameter to best so that best model is loaded for validation part 30 | config['comet_ml'] = False 31 | model_name = config['model'] 32 | 33 | device = torch.device('cuda' if torch.cuda.is_available() and config['cuda'] else 'cpu') 34 | 35 | # get loaders 36 | _, vocab_words, vocab_answers, index_unk_answer = loaders_factory.get_vqa_loader('train', config, shuffle=True) 37 | test_loader = loaders_factory.get_vqa_loader('test', config, shuffle=False) 38 | 39 | # get model 40 | model = model_factory.get_vqa_model(config, vocab_words, vocab_answers) 41 | 42 | # create optimizer 43 | optimizer, scheduler = optimizers.get_optimizer(config, model, add_scheduler=True) 44 | 45 | # get best epoch 46 | best_epoch, _, _, _, path_logs = train_utils.initialize_experiment(config, model, optimizer, args.path_config, lower_is_better=True) 47 | 48 | dirs.create_folder(jp(path_logs, 'att_maps')) 49 | 50 | model.eval() 51 | with torch.no_grad(): 52 | for i, sample in enumerate(tqdm(test_loader)): 53 | print('Batch', i+1, '/', len(test_loader)) 54 | # move data to GPU 55 | question = sample['question'].to(device) 56 | visual = sample['visual'].to(device) 57 | answer = sample['answer'].to(device) 58 | question_indexes = sample['question_id'] # keep in cpu 59 | mask = sample['mask'].to(device) 60 | 61 | visualize_att.plot_attention_maps(model_name, model, visual, question, mask, answer, vocab_words, path_logs, question_indexes, vocab_answers) 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /metrics/metrics.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Performance assessment 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | 9 | import torch 10 | import numpy as np 11 | from tqdm import tqdm 12 | from torch import nn 13 | from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve 14 | 15 | 16 | def vqa_accuracy(predicted, true): 17 | """ Compute the accuracies for a batch according to VQA challenge accuracy""" 18 | # in this case true is a [B, 10] matrix where ever row contains all answers for the particular question 19 | _, predicted_index = predicted.max(dim=1, keepdim=True) # should be [B, 1] where every row is an index 20 | agreement = torch.eq(predicted_index.view(true.size(0),1), true).sum(dim=1) # row-wise check of times that answer in predicted_index is in true 21 | 22 | return torch.min(agreement*0.3, torch.ones_like(agreement)).float().sum() # returning batch sum 23 | 24 | def accuracy(pred, gt): 25 | return torch.eq(pred, gt).sum()/pred.shape[0] 26 | 27 | def batch_strict_accuracy(predicted, true): 28 | # in this case true is a [B] tensor with the answers 29 | sm = nn.Softmax(dim=1) 30 | probs = sm(predicted) 31 | _, predicted_index = probs.max(dim=1) # should be [B, 1] where every row is an index 32 | return torch.eq(predicted_index, true).sum() # returning sum 33 | 34 | def batch_binary_accuracy(predicted, true): 35 | # input predicted already contains the indexes of the answers 36 | return torch.eq(predicted, true).sum() # returning sum 37 | 38 | def compute_auc_ap(targets_and_preds): 39 | # input is an Nx2 tensor where the first column contains the target answer for all samples and the second column containes the sigmoided predictions 40 | targets_and_preds_np = targets_and_preds.cpu().numpy() 41 | auc = roc_auc_score(targets_and_preds_np[:,0], targets_and_preds_np[:,1]) # eventually take np.ones((targets_and_preds_np.shape[0],)) - targets_and_preds_np[:,1] 42 | ap = average_precision_score(targets_and_preds_np[:,0], targets_and_preds_np[:,1], pos_label=1) 43 | return auc, ap 44 | 45 | def compute_roc_prc(targets_and_preds, positive_label = 1): 46 | y_true = targets_and_preds[:,0] 47 | y_pred = targets_and_preds[:,1] 48 | fpr, tpr, thresholds_roc = roc_curve(y_true, y_pred) 49 | auc = roc_auc_score(y_true, y_pred) 50 | precision, recall, thresholds_pr = precision_recall_curve(y_true, y_pred, pos_label=positive_label) 51 | ap = average_precision_score(y_true, y_pred, pos_label=positive_label) 52 | return auc, ap, (fpr, tpr, thresholds_roc), (precision, recall, thresholds_pr) -------------------------------------------------------------------------------- /testing/plot_metrics_per_class.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Script to plot metrics per class 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import sys 9 | sys.path.append('/home/sergio814/Documents/PhD/code/locvqa/') 10 | 11 | from os.path import join as jp 12 | import misc.io as io 13 | import torch 14 | import json 15 | from tqdm import tqdm 16 | import os 17 | from plot import plotter 18 | from metrics import metrics 19 | import numpy as np 20 | 21 | # read config name from CLI argument --path_config 22 | args = io.get_config_file_name() 23 | 24 | def main(): 25 | # read config file 26 | config = io.read_config(args.path_config) 27 | 28 | config_file_name = args.path_config.split("/")[-1].split(".")[0] 29 | 30 | path_logs = jp(config['logs_dir'], config['dataset'], config_file_name) 31 | 32 | path_qa = jp(config['path_data'], 'processed') 33 | 34 | # read qa test using pickle 35 | qa_test = io.read_pickle(jp(path_qa, 'testset.pickle')) 36 | # read qa val using pickle 37 | qa_val = io.read_pickle(jp(path_qa, 'valset.pickle')) 38 | # read map answer to index 39 | map_answer2idx = io.read_pickle(jp(path_qa, 'map_answer_index.pickle')) 40 | 41 | if config['num_answers'] == 2: 42 | path_test_answers_file = jp(path_logs, 'answers', 'answers_epoch_test.pt') 43 | if not os.path.exists(path_test_answers_file): 44 | raise Exception("Test set answers haven't been generated with inference.py") 45 | answers_test = torch.load(path_test_answers_file, map_location=torch.device('cpu')) 46 | # build dictionary with key: answer, value: probability 47 | id2prob = {answers_test['results'][i,0].item(): answers_test['answers'][i,1].item() for i in range(answers_test['answers'].shape[0])} 48 | # add probability to qa_test 49 | for q in qa_test: 50 | q['prob'] = id2prob[q['question_id']] 51 | # separate qa_test into groups based on object_object field, putting as value the gt answer and the probability 52 | objects_classes = set([q['question_object'] for q in qa_test]) 53 | qa_test_per_class = {obj_class: [] for obj_class in objects_classes} 54 | for q in tqdm(qa_test): 55 | qa_test_per_class[q['question_object']].append((q['answer_index'], q['prob'])) 56 | # compute metrics per class 57 | for k, v in qa_test_per_class.items(): 58 | matrix = np.array(v, dtype=float) 59 | # compute metrics 60 | auc_test, ap_test, roc_test, prc_test = metrics.compute_roc_prc(matrix) 61 | plotter.plot_roc_prc(roc_test, auc_test, prc_test, ap_test, title=k, save=True, path=path_logs, suffix=k) 62 | else: 63 | raise NotImplementedError 64 | 65 | 66 | if __name__ == '__main__': 67 | main() -------------------------------------------------------------------------------- /config/insegcat/config_nomask.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: insegcat 8 | path_data: data/INSEGCAT_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: bce 25 | batch_size: 64 26 | num_workers: 4 27 | pin_memory: True 28 | data_parallel: True 29 | cuda: True 30 | learning_rate: 0.0001 31 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 32 | epochs: 100 33 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 34 | patience: 20 # patience for the early stopping condition 35 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 36 | 37 | # ******************************************** 38 | # model structure 39 | # ******************************************** 40 | 41 | # visual feature extraction and pre-processing of images 42 | size: 448 43 | model: VQA_IgnoreMask # VQA_MaskRegion, VQA_IgnoreMask 44 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 45 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 46 | 47 | # visual feature extraction and pre-processing of images 48 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 49 | size: 448 50 | imagenet_weights: True 51 | visual_extractor: resnet # options are resnet, 52 | 53 | # text feature extraction 54 | word_embedding_size: 300 55 | num_layers_LSTM: 1 56 | 57 | # attention 58 | attention: True 59 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 60 | number_of_glimpses: 2 61 | attention_dropout: 0.25 62 | attention_fusion: mul # options are cat, mul, sum 63 | 64 | # fusion 65 | fusion: cat # options are cat, mul, sum 66 | 67 | # classifier 68 | classifier_hidden_size: 1024 69 | classifier_dropout: 0.25 70 | -------------------------------------------------------------------------------- /config/insegcat/config_regionintext.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: insegcat 8 | path_data: data/INSEGCAT_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: True # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: bce 25 | batch_size: 64 26 | num_workers: 4 27 | pin_memory: True 28 | data_parallel: True 29 | cuda: True 30 | learning_rate: 0.0001 31 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 32 | epochs: 100 33 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 34 | patience: 20 # patience for the early stopping condition 35 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 36 | 37 | # ******************************************** 38 | # model structure 39 | # ******************************************** 40 | 41 | # visual feature extraction and pre-processing of images 42 | size: 448 43 | model: VQA_Base # VQA_MaskRegion, VQA_IgnoreMask 44 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 45 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 46 | 47 | # visual feature extraction and pre-processing of images 48 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 49 | size: 448 50 | imagenet_weights: True 51 | visual_extractor: resnet # options are resnet, 52 | 53 | # text feature extraction 54 | word_embedding_size: 300 55 | num_layers_LSTM: 1 56 | 57 | # attention 58 | attention: True 59 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 60 | number_of_glimpses: 2 61 | attention_dropout: 0.25 62 | attention_fusion: mul # options are cat, mul, sum 63 | 64 | # fusion 65 | fusion: cat # options are cat, mul, sum 66 | 67 | # classifier 68 | classifier_hidden_size: 1024 69 | classifier_dropout: 0.25 70 | -------------------------------------------------------------------------------- /config/sts2017/config_cropregion.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: sts2017 8 | path_data: data/STS2017_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: bce 25 | batch_size: 64 26 | num_workers: 4 27 | pin_memory: True 28 | data_parallel: True 29 | cuda: True 30 | learning_rate: 0.0001 31 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 32 | epochs: 100 33 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 34 | patience: 20 # patience for the early stopping condition 35 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 36 | 37 | # ******************************************** 38 | # model structure 39 | # ******************************************** 40 | 41 | # visual feature extraction and pre-processing of images 42 | size: 448 43 | model: VQA_MaskRegion # VQA_MaskRegion, VQA_IgnoreMask 44 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 45 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 46 | 47 | # visual feature extraction and pre-processing of images 48 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 49 | size: 448 50 | imagenet_weights: True 51 | visual_extractor: resnet # options are resnet, 52 | 53 | # text feature extraction 54 | word_embedding_size: 300 55 | num_layers_LSTM: 1 56 | 57 | # attention 58 | attention: True 59 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 60 | number_of_glimpses: 2 61 | attention_dropout: 0.25 62 | attention_fusion: mul # options are cat, mul, sum 63 | 64 | # fusion 65 | fusion: cat # options are cat, mul, sum 66 | 67 | # classifier 68 | classifier_hidden_size: 1024 69 | classifier_dropout: 0.25 70 | -------------------------------------------------------------------------------- /config/sts2017/config_nomask.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: sts2017 8 | path_data: data/STS2017_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: bce 25 | batch_size: 64 26 | num_workers: 4 27 | pin_memory: True 28 | data_parallel: True 29 | cuda: True 30 | learning_rate: 0.0001 31 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 32 | epochs: 100 33 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 34 | patience: 20 # patience for the early stopping condition 35 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 36 | 37 | # ******************************************** 38 | # model structure 39 | # ******************************************** 40 | 41 | # visual feature extraction and pre-processing of images 42 | size: 448 43 | model: VQA_IgnoreMask # VQA_MaskRegion, VQA_IgnoreMask 44 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 45 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 46 | 47 | # visual feature extraction and pre-processing of images 48 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 49 | size: 448 50 | imagenet_weights: True 51 | visual_extractor: resnet # options are resnet, 52 | 53 | # text feature extraction 54 | word_embedding_size: 300 55 | num_layers_LSTM: 1 56 | 57 | # attention 58 | attention: True 59 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 60 | number_of_glimpses: 2 61 | attention_dropout: 0.25 62 | attention_fusion: mul # options are cat, mul, sum 63 | 64 | # fusion 65 | fusion: cat # options are cat, mul, sum 66 | 67 | # classifier 68 | classifier_hidden_size: 1024 69 | classifier_dropout: 0.25 70 | -------------------------------------------------------------------------------- /config/insegcat/config_cropregion.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: insegcat 8 | path_data: data/INSEGCAT_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: bce 25 | batch_size: 64 26 | num_workers: 4 27 | pin_memory: True 28 | data_parallel: True 29 | cuda: True 30 | learning_rate: 0.0001 31 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 32 | epochs: 100 33 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 34 | patience: 20 # patience for the early stopping condition 35 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 36 | 37 | # ******************************************** 38 | # model structure 39 | # ******************************************** 40 | 41 | # visual feature extraction and pre-processing of images 42 | size: 448 43 | model: VQA_MaskRegion # VQA_MaskRegion, VQA_IgnoreMask 44 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 45 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 46 | 47 | # visual feature extraction and pre-processing of images 48 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 49 | size: 448 50 | imagenet_weights: True 51 | visual_extractor: resnet # options are resnet, 52 | 53 | # text feature extraction 54 | word_embedding_size: 300 55 | num_layers_LSTM: 1 56 | 57 | # attention 58 | attention: True 59 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 60 | number_of_glimpses: 2 61 | attention_dropout: 0.25 62 | attention_fusion: mul # options are cat, mul, sum 63 | 64 | # fusion 65 | fusion: cat # options are cat, mul, sum 66 | 67 | # classifier 68 | classifier_hidden_size: 1024 69 | classifier_dropout: 0.25 70 | -------------------------------------------------------------------------------- /config/insegcat/config_ours.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: insegcat 8 | path_data: data/INSEGCAT_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: bce 25 | batch_size: 64 26 | num_workers: 4 27 | pin_memory: True 28 | data_parallel: True 29 | cuda: True 30 | learning_rate: 0.0001 31 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 32 | epochs: 100 33 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 34 | patience: 20 # patience for the early stopping condition 35 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 36 | 37 | # ******************************************** 38 | # model structure 39 | # ******************************************** 40 | 41 | # visual feature extraction and pre-processing of images 42 | size: 448 43 | model: VQA_LocalizedAttention # VQA_MaskRegion, VQA_IgnoreMask 44 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 45 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 46 | 47 | # visual feature extraction and pre-processing of images 48 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 49 | size: 448 50 | imagenet_weights: True 51 | visual_extractor: resnet # options are resnet, 52 | 53 | # text feature extraction 54 | word_embedding_size: 300 55 | num_layers_LSTM: 1 56 | 57 | # attention 58 | attention: True 59 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 60 | number_of_glimpses: 2 61 | attention_dropout: 0.25 62 | attention_fusion: mul # options are cat, mul, sum 63 | 64 | # fusion 65 | fusion: cat # options are cat, mul, sum 66 | 67 | # classifier 68 | classifier_hidden_size: 1024 69 | classifier_dropout: 0.25 70 | -------------------------------------------------------------------------------- /config/sts2017/config_ours.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: sts2017 8 | path_data: data/STS2017_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: bce 25 | batch_size: 64 26 | num_workers: 4 27 | pin_memory: True 28 | data_parallel: True 29 | cuda: True 30 | learning_rate: 0.0001 31 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 32 | epochs: 100 33 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 34 | patience: 20 # patience for the early stopping condition 35 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 36 | 37 | # ******************************************** 38 | # model structure 39 | # ******************************************** 40 | 41 | # visual feature extraction and pre-processing of images 42 | size: 448 43 | model: VQA_LocalizedAttention # VQA_MaskRegion, VQA_IgnoreMask 44 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 45 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 46 | 47 | # visual feature extraction and pre-processing of images 48 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 49 | size: 448 50 | imagenet_weights: True 51 | visual_extractor: resnet # options are resnet, 52 | 53 | # text feature extraction 54 | word_embedding_size: 300 55 | num_layers_LSTM: 1 56 | 57 | # attention 58 | attention: True 59 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 60 | number_of_glimpses: 2 61 | attention_dropout: 0.25 62 | attention_fusion: mul # options are cat, mul, sum 63 | 64 | # fusion 65 | fusion: cat # options are cat, mul, sum 66 | 67 | # classifier 68 | classifier_hidden_size: 1024 69 | classifier_dropout: 0.25 70 | -------------------------------------------------------------------------------- /config/dme/config_nomask.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: dme 8 | path_data: data/DME_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 5 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: crossentropy 25 | weighted_loss: True 26 | batch_size: 64 27 | num_workers: 4 28 | pin_memory: True 29 | data_parallel: True 30 | cuda: True 31 | learning_rate: 0.0001 32 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 33 | epochs: 100 34 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 35 | patience: 20 # patience for the early stopping condition 36 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 37 | 38 | # ******************************************** 39 | # model structure 40 | # ******************************************** 41 | 42 | # visual feature extraction and pre-processing of images 43 | size: 448 44 | model: VQA_IgnoreMask # VQA_MaskRegion, VQA_IgnoreMask 45 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 46 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 47 | 48 | # visual feature extraction and pre-processing of images 49 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 50 | size: 448 51 | imagenet_weights: True 52 | visual_extractor: resnet # options are resnet, 53 | 54 | # text feature extraction 55 | word_embedding_size: 300 56 | num_layers_LSTM: 1 57 | 58 | # attention 59 | attention: True 60 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 61 | number_of_glimpses: 2 62 | attention_dropout: 0.25 63 | attention_fusion: mul # options are cat, mul, sum 64 | 65 | # fusion 66 | fusion: cat # options are cat, mul, sum 67 | 68 | # classifier 69 | classifier_hidden_size: 1024 70 | classifier_dropout: 0.25 71 | -------------------------------------------------------------------------------- /config/sts2017/config_drawregion.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: sts2017 8 | path_data: data/STS2017_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | draw_regions: True 11 | 12 | # training and data-loading-related parameters 13 | mask_as_text: False # whether mask should be in the question or as a separate mask 14 | 15 | # text pre-processing 16 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 17 | max_question_length: 12 # length for questions that do not contain the region 18 | max_question_length_alt: 21 # length for questions that contain the region as text 19 | process_qa_again: False # whether or not QA pairs should be pre-processed again 20 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 21 | tokenizer: spacy # which tokenizer to use 22 | min_word_frequency: 0 23 | 24 | # training and data-loading-related parameters 25 | loss: bce 26 | batch_size: 64 27 | num_workers: 4 28 | pin_memory: True 29 | data_parallel: True 30 | cuda: True 31 | learning_rate: 0.0001 32 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 33 | epochs: 100 34 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 35 | patience: 20 # patience for the early stopping condition 36 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 37 | 38 | # ******************************************** 39 | # model structure 40 | # ******************************************** 41 | 42 | # visual feature extraction and pre-processing of images 43 | size: 448 44 | model: VQA_Base # VQA_MaskRegion, VQA_IgnoreMask 45 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 46 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 47 | 48 | # visual feature extraction and pre-processing of images 49 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 50 | size: 448 51 | imagenet_weights: True 52 | visual_extractor: resnet # options are resnet, 53 | 54 | # text feature extraction 55 | word_embedding_size: 300 56 | num_layers_LSTM: 1 57 | 58 | # attention 59 | attention: True 60 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 61 | number_of_glimpses: 2 62 | attention_dropout: 0.25 63 | attention_fusion: mul # options are cat, mul, sum 64 | 65 | # fusion 66 | fusion: cat # options are cat, mul, sum 67 | 68 | # classifier 69 | classifier_hidden_size: 1024 70 | classifier_dropout: 0.25 71 | -------------------------------------------------------------------------------- /config/dme/config_cropregion.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: dme 8 | path_data: data/DME_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 5 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: crossentropy 25 | weighted_loss: True 26 | batch_size: 64 27 | num_workers: 4 28 | pin_memory: True 29 | data_parallel: True 30 | cuda: True 31 | learning_rate: 0.0001 32 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 33 | epochs: 100 34 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 35 | patience: 20 # patience for the early stopping condition 36 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 37 | 38 | # ******************************************** 39 | # model structure 40 | # ******************************************** 41 | 42 | # visual feature extraction and pre-processing of images 43 | size: 448 44 | model: VQA_MaskRegion # VQA_MaskRegion, VQA_IgnoreMask 45 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 46 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 47 | 48 | # visual feature extraction and pre-processing of images 49 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 50 | size: 448 51 | imagenet_weights: True 52 | visual_extractor: resnet # options are resnet, 53 | 54 | # text feature extraction 55 | word_embedding_size: 300 56 | num_layers_LSTM: 1 57 | 58 | # attention 59 | attention: True 60 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 61 | number_of_glimpses: 2 62 | attention_dropout: 0.25 63 | attention_fusion: mul # options are cat, mul, sum 64 | 65 | # fusion 66 | fusion: cat # options are cat, mul, sum 67 | 68 | # classifier 69 | classifier_hidden_size: 1024 70 | classifier_dropout: 0.25 71 | -------------------------------------------------------------------------------- /config/dme/config_ours.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: dme 8 | path_data: data/DME_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: False # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 5 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: crossentropy 25 | weighted_loss: True 26 | batch_size: 64 27 | num_workers: 4 28 | pin_memory: True 29 | data_parallel: True 30 | cuda: True 31 | learning_rate: 0.0001 32 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 33 | epochs: 100 34 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 35 | patience: 20 # patience for the early stopping condition 36 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 37 | 38 | # ******************************************** 39 | # model structure 40 | # ******************************************** 41 | 42 | # visual feature extraction and pre-processing of images 43 | size: 448 44 | model: VQA_LocalizedAttention # VQA_MaskRegion, VQA_IgnoreMask 45 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 46 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 47 | 48 | # visual feature extraction and pre-processing of images 49 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 50 | size: 448 51 | imagenet_weights: True 52 | visual_extractor: resnet # options are resnet, 53 | 54 | # text feature extraction 55 | word_embedding_size: 300 56 | num_layers_LSTM: 1 57 | 58 | # attention 59 | attention: True 60 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 61 | number_of_glimpses: 2 62 | attention_dropout: 0.25 63 | attention_fusion: mul # options are cat, mul, sum 64 | 65 | # fusion 66 | fusion: cat # options are cat, mul, sum 67 | 68 | # classifier 69 | classifier_hidden_size: 1024 70 | classifier_dropout: 0.25 71 | -------------------------------------------------------------------------------- /config/insegcat/config_drawregion.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: insegcat 8 | path_data: data/INSEGCAT_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | draw_regions: True 11 | 12 | # training and data-loading-related parameters 13 | mask_as_text: False # whether mask should be in the question or as a separate mask 14 | 15 | # text pre-processing 16 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 17 | max_question_length: 12 # length for questions that do not contain the region 18 | max_question_length_alt: 21 # length for questions that contain the region as text 19 | process_qa_again: False # whether or not QA pairs should be pre-processed again 20 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 21 | tokenizer: spacy # which tokenizer to use 22 | min_word_frequency: 0 23 | 24 | # training and data-loading-related parameters 25 | loss: bce 26 | batch_size: 64 27 | num_workers: 4 28 | pin_memory: True 29 | data_parallel: True 30 | cuda: True 31 | learning_rate: 0.0001 32 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 33 | epochs: 100 34 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 35 | patience: 20 # patience for the early stopping condition 36 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 37 | 38 | # ******************************************** 39 | # model structure 40 | # ******************************************** 41 | 42 | # visual feature extraction and pre-processing of images 43 | size: 448 44 | model: VQA_Base # VQA_MaskRegion, VQA_IgnoreMask 45 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 46 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 47 | 48 | # visual feature extraction and pre-processing of images 49 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 50 | size: 448 51 | imagenet_weights: True 52 | visual_extractor: resnet # options are resnet, 53 | 54 | # text feature extraction 55 | word_embedding_size: 300 56 | num_layers_LSTM: 1 57 | 58 | # attention 59 | attention: True 60 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 61 | number_of_glimpses: 2 62 | attention_dropout: 0.25 63 | attention_fusion: mul # options are cat, mul, sum 64 | 65 | # fusion 66 | fusion: cat # options are cat, mul, sum 67 | 68 | # classifier 69 | classifier_hidden_size: 1024 70 | classifier_dropout: 0.25 71 | -------------------------------------------------------------------------------- /config/sts2017/config_regionintext.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: sts2017 8 | path_data: data/STS2017_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: True # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 2 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: bce 25 | batch_size: 64 26 | num_workers: 4 27 | pin_memory: True 28 | data_parallel: True 29 | cuda: True 30 | learning_rate: 0.0001 31 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 32 | epochs: 100 33 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 34 | patience: 20 # patience for the early stopping condition 35 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 36 | 37 | # ******************************************** 38 | # model structure 39 | # ******************************************** 40 | 41 | # visual feature extraction and pre-processing of images 42 | size: 448 43 | model: VQA_Base # VQA_MaskRegion, VQA_IgnoreMask 44 | attenuation_factor: 0.5 45 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 46 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 47 | 48 | # visual feature extraction and pre-processing of images 49 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 50 | size: 448 51 | imagenet_weights: True 52 | visual_extractor: resnet # options are resnet, 53 | 54 | # text feature extraction 55 | word_embedding_size: 300 56 | num_layers_LSTM: 1 57 | 58 | # attention 59 | attention: True 60 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 61 | number_of_glimpses: 2 62 | attention_dropout: 0.25 63 | attention_fusion: mul # options are cat, mul, sum 64 | 65 | # fusion 66 | fusion: cat # options are cat, mul, sum 67 | 68 | # classifier 69 | classifier_hidden_size: 1024 70 | classifier_dropout: 0.25 71 | -------------------------------------------------------------------------------- /config/dme/config_drawregion.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: dme 8 | path_data: data/DME_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | draw_regions: True 11 | 12 | # training and data-loading-related parameters 13 | mask_as_text: False # whether mask should be in the question or as a separate mask 14 | 15 | # text pre-processing 16 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 17 | max_question_length: 12 # length for questions that do not contain the region 18 | max_question_length_alt: 21 # length for questions that contain the region as text 19 | process_qa_again: False # whether or not QA pairs should be pre-processed again 20 | num_answers: 5 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 21 | tokenizer: spacy # which tokenizer to use 22 | min_word_frequency: 0 23 | 24 | # training and data-loading-related parameters 25 | loss: crossentropy 26 | weighted_loss: True 27 | batch_size: 64 28 | num_workers: 4 29 | pin_memory: True 30 | data_parallel: True 31 | cuda: True 32 | learning_rate: 0.0001 33 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 34 | epochs: 100 35 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 36 | patience: 20 # patience for the early stopping condition 37 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 38 | 39 | # ******************************************** 40 | # model structure 41 | # ******************************************** 42 | 43 | # visual feature extraction and pre-processing of images 44 | size: 448 45 | model: VQA_Base # VQA_MaskRegion, VQA_IgnoreMask 46 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 47 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 48 | 49 | # visual feature extraction and pre-processing of images 50 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 51 | size: 448 52 | imagenet_weights: True 53 | visual_extractor: resnet # options are resnet, 54 | 55 | # text feature extraction 56 | word_embedding_size: 300 57 | num_layers_LSTM: 1 58 | 59 | # attention 60 | attention: True 61 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 62 | number_of_glimpses: 2 63 | attention_dropout: 0.25 64 | attention_fusion: mul # options are cat, mul, sum 65 | 66 | # fusion 67 | fusion: cat # options are cat, mul, sum 68 | 69 | # classifier 70 | classifier_hidden_size: 1024 71 | classifier_dropout: 0.25 72 | -------------------------------------------------------------------------------- /config/dme/config_regionintext.yaml: -------------------------------------------------------------------------------- 1 | # logging 2 | comet_ml: False 3 | experiment_key: 4 | logs_dir: logs 5 | 6 | # dataset info 7 | dataset: dme 8 | path_data: data/DME_v1/ # must contain a folder named images and a folder named qa 9 | augment: True # only horizontal random flips in this case because rotations are not useful 10 | 11 | # training and data-loading-related parameters 12 | mask_as_text: True # whether mask should be in the question or as a separate mask 13 | 14 | # text pre-processing 15 | alt_questions: False # whether or not to use the questions that describe the region in the text (i.e. no external masks) 16 | max_question_length: 12 # length for questions that do not contain the region 17 | max_question_length_alt: 21 # length for questions that contain the region as text 18 | process_qa_again: False # whether or not QA pairs should be pre-processed again 19 | num_answers: 5 # binary in this case so we can use all answers. During QA pre-processing UNK token will be added so final number of possible answers will be this plus one. 20 | tokenizer: spacy # which tokenizer to use 21 | min_word_frequency: 0 22 | 23 | # training and data-loading-related parameters 24 | loss: crossentropy 25 | weighted_loss: True 26 | batch_size: 64 27 | num_workers: 4 28 | pin_memory: True 29 | data_parallel: True 30 | cuda: True 31 | learning_rate: 0.0001 32 | optimizer: adam # options are adam, adadelta, rmsprop, sgd 33 | epochs: 100 34 | train_from: scratch # whether or not to resume training from some checkpoint. Options are best, last, or scratch 35 | patience: 20 # patience for the early stopping condition 36 | metric_to_monitor: 'loss_val' # which metric to monitor to see when to consider change as improvement. eg loss_val, acc_val, auc_val, ap_val 37 | 38 | # ******************************************** 39 | # model structure 40 | # ******************************************** 41 | 42 | # visual feature extraction and pre-processing of images 43 | size: 448 44 | model: VQA_Base # VQA_MaskRegion, VQA_IgnoreMask 45 | attenuation_factor: 0.5 46 | visual_feature_size: 2048 # number of feature maps from the visual feature extractor 47 | question_feature_size: 1024 # size of embedded question (same as lstm_features) 48 | 49 | # visual feature extraction and pre-processing of images 50 | pre_extracted_visual_feat: False # must be false during visual feature extraction! 51 | size: 448 52 | imagenet_weights: True 53 | visual_extractor: resnet # options are resnet, 54 | 55 | # text feature extraction 56 | word_embedding_size: 300 57 | num_layers_LSTM: 1 58 | 59 | # attention 60 | attention: True 61 | attention_middle_size: 512 # size of the feature maps in the attention mechanism after the first operations (before internal fusion) 62 | number_of_glimpses: 2 63 | attention_dropout: 0.25 64 | attention_fusion: mul # options are cat, mul, sum 65 | 66 | # fusion 67 | fusion: cat # options are cat, mul, sum 68 | 69 | # classifier 70 | classifier_hidden_size: 1024 71 | classifier_dropout: 0.25 72 | -------------------------------------------------------------------------------- /testing/plot_predictions.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Script to plot some predictions from a model 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import sys 9 | sys.path.append('/home/sergio814/Documents/PhD/code/locvqa/') 10 | 11 | import os 12 | import json 13 | import pickle 14 | import random 15 | from tqdm import tqdm 16 | import torch 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | from os.path import join as jp 20 | from PIL import Image 21 | from collections import Counter 22 | from plot import plotter 23 | 24 | DATASET_NAME = {'cholec': 'CholecVQA', 'sts2017': 'STS2017'} 25 | 26 | dataset = 'cholec' 27 | version = 'v1' 28 | subset = 'test' # 'val' or 'test' 29 | config = '003' 30 | num_examples = 50 # number of examples to plot 31 | 32 | # define paths 33 | path_data = '/home/sergio814/Documents/PhD/code/data/Tools/' + DATASET_NAME[dataset] + '_' + version 34 | path_qa = jp(path_data, 'processed') 35 | path_images = jp(path_data, 'images') 36 | path_logs = '/home/sergio814/Documents/PhD/code/logs/' + dataset + '/config_' + config 37 | path_output = jp(path_logs, 'prediction_examples') 38 | os.makedirs(path_output, exist_ok=True) 39 | os.makedirs(jp(path_output, subset), exist_ok=True) 40 | 41 | # load questions 42 | with open(jp(path_qa, subset + 'set.pickle'), 'rb') as f: 43 | qa = pickle.load(f) 44 | 45 | # load dictionary idx to answer 46 | with open(jp(path_qa, 'map_index_answer.pickle'), 'rb') as f: 47 | idx2answer = pickle.load(f) 48 | 49 | # load predictions 50 | preds = torch.load(jp(path_logs, 'answers', 'answers_epoch_' + subset + '.pt'))['results'] # using results field, which has question_id and answer based on 0.5 threshold 51 | question_id2pred = {preds[i,0].item(): preds[i,1].item() for i in range(preds.shape[0])} 52 | 53 | # add predictions to qa 54 | for q in tqdm(qa): 55 | q['prediction'] = idx2answer[question_id2pred[q['question_id']]] 56 | 57 | # for now, I will focus on errouneous predictions 58 | qa_wrong = [q for q in qa if q['prediction'] != q['answer']] 59 | for i in range(num_examples): 60 | example = random.choice(qa_wrong) 61 | id_example = example['question_id'] 62 | path_image = jp(path_images, subset, example['image_name']) 63 | image = np.array(Image.open(path_image)) 64 | mask = np.zeros(example['mask_size'], dtype=np.uint8) 65 | mask[example['mask_coords'][0][0]:example['mask_coords'][0][0] + example['mask_coords'][1], example['mask_coords'][0][1]:example['mask_coords'][0][1] + example['mask_coords'][2]] = 255 66 | fig, ax = plt.subplots() 67 | plt.title(example['question'] + '\n' 68 | + 'GT: ' + example['answer'] + '\n' 69 | + 'Pred: ' + example['prediction'] + '\n' 70 | + 'Question id: ' + str(example['question_id'])) 71 | plotter.overlay_mask(image, mask, mask, alpha = 0.3, save = True, path_without_ext=jp(path_output, subset, str(i).zfill(3)), ax=ax, fig = fig) 72 | 73 | -------------------------------------------------------------------------------- /plot/plotter.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Plotting functions 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from matplotlib import pyplot as plt 9 | from os.path import join as jp 10 | import numpy as np 11 | 12 | def show_image(img, title=None, cmap=None): 13 | plt.imshow(img, cmap=cmap) 14 | if title is not None: 15 | plt.title(title) 16 | plt.show() 17 | 18 | def plot_learning_curve(metric_dict_train, metric_dict_val, metric_name, x_label='epoch', title="Learning curve", save=False, path=None): 19 | """ Input dictionaries are expected to have epoch indexes (string) as keys and floats as values""" 20 | fig = plt.figure() 21 | if metric_name == 'loss': 22 | top_val = max(max(list(metric_dict_train.values())), max(list(metric_dict_val.values()))) 23 | else: 24 | top_val = 1.0 25 | metric_name = metric_name.upper() 26 | 27 | # plot train metrics 28 | plt.plot([int(e) for e in metric_dict_train.keys()], list(metric_dict_train.values()), label=metric_name + ' train', linewidth=2, color='orange') 29 | # plot val metrics 30 | plt.plot([int(e) for e in metric_dict_val.keys()], list(metric_dict_val.values()), label=metric_name + ' val', linewidth=2, color='blue') 31 | plt.xticks([int(e) for e in metric_dict_train.keys()]) 32 | plt.grid() 33 | plt.title(title) 34 | plt.xlabel(x_label) 35 | plt.ylim((0, top_val)) 36 | plt.ylabel(metric_name) 37 | plt.legend() 38 | if save: 39 | if path is not None: 40 | plt.savefig(jp(path, metric_name + '.png'), dpi=300) 41 | else: 42 | raise ValueError 43 | 44 | 45 | def plot_roc_prc(roc, auc, prc, ap, title='ROC and PRC plots', save=True, path=None, suffix=''): 46 | f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) 47 | f.suptitle(title) 48 | # plot PRC 49 | ax1.plot(prc[1], prc[0], label = "PRC , AP: " + "{:.3f}".format(ap)) 50 | #ax1.plot([0, 1], [no_skill, no_skill], linestyle='--', color = colors[k], label='No Skill') 51 | ax1.set_xlabel("recall") 52 | ax1.set_ylabel("precision") 53 | ax1.grid() 54 | ax1.legend() 55 | 56 | # plot ROC 57 | ax2.plot(roc[0], roc[1],label = "ROC, AUC: " + "{:.3f}".format(auc)) 58 | #ax2.plot(fpr_dumb, tpr_dumb, linestyle="--", color = "gray", label="No Skill") 59 | ax2.set_xlabel("fpr") 60 | ax2.set_ylabel("tpr") 61 | ax2.grid() 62 | ax2.legend() 63 | 64 | if save and path is not None: 65 | plt.savefig(jp(path, 'ROC_PRC_' + suffix + '.png'), dpi=300) 66 | 67 | 68 | def overlay_mask(img, mask, gt, save= False, path_without_ext=None, alpha = 0.7, fig = None, ax = None): 69 | masked = np.ma.masked_where(mask ==0, mask) 70 | gt = np.ma.masked_where(gt==0, gt) 71 | if fig is None or ax is None: 72 | fig, ax = plt.subplots() 73 | ax.imshow(img, 'gray', interpolation='none') 74 | ax.imshow(masked, 'jet', interpolation='none', alpha=alpha) 75 | ax.imshow(gt, 'pink', interpolation='none', alpha=alpha) 76 | #fig.set_facecolor("black") 77 | fig.tight_layout() 78 | ax.axis('off') 79 | if save: 80 | plt.savefig(path_without_ext + '.png', bbox_inches='tight') 81 | plt.close() 82 | else: 83 | plt.show() -------------------------------------------------------------------------------- /core/datasets/visual.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Visual dataset handling 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import os 9 | from os.path import join as jp 10 | from PIL import Image 11 | import torchvision.transforms as transforms 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class ImagesDataset(Dataset): 16 | 17 | def __init__(self, subset, config, transform=None): 18 | self.subset = subset 19 | self.path_img = jp(config['path_data'], 'images', subset) 20 | self.transform = transform 21 | self.images = os.listdir(self.path_img) # list all images in folder 22 | self.map_name_index = {img:i for i, img in enumerate(self.images)} 23 | self.map_index_name = self.images 24 | 25 | def get_by_name(self, image_name): 26 | return self.__getitem__(self.map_name_index[image_name]) 27 | 28 | def __getitem__(self, index): 29 | sample = {} 30 | sample['name'] = self.map_index_name[index] 31 | sample['path'] = jp(self.path_img, sample['name']) # full relative path to image 32 | sample['visual'] = Image.open(sample['path']).convert('RGB') 33 | 34 | # apply transform(s) 35 | if self.transform is not None: 36 | sample['visual'] = self.transform(sample['visual']) 37 | 38 | return sample 39 | 40 | def __len__(self): 41 | return len(self.images) 42 | 43 | 44 | def default_transform(size): 45 | """Define basic (standard) transform for input images, as required by image processor 46 | 47 | Parameters 48 | ---------- 49 | size : int or tuple 50 | new size for the images 51 | 52 | Returns 53 | ------- 54 | torchvision transform 55 | composed transform for files 56 | """ 57 | transform = transforms.Compose([ 58 | transforms.Resize(size), 59 | transforms.CenterCrop(size), 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 62 | std=[0.229, 0.224, 0.225]) 63 | ]) 64 | return transform 65 | 66 | def default_inverse_transform(): 67 | # undoes basic ImageNet normalization 68 | transform = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], 69 | std = [ 1/0.229, 1/0.224, 1/0.225 ]), 70 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], 71 | std = [ 1., 1., 1. ]), 72 | ]) 73 | return transform 74 | 75 | def get_visual_dataset(subset, config, transform=None): 76 | """Get visual dataset either from images or from extracted features 77 | 78 | Parameters 79 | ---------- 80 | split : str 81 | split name (train, val, test, trainval) 82 | options_visual : dict 83 | visual options as determined in yaml file 84 | transform : torchvision transform, optional 85 | transform to be applied to images, by default None 86 | 87 | Returns 88 | ------- 89 | images dataset 90 | images dataset with images or feature maps (depending on options_visual['mode']) 91 | """ 92 | 93 | if transform is None: 94 | transform = default_transform(config['size']) 95 | visual_dataset = ImagesDataset(subset, config, transform) # create images dataset 96 | return visual_dataset -------------------------------------------------------------------------------- /testing/test_dataset.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Script to test a dataset in terms of question_id unicity, balance of answers, and visualization of random examples against the GT. 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import sys 9 | sys.path.append('/home/sergio814/Documents/PhD/code/locvqa/') 10 | 11 | import os 12 | import json 13 | import pickle 14 | import random 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | from os.path import join as jp 18 | from PIL import Image 19 | from collections import Counter 20 | from plot import plotter 21 | 22 | 23 | path_base = '/home/sergio814/Documents/PhD/code/data/Tools/INSEGCAT_v1' 24 | path_output = jp(path_base, 'test_output') 25 | os.makedirs(path_output, exist_ok=True) 26 | path_qa = jp(path_base, 'qa') 27 | path_processed = jp(path_base, 'processed') 28 | path_images = jp(path_base, 'images') 29 | 30 | subset = 'test' 31 | os.makedirs(jp(path_output, subset), exist_ok=True) 32 | n_examples = 50 33 | 34 | # load questions 35 | path_questions = jp(path_qa, subset + '_qa.json') 36 | with open(path_questions, 'r') as f: 37 | qa = json.load(f) 38 | 39 | path_processed_qa = jp(path_processed, subset + 'set.pickle') 40 | with open(path_processed_qa, 'rb') as f: 41 | processed_qa = pickle.load(f) 42 | 43 | # build dict question_id to entry from processed_qa 44 | processed_qa_dict = {q['question_id']: q for q in processed_qa} 45 | 46 | # first, check unicity of question ids 47 | question_ids = [q['question_id'] for q in qa] 48 | if len(question_ids) == len(set(question_ids)): 49 | print('PASSED: All question ids are unique') 50 | else: 51 | print('FAILED: There are repeated question ids') 52 | 53 | # second, check balance of answers 54 | answers = [q['answer'] for q in qa] 55 | answer_counts = Counter(answers).most_common() 56 | print('Answer counts:') 57 | for answer, count in answer_counts: 58 | print(answer, count) 59 | 60 | # third, visualize random examples 61 | print('Generating random examples. To be saved at', path_output) 62 | for i in range(n_examples): 63 | example = random.choice(qa) 64 | id_example = example['question_id'] 65 | processed_example = processed_qa_dict[id_example] 66 | path_image = jp(path_images, subset, example['image_name']) 67 | image = np.array(Image.open(path_image)) 68 | # build mask from coordinates 69 | mask = np.zeros(example['mask_size'], dtype=np.uint8) 70 | mask[example['mask_coords'][0][0]:example['mask_coords'][0][0] + example['mask_coords'][1], example['mask_coords'][0][1]:example['mask_coords'][0][1] + example['mask_coords'][2]] = 255 71 | # overlay mask on image, and question and answer as title 72 | fig, ax = plt.subplots() 73 | plt.title(example['question'] + ' ' + example['answer']) 74 | # compare information from qa and processed_qa 75 | # print('Question qa:', example['question'], 'Question processed_qa:', processed_example['question']) 76 | # print('Image name qa:', example['image_name'], 'Image name processed_qa:', processed_example['image_name']) 77 | # print('Answer qa:', example['answer'], 'Answer processed_qa:', processed_example['answer']) 78 | # print('Mask coords qa:', example['mask_coords'], 'Mask coords processed_qa:', processed_example['mask_coords']) 79 | # print('Mask size qa:', example['mask_size'], 'Mask size processed_qa:', processed_example['mask_size']) 80 | plotter.overlay_mask(image, mask, mask, alpha = 0.3, save = True, path_without_ext=jp(path_output, subset, str(i).zfill(3) + '_' + str(id_example)), ax=ax, fig = fig) 81 | 82 | -------------------------------------------------------------------------------- /misc/io.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # I/O file 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import yaml 9 | import argparse 10 | import pickle 11 | import json 12 | import torch 13 | import os 14 | from PIL import Image 15 | import numpy as np 16 | from matplotlib import pyplot as plt 17 | from os.path import join as jp 18 | 19 | def get_config_file_name(pre_extract = False, single=False): 20 | """Function to create CLI argument parser and return corresponding args 21 | 22 | pre_extract 23 | 24 | single 25 | Whether or not a single image is to be processed. 26 | 27 | Returns 28 | ------- 29 | parser 30 | argument parser 31 | """ 32 | parser = argparse.ArgumentParser( 33 | description='Read config file name', 34 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 35 | 36 | parser.add_argument('--path_config', default='vqa/config/vqa2/default.yaml', type=str, help='path to a yaml options file') 37 | 38 | if pre_extract: 39 | parser.add_argument('--subset', default='train', type=str, help='subset to be processed') 40 | 41 | if single: 42 | parser.add_argument('--path_image', default='/home/sergio814/Documents/PhD/code/data/to_inpaint/grade_0/inpainted/IDRiD_133.jpg', type=str, help='Path to image') 43 | parser.add_argument('--path_mask', default='/home/sergio814/Documents/PhD/code/data/dme_dataset_8_balanced/masks/train/maskA/whole_image_mask.tif', type=str, help='Path to mask') 44 | parser.add_argument('--question', default='What is the diabetic macular edema grade for this image?', type=str) 45 | parser.add_argument('--path_output', default=os.getcwd()) 46 | return parser.parse_args() 47 | 48 | 49 | def read_image(path): 50 | # function to read an image using PIL Image 51 | return np.array(Image.open(path)) 52 | 53 | def save_image(image, path): 54 | # function to save an image using PIL Image 55 | Image.fromarray(image).save(path) 56 | 57 | def read_weights(config): 58 | # Function to read (class) weights that come from the answer distribution and were previously computed using compute_answer_weights.py 59 | path_weights = jp(config['path_data'], 'answer_weights', 'w.pt') 60 | if not os.path.exists(path_weights): 61 | raise FileNotFoundError 62 | weights = torch.load(path_weights) 63 | return weights 64 | 65 | def read_config(path_config): 66 | """Function to read the config file from path_config 67 | 68 | Parameters 69 | ---------- 70 | path_config : str 71 | path to config file 72 | 73 | Returns 74 | ------- 75 | dict 76 | parsed config file 77 | """ 78 | with open(path_config, "r") as ymlfile: 79 | cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) 80 | return cfg 81 | 82 | 83 | def save_pickle(data, path): 84 | """Function to save a pickle file in the specified path 85 | 86 | Parameters 87 | ---------- 88 | data : list 89 | data to be saved 90 | path : str 91 | path including format for pickle file 92 | """ 93 | with open(path, 'wb') as f: 94 | pickle.dump(data, f) 95 | 96 | def save_json(data, path): 97 | with open(path, 'w') as f: 98 | json.dump(data, f) 99 | 100 | def read_json(path): 101 | with open(path, 'r') as f: 102 | return json.load(f) 103 | 104 | def read_pickle(path): 105 | """Function to read a pickle file from the specified path 106 | 107 | Parameters 108 | ---------- 109 | path : str 110 | path including format for pickle file 111 | 112 | Returns 113 | ------- 114 | list 115 | data read from pickle file 116 | """ 117 | with open(path, 'rb') as f: 118 | return pickle.load(f) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Main train file 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | # IMPORTANT: All configurations are made through the yaml config file which is located in config//.yaml. The path to this file is 9 | # specified using CLI arguments, with --path_config . If you don't use comet ml, set the parameter comet_ml to False 10 | 11 | import time 12 | import comet_ml 13 | import torch 14 | import misc.io as io 15 | from core.datasets import loaders_factory 16 | from core.models import model_factory 17 | from core.train_vault import criteria, optimizers, train_utils, looper, comet 18 | 19 | # read config name from CLI argument --path_config 20 | args = io.get_config_file_name() 21 | 22 | def main(): 23 | # read config file 24 | config = io.read_config(args.path_config) 25 | if 'draw_regions' in config: 26 | draw_regions = config['draw_regions'] 27 | else: 28 | draw_regions = False 29 | 30 | device = torch.device('cuda' if torch.cuda.is_available() and config['cuda'] else 'cpu') 31 | 32 | train_loader, vocab_words, vocab_answers, index_unk_answer = loaders_factory.get_vqa_loader('train', config, shuffle=True, draw_regions=draw_regions) 33 | 34 | print('Num batches train: ', len(train_loader)) 35 | print('Num samples train:', len(train_loader.dataset)) 36 | 37 | val_loader = loaders_factory.get_vqa_loader('val', config, shuffle=False, draw_regions=draw_regions) 38 | 39 | print('Num batches val: ', len(val_loader)) 40 | print('Num samples val:', len(val_loader.dataset)) 41 | 42 | model = model_factory.get_vqa_model(config, vocab_words, vocab_answers) 43 | 44 | if 'weighted_loss' in config: 45 | if config['weighted_loss']: 46 | answer_weights = io.read_weights(config) # if use of weights is required, read them from folder where they were previously saved using compute_answer_weights scripts 47 | else: 48 | answer_weights = None # If false, just set variable to None 49 | else: 50 | answer_weights = None 51 | 52 | # create criterion 53 | criterion = criteria.get_criterion(config, device, ignore_index = index_unk_answer, weights=answer_weights) 54 | 55 | # create optimizer 56 | optimizer, scheduler = optimizers.get_optimizer(config, model, add_scheduler=True) 57 | 58 | # initialize experiment 59 | start_epoch, comet_experiment, early_stopping, logbook, path_logs = train_utils.initialize_experiment(config, model, optimizer, args.path_config, lower_is_better=True) 60 | 61 | # log config info 62 | logbook.log_general_info('config', config) 63 | 64 | # get train and val functions 65 | train, validate = looper.get_looper_functions(config) 66 | 67 | # train loop 68 | for epoch in range(start_epoch, config['epochs']+1): 69 | print('Epoch: ', epoch) 70 | # train for one epoch 71 | train_epoch_metrics = train(train_loader, model, criterion, optimizer, device, epoch, config, logbook, comet_exp=comet_experiment) 72 | comet.log_metrics(comet_experiment, train_epoch_metrics, epoch) 73 | # validation 74 | val_epoch_metrics, val_results = validate(val_loader, model, criterion, device, epoch, config, logbook, comet_exp=comet_experiment) 75 | comet.log_metrics(comet_experiment, val_epoch_metrics, epoch) 76 | # run step of scheduler 77 | scheduler.step(val_epoch_metrics[config['metric_to_monitor']]) 78 | # save validation answers for current epoch 79 | train_utils.save_results(val_results, epoch, config, path_logs) 80 | logbook.save_logbook(path_logs) 81 | # check early stopping condition 82 | early_stopping(val_epoch_metrics, config['metric_to_monitor'], model, optimizer, epoch) 83 | # if patience was reached, stop train loop 84 | if early_stopping.early_stop: 85 | print("Early stopping") 86 | break 87 | 88 | if __name__ == '__main__': 89 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Localized Questions in Medical Visual Question Answering 2 | 3 | This is the official repository of the paper "Localized Questions in Medical Visual Question Answering," (MICCAI 2023). We also have a [Project Website](https://sergiotasconmorales.github.io/conferences/miccai2023.html). 4 | 5 | **Are you attending MICCAI 2023 in Vancouver? Let's connect! This is my [LinkedIn](https://www.linkedin.com/in/sergio-tascon/) or drop me an email at sergio.tasconmorales@unibe.ch.** 6 | 7 | 8 | Our paper presents a method to answer questions about regions by using localized attention. In localized attention, a target region can be given to the model so that answers are focused on a user-defined region. 9 | 10 | 11 | 🔥 Repo updates 12 | - [x] Data download 13 | - [x] Training 14 | - [x] Inference 15 | - [x] Metrics plotting 16 | - [x] Running the code in this repo to make sure everything works 17 | - [ ] Make sure everything works after updating requirement versions (update done on 06.05.2025) 18 | 19 | 20 | ## Installing requirements 21 | After cloning the repo, create a new environment with Python 3.9, activate it, and then install the required packages by running: 22 | 23 | pip install -r requirements.txt 24 | 25 | --- 26 | 27 | ## Data 28 | 29 | You can access the datasets [here](https://zenodo.org/record/8192556). After downloading the data, decompress it in the repo folder, and make sure they follow the following structure (i.e. rename the data folder to `data` so that you don't have to change the path to the data in the config files) 30 | 31 | **📂data**\ 32 | ┣ **📂STS2017_v1**   # RIS dataset\ 33 | ┣ **📂INSEGCAT_v1**   # INSEGCAT dataset\ 34 | ┗ **📂DME_v1**   # DME dataset\ 35 | 36 | Each of the above dataset folders should contain two folders: `images` and `qa`. A third folder named `processed` is created during dataset class instantiation when you run the training script. I included this processed data too, so that you can reproduce our results more easily (If you want to generate the processed data again, set `process_qa_again` to `True` in the config files). The DME dataset also contains a folder named `answer_weights` which contains the weights for the answers. The other two datasets do not require this, since they are balanced. 37 | 38 | 39 | --- 40 | 41 | ## Config files 42 | 43 | Please refer to the following table for the names of the config files that lead to the results of the different baselines. Note that in our paper we took the average of 5 models trained with different seeds, so if you train only once, do not expect to obtain the same results reported in the paper. 44 | 45 | | **Baseline** | **Config name** | 46 | |----------------|--------------------------| 47 | | No mask | config_nomask.yaml | 48 | | Region in text | config_regionintext.yaml | 49 | | Crop region | config_cropregion.yaml | 50 | | Draw region | config_drawregion.yaml | 51 | | Ours | config_ours.yaml | 52 | 53 | 54 | Notice that the files mentioned in the previous table are available for each dataset in the `config` folder. 55 | 56 | In the config files, do not forget to configure the paths according to your system. 57 | 58 | 59 | --- 60 | 61 | ## Training a model 62 | 63 | To train a model, run 64 | 65 | python train.py --path_config config//config_XX.yaml 66 | 67 | Where `` can be one of (`dme`, `insegcat`,`sts2017`) and XX should be changed according to the table above (note that `sts2017` corresponds to the dataset called `RIS` in the paper). The model weights will be stored in the logs folder specified in the config file. Weights and optimizer parameters are saved both for the best and last version of the model. A file named `logbook.json` will contain the config parameters as well as the values of the learning curves. In the folder `answers` the answers are stored for each epoch. 68 | 69 | --- 70 | 71 | ## Inference 72 | 73 | To run inference, run 74 | 75 | python inference.py --path_config config//config_XX.yaml 76 | 77 | after inference, the metrics are printed for the validation and test sets. Also, the folder `answers` will contain the answers files for test and validation (`answers_epoch_val.pt` and `answers_epoch_test.pt` ). 78 | 79 | --- 80 | 81 | ## Plotting results 82 | 83 | To plot the metrics, run 84 | 85 | python plot_metrics.py --path_config config//config_XX.yaml 86 | 87 | This will produce plots of the learning curves, as well as metrics for test and validation in the logs folder specified in the yaml config file. 88 | 89 | ## Citation 90 | 91 | This work was carried out at the [AIMI Lab](https://www.artorg.unibe.ch/research/aimi/index_eng.html) of the [ARTORG Center for Biomedical Engineering Research](https://www.artorg.unibe.ch) of the [University of Bern](https://www.unibe.ch/index_eng.html). Please cite this work as: 92 | 93 | > @inproceedings{tascon2023localized,\ 94 | title={Localized Questions in Medical Visual Question Answering},\ 95 | author={Tascon-Morales, Sergio and M{\'a}rquez-Neila, Pablo and Sznitman,Raphael},\ 96 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},\ 97 | pages={--},\ 98 | year={2023}\ 99 | organization={Springer} 100 | } 101 | 102 | --- 103 | 104 | ## Acknowledgements 105 | 106 | This project was partially funded by the Swiss National Science Foundation through grant 191983. 107 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Inference script to get results and plots 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import comet_ml 9 | import torch 10 | import misc.io as io 11 | from metrics import metrics 12 | from misc import printer 13 | from core.datasets import loaders_factory 14 | from core.models import model_factory 15 | from core.train_vault import criteria, optimizers, train_utils, looper 16 | 17 | args = io.get_config_file_name() 18 | 19 | def main(): 20 | # read config file 21 | config = io.read_config(args.path_config) 22 | 23 | config['train_from'] = 'best' # set this parameter to best so that best model is loaded for validation part 24 | config['comet_ml'] = False 25 | 26 | if 'draw_regions' in config: 27 | draw_regions = config['draw_regions'] 28 | else: 29 | draw_regions = False 30 | 31 | device = torch.device('cuda' if torch.cuda.is_available() and config['cuda'] else 'cpu') 32 | 33 | # get loaders 34 | train_loader, vocab_words, vocab_answers, index_unk_answer = loaders_factory.get_vqa_loader('train', config, shuffle=True, draw_regions=draw_regions) 35 | val_loader = loaders_factory.get_vqa_loader('val', config, shuffle=False, draw_regions=draw_regions) 36 | test_loader = loaders_factory.get_vqa_loader('test', config, shuffle=False, draw_regions=draw_regions) 37 | 38 | # get model 39 | model = model_factory.get_vqa_model(config, vocab_words, vocab_answers) 40 | 41 | if 'weighted_loss' in config: 42 | if config['weighted_loss']: 43 | answer_weights = io.read_weights(config) # if use of weights is required, read them from folder where they were previously saved using compute_answer_weights scripts 44 | else: 45 | answer_weights = None # If false, just set variable to None 46 | else: 47 | answer_weights = None 48 | 49 | # create criterion 50 | criterion = criteria.get_criterion(config, device, ignore_index = index_unk_answer, weights=answer_weights) 51 | 52 | # create optimizer 53 | optimizer, scheduler = optimizers.get_optimizer(config, model, add_scheduler=True) 54 | 55 | # get best epoch 56 | best_epoch, _, _, _, path_logs = train_utils.initialize_experiment(config, model, optimizer, args.path_config, lower_is_better=True) 57 | 58 | # get validation function 59 | _, validate = looper.get_looper_functions(config) 60 | 61 | metrics_test, results_test = validate(test_loader, model, criterion, device, 0, config, None, comet_exp=None) 62 | print("Test set was evaluated for epoch", best_epoch-1, "which was the epoch with the lowest", config['metric_to_monitor'], "during training") 63 | print(metrics_test) 64 | train_utils.save_results(results_test, 'test', config, path_logs) # test results saved as epoch 0 65 | 66 | metrics_val, results_val = validate(val_loader, model, criterion, device, 0, config, None, comet_exp=None) 67 | print("Metrics after inference on the val set, best epoch") 68 | print(metrics_val) 69 | train_utils.save_results(results_val, 'val', config, path_logs) 70 | 71 | # adding code to get results for each type of question 72 | # get question types 73 | qa_data_test = test_loader.dataset.dataset_qa 74 | qa_data_val = val_loader.dataset.dataset_qa 75 | # i need to generate a dict, where each key is a question type and the value is a tensor where the first row is the answer and the second one is the probability 76 | # i need to do this for the test and val set 77 | # first, let's get a dict question_id to question_type 78 | question_id_to_type_test = {e['question_id']:e['question_type'] for e in qa_data_test} 79 | types_test = list(set(question_id_to_type_test.values())) 80 | type2idx_test = {t:i for i,t in enumerate(types_test)} 81 | question_id_to_type_val = {e['question_id']:e['question_type'] for e in qa_data_val} 82 | types_val = list(set(question_id_to_type_val.values())) 83 | type2idx_val = {t:i for i,t in enumerate(types_val)} 84 | # dicts to store answers to get auc, acc and ap from 85 | answers_group_test = {} 86 | answers_group_val = {} 87 | typeidx_test = torch.zeros(len(results_test['answers']), dtype=torch.long) 88 | typeidx_val = torch.zeros(len(results_val['answers']), dtype=torch.long) 89 | for i in range(typeidx_test.shape[0]): 90 | typeidx_test[i] = type2idx_test[question_id_to_type_test[results_test['results'][i,0].item()]] 91 | for i in range(typeidx_val.shape[0]): 92 | typeidx_val[i] = type2idx_val[question_id_to_type_val[results_val['results'][i,0].item()]] 93 | # now for each type, get metrics 94 | for k,v in type2idx_test.items(): 95 | answers_group_test[k] = results_test['answers'][typeidx_test==v] 96 | for k,v in type2idx_val.items(): 97 | answers_group_val[k] = results_val['answers'][typeidx_val==v] 98 | # now get metrics 99 | printer.print_line() 100 | # test 101 | for k,v in answers_group_test.items(): 102 | auc_test, ap_test = metrics.compute_auc_ap(v) 103 | print("AUC for test set, question type", k, "is", '{:.3f}'.format(auc_test)) 104 | print("AP for test set, question type", k, "is", '{:.3f}'.format(ap_test)) 105 | # val 106 | for k,v in answers_group_val.items(): 107 | auc_val, ap_val = metrics.compute_auc_ap(v) 108 | print("AUC for val set, question type", k, "is", '{:.2f}'.format(auc_val)) 109 | print("AP for val set, question type", k, "is", '{:.2f}'.format(ap_val)) 110 | 111 | printer.print_line() 112 | if not 'question_object' in qa_data_test[0]: return 113 | # get results for each question object 114 | # get object types 115 | question_id_to_qobject_test = {e['question_id']:e['question_object'] for e in qa_data_test} 116 | qobjects_test = list(set(question_id_to_qobject_test.values())) 117 | qobject2idx_test = {t:i for i,t in enumerate(qobjects_test)} 118 | answers_group_test = {} 119 | qobjectidx_test = torch.zeros(len(results_test['answers']), dtype=torch.long) 120 | for i in range(qobjectidx_test.shape[0]): 121 | qobjectidx_test[i] = qobject2idx_test[question_id_to_qobject_test[results_test['results'][i,0].item()]] 122 | for k,v in qobject2idx_test.items(): 123 | answers_group_test[k] = results_test['answers'][qobjectidx_test==v] 124 | for k,v in answers_group_test.items(): 125 | auc_test, ap_test = metrics.compute_auc_ap(v) 126 | print("AUC for test set, question object", k, "is", '{:.3f}'.format(auc_test)) 127 | print("AP for test set, question object", k, "is", '{:.3f}'.format(ap_test)) 128 | 129 | 130 | if __name__ == '__main__': 131 | main() -------------------------------------------------------------------------------- /plot_metrics.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Script for metrics plotting 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from os.path import join as jp 9 | import misc.io as io 10 | import torch 11 | import json 12 | import pickle 13 | import os 14 | from plot import plotter 15 | from metrics import metrics 16 | from collections import Counter 17 | 18 | # read config name from CLI argument --path_config 19 | args = io.get_config_file_name() 20 | 21 | def main(): 22 | # read config file 23 | config = io.read_config(args.path_config) 24 | 25 | config_file_name = args.path_config.split("/")[-1].split(".")[0] 26 | 27 | path_logs = jp(config['logs_dir'], config['dataset'], config_file_name) 28 | 29 | # first, plot logged learning curves for all available metrics 30 | with open(jp(path_logs, 'logbook.json'), 'r') as f: 31 | logbook = json.load(f) 32 | 33 | general_info = logbook['general'] 34 | train_metrics = logbook['train'] 35 | val_metrics = logbook['val'] 36 | 37 | #* assumption: all reported train metrics were also reported for validation 38 | 39 | for (k_train, v_train), (k_val, v_val) in zip(train_metrics.items(), val_metrics.items()): 40 | assert k_train.split('_')[0] == k_val.split('_')[0] # check that metrics correspond 41 | metric_name = k_train.split('_')[0] 42 | 43 | plotter.plot_learning_curve(v_train, v_val, metric_name, title=general_info['config']['model'] + ' ' + config_file_name, save=True, path=path_logs) 44 | 45 | # if model is binary, plot ROC and PRC curves along with AUC and AP 46 | if config['num_answers'] == 2: 47 | # VAL 48 | # now go to answers folder and read info from there 49 | path_val_answers_file = jp(path_logs, 'answers', 'answers_epoch_val.pt') 50 | answers_best_val_epoch = torch.load(path_val_answers_file, map_location=torch.device('cpu')) # dictionary with keys: results, answers. results contains tensor with (question_index, model's answer), answers is (target, prob) 51 | 52 | auc_val, ap_val, roc_val, prc_val = metrics.compute_roc_prc(answers_best_val_epoch['answers']) 53 | plotter.plot_roc_prc(roc_val, auc_val, prc_val, ap_val, title='Validation plots', save=True, path=path_logs, suffix='val') 54 | 55 | # TEST 56 | # plot curves for test set, if it has been processed with inference.py 57 | path_test_answers_file = jp(path_logs, 'answers', 'answers_epoch_test.pt') 58 | if not os.path.exists(path_test_answers_file): 59 | raise Exception("Test set answers haven't been generated with inference.py") 60 | answers_test = torch.load(path_test_answers_file, map_location=torch.device('cpu')) 61 | 62 | auc_test, ap_test, roc_test, prc_test = metrics.compute_roc_prc(answers_test['answers']) 63 | plotter.plot_roc_prc(roc_test, auc_test, prc_test, ap_test, title='Test plots', save=True, path=path_logs, suffix='test') 64 | else: # for dme, compute accuracies for each type of question 65 | path_val_answers_file = jp(path_logs, 'answers', 'answers_epoch_val.pt') 66 | answers_best_val_epoch = torch.load(path_val_answers_file, map_location=torch.device('cpu')) # contains two columns, first one is question id and second one is the predicted answer 67 | id2pred = {answers_best_val_epoch[i,0].item(): answers_best_val_epoch[i,1].item() for i in range(answers_best_val_epoch.shape[0])} 68 | # open qa file to get question types 69 | path_qa_file_val = jp(config['path_data'], 'processed', 'valset.pickle') 70 | with open(path_qa_file_val, 'rb') as f: 71 | qa_val = pickle.load(f) 72 | # add prediction to qa_val 73 | for q in qa_val: 74 | q['prediction'] = id2pred[q['question_id']] 75 | # group questions by type 76 | types_counts = Counter([e['question_type'] for e in qa_val]).most_common() 77 | question_types = {e[0]:e[1] for e in types_counts} 78 | indexes_types = {e[0]:0 for e in types_counts} 79 | groups_type = {k:torch.zeros((v,2)) for k,v in question_types.items()} 80 | all_types = torch.zeros((len(qa_val),2)) 81 | # fill groups_type with targets and predictions 82 | for i, q in enumerate(qa_val): 83 | groups_type[q['question_type']][indexes_types[q['question_type']],0] = q['answer_index'] 84 | groups_type[q['question_type']][indexes_types[q['question_type']],1] = q['prediction'] 85 | indexes_types[q['question_type']] += 1 86 | all_types[i,0] = q['answer_index'] 87 | all_types[i,1] = q['prediction'] 88 | # compute accuracy for each type 89 | accuracies = {k:torch.eq(v[:,0], v[:,1]).sum()/v.shape[0] for k,v in groups_type.items()} 90 | # print accuracies 91 | print(config_file_name) 92 | print('Validation accuracies by type:') 93 | for k,v in accuracies.items(): 94 | print(f'{k}: {100*v:.2f}') 95 | print('Overall accuracy: {:.2f}'.format(100*torch.eq(all_types[:,0], all_types[:,1]).sum()/all_types.shape[0])) 96 | print('*'*50) 97 | 98 | # do exactly the same for test set 99 | path_test_answers_file = jp(path_logs, 'answers', 'answers_epoch_test.pt') 100 | if not os.path.exists(path_test_answers_file): 101 | raise Exception("Test set answers haven't been generated with inference.py") 102 | answers_test = torch.load(path_test_answers_file, map_location=torch.device('cpu')) 103 | id2pred = {answers_test[i,0].item(): answers_test[i,1].item() for i in range(answers_test.shape[0])} 104 | # open qa file to get question types 105 | path_qa_file_test = jp(config['path_data'], 'processed', 'testset.pickle') 106 | with open(path_qa_file_test, 'rb') as f: 107 | qa_test = pickle.load(f) 108 | # add prediction to qa_val 109 | for q in qa_test: 110 | q['prediction'] = id2pred[q['question_id']] 111 | # group questions by type 112 | types_counts = Counter([e['question_type'] for e in qa_test]).most_common() 113 | question_types = {e[0]:e[1] for e in types_counts} 114 | indexes_types = {e[0]:0 for e in types_counts} 115 | groups_type = {k:torch.zeros((v,2)) for k,v in question_types.items()} 116 | all_types = torch.zeros((len(qa_test),2)) 117 | # fill groups_type with targets and predictions 118 | for i, q in enumerate(qa_test): 119 | groups_type[q['question_type']][indexes_types[q['question_type']],0] = q['answer_index'] 120 | groups_type[q['question_type']][indexes_types[q['question_type']],1] = q['prediction'] 121 | indexes_types[q['question_type']] += 1 122 | all_types[i,0] = q['answer_index'] 123 | all_types[i,1] = q['prediction'] 124 | # compute accuracy for each type 125 | accuracies = {k:torch.eq(v[:,0], v[:,1]).sum()/v.shape[0] for k,v in groups_type.items()} 126 | # print accuracies 127 | print('Test accuracies by type:') 128 | for k,v in accuracies.items(): 129 | print(f'{k}: {100*v:.2f}') 130 | print('Overall accuracy: {:.2f}'.format(100*torch.eq(all_types[:,0], all_types[:,1]).sum()/all_types.shape[0])) 131 | print('*'*50) 132 | 133 | 134 | if __name__ == '__main__': 135 | main() -------------------------------------------------------------------------------- /core/train_vault/train_utils.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Training utilities 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from os.path import join as jp 9 | import numpy as np 10 | import shutil 11 | import torch 12 | import os 13 | from misc import dirs 14 | from . import comet, logbook 15 | 16 | 17 | def save_results(results, epoch_index, config, path_logs): 18 | # save tensor with indexes and answers produced by the model 19 | file_name = 'answers_epoch_' + str(epoch_index) + '.pt' 20 | path_answers = jp(path_logs, 'answers') 21 | dirs.create_folder(path_answers) 22 | torch.save(results, jp(path_answers, file_name)) 23 | return 24 | 25 | def sync_if_parallel(config): 26 | if config['data_parallel']: 27 | torch.cuda.synchronize() 28 | 29 | class EarlyStopping: 30 | """Early stops the training if validation loss doesn't improve after a given patience.""" 31 | def __init__(self, config, path_logs, lower_is_better=True): 32 | self.patience = config['patience'] 33 | self.path_log = path_logs 34 | self.verbose = True 35 | self.counter = 0 36 | self.lower_is_better = lower_is_better 37 | self.model_name = config['model'] 38 | self.best_score_new = None 39 | self.early_stop = False 40 | self.best_score_old = np.Inf 41 | self.data_parallel = config['data_parallel'] 42 | 43 | def __call__(self, metrics, metric_name, model, optimizer, epoch): 44 | 45 | score = metrics[metric_name] 46 | 47 | if self.best_score_new is None: 48 | self.best_score_new = score 49 | self.save_checkpoint(score, model, optimizer, self.path_log, metric_name, epoch) 50 | elif not self.lower_is_better and score <= self.best_score_new: 51 | self.counter += 1 52 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 53 | self.save_checkpoint(score, model, optimizer, self.path_log, metric_name, epoch, best=False) 54 | if self.counter >= self.patience: 55 | self.early_stop = True 56 | elif self.lower_is_better and score >= self.best_score_new: 57 | self.counter += 1 58 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 59 | self.save_checkpoint(score, model, optimizer, self.path_log, metric_name, epoch, best=False) 60 | if self.counter >= self.patience: 61 | self.early_stop = True 62 | else: 63 | self.best_score_new = score 64 | self.save_checkpoint(score, model, optimizer, self.path_log, metric_name, epoch) 65 | self.counter = 0 66 | 67 | def update_attributes(self, new_attributes): 68 | for k,v in new_attributes.items(): 69 | setattr(self, k, v) 70 | 71 | def save_checkpoint(self, score, model, optimizer, path_experiment, metric_name, epoch, best=True): 72 | '''Saves model when validation metric improves.''' 73 | if best: 74 | info_file_name = 'best_checkpoint_info.pt' 75 | model_file_name = 'best_checkpoint_model.pt' 76 | optimizer_file_name = 'best_checkpoint_optimizer.pt' 77 | early_stop_file_name = 'best_checkpoint_early_stop.pt' 78 | if self.verbose: 79 | print(f'Metric {metric_name} improved ({self.best_score_old:.4f} --> {self.best_score_new:.4f}). Saving model ...') 80 | else: 81 | info_file_name = 'last_checkpoint_info.pt' 82 | model_file_name = 'last_checkpoint_model.pt' 83 | optimizer_file_name = 'last_checkpoint_optimizer.pt' 84 | early_stop_file_name = 'last_checkpoint_early_stop.pt' 85 | 86 | # save info 87 | info = {'epoch': epoch, 'model': self.model_name, metric_name: score} 88 | torch.save(info, jp(path_experiment, info_file_name)) 89 | 90 | # save model parameters 91 | if not self.data_parallel: 92 | torch.save(model.state_dict(), jp(path_experiment, model_file_name)) 93 | else: 94 | torch.save(model.module.state_dict(), jp(path_experiment, model_file_name)) 95 | 96 | # save optimizer 97 | torch.save(optimizer.state_dict(), jp(path_experiment, optimizer_file_name)) 98 | 99 | # save parameters of early stop 100 | torch.save(vars(self), jp(path_experiment, early_stop_file_name)) 101 | 102 | # if it's best, make copy of pt files 103 | if best: 104 | shutil.copyfile(jp(path_experiment, info_file_name), jp(path_experiment, 'last_checkpoint_info.pt')) 105 | shutil.copyfile(jp(path_experiment, model_file_name), jp(path_experiment, 'last_checkpoint_model.pt')) 106 | shutil.copyfile(jp(path_experiment, optimizer_file_name), jp(path_experiment, 'last_checkpoint_optimizer.pt')) 107 | shutil.copyfile(jp(path_experiment, early_stop_file_name), jp(path_experiment, 'last_checkpoint_early_stop.pt')) 108 | 109 | self.best_score_old = self.best_score_new 110 | 111 | return 112 | 113 | 114 | def initialize_experiment(config, model, optimizer, path_config, lower_is_better=False, classif=False): 115 | path_logs = jp(config['logs_dir'], config['dataset'], path_config.split("/")[-1].split(".")[0]) 116 | if classif: 117 | path_logs = path_logs + '_classif' 118 | dirs.create_folder(path_logs) 119 | 120 | if config['train_from'] == 'scratch': # if training from scratch is desired 121 | dirs.clean_folder(path_logs) # remove old files if any 122 | # create comet ml experiment 123 | comet_experiment = comet.get_new_experiment(config, path_config) 124 | # instantiate early_stop object 125 | early_stopping = EarlyStopping(config, path_logs, lower_is_better=lower_is_better) # lower_is_better=False means metric to monitor E.S. should increase 126 | start_epoch = 1 127 | book = logbook.Logbook() 128 | 129 | elif config['train_from'] == 'last' or config['train_from'] == 'best': # from last saved checkpoint 130 | 131 | if config['comet_ml'] and config['experiment_key'] is None: 132 | raise ValueError("Please enter experiment key for comet experiment") 133 | 134 | # load info and model+optimizer parameters 135 | info = torch.load(jp(path_logs, config['train_from'] + '_checkpoint_info.pt')) 136 | if torch.cuda.is_available(): 137 | model_params = torch.load(jp(path_logs, config['train_from'] + '_checkpoint_model.pt')) 138 | optimizer_params = torch.load(jp(path_logs, config['train_from'] + '_checkpoint_optimizer.pt')) 139 | else: 140 | model_params = torch.load(jp(path_logs, config['train_from'] + '_checkpoint_model.pt'), map_location=torch.device('cpu')) 141 | optimizer_params = torch.load(jp(path_logs, config['train_from'] + '_checkpoint_optimizer.pt'), map_location=torch.device('cpu')) 142 | 143 | if not config['data_parallel']: 144 | model.load_state_dict(model_params) 145 | else: 146 | model.module.load_state_dict(model_params) 147 | optimizer.load_state_dict(optimizer_params) 148 | start_epoch = info['epoch'] + 1 149 | 150 | # resume comet experiment 151 | comet_experiment = comet.get_existing_experiment(config) 152 | 153 | # initialize early stopping 154 | early_stopping = EarlyStopping(config, path_logs, lower_is_better=lower_is_better) 155 | early_stopping_params = torch.load(jp(path_logs, config['train_from'] + '_checkpoint_early_stop.pt')) 156 | early_stopping.update_attributes(early_stopping_params) 157 | 158 | # create logbook 159 | book = logbook.Logbook() 160 | book.load_logbook(path_logs) 161 | 162 | else: 163 | raise ValueError("Wrong value for train_from option in config file. Options are best, last or scratch") 164 | 165 | return start_epoch, comet_experiment, early_stopping, book, path_logs -------------------------------------------------------------------------------- /core/train_vault/looper.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Train loops 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import torch 9 | from torch import nn 10 | from metrics import metrics 11 | 12 | def train(train_loader, model, criterion, optimizer, device, epoch, config, logbook, comet_exp=None, consistency_term=None, rels=None): 13 | # set train mode 14 | model.train() 15 | 16 | # Initialize variables for collecting metrics from all batches 17 | loss_epoch = 0.0 18 | running_loss = 0.0 19 | acc_epoch = 0.0 20 | 21 | for i, sample in enumerate(train_loader): 22 | 23 | # move data to GPU 24 | question = sample['question'].to(device) 25 | visual = sample['visual'].to(device) 26 | answer = sample['answer'].to(device) 27 | mask = sample['mask'].to(device) 28 | 29 | # clear parameter gradients 30 | optimizer.zero_grad() 31 | 32 | # get output from model 33 | output = model(visual, question, mask) 34 | 35 | # compute loss 36 | loss = criterion(output, answer) 37 | 38 | loss.backward() 39 | 40 | optimizer.step() 41 | 42 | 43 | running_loss += loss.item() 44 | if comet_exp is not None and i%10 == 9: # log every 10 iterations 45 | comet_exp.log_metric('loss_train_step', running_loss/10, step=len(train_loader)*(epoch-1) + i+1) 46 | running_loss = 0.0 47 | 48 | # compute accuracy 49 | acc = metrics.batch_strict_accuracy(output, answer) 50 | 51 | # laters: save to logger and print 52 | loss_epoch += loss.item() 53 | acc_epoch += acc.item() 54 | 55 | metrics_dict = {'loss_train': loss_epoch/len(train_loader.dataset), 'acc_train': acc_epoch/len(train_loader.dataset)} 56 | 57 | logbook.log_metrics('train', metrics_dict, epoch) 58 | 59 | return metrics_dict# returning average for all samples 60 | 61 | 62 | def validate(val_loader, model, criterion, device, epoch, config, logbook, comet_exp=None, consistency_term=None, rels=None): 63 | 64 | #if config['mainsub']: 65 | # denominator_acc = 2*len(val_loader.dataset) 66 | #else: 67 | denominator_acc = len(val_loader.dataset) 68 | 69 | # tensor to save results 70 | results = torch.zeros((denominator_acc, 2), dtype=torch.int64) 71 | 72 | # set evaluation mode 73 | model.eval() 74 | 75 | # Initialize variables for collecting metrics from all batches 76 | loss_epoch = 0.0 77 | acc_epoch = 0.0 78 | 79 | offset = 0 80 | with torch.no_grad(): 81 | for i, sample in enumerate(val_loader): 82 | batch_size = sample['question'].size(0) 83 | 84 | # move data to GPU 85 | question = sample['question'].to(device) 86 | visual = sample['visual'].to(device) 87 | answer = sample['answer'].to(device) 88 | question_indexes = sample['question_id'] # keep in cpu 89 | mask = sample['mask'].to(device) 90 | output = model(visual, question, mask) 91 | 92 | 93 | loss = criterion(output, answer) 94 | 95 | # compute accuracy 96 | acc = metrics.batch_strict_accuracy(output, answer) 97 | 98 | # save answer indexes and answers 99 | sm = nn.Softmax(dim=1) 100 | probs = sm(output) 101 | _, pred = probs.max(dim=1) 102 | 103 | results[offset:offset+batch_size,0] = question_indexes 104 | results[offset:offset+batch_size,1] = pred 105 | offset += batch_size 106 | 107 | loss_epoch += loss.item() 108 | acc_epoch += acc.item() 109 | 110 | metrics_dict = {'loss_val': loss_epoch/denominator_acc, 'acc_val': acc_epoch/denominator_acc} 111 | if logbook is not None: 112 | logbook.log_metrics('val', metrics_dict, epoch) 113 | 114 | return metrics_dict, results # returning averages for all samples 115 | 116 | 117 | 118 | # ------------------------------------------------------------------------------------------------ 119 | # functions for binary case 120 | # ------------------------------------------------------------------------------------------------ 121 | 122 | 123 | def train_binary(train_loader, model, criterion, optimizer, device, epoch, config, logbook, comet_exp=None, consistency_term=None): 124 | # tensor to save results 125 | results = torch.zeros((len(train_loader.dataset), 2), dtype=torch.int64) # to store question id, model's answer 126 | answers = torch.zeros(len(train_loader.dataset), 2) # store target answer, prob 127 | 128 | # set train mode 129 | model.train() 130 | 131 | # Initialize variables for collecting metrics from all batches 132 | loss_epoch = 0.0 133 | acc_epoch = 0.0 134 | offset = 0 135 | for i, sample in enumerate(train_loader): 136 | batch_size = sample['question'].size(0) 137 | 138 | # move data to GPU 139 | question = sample['question'].to(device) 140 | visual = sample['visual'].to(device) 141 | answer = sample['answer'].to(device) 142 | question_indexes = sample['question_id'] # keep in cpu 143 | mask = sample['mask'].to(device) 144 | 145 | # clear parameter gradients 146 | optimizer.zero_grad() 147 | # get output from model 148 | output = model(visual, question, mask) 149 | 150 | # compute loss 151 | loss = criterion(output.squeeze_(dim=-1), answer.float()) # cast to float because of BCEWithLogitsLoss 152 | 153 | loss.backward() 154 | 155 | optimizer.step() 156 | 157 | # add running loss 158 | loss_epoch += loss.item() 159 | # save probs and answers 160 | m = nn.Sigmoid() 161 | pred = m(output.data.cpu()) 162 | # compute accuracy 163 | acc = metrics.batch_binary_accuracy((pred > 0.5).float().to(device), answer) 164 | results[offset:offset+batch_size,:] = torch.cat((question_indexes.view(batch_size, 1), torch.round(pred.view(batch_size,1))), dim=1) 165 | answers[offset:offset+batch_size] = torch.cat((answer.data.cpu().view(batch_size, 1), pred.view(batch_size,1)), dim=1) 166 | offset += batch_size 167 | acc_epoch += acc.item() 168 | 169 | # compute AUC and AP for current epoch 170 | auc, ap = metrics.compute_auc_ap(answers) 171 | metrics_dict = {'loss_train': loss_epoch/len(train_loader.dataset), 'auc_train': auc, 'ap_train': ap, 'acc_train': acc_epoch/len(train_loader.dataset)} 172 | logbook.log_metrics('train', metrics_dict, epoch) 173 | return metrics_dict 174 | 175 | 176 | def validate_binary(val_loader, model, criterion, device, epoch, config, logbook, comet_exp=None): 177 | # tensor to save results 178 | results = torch.zeros((len(val_loader.dataset), 2), dtype=torch.int64) # to store question id, model's answer 179 | answers = torch.zeros(len(val_loader.dataset), 2) # store target answer, prob 180 | 181 | # set evaluation mode 182 | model.eval() 183 | 184 | # Initialize variables for collecting metrics from all batches 185 | loss_epoch = 0.0 186 | acc_epoch = 0.0 187 | offset = 0 188 | with torch.no_grad(): 189 | for i, sample in enumerate(val_loader): 190 | batch_size = sample['question'].size(0) 191 | 192 | # move data to GPU 193 | question = sample['question'].to(device) 194 | visual = sample['visual'].to(device) 195 | answer = sample['answer'].to(device) 196 | question_indexes = sample['question_id'] # keep in cpu 197 | mask = sample['mask'].to(device) 198 | 199 | # get output from model 200 | output = model(visual, question, mask) 201 | 202 | # compute loss 203 | loss = criterion(output.squeeze_(dim=-1), answer.float()) 204 | 205 | # save probs and answers 206 | m = nn.Sigmoid() 207 | pred = m(output.data.cpu()) 208 | # compute accuracy 209 | acc = metrics.batch_binary_accuracy((pred > 0.5).float().to(device), answer) 210 | results[offset:offset+batch_size,0] = question_indexes 211 | results[offset:offset+batch_size,1] = torch.round(pred) 212 | answers[offset:offset+batch_size] = torch.cat((answer.data.cpu().view(batch_size, 1), pred.view(batch_size,1)), dim=1) 213 | offset += batch_size 214 | 215 | loss_epoch += loss.item() 216 | acc_epoch += acc.item() 217 | 218 | # compute AUC and AP for current epoch for all samples, using info in results 219 | auc, ap = metrics.compute_auc_ap(answers) 220 | metrics_dict = {'loss_val': loss_epoch/len(val_loader.dataset), 'auc_val': auc, 'ap_val': ap, 'acc_val': acc_epoch/len(val_loader.dataset)} 221 | if logbook is not None: 222 | logbook.log_metrics('val', metrics_dict, epoch) 223 | return metrics_dict, {'results': results, 'answers': answers} 224 | 225 | 226 | def get_looper_functions(config): 227 | if config['num_answers'] == 2: 228 | train_fn = train_binary 229 | val_fn = validate_binary 230 | else: 231 | train_fn = train 232 | val_fn = validate 233 | return train_fn, val_fn -------------------------------------------------------------------------------- /core/datasets/nlp.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # NLP functions and classes 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from collections import Counter 9 | import re 10 | from tqdm import tqdm 11 | from os.path import join as jp 12 | import itertools 13 | import json 14 | import torch 15 | import pickle 16 | import os 17 | 18 | def get_top_answers(answers, nans=2000): 19 | counts = Counter(answers).most_common() # get counts 20 | if len(set(answers)) == nans: # for binary case, return both answers 21 | return [e[0] for e in counts] # in this case return all answers, ordered from most to least frequent 22 | top_answers = [elem[0] for elem in counts[:nans]] 23 | return top_answers 24 | 25 | def clean_text(text): 26 | text = text.lower().replace("\n", " ").replace("\r", " ") 27 | # replace numbers and punctuation with space 28 | punc_list = '!"#$%&()*+,-./:;<=>?@[\]^_{|}~' # * Leave numbers because some questions use coordinates+ '0123456789' 29 | t = str.maketrans(dict.fromkeys(punc_list, " ")) 30 | text = text.translate(t) 31 | 32 | # replace single quote with empty character 33 | t = str.maketrans(dict.fromkeys("'`", "")) 34 | text = text.translate(t) 35 | 36 | # remove double spaces 37 | text = text.replace(" ", " ") 38 | 39 | return text 40 | 41 | 42 | def tokenizer_nltk(text, tokenizer): 43 | text = clean_text(text) 44 | tokens = tokenizer(text) 45 | return tokens 46 | 47 | def tokenizer_spacy(text, tokenizer): 48 | text = clean_text(text) 49 | tokens = list(tokenizer(text)) 50 | tokens_list_of_strings = [str(token) for token in tokens] 51 | return tokens_list_of_strings 52 | 53 | def tokenizer_re(text): 54 | WORD = re.compile(r'\w+') 55 | text = clean_text(text) 56 | tokens = WORD.findall(text) 57 | return tokens 58 | 59 | def add_tokens(qa_samples, tokenizer_name): 60 | """Function to add tokens to data 61 | 62 | Parameters 63 | ---------- 64 | qa_samples : _type_ 65 | _description_ 66 | tokenizer_name : _type_ 67 | _description_ 68 | 69 | Returns 70 | ------- 71 | list 72 | original list of samples with each sample having a new field for the tokens 73 | 74 | Raises 75 | ------ 76 | ValueError 77 | _description_ 78 | """ 79 | if tokenizer_name == 'nltk': 80 | from nltk import word_tokenize 81 | elif tokenizer_name == 'spacy': 82 | from spacy.tokenizer import Tokenizer 83 | from spacy.lang.en import English 84 | lang = English() 85 | tokenizer = Tokenizer(lang.vocab) 86 | 87 | 88 | for elem in tqdm(qa_samples): 89 | question_text = elem['question'] 90 | if tokenizer_name == 'nltk': 91 | elem['question_tokens'] = tokenizer_nltk(question_text, word_tokenize) 92 | elif tokenizer_name == 'spacy': 93 | elem['question_tokens'] = tokenizer_spacy(question_text, tokenizer) 94 | elif tokenizer_name == 're': 95 | elem['question_tokens'] = tokenizer_re(question_text) 96 | else: 97 | raise ValueError('Unknown tokenizer') 98 | 99 | return qa_samples 100 | 101 | 102 | def add_UNK_token_and_build_word_maps(data, min_word_frequency): 103 | # function to build vocabulary from question words and then build maps to indexes 104 | all_words_in_all_questions = list(itertools.chain.from_iterable(elem['question_tokens'] for elem in data)) 105 | # count and sort 106 | counts = Counter(all_words_in_all_questions).most_common() 107 | # get list of words (vocabulary for questions) 108 | vocab_words_in_questions = [elem[0] for elem in counts if elem[1] > min_word_frequency] 109 | # add_entry for tokens with UNK to data 110 | for elem in tqdm(data): 111 | elem['question_tokens_with_UNK'] = [w if w in vocab_words_in_questions else 'UNK' for w in elem['question_tokens']] 112 | # build maps 113 | vocab_words_in_questions.append('UNK') # Add UNK to the vocabulary 114 | map_word_index = {elem:i+1 for i,elem in enumerate(vocab_words_in_questions)} #* +1 to avoid same symbol of padding 115 | map_index_word = {v:k for k,v in map_word_index.items()} 116 | return data, map_word_index, map_index_word 117 | 118 | 119 | def add_UNK_token(data, vocab): 120 | for elem in tqdm(data): 121 | elem['question_tokens_with_UNK'] = [w if w in vocab else 'UNK' for w in elem['question_tokens']] 122 | return data 123 | 124 | 125 | 126 | def encode_questions(data, map_word_index, question_vector_length): 127 | for elem in tqdm(data): 128 | # add question length 129 | elem['question_length'] = min(question_vector_length, len(elem['question_tokens_with_UNK'])) 130 | elem['question_word_indexes'] = [0]*question_vector_length # create list with question_vector_length zeros 131 | for i, word in enumerate(elem['question_tokens_with_UNK']): 132 | if i < question_vector_length: 133 | # using padding to the right. Add padding left? 134 | elem['question_word_indexes'][i] = map_word_index[word] # replace word with index in vocabulary 135 | return data 136 | 137 | 138 | 139 | def encode_answers(data, map_answer_index): 140 | # function to encode answers. If they are not in the answer vocab, they are mapped to -1 141 | if 'answers_occurence' in data[0]: # if there are multiple answers (VQA2 dataset) 142 | for i, elem in enumerate(data): 143 | answers = [] 144 | answers_indexes = [] 145 | answers_count = [] 146 | unknown_answer_symbol = map_answer_index['UNK'] 147 | elem['answer_index'] = map_answer_index.get(elem['answer'], unknown_answer_symbol) # unknown_answer_symbol for unknown answers 148 | for answer in elem['answers_occurence']: 149 | answer_index = map_answer_index.get(answer[0], unknown_answer_symbol) 150 | #if answer_index != unknown_answer_symbol: 151 | answers += answer[1]*[answer[0]] # add all answers 152 | answers_indexes += answer[1]*[answer_index] 153 | answers_count.append(answer[1]) 154 | elem['answers'] = answers 155 | elem['answers_indexes'] = answers_indexes 156 | elem['answers_counts'] = answers_count 157 | else: 158 | for elem in tqdm(data): 159 | unknown_answer_symbol = map_answer_index['UNK'] 160 | elem['answer_index'] = map_answer_index.get(elem['answer'], unknown_answer_symbol) # unknown_answer_symbol for unknown answers 161 | return data 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | def process_qa(config, data_train, data_val, data_test, data_testdev = None, alt_questions = False): 170 | 171 | if alt_questions: 172 | max_question_len = config['max_question_length_alt'] 173 | else: 174 | max_question_len = config['max_question_length'] 175 | 176 | # function to process questions and answers using functions from nlp.py This function can be used on other datasets 177 | all_answers = [elem['answer'] for elem in data_train] 178 | 179 | # get top answers 180 | top_answers = get_top_answers(all_answers, config['num_answers']) 181 | 182 | # get maps for answers 183 | top_answers.append('UNK') # add unknown symbol answer 184 | map_answer_index = {elem:i for i, elem in enumerate(top_answers)} 185 | map_index_answer = top_answers.copy() 186 | 187 | # remove examples for which answer is not in top answers 188 | # data_train = nlp.remove_examples_if_answer_not_common(data_train, top_answers) 189 | 190 | # tokenize questions for each subset 191 | print('Tokenizing questions...') 192 | data_train = add_tokens(data_train, config['tokenizer']) 193 | data_val = add_tokens(data_val, config['tokenizer']) 194 | data_test = add_tokens(data_test, config['tokenizer']) 195 | if data_testdev is not None: 196 | data_testdev = add_tokens(data_testdev, config['tokenizer']) 197 | 198 | # insert UNK tokens and build word maps 199 | print("Adding UNK tokens...") 200 | data_train, map_word_index, map_index_word = add_UNK_token_and_build_word_maps(data_train, config['min_word_frequency']) 201 | words_vocab_list = list(map_index_word.values()) 202 | data_val = add_UNK_token(data_val, words_vocab_list) 203 | data_test = add_UNK_token(data_test, words_vocab_list) 204 | if data_testdev is not None: 205 | data_testdev = add_UNK_token(data_testdev, words_vocab_list) 206 | 207 | # encode questions 208 | print("Encoding questions...") 209 | data_train = encode_questions(data_train, map_word_index, max_question_len) 210 | data_val = encode_questions(data_val, map_word_index, max_question_len) 211 | data_test = encode_questions(data_test, map_word_index, max_question_len) 212 | if data_testdev is not None: 213 | data_testdev = encode_questions(data_testdev, map_word_index, max_question_len) 214 | 215 | # encode answers 216 | print("Encoding answers...") 217 | data_train = encode_answers(data_train, map_answer_index) 218 | data_val = encode_answers(data_val, map_answer_index) 219 | if 'answer' in data_test[0]: # if test set has answers 220 | data_test = encode_answers(data_test, map_answer_index) 221 | 222 | # build return dictionaries 223 | if data_testdev is not None: 224 | sets = {'trainset': data_train, 'valset': data_val, 'testset': data_test, 'testdevset': data_testdev} 225 | else: 226 | sets = {'trainset': data_train, 'valset': data_val, 'testset': data_test} 227 | maps = {'map_index_word': map_index_word, 'map_word_index': map_word_index, 'map_index_answer': map_index_answer, 'map_answer_index': map_answer_index} 228 | 229 | # sets: {'trainset': trainset, 'valset': valset, 'testset': testset, 'testdevset': testdevset} 230 | # maps: {'map_index_word': map_index_word, 'map_word_index': map_word_index, 'map_index_answer': map_index_answer, 'map_answer_index': map_answer_index} 231 | return sets, maps -------------------------------------------------------------------------------- /core/datasets/vqa.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # VQA dataset classes 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | 9 | import pickle 10 | import os 11 | import json 12 | import random 13 | from os.path import join as jp 14 | import torch 15 | from tqdm import tqdm 16 | import numpy as np 17 | import pandas as pd 18 | from PIL import Image, ImageDraw 19 | import torchvision.transforms as T 20 | from torch.utils.data import Dataset 21 | import torchvision.transforms.functional as TF 22 | from copy import deepcopy 23 | 24 | from misc import io 25 | from . import nlp 26 | from . import visual as vis 27 | 28 | 29 | class VQABase(Dataset): 30 | def __init__(self, subset, config, dataset_visual): 31 | self.subset = subset 32 | self.config = config 33 | self.dataset_visual = dataset_visual 34 | self.path_annotations_and_questions = jp(config['path_data'], 'qa') 35 | self.path_processed = jp(config['path_data'], 'processed') 36 | if 'mask_as_text' in config: 37 | self.mask_as_text = config['mask_as_text'] 38 | else: 39 | self.mask_as_text = False 40 | if not os.path.exists(self.path_processed) or len(os.listdir(self.path_processed))<1 or (subset == 'train' and config['process_qa_again']): 41 | self.pre_process_qa() # pre-process qa, produce pickle files 42 | 43 | # load pre-processed qa 44 | self.read_prep_rocessed(self.path_processed) 45 | 46 | def pre_process_qa(self): 47 | raise NotImplementedError # to be implemented in baby class 48 | 49 | def read_prep_rocessed(self, path_files): 50 | # define paths 51 | path_map_index_word = jp(path_files, 'map_index_word.pickle') 52 | path_map_word_index = jp(path_files, 'map_word_index.pickle') 53 | path_map_index_answer = jp(path_files, 'map_index_answer.pickle') 54 | path_map_answer_index = jp(path_files, 'map_answer_index.pickle') 55 | path_dataset = jp(path_files, self.subset + 'set.pickle') 56 | 57 | # read files 58 | with open(path_map_index_word, 'rb') as f: 59 | self.map_index_word = pickle.load(f) 60 | with open(path_map_word_index, 'rb') as f: 61 | self.map_word_index = pickle.load(f) 62 | with open(path_map_index_answer, 'rb') as f: 63 | self.map_index_answer = pickle.load(f) 64 | with open(path_map_answer_index, 'rb') as f: 65 | self.map_answer_index = pickle.load(f) 66 | with open(path_dataset, 'rb') as f: 67 | self.dataset_qa = pickle.load(f) 68 | 69 | # save unknown answer index 70 | self.index_unknown_answer = self.map_answer_index['UNK'] 71 | 72 | def __getitem__(self, index): 73 | sample = {} 74 | 75 | # get qa pair 76 | item_qa = self.dataset_qa[index] 77 | 78 | # get visual 79 | sample['visual'] = self.dataset_visual.get_by_name(item_qa['image_name'])['visual'] 80 | 81 | # get question 82 | sample['question_id'] = item_qa['question_id'] 83 | sample['question'] = torch.LongTensor(item_qa['question_word_indexes']) 84 | 85 | # get answer 86 | sample['answer'] = item_qa['answer_index'] 87 | if 'answer_indexes' in item_qa: # trick so that this class can be used with non-vqa2 data 88 | sample['answers'] = item_qa['answers_indexes'] 89 | 90 | return sample 91 | 92 | def __len__(self): 93 | return len(self.dataset_qa) 94 | 95 | 96 | class VQARegionsSingle(VQABase): 97 | """Class for dataloader that contains questions about a single region 98 | 99 | Parameters 100 | ---------- 101 | VQABase : Parent class 102 | Base class for VQA dataset. 103 | """ 104 | def __init__(self, subset, config, dataset_visual, draw_regions=False): 105 | super().__init__(subset, config, dataset_visual) 106 | self.augment = config['augment'] 107 | self.draw_regions = draw_regions 108 | 109 | def transform(self, image, mask, size): 110 | 111 | if self.subset == 'train': # only for training samples 112 | 113 | # Random horizontal flipping 114 | if random.random() > 0.5: 115 | image = TF.hflip(image) 116 | mask = TF.hflip(mask) 117 | 118 | ## random rotation in small range 119 | #if random.random() > 0.5: 120 | # angle = random.randint(-10, 10) 121 | # image = TF.rotate(image, angle) 122 | # mask = TF.rotate(mask, angle) 123 | 124 | # Transform to tensor 125 | if not torch.is_tensor(image): 126 | image = TF.to_tensor(image) 127 | if not torch.is_tensor(mask): 128 | mask = TF.to_tensor(mask) 129 | return image, mask 130 | 131 | def get_mask(self, mask_coords, mask_size): 132 | # mask_coords has the format ((y,x), h, w) 133 | if self.config['dataset'] == 'dme': # requires ellipse regions 134 | mask_ref = Image.new('L', mask_size, 0) 135 | mask = ImageDraw.Draw(mask_ref) 136 | mask.ellipse([(mask_coords[0][1], mask_coords[0][0]),(mask_coords[0][1] + mask_coords[2], mask_coords[0][0] + mask_coords[1])], fill=1) 137 | mask = torch.from_numpy(np.array(mask_ref)) 138 | else: 139 | mask = torch.zeros(mask_size, dtype=torch.uint8) 140 | mask[mask_coords[0][0]:mask_coords[0][0]+mask_coords[1] , mask_coords[0][1]:mask_coords[0][1]+mask_coords[2]] = 1 141 | return mask.unsqueeze_(0) 142 | 143 | def draw_region(self, img, coords, r=2): 144 | if self.config['dataset'] == 'dme': # requires ellipse regions 145 | img_ref = T.ToPILImage()(img) 146 | ((y,x), h, w) = coords 147 | draw = ImageDraw.Draw(img_ref) 148 | draw.ellipse([(x, y),(x + w, y + h)], outline='red') 149 | img_ref = np.array(img_ref) 150 | img_ref = img_ref.transpose(2,0,1) 151 | img_ref = torch.from_numpy(img_ref) 152 | return img_ref 153 | else: 154 | ((y,x), h, w) = coords 155 | 156 | for i in range(3): 157 | img[i, y-r:y+h+r, x-r:x+r] = 0 158 | img[i, y-r:y+r, x-r:x+w+r] = 0 159 | img[i, y-r:y+h+r, x+w-r:x+w+r] = 0 160 | img[i, y+h-r:y+h+r, x-r:x+w+r] = 0 161 | 162 | # set red channel line to red 163 | img[0, y-r:y+h+r, x-r:x+r] = 1 164 | img[0, y-r:y+r, x-r:x+w+r] = 1 165 | img[0, y-r:y+h+r, x+w-r:x+w+r] = 1 166 | img[0, y+h-r:y+h+r, x-r:x+w+r] = 1 167 | return img 168 | 169 | def get_by_question_id(self, question_id): 170 | for i in range(len(self.dataset_qa)): 171 | if self.dataset_qa[i]['question_id'] == question_id: 172 | return self.__getitem__(i) 173 | 174 | def regenerate(self, question_ids): 175 | # reduce self.dataset_qa to only contain question_ids 176 | temp = [item for item in self.dataset_qa if item['question_id'] in question_ids] 177 | self.dataset_qa = temp 178 | 179 | # override getitem method 180 | def __getitem__(self, index): 181 | sample = {} 182 | 183 | # get qa pair 184 | item_qa = self.dataset_qa[index] 185 | 186 | # get visual 187 | visual = self.dataset_visual.get_by_name(item_qa['image_name'])['visual'] 188 | if self.draw_regions: 189 | # first, apply inverse transform to get original image 190 | visual = vis.default_inverse_transform()(visual) 191 | visual = self.draw_region(visual, item_qa['mask_coords']) 192 | visual = vis.default_transform(self.config['size'])(T.ToPILImage()(visual)) 193 | mask = self.get_mask(item_qa['mask_coords'], item_qa['mask_size']) 194 | 195 | if self.augment: 196 | sample['visual'], sample['mask'] = self.transform(visual, mask, 448) 197 | else: 198 | sample['visual'] = visual 199 | sample['mask'] = mask 200 | 201 | # get question 202 | sample['question_id'] = item_qa['question_id'] 203 | 204 | # if mask should be included in the questions 205 | 206 | sample['question'] = torch.LongTensor(item_qa['question_word_indexes']) 207 | 208 | # get answer 209 | sample['answer'] = item_qa['answer_index'] 210 | 211 | return sample 212 | 213 | # define preprocessing method for qa pairs 214 | def pre_process_qa(self): 215 | 216 | # define paths to save pickle files. Have to process all of them at the same time because the train set determines possible answers and vocabularies 217 | data_train = json.load(open(jp(self.path_annotations_and_questions, 'train_qa.json'), 'r')) 218 | data_val = json.load(open(jp(self.path_annotations_and_questions, 'val_qa.json'), 'r')) 219 | data_test = json.load(open(jp(self.path_annotations_and_questions, 'test_qa.json'), 'r')) 220 | 221 | if self.mask_as_text: 222 | # exchange question_alt to question 223 | for data in tqdm([data_train, data_val, data_test], desc='mask_as_text is set to True, therefore alt questions are used'): 224 | for item in data: 225 | item['question'], item['question_alt'] = item['question_alt'], item['question'] 226 | 227 | sets, maps = nlp.process_qa(self.config, data_train, data_val, data_test, alt_questions=self.mask_as_text) 228 | 229 | # define paths to save pickle files 230 | if not os.path.exists(self.path_processed): 231 | os.mkdir(self.path_processed) 232 | for name, data in sets.items(): 233 | io.save_pickle(data, jp(self.path_processed, name + '.pickle')) 234 | for name, data in maps.items(): 235 | io.save_pickle(data, jp(self.path_processed, name + '.pickle')) 236 | 237 | 238 | 239 | def get_vqa_dataset(subset, config, dataset_visual, draw_regions=False): 240 | # provides dataset class for current training config 241 | if config['dataset'] in ['cholec', 'sts2017', 'insegcat', 'dme']: 242 | dataset_vqa = VQARegionsSingle(subset, config, dataset_visual, draw_regions=draw_regions) 243 | elif config['dataset'] == 'IDRID': 244 | raise NotImplementedError 245 | else: 246 | dataset_vqa = VQABase(subset, config, dataset_visual) 247 | 248 | return dataset_vqa -------------------------------------------------------------------------------- /plot/visualize_att.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Guided attention visualization 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from matplotlib import pyplot as plt 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from os.path import join as jp 13 | from core.datasets.visual import default_inverse_transform as dit 14 | import cv2 15 | 16 | m = nn.Sigmoid() 17 | sm = nn.Softmax(dim=2) 18 | 19 | att = {} 20 | def get_att_map(name): 21 | def hook(model, input, output): 22 | att[name] = output.detach() 23 | return hook 24 | 25 | def plot_attention_maps(model_name, model, visual, question, mask, answer, vocab_words, path_logs, question_indexes, vocab_answers): 26 | # if data parallel, get the model 27 | if model.__class__.__name__ == 'DataParallel': 28 | model = model.module 29 | if model_name == 'VQA_MaskRegion': 30 | model.attention_mechanism.conv2.register_forward_hook(get_att_map('attention_mechanism.conv2')) 31 | output = model(visual, question, mask) 32 | pred = (m(output.data.cpu())>0.5).to(torch.int64) 33 | k = att['attention_mechanism.conv2'] # size 64, 2, 14, 14 for sts 34 | h = k.clone() 35 | h = h.view(output.shape[0],2,14*14) 36 | h_out = sm(h) 37 | g_out = h_out.view(output.shape[0],2,14,14) 38 | for i_s in range(g_out.shape[0]): # for every element of the batch 39 | image = dit()(visual[i_s]).permute(1,2,0).cpu().numpy() 40 | plt.ioff() 41 | f, ax = plt.subplots(1, 3) 42 | f.tight_layout() 43 | ax[0].imshow(image) 44 | ax[0].axis('off') 45 | question_words_encoded = [vocab_words[question[i_s, i].item()] for i in range(question.shape[1]) if question[i_s, i].item()!= 0] 46 | question_text = ' '.join(question_words_encoded) 47 | f.suptitle(question_text + "\n GT: " + str(vocab_answers[answer[i_s].item()]) + ', Pred: ' + str(vocab_answers[pred[i_s].item()])) 48 | ax[0].set_title('Image') 49 | #ax[0].set_title(question_text + "\n, GT: " + str(vocab_answers[answer[i_s].item()])) 50 | if pred[i_s].item() == answer[i_s].item(): 51 | f.set_facecolor("green") 52 | else: 53 | f.set_facecolor("r") 54 | for i_glimpse in range(g_out.shape[1]): # for every glimpse 55 | img1 = g_out[i_s, i_glimpse, :, :].cpu().numpy() 56 | heatmap = cv2.resize(img1, (image.shape[1], image.shape[0])) 57 | heatmap = (heatmap - np.min(heatmap))/(np.max(heatmap) - np.min(heatmap)) 58 | heatmap = np.uint8(255*heatmap) 59 | #heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 60 | norm = plt.Normalize() 61 | heatmap = plt.cm.jet(norm(heatmap)) 62 | superimposed = heatmap[:,:,:3] * 0.4 + image*mask[i_s].permute(1,2,0).cpu().numpy() 63 | superimposed = 255*(superimposed - np.min(superimposed))/(np.max(superimposed) - np.min(superimposed)) 64 | ax[i_glimpse+1].imshow(superimposed.astype(np.uint8)) 65 | ax[i_glimpse+1].axis('off') 66 | ax[i_glimpse+1].set_title("Glimpse " + str(i_glimpse+1)) 67 | plt.savefig(jp(path_logs, 'att_maps', str(question_indexes[i_s].item()) + '.png') ,bbox_inches='tight') 68 | plt.close() 69 | elif model_name == 'vQA_IgnoreMask': 70 | pass 71 | elif model_name == 'VQA_LocalizedAttention': 72 | model.attention_mechanism.conv2.register_forward_hook(get_att_map('attention_mechanism.conv2')) 73 | output = model(visual, question, mask) 74 | pred = (m(output.data.cpu())>0.5).to(torch.int64) 75 | k = att['attention_mechanism.conv2'] # size 64, 2, 14, 14 for sts 76 | h = k.clone() 77 | h = h.view(output.shape[0],2,14*14) 78 | h_out = sm(h) 79 | g_out = h_out.view(output.shape[0],2,14,14) 80 | for i_s in range(g_out.shape[0]): # for every element of the batch 81 | image = dit()(visual[i_s]).permute(1,2,0).cpu().numpy() 82 | f, ax = plt.subplots(1, 3) 83 | f.tight_layout() 84 | ax[0].imshow(image) 85 | if not np.count_nonzero(mask[i_s].cpu().numpy()) == mask[i_s].shape[-1]*mask[i_s].shape[-2]: 86 | masked = np.ma.masked_where(mask[i_s].permute(1,2,0).cpu().numpy() ==0, mask[i_s].permute(1,2,0).cpu().numpy()) 87 | ax[0].imshow(masked, 'jet', interpolation='none', alpha=0.5) 88 | ax[0].axis('off') 89 | question_words_encoded = [vocab_words[question[i_s, i].item()] for i in range(question.shape[1]) if question[i_s, i].item()!= 0] 90 | question_text = ' '.join(question_words_encoded) 91 | f.suptitle(question_text + "\n GT: " + str(vocab_answers[answer[i_s].item()]) + ', Pred: ' + str(vocab_answers[pred[i_s].item()])) 92 | ax[0].set_title('Image') 93 | #ax[0].set_title(question_text + "\n, GT: " + str(vocab_answers[answer[i_s].item()])) 94 | if pred[i_s].item() == answer[i_s].item(): 95 | f.set_facecolor("green") 96 | else: 97 | f.set_facecolor("r") 98 | for i_glimpse in range(g_out.shape[1]): # for every glimpse 99 | img1 = g_out[i_s, i_glimpse, :, :].cpu().numpy() 100 | heatmap = cv2.resize(img1, (image.shape[1], image.shape[0])) 101 | heatmap = (heatmap - np.min(heatmap))/(np.max(heatmap) - np.min(heatmap)) 102 | heatmap = np.uint8(255*heatmap) 103 | #heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 104 | norm = plt.Normalize() 105 | heatmap = plt.cm.jet(norm(heatmap)) 106 | superimposed = heatmap[:,:,:3] * 0.4 + image 107 | superimposed = 255*(superimposed - np.min(superimposed))/(np.max(superimposed) - np.min(superimposed)) 108 | ax[i_glimpse+1].imshow(superimposed.astype(np.uint8)) 109 | ax[i_glimpse+1].axis('off') 110 | ax[i_glimpse+1].set_title("Glimpse " + str(i_glimpse+1)) 111 | plt.savefig(jp(path_logs, 'att_maps', str(question_indexes[i_s].item()) + '.png') ,bbox_inches='tight') 112 | plt.close() 113 | elif model_name == 'VQARS_4': 114 | pass 115 | else: 116 | raise ValueError('Model not supported') 117 | 118 | def plot_attention_single(model_name, model, visual, question, mask, path_output, question_id, case, idx2ans, mask_as_text=False, tag='404'): 119 | if model.__class__.__name__ == 'DataParallel': 120 | model = model.module 121 | if model_name == 'VQA_MaskRegion': 122 | model.attention_mechanism.conv2.register_forward_hook(get_att_map('attention_mechanism.conv2')) 123 | output = model(visual, question, mask) 124 | pred = (m(output.data.cpu())>0.5).to(torch.int64) 125 | #pred = torch.argmax(output, dim=1) 126 | k = att['attention_mechanism.conv2'] # size 64, 2, 14, 14 for sts 127 | h = k.clone() 128 | h = h.view(output.shape[0],2,14*14) 129 | h_out = sm(h) 130 | g_out = h_out.view(output.shape[0],2,14,14) 131 | for i_s in range(g_out.shape[0]): # for every element of the batch 132 | image = dit()(visual[i_s]).permute(1,2,0).cpu().numpy() 133 | answer = idx2ans[pred[i_s].item()] 134 | for i_glimpse in range(g_out.shape[1]): # for every glimpse 135 | plt.ioff() 136 | img1 = g_out[i_s, i_glimpse, :, :].cpu().numpy() 137 | heatmap = cv2.resize(img1, (image.shape[1], image.shape[0])) 138 | heatmap = (heatmap - np.min(heatmap))/(np.max(heatmap) - np.min(heatmap)) 139 | heatmap = np.uint8(255*heatmap) 140 | norm = plt.Normalize() 141 | heatmap = plt.cm.jet(norm(heatmap)) 142 | superimposed = heatmap[:,:,:3] * 0.4 + image*mask[i_s].permute(1,2,0).cpu().numpy() 143 | superimposed = 255*(superimposed - np.min(superimposed))/(np.max(superimposed) - np.min(superimposed)) 144 | # save image 145 | if i_glimpse == 0: # only first glimpse 146 | plt.imsave(jp(path_output, tag, case, str(question_id[i_s].item()) + '_' + model_name + '_g' + str(i_glimpse) + '_'+ str(answer) +'.png') ,superimposed.astype(np.uint8)) 147 | elif model_name == 'VQA_LocalizedAttention': 148 | model.attention_mechanism.conv2.register_forward_hook(get_att_map('attention_mechanism.conv2')) 149 | output = model(visual, question, mask) 150 | pred = (m(output.data.cpu())>0.5).to(torch.int64) 151 | #pred = torch.argmax(output, dim=1) 152 | k = att['attention_mechanism.conv2'] # size 64, 2, 14, 14 for sts 153 | h = k.clone() 154 | h = h.view(output.shape[0],2,14*14) 155 | h_out = sm(h) 156 | g_out = h_out.view(output.shape[0],2,14,14) 157 | for i_s in range(g_out.shape[0]): # for every element of the batch 158 | image = dit()(visual[i_s]).permute(1,2,0).cpu().numpy() 159 | answer = idx2ans[pred[i_s].item()] 160 | for i_glimpse in range(g_out.shape[1]): # for every glimpse 161 | plt.ioff() 162 | img1 = g_out[i_s, i_glimpse, :, :].cpu().numpy() 163 | heatmap = cv2.resize(img1, (image.shape[1], image.shape[0])) 164 | heatmap = (heatmap - np.min(heatmap))/(np.max(heatmap) - np.min(heatmap)) 165 | heatmap = np.uint8(255*heatmap) 166 | norm = plt.Normalize() 167 | heatmap = plt.cm.jet(norm(heatmap)) 168 | superimposed = heatmap[:,:,:3]*mask[i_s].permute(1,2,0).cpu().numpy() * 0.4 + image 169 | superimposed = 255*(superimposed - np.min(superimposed))/(np.max(superimposed) - np.min(superimposed)) 170 | # save image 171 | if i_glimpse == 0: # only first glimpse 172 | plt.imsave(jp(path_output,tag, case, str(question_id[i_s].item()) + '_' + model_name + '_g' + str(i_glimpse) + '_'+ str(answer) +'.png') ,superimposed.astype(np.uint8)) 173 | elif model_name == 'VQA_IgnoreMask' or (model_name == 'VQA_Base'): 174 | model.attention_mechanism.conv2.register_forward_hook(get_att_map('attention_mechanism.conv2')) 175 | output = model(visual, question, mask) 176 | pred = (m(output.data.cpu())>0.5).to(torch.int64) 177 | #pred = torch.argmax(output, dim=1) 178 | k = att['attention_mechanism.conv2'] # size 64, 2, 14, 14 for sts 179 | h = k.clone() 180 | h = h.view(output.shape[0],2,14*14) 181 | h_out = sm(h) 182 | g_out = h_out.view(output.shape[0],2,14,14) 183 | for i_s in range(g_out.shape[0]): # for every element of the batch 184 | image = dit()(visual[i_s]).permute(1,2,0).cpu().numpy() 185 | answer = idx2ans[pred[i_s].item()] 186 | for i_glimpse in range(g_out.shape[1]): # for every glimpse 187 | plt.ioff() 188 | img1 = g_out[i_s, i_glimpse, :, :].cpu().numpy() 189 | heatmap = cv2.resize(img1, (image.shape[1], image.shape[0])) 190 | heatmap = (heatmap - np.min(heatmap))/(np.max(heatmap) - np.min(heatmap)) 191 | heatmap = np.uint8(255*heatmap) 192 | norm = plt.Normalize() 193 | heatmap = plt.cm.jet(norm(heatmap)) 194 | superimposed = heatmap[:,:,:3] * 0.4 + image 195 | superimposed = 255*(superimposed - np.min(superimposed))/(np.max(superimposed) - np.min(superimposed)) 196 | # save image 197 | if i_glimpse == 0: # only first glimpse 198 | plt.imsave(jp(path_output, tag, case, str(question_id[i_s].item()) + '_' + model_name + '_g' + str(i_glimpse) + '_'+ str(answer) +'.png') ,superimposed.astype(np.uint8)) 199 | 200 | elif model_name == 'VQA_Base' and not mask_as_text: 201 | pass -------------------------------------------------------------------------------- /dataset_factory/qa_factory.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # QA pair creation 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import os 9 | import numpy as np 10 | import random 11 | from PIL import Image 12 | from os.path import join as jp 13 | from skimage.measure import regionprops 14 | 15 | 16 | def generate_random_window(h, w, min_side, max_side, prop, regions_in_subwindow=False, offset=0): 17 | p = random.random() # random number to decide whether random width or random height is sampled first 18 | if p >= 0.5: 19 | # sample width first 20 | random_w = random.randint(min_side, max_side) 21 | # to generate the random height I move back to the resized space so that the proportion is kept there instead of in the original space 22 | random_h = random.randint(max(min_side, round((1-prop)*random_w)), min(max_side, round((1+prop)*random_w))) 23 | else: 24 | # sample height first 25 | random_h = random.randint(min_side, max_side) 26 | # to generate the random width I move back to the resized space so that the proportion is kept there instead of in the original space 27 | random_w = random.randint(max(min_side, round((1-prop)*random_h)), min(max_side, round((1+prop)*random_h))) 28 | if regions_in_subwindow: 29 | # force windows to be in a specific sub-range of the original image so that in the resized space they are also within a sub-range (a centered sub-window) 30 | top_left_corner = (random.randint(offset, h - offset - random_h), random.randint(offset, w - offset - random_w)) 31 | else: 32 | top_left_corner = (random.randint(0, h - random_h), random.randint(0, w - random_w)) 33 | return top_left_corner, random_h, random_w 34 | 35 | def convert_region_coords(top_left, window_h, window_w, h, w, new_side): 36 | """coords convertion function 37 | 38 | Parameters 39 | ---------- 40 | top_left : tuple 41 | top left corner of the region in the original image size 42 | window_h : int 43 | height of the region in the original image size 44 | window_w : int 45 | width of the region in the original image size 46 | h : int 47 | height of the original image 48 | w : int 49 | width of the original image 50 | new_side : int 51 | new side of the resized image (square image) 52 | 53 | Returns 54 | ------- 55 | tuple 56 | converted coords with the format ((top_left_y, top_left_x), h_, w_)) 57 | """ 58 | # convert to resized image coordinates 59 | Ry, Rx = new_side/h, new_side/w 60 | top_left_resized = (round(top_left[0]*Ry), round(top_left[1]*Rx)) 61 | window_h_resized = round(window_h*Ry) 62 | window_w_resized = round(window_w*Rx) 63 | return top_left_resized, window_h_resized, window_w_resized 64 | 65 | def generate_questions_about_regions(config, mask_gt, class_name, partial_qa_id, image_name, balanced=True, dataset='cholec'): 66 | """Generates questions about regions for the Cholec dataset 67 | 68 | Parameters 69 | ---------- 70 | config : config dict 71 | _description_ 72 | mask_gt : numpy array 73 | _description_ 74 | class_name : str 75 | _description_ 76 | partial_qa_id : int 77 | partial question id created from image index and class index 78 | image_name : str 79 | name of the image with extension 80 | balanced : bool, optional 81 | whether or not the dataset shoould be balanced, by default True 82 | dataset : str, optional 83 | name of the dataset, by default 'Cholec' 84 | 85 | Returns 86 | ------- 87 | list 88 | List with QA pairs about random regions for current image 89 | """ 90 | # get info from config 91 | num_regions = config['num_regions'] 92 | min_window_side = config['min_window_side'] 93 | max_window_side = config['max_window_side'] 94 | threshold = config['threshold'] 95 | proportion = config['proportion_deviation'] 96 | 97 | # first, get number of pixels in the image 98 | num_pixels_img = mask_gt.shape[0] * mask_gt.shape[1] 99 | 100 | # now get number of 1s in mask 101 | if mask_gt.ndim > 2: 102 | num_pixels_mask = np.sum(mask_gt[:,:,0]) # take only one channel (all of them have the same amount - this info is in the Cholec info) 103 | else: 104 | num_pixels_mask = np.sum(mask_gt) 105 | 106 | if num_pixels_mask == 0: 107 | return [] # if there are no pixels in the mask, return empty list 108 | 109 | if dataset == 'sts2017' or dataset == 'insegcat': # due to the shape of the tools, use bounding box 110 | props = regionprops(mask_gt) 111 | # add the areas of all bounding boxes 112 | num_pixels_mask = 0 113 | for prop in props: 114 | minr, minc, maxr, maxc = prop.bbox 115 | num_pixels_mask += (maxr - minr)*(maxc - minc) 116 | 117 | # define number of regions. For now, I define it as a function that moves between 3 and num_regions. The idea is that if the GT region is too small 118 | # The idea here is that if the number of pixels in the mask is close to num_pixels_img/2, then num_regions should be produced. If the mask is too big or too small, then reduce the number of regions 119 | # so that small tissues are not redundant (regions on big tissues would correspond to different parts of the image, but small tissues would be just sampled several times) 120 | # This is described on pp. 23 of notebook 4 121 | if num_pixels_mask <= num_pixels_img/2: 122 | num_regions_recomputed = np.min([num_regions, np.max([config['min_regions'], int((2*(num_regions + (num_regions/2))/num_pixels_img)*num_pixels_mask)])]) 123 | if num_regions_recomputed%2 != 0: # make even so that half of the questions can be answered with yes and half with no 124 | num_regions_recomputed += 1 125 | else: 126 | num_regions_recomputed = np.min([num_regions, np.max([config['min_regions'], int(-1*(2*(num_regions + (num_regions/2))/num_pixels_img)*(num_pixels_mask - num_pixels_img/2) + num_regions + (num_regions/2) )])]) 127 | if num_regions_recomputed%2 != 0: # make even so that half of the questions can be answered with yes and half with no 128 | num_regions_recomputed += 1 129 | 130 | 131 | qa_group = [] 132 | i_region = 0 # region index for current image 133 | num_questions_yes = 0 134 | num_questions_no = 0 135 | budget = 100*num_regions_recomputed 136 | while num_questions_yes < round(num_regions_recomputed/2) or num_questions_no < round(num_regions_recomputed/2): # while not complete 137 | # generate randomly-sized region with random location 138 | top_left, window_h, window_w = generate_random_window(mask_gt.shape[0], mask_gt.shape[1], min_window_side, max_window_side, proportion, regions_in_subwindow=True, offset=config['window_offset']) 139 | # convert coordinates of the random region to the resized space 140 | top_left_resized, window_h_resized, window_w_resized = convert_region_coords(top_left, window_h, window_w, mask_gt.shape[0], mask_gt.shape[1], config['size']) 141 | # build mask array 142 | mask_region = np.zeros_like(mask_gt, dtype = np.uint8) 143 | mask_region[top_left[0]:top_left[0]+window_h, top_left[1]:top_left[1]+window_w] = 1 # * Important: to be used like this in dataset class to create the mask, but setting it to 255 144 | 145 | num_pixels_in_region = np.count_nonzero(mask_gt*mask_region) 146 | 147 | # if threshold parameter should be treated as a percentage of the region, then compute it 148 | if config['threshold_as_percentage']: 149 | threshold = int(config['threshold']*np.count_nonzero(mask_region)/100) 150 | 151 | if (num_pixels_in_region >= threshold) and num_questions_yes < round(num_regions_recomputed/2): # if answer is yes and i haven't reached the maximum number of positive questions 152 | answer = 'yes' 153 | question_linked_to_region_mask = ('is there ' + class_name + ' in this region?').lower() 154 | question_mentioning_region = ('is there ' + class_name + ' in the region with top left corner at (' + str(top_left_resized[0]) + ', ' + str(top_left_resized[1]) + ') and height ' + str(window_h_resized) + ' and width ' + str(window_w_resized) + '?').lower() 155 | qa_group.append({ 156 | 'image_name': image_name, 157 | 'question': question_linked_to_region_mask, 158 | 'question_alt': question_mentioning_region, 159 | 'question_id': int(str(i_region+1).zfill(3) + partial_qa_id), 160 | 'question_type': 'region', 161 | 'mask_coords': (top_left_resized, window_h_resized, window_w_resized), 162 | 'mask_coords_orig': (top_left, window_h, window_w), # save coords in original space just in case 163 | 'answer': answer, 164 | 'mask_size': (config['size'], config['size']), 165 | 'mask_size_orig': mask_region.shape, # original shape 166 | 'question_object': class_name 167 | }) 168 | num_questions_yes += 1 169 | i_region += 1 170 | 171 | elif num_pixels_in_region == 0 and num_questions_no < round(num_regions_recomputed/2): # if answer is no and i haven't reached the maximum number of negative questions 172 | answer = 'no' 173 | 174 | question_linked_to_region_mask = ('is there ' + class_name + ' in this region?').lower() 175 | question_mentioning_region = ('is there ' + class_name + ' in the region with top left corner at (' + str(top_left_resized[0]) + ', ' + str(top_left_resized[1]) + ') and height ' + str(window_h_resized) + ' and width ' + str(window_w_resized) + '?').lower() 176 | 177 | qa_group.append({ 178 | 'image_name': image_name, 179 | 'question': question_linked_to_region_mask, 180 | 'question_alt': question_mentioning_region, 181 | 'question_id': int(str(i_region+1).zfill(3) + partial_qa_id), 182 | 'question_type': 'region', 183 | 'mask_coords': (top_left_resized, window_h_resized, window_w_resized), 184 | 'mask_coords_orig': (top_left, window_h, window_w), 185 | 'answer': answer, 186 | 'mask_size': (config['size'], config['size']), 187 | 'mask_size_orig': mask_region.shape, 188 | 'question_object': class_name 189 | }) 190 | num_questions_no += 1 191 | i_region += 1 192 | else: 193 | budget -= 1 194 | 195 | if budget == 0: # if I can't find a region that satisfies the conditions, I stop 196 | print('WARNING: budget exceeded for image ' + image_name + ' and class ' + class_name + '. Skipping class') 197 | return [] 198 | 199 | 200 | assert num_questions_no == num_questions_yes # sanity check: check balance 201 | return qa_group 202 | 203 | 204 | def generate_questions_about_image(config, labels, mask_code, image_name, h, w, img_idx): 205 | list_classes_int = set(mask_code.keys()) 206 | labels_in_image = set(labels) 207 | 208 | code2class = {v: k for k, v in mask_code.items()} 209 | 210 | to_choose_from = list_classes_int - labels_in_image # classes that are not in the image (negative questions) 211 | 212 | num_yes = 0 213 | num_no = 0 214 | qa_group = [] 215 | 216 | # first, create positive questions 217 | labels_in_image_text = [mask_code[l] for l in labels_in_image] 218 | idx = 0 219 | for class_name in labels_in_image_text: 220 | question_linked_to_region_mask = ('is there ' + class_name + ' in this image?').lower() 221 | question_mentioning_region = ('is there ' + class_name + ' in the region with top left corner at (' + str(0) + ', ' + str(0) + ') and height ' + str(config['size']) + 'and width ' + str(config['size']) + '?').lower() 222 | qa_group.append({ 223 | 'image_name': image_name, 224 | 'question': question_linked_to_region_mask, 225 | 'question_alt': question_mentioning_region, 226 | 'question_id': int(str(idx+1).zfill(2) + '2' + str(code2class[class_name]).zfill(2) + img_idx), 227 | 'question_type': 'whole', 228 | 'mask_coords': ((0,0), config['size'], config['size']), 229 | 'mask_coords_orig': ((0,0), h, w), 230 | 'answer': 'yes', 231 | 'mask_size': (config['size'], config['size']), 232 | 'mask_size_orig': (h, w), 233 | 'question_object': class_name 234 | }) 235 | num_yes += 1 236 | idx += 1 237 | 238 | # then, create negative questions 239 | for i in range(len(labels_in_image)): 240 | # choose a random class that is not in the image 241 | class_name = mask_code[np.random.choice(list(to_choose_from))] 242 | question_linked_to_region_mask = ('is there ' + class_name + ' in this image?').lower() 243 | question_mentioning_region = ('is there ' + class_name + ' in the region with top left corner at (' + str(0) + ', ' + str(0) + ') and height ' + str(h) + 'and width ' + str(w) + '?').lower() 244 | qa_group.append({ 245 | 'image_name': image_name, 246 | 'question': question_linked_to_region_mask, 247 | 'question_alt': question_mentioning_region, 248 | 'question_id': int(str(idx+1).zfill(2) + '2' + str(code2class[class_name]).zfill(2) + img_idx), 249 | 'question_type': 'whole', 250 | 'mask_coords': ((0,0), config['size'], config['size']), 251 | 'mask_coords_orig': ((0,0), h, w), 252 | 'answer': 'no', 253 | 'mask_size': (config['size'], config['size']), 254 | 'mask_size_orig': (h, w), 255 | 'question_object': class_name 256 | }) 257 | num_no += 1 258 | idx += 1 259 | 260 | assert num_yes == num_no # sanity check: check balance 261 | return qa_group 262 | 263 | -------------------------------------------------------------------------------- /core/models/models.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Model definition file for VQA 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torchvision import transforms 11 | from .components import image, text, attention, fusion, classification 12 | 13 | class VQA_Base(nn.Module): 14 | # base class for simple VQA model 15 | def __init__(self, config, vocab_words, vocab_answers): 16 | super().__init__() 17 | self.visual_feature_size = config['visual_feature_size'] 18 | self.question_feature_size = config['question_feature_size'] 19 | self.pre_visual = config['pre_extracted_visual_feat'] 20 | self.use_attention = config['attention'] 21 | self.number_of_glimpses = config['number_of_glimpses'] 22 | self.visual_size_before_fusion = self.visual_feature_size # 2048 by default, changes if attention 23 | # Create modules for the model 24 | 25 | # if necesary, create module for offline visual feature extraction 26 | if not self.pre_visual: 27 | self.image = image.get_visual_feature_extractor(config) 28 | 29 | # create module for text feature extraction 30 | self.text = text.get_text_feature_extractor(config, vocab_words) 31 | 32 | # if necessary, create attention module 33 | if self.use_attention: 34 | self.visual_size_before_fusion = self.number_of_glimpses*self.visual_feature_size 35 | self.attention_mechanism = attention.get_attention_mechanism(config) 36 | else: 37 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1)) 38 | 39 | # create multimodal fusion module 40 | self.fuser, fused_size = fusion.get_fuser(config['fusion'], self.visual_size_before_fusion, self.question_feature_size) 41 | 42 | # create classifier 43 | self.classifer = classification.get_classfier(fused_size, config) 44 | 45 | 46 | def forward(self, v, q, m): #* For simplicity, VQA_Base receives the mask as an input, but it is not used 47 | # if required, extract visual features from visual input 48 | if not self.pre_visual: 49 | v = self.image(v) # [B, 2048, 14, 14] 50 | 51 | # l2 norm 52 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 53 | 54 | # extract text features 55 | q = self.text(q) 56 | 57 | # if required, apply attention 58 | if self.use_attention: 59 | v = self.attention_mechanism(v, q) # should apply attention too 60 | else: 61 | v = self.avgpool(v).squeeze_() # [B, 2048] 62 | 63 | # apply multimodal fusion 64 | fused = self.fuser(v, q) 65 | 66 | # apply MLP 67 | x = self.classifer(fused) 68 | 69 | return x 70 | 71 | # VQA Models 72 | 73 | class VQA_MaskRegion(VQA_Base): 74 | # First model for region-based VQA, with single mask. Input image is multiplied with the mask to produced a masked version which is sent to the model as normal 75 | # A.k.a. VQARS_1 76 | def __init__(self, config, vocab_words, vocab_answers): 77 | # call mom 78 | super().__init__(config, vocab_words, vocab_answers) 79 | 80 | # override forward method to accept mask 81 | def forward(self, v, q, m): 82 | # if required, extract visual features from visual input 83 | if self.pre_visual: 84 | raise ValueError("This model does not allow pre-extracted features") 85 | else: 86 | v = self.image(torch.mul(v, m)) # [B, 2048, 14, 14] MASK IS INCLUDED HERE 87 | 88 | # l2 norm 89 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 90 | 91 | # extract text features 92 | q = self.text(q) 93 | 94 | # if required, apply attention 95 | if self.use_attention: 96 | v = self.attention_mechanism(v, q) # should apply attention too 97 | else: 98 | v = self.avgpool(v).squeeze_(dim=-1).squeeze_(dim=-1) # [B, 2048] 99 | 100 | # apply multimodal fusion 101 | fused = self.fuser(v, q) 102 | 103 | # apply MLP 104 | x = self.classifer(fused) 105 | 106 | return x 107 | 108 | 109 | class VQA_IgnoreMask(VQA_Base): 110 | # First model for region-based VQA, with single mask, but the mask is totally ignored. This model measures the ability of the system to answer without masks 111 | # A.k.a. VQARS_2 112 | def __init__(self, config, vocab_words, vocab_answers): 113 | # call mom 114 | super().__init__(config, vocab_words, vocab_answers) 115 | 116 | # override forward method to accept mask 117 | def forward(self, v, q, m): 118 | # if required, extract visual features from visual input 119 | if self.pre_visual: 120 | raise ValueError("This model does not allow pre-extracted features") 121 | else: 122 | v = self.image(v) # [B, 2048, 14, 14] MASK IS INCLUDED HERE 123 | 124 | # l2 norm 125 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 126 | 127 | # extract text features 128 | q = self.text(q) 129 | 130 | # if required, apply attention 131 | if self.use_attention: 132 | v = self.attention_mechanism(v, q) # should apply attention too 133 | else: 134 | v = self.avgpool(v).squeeze_() # [B, 2048] 135 | 136 | # apply multimodal fusion 137 | fused = self.fuser(v, q) 138 | 139 | # apply MLP 140 | x = self.classifer(fused) 141 | 142 | return x 143 | 144 | 145 | class VQARS_3(VQA_Base): 146 | # Same as model VQARS_1 but mask is resized, flattened and then injected into the multimodal fusion 147 | def __init__(self, config, vocab_words, vocab_answers): 148 | # call mom 149 | super().__init__(config, vocab_words, vocab_answers) 150 | 151 | # create fuser to include reshaped, vectorized mask 152 | self.fuser2, fused_size2 = fusion.get_fuser(config['fusion'], self.visual_size_before_fusion + self.question_feature_size, 14*14) 153 | 154 | # correct classifier 155 | self.classifer = classification.get_classfier(fused_size2, config) 156 | 157 | # override forward method to accept mask 158 | def forward(self, v, q, m): 159 | # if required, extract visual features from visual input 160 | if self.pre_visual: 161 | raise ValueError("This model does not allow pre-extracted features") 162 | else: 163 | v = self.image(v) # [B, 2048, 14, 14] 164 | 165 | # resize mask 166 | m = transforms.Resize(14)(m) #should become size (B,14,14) 167 | m = m.view(-1, 14*14) 168 | 169 | # extract text features 170 | q = self.text(q) 171 | 172 | # if required, apply attention 173 | if self.use_attention: 174 | v = self.attention_mechanism(v, q) # should apply attention too 175 | else: 176 | v = self.avgpool(v).squeeze_() # [B, 2048] 177 | 178 | # apply multimodal fusion 179 | fused = self.fuser(v, q) 180 | fused = self.fuser2(fused, m) 181 | 182 | # apply MLP 183 | x = self.classifer(fused) 184 | 185 | return x 186 | 187 | 188 | class VQARS_4(VQA_Base): 189 | # Model that requires attention. Mask is used to mask the attention maps of the attention mechanism. 190 | def __init__(self, config, vocab_words, vocab_answers): 191 | if not config['attention']: 192 | raise ValueError("This model requires attention. Please set to True in the config file") 193 | 194 | # call mom 195 | super().__init__(config, vocab_words, vocab_answers) 196 | 197 | # replace attention mechanism 198 | self.attention_mechanism = attention.get_attention_mechanism(config, special='Att1') 199 | 200 | # override forward method to accept mask 201 | def forward(self, v, q, m): 202 | # if required, extract visual features from visual input 203 | if self.pre_visual: 204 | raise ValueError("This model does not allow pre-extracted features") 205 | else: 206 | v = self.image(v) 207 | 208 | # extract text features 209 | q = self.text(q) 210 | 211 | # resize mask 212 | m = transforms.Resize(14)(m) #should become size (B,1,14,14) 213 | m = m.view(m.size(0),-1, 14*14) # [B,1,196] 214 | 215 | # if required, apply attention 216 | if self.use_attention: 217 | v = self.attention_mechanism(v, m, q) 218 | else: 219 | raise ValueError("This model requires attention") 220 | 221 | # apply multimodal fusion 222 | fused = self.fuser(v, q) 223 | 224 | # apply MLP 225 | x = self.classifer(fused) 226 | 227 | return x 228 | 229 | 230 | class VQARS_5(VQA_Base): 231 | # Alternative model that considers mask as an additional channel along with the input image 232 | def __init__(self, config, vocab_words, vocab_answers): 233 | # call mom 234 | super().__init__(config, vocab_words, vocab_answers) 235 | 236 | # re-define first layer of visual feature extractor so that it admits 4 channels as input instead of 3 237 | self.image.net_base.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 238 | modules = list(self.image.net_base.children())[:-2] # ignore avgpool layer and classifier 239 | self.image.extractor = nn.Sequential(*modules) 240 | for p in self.image.extractor.parameters(): 241 | p.requires_grad = False 242 | # override forward method to accept mask 243 | def forward(self, v, q, m): 244 | # if required, extract visual features from visual input 245 | if self.pre_visual: 246 | raise ValueError("This model does not allow pre-extracted features") 247 | else: 248 | v = self.image(torch.cat((v, m), 1)) # [B, 2048, 14, 14] MASK IS INCLUDED HERE 249 | 250 | # extract text features 251 | q = self.text(q) 252 | 253 | # if required, apply attention 254 | if self.use_attention: 255 | v = self.attention_mechanism(v, q) # should apply attention too 256 | else: 257 | v = self.avgpool(v).squeeze_() # [B, 2048] 258 | 259 | # apply multimodal fusion 260 | fused = self.fuser(v, q) 261 | 262 | # apply MLP 263 | x = self.classifer(fused) 264 | 265 | return x 266 | 267 | 268 | class VQARS_6(VQA_Base): 269 | # Same as base class. Had to override forward because of number of arguments passed in training functions 270 | def __init__(self, config, vocab_words, vocab_answers): 271 | if not config['attention']: 272 | raise ValueError("This model requires attention. Please set to True in the config file") 273 | 274 | # call mom 275 | super().__init__(config, vocab_words, vocab_answers) 276 | 277 | # replace attention mechanism 278 | self.attention_mechanism = attention.get_attention_mechanism(config, special='Att2') 279 | 280 | # override forward method to accept mask 281 | def forward(self, v, q, m): 282 | # if required, extract visual features from visual input 283 | if self.pre_visual: 284 | raise ValueError("This model does not allow pre-extracted features") 285 | else: 286 | v = self.image(v) # [B, 2048, 14, 14] 287 | 288 | # extract text features 289 | q = self.text(q) 290 | 291 | # resize mask 292 | m = transforms.Resize(14)(m) #should become size (B,1,14,14) 293 | m = m.view(m.size(0),-1, 14*14) # [B,1,196] 294 | 295 | # if required, apply attention 296 | if self.use_attention: 297 | v = self.attention_mechanism(v, m, q) 298 | else: 299 | raise ValueError("This model requires attention") 300 | 301 | # apply multimodal fusion 302 | fused = self.fuser(v, q) 303 | 304 | # apply MLP 305 | x = self.classifer(fused) 306 | 307 | return x 308 | 309 | 310 | class VQA_LocalizedAttention(VQA_Base): 311 | # Model that requires attention. Mask is used to mask the attention maps of the attention mechanism. 312 | # A.k.a. VQARS_7 313 | def __init__(self, config, vocab_words, vocab_answers): 314 | if not config['attention']: 315 | raise ValueError("This model requires attention. Please set to True in the config file") 316 | 317 | # call mom 318 | super().__init__(config, vocab_words, vocab_answers) 319 | 320 | # replace attention mechanism 321 | self.attention_mechanism = attention.get_attention_mechanism(config, special='Att3') 322 | 323 | # override forward method to accept mask 324 | def forward(self, v, q, m): 325 | # if required, extract visual features from visual input 326 | if not self.pre_visual: 327 | v = self.image(v) 328 | 329 | # l2 norm 330 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 331 | 332 | # extract text features 333 | q = self.text(q) 334 | 335 | # resize mask 336 | m = transforms.Resize(14)(m) #should become size (B,1,14,14) 337 | m = m.view(m.size(0),-1, 14*14) # [B,1,196] 338 | 339 | # if required, apply attention 340 | if self.use_attention: 341 | v = self.attention_mechanism(v, m, q) 342 | else: 343 | raise ValueError("This model requires attention") 344 | 345 | # apply multimodal fusion 346 | fused = self.fuser(v, q) 347 | 348 | # apply MLP 349 | x = self.classifer(fused) 350 | 351 | return x 352 | 353 | class VQARS_8(VQA_Base): 354 | # Model that requires attention. Mask is used to mask the attention maps of the attention mechanism. 355 | # A.k.a. VQARS_7 356 | def __init__(self, config, vocab_words, vocab_answers): 357 | if not config['attention']: 358 | raise ValueError("This model requires attention. Please set to True in the config file") 359 | 360 | # call mom 361 | super().__init__(config, vocab_words, vocab_answers) 362 | 363 | # replace attention mechanism 364 | self.attention_mechanism = attention.get_attention_mechanism(config, special='Att5') 365 | 366 | # override forward method to accept mask 367 | def forward(self, v, q, m): 368 | # if required, extract visual features from visual input 369 | if not self.pre_visual: 370 | v = self.image(v) 371 | 372 | # l2 norm 373 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 374 | 375 | # extract text features 376 | q = self.text(q) 377 | 378 | # resize mask 379 | m = transforms.Resize(14)(m) #should become size (B,1,14,14) 380 | m = m.view(m.size(0),-1, 14*14) # [B,1,196] 381 | 382 | # if required, apply attention 383 | if self.use_attention: 384 | v = self.attention_mechanism(v, m, q) 385 | else: 386 | raise ValueError("This model requires attention") 387 | 388 | # apply multimodal fusion 389 | fused = self.fuser(v, q) 390 | 391 | # apply MLP 392 | x = self.classifer(fused) 393 | 394 | return x 395 | 396 | 397 | class VQARS_9(VQA_Base): 398 | # Model that requires attention. Mask is used to mask the attention maps of the attention mechanism. 399 | # A.k.a. VQARS_7 400 | def __init__(self, config, vocab_words, vocab_answers): 401 | if not config['attention']: 402 | raise ValueError("This model requires attention. Please set to True in the config file") 403 | 404 | # call mom 405 | super().__init__(config, vocab_words, vocab_answers) 406 | 407 | # replace attention mechanism 408 | self.attention_mechanism = attention.get_attention_mechanism(config, special='Att6') 409 | 410 | # override forward method to accept mask 411 | def forward(self, v, q, m): 412 | # if required, extract visual features from visual input 413 | if not self.pre_visual: 414 | v = self.image(v) 415 | 416 | # l2 norm 417 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 418 | 419 | # extract text features 420 | q = self.text(q) 421 | 422 | # resize mask 423 | m = transforms.Resize(14)(m) #should become size (B,1,14,14) 424 | m = m.view(m.size(0),-1, 14*14) # [B,1,196] 425 | 426 | # if required, apply attention 427 | if self.use_attention: 428 | v = self.attention_mechanism(v, m, q) 429 | else: 430 | raise ValueError("This model requires attention") 431 | 432 | # apply multimodal fusion 433 | fused = self.fuser(v, q) 434 | 435 | # apply MLP 436 | x = self.classifer(fused) 437 | 438 | return x 439 | 440 | class VQA_LocalizedAttentionScale(VQA_Base): 441 | # Model that requires attention. Mask is used to mask the attention maps of the attention mechanism. 442 | # Added feature: scale the attention in the region by abs(max(ouside) - max(inside)). 443 | # A.k.a. VQARS_7 444 | def __init__(self, config, vocab_words, vocab_answers): 445 | if not config['attention']: 446 | raise ValueError("This model requires attention. Please set to True in the config file") 447 | 448 | # call mom 449 | super().__init__(config, vocab_words, vocab_answers) 450 | 451 | # replace attention mechanism 452 | self.attention_mechanism = attention.get_attention_mechanism(config, special='Att4') 453 | 454 | # override forward method to accept mask 455 | def forward(self, v, q, m): 456 | # if required, extract visual features from visual input 457 | if not self.pre_visual: 458 | v = self.image(v) 459 | 460 | # l2 norm 461 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 462 | 463 | # extract text features 464 | q = self.text(q) 465 | 466 | # resize mask 467 | m = transforms.Resize(14)(m) #should become size (B,1,14,14) 468 | m = m.view(m.size(0),-1, 14*14) # [B,1,196] 469 | 470 | # if required, apply attention 471 | if self.use_attention: 472 | v = self.attention_mechanism(v, m, q) 473 | else: 474 | raise ValueError("This model requires attention") 475 | 476 | # apply multimodal fusion 477 | fused = self.fuser(v, q) 478 | 479 | # apply MLP 480 | x = self.classifer(fused) 481 | 482 | return x -------------------------------------------------------------------------------- /core/models/components/attention.py: -------------------------------------------------------------------------------- 1 | # Project: 2 | # Localized Questions in VQA 3 | # Description: 4 | # Attention-related definitions 5 | # Author: 6 | # Sergio Tascon-Morales, Ph.D. Student, ARTORG Center, University of Bern 7 | 8 | from . import fusion 9 | from . import utils 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import torch 13 | 14 | def get_attention_mechanism(config, special=None): 15 | # get attention parameters 16 | visual_features_size = config['visual_feature_size'] 17 | question_feature_size = config['question_feature_size'] 18 | attention_middle_size = config['attention_middle_size'] 19 | number_of_glimpses = config['number_of_glimpses'] 20 | attention_fusion = config['attention_fusion'] 21 | dropout_attention = config['attention_dropout'] 22 | if special is None: # Normal attention mechanism 23 | attention = AttentionMechanismBase(visual_features_size, question_feature_size, attention_middle_size, number_of_glimpses, attention_fusion, drop=dropout_attention) 24 | elif special == 'Att1': # special attention mechanism 1 25 | attention = AttentionMechanism_1(visual_features_size, question_feature_size, attention_middle_size, number_of_glimpses, attention_fusion, drop=dropout_attention) 26 | elif special == 'Att2': 27 | attention = AttentionMechanism_2(visual_features_size, question_feature_size, attention_middle_size, number_of_glimpses, attention_fusion, drop=dropout_attention) 28 | elif special == 'Att3': 29 | attention = AttentionMechanism_3(visual_features_size, question_feature_size, attention_middle_size, number_of_glimpses, attention_fusion, drop=dropout_attention) 30 | elif special == 'Att4': 31 | attention = AttentionMechanism_4(visual_features_size, question_feature_size, attention_middle_size, number_of_glimpses, attention_fusion, drop=dropout_attention) 32 | elif special == 'Att5': 33 | if 'attenuation_factor' in config: 34 | attenuation_factor = config['attenuation_factor'] 35 | else: 36 | print('Using default attenuation factor of 0.1') 37 | attention = AttentionMechanism_5(visual_features_size, question_feature_size, attention_middle_size, number_of_glimpses, attention_fusion, drop=dropout_attention, attenuation_factor=attenuation_factor) 38 | elif special == 'Att6': 39 | attention = AttentionMechanism_6(visual_features_size, question_feature_size, attention_middle_size, number_of_glimpses, attention_fusion, drop=dropout_attention) 40 | return attention 41 | 42 | 43 | def apply_attention(visual_features, attention): 44 | # visual features has size [b, m, k, k] 45 | # attention has size [b, glimpses, k, k] 46 | b, m = visual_features.size()[:2] # batch size, number of feature maps 47 | glimpses = attention.size(1) 48 | visual_features = visual_features.view(b, 1, m, -1) # vectorize feature maps [b, 1, m, k*k] 49 | attention = attention.view(b, glimpses, -1) # vectorize attention maps [b, glimpses, k*k] 50 | attention = F.softmax(attention, dim = -1).unsqueeze(2) # [b, glimpses, 1, k*k] 51 | attended = attention*visual_features # use broadcasting to weight the feature maps 52 | attended = attended.sum(dim=-1) # sum in the spatial dimension [b, glimpses, m] 53 | return attended.view(b, -1) # return vectorized version with size [b, glimpses*m] 54 | 55 | class AttentionMechanismBase(nn.Module): 56 | """Normal attention mechanism""" 57 | def __init__(self, visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=0.0): 58 | super().__init__() 59 | self.conv1 = nn.Conv2d(visual_features_size, attention_middle_size, 1, bias=False) 60 | self.lin1 = nn.Linear(question_feature_size, attention_middle_size) 61 | self.fuser, self.size_after_fusion = fusion.get_fuser(fusion_method, attention_middle_size, attention_middle_size) 62 | self.relu = nn.ReLU() 63 | self.drop = nn.Dropout(drop) 64 | self.conv2 = nn.Conv2d(self.size_after_fusion, glimpses, 1) 65 | 66 | def forward(self, visual_features, question_features, return_maps=False): 67 | # first, compute attention vectors 68 | v = self.conv1(self.drop(visual_features)) 69 | q = self.lin1(self.drop(question_features)) 70 | q = utils.expand_like_2D(q, v) 71 | x = self.relu(self.fuser(v, q)) 72 | x = self.conv2(self.drop(x)) 73 | 74 | if return_maps: # if maps have to be returned, save them in a variable 75 | maps = x.clone() 76 | 77 | # then, apply attention vectors to input visual features 78 | x = apply_attention(visual_features, x) 79 | 80 | if return_maps: 81 | return x, maps 82 | else: 83 | return x 84 | 85 | 86 | class AttentionMechanism_1(AttentionMechanismBase): 87 | """Attention mechanism for model VQARS_4 to include mask before softmax and then softmax only part that mask keeps""" 88 | def __init__(self, visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=0.0): 89 | super().__init__(visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=drop) 90 | 91 | # same as general function above but receiving mask and applying it before the softmax 92 | def apply_attention(self, visual_features, attention, mask): 93 | # visual features has size [b, m, k, k] 94 | # attention has size [b, glimpses, k, k] 95 | # mask has size [b, 1, k*k] 96 | b, m = visual_features.size()[:2] # batch size, number of feature maps 97 | glimpses = attention.size(1) 98 | visual_features = visual_features.view(b, 1, m, -1) # vectorize feature maps [b, 1, m, k*k] 99 | attention = attention.view(b, glimpses, -1) # vectorize attention maps [b, glimpses, k*k] 100 | attention = attention*mask #! Apply mask 101 | for i in range(glimpses): 102 | attention[:,i,:][mask.squeeze().to(torch.bool)] = F.softmax(attention[:,i,:][mask.squeeze().to(torch.bool)], dim=-1) 103 | attention.unsqueeze_(2) 104 | attended = attention*visual_features # use broadcasting to weight the feature maps 105 | attended = attended.sum(dim=-1) # sum in the spatial dimension [b, glimpses, m] 106 | return attended.view(b, -1) # return vectorized version with size [b, glimpses*m] 107 | 108 | # override forward method 109 | def forward(self, visual_features, mask, question_features): 110 | # first, compute attention vectors 111 | v = self.conv1(self.drop(visual_features)) 112 | q = self.lin1(self.drop(question_features)) 113 | q = utils.expand_like_2D(q, v) 114 | x = self.relu(self.fuser(v, q)) 115 | x = self.conv2(self.drop(x)) 116 | 117 | # then, apply attention vectors to input visual features 118 | x = self.apply_attention(visual_features, x, mask) 119 | return x 120 | 121 | class AttentionMechanism_2(AttentionMechanismBase): 122 | """Attention mechanism for model VQARS_6 = VQARS_4 to include mask before softmax and then softmax only part that mask keeps""" 123 | def __init__(self, visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=0.0): 124 | super().__init__(visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=drop) 125 | 126 | # same as general function above but receiving mask and applying it before the softmax 127 | def apply_attention(self, visual_features, attention, mask): 128 | # visual features has size [b, m, k, k] 129 | # attention has size [b, glimpses, k, k] 130 | # mask has size [b, 1, k*k] 131 | b, m = visual_features.size()[:2] # batch size, number of feature maps 132 | glimpses = attention.size(1) 133 | visual_features = visual_features.view(b, 1, m, -1) # vectorize feature maps [b, 1, m, k*k] 134 | attention = attention.view(b, glimpses, -1) # vectorize attention maps [b, glimpses, k*k] 135 | attention[:,0,:] = attention[:,0,:]*mask.squeeze() # apply to first glimpse only 136 | attention[:,0,:][mask.squeeze().to(torch.bool)] = F.softmax(attention[:,0,:][mask.squeeze().to(torch.bool)], dim=-1) # again, only first glimpse 137 | attention.unsqueeze_(2) 138 | attended = attention*visual_features # use broadcasting to weight the feature maps 139 | attended = attended.sum(dim=-1) # sum in the spatial dimension [b, glimpses, m] 140 | return attended.view(b, -1) # return vectorized version with size [b, glimpses*m] 141 | 142 | # override forward method 143 | def forward(self, visual_features, mask, question_features): 144 | # first, compute attention vectors 145 | v = self.conv1(self.drop(visual_features)) 146 | q = self.lin1(self.drop(question_features)) 147 | q = utils.expand_like_2D(q, v) 148 | x = self.relu(self.fuser(v, q)) 149 | x = self.conv2(self.drop(x)) 150 | 151 | # then, apply attention vectors to input visual features 152 | x = self.apply_attention(visual_features, x, mask) 153 | return x 154 | 155 | 156 | class AttentionMechanism_3(AttentionMechanismBase): 157 | """Attention mechanism for model VQARS_7 to include mask after softmax""" 158 | def __init__(self, visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=0.0): 159 | super().__init__(visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=drop) 160 | 161 | # same as general function above but receiving mask and applying it before the softmax 162 | def apply_attention(self, visual_features, attention, mask): 163 | # visual features has size [b, m, k, k] 164 | # attention has size [b, glimpses, k, k] 165 | # mask has size [b, 1, k*k] 166 | b, m = visual_features.size()[:2] # batch size, number of feature maps 167 | glimpses = attention.size(1) 168 | visual_features = visual_features.view(b, 1, m, -1) # vectorize feature maps [b, 1, m, k*k] 169 | attention = attention.view(b, glimpses, -1) # vectorize attention maps [b, glimpses, k*k] 170 | attention = F.softmax(attention, dim = -1) # [b, glimpses, k*k] 171 | attention = attention*mask #! Apply mask 172 | attention.unsqueeze_(2) 173 | attended = attention*visual_features # use broadcasting to weight the feature maps 174 | attended = attended.sum(dim=-1) # sum in the spatial dimension [b, glimpses, m] 175 | return attended.view(b, -1) # return vectorized version with size [b, glimpses*m] 176 | 177 | # override forward method 178 | def forward(self, visual_features, mask, question_features, return_maps=False): 179 | # first, compute attention vectors 180 | v = self.conv1(self.drop(visual_features)) 181 | q = self.lin1(self.drop(question_features)) 182 | q = utils.expand_like_2D(q, v) 183 | x = self.relu(self.fuser(v, q)) 184 | x = self.conv2(self.drop(x)) 185 | 186 | if return_maps: # if maps have to be returned, save them in a variable 187 | maps = x.clone() 188 | 189 | # then, apply attention vectors to input visual features 190 | x = self.apply_attention(visual_features, x, mask) 191 | 192 | if return_maps: 193 | return x, maps 194 | else: 195 | return x 196 | 197 | 198 | class AttentionMechanism_4(AttentionMechanismBase): 199 | """Attention mechanism for model VQARS_7 to include mask after softmax, but region attention is scaled by (1- (max(outside) - max(inside)))""" 200 | def __init__(self, visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=0.0): 201 | super().__init__(visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=drop) 202 | 203 | # same as general function above but receiving mask and applying it before the softmax 204 | def apply_attention(self, visual_features, attention, mask): 205 | # visual features has size [b, m, k, k] 206 | # attention has size [b, glimpses, k, k] 207 | # mask has size [b, 1, k*k] 208 | b, m = visual_features.size()[:2] # batch size, number of feature maps 209 | glimpses = attention.size(1) 210 | visual_features = visual_features.view(b, 1, m, -1) # vectorize feature maps [b, 1, m, k*k] 211 | attention = attention.view(b, glimpses, -1) # vectorize attention maps [b, glimpses, k*k] 212 | attention = F.softmax(attention, dim = -1) # [b, glimpses, k*k] 213 | not_mask = torch.max(mask) - mask 214 | attention = (attention*mask)*(1 - torch.abs(torch.max(attention*not_mask) - torch.max(attention*mask))) #! Apply mask 215 | attention.unsqueeze_(2) 216 | attended = attention*visual_features # use broadcasting to weight the feature maps 217 | attended = attended.sum(dim=-1) # sum in the spatial dimension [b, glimpses, m] 218 | return attended.view(b, -1) # return vectorized version with size [b, glimpses*m] 219 | 220 | # override forward method 221 | def forward(self, visual_features, mask, question_features, return_maps=False): 222 | # first, compute attention vectors 223 | v = self.conv1(self.drop(visual_features)) 224 | q = self.lin1(self.drop(question_features)) 225 | q = utils.expand_like_2D(q, v) 226 | x = self.relu(self.fuser(v, q)) 227 | x = self.conv2(self.drop(x)) 228 | 229 | if return_maps: # if maps have to be returned, save them in a variable 230 | maps = x.clone() 231 | 232 | # then, apply attention vectors to input visual features 233 | x = self.apply_attention(visual_features, x, mask) 234 | 235 | if return_maps: 236 | return x, maps 237 | else: 238 | return x 239 | 240 | 241 | class AttentionMechanism_5(AttentionMechanismBase): 242 | """Same as AttentionMechanism_3 but instead of masking, magnify inside of the region and attenuate outside""" 243 | def __init__(self, visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=0.0, attenuation_factor = 0.1): 244 | super().__init__(visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=drop) 245 | self.attenuation_factor = attenuation_factor 246 | 247 | # same as general function above but receiving mask and applying it before the softmax 248 | def apply_attention(self, visual_features, attention, mask): 249 | # visual features has size [b, m, k, k] 250 | # attention has size [b, glimpses, k, k] 251 | # mask has size [b, 1, k*k] 252 | b, m = visual_features.size()[:2] # batch size, number of feature maps 253 | glimpses = attention.size(1) 254 | visual_features = visual_features.view(b, 1, m, -1) # vectorize feature maps [b, 1, m, k*k] 255 | attention = attention.view(b, glimpses, -1) # vectorize attention maps [b, glimpses, k*k] 256 | attention = F.softmax(attention, dim = -1) # [b, glimpses, k*k] 257 | not_mask = torch.max(mask) - mask 258 | attention = attention*mask + attention*not_mask*self.attenuation_factor 259 | attention.unsqueeze_(2) 260 | attended = attention*visual_features # use broadcasting to weight the feature maps 261 | attended = attended.sum(dim=-1) # sum in the spatial dimension [b, glimpses, m] 262 | return attended.view(b, -1) # return vectorized version with size [b, glimpses*m] 263 | 264 | # override forward method 265 | def forward(self, visual_features, mask, question_features, return_maps=False): 266 | # first, compute attention vectors 267 | v = self.conv1(self.drop(visual_features)) 268 | q = self.lin1(self.drop(question_features)) 269 | q = utils.expand_like_2D(q, v) 270 | x = self.relu(self.fuser(v, q)) 271 | x = self.conv2(self.drop(x)) 272 | 273 | if return_maps: # if maps have to be returned, save them in a variable 274 | maps = x.clone() 275 | 276 | # then, apply attention vectors to input visual features 277 | x = self.apply_attention(visual_features, x, mask) 278 | 279 | if return_maps: 280 | return x, maps 281 | else: 282 | return x 283 | 284 | 285 | class AttentionMechanism_6(AttentionMechanismBase): 286 | """Same as AttentionMechanism_3 but apply the mask to one glimpse only""" 287 | def __init__(self, visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=0.0): 288 | super().__init__(visual_features_size, question_feature_size, attention_middle_size, glimpses, fusion_method, drop=drop) 289 | 290 | # same as general function above but receiving mask and applying it before the softmax 291 | def apply_attention(self, visual_features, attention, mask): 292 | # visual features has size [b, m, k, k] 293 | # attention has size [b, glimpses, k, k] 294 | # mask has size [b, 1, k*k] 295 | b, m = visual_features.size()[:2] # batch size, number of feature maps 296 | glimpses = attention.size(1) 297 | visual_features = visual_features.view(b, 1, m, -1) # vectorize feature maps [b, 1, m, k*k] 298 | attention = attention.view(b, glimpses, -1) # vectorize attention maps [b, glimpses, k*k] 299 | attention = F.softmax(attention, dim = -1) # [b, glimpses, k*k] 300 | attention_g1 = attention[:, 0, :] 301 | attention = attention*mask 302 | attention[:, 0, :] = attention_g1 303 | # now apply mask to only one glimpse, leave the other glimpses untouched 304 | #attention[:, 0, :] = (attention[:, 0, :].unsqueeze(1)*mask).squeeze(1) 305 | attention.unsqueeze_(2) 306 | attended = attention*visual_features # use broadcasting to weight the feature maps 307 | attended = attended.sum(dim=-1) # sum in the spatial dimension [b, glimpses, m] 308 | return attended.view(b, -1) # return vectorized version with size [b, glimpses*m] 309 | 310 | # override forward method 311 | def forward(self, visual_features, mask, question_features, return_maps=False): 312 | # first, compute attention vectors 313 | v = self.conv1(self.drop(visual_features)) 314 | q = self.lin1(self.drop(question_features)) 315 | q = utils.expand_like_2D(q, v) 316 | x = self.relu(self.fuser(v, q)) 317 | x = self.conv2(self.drop(x)) 318 | 319 | if return_maps: # if maps have to be returned, save them in a variable 320 | maps = x.clone() 321 | 322 | # then, apply attention vectors to input visual features 323 | x = self.apply_attention(visual_features, x, mask) 324 | 325 | if return_maps: 326 | return x, maps 327 | else: 328 | return x --------------------------------------------------------------------------------