├── src ├── modeling │ ├── models │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── vit.py │ │ └── albef_model.py │ ├── continual_learner.py │ ├── __init__.py │ ├── adaptered_output.py │ ├── vilt_clf.py │ └── albef.py ├── data │ ├── image_datasets │ │ ├── __init__.py │ │ ├── vizwizimages_dataset.py │ │ ├── vgimages_dataset.py │ │ ├── get_avg_images.py │ │ ├── flickr30kimages_dataset.py │ │ ├── cocoimages_dataset.py │ │ └── cocoimages_dataset_crossvqas.py │ ├── visionlanguage_datasets │ │ ├── __init__.py │ │ ├── nlvr2_dataset.py │ │ ├── snli_ve_dataset.py │ │ ├── vcr_dataset.py │ │ └── vqa_dataset.py │ └── image_collation.py ├── coco_mean_image.png ├── utils │ ├── coco_mean_image.png │ ├── seed_utils.py │ ├── wandb.py │ ├── vqa_utils.py │ ├── make_table.py │ ├── image_utils.py │ └── word_utils.py ├── configs │ ├── wandb_config.py │ ├── adapter_configs.py │ ├── model_configs.py │ ├── task_configs_fed.py │ └── task_configs.py ├── train_vilt.sh ├── train_albef.sh └── train │ ├── visionlanguage_tasks │ ├── train_nlvr2.py │ ├── train_vcr.py │ ├── train_snli_ve.py │ ├── train_vqa.py │ └── train_vqa_crossvqa.py │ ├── train_vision.py │ └── train_lowshot_multimodal.py ├── assets ├── dat.png └── fedvqa.png ├── accelerate_config.yaml ├── requirements.txt ├── LICENSE ├── README.md └── .gitignore /src/modeling/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/image_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/visionlanguage_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/dat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaokunChen245/FedDAT/HEAD/assets/dat.png -------------------------------------------------------------------------------- /assets/fedvqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaokunChen245/FedDAT/HEAD/assets/fedvqa.png -------------------------------------------------------------------------------- /src/coco_mean_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaokunChen245/FedDAT/HEAD/src/coco_mean_image.png -------------------------------------------------------------------------------- /src/utils/coco_mean_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaokunChen245/FedDAT/HEAD/src/utils/coco_mean_image.png -------------------------------------------------------------------------------- /src/configs/wandb_config.py: -------------------------------------------------------------------------------- 1 | wandb_config = { 2 | 'entity': 'vl_in_cl', 3 | 'api_key': '8c8e51a8e53a186730151df59aef40cdd5293e92', 4 | 'project_name': 'climb-cl', 5 | 'log_freq': 100, 6 | } 7 | -------------------------------------------------------------------------------- /src/utils/seed_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def set_seed(seed): 7 | random.seed(seed) 8 | np.random.seed(seed) 9 | torch.manual_seed(seed) 10 | -------------------------------------------------------------------------------- /src/configs/adapter_configs.py: -------------------------------------------------------------------------------- 1 | from transformers import PfeifferConfig, HoulsbyConfig, ParallelConfig, CompacterConfig 2 | 3 | ADAPTER_MAP = { 4 | 'pfeiffer': PfeifferConfig, 5 | 'houlsby': HoulsbyConfig, 6 | 'parallel': ParallelConfig, 7 | 'compacter': CompacterConfig, 8 | } 9 | -------------------------------------------------------------------------------- /accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: { } 3 | distributed_type: MULTI_GPU 4 | fsdp_config: { } 5 | #machine_rank: 0 6 | main_training_function: main 7 | main_process_port: 6012 8 | mixed_precision: fp16 9 | num_machines: 1 10 | num_processes: 1 11 | use_cpu: false -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | torch==1.10.2+cu113 3 | torchvision==0.11.3+cu113 4 | tokenizers==0.11.6 5 | transformers==4.16.2 6 | Pillow==9.2.0 7 | tqdm 8 | numpy 9 | git+https://github.com/rwightman/pytorch-image-models.git 10 | wandb 11 | datasets 12 | jsonlines 13 | scikit-learn 14 | 15 | -------------------------------------------------------------------------------- /src/modeling/continual_learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class EncoderWrapper(nn.Module): 7 | 8 | def __init__(self, **kwargs): 9 | super().__init__() 10 | 11 | def forward(self, **kwargs): 12 | pass 13 | 14 | 15 | class ContinualLearner(nn.Module): 16 | 17 | def __init__(self, **kwargs): 18 | super().__init__() 19 | 20 | def forward(self, **kwargs): 21 | pass 22 | 23 | def get_encoder(self): 24 | pass 25 | -------------------------------------------------------------------------------- /src/train_vilt.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 TOKENIZERS_PARALLELISM=false accelerate launch \ 2 | --config_file accelerate_config.yaml \ 3 | src/train/main.py \ 4 | --encoder_name vilt \ 5 | --pretrained_model_name ./models/vilt-b32-mlm \ 6 | --climb_data_dir '' \ 7 | --do_train \ 8 | --model_path ./models \ 9 | --output_dir ./logs \ 10 | --batch_size 2 \ 11 | --val_batch_size 2 \ 12 | --comm_round 30 \ 13 | --local_epochs 1 \ 14 | --lr 1e-4 \ 15 | --optimizer_mode dat \ 16 | --seed 1 \ 17 | --adapter_reduction_factor 16 \ 18 | --adapter_config pfeiffer \ 19 | --splits train_small val test_small \ 20 | --ordered_cl_tasks domain -------------------------------------------------------------------------------- /src/train_albef.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 TOKENIZERS_PARALLELISM=false accelerate launch \ 2 | --config_file ./accelerate_config.yaml \ 3 | src/train/main.py \ 4 | --encoder_name albef_no_distill \ 5 | --pretrained_model_name ./models/ALBEF.pth \ 6 | --climb_data_dir '' \ 7 | --do_train \ 8 | --model_path ./models/ \ 9 | --output_dir ./logs/ \ 10 | --batch_size 2 \ 11 | --val_batch_size 2 \ 12 | --lr 1e-4 \ 13 | --optimizer_mode dat \ 14 | --seed 2 \ 15 | --adapter_reduction_factor 16 \ 16 | --adapter_config pfeiffer \ 17 | --splits train_small val test \ 18 | --ordered_cl_tasks domain 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .vilt import load_vilt_encoder, create_vilt_continual_learner_model 2 | from .albef import load_albef, create_albef_continual_learner_model 3 | from .viltbert import load_viltbert_encoder, create_viltbert_continual_learner_model 4 | 5 | load_encoder_map = { 6 | 'vilt': load_vilt_encoder, 7 | 'viltbert': load_viltbert_encoder, 8 | 'albef_distill': load_albef, 9 | 'albef_no_distill': load_albef 10 | } 11 | 12 | create_continual_learner_map = { 13 | 'vilt': create_vilt_continual_learner_model, 14 | 'albef_distill': create_albef_continual_learner_model, 15 | 'albef_no_distill': create_albef_continual_learner_model, 16 | 'viltbert': create_viltbert_continual_learner_model, 17 | } 18 | -------------------------------------------------------------------------------- /src/utils/wandb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | 4 | 5 | class WandBLogger: 6 | 7 | def __init__(self): 8 | 9 | self.is_initialized = False 10 | 11 | def initialize(self, wandb_config, experiment_name): 12 | 13 | os.environ['WANDB_API_KEY'] = wandb_config['api_key'] 14 | wandb.init(entity=wandb_config['entity'], 15 | project=wandb_config['project_name'], 16 | name=experiment_name) 17 | self.is_initialized = True 18 | self.log_freq = wandb_config['log_freq'] 19 | 20 | def log(self, log_dict): 21 | 22 | if self.is_initialized: 23 | wandb.log(log_dict) 24 | 25 | def get_log_freq(self): 26 | if self.is_initialized: 27 | return self.log_freq 28 | else: 29 | return 100 30 | 31 | 32 | wandb_logger = WandBLogger() 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tejas Srinivasan, Ting-Yun Chang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/data/image_datasets/vizwizimages_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torchvision import transforms as T 17 | from torch.utils.data import Dataset 18 | 19 | from PIL import Image 20 | from src.utils.image_utils import resize_image 21 | 22 | 23 | class vizwizImagesDataset(Dataset): 24 | 25 | def __init__(self, coco_dir: str, data_dir: str, visual_input_type: str, task_key: str, image_size=(384, 640), transform=None): 26 | 27 | ''' 28 | Initializes an MSCOCOImagesDataset instance that handles image-side processing for VQA and other tasks that use MS-COCO images 29 | coco_dir: directory that contains MS-COCO data (images within 'images' folder) 30 | visual_input_type: format of visual input to model 31 | image_size: tuple indicating size of image input to model 32 | ''' 33 | 34 | self.image_size = image_size 35 | self.raw_transform = T.Compose([ 36 | T.Resize(image_size), 37 | T.ToTensor(), # [0, 1] 38 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [-1, 1] 39 | ]) 40 | 41 | self.pil_transform = T.Resize(size=384, max_size=640) 42 | 43 | def get_image_data(self, image_id: str) -> Image: 44 | ''' 45 | Loads image corresponding to image_id, re-sizes and returns PIL.Image object 46 | ''' 47 | p = f'/home/stud/zhangya/carvendata/vizwiz/images/{image_id}' 48 | image = Image.open(p) 49 | image = image.convert('RGB') 50 | if min(list(image.size)) > 384 or hasattr(self, 'use_albef'): 51 | image = self.pil_transform(image) 52 | return image -------------------------------------------------------------------------------- /src/data/image_datasets/vgimages_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torchvision import transforms as T 17 | from torch.utils.data import Dataset 18 | 19 | from PIL import Image 20 | from src.utils.image_utils import resize_image 21 | 22 | 23 | class VGImagesDataset(Dataset): 24 | 25 | def __init__(self, coco_dir: str, data_dir: str, visual_input_type: str, task_key: str, image_size=(384, 640)): 26 | 27 | ''' 28 | Initializes an MSCOCOImagesDataset instance that handles image-side processing for VQA and other tasks that use MS-COCO images 29 | coco_dir: directory that contains MS-COCO data (images within 'images' folder) 30 | visual_input_type: format of visual input to model 31 | image_size: tuple indicating size of image input to model 32 | ''' 33 | 34 | self.image_size = image_size 35 | self.raw_transform = T.Compose([ 36 | T.Resize(image_size), 37 | T.ToTensor(), # [0, 1] 38 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [-1, 1] 39 | ]) 40 | 41 | self.pil_transform = T.Resize(size=384, max_size=640) 42 | 43 | def get_image_data(self, image_id: str) -> Image: 44 | ''' 45 | Loads image corresponding to image_id, re-sizes and returns PIL.Image object 46 | ''' 47 | image_id = image_id.replace('n', '') 48 | p = f'./data/vg/VG_100K/{image_id}.jpg' 49 | image = Image.open(p) 50 | image = image.convert('RGB') 51 | if min(list(image.size)) > 384 or hasattr(self, 'use_albef'): 52 | image = self.pil_transform(image) 53 | return image -------------------------------------------------------------------------------- /src/utils/vqa_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle as pkl 3 | import os 4 | 5 | import torch 6 | 7 | from collections import defaultdict, Counter 8 | from src.utils.word_utils import normalize_word 9 | 10 | 11 | class FeatureHook: 12 | def __init__(self, module): 13 | self.hook = module.register_forward_hook(self.hook_fn) 14 | 15 | def hook_fn(self, module, input, output): 16 | self.embedding = output 17 | 18 | def close(self): 19 | self.hook.remove() 20 | 21 | def get_score(occurences): 22 | if occurences == 0: 23 | return 0.0 24 | elif occurences == 1: 25 | return 0.3 26 | elif occurences == 2: 27 | return 0.6 28 | elif occurences == 3: 29 | return 0.9 30 | else: 31 | return 1.0 32 | 33 | 34 | def create_vqa_labels(vqa_dir): 35 | train_annotations = json.load(open(os.path.join(vqa_dir, 'v2_mscoco_train2014_annotations.json')))['annotations'] 36 | val_annotations = json.load(open(os.path.join(vqa_dir, 'v2_mscoco_val2014_annotations.json')))['annotations'] 37 | 38 | all_major_answers = [] 39 | for anno in train_annotations: 40 | all_major_answers.append(normalize_word(anno['multiple_choice_answer'])) 41 | for anno in val_annotations: 42 | all_major_answers.append(normalize_word(anno['multiple_choice_answer'])) 43 | counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9} 44 | 45 | ans2label = {k: i for i, k in enumerate(counter.keys())} 46 | print("Number of labels: {}".format(len(ans2label))) 47 | 48 | pkl.dump(ans2label, open(os.path.join(vqa_dir, 'ans2label.pkl'), 'wb')) 49 | 50 | 51 | ''' 52 | def target_tensor(len, labels, scores): 53 | """ create the target by labels and scores """ 54 | target = [0]*len 55 | for id, l in enumerate(labels): 56 | target[l] = scores[id] 57 | 58 | return torch.tensor(target) 59 | ''' # this seems more straightforward to me 60 | 61 | 62 | def target_tensor(num_labels, labels, scores): 63 | """ create the target by labels and scores """ 64 | target = torch.zeros(num_labels) 65 | target[labels] = torch.tensor(scores) 66 | 67 | return target 68 | 69 | 70 | if __name__ == '__main__': 71 | create_vqa_labels('/nfs/data2/yyang/climb_data/vqav2/') 72 | -------------------------------------------------------------------------------- /src/utils/make_table.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pprint 3 | import os 4 | import pdb 5 | import numpy as np 6 | from collections import defaultdict 7 | import sys 8 | import glob 9 | 10 | 11 | def merge_all_results(all_scores, fns, backbone): 12 | for fn in fns: 13 | with open(fn, "r") as f: 14 | rdict = json.load(f) 15 | 16 | name = os.path.basename(fn).split('_')[:-1] 17 | if len(name) == 2: 18 | algo = backbone 19 | t_order = 'task0' 20 | t_name = 'NA' 21 | elif len(name) == 3: 22 | algo = 'single' 23 | t_order, t_name = name[1:] 24 | elif len(name) == 4: 25 | t_order, t_name, algo = name[1:] 26 | 27 | for k in rdict.keys(): 28 | scores = np.array(list(rdict[k].values())) 29 | test_scores, dev_scores = scores[:, 0], scores[:, 1] 30 | 31 | n_shot = k.split('-')[-1] 32 | if 'vision' in fn: 33 | all_scores[algo][t_order][t_name][n_shot] = f'{test_scores[0]:.1f}' 34 | else: 35 | # assert scores.shape == (3,3) 36 | all_scores[backbone][algo][t_order][t_name][n_shot] = f'{test_scores.mean():.1f} ±{test_scores.std():.1f}' 37 | 38 | return all_scores 39 | 40 | 41 | def dump_outputs(all_scores, task_name): 42 | out_fn = f"{task_name}.json" 43 | with open(out_fn, "w") as outfile: 44 | outfile.write(json.dumps(all_scores)) 45 | with open(out_fn, "r") as f: 46 | rdict = json.load(f) 47 | 48 | pp = pprint.PrettyPrinter() 49 | pp.pprint(rdict) 50 | 51 | 52 | if __name__ == "__main__": 53 | assert len(sys.argv) == 2, "input task name" 54 | task_name = sys.argv[1] 55 | tree = lambda: defaultdict(tree) 56 | all_scores = tree() 57 | 58 | if task_name in ['coco', 'imagenet', 'inat2019', 'places365']: 59 | dir_name = 'vision_only' 60 | fns = glob.glob(f"/data/experiments/MCL/{dir_name}/{task_name}_*") 61 | all_scores = merge_all_results(all_scores, fns, 'ViLT') 62 | else: 63 | dir_name = 'lang_only' 64 | 65 | fns = glob.glob(f"/data/experiments/MCL/{dir_name}/{task_name}_*") 66 | all_scores = merge_all_results(all_scores, fns, 'ViLT') 67 | fns = glob.glob(f"/data/experiments/MCL/{dir_name}/viltbert/{task_name}_*") 68 | all_scores = merge_all_results(all_scores, fns, 'ViLTBERT') 69 | 70 | dump_outputs(all_scores, task_name) 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedDAT (Federated Dual-Adapter Teacher) 2 | 3 | An approach for foundation model finetuning in multi-modal heterogeneous federated learning. [ [Pre-print]](https://arxiv.org/pdf/2308.12305.pdf) 4 | 5 | ![Problem Setup](/assets/fedvqa.png "Magic Gardens") 6 | 7 | We propose Dual-Adapter Teacher (DAT) module and apply Mutual Knowledge Distillation (MKD) to mitigate the client local data heterogeneity in different modality. 8 | 9 | ![Method](/assets/dat.png "Method") 10 | 11 | 12 | --- 13 | 14 | ## Setup 15 | 16 | 1. Create Conda environment with Python 3.8 17 | 18 | ``` 19 | conda create -n feddat python=3.8 20 | conda activate feddat 21 | ``` 22 | 23 | 2. Install requirements 24 | 25 | ``` 26 | git clone https://github.com/HaokunChen245/FedDAT.git 27 | pip install -r requirements.txt 28 | pip install -U adapters 29 | pip install accelerate 30 | ``` 31 | 3. Prepare datasets and pretrained-models 32 | 33 | | Dataset | Link | 34 | | :----:| :----: | 35 | | AQUA | https://github.com/noagarcia/ArtVQA/tree/master/AQUA | 36 | | COCO-QA | http://www.cs.toronto.edu/~mren/imageqa/data/cocoqa/cocoqa-2015-05-17.zip | 37 | | Images for COCO-QA | https://cocodataset.org/#download | 38 | | Abstract Scenes | https://visualqa.org/download.html | 39 | | VizWiz | https://vizwiz.org/tasks-and-datasets/vqa/ | 40 | | GQA | https://cs.stanford.edu/people/dorarad/gqa/download.html | 41 | | VG_100K | https://huggingface.co/datasets/visual_genome | 42 | | Function & Scene (CLOVE benchmark) | TODO | 43 | 44 | 45 | Put the datasets in the folder /data 46 | 47 | | Model | Link | 48 | | :----:| :----: | 49 | | ALBEF | https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth | 50 | | ViLT | https://huggingface.co/dandelin/vilt-b32-mlm | 51 | | BERT | https://huggingface.co/bert-base-uncased/tree/main | 52 | 53 | Put the models in the folder /models 54 | 55 | --- 56 | 57 | ## Run 58 | 59 | ``` 60 | # Training with ViLT 61 | bash src/train_vilt.sh 62 | 63 | # Training with ALBEF 64 | bash src/train_albef.sh 65 | ``` 66 | 67 | --- 68 | 69 | ## Citation 70 | 71 | ```bibtex 72 | @article{chen2023feddat, 73 | title={FedDAT: An Approach for Foundation Model Finetuning in Multi-Modal Heterogeneous Federated Learning}, 74 | author={Chen, Haokun and Zhang, Yao and Krompass, Denis and Gu, Jindong and Tresp, Volker}, 75 | journal={arXiv preprint arXiv:2308.12305}, 76 | year={2023} 77 | } 78 | ``` 79 | 80 | -------------------------------------------------------------------------------- /src/data/image_collation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | import pdb 13 | from typing import List, Dict 14 | 15 | from transformers import BertTokenizer 16 | 17 | from PIL import Image 18 | from src.utils.image_utils import resize_image 19 | from src.utils.vqa_utils import get_score, target_tensor 20 | 21 | from src.data.image_datasets.cocoimages_dataset import MSCOCOImagesDataset 22 | 23 | ALLOWED_VISUAL_INPUT_TYPES = ['raw', # A raw tensor of size (3, W, H) 24 | 'pil-image', # A PIL.Image instance 25 | 'fast-rcnn' # A set of R features, each of dim H 26 | ] 27 | 28 | 29 | def image_collate(images: List, 30 | visual_input_type: str): 31 | """ 32 | Converts list of B images into a batched image input 33 | 34 | Args: 35 | images: list of B images - type(image) can vary according to visual_input_type (see ALLOWED_VISUAL_INPUT_TYPES above) 36 | visual_input_type: one element from ALLOWED_VISUAL_INPUT_TYPES 37 | 38 | Returns: 39 | collated_images: list/Tensor that contains all the images collated together into a batch input 40 | """ 41 | 42 | if visual_input_type == 'pil-image': 43 | # returns list of PIL.Image objects 44 | collated_images = images 45 | 46 | if visual_input_type == 'raw': 47 | # Stacks individual raw image tensors to give (B, 3, W, H) tensor 48 | collated_images = torch.stack(images, dim=0) 49 | 50 | elif visual_input_type == 'fast-rcnn': 51 | # Stack the image tensors, doing padding if necessary for the sequence of region features 52 | # Each element is a tensor of shape [R_i, H], returns tensor of shape [B, max(R_i), H] 53 | max_len = max([t.shape[0] for t in images]) 54 | image_tensors_padded = [] 55 | for i in range(len(images)): 56 | padding_tensor = torch.zeros(max_len - images[i].shape[0], images[i].shape[1]) 57 | padded_tensor = torch.cat((images[i], padding_tensor), dim=0) 58 | assert padded_tensor.shape[0] == max_len 59 | image_tensors_padded.append(padded_tensor) 60 | collated_images = torch.stack(image_tensors_padded, dim=0) # Pads region features with 0 vectors to give (B, R, hv) tensor 61 | 62 | return collated_images 63 | -------------------------------------------------------------------------------- /src/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | from PIL import Image 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | def resize_image(image, desired_shape): 9 | try: 10 | w, h = image.size 11 | # logging.info("size: {}".format((w, h))) 12 | if image.mode == 'CMYK': 13 | image = image.convert('RGB') 14 | 15 | if w > h: 16 | d_w = max(desired_shape) 17 | d_h = min(desired_shape) 18 | # print("d_w: {} d_h: {}".format(d_w, d_h)) 19 | if w >= d_w: 20 | new_h = int(h * d_w / w) 21 | # print("new_h: {}".format(new_h)) 22 | if new_h > d_h: 23 | image = image.resize((int(w * d_h / h), d_h), resample=0) 24 | else: 25 | image = image.resize((d_w, new_h), resample=0) 26 | else: 27 | if h > d_h: 28 | new_w = int(d_h * w / h) 29 | # print("new_w: {}".format(new_w)) 30 | image = image.resize((new_w, d_h), resample=0) 31 | else: 32 | d_h = max(desired_shape) 33 | d_w = min(desired_shape) 34 | # print("d_w: {} d_h: {}".format(d_w, d_h)) 35 | if h >= d_h: 36 | new_w = int(w * d_h / h) 37 | # print("new_w: {}".format(new_w)) 38 | if new_w > d_w: 39 | image = image.resize((d_w, int(h * d_w / w)), resample=0) 40 | else: 41 | image = image.resize((new_w, d_h), resample=0) 42 | else: 43 | if w > d_w: 44 | new_h = int(d_w * h / w) 45 | # print("new_h: {}".format(new_h)) 46 | image = image.resize((d_w, new_h), resample=0) 47 | 48 | image_arr = np.asarray(image) # size: (w, h) -> shape (h, w) 49 | if len(image_arr.shape) < 3: # for grayscale images, stack channels 50 | image_arr = np.stack((image_arr,) * 3, axis=-1) 51 | elif len(image_arr.shape) == 3 and image_arr.shape[2] > 3: 52 | image_arr = image_arr[:, :, :3] 53 | padded_image = np.zeros((d_h, d_w, 3,), dtype=np.float64) 54 | padded_image[:image_arr.shape[0], :image_arr.shape[1]] = image_arr 55 | return padded_image 56 | except Exception as e: 57 | d_w = max(desired_shape) 58 | d_h = min(desired_shape) 59 | padded_image = np.zeros((d_h, d_w, 3,), dtype=np.float64) 60 | return padded_image 61 | -------------------------------------------------------------------------------- /src/configs/model_configs.py: -------------------------------------------------------------------------------- 1 | from src.modeling.albef import * 2 | from src.modeling.vilt import * 3 | from src.modeling.vilt_clf import * 4 | from src.modeling.viltbert import * 5 | 6 | ALLOWED_CL_ENCODERS = ["vilt", "viltbert", "flava", "albef_distill", "albef_no_distill"] 7 | 8 | #### for ViLT 9 | vilt_config = { 10 | 'encoder_dim': 768, 11 | 'visual_input_type': 'pil-image', 12 | 'encoder_class': ViltEncoderWrapper, 13 | 'batch2inputs_converter': convert_batch_to_vilt_input_dict, 14 | 'encoder_name': 'ViLT' 15 | } 16 | 17 | 18 | viltbert_config = { 19 | "encoder_dim": 768, 20 | "visual_input_type": "pil-image", 21 | "encoder_class": ViltBertEncoderWrapper, 22 | "batch2inputs_converter": convert_batch_to_viltbert_input_dict, 23 | "encoder_name": "ViLT-BERT", 24 | } 25 | viltbert_lang_seq_config = { 26 | "encoder_dim": 768, 27 | "visual_input_type": "pil-image", 28 | "encoder_class": ViltBertEncoderWrapper, 29 | "classifier_class": ViltBertForSequenceClassification, 30 | "batch2inputs_converter": convert_seq_batch_to_vilt_input_dict, 31 | } 32 | viltbert_lang_mc_config = { 33 | "encoder_dim": 768, 34 | "visual_input_type": "pil-image", 35 | "encoder_class": ViltBertEncoderWrapper, 36 | "classifier_class": ViltBertForMultipleChoice, 37 | "batch2inputs_converter": convert_mc_batch_to_vilt_input_dict, 38 | } 39 | 40 | config_bert = { 41 | "architectures": [ 42 | "BertForMaskedLM" 43 | ], 44 | "attention_probs_dropout_prob": 0.1, 45 | "hidden_act": "gelu", 46 | "hidden_dropout_prob": 0.1, 47 | "hidden_size": 768, 48 | "initializer_range": 0.02, 49 | "intermediate_size": 3072, 50 | "layer_norm_eps": 1e-12, 51 | "max_position_embeddings": 512, 52 | "model_type": "bert", 53 | "num_attention_heads": 12, 54 | "num_hidden_layers": 12, 55 | "pad_token_id": 0, 56 | "type_vocab_size": 2, 57 | "vocab_size": 30522, 58 | "fusion_layer": 6, 59 | "encoder_width": 768 60 | } 61 | 62 | albef_no_distill_config = { 63 | "text_encoder": "bert-base-uncased", 64 | "text_decoder": "bert-base-uncased", 65 | "image_res": 384, 66 | "visual_input_type": "pil-image", 67 | "bert_config": config_bert, 68 | "batch2inputs_converter": convert_batch_to_albef_input_dict, 69 | "distill": False, 70 | "encoder_class": ALBEFWrapper, 71 | "encoder_name": "albef_no_distill", 72 | } 73 | 74 | albef_distill_config = { 75 | "text_encoder": "bert-base-uncased", 76 | "text_decoder": "bert-base-uncased", 77 | "image_res": 384, 78 | "visual_input_type": "pil-image", 79 | "bert_config": config_bert, 80 | "distill": True, 81 | "batch2inputs_converter": convert_batch_to_albef_input_dict, 82 | "encoder_class": ALBEFWrapper, 83 | "encoder_name": "albef_distill", 84 | } 85 | 86 | model_configs = { 87 | "vilt": vilt_config, 88 | "albef_distill": albef_distill_config, 89 | "albef_no_distill": albef_no_distill_config, 90 | } 91 | -------------------------------------------------------------------------------- /src/modeling/adaptered_output.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import sys 4 | import logging 5 | from accelerate.logging import get_logger 6 | import itertools 7 | import pdb 8 | import time 9 | from PIL import Image 10 | from typing import List, Dict 11 | from typing_extensions import OrderedDict 12 | 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from transformers import BertConfig, BertTokenizer, BertModel 20 | from transformers import ViltConfig, ViltProcessor, ViltModel 21 | from transformers import BertTokenizerFast 22 | from transformers import logging as transformers_logging 23 | from src.modeling.models.adapter import Adapter, init_bert_weights 24 | from src.modeling.continual_learner import EncoderWrapper, ContinualLearner 25 | from src.modeling.models.vit import Attention 26 | import loralib as lora 27 | 28 | class Attention_lorad(nn.Module): 29 | def __init__(self, layer, dim): 30 | super().__init__() 31 | self.layer = layer 32 | self.query = lora.Linear(dim, dim, r=16) 33 | self.value = lora.Linear(dim, dim, r=16) 34 | 35 | def forward(self, x, register_hook=False): 36 | B, N, C = x.shape 37 | qkv = self.layer.qkv(x).reshape(B, N, 3, self.layer.num_heads, C // self.layer.num_heads).permute(2, 0, 3, 1, 4) 38 | q_lora = self.query(x).reshape(B, N, self.layer.num_heads, C // self.layer.num_heads).permute(0, 2, 1, 3) 39 | v_lora = self.value(x).reshape(B, N, self.layer.num_heads, C // self.layer.num_heads).permute(0, 2, 1, 3) 40 | q, k, v = q_lora, qkv[1], v_lora # make torchscript happy (cannot use tensor as tuple) 41 | 42 | attn = (q @ k.transpose(-2, -1)) * self.layer.scale 43 | attn = attn.softmax(dim=-1) 44 | attn = self.layer.attn_drop(attn) 45 | 46 | if register_hook: 47 | self.layer.save_attention_map(attn) 48 | attn.register_hook(self.layer.save_attn_gradients) 49 | 50 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 51 | x = self.layer.proj(x) 52 | x = self.layer.proj_drop(x) 53 | return x 54 | 55 | class Adaptered_BertOutput(nn.Module): 56 | def __init__(self, layer, adapter_config): 57 | super().__init__() 58 | self.layer = layer 59 | self.adapter = Adapter(**adapter_config, model_dim=768) 60 | 61 | def forward(self, hidden_states, input_tensor): 62 | hidden_states = self.layer.dense(hidden_states) 63 | hidden_states = self.layer.dropout(hidden_states) 64 | hidden_states = self.adapter.adapter_layer_forward_bert(hidden_states, input_tensor, self.layer.LayerNorm) 65 | return hidden_states 66 | 67 | class Adaptered_ViltOutput(nn.Module): 68 | def __init__(self, layer, adapter_config) -> None: 69 | super().__init__() 70 | self.layer = layer 71 | self.adapter = Adapter(**adapter_config, model_dim=768) 72 | 73 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: 74 | hidden_states = self.layer.dense(hidden_states) 75 | hidden_states = self.layer.dropout(hidden_states) 76 | hidden_states = hidden_states + input_tensor 77 | 78 | hidden_states = self.adapter(hidden_states, hidden_states) 79 | return hidden_states -------------------------------------------------------------------------------- /src/data/image_datasets/get_avg_images.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | import pdb 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torchvision import transforms as T 17 | from torch.utils.data import Dataset 18 | from PIL import Image 19 | from torchvision.utils import save_image 20 | from torch.utils import data 21 | 22 | class MSCOCOImagesDataset(Dataset): 23 | 24 | def __init__(self, coco_dir, image_size=(384, 384)): 25 | 26 | self.images_dir = os.path.join(coco_dir, 'images') # Images across all 2017 splits stored in same directory 27 | self.image_size = image_size 28 | 29 | image_filenames = os.listdir(self.images_dir) 30 | self.imageid2filename = {} 31 | for fn in image_filenames: 32 | image_id = int(fn.strip('.jpg')) 33 | self.imageid2filename[image_id] = os.path.join(self.images_dir, fn) 34 | self.imageids = list(self.imageid2filename.keys()) 35 | self.num_images = len(self.imageids) 36 | 37 | self.raw_transform = T.Compose([ 38 | T.Resize(image_size), 39 | T.ToTensor(), # [0, 1] 40 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [-1, 1] 41 | ]) 42 | 43 | 44 | def get_raw_image_tensor(self, image_fn): 45 | image = Image.open(image_fn) 46 | image = image.convert('RGB') 47 | image_tensor = self.raw_transform(image) 48 | image.close() 49 | return image_tensor # (B, 3, W, H) 50 | 51 | def __getitem__(self, i): 52 | image_id = self.imageids[i] 53 | image_fn = self.imageid2filename[image_id] 54 | x = self.get_raw_image_tensor(image_fn) 55 | return x 56 | 57 | def __len__(self): 58 | return self.num_images 59 | 60 | 61 | def save_imgs(image, fn, n=4): 62 | def denorm(x): 63 | """Convert the range from [-1, 1] to [0, 1].""" 64 | out = (x + 1) / 2 65 | return out.clamp_(0, 1) 66 | 67 | imgs = denorm(image[:n].cpu()) 68 | save_image(imgs, f'{fn}.png', nrow=n, padding=0) 69 | print(f'Save {fn}.png!', imgs.shape) 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | dataset = MSCOCOImagesDataset('/data/datasets/MCL/ms-coco/') 75 | data_loader = data.DataLoader(dataset=dataset, 76 | batch_size=2048, 77 | shuffle=False, 78 | drop_last=False, 79 | num_workers=8, 80 | ) 81 | 82 | print('# images:', dataset.num_images) 83 | sum_image = None 84 | for batch in tqdm(data_loader): 85 | batch = batch.cuda() 86 | if sum_image is not None: 87 | sum_image += batch.sum(0) 88 | else: 89 | sum_image = batch.sum(0) 90 | tmp_mean_img = batch[:4].sum(0) / 4 91 | tmp_cat = torch.cat((batch[:4], tmp_mean_img.unsqueeze(0)), 0) 92 | save_imgs(tmp_cat, 0, 5) 93 | pdb.set_trace() 94 | mean_image = sum_image / dataset.num_images 95 | save_imgs(mean_image.unsqueeze(0), 'coco_mean_image', 1) 96 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /src/data/image_datasets/flickr30kimages_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torchvision import transforms as T 17 | from torch.utils.data import Dataset 18 | 19 | from PIL import Image 20 | 21 | 22 | class Flickr30KImagesDataset(Dataset): 23 | 24 | def __init__(self, flickr_dir: str, visual_input_type: str, image_size=(384,640), transform=None): 25 | 26 | ''' 27 | Initializes a Flickr30KImagesDataset instance that handles image-side processing for SNLI-VE and other tasks that use Flickr images 28 | coco_dir: directory that contains Flickr30K data (images within 'flickr30k_images' folder) 29 | visual_input_type: format of visual input to model 30 | image_size: tuple indicating size of image input to model 31 | ''' 32 | 33 | self.images_dir = os.path.join(flickr_dir, 'flickr30k_images') # Images across all 2017 splits stored in same directory 34 | self.image_size = image_size 35 | self.visual_input_type = visual_input_type 36 | assert visual_input_type in ['pil-image', 'raw', 'fast-rcnn'] 37 | 38 | image_filenames = os.listdir(self.images_dir) 39 | self.imageid2filename = {} 40 | for fn in image_filenames: 41 | image_id = int(fn.strip('.jpg')) 42 | self.imageid2filename[image_id] = os.path.join(self.images_dir, fn) 43 | self.imageids = list(self.imageid2filename.keys()) 44 | 45 | self.raw_transform = T.Compose([ 46 | T.Resize(image_size), 47 | T.ToTensor(), # [0, 1] 48 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [-1, 1] 49 | ]) 50 | 51 | self.pil_transform = T.Resize(image_size) 52 | 53 | def get_image_data(self, image_id: str): 54 | 55 | ''' 56 | Returns image data according to required visual_input_type. Output format varies by visual_input_type 57 | ''' 58 | 59 | if self.visual_input_type == 'pil-image': 60 | return self.get_pil_image(image_id) 61 | 62 | if self.visual_input_type == 'raw': 63 | return self.get_raw_image_tensor(image_id) 64 | 65 | elif self.visual_input_type == 'fast-rcnn': 66 | raise NotImplementedError("Have not implemented Fast-RCNN feature inputs for Flickr30K images!") 67 | 68 | 69 | def get_pil_image(self, image_id: str) -> Image: 70 | ''' 71 | Loads image corresponding to image_id, re-sizes and returns PIL.Image object 72 | ''' 73 | 74 | assert image_id in self.imageid2filename.keys() 75 | image_fn = self.imageid2filename[image_id] 76 | image = Image.open(image_fn) 77 | image = image.convert('RGB') 78 | if min(list(image.size)) > 384: 79 | image = self.pil_transform(image) 80 | return image 81 | 82 | def get_raw_image_tensor(self, image_id: str) -> torch.Tensor: 83 | ''' 84 | Loads image corresponding to image_id, re-sizes, and returns tensor of size (3, W, H) 85 | ''' 86 | 87 | assert image_id in self.imageid2filename.keys() 88 | image_fn = self.imageid2filename[image_id] 89 | image = Image.open(image_fn) 90 | image = image.convert('RGB') 91 | 92 | image_tensor = self.raw_transform(image) 93 | 94 | image.close() 95 | return image_tensor # (B, 3, W, H) 96 | 97 | if __name__ == '__main__': 98 | 99 | dataset = Flickr30KImagesDataset('/data/datasets/MCL/flickr30k/', 'raw') 100 | imgid = dataset.imageids[0] 101 | x = dataset.get_image_data(imgid) 102 | print(x.shape) 103 | -------------------------------------------------------------------------------- /src/data/image_datasets/cocoimages_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torchvision import transforms as T 17 | from torch.utils.data import Dataset 18 | 19 | from PIL import Image 20 | 21 | class MSCOCOImagesDataset(Dataset): 22 | 23 | def __init__(self, coco_dir: str, visual_input_type: str, image_size=(384,640), transform=None): 24 | 25 | ''' 26 | Initializes an MSCOCOImagesDataset instance that handles image-side processing for VQA and other tasks that use MS-COCO images 27 | coco_dir: directory that contains MS-COCO data (images within 'images' folder) 28 | visual_input_type: format of visual input to model 29 | image_size: tuple indicating size of image input to model 30 | ''' 31 | 32 | self.images_dir = os.path.join(coco_dir, 'images') # Images across all 2017 splits stored in same directory 33 | self.image_size = image_size 34 | 35 | self.visual_input_type = visual_input_type 36 | assert visual_input_type in ['pil-image', 'raw', 'fast-rcnn'] 37 | 38 | image_filenames = os.listdir(self.images_dir) 39 | self.imageid2filename = {} 40 | for fn in image_filenames: 41 | fn = fn.split('_')[-1] 42 | image_id = int(fn.strip('.jpg')) 43 | self.imageid2filename[image_id] = os.path.join(self.images_dir, fn) 44 | self.imageids = list(set(list(self.imageid2filename.keys()))) 45 | 46 | self.raw_transform = T.Compose([ 47 | T.Resize(image_size), 48 | T.ToTensor(), # [0, 1] 49 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [-1, 1] 50 | ]) 51 | 52 | #self.pil_transform = T.Resize(image_size) 53 | self.pil_transform = T.Resize(size=384, max_size=640) 54 | 55 | 56 | def get_image_data(self, image_id: str): 57 | 58 | ''' 59 | Returns image data according to required visual_input_type. Output format varies by visual_input_type 60 | ''' 61 | 62 | if self.visual_input_type == 'pil-image': 63 | return self.get_pil_image(image_id) 64 | 65 | if self.visual_input_type == 'raw': 66 | return self.get_raw_image_tensor(image_id) 67 | 68 | elif self.visual_input_type == 'fast-rcnn': 69 | raise NotImplementedError("Have not implemented Fast-RCNN feature inputs for MS-COCO images!") 70 | 71 | def get_pil_image(self, image_id: str) -> Image: 72 | ''' 73 | Loads image corresponding to image_id, re-sizes and returns PIL.Image object 74 | ''' 75 | 76 | assert image_id in self.imageid2filename.keys() 77 | image_fn = self.imageid2filename[image_id] 78 | image = Image.open(image_fn) 79 | image = image.convert('RGB') 80 | if min(list(image.size)) > 384 or hasattr(self, 'use_albef'): 81 | image = self.pil_transform(image) 82 | return image 83 | 84 | def get_raw_image_tensor(self, image_id: str) -> torch.Tensor: 85 | ''' 86 | Loads image corresponding to image_id, re-sizes, and returns tensor of size (3, W, H) 87 | ''' 88 | 89 | assert image_id in self.imageid2filename.keys() 90 | image_fn = self.imageid2filename[image_id] 91 | image = Image.open(image_fn) 92 | image = image.convert('RGB') 93 | 94 | image_tensor = self.raw_transform(image) 95 | 96 | image.close() 97 | return image_tensor # (B, 3, W, H) 98 | 99 | if __name__ == '__main__': 100 | 101 | dataset = MSCOCOImagesDataset('/data/datasets/MCL/ms-coco/', 'raw') 102 | imgid = dataset.imageids[0] 103 | x = dataset.get_image_data(imgid) 104 | print(x.shape) 105 | -------------------------------------------------------------------------------- /src/train/visionlanguage_tasks/train_nlvr2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import time 9 | import math 10 | import shutil 11 | import pickle as pkl 12 | import copy 13 | import pdb 14 | from tqdm import tqdm 15 | from typing import List, Dict, Tuple 16 | 17 | import numpy as np 18 | import torch 19 | from torch import nn 20 | from torch.optim import AdamW 21 | from transformers import get_polynomial_decay_schedule_with_warmup 22 | 23 | from src.data.visionlanguage_datasets.nlvr2_dataset import build_nlvr2_dataloader 24 | from src.train.visionlanguage_tasks.task_trainer import TaskTrainer 25 | from src.utils.wandb import wandb_logger 26 | 27 | sys.path.insert(0, '.') 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | 30 | logger = logging.getLogger(__name__) 31 | logging.basicConfig( 32 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 33 | datefmt='%m/%d/%Y %H:%M:%S', 34 | level=logging.INFO) 35 | 36 | class NLVR2Trainer(TaskTrainer): 37 | 38 | def __init__(self, 39 | logger, 40 | args: argparse.Namespace, 41 | task_configs: Dict, 42 | model_config: Dict, 43 | device: torch.device, 44 | task_key, 45 | task_output_dir, 46 | accelerator): 47 | 48 | ''' 49 | Initializes a Trainer that handles training of a model on the VCR task 50 | 51 | args: Arguments provided by user 52 | task_configs: dictionary containing task-specific configuration parameters for all tasks 53 | model_config: dictionary containing model-specific configuration parameters 54 | device: cuda/cpu 55 | ''' 56 | 57 | super().__init__() 58 | 59 | self.args = args 60 | self.local_epochs = args.local_epochs 61 | self.device = device 62 | self.accelerator = accelerator 63 | self.task_output_dir = task_output_dir 64 | self.task_key = task_key 65 | 66 | self.nlvr_config = task_configs['nlvr2'] 67 | self.data_dir = os.path.join(args.climb_data_dir, self.nlvr_config['data_dir']) 68 | 69 | # Model-specific stuff 70 | self.visual_input_type = model_config['visual_input_type'] 71 | self.batch2inputs_converter = model_config['batch2inputs_converter'] 72 | 73 | # Create dataloaders for training and validation 74 | self.nlvr_train_dataloader = build_nlvr2_dataloader(args=args, 75 | data_dir=self.data_dir, 76 | split='train', 77 | visual_input_type=self.visual_input_type) 78 | 79 | self.nlvr_val_dataloader = build_nlvr2_dataloader(args=args, 80 | data_dir=self.data_dir, 81 | split='val', 82 | visual_input_type=self.visual_input_type) 83 | 84 | # Training hyperparameters 85 | self.num_epochs = self.nlvr_config['num_epochs'] 86 | self.lr = self.nlvr_config['lr'] 87 | self.adam_epsilon = self.nlvr_config['adam_epsilon'] 88 | self.weight_decay = self.nlvr_config['weight_decay'] 89 | self.loss_criterion = nn.CrossEntropyLoss() 90 | 91 | self.nlvr_train_dataloader.dataset.convert_to_low_shot(num_shots_per_class=2048) 92 | self.nlvr_val_dataloader.dataset.convert_to_low_shot(num_shots_per_class=256) 93 | self.max_steps = len(self.nlvr_train_dataloader) * self.num_epochs 94 | self.warmup_ratio = 0.1 # TODO remove hard code 95 | self.hparams = { 96 | 'lr': self.lr, 97 | 'weight_decay': self.weight_decay, 98 | 'adam_epsilon': self.adam_epsilon, 99 | } 100 | 101 | def get_train_dataloader(self): 102 | return self.nlvr_train_dataloader 103 | 104 | def get_collate_fn(self): 105 | return self.nlvr_train_dataloader.collate_fn 106 | -------------------------------------------------------------------------------- /src/train/visionlanguage_tasks/train_vcr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import time 9 | import math 10 | import shutil 11 | import pickle as pkl 12 | import copy 13 | import pdb 14 | from tqdm import tqdm 15 | from typing import List, Dict, Tuple 16 | 17 | sys.path.insert(0, '.') 18 | 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | from torch.optim import AdamW 23 | from transformers import get_polynomial_decay_schedule_with_warmup 24 | 25 | from src.data.visionlanguage_datasets.vcr_dataset import build_vcr_dataloader 26 | from src.train.visionlanguage_tasks.task_trainer import TaskTrainer 27 | from src.utils.wandb import wandb_logger 28 | 29 | logger = logging.getLogger(__name__) 30 | logging.basicConfig( 31 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 32 | datefmt='%m/%d/%Y %H:%M:%S', 33 | level=logging.INFO) 34 | 35 | 36 | class VCRTrainer(TaskTrainer): 37 | 38 | def __init__(self, 39 | logger, 40 | args: argparse.Namespace, 41 | task_configs: Dict, 42 | model_config: Dict, 43 | device: torch.device, 44 | task_key, 45 | task_output_dir, 46 | accelerator): 47 | 48 | ''' 49 | Initializes a Trainer that handles training of a model on the VCR task 50 | 51 | args: Arguments provided by user 52 | task_configs: dictionary containing task-specific configuration parameters for all tasks 53 | model_config: dictionary containing model-specific configuration parameters 54 | device: cuda/cpu 55 | ''' 56 | 57 | super().__init__() 58 | 59 | self.args = args 60 | self.local_epochs = args.local_epochs 61 | self.device = device 62 | self.accelerator = accelerator 63 | self.task_output_dir = task_output_dir 64 | self.task_key = task_key 65 | 66 | self.vcr_config = task_configs['vcr'] 67 | self.data_dir = os.path.join(args.climb_data_dir, self.vcr_config['data_dir']) 68 | self.task_type = self.vcr_config['task_type'] 69 | 70 | # Model-specific stuff 71 | self.visual_input_type = model_config['visual_input_type'] 72 | self.batch2inputs_converter = model_config['batch2inputs_converter'] 73 | 74 | # Create dataloaders for training and validation 75 | self.vcr_train_dataloader = build_vcr_dataloader(args=args, 76 | data_dir=self.data_dir, 77 | split='train', 78 | task_type=self.task_type, 79 | visual_input_type=self.visual_input_type) 80 | 81 | self.vcr_val_dataloader = build_vcr_dataloader(args=args, 82 | data_dir=self.data_dir, 83 | split='val', 84 | task_type=self.task_type, 85 | visual_input_type=self.visual_input_type) 86 | 87 | # Training hyperparameters 88 | self.num_epochs = self.vcr_config['num_epochs'] 89 | self.lr = self.vcr_config['lr'] 90 | self.adam_epsilon = self.vcr_config['adam_epsilon'] 91 | self.weight_decay = self.vcr_config['weight_decay'] 92 | self.loss_criterion = nn.CrossEntropyLoss() 93 | 94 | self.vcr_train_dataloader.dataset.convert_to_low_shot(low_shot_percentage=0.05) 95 | self.vcr_val_dataloader.dataset.convert_to_low_shot(low_shot_percentage=0.05) 96 | self.max_steps = len(self.vcr_train_dataloader) * self.num_epochs 97 | self.warmup_ratio = 0.1 # TODO remove hard code 98 | self.hparams = { 99 | 'lr': self.lr, 100 | 'weight_decay': self.weight_decay, 101 | 'adam_epsilon': self.adam_epsilon, 102 | } 103 | 104 | def get_train_dataloader(self): 105 | return self.vcr_train_dataloader 106 | 107 | def get_collate_fn(self): 108 | return self.vcr_train_dataloader.collate_fn -------------------------------------------------------------------------------- /src/train/visionlanguage_tasks/train_snli_ve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import time 9 | import math 10 | import shutil 11 | import pickle as pkl 12 | import copy 13 | import pdb 14 | from tqdm import tqdm 15 | from typing import List, Dict, Tuple 16 | 17 | sys.path.insert(0, '.') 18 | 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | from torch.optim import AdamW 23 | from transformers import get_polynomial_decay_schedule_with_warmup 24 | 25 | from src.data.image_datasets.flickr30kimages_dataset import Flickr30KImagesDataset 26 | from src.data.visionlanguage_datasets.snli_ve_dataset import build_snli_ve_dataloader 27 | from src.train.visionlanguage_tasks.task_trainer import TaskTrainer 28 | from src.utils.wandb import wandb_logger 29 | 30 | logger = logging.getLogger(__name__) 31 | logging.basicConfig( 32 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 33 | datefmt='%m/%d/%Y %H:%M:%S', 34 | level=logging.INFO) 35 | 36 | class SNLIVETrainer(TaskTrainer): 37 | 38 | def __init__(self, 39 | logger, 40 | args: argparse.Namespace, 41 | task_configs: Dict, 42 | model_config: Dict, 43 | device: torch.device, 44 | task_key, 45 | task_output_dir, 46 | accelerator): 47 | 48 | ''' 49 | Initializes a Trainer that handles training of a model on the VCR task 50 | 51 | args: Arguments provided by user 52 | task_configs: dictionary containing task-specific configuration parameters for all tasks 53 | model_config: dictionary containing model-specific configuration parameters 54 | device: cuda/cpu 55 | ''' 56 | 57 | super().__init__() 58 | 59 | self.args = args 60 | self.local_epochs = args.local_epochs 61 | self.device = device 62 | self.accelerator = accelerator 63 | self.task_output_dir = task_output_dir 64 | self.task_key = task_key 65 | 66 | self.snli_ve_config = task_configs['snli-ve'] 67 | self.data_dir = os.path.join(args.climb_data_dir, self.snli_ve_config['data_dir']) 68 | 69 | # Model-specific stuff 70 | self.visual_input_type = model_config['visual_input_type'] 71 | self.batch2inputs_converter = model_config['batch2inputs_converter'] 72 | 73 | # Load Flickr30K Images dataset for image data backbone 74 | images_source = self.snli_ve_config['images_source'] 75 | flickr30k_config = task_configs[images_source] 76 | images_dataset = Flickr30KImagesDataset(os.path.join(args.climb_data_dir, flickr30k_config['data_dir']), 77 | visual_input_type=self.visual_input_type) 78 | 79 | # Create dataloaders for training and validation 80 | self.snli_ve_train_dataloader = build_snli_ve_dataloader(args=args, 81 | data_dir=self.data_dir, 82 | images_dataset=images_dataset, 83 | split='train', 84 | visual_input_type=self.visual_input_type) 85 | 86 | self.snli_ve_dev_dataloader = build_snli_ve_dataloader(args=args, 87 | data_dir=self.data_dir, 88 | images_dataset=images_dataset, 89 | split='dev', 90 | visual_input_type=self.visual_input_type) 91 | 92 | # Training hyperparameters 93 | self.num_epochs = self.snli_ve_config['num_epochs'] 94 | self.lr = self.snli_ve_config['lr'] 95 | self.adam_epsilon = self.snli_ve_config['adam_epsilon'] 96 | self.weight_decay = self.snli_ve_config['weight_decay'] 97 | self.loss_criterion = nn.CrossEntropyLoss() 98 | 99 | self.snli_ve_train_dataloader.dataset.convert_to_low_shot(num_shots_per_class=2048) 100 | self.snli_ve_dev_dataloader.dataset.convert_to_low_shot(num_shots_per_class=256) 101 | 102 | self.max_steps = len(self.snli_ve_train_dataloader) * self.num_epochs 103 | self.warmup_ratio = 0.1 # TODO remove hard code 104 | self.hparams = { 105 | 'lr': self.lr, 106 | 'weight_decay': self.weight_decay, 107 | 'adam_epsilon': self.adam_epsilon, 108 | } 109 | 110 | def get_train_dataloader(self): 111 | return self.snli_ve_train_dataloader 112 | 113 | def get_collate_fn(self): 114 | return self.snli_ve_train_dataloader.collate_fn -------------------------------------------------------------------------------- /src/modeling/vilt_clf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from accelerate.logging import get_logger 5 | import itertools 6 | import pdb 7 | import time 8 | from PIL import Image 9 | from typing import List, Dict 10 | from typing_extensions import OrderedDict 11 | 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | from transformers import BertConfig, BertTokenizer, BertModel 19 | from transformers import ViltConfig, ViltProcessor, ViltModel 20 | from transformers import BertTokenizerFast 21 | from transformers import logging as transformers_logging 22 | 23 | from src.modeling.continual_learner import EncoderWrapper, ContinualLearner 24 | 25 | 26 | class ViltForImageClassification(nn.Module): 27 | 28 | def __init__(self, encoder_dim: int, num_labels: int): 29 | ''' 30 | Modified ViLT model for image classification tasks 31 | args: 32 | encoder - instance of ViltEncoderWrapper class 33 | encoder_dim - output dimension of vilt encoder 34 | num_labels - number of output labels for image classification task 35 | ''' 36 | 37 | super().__init__() 38 | self.encoder_dim = encoder_dim 39 | self.clf_layer = nn.Sequential( 40 | nn.Linear(encoder_dim, encoder_dim*2), 41 | nn.LayerNorm(encoder_dim*2), 42 | nn.GELU(), 43 | nn.Linear(encoder_dim*2, num_labels) 44 | ) 45 | 46 | def forward(self, encoder, images: List, texts: List[str]) -> torch.FloatTensor: 47 | ''' 48 | Does forward pass of image and text inputs through model, where texts are dummy texts 49 | 50 | Args: 51 | images - batch_size-sized list of num_images-sized list of PIL Image objects 52 | texts - list of dummy text strings 53 | ''' 54 | encodings = encoder.process_inputs(images, texts) 55 | encoder_output = encoder(**encodings) 56 | 57 | output_logits = self.clf_layer(encoder_output) 58 | return output_logits 59 | 60 | 61 | class ViltForSequenceClassification(nn.Module): 62 | 63 | def __init__(self, encoder_dim: int, num_labels: int): 64 | ''' 65 | Modified ViLT model for text classification tasks 66 | 67 | Args: 68 | encoder_dim - output dimension of vilt encoder 69 | num_labels - number of output labels for text classification task 70 | ''' 71 | 72 | super().__init__() 73 | self.encoder_dim = encoder_dim 74 | self.clf_layer = nn.Sequential( 75 | nn.Linear(encoder_dim, encoder_dim*2), 76 | nn.LayerNorm(encoder_dim*2), 77 | nn.GELU(), 78 | nn.Linear(encoder_dim*2, num_labels) 79 | ) 80 | 81 | def forward(self, encoder, images: List, texts: List[str]) -> torch.FloatTensor: 82 | ''' 83 | Does forward pass of image and text inputs through model, where image is averaged image 84 | 85 | Args: 86 | images - batch_size-sized list of "average image"'sPIL Image objects 87 | texts - list of text strings 88 | ''' 89 | 90 | encodings = encoder.process_inputs(images, texts) 91 | # expand to batch size 92 | bs = len(encodings['input_ids']) 93 | encodings['pixel_values'] = encodings['pixel_values'].expand([bs, *encodings['pixel_values'].shape[1:]]) 94 | encodings['pixel_mask'] = encodings['pixel_mask'].expand([bs, *encodings['pixel_mask'].shape[1:]]) 95 | encoder_output = encoder(**encodings) 96 | output_logits = self.clf_layer(encoder_output) 97 | return output_logits 98 | 99 | 100 | class ViltForMultipleChoice(nn.Module): 101 | 102 | def __init__(self, encoder_dim: int, num_labels: int): 103 | ''' 104 | Modified ViLT model for text multiple-choice tasks 105 | Args: 106 | encoder_dim - output dimension of vilt encoder 107 | num_labels - number of choices for multi-choice task 108 | ''' 109 | super().__init__() 110 | self.encoder_dim = encoder_dim 111 | self.num_labels = num_labels 112 | self.clf_layer = nn.Sequential( 113 | nn.Dropout(0.1), 114 | nn.Linear(encoder_dim, 1) 115 | ) 116 | 117 | def forward(self, encoder, images, texts): 118 | encodings = encoder.process_inputs(images, texts) 119 | # unflat_input_ids = encodings['input_ids'].view(self.num_labels, 32, -1).transpose(0, 1) 120 | bs = len(encodings['input_ids']) 121 | encodings['pixel_values'] = encodings['pixel_values'].expand([bs, *encodings['pixel_values'].shape[1:]]) 122 | encodings['pixel_mask'] = encodings['pixel_mask'].expand([bs, *encodings['pixel_mask'].shape[1:]]) 123 | encoder_output = encoder(**encodings) 124 | reshape_output = encoder_output.view(self.num_labels, -1, self.encoder_dim).transpose(0, 1).contiguous() 125 | 126 | output_logits = self.clf_layer(reshape_output).squeeze() 127 | return output_logits -------------------------------------------------------------------------------- /src/utils/word_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", 5 | "arent": "aren't", 6 | "cant": "can't", 7 | "couldve": "could've", 8 | "couldnt": "couldn't", 9 | "couldn'tve": "couldn't've", 10 | "couldnt've": "couldn't've", 11 | "didnt": "didn't", 12 | "doesnt": "doesn't", 13 | "dont": "don't", 14 | "hadnt": "hadn't", 15 | "hadnt've": "hadn't've", 16 | "hadn'tve": "hadn't've", 17 | "hasnt": "hasn't", 18 | "havent": "haven't", 19 | "hed": "he'd", 20 | "hed've": "he'd've", 21 | "he'dve": "he'd've", 22 | "hes": "he's", 23 | "howd": "how'd", 24 | "howll": "how'll", 25 | "hows": "how's", 26 | "Id've": "I'd've", 27 | "I'dve": "I'd've", 28 | "Im": "I'm", 29 | "Ive": "I've", 30 | "isnt": "isn't", 31 | "itd": "it'd", 32 | "itd've": "it'd've", 33 | "it'dve": "it'd've", 34 | "itll": "it'll", 35 | "let's": "let's", 36 | "maam": "ma'am", 37 | "mightnt": "mightn't", 38 | "mightnt've": "mightn't've", 39 | "mightn'tve": "mightn't've", 40 | "mightve": "might've", 41 | "mustnt": "mustn't", 42 | "mustve": "must've", 43 | "neednt": "needn't", 44 | "notve": "not've", 45 | "oclock": "o'clock", 46 | "oughtnt": "oughtn't", 47 | "ow's'at": "'ow's'at", 48 | "'ows'at": "'ow's'at", 49 | "'ow'sat": "'ow's'at", 50 | "shant": "shan't", 51 | "shed've": "she'd've", 52 | "she'dve": "she'd've", 53 | "she's": "she's", 54 | "shouldve": "should've", 55 | "shouldnt": "shouldn't", 56 | "shouldnt've": "shouldn't've", 57 | "shouldn'tve": "shouldn't've", 58 | "somebody'd": "somebodyd", 59 | "somebodyd've": "somebody'd've", 60 | "somebody'dve": "somebody'd've", 61 | "somebodyll": "somebody'll", 62 | "somebodys": "somebody's", 63 | "someoned": "someone'd", 64 | "someoned've": "someone'd've", 65 | "someone'dve": "someone'd've", 66 | "someonell": "someone'll", 67 | "someones": "someone's", 68 | "somethingd": "something'd", 69 | "somethingd've": "something'd've", 70 | "something'dve": "something'd've", 71 | "somethingll": "something'll", 72 | "thats": "that's", 73 | "thered": "there'd", 74 | "thered've": "there'd've", 75 | "there'dve": "there'd've", 76 | "therere": "there're", 77 | "theres": "there's", 78 | "theyd": "they'd", 79 | "theyd've": "they'd've", 80 | "they'dve": "they'd've", 81 | "theyll": "they'll", 82 | "theyre": "they're", 83 | "theyve": "they've", 84 | "twas": "'twas", 85 | "wasnt": "wasn't", 86 | "wed've": "we'd've", 87 | "we'dve": "we'd've", 88 | "weve": "we've", 89 | "werent": "weren't", 90 | "whatll": "what'll", 91 | "whatre": "what're", 92 | "whats": "what's", 93 | "whatve": "what've", 94 | "whens": "when's", 95 | "whered": "where'd", 96 | "wheres": "where's", 97 | "whereve": "where've", 98 | "whod": "who'd", 99 | "whod've": "who'd've", 100 | "who'dve": "who'd've", 101 | "wholl": "who'll", 102 | "whos": "who's", 103 | "whove": "who've", 104 | "whyll": "why'll", 105 | "whyre": "why're", 106 | "whys": "why's", 107 | "wont": "won't", 108 | "wouldve": "would've", 109 | "wouldnt": "wouldn't", 110 | "wouldnt've": "wouldn't've", 111 | "wouldn'tve": "wouldn't've", 112 | "yall": "y'all", 113 | "yall'll": "y'all'll", 114 | "y'allll": "y'all'll", 115 | "yall'd've": "y'all'd've", 116 | "y'alld've": "y'all'd've", 117 | "y'all'dve": "y'all'd've", 118 | "youd": "you'd", 119 | "youd've": "you'd've", 120 | "you'dve": "you'd've", 121 | "youll": "you'll", 122 | "youre": "you're", 123 | "youve": "you've", 124 | } 125 | 126 | manual_map = { 127 | "none": "0", 128 | "zero": "0", 129 | "one": "1", 130 | "two": "2", 131 | "three": "3", 132 | "four": "4", 133 | "five": "5", 134 | "six": "6", 135 | "seven": "7", 136 | "eight": "8", 137 | "nine": "9", 138 | "ten": "10", 139 | } 140 | articles = ["a", "an", "the"] 141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 142 | comma_strip = re.compile("(\d)(\,)(\d)") 143 | punct = [ 144 | ";", 145 | r"/", 146 | "[", 147 | "]", 148 | '"', 149 | "{", 150 | "}", 151 | "(", 152 | ")", 153 | "=", 154 | "+", 155 | "\\", 156 | "_", 157 | "-", 158 | ">", 159 | "<", 160 | "@", 161 | "`", 162 | ",", 163 | "?", 164 | "!", 165 | ] 166 | 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | -------------------------------------------------------------------------------- /src/data/image_datasets/cocoimages_dataset_crossvqas.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torchvision import transforms as T 17 | from torch.utils.data import Dataset 18 | 19 | from PIL import Image 20 | 21 | 22 | class MSCOCOImagesDataset(Dataset): 23 | 24 | def __init__(self, coco_dir: str, data_dir: str, visual_input_type: str, task_key: str, image_size=(384, 640), transform=None): 25 | 26 | ''' 27 | Initializes an MSCOCOImagesDataset instance that handles image-side processing for VQA and other tasks that use MS-COCO images 28 | coco_dir: directory that contains MS-COCO data (images within 'images' folder) 29 | visual_input_type: format of visual input to model 30 | image_size: tuple indicating size of image input to model 31 | ''' 32 | 33 | self.images_dir = [os.path.join(coco_dir, dir) for dir in data_dir] 34 | self.image_size = image_size 35 | 36 | self.visual_input_type = visual_input_type 37 | assert visual_input_type in ['pil-image', 'raw', 'fast-rcnn'] 38 | 39 | if task_key in ['art', 'med']: 40 | image_filenames = os.listdir(self.images_dir[0]) 41 | elif task_key in ['pvqa']: 42 | image_filenames = os.listdir(self.images_dir[0]) + os.listdir(self.images_dir[1]) + os.listdir(self.images_dir[2]) 43 | elif task_key in ['toronto', 'abstract']: 44 | image_filenames = os.listdir(self.images_dir[0]) + os.listdir(self.images_dir[1]) 45 | self.imageid2filename = {} 46 | for fn in image_filenames: 47 | if task_key in ['abstract']: 48 | image_id = int(fn.strip('.png').split('_')[-1]) 49 | elif task_key in ['toronto']: 50 | image_id = int(fn.strip('.jpg').split('_')[-1]) 51 | elif task_key in ['pvqa']: 52 | image_id = fn.strip('.jpg') 53 | elif task_key in ['med']: 54 | image_id = fn.strip('.jpg').split('/')[-1] 55 | elif task_key in ['art']: 56 | image_id = int(fn.strip('.jpg').split('-')[0]) 57 | if task_key in ['art', 'med']: 58 | self.imageid2filename[image_id] = os.path.join(self.images_dir[0], fn) 59 | else: 60 | if 'train' in fn: 61 | self.imageid2filename[image_id] = os.path.join(self.images_dir[0], fn) 62 | elif 'val' in fn: 63 | self.imageid2filename[image_id] = os.path.join(self.images_dir[1], fn) 64 | elif 'test' in fn: 65 | self.imageid2filename[image_id] = os.path.join(self.images_dir[2], fn) 66 | 67 | self.imageids = list(set(list(self.imageid2filename.keys()))) 68 | 69 | # image_filenames = os.listdir(self.images_dir) 70 | # self.imageid2filename = {} 71 | # for fn in image_filenames: 72 | # fn = fn.split('_')[-1] 73 | # image_id = int(fn.strip('.jpg')) 74 | # self.imageid2filename[image_id] = os.path.join(self.images_dir, fn) 75 | # self.imageids = list(set(list(self.imageid2filename.keys()))) 76 | 77 | self.raw_transform = T.Compose([ 78 | T.Resize(image_size), 79 | T.ToTensor(), # [0, 1] 80 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [-1, 1] 81 | ]) 82 | 83 | self.pil_transform = T.Resize(size=384, max_size=640) 84 | 85 | 86 | def get_image_data(self, image_id: str): 87 | 88 | ''' 89 | Returns image data according to required visual_input_type. Output format varies by visual_input_type 90 | ''' 91 | 92 | if self.visual_input_type == 'pil-image': 93 | return self.get_pil_image(image_id) 94 | if self.visual_input_type == 'raw': 95 | return self.get_raw_image_tensor(image_id) 96 | elif self.visual_input_type == 'fast-rcnn': 97 | raise NotImplementedError("Have not implemented Fast-RCNN feature inputs for MS-COCO images!") 98 | 99 | def get_pil_image(self, image_id: str) -> Image: 100 | ''' 101 | Loads image corresponding to image_id, re-sizes and returns PIL.Image object 102 | ''' 103 | 104 | assert image_id in self.imageid2filename.keys() 105 | image_fn = self.imageid2filename[image_id] 106 | image = Image.open(image_fn) 107 | image = image.convert('RGB') 108 | if min(list(image.size)) > 384 or hasattr(self, 'use_albef'): 109 | image = self.pil_transform(image) 110 | return image 111 | 112 | def get_raw_image_tensor(self, image_id: str) -> torch.Tensor: 113 | ''' 114 | Loads image corresponding to image_id, re-sizes, and returns tensor of size (3, W, H) 115 | ''' 116 | 117 | assert image_id in self.imageid2filename.keys() 118 | image_fn = self.imageid2filename[image_id] 119 | image = Image.open(image_fn) 120 | image = image.convert('RGB') 121 | 122 | image_tensor = self.raw_transform(image) 123 | 124 | image.close() 125 | return image_tensor # (B, 3, W, H) -------------------------------------------------------------------------------- /src/train/visionlanguage_tasks/train_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import time 9 | import math 10 | import shutil 11 | import pickle as pkl 12 | import copy 13 | import pdb 14 | from tqdm import tqdm 15 | from typing import List, Dict, Tuple 16 | 17 | sys.path.insert(0, '.') 18 | 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | from torch.optim import AdamW 23 | from transformers import get_polynomial_decay_schedule_with_warmup 24 | 25 | from src.data.image_datasets.cocoimages_dataset import MSCOCOImagesDataset 26 | from src.data.visionlanguage_datasets.vqa_dataset import build_vqa_dataloader 27 | from src.train.visionlanguage_tasks.task_trainer import TaskTrainer 28 | from src.utils.wandb import wandb_logger 29 | 30 | logger = logging.getLogger(__name__) 31 | logging.basicConfig( 32 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 33 | datefmt='%m/%d/%Y %H:%M:%S', 34 | level=logging.INFO) 35 | 36 | class VQATrainer(TaskTrainer): 37 | 38 | def __init__(self, 39 | logger, 40 | args: argparse.Namespace, 41 | task_configs: Dict, 42 | model_config: Dict, 43 | device: torch.device, 44 | task_key, 45 | task_output_dir, 46 | accelerator): 47 | 48 | ''' 49 | Initializes a Trainer that handles training of a model on the VCR task 50 | 51 | args: Arguments provided by user 52 | task_configs: dictionary containing task-specific configuration parameters for all tasks 53 | model_config: dictionary containing model-specific configuration parameters 54 | device: cuda/cpu 55 | ''' 56 | 57 | super().__init__() 58 | 59 | self.args = args 60 | self.local_epochs = args.local_epochs 61 | self.device = device 62 | self.accelerator = accelerator 63 | self.task_output_dir = task_output_dir 64 | self.task_key = task_key 65 | 66 | self.vqa_config = task_configs['vqa'] 67 | self.data_dir = os.path.join(args.climb_data_dir, self.vqa_config['data_dir']) 68 | 69 | # Model-specific stuff 70 | self.visual_input_type = model_config['visual_input_type'] 71 | self.batch2inputs_converter = model_config['batch2inputs_converter'] 72 | 73 | # Load COCO Images dataset for image data backbone 74 | images_source = self.vqa_config['images_source'] 75 | mscoco_config = task_configs[images_source] 76 | self.images_dataset = MSCOCOImagesDataset(coco_dir=os.path.join(args.climb_data_dir, mscoco_config['data_dir']), 77 | visual_input_type=args.visual_input_type) 78 | 79 | # Create dataloaders for training and validation 80 | self.vqa_train_dataloader = build_vqa_dataloader(args=args, 81 | data_dir=self.data_dir, 82 | images_dataset=self.images_dataset, 83 | split='train', 84 | visual_input_type=self.visual_input_type) 85 | 86 | self.vqa_val_dataloader = build_vqa_dataloader(args=args, 87 | data_dir=self.data_dir, 88 | images_dataset=self.images_dataset, 89 | split='val', 90 | visual_input_type=self.visual_input_type) 91 | 92 | # Training hyperparameters 93 | self.num_epochs = self.vqa_config['num_epochs'] 94 | self.lr = self.vqa_config['lr'] 95 | self.adam_epsilon = self.vqa_config['adam_epsilon'] 96 | self.weight_decay = self.vqa_config['weight_decay'] 97 | self.hparams = { 98 | 'lr': self.lr, 99 | 'weight_decay': self.weight_decay, 100 | 'adam_epsilon': self.adam_epsilon, 101 | } 102 | 103 | self.loss_criterion = nn.BCEWithLogitsLoss(reduction='mean') 104 | 105 | self.vqa_train_dataloader.dataset.convert_to_low_shot(low_shot_percentage=0.05) 106 | self.vqa_val_dataloader.dataset.convert_to_low_shot(low_shot_percentage=0.05) 107 | self.max_steps = len(self.vqa_train_dataloader) * self.num_epochs 108 | self.warmup_ratio = 0.1 # TODO remove hard code 109 | 110 | def compute_score_with_logits(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 111 | ''' 112 | Given logits for each answer in VQA classification, selects answer with max logit and returns VQA-score for that answer 113 | logits: logits for each answer - size=(batch_size, num_answers) 114 | labels: label for each answer in {0, 0.3, 0.6, 1} (batch_size, num_answers) 115 | 116 | Returns: 117 | scores: score of predicted answer (batch_size, num_answers) 118 | ''' 119 | 120 | logits = torch.max(logits, 1)[1].data # argmax 121 | one_hots = torch.zeros(*labels.size()).to(self.device) 122 | one_hots.scatter_(1, logits.view(-1, 1), 1) 123 | scores = (one_hots * labels) 124 | return scores 125 | 126 | def get_train_dataloader(self): 127 | return self.vqa_train_dataloader 128 | 129 | def get_collate_fn(self): 130 | return self.vqa_train_dataloader.collate_fn 131 | -------------------------------------------------------------------------------- /src/modeling/models/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def init_bert_weights(module): 6 | """Initialize the weights.""" 7 | if isinstance(module, (nn.Linear, nn.Embedding)): 8 | # std defaults to 0.02, this might need to be changed 9 | module.weight.data.normal_(mean=0.0, std=0.02) 10 | elif isinstance(module, nn.LayerNorm): 11 | module.bias.data.zero_() 12 | module.weight.data.fill_(1.0) 13 | if isinstance(module, nn.Linear) and module.bias is not None: 14 | module.bias.data.zero_() 15 | 16 | class Adapter(nn.Module): 17 | """ 18 | The adapters first project the original 19 | d-dimensional features into a smaller dimension, m, apply 20 | a nonlinearity, then project back to d dimensions. 21 | """ 22 | def __init__(self, names, device, model_dim=768, adapter_reduction_factor=16): 23 | super().__init__() 24 | self.actv = nn.ReLU() 25 | self.scaling = 1.0 26 | self.gating = False 27 | print(names) 28 | 29 | if isinstance(names, str): 30 | names = [names] 31 | self.adapter_dict = {} 32 | for name in names: 33 | if 'adapter' in name: 34 | n = f'{name}_down' 35 | setattr(self, n, nn.Linear(model_dim, model_dim//adapter_reduction_factor).to(device)) 36 | m = getattr(self, n) 37 | m.apply(init_bert_weights) 38 | for p in m.parameters(): 39 | p.requires_grad = True 40 | n = f'{name}_up' 41 | setattr(self, n, nn.Linear(model_dim//adapter_reduction_factor, model_dim).to(device)) 42 | m = getattr(self, n) 43 | m.apply(init_bert_weights) 44 | for p in m.parameters(): 45 | p.requires_grad = True 46 | 47 | elif name in ['gating']: 48 | # for each client we init a spec gating 49 | setattr(self, f'{name}_module', nn.Linear(model_dim, 2).to(device)) 50 | m = getattr(self, f'{name}_module') 51 | m.apply(init_bert_weights) 52 | for p in m.parameters(): 53 | p.requires_grad = True 54 | 55 | if hasattr(self, 'adapter_2_down'): 56 | for m in [self.adapter_2_down, self.adapter_2_up]: 57 | for p in m.parameters(): 58 | p.requires_grad = False 59 | 60 | def deactivate_gating(self): 61 | self.gating = False 62 | 63 | def activate_gating(self): 64 | self.gating = True 65 | 66 | def set_active_adapter(self, name): 67 | if isinstance(name, str): 68 | self.active_adapter_down = getattr(self, f'{name}_down') 69 | self.active_adapter_up = getattr(self, f'{name}_up') 70 | 71 | if name == 'adapter_0': 72 | for m in [self.adapter_0_down, self.adapter_0_up]: 73 | for p in m.parameters(): 74 | p.requires_grad = True 75 | for m in [self.adapter_1_down, self.adapter_1_up]: 76 | for p in m.parameters(): 77 | p.requires_grad = False 78 | 79 | elif name == 'adapter_1': 80 | for m in [self.adapter_1_down, self.adapter_1_up]: 81 | for p in m.parameters(): 82 | p.requires_grad = True 83 | for m in [self.adapter_0_down, self.adapter_0_up]: 84 | for p in m.parameters(): 85 | p.requires_grad = False 86 | 87 | elif isinstance(name, list): 88 | for n in name: 89 | m = getattr(self, f'{n}_down') 90 | for p in m.parameters(): 91 | p.requires_grad = True 92 | m = getattr(self, f'{n}_up') 93 | for p in m.parameters(): 94 | p.requires_grad = True 95 | return 96 | 97 | def adapter_layer_forward_bert(self, hidden_states, input_tensor, layer_norm): 98 | hidden_states, residual = self.pre_forward(hidden_states, input_tensor, layer_norm) 99 | hidden_states = self.forward(hidden_states, residual) 100 | hidden_states = self.post_forward(hidden_states, input_tensor, layer_norm) 101 | return hidden_states 102 | 103 | def pre_forward(self, hidden_states, input_tensor, layer_norm): 104 | residual = hidden_states # residual_before_ln = True 105 | if layer_norm: 106 | hidden_states = layer_norm(hidden_states + input_tensor) 107 | else: 108 | hidden_states = hidden_states + input_tensor 109 | return hidden_states, residual 110 | 111 | def post_forward(self, hidden_states, input_tensor, layer_norm): 112 | if layer_norm: 113 | hidden_states = layer_norm(hidden_states + input_tensor) 114 | else: 115 | hidden_states = hidden_states + input_tensor 116 | return hidden_states 117 | 118 | def get_agg_out(self, outs, weights): 119 | agg_out = weights[:, :, 0].unsqueeze(-1) * outs[0] 120 | for i, out in enumerate(outs[1:]): 121 | agg_out += weights[:, :, i+1].unsqueeze(-1) * out 122 | return agg_out 123 | 124 | def forward(self, hidden_states, input_tensor): 125 | if not self.gating: 126 | # one adapter for all 127 | down = self.active_adapter_down(hidden_states) 128 | down = self.actv(down) 129 | up = self.active_adapter_up(down) 130 | 131 | hidden_states = input_tensor + up 132 | 133 | elif hasattr(self, 'adapter_2_down'): 134 | up_outs = [] 135 | for i in [0, 2]: 136 | adapter_down = getattr(self, f'adapter_{i}_down') 137 | down_out = adapter_down(hidden_states) 138 | down_out = self.actv(down_out) 139 | adapter_up = getattr(self, f'adapter_{i}_up') 140 | up_out = adapter_up(down_out) 141 | up_outs.append(up_out) 142 | 143 | # weight_up = F.softmax(self.gating_module(hidden_states) + 10**-6, dim=-1) 144 | weight_up = torch.ones(list(up_out.shape)[:-1] + [2]).to('cuda') * 0.5 145 | agg_up_out = self.get_agg_out(up_outs, weight_up) 146 | hidden_states = input_tensor + agg_up_out * self.scaling 147 | 148 | else: 149 | # one gating for all 150 | up_outs = [] 151 | for i in range(2): 152 | adapter_down = getattr(self, f'adapter_{i}_down') 153 | down_out = adapter_down(hidden_states) 154 | down_out = self.actv(down_out) 155 | adapter_up = getattr(self, f'adapter_{i}_up') 156 | up_out = adapter_up(down_out) 157 | up_outs.append(up_out) 158 | 159 | # weight_up = F.softmax(self.gating_module(hidden_states) + 10**-6, dim=-1) 160 | weight_up = torch.ones(list(up_out.shape)[:-1] + [2]).to('cuda') * 0.5 161 | agg_up_out = self.get_agg_out(up_outs, weight_up) 162 | hidden_states = input_tensor + agg_up_out * self.scaling 163 | return hidden_states -------------------------------------------------------------------------------- /src/data/visionlanguage_datasets/nlvr2_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import jsonlines 6 | import logging 7 | import glob 8 | from tqdm import tqdm 9 | import pickle 10 | import pdb 11 | from PIL import Image 12 | from typing import List, Dict 13 | 14 | import numpy as np 15 | import torch 16 | from torch.utils.data import Dataset 17 | import random 18 | 19 | from PIL import Image 20 | from torchvision import transforms as T 21 | 22 | logger = logging.getLogger(__name__) 23 | logging.basicConfig( 24 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 25 | datefmt='%m/%d/%Y %H:%M:%S', 26 | level=logging.INFO) 27 | 28 | 29 | 30 | class NLVR2Dataset(Dataset): 31 | 32 | def __init__(self, 33 | data_dir: str, 34 | split: str, 35 | **kwargs): 36 | 37 | """ 38 | Initiates the NLVR2Dataset - loads all the sentences and corresponding image IDs and output label 39 | Every item in self.data corresponds to a single NLVR2 input 40 | 41 | Args: 42 | data_dir : path containing NLVR2 annotations and images 43 | split: either train/val/test split 44 | 45 | Returns: 46 | Loads all annotations into self.data, where each item is a single NLVR2 input 47 | """ 48 | 49 | self.data_dir = data_dir 50 | self.num_labels = 2 51 | self.split = split 52 | 53 | rename_split = {'train': 'train', 'val': 'dev', 'test': 'test1'} 54 | _split = rename_split[split] 55 | self.image_dir = os.path.join(data_dir, 'images', _split) 56 | 57 | # Load if cached data exist 58 | self.cached_data_file = os.path.join(data_dir, 'cached_nlvr2_data', f'{_split}.pkl') 59 | if os.path.exists(self.cached_data_file): 60 | with open(self.cached_data_file, 'rb') as f: 61 | self.data = pickle.load(open(self.cached_data_file, 'rb')) 62 | else: 63 | annotations_file = os.path.join(data_dir, 'data', f'{_split}.json') 64 | 65 | self.data = [] 66 | # https://github.com/facebookresearch/vilbert-multi-task/blob/main/vilbert/datasets/nlvr2_dataset.py 67 | with jsonlines.open(annotations_file) as reader: 68 | for annotation in reader: 69 | # logger.info(annotation) 70 | example = {} 71 | example["id"] = annotation["identifier"] 72 | example["image_id_0"] = os.path.join(self.image_dir, ( 73 | "-".join(annotation["identifier"].split("-")[:-1]) + "-img0.png" 74 | )) 75 | example["image_id_1"] = os.path.join(self.image_dir, ( 76 | "-".join(annotation["identifier"].split("-")[:-1]) + "-img1.png" 77 | )) 78 | example["sentence"] = str(annotation["sentence"]) 79 | example["labels"] = 0 if str(annotation["label"]) == "False" else 1 80 | self.data.append(example) 81 | 82 | with open(self.cached_data_file, 'wb') as f: 83 | pickle.dump(self.data, f) 84 | 85 | self.n_examples = len(self.data) 86 | logger.info("Loaded NLVRv2 {} dataset, with {} examples".format(split, self.n_examples)) 87 | self.pil_transform = T.Resize(size=384, max_size=640) 88 | 89 | def get_pil_image(self, image_fn): 90 | image = Image.open(image_fn) 91 | image = image.convert('RGB') 92 | if min(list(image.size)) > 384: 93 | image = self.pil_transform(image) 94 | return image 95 | 96 | def __len__(self): 97 | return self.n_examples 98 | 99 | def __getitem__(self, index: int): 100 | 101 | """ 102 | Args: 103 | index : index of element in self.data to return as data instance 104 | 105 | Returns: 106 | dictionary containing inputs and targets for model to do NLVR 107 | """ 108 | 109 | example = self.data[index] 110 | img1 = self.get_pil_image(example["image_id_0"]) 111 | img2 = self.get_pil_image(example["image_id_1"]) 112 | image = [img1, img2] 113 | 114 | return {'text': example["sentence"], 115 | 'image': image, 116 | 'label': example["labels"]} 117 | 118 | def convert_to_low_shot(self, num_shots_per_class: int): 119 | """ 120 | Args: 121 | num_shots_per_class: int, denoting number of examples for each output label in low-shot setting 122 | """ 123 | 124 | logger.info("Converting NLVR2 train split into low-shot dataset, with {} examples per class...".format(num_shots_per_class)) 125 | new_data = [] 126 | for i in range(self.num_labels): 127 | i_examples = [d for d in self.data if d['labels'] == i] 128 | low_shot_examples = random.Random(1).sample(i_examples, num_shots_per_class) 129 | new_data.extend(low_shot_examples) 130 | self.data = new_data 131 | self.n_examples = len(self.data) 132 | 133 | logger.info("Converted into low-shot dataset, with {} examples".format(self.n_examples)) 134 | 135 | def nlvr2_batch_collate(batch: List[Dict], 136 | visual_input_type: str): 137 | 138 | """ 139 | Collates each model input for all batch items into a single model input (e.g. converts a list of input_ids into a matrix of size (batch_size, max_len)) 140 | 141 | Args: 142 | batch - list of batch items, each item being a dictionary returned by Dataset's __getitem__ method 143 | visual_input_type: string which specifies the type of visual input 144 | 145 | Returns: 146 | Dictionary containing batched inputs and outputs 147 | """ 148 | 149 | assert visual_input_type == 'pil-image' 150 | texts = [x['text'] for x in batch] 151 | pil_objs = [x['image'] for x in batch] 152 | labels = [x['label'] for x in batch] 153 | 154 | return {'raw_texts': texts, 155 | 'images': pil_objs, 156 | 'labels': torch.LongTensor(labels)} 157 | 158 | def build_nlvr2_dataloader(args, 159 | data_dir: str, 160 | split: str, 161 | visual_input_type: str, 162 | **kwargs) -> torch.utils.data.DataLoader: 163 | 164 | """ 165 | Creates the NLVR2 Dataloader, which gives batches of NLVR2 inputs and outputs 166 | 167 | Args: 168 | data_dir : path containing NLVR questions and annotations. 169 | split: either train/val split 170 | visual_input_type: format of visual input to model 171 | 172 | Returns: 173 | DataLoader object 174 | """ 175 | 176 | logger.info("Creating NLVR2 {} dataloader with batch size of {}".format(split, int(args.batch_size/2))) 177 | 178 | if visual_input_type != "pil-image": 179 | raise NotImplementedError("Have not implemented other inputs for NLVR2 images!") 180 | 181 | dataset = NLVR2Dataset(data_dir, split, **kwargs) 182 | dataloader = torch.utils.data.DataLoader( 183 | dataset, 184 | num_workers = args.num_workers, 185 | batch_size = int(args.batch_size/2), 186 | shuffle = (split=='train'), 187 | collate_fn = lambda x: nlvr2_batch_collate(x, visual_input_type) 188 | ) 189 | return dataloader 190 | 191 | 192 | ''' 193 | if __name__ == '__main__': 194 | 195 | class Args: 196 | def __init__(self): 197 | self.batch_size = 4 198 | self.num_workers = 2 199 | self.visual_input_type = 'pil-image' 200 | 201 | args = Args() 202 | data_dir = '/data/datasets/MCL/nlvr2/' 203 | split = 'val' #'train' 204 | 205 | from transformers import BertTokenizer 206 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 207 | 208 | nlvr_dataloader = build_nlvr2_dataloader(args, data_dir,'val', args.visual_input_type, tokenizer=tokenizer) 209 | 210 | for batch in nlvr_dataloader: 211 | pdb.set_trace() 212 | ''' 213 | -------------------------------------------------------------------------------- /src/configs/task_configs_fed.py: -------------------------------------------------------------------------------- 1 | from src.train.visionlanguage_tasks.train_vqa import VQATrainer 2 | from src.train.visionlanguage_tasks.train_vqa_crossvqa import VQATrainerCross 3 | from src.train.visionlanguage_tasks.train_nlvr2 import NLVR2Trainer 4 | from src.train.visionlanguage_tasks.train_snli_ve import SNLIVETrainer 5 | from src.train.visionlanguage_tasks.train_vcr import VCRTrainer 6 | import copy 7 | 8 | SUPPORTED_VL_TASKS = [ 9 | "vqa", 10 | "abstract", 11 | "toronto", 12 | "vizwiz", 13 | "pvqa", 14 | "med", 15 | "art", 16 | ] + ['vqa', 'nlvr2', 'snli-ve', 'vcr'] 17 | 18 | SUPPORTED_ORDER = ["Order1", "Order2", "Order3", "Order4", "Order5", "Debug_order"] 19 | 20 | data_root = "./data" 21 | 22 | mscoco_config = { 23 | "data_dir": data_root + "/mscoco", 24 | } 25 | abstract_image_config = { 26 | "data_dir": [ 27 | data_root + "/vqa_abstract/train2015", 28 | data_root + "/vqa_abstract/val2015", 29 | ], 30 | } 31 | toronto_image_config = { 32 | "data_dir": [ 33 | data_root + "/mscoco/train2014", 34 | data_root + "/mscoco/val2014", 35 | ] 36 | } 37 | art_image_config = {"data_dir": [data_root + "/AQUA/SemArt/Images"]} 38 | 39 | clove_function_a_config = { 40 | "task_name": "clove_function_a", 41 | "data_dir": data_root + "/CLOVE/json/function", 42 | "images_source": "vgd", 43 | "splits": ["train", "val_small"], 44 | "num_labels": 100, 45 | "num_images": 1, 46 | "model_type": "classification", 47 | "num_epochs": 20, 48 | "lr": 1e-4, 49 | "weight_decay": 1e-2, 50 | "adam_epsilon": 1e-8, 51 | "warmup_ratio": 0.1, 52 | "task_trainer": VQATrainerCross, 53 | "random_baseline_score": 0.0, 54 | } 55 | 56 | clove_function_b_config = copy.deepcopy(clove_function_a_config) 57 | clove_function_b_config["task_name"] = "clove_function_b" 58 | clove_function_c_config = copy.deepcopy(clove_function_a_config) 59 | clove_function_c_config["task_name"] = "clove_function_c" 60 | clove_function_d_config = copy.deepcopy(clove_function_a_config) 61 | clove_function_d_config["task_name"] = "clove_function_d" 62 | clove_function_e_config = copy.deepcopy(clove_function_a_config) 63 | clove_function_e_config["task_name"] = "clove_function_e" 64 | 65 | clove_scene_a_config = { 66 | "task_name": "clove_scene_a", 67 | "data_dir": data_root + "/CLOVE/json/scene", 68 | "images_source": "vgd", 69 | "splits": ["train", "val_small"], 70 | "num_labels": 100, 71 | "num_images": 1, 72 | "model_type": "classification", 73 | "num_epochs": 20, 74 | "lr": 1e-4, 75 | "weight_decay": 1e-2, 76 | "adam_epsilon": 1e-8, 77 | "warmup_ratio": 0.1, 78 | "task_trainer": VQATrainerCross, 79 | "random_baseline_score": 0.0, 80 | } 81 | 82 | clove_scene_b_config = copy.deepcopy(clove_scene_a_config) 83 | clove_scene_b_config["task_name"] = "clove_scene_b" 84 | clove_scene_c_config = copy.deepcopy(clove_scene_a_config) 85 | clove_scene_c_config["task_name"] = "clove_scene_c" 86 | clove_scene_d_config = copy.deepcopy(clove_scene_a_config) 87 | clove_scene_d_config["task_name"] = "clove_scene_d" 88 | clove_scene_e_config = copy.deepcopy(clove_scene_a_config) 89 | clove_scene_e_config["task_name"] = "clove_scene_e" 90 | clove_scene_f_config = copy.deepcopy(clove_scene_a_config) 91 | clove_scene_f_config["task_name"] = "clove_scene_f" 92 | 93 | vizwiz_config = { 94 | "task_name": "vizwiz", 95 | "data_dir": data_root + "/vizwiz", 96 | "images_source": "vizwiz", 97 | "splits": ["train", "val_small"], 98 | "num_labels": 100, 99 | "num_images": 1, 100 | "model_type": "classification", 101 | "num_epochs": 20, 102 | "lr": 1e-4, 103 | "weight_decay": 1e-2, 104 | "adam_epsilon": 1e-8, 105 | "warmup_ratio": 0.1, 106 | "task_trainer": VQATrainerCross, 107 | "random_baseline_score": 0.0, 108 | } 109 | 110 | gqa_config = { 111 | "task_name": "gqa", 112 | "data_dir": data_root + "/GQA", 113 | "images_source": "vg", 114 | "splits": ["train", "val_small"], 115 | "num_labels": 100, 116 | "num_images": 1, 117 | "model_type": "classification", 118 | "num_epochs": 20, 119 | "lr": 1e-4, 120 | "weight_decay": 1e-2, 121 | "adam_epsilon": 1e-8, 122 | "warmup_ratio": 0.1, 123 | "task_trainer": VQATrainerCross, 124 | "random_baseline_score": 0.0, 125 | } 126 | 127 | abstract_config = { 128 | "task_name": "abstract", 129 | "data_dir": data_root + "/vqa_abstract", 130 | "images_source": "abstract_image", 131 | "splits": ["train", "val_small"], 132 | "num_labels": 100, 133 | "num_images": 1, 134 | "model_type": "classification", 135 | "num_epochs": 20, # Yao: original 10 136 | "lr": 1e-4, 137 | "weight_decay": 1e-2, 138 | "adam_epsilon": 1e-8, 139 | "warmup_ratio": 0.1, 140 | "task_trainer": VQATrainerCross, 141 | "random_baseline_score": 0.0, 142 | } 143 | 144 | 145 | toronto_config = { 146 | "task_name": "toronto", 147 | "data_dir": data_root + "/torontoCOCO", 148 | "images_source": "toronto_image", 149 | "splits": ["train", "val"], 150 | "num_labels": 100, 151 | "num_images": 1, 152 | "model_type": "classification", 153 | "num_epochs": 20, # Yao: original 10 154 | "lr": 1e-4, 155 | "weight_decay": 1e-2, 156 | "adam_epsilon": 1e-8, 157 | "warmup_ratio": 0.1, 158 | "task_trainer": VQATrainerCross, 159 | "random_baseline_score": 0.0, 160 | } 161 | 162 | art_config = { 163 | "task_name": "art", 164 | "data_dir": data_root + "/albef/art", 165 | "images_source": "art_image", 166 | "splits": ["train", "val"], 167 | "num_labels": 100, 168 | "num_images": 1, 169 | "model_type": "classification", 170 | "num_epochs": 20, # Yao: original 10 171 | "lr": 1e-4, 172 | "weight_decay": 1e-2, 173 | "adam_epsilon": 1e-8, 174 | "warmup_ratio": 0.1, 175 | "task_trainer": VQATrainerCross, 176 | "random_baseline_score": 0.0, 177 | 178 | } 179 | 180 | mscoco_config = { 181 | 'data_dir': 'ms-coco/', 182 | } 183 | 184 | flickr_config = { 185 | 'data_dir': 'flickr30k/', 186 | } 187 | 188 | vqa_config = { 189 | 'task_name': 'VQAv2', 190 | 'data_dir': 'vqav2/', 191 | 'images_source': 'ms-coco', 192 | 'splits': ['train', 'val'], 193 | 'num_labels': 3129, 194 | 'num_images': 1, 195 | 'model_type': 'classification', 196 | 'num_epochs': 10, 197 | 'lr': 1e-4, 198 | 'weight_decay': 1e-2, 199 | 'adam_epsilon': 1e-8, 200 | 'warmup_ratio': 0.1, 201 | 'task_trainer': VQATrainerCross, 202 | 'random_baseline_score': 0.0, 203 | } 204 | 205 | nlvr_config = { 206 | 'task_name': 'NLVRv2', 207 | 'data_dir': 'nlvr2/', 208 | 'splits': ['train', 'val'], 209 | 'num_labels': 2, 210 | 'num_images': 2, 211 | 'model_type': 'classification', 212 | 'num_epochs': 10, 213 | 'lr': 1e-4, 214 | 'weight_decay': 1e-2, 215 | 'adam_epsilon': 1e-8, 216 | 'warmup_ratio': 0.1, 217 | 'task_trainer': NLVR2Trainer, 218 | 'random_baseline_score': 50.0, 219 | } 220 | 221 | snli_ve_config = { 222 | 'task_name': 'SNLI-VE', 223 | 'data_dir': 'snli-ve/', 224 | 'images_source': 'flickr30k', 225 | 'splits': ['train', 'dev', 'test'], 226 | 'num_labels': 3, 227 | 'num_images': 1, 228 | 'model_type': 'classification', 229 | 'num_epochs': 5, 230 | 'lr': 5e-5, 231 | 'weight_decay': 1e-2, 232 | 'adam_epsilon': 1e-8, 233 | 'warmup_ratio': 0.1, 234 | 'task_trainer': SNLIVETrainer, 235 | 'random_baseline_score': 33.33, 236 | } 237 | 238 | vcr_config = { 239 | 'task_name': 'VCR', 240 | 'data_dir': 'vcr/', 241 | 'splits': ['train', 'dev', 'test'], 242 | 'num_labels': 4, 243 | 'num_images': 1, 244 | 'model_type': 'multi-choice', 245 | 'task_type': 'answer', 246 | 'num_choices': 4, 247 | 'num_epochs': 10, 248 | 'lr': 1e-4, 249 | 'weight_decay': 1e-2, 250 | 'adam_epsilon': 1e-8, 251 | 'warmup_ratio': 0.1, 252 | 'task_trainer': VCRTrainer, 253 | 'random_baseline_score': 25.0, 254 | } 255 | 256 | task_configs = { 257 | "ms-coco": mscoco_config, 258 | "flickr30k": flickr_config, 259 | "abstract": abstract_config, 260 | "clove_scene_a": clove_scene_a_config, 261 | "clove_scene_b": clove_scene_b_config, 262 | "clove_scene_c": clove_scene_c_config, 263 | "clove_scene_d": clove_scene_d_config, 264 | "clove_scene_e": clove_scene_e_config, 265 | "clove_scene_f": clove_scene_f_config, 266 | "clove_function_a": clove_function_a_config, 267 | "clove_function_b": clove_function_b_config, 268 | "clove_function_c": clove_function_c_config, 269 | "clove_function_d": clove_function_d_config, 270 | "clove_function_e": clove_function_e_config, 271 | "toronto": toronto_config, 272 | "vizwiz": vizwiz_config, 273 | "gqa": gqa_config, 274 | "art": art_config, 275 | 'vqa': vqa_config, 276 | 'nlvr2': nlvr_config, 277 | 'snli-ve': snli_ve_config, 278 | 'vcr': vcr_config, 279 | "abstract_image": abstract_image_config, 280 | "toronto_image": toronto_image_config, 281 | "art_image": art_image_config, 282 | } 283 | -------------------------------------------------------------------------------- /src/data/visionlanguage_datasets/snli_ve_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | import pdb 13 | import jsonlines 14 | from typing import List, Dict 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from torchvision import transforms as T 20 | from torch.utils.data import Dataset 21 | 22 | from PIL import Image 23 | from src.utils.image_utils import resize_image 24 | 25 | from src.data.image_datasets.flickr30kimages_dataset import Flickr30KImagesDataset 26 | from src.data.image_collation import image_collate 27 | 28 | logger = logging.getLogger(__name__) 29 | logging.basicConfig( 30 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 31 | datefmt='%m/%d/%Y %H:%M:%S', 32 | level=logging.INFO) 33 | 34 | class SnliVEDataset(Dataset): 35 | 36 | def __init__(self, 37 | data_dir: str, 38 | images_dataset: Flickr30KImagesDataset, 39 | split: str, 40 | **kwargs): 41 | 42 | """ 43 | Initiates the SnliVEDataset - loads all the questions (and converts to input IDs using the tokenizer, if provided) 44 | and answers (including converting each to a numeric label, and a score based on occurence from annotators) 45 | Every item in self.data corresponds to a single VE hypothesis and corresponding image 46 | 47 | Args: 48 | data_dir : path containing SNLI-VE hypotheses and annotations. 49 | images_dataset : instance of Flickr30KImagesDataset, that is used to retrieve the Flickr30K image for each question 50 | split: either train/val split 51 | 52 | 53 | Returns: 54 | Loads all annotations into self.data, where each item is a single SNLI-VE pair 55 | """ 56 | 57 | self.data_dir = data_dir 58 | self.images_dataset = images_dataset 59 | self.image_dir = os.path.join(data_dir, 'flickr30k_images') 60 | self.split = split 61 | self.tokenizer = kwargs['tokenizer'] if 'tokenizer' in kwargs else None 62 | 63 | self.annotations_file = os.path.join(data_dir, 'snli_ve_{}.jsonl'.format(split)) 64 | self.categories = ['entailment', 'contradiction', 'neutral'] 65 | self.cat2label = {cat: i for i, cat in enumerate(self.categories)} 66 | self.num_labels = len(self.categories) 67 | 68 | self.cached_data_file = os.path.join(data_dir, 'cached_ve_data', 'snli-ve_{}.pkl'.format(split)) 69 | if os.path.exists(self.cached_data_file): 70 | self.data = pkl.load(open(self.cached_data_file, 'rb')) 71 | else: 72 | self.data = [] 73 | json_lines = jsonlines.open(self.annotations_file) 74 | for line in tqdm(json_lines): 75 | image_id = int(line['Flickr30K_ID']) 76 | hypothesis = str(line['sentence2']) 77 | gold_label = self.cat2label[line['gold_label']] 78 | 79 | if self.tokenizer is not None: 80 | tokens = self.tokenizer.tokenize(hypothesis) 81 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 82 | else: 83 | input_ids = [] 84 | 85 | doc = {'image_id': image_id, 86 | 'hypothesis': hypothesis, 87 | 'hypothesis_input_ids': input_ids, 88 | 'label': gold_label} 89 | self.data.append(doc) 90 | 91 | pkl.dump(self.data, open(self.cached_data_file, 'wb')) 92 | 93 | logger.info("Loaded SNLI-VE {} dataset, with {} examples".format(self.split, len(self.data))) 94 | 95 | def __len__(self): 96 | return len(self.data) 97 | 98 | def __getitem__(self, index: int): 99 | 100 | """ 101 | Args: 102 | index : index of element in self.data to return as data instance 103 | 104 | Returns: 105 | dictionary containing inputs and targets for model to do SNLI-VE 106 | """ 107 | 108 | example = self.data[index] 109 | 110 | # Tokenize the input hypothesis 111 | hypothesis = example['hypothesis'] 112 | input_ids = example['hypothesis_input_ids'] 113 | 114 | # Get the image tensor from ImageDataset 115 | image_id = example['image_id'] 116 | image = self.images_dataset.get_image_data(image_id) 117 | 118 | label = example['label'] 119 | 120 | return {'hypothesis': hypothesis, 121 | 'input_ids': input_ids, 122 | 'image': image, 123 | 'label': label 124 | } 125 | 126 | 127 | def convert_to_low_shot(self, num_shots_per_class: int): 128 | """ 129 | Args: 130 | num_shots_per_class: int, denoting number of examples for each output label in low-shot setting 131 | """ 132 | 133 | logger.info("Converting SNLI-VE train split into low-shot dataset, with {} examples per class...".format(num_shots_per_class)) 134 | new_data = [] 135 | for i in range(self.num_labels): 136 | i_examples = [d for d in self.data if d['label'] == i] 137 | low_shot_examples = random.Random(1).sample(i_examples, num_shots_per_class) 138 | new_data.extend(low_shot_examples) 139 | self.data = new_data 140 | self.n_examples = len(self.data) 141 | logger.info("Converted into low-shot dataset, with {} examples".format(self.n_examples)) 142 | 143 | 144 | def snlive_batch_collate(batch: List[Dict], 145 | visual_input_type: str): 146 | 147 | """ 148 | Collates each model input for all batch items into a single model input (e.g. converts a list of input_ids into a matrix of size (batch_size, max_len)) 149 | 150 | Args: 151 | batch - list of batch items, each item being a dictionary returned by Dataset's __getitem__ method 152 | visual_input_type: string which specifies the type of visual input 153 | 154 | Returns: 155 | Dictionary containing batched inputs and outputs 156 | """ 157 | 158 | #pad_token = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0] # should be 0, but doing this anyway 159 | pad_token = 0 # tokenizer.pad_token_id 160 | 161 | # Pad the text inputs 162 | hypotheses = [x['hypothesis'] for x in batch] 163 | input_ids = [x['input_ids'] for x in batch] 164 | max_len = max([len(x) for x in input_ids]) 165 | input_ids_padded = [] 166 | attn_masks = [] 167 | for i in range(len(input_ids)): 168 | ids_padded = input_ids[i] + [pad_token]*(max_len - len(input_ids[i])) 169 | attn_mask = [1]*len(input_ids[i]) + [0]*(max_len - len(input_ids[i])) 170 | 171 | input_ids_padded.append(ids_padded) 172 | attn_masks.append(attn_mask) 173 | input_ids = torch.tensor(input_ids_padded, dtype=torch.long) 174 | attn_mask = torch.tensor(attn_masks, dtype=torch.long) 175 | 176 | # Stack the target tensors 177 | # Create labels tensor 178 | labels = [x['label'] for x in batch] 179 | labels = torch.tensor(labels, dtype=torch.long) 180 | 181 | # Depending on the visual_input_type variable, process the images accordingly 182 | images = [x['image'] for x in batch] 183 | images = image_collate(images, visual_input_type) 184 | 185 | return {'raw_texts': hypotheses, 186 | 'input_ids': input_ids, 187 | 'attn_mask': attn_mask, 188 | 'images': images, 189 | 'labels': labels} 190 | 191 | def build_snli_ve_dataloader(args, 192 | data_dir: str, 193 | images_dataset: Flickr30KImagesDataset, 194 | split: str, 195 | visual_input_type: str, 196 | **kwargs) -> torch.utils.data.DataLoader: 197 | 198 | """ 199 | Creates the SNLI-VE Dataloader, which gives batches of SNLI-VE inputs and outputs 200 | 201 | Args: 202 | args 203 | data_dir : path containing SNLI-VE hypotheses and annotations. 204 | images_dataset : instance of Flickr30KImagesDataset, that is used to retrieve the Flickr30K image for each question 205 | split: either train/val split 206 | visual_input_type: format of visual input to model 207 | 208 | Returns: 209 | DataLoader object 210 | """ 211 | 212 | 213 | batch_size = args.batch_size 214 | shuffle = True if split == 'train' else False 215 | 216 | logger.info("Creating SNLI-VE {} dataloader with batch size of {}".format(split, batch_size)) 217 | 218 | dataset = SnliVEDataset(data_dir, images_dataset, split, **kwargs) 219 | dataloader = torch.utils.data.DataLoader( 220 | dataset, 221 | num_workers=args.num_workers, 222 | batch_size=batch_size, 223 | shuffle=shuffle, 224 | collate_fn=lambda x: snlive_batch_collate(x, visual_input_type)) 225 | 226 | return dataloader 227 | 228 | if __name__ == '__main__': 229 | data_dir = '/data/datasets/MCL/snli-ve/' 230 | class Args: 231 | def __init__(self): 232 | self.batch_size = 4 233 | self.shuffle = True 234 | self.num_workers = 2 235 | self.visual_input_type = 'pil-image' 236 | args = Args() 237 | 238 | images_dataset = Flickr30KImagesDataset('/data/datasets/MCL/flickr30k/', args.visual_input_type) 239 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 240 | snli_ve_dataloader = build_snli_ve_dataloader(args, data_dir, images_dataset, 'train', args.visual_input_type, tokenizer=tokenizer) 241 | 242 | for batch in snli_ve_dataloader: 243 | pdb.set_trace() 244 | -------------------------------------------------------------------------------- /src/modeling/models/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.vision_transformer import _cfg, PatchEmbed 7 | from timm.models.registry import register_model 8 | from timm.models.layers import trunc_normal_, DropPath 9 | from .adapter import Adapter 10 | 11 | 12 | class Mlp(nn.Module): 13 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 14 | """ 15 | 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 36 | super().__init__() 37 | self.num_heads = num_heads 38 | head_dim = dim // num_heads 39 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 40 | self.scale = qk_scale or head_dim ** -0.5 41 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | self.attn_gradients = None 46 | self.attention_map = None 47 | 48 | def save_attn_gradients(self, attn_gradients): 49 | self.attn_gradients = attn_gradients 50 | 51 | def get_attn_gradients(self): 52 | return self.attn_gradients 53 | 54 | def save_attention_map(self, attention_map): 55 | self.attention_map = attention_map 56 | 57 | def get_attention_map(self): 58 | return self.attention_map 59 | 60 | def forward(self, x, register_hook=False): 61 | B, N, C = x.shape 62 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 63 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 64 | 65 | attn = (q @ k.transpose(-2, -1)) * self.scale 66 | attn = attn.softmax(dim=-1) 67 | attn = self.attn_drop(attn) 68 | 69 | if register_hook: 70 | self.save_attention_map(attn) 71 | attn.register_hook(self.save_attn_gradients) 72 | 73 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 74 | x = self.proj(x) 75 | x = self.proj_drop(x) 76 | return x 77 | 78 | 79 | class Block(nn.Module): 80 | 81 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 82 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, adapter_config=None): 83 | super().__init__() 84 | self.norm1 = norm_layer(dim) 85 | self.attn = Attention( 86 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 87 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 88 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 89 | self.norm2 = norm_layer(dim) 90 | mlp_hidden_dim = int(dim * mlp_ratio) 91 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 92 | 93 | if adapter_config is None: 94 | self.adaptered = False 95 | else: 96 | self.adaptered = True 97 | self.adapter = Adapter(**adapter_config, model_dim=dim) 98 | 99 | def forward(self, x, register_hook=False): 100 | # drop_path is something like DropOut 101 | x = x + self.attn(self.norm1(x), register_hook=register_hook) 102 | 103 | # in ViT, layernorm is also applied after self-attention 104 | if self.adaptered: 105 | # ViTOutput, following Vilt 106 | x = x + self.mlp(self.norm2(x)) 107 | x = self.adapter(x, x) 108 | else: 109 | x = x + self.mlp(self.norm2(x)) 110 | return x 111 | 112 | class VisionTransformer(nn.Module): 113 | """ Vision Transformer 114 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 115 | https://arxiv.org/abs/2010.11929 116 | """ 117 | 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, adapter_config=None): 121 | """ 122 | Args: 123 | img_size (int, tuple): input image size 124 | patch_size (int, tuple): patch size 125 | in_chans (int): number of input channels 126 | num_classes (int): number of classes for classification head 127 | embed_dim (int): embedding dimension 128 | depth (int): depth of transformer 129 | num_heads (int): number of attention heads 130 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 131 | qkv_bias (bool): enable bias for qkv if True 132 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 133 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 134 | drop_rate (float): dropout rate 135 | attn_drop_rate (float): attention dropout rate 136 | drop_path_rate (float): stochastic depth rate 137 | norm_layer: (nn.Module): normalization layer 138 | """ 139 | super().__init__() 140 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 141 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 142 | 143 | self.patch_embed = PatchEmbed( 144 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 145 | num_patches = self.patch_embed.num_patches 146 | 147 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 148 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 149 | self.pos_drop = nn.Dropout(p=drop_rate) 150 | 151 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 152 | self.blocks = nn.ModuleList([ 153 | Block( 154 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 155 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, adapter_config=adapter_config) 156 | for i in range(depth)]) 157 | self.norm = norm_layer(embed_dim) 158 | 159 | trunc_normal_(self.pos_embed, std=.02) 160 | trunc_normal_(self.cls_token, std=.02) 161 | self.apply(self._init_weights) 162 | 163 | def _init_weights(self, m): 164 | if isinstance(m, nn.Linear): 165 | trunc_normal_(m.weight, std=.02) 166 | if isinstance(m, nn.Linear) and m.bias is not None: 167 | nn.init.constant_(m.bias, 0) 168 | elif isinstance(m, nn.LayerNorm): 169 | nn.init.constant_(m.bias, 0) 170 | nn.init.constant_(m.weight, 1.0) 171 | 172 | @torch.jit.ignore 173 | def no_weight_decay(self): 174 | return {'pos_embed', 'cls_token'} 175 | 176 | def forward(self, x, register_blk=-1): 177 | B = x.shape[0] 178 | x = self.patch_embed(x) 179 | 180 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 181 | x = torch.cat((cls_tokens, x), dim=1) 182 | 183 | x = x + self.pos_embed[:, :x.size(1), :] 184 | x = self.pos_drop(x) 185 | 186 | for i, blk in enumerate(self.blocks): 187 | x = blk(x, register_blk == i) 188 | x = self.norm(x) 189 | 190 | return x 191 | 192 | 193 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 194 | # interpolate position embedding 195 | embedding_size = pos_embed_checkpoint.shape[-1] 196 | num_patches = visual_encoder.patch_embed.num_patches 197 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 198 | # height (== width) for the checkpoint position embedding 199 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 200 | # height (== width) for the new position embedding 201 | new_size = int(num_patches ** 0.5) 202 | 203 | if orig_size != new_size: 204 | # class_token and dist_token are kept unchanged 205 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 206 | # only the position tokens are interpolated 207 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 208 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 209 | pos_tokens = torch.nn.functional.interpolate( 210 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 211 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 212 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 213 | print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2)) 214 | 215 | return new_pos_embed 216 | else: 217 | return pos_embed_checkpoint 218 | -------------------------------------------------------------------------------- /src/configs/task_configs.py: -------------------------------------------------------------------------------- 1 | from src.train.visionlanguage_tasks.train_vqa import VQATrainer 2 | from src.train.visionlanguage_tasks.task_trainer import TaskTrainer 3 | 4 | SUPPORTED_VL_TASKS = [ 5 | "vqa", 6 | "abstract", 7 | "toronto", 8 | "pvqa", 9 | "med", 10 | "art", 11 | "abstract_albef", 12 | "toronto_albef", 13 | "pvqa_albef", 14 | "med_albef", 15 | "art_albef", 16 | ] 17 | SUPPORTED_ORDER = ["Order1", "Order2", "Order3", "Order4", "Order5", "Debug_order"] 18 | 19 | mscoco_config = { 20 | "data_dir": "/nfs/data3/zhangya/albef/mscoco/", 21 | } 22 | abstract_image_config = { 23 | "data_dir": [ 24 | "/nfs/data3/zhangya/vqa_abstract/train2015", 25 | "/nfs/data3/zhangya/vqa_abstract/val2015", 26 | ], 27 | } 28 | toronto_image_config = { 29 | "data_dir": [ 30 | "/nfs/data3/zhangya/albef/mscoco/train2014", 31 | "/nfs/data3/zhangya/albef/mscoco/val2014", 32 | ] 33 | } 34 | 35 | pvqa_image_config = { 36 | "data_dir": [ 37 | "/nfs/data3/zhangya/PathVQA_data/PathVQA/split/images/train", 38 | "/nfs/data3/zhangya/albef/pvqa/split/images/val", 39 | "/nfs/data3/zhangya/albef/pvqa/split/images/test", 40 | ] 41 | } 42 | med_image_config = {"data_dir": ["/nfs/data3/yyang/med_images/all_images"]} 43 | art_image_config = {"data_dir": ["/nfs/data3/yyang/AQUA/SemArt/Images"]} 44 | 45 | flickr_config = { 46 | "data_dir": "flickr30k/", 47 | } 48 | 49 | vqa_config = { 50 | "task_name": "VQAv2", 51 | "data_dir": "vqav2/", 52 | "images_source": "ms-coco", 53 | "splits": ["train", "val"], 54 | "num_labels": 3129, 55 | "num_images": 1, 56 | "model_type": "classification", 57 | "num_epochs": 50, # Yao: original 10 58 | "lr": 1e-4, 59 | "weight_decay": 1e-2, 60 | "adam_epsilon": 1e-8, 61 | "warmup_ratio": 0.1, 62 | "task_trainer": VQATrainer, 63 | "random_baseline_score": 0.0, 64 | } 65 | 66 | abstract_config = { 67 | "task_name": "abstract", 68 | "data_dir": "/nfs/data3/zhangya/vqa_abstract", 69 | "images_source": "abstract_image", 70 | "splits": ["train", "val_small"], 71 | "num_labels": 500, 72 | "num_images": 1, 73 | "model_type": "classification", 74 | "num_epochs": 20, # Yao: original 10 75 | "lr": 1e-4, 76 | "weight_decay": 1e-2, 77 | "adam_epsilon": 1e-8, 78 | "warmup_ratio": 0.1, 79 | "task_trainer": VQATrainer, 80 | "random_baseline_score": 0.0, 81 | } 82 | 83 | toronto_config = { 84 | "task_name": "toronto", 85 | "data_dir": "/nfs/data3/zhangya/torontoCOCO", 86 | "images_source": "toronto_image", 87 | "splits": ["train", "val"], 88 | "num_labels": 430, 89 | "num_images": 1, 90 | "model_type": "classification", 91 | "num_epochs": 20, # Yao: original 10 92 | "lr": 1e-4, 93 | "weight_decay": 1e-2, 94 | "adam_epsilon": 1e-8, 95 | "warmup_ratio": 0.1, 96 | "task_trainer": VQATrainer, 97 | "random_baseline_score": 0.0, 98 | } 99 | pvqa_config = { 100 | "task_name": "pvqa", 101 | "data_dir": "/nfs/data3/zhangya/albef/pvqa", 102 | "images_source": "pvqa_image", 103 | "splits": ["train", "val"], 104 | "num_labels": 2540, 105 | "num_images": 1, 106 | "model_type": "classification", 107 | "num_epochs": 20, # Yao: original 10 108 | "lr": 1e-4, 109 | "weight_decay": 1e-2, 110 | "adam_epsilon": 1e-8, 111 | "warmup_ratio": 0.1, 112 | "task_trainer": VQATrainer, 113 | "random_baseline_score": 0.0, 114 | } 115 | 116 | med_config = { 117 | "task_name": "med", 118 | "data_dir": "/nfs/data3/zhangya/VQA-Med-2019", 119 | "images_source": "med_image", 120 | "splits": ["train", "val"], 121 | "num_labels": 1701, 122 | "num_images": 1, 123 | "model_type": "classification", 124 | "num_epochs": 20, # Yao: original 10 125 | "lr": 1e-4, 126 | "weight_decay": 1e-2, 127 | "adam_epsilon": 1e-8, 128 | "warmup_ratio": 0.1, 129 | "task_trainer": VQATrainer, 130 | "random_baseline_score": 0.0, 131 | } 132 | 133 | art_config = { 134 | "task_name": "art", 135 | "data_dir": "/nfs/data3/zhangya/albef/art", 136 | "images_source": "art_image", 137 | "splits": ["train", "val"], 138 | "num_labels": 326, 139 | "num_images": 1, 140 | "model_type": "classification", 141 | "num_epochs": 20, # Yao: original 10 142 | "lr": 1e-4, 143 | "weight_decay": 1e-2, 144 | "adam_epsilon": 1e-8, 145 | "warmup_ratio": 0.1, 146 | "task_trainer": VQATrainer, 147 | "random_baseline_score": 0.0, 148 | 149 | } 150 | 151 | abstract_albef_config = { 152 | "task_name": "abstract", 153 | "train": "/nfs/data3/zhangya/vqa_abstract/vqa_abstract_train.json", 154 | "train_small": "/nfs/data3/zhangya/vqa_abstract/abstract_train_small.json", 155 | "val": "/nfs/data3/zhangya/vqa_abstract/abstract_val.json", 156 | "test": "/nfs/data3/zhangya/vqa_abstract/abstract_test_small.json", 157 | "images": "/nfs/data3/zhangya/vqa_abstract", 158 | "answer_list": "/nfs/data3/zhangya/vqa_abstract/answer_list.json", 159 | "splits": ["train", "val"], 160 | # 'num_labels': 500, 161 | # 'num_images': 1, 162 | # 'model_type': 'classification', 163 | "num_epochs": 20, # Yao: original 10 164 | # 'lr': 1e-4, 165 | "weight_decay": 1e-2, 166 | "adam_epsilon": 1e-8, 167 | # 'warmup_ratio': 0.1, 168 | "task_trainer": VQATrainer, 169 | # 'random_baseline_score': 0.0, 170 | # 'low_shot_config': {'task_trainer': LowShotVQATrainer, 171 | # 'type': 'percentage', 172 | # 'percentage': 0.05, 173 | # 'eval_epochs': [6, 8, 10, 15, 20, 25, 30]} 174 | } 175 | 176 | toronto_albef_config = { 177 | "task_name": "toronto", 178 | "train": "/nfs/data3/zhangya/torontoCOCO/toronto_train.json", 179 | "train_small": "/nfs/data3/zhangya/torontoCOCO/toronto_train_small.json", 180 | "val": "/nfs/data3/zhangya/torontoCOCO/toronto_val.json", 181 | "test": "/nfs/data3/zhangya/torontoCOCO/toronto_test_small.json", 182 | "images": "/nfs/data3/zhangya/albef/mscoco", 183 | "answer_list": "/nfs/data3/zhangya/torontoCOCO/answer_list.json", 184 | "splits": ["train", "val"], 185 | # 'num_labels': 430, 186 | # 'num_images': 1, 187 | # 'model_type': 'classification', 188 | "num_epochs": 20, # Yao: original 10 189 | # 'lr': 1e-4, 190 | "weight_decay": 1e-2, 191 | "adam_epsilon": 1e-8, 192 | # 'warmup_ratio': 0.1, 193 | "task_trainer": VQATrainer, 194 | # 'random_baseline_score': 0.0, 195 | # 'low_shot_config': {'task_trainer': LowShotVQATrainer, 196 | # 'type': 'percentage', 197 | # 'percentage': 0.05, 198 | # 'eval_epochs': [6, 8, 10, 15, 20, 25, 30]} 199 | } 200 | pvqa_albef_config = { 201 | "task_name": "pvqa", 202 | "train": "/nfs/data3/zhangya/albef/pvqa/pvqa_train.json", 203 | "train_small": "/nfs/data3/zhangya/albef/pvqa/pvqa_train_small.json", 204 | "val": "/nfs/data3/zhangya/albef/pvqa/pvqa_val.json", 205 | "test": "/nfs/data3/zhangya/albef/pvqa/pvqa_test_small.json", 206 | "answer_list": "/nfs/data3/zhangya/albef/pvqa/answer_list_small.json", 207 | "images": "/nfs/data3/zhangya/PathVQA_data/PathVQA/split/images", 208 | "splits": ["train", "val"], 209 | "num_labels": 2540, 210 | "num_images": 1, 211 | "model_type": "classification", 212 | "num_epochs": 20, # Yao: original 10 213 | "lr": 1e-4, 214 | "weight_decay": 1e-2, 215 | "adam_epsilon": 1e-8, 216 | "warmup_ratio": 0.1, 217 | "task_trainer": VQATrainer, 218 | "random_baseline_score": 0.0, 219 | } 220 | 221 | med_albef_config = { 222 | "task_name": "med", 223 | "train": "/nfs/data3/zhangya/VQA-Med-2019/med2019_train.json", 224 | "train_small": "/nfs/data3/zhangya/VQA-Med-2019/med2019_train.json", 225 | "val": "/nfs/data3/zhangya/VQA-Med-2019/med2019_val.json", 226 | "test": "/nfs/data3/zhangya/VQA-Med-2019/med2019_test.json", 227 | "answer_list": "/nfs/data3/zhangya/VQA-Med-2019/answer_list_trainval.json", 228 | "images": "/nfs/data3/zhangya/VQA-Med-2019", 229 | "splits": ["train", "val"], 230 | "num_labels": 1701, 231 | "num_images": 1, 232 | "model_type": "classification", 233 | "num_epochs": 20, # Yao: original 10 234 | "lr": 1e-4, 235 | "weight_decay": 1e-2, 236 | "adam_epsilon": 1e-8, 237 | "warmup_ratio": 0.1, 238 | "task_trainer": VQATrainer, 239 | "random_baseline_score": 0.0, 240 | 241 | } 242 | 243 | art_albef_config = { 244 | "task_name": "art", 245 | "train": "/nfs/data3/zhangya/albef/art/art_train.json", 246 | "train_small": "/nfs/data3/zhangya/albef/art/art_train_small.json", 247 | "val": "/nfs/data3/zhangya/albef/art/art_val.json", 248 | "test": "/nfs/data3/zhangya/albef/art/art_test_small.json", 249 | "images": "/nfs/data3/yyang/AQUA/SemArt/Images", 250 | "answer_list": "/nfs/data3/zhangya/albef/art/answer_list_small.json", 251 | "splits": ["train", "val"], 252 | # 'num_labels': 326, 253 | # 'num_images': 1, 254 | # 'model_type': 'classification', 255 | "num_epochs": 20, # Yao: original 10 256 | # 'lr': 1e-4, 257 | "weight_decay": 1e-2, 258 | "adam_epsilon": 1e-8, 259 | # 'warmup_ratio': 0.1, 260 | "task_trainer": VQATrainer, 261 | # 'random_baseline_score': 0.0, 262 | # 'low_shot_config': {'task_trainer': LowShotVQATrainer, 263 | # 'type': 'percentage', 264 | # 'percentage': 0.05, 265 | # 'eval_epochs': [6, 8, 10, 15, 20, 25, 30]} 266 | 267 | } 268 | 269 | task_configs = { 270 | "ms-coco": mscoco_config, 271 | "flickr30k": flickr_config, 272 | "vqa": vqa_config, 273 | "abstract": abstract_config, 274 | "toronto": toronto_config, 275 | "pvqa": pvqa_config, 276 | "med": med_config, 277 | "art": art_config, 278 | "abstract_albef": abstract_albef_config, 279 | "toronto_albef": toronto_albef_config, 280 | "pvqa_albef": pvqa_albef_config, 281 | "med_albef": med_albef_config, 282 | "art_albef": art_albef_config, 283 | "abstract_image": abstract_image_config, 284 | "toronto_image": toronto_image_config, 285 | "pvqa_image": pvqa_image_config, 286 | "med_image": med_image_config, 287 | "art_image": art_image_config, 288 | } 289 | -------------------------------------------------------------------------------- /src/train/visionlanguage_tasks/train_vqa_crossvqa.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import argparse 3 | import datetime 4 | import json 5 | import logging 6 | import os 7 | import random 8 | import sys 9 | import time 10 | import math 11 | import shutil 12 | import pickle as pkl 13 | import pdb 14 | from tqdm import tqdm 15 | from typing import List, Dict 16 | import torch.nn.functional as F 17 | 18 | from src.modeling.continual_learner import ContinualLearner 19 | 20 | sys.path.insert(0, ".") 21 | 22 | import numpy as np 23 | import torch 24 | from torch import nn 25 | from torch.optim import AdamW 26 | from transformers import get_polynomial_decay_schedule_with_warmup 27 | from src.data.image_datasets.vgimages_dataset import VGImagesDataset 28 | from src.data.image_datasets.vizwizimages_dataset import vizwizImagesDataset 29 | from src.data.visionlanguage_datasets.vqa_dataset_crossvqa import ( 30 | build_vqa_vilt_dataloader, 31 | build_vqa_albef_dataloader, 32 | ) 33 | from src.train.visionlanguage_tasks.task_trainer import TaskTrainer 34 | from src.utils.wandb import wandb_logger 35 | from src.modeling.models.tokenization_bert import BertTokenizer 36 | from src.utils.seed_utils import set_seed 37 | 38 | 39 | class VQATrainerCross(TaskTrainer): 40 | def __init__( 41 | self, 42 | logger, 43 | args: argparse.Namespace, 44 | task_configs: Dict, 45 | model_config: Dict, 46 | device: torch.device, 47 | task_key: str, 48 | task_output_dir=None, 49 | client_id=-1, 50 | accelerator=None, 51 | ): 52 | 53 | """ 54 | Initializes a Trainer that handles training of a model on the VQA task 55 | 56 | args: Arguments provided by user 57 | task_configs: dictionary containing task-specific configuration parameters for all tasks 58 | model_config: dictionary containing model-specific configuration parameters 59 | device: cuda/cpu 60 | """ 61 | 62 | super().__init__() 63 | self.accelerator = accelerator 64 | self.device = self.accelerator.device 65 | # make sure different process gets different seed 66 | set_seed(args.seed + self.accelerator.process_index) 67 | self.logger = logger 68 | 69 | # Create W&B experiment 70 | if args.do_wandb_logging: 71 | self.logger.info( 72 | "W&B project: {}, experiment: {}".format( 73 | "CARVEN", task_output_dir.split("/")[-3] 74 | ) 75 | ) 76 | if self.accelerator.is_main_process: 77 | self.accelerator.init_trackers(project_name="CARVEN") 78 | self.accelerator.trackers[0].run.name = ( 79 | task_output_dir.split("/")[-4] 80 | + "/" 81 | + task_output_dir.split("/")[-3] 82 | + "/" 83 | + task_output_dir.split("/")[-1] 84 | ) 85 | 86 | self.args = args 87 | self.local_epochs = args.local_epochs 88 | self.task_key = task_key 89 | self.task_output_dir = task_output_dir 90 | 91 | self.vqa_config = task_configs[self.task_key] # in task_configs.py 92 | self.batch2inputs_converter = model_config["batch2inputs_converter"] 93 | 94 | # Model-specific stuff 95 | if "vilt" in args.encoder_name: 96 | self.model_name = "vilt" 97 | self.data_dir = os.path.join( 98 | args.climb_data_dir, self.vqa_config["data_dir"] 99 | ) # vqa_abstract 100 | self.visual_input_type = model_config["visual_input_type"] # pil_image 101 | 102 | # Create dataloaders for training and validation 103 | # Load COCO Images dataset for image data backbone 104 | images_source = self.vqa_config["images_source"] 105 | if task_key=='gqa' or "clove" in task_key: 106 | self.images_dataset = VGImagesDataset( 107 | coco_dir=args.climb_data_dir, 108 | data_dir=None, 109 | visual_input_type=args.visual_input_type, 110 | task_key=self.task_key, 111 | ) 112 | elif task_key=='vizwiz': 113 | self.images_dataset = vizwizImagesDataset( 114 | coco_dir=args.climb_data_dir, 115 | data_dir=None, 116 | visual_input_type=args.visual_input_type, 117 | task_key=self.task_key, 118 | ) 119 | else: 120 | mscoco_config = task_configs[images_source] 121 | from src.data.image_datasets.cocoimages_dataset_crossvqas import MSCOCOImagesDataset 122 | self.images_dataset = MSCOCOImagesDataset( 123 | coco_dir=args.climb_data_dir, 124 | data_dir=mscoco_config["data_dir"], 125 | visual_input_type=args.visual_input_type, 126 | task_key=self.task_key, 127 | ) 128 | 129 | self.vqa_train_dataloader = build_vqa_vilt_dataloader( 130 | logger=self.logger, 131 | args=args, 132 | data_dir=self.data_dir, 133 | images_dataset=self.images_dataset, 134 | split=self.args.splits[0], 135 | task_key=self.task_key, 136 | visual_input_type=self.visual_input_type, 137 | client_id=client_id, 138 | ) 139 | 140 | self.vqa_val_dataloader = build_vqa_vilt_dataloader( 141 | logger=self.logger, 142 | args=args, 143 | data_dir=self.data_dir, 144 | images_dataset=self.images_dataset, 145 | split=self.args.splits[1], 146 | task_key=self.task_key, 147 | visual_input_type=self.visual_input_type, 148 | client_id=client_id, 149 | ) 150 | 151 | self.vqa_test_dataloader = build_vqa_vilt_dataloader( 152 | logger=self.logger, 153 | args=args, 154 | data_dir=self.data_dir, 155 | images_dataset=self.images_dataset, 156 | split=self.args.splits[2], 157 | task_key=self.task_key, 158 | visual_input_type=self.visual_input_type, 159 | client_id=client_id, 160 | ) 161 | else: 162 | self.model_name = "albef" 163 | self.data_dir = os.path.join( 164 | args.climb_data_dir, self.vqa_config["data_dir"] 165 | ) # vqa_abstract 166 | 167 | images_source = self.vqa_config["images_source"] 168 | if task_key=='gqa' or "clove" in task_key: 169 | self.images_dataset = VGImagesDataset( 170 | coco_dir=args.climb_data_dir, 171 | data_dir=None, 172 | visual_input_type=args.visual_input_type, 173 | task_key=self.task_key, 174 | ) 175 | elif task_key=='vizwiz': 176 | self.images_dataset = vizwizImagesDataset( 177 | coco_dir=args.climb_data_dir, 178 | data_dir=None, 179 | visual_input_type=args.visual_input_type, 180 | task_key=self.task_key, 181 | ) 182 | else: 183 | mscoco_config = task_configs[images_source] 184 | from src.data.image_datasets.cocoimages_dataset_crossvqas import MSCOCOImagesDataset 185 | self.images_dataset = MSCOCOImagesDataset( 186 | coco_dir=args.climb_data_dir, 187 | data_dir=mscoco_config["data_dir"], 188 | visual_input_type=args.visual_input_type, 189 | task_key=self.task_key, 190 | ) 191 | self.vqa_train_dataloader = build_vqa_albef_dataloader( 192 | logger=self.logger, 193 | args=args, 194 | data_dir=self.data_dir, 195 | images_dataset=self.images_dataset, 196 | vqa_config=self.vqa_config, 197 | split=self.args.splits[0], 198 | task_key=self.task_key, 199 | client_id=client_id, 200 | ) 201 | 202 | self.vqa_val_dataloader = build_vqa_albef_dataloader( 203 | logger=self.logger, 204 | args=args, 205 | data_dir=self.data_dir, 206 | images_dataset=self.images_dataset, 207 | vqa_config=self.vqa_config, 208 | split=self.args.splits[1], 209 | task_key=self.task_key, 210 | client_id=client_id, 211 | ) 212 | 213 | self.vqa_test_dataloader = build_vqa_albef_dataloader( 214 | logger=self.logger, 215 | args=args, 216 | data_dir=self.data_dir, 217 | images_dataset=self.images_dataset, 218 | vqa_config=self.vqa_config, 219 | split=self.args.splits[2], 220 | task_key=self.task_key, 221 | client_id=client_id, 222 | ) 223 | 224 | ( 225 | self.vqa_train_dataloader, 226 | self.vqa_val_dataloader, 227 | self.vqa_test_dataloader, 228 | ) = self.accelerator.prepare( 229 | self.vqa_train_dataloader, self.vqa_val_dataloader, self.vqa_test_dataloader 230 | ) 231 | 232 | # Training hyperparameters 233 | self.num_epochs = self.args.num_epochs 234 | self.lr = self.args.lr 235 | self.adam_epsilon = self.vqa_config["adam_epsilon"] 236 | self.weight_decay = self.vqa_config["weight_decay"] 237 | self.loss_criterion = nn.BCEWithLogitsLoss(reduction="mean") 238 | self.max_steps = len(self.vqa_train_dataloader) * self.num_epochs 239 | self.warmup_ratio = 0.1 # TODO remove hard code 240 | 241 | def compute_score_with_logits( 242 | self, logits: torch.Tensor, labels: torch.Tensor 243 | ) -> torch.Tensor: 244 | """ 245 | Given logits for each answer in VQA classification, selects answer with max logit and returns VQA-score for that answer 246 | logits: logits for each answer - size=(batch_size, num_answers) 247 | labels: label for each answer in {0, 0.3, 0.6, 1} (batch_size, num_answers) 248 | 249 | Returns: 250 | scores: score of predicted answer (batch_size, num_answers) 251 | """ 252 | 253 | logits = torch.max(logits, 1)[1].data # argmax 254 | one_hots = torch.zeros(*labels.size()).to(self.device) 255 | one_hots.scatter_(1, logits.view(-1, 1), 1) 256 | scores = one_hots * labels 257 | return scores 258 | 259 | def get_train_dataloader(self): 260 | return self.vqa_train_dataloader 261 | 262 | def get_collate_fn(self): 263 | return self.vqa_train_dataloader.collate_fn 264 | 265 | def add_alpha(self, epoch, batch, step): 266 | if epoch > 0: # alpha is for distill 267 | alpha = 0.4 268 | else: 269 | alpha = 0.4 * min(1, step / len(self.vqa_train_dataloader)) 270 | batch.append(alpha) 271 | return batch 272 | -------------------------------------------------------------------------------- /src/data/visionlanguage_datasets/vcr_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import json 5 | import logging 6 | import random 7 | import glob 8 | import base64 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import pickle as pkl 12 | import pdb 13 | import jsonlines 14 | from typing import List, Dict 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from torchvision import transforms as T 20 | from torch.utils.data import Dataset 21 | 22 | from PIL import Image 23 | 24 | from src.data.image_collation import image_collate 25 | 26 | logger = logging.getLogger(__name__) 27 | logging.basicConfig( 28 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 29 | datefmt='%m/%d/%Y %H:%M:%S', 30 | level=logging.INFO) 31 | 32 | GENDER_NEUTRAL_NAMES = ['Casey', 'Riley', 'Jessie', 'Jackie', 'Avery', 'Jaime', 'Peyton', 'Kerry', 'Jody', 'Kendall', 33 | 'Skyler', 'Frankie', 'Pat', 'Quinn', 'Morgan', 'Finley', 'Harley', 'Robbie', 'Sidney', 'Tommie', 34 | 'Ashley', 'Carter', 'Adrian', 'Clarke', 'Logan', 'Mickey', 'Nicky', 'Parker', 'Tyler', 35 | 'Reese', 'Charlie', 'Austin', 'Denver', 'Emerson', 'Tatum', 'Dallas', 'Haven', 'Jordan', 36 | 'Robin', 'Rory', 'Bellamy', 'Salem', 'Sutton', 'Gray', 'Shae', 'Kyle', 'Alex', 'Ryan', 37 | 'Cameron', 'Dakota'] 38 | 39 | 40 | def process_list(mytext, objects): 41 | ## Read file with the name of the color per object 42 | 43 | ## processing the text 44 | text = '' 45 | for element in mytext: 46 | #print(element) 47 | if(type(element) == list): #### If it's a list we need to process each object 48 | for subelement in element: 49 | if(objects[int(subelement)] == 'person'): 50 | temporal_text = GENDER_NEUTRAL_NAMES[int(subelement)] 51 | else: 52 | temporal_text = 'the gray ' + str(objects[int(subelement)]).strip() 53 | elif(type(element) == int): 54 | if(objects[int(element)] == 'person'): 55 | temporal_text = GENDER_NEUTRAL_NAMES[int(subelement)] 56 | else: 57 | temporal_text = 'the gray ' + str(objects[int(subelement)]) 58 | else: 59 | temporal_text = element 60 | text += temporal_text + ' ' 61 | #print('text: ', text) 62 | return text 63 | 64 | class VCRDataset(Dataset): 65 | 66 | def __init__(self, 67 | data_dir: str, 68 | split: str, 69 | task_type='qa', 70 | **kwargs): 71 | 72 | """ 73 | Initiates the VCRDataset - loads all the questions and answers, concatenates each question and answer into a choice text 74 | (and converts to input IDs using the tokenizer, if provided) and stores all 4 choice texts and label 75 | Every item in self.data corresponds to a single VCR input 76 | 77 | Args: 78 | data_dir : path containing VCR questions and annotations 79 | split: either train/val/test split 80 | task_type: either 'qa' or 'qar', depending on if we do Q->A or QA->R 81 | 82 | Returns: 83 | Loads all annotations into self.data, where each item is a single VCR input 84 | """ 85 | 86 | self.data_dir = data_dir 87 | self.images_dataset = self.data_dir + 'draw_images/bbox/' 88 | 89 | self.image_dir = os.path.join(data_dir, 'vcr') 90 | self.split = split 91 | self.task_type = task_type 92 | self.tokenizer = kwargs['tokenizer'] if 'tokenizer' in kwargs else None 93 | 94 | self.annotations_file = os.path.join(data_dir, 'annotation/{}.jsonl'.format(split)) 95 | 96 | self.cached_data_file = os.path.join(data_dir, 'cached_vcr_data', 'vcr_'+ str(task_type) + '_' + '{}.pkl'.format(split)) 97 | if os.path.exists(self.cached_data_file): 98 | self.data = pkl.load(open(self.cached_data_file, 'rb')) 99 | else: 100 | self.data = [] 101 | json_lines = jsonlines.open(self.annotations_file) 102 | count = 0 103 | for line in tqdm(json_lines): 104 | 105 | image_path = os.path.join('drawn_images/' + str(split) + '/' + str(task_type)+ '/' + str(line['annot_id']) +'.jpg') ## train-0, train-1, train-2 106 | multichoice_texts = [] 107 | objects = line['objects'] ### objects 108 | 109 | question = process_list(line['question'], objects) ### question 110 | if(task_type == 'qa'): 111 | ### answers: question + ' [SEP] ' + answer 112 | for answer in line['answer_choices']: 113 | answer1 = process_list(answer, objects) 114 | text = question + ' [SEP] ' + answer1 115 | multichoice_texts.append(text) 116 | label = int(line['answer_label']) ##number 117 | 118 | else: 119 | ### rationales: question + '[SEP]' + answer + '[SEP]' + rationale 120 | answer = process_list( line['answer_choices'][int(line['answer_label'])], objects) 121 | for rationale in line['rationale_choices']: 122 | rationale1 = process_list(rationale, objects) 123 | text = question + ' [SEP] ' + answer + ' [SEP] ' + rationale1 124 | multichoice_texts.append(text) 125 | label = int(line['rationale_label']) ##number 126 | 127 | if self.tokenizer is not None: 128 | multichoice_tokens = [self.tokenizer.tokenize(text) for text in multichoice_texts] 129 | multichoice_input_ids = [self.tokenizer.convert_tokens_to_ids(t) for t in multichoice_tokens] 130 | else: 131 | multichoice_input_ids = [] 132 | 133 | doc = {'image_path': image_path, 134 | 'texts': multichoice_texts, 135 | 'input_ids': multichoice_input_ids, 136 | 'label': label} 137 | self.data.append(doc) 138 | 139 | pkl.dump(self.data, open(self.cached_data_file, 'wb')) 140 | self.n_examples = len(self.data) 141 | logger.info("Loaded VCR-{} {} dataset, with {} examples".format(self.task_type, self.split, len(self.data))) 142 | 143 | def __len__(self): 144 | return self.n_examples 145 | 146 | def __getitem__(self, index: int): 147 | 148 | """ 149 | Args: 150 | index : index of element in self.data to return as data instance 151 | 152 | Returns: 153 | dictionary containing inputs and targets for model to do VCR 154 | """ 155 | 156 | example = self.data[index] 157 | 158 | image_fn = os.path.join(self.data_dir, example['image_path']) 159 | pil_transform = T.Resize(size=384, max_size=640) 160 | image = Image.open(image_fn) 161 | image = image.convert('RGB') 162 | if min(list(image.size)) > 384: 163 | image = pil_transform(image) 164 | 165 | texts = example['texts'] 166 | label = example['label'] 167 | 168 | return {'texts': texts, 169 | 'image': image, 170 | 'label': label 171 | } 172 | 173 | def convert_to_low_shot(self, low_shot_percentage: float): 174 | """ 175 | Args: 176 | low_shot_percentage: float between 0 and 1, telling what % of full data to retain for low-shot setting 177 | """ 178 | 179 | logger.info("Converting VCR train split into low-shot dataset, with {:.2f}% training samples...".format(low_shot_percentage*100.0)) 180 | n_low_shot_examples = int(low_shot_percentage*self.n_examples) 181 | 182 | new_data = random.Random(1).sample(self.data, n_low_shot_examples) 183 | self.data = new_data 184 | self.n_examples = len(self.data) 185 | 186 | logger.info("Converted into low-shot dataset, with {} examples".format(self.n_examples)) 187 | 188 | def vcr_batch_collate(batch: List[Dict], 189 | visual_input_type: str): 190 | 191 | """ 192 | Collates each model input for all batch items into a single model input (e.g. converts a list of input_ids into a matrix of size (batch_size, max_len)) 193 | 194 | Args: 195 | batch - list of batch items, each item being a dictionary returned by Dataset's __getitem__ method 196 | visual_input_type: string which specifies the type of visual input 197 | 198 | Returns: 199 | Dictionary containing batched inputs and outputs 200 | """ 201 | 202 | assert visual_input_type == 'pil-image' 203 | texts = [x['texts'] for x in batch] 204 | pil_objs = [x['image'] for x in batch] 205 | labels = [x['label'] for x in batch] 206 | 207 | return {'raw_texts': texts, 208 | 'images': pil_objs, 209 | 'labels': torch.LongTensor(labels)} 210 | 211 | def build_vcr_dataloader(args, 212 | data_dir: str, 213 | split: str, 214 | task_type: str, 215 | visual_input_type: str, 216 | **kwargs) -> torch.utils.data.DataLoader: 217 | 218 | """ 219 | Creates the VCR Dataloader, which gives batches of VCR inputs and outputs 220 | 221 | Args: 222 | data_dir : path containing VCR questions and annotations. 223 | split: either train/val split 224 | task_type: either 'qa' or 'qar', depending on if we do Q->A or QA->R 225 | visual_input_type: format of visual input to model 226 | 227 | Returns: 228 | DataLoader object 229 | """ 230 | 231 | shuffle = True if split == 'train' else False 232 | 233 | assert visual_input_type == 'pil-image' # VCR not supported for other visual inputs yet! 234 | 235 | logger.info("Creating VCR {} dataloader with batch size of {}".format(split, args.batch_size//4)) 236 | 237 | dataset = VCRDataset(data_dir, split, task_type, **kwargs) 238 | 239 | dataloader = torch.utils.data.DataLoader( 240 | dataset, 241 | num_workers=args.num_workers, 242 | batch_size=args.batch_size//4, 243 | shuffle=shuffle, 244 | collate_fn=lambda x: vcr_batch_collate(x, visual_input_type)) 245 | return dataloader 246 | 247 | 248 | if __name__ == '__main__': 249 | 250 | class Args: 251 | def __init__(self): 252 | self.batch_size = 16 253 | self.num_workers = 2 254 | self.visual_input_type = 'pil-image' 255 | 256 | args = Args() 257 | data_dir = '/data/datasets/MCL/vcr/' 258 | #annotation_dir = '/data/datasets/MCL/vcr/annotation/' 259 | split = 'val' #'train' 260 | text = ['Why', 'is', [0], 'smiling', 'at', [1], '?'] 261 | objects = ['person', 'person', 'bottle'] 262 | #process_list(text, objects) 263 | 264 | from transformers import BertTokenizer 265 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 266 | #vcr.VCRDataset(data_dir, split, tokenizer, task_type='qa') 267 | 268 | vcr_train_dataloader = build_vcr_dataloader(args, data_dir, split= 'train', tokenizer = tokenizer, task_type = 'qa', visual_input_type=args.visual_input_type) 269 | vcr_val_dataloader = build_vcr_dataloader(args, data_dir, split= 'val', tokenizer = tokenizer, task_type = 'qa', visual_input_type=args.visual_input_type) 270 | 271 | vcr_train_dataloader = build_vcr_dataloader(args, data_dir, split= 'train', task_type = 'qar', visual_input_type=args.visual_input_type) 272 | vcr_val_dataloader = build_vcr_dataloader(args, data_dir, split= 'val', task_type = 'qar', visual_input_type=args.visual_input_type) 273 | 274 | for batch in vcr_val_dataloader: 275 | pdb.set_trace() -------------------------------------------------------------------------------- /src/data/visionlanguage_datasets/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import os 4 | import time 5 | import json 6 | import logging 7 | import random 8 | import glob 9 | import base64 10 | from tqdm import tqdm 11 | from collections import defaultdict 12 | import pickle as pkl 13 | import pdb 14 | from typing import List, Dict 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from torchvision import transforms as T 20 | from torch.utils.data import Dataset 21 | 22 | from PIL import Image 23 | from src.utils.vqa_utils import get_score, target_tensor 24 | 25 | from src.data.image_datasets.cocoimages_dataset import MSCOCOImagesDataset 26 | from src.data.image_collation import image_collate 27 | 28 | logger = logging.getLogger(__name__) 29 | logging.basicConfig( 30 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 31 | datefmt='%m/%d/%Y %H:%M:%S', 32 | level=logging.INFO) 33 | 34 | class VQADataset(Dataset): 35 | 36 | def __init__(self, 37 | data_dir: str, 38 | images_dataset: MSCOCOImagesDataset, 39 | split: str, 40 | **kwargs): 41 | 42 | """ 43 | Initiates the VQADataset - loads all the questions (and converts to input IDs using the tokenizer, if provided) 44 | and answers (including converting each to a numeric label, and a score based on occurence from annotators) 45 | Every item in self.data corresponds to a single QA pair, with a corresponding image 46 | 47 | Args: 48 | data_dir : path containing VQA questions and annotations. Also contains mapping from each answer in set of possible answers to a numerical label 49 | images_dataset : instance of MSCOCOImagesDataset, that is used to retrieve the MS-COCO image for each question 50 | split: either train/val split 51 | 52 | Returns: 53 | Loads all annotations into self.data, where each item is a single VQA pair 54 | """ 55 | 56 | self.images_dataset = images_dataset 57 | self.data_dir = data_dir 58 | self.split = split 59 | self.tokenizer = kwargs['tokenizer'] if 'tokenizer' in kwargs else None 60 | 61 | self.annotations_file = os.path.join(data_dir, 'v2_mscoco_{}2014_annotations.json'.format(split)) 62 | self.questions_file = os.path.join(data_dir, 'v2_OpenEnded_mscoco_{}2014_questions.json'.format(split)) 63 | self.ans2label_file = os.path.join(data_dir, 'ans2label.pkl'.format(split)) 64 | 65 | # Load mapping from answers to labels 66 | self.ans2label = pkl.load(open(self.ans2label_file, 'rb')) 67 | self.label2ans = {v: k for k, v in self.ans2label.items()} 68 | self.num_labels = len(self.label2ans) 69 | self.num_answers = len(self.ans2label) 70 | 71 | self.cached_data_file = os.path.join(data_dir, 'cached_vqa_data', 'vqa_{}.pkl'.format(split)) 72 | if os.path.exists(self.cached_data_file): 73 | # Load cached data 74 | self.data = pkl.load(open(self.cached_data_file, 'rb')) 75 | 76 | else: 77 | # Create map from question id to question 78 | questions = json.load(open(self.questions_file))['questions'] 79 | qid2qdata = {x['question_id']: x for x in questions} 80 | 81 | # Create data for each annotation 82 | annotations = json.load(open(self.annotations_file))['annotations'] 83 | self.data = [] 84 | for anno in annotations: 85 | qid = anno['question_id'] 86 | correct_answer = anno['multiple_choice_answer'] 87 | image_id = anno['image_id'] 88 | 89 | # Retrieve the question for this annotation 90 | qdata = qid2qdata[qid] 91 | assert qdata['image_id'] == image_id 92 | question = qdata['question'] 93 | if self.tokenizer is not None: 94 | tokens = self.tokenizer.tokenize(question) 95 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 96 | else: 97 | tokens = [] 98 | input_ids = [] 99 | 100 | # Map from each crowdsourced answer to occurrences in annotation 101 | answers = [a['answer'] for a in anno['answers']] 102 | answer_count = defaultdict(int) 103 | for ans in answers: 104 | answer_count[ans] += 1 105 | 106 | # Get label and score (0.3/0.6/1) corresponding to each crowdsourced answer 107 | labels = [] 108 | scores = [] 109 | answers = [] 110 | for answer in answer_count: 111 | if answer not in self.ans2label: 112 | continue 113 | labels.append(self.ans2label[answer]) 114 | score = get_score(answer_count[answer]) 115 | scores.append(score) 116 | answers.append(answer) 117 | 118 | # Store pre-processed example 119 | example = {'question_id': qid, 120 | 'image_id': image_id, 121 | 'question': question, 122 | 'question_input_ids': input_ids, 123 | 'correct_answer': correct_answer, 124 | 'labels': labels, 125 | 'answers': answers, 126 | 'scores': scores} 127 | self.data.append(example) 128 | 129 | pkl.dump(self.data, open(self.cached_data_file, 'wb')) 130 | 131 | self.n_examples = len(self.data) 132 | 133 | logger.info("Loaded VQAv2 {} dataset, with {} examples".format(self.split, len(self.data))) 134 | 135 | def __len__(self): 136 | return len(self.data) 137 | 138 | def __getitem__(self, index: int): 139 | 140 | """ 141 | Args: 142 | index : index of element in self.data to return as data instance 143 | 144 | Returns: 145 | dictionary containing inputs and targets for model to do VQA 146 | 147 | """ 148 | 149 | example = self.data[index] 150 | question_id = example['question_id'] 151 | 152 | # Tokenize the input question 153 | question = example['question'] 154 | input_ids = example['question_input_ids'] 155 | 156 | # Get the image tensor from ImageDataset 157 | image_id = example['image_id'] 158 | image = self.images_dataset.get_image_data(image_id) 159 | 160 | labels = example['labels'] 161 | scores = example['scores'] 162 | target_scores = target_tensor(self.num_labels, labels, scores) 163 | 164 | return {'question': question, 165 | 'input_ids': input_ids, 166 | 'image': image, 167 | 'labels': labels, 168 | 'target_scores': target_scores, 169 | 'question_id': question_id 170 | } 171 | 172 | def convert_to_low_shot(self, low_shot_percentage: float): 173 | """ 174 | Args: 175 | low_shot_percentage: float between 0 and 1, telling what % of full data to retain for low-shot setting 176 | """ 177 | 178 | logger.info("Converting VQA train split into low-shot dataset, with {:.2f}% training samples...".format(low_shot_percentage*100.0)) 179 | n_low_shot_examples = int(low_shot_percentage*self.n_examples) 180 | 181 | new_data = random.Random(1).sample(self.data, n_low_shot_examples) 182 | self.data = new_data 183 | self.n_examples = len(self.data) 184 | 185 | logger.info("Converted into low-shot dataset, with {} examples".format(self.n_examples)) 186 | 187 | def vqa_batch_collate(batch: List[Dict], 188 | visual_input_type: str): 189 | 190 | """ 191 | Collates each model input for all batch items into a single model input (e.g. converts a list of input_ids into a matrix of size (batch_size, max_len)) 192 | 193 | Args: 194 | batch - list of batch items, each item being a dictionary returned by Dataset's __getitem__ method 195 | visual_input_type: string which specifies the type of visual input 196 | 197 | Returns: 198 | Dictionary containing batched inputs and outputs 199 | """ 200 | 201 | pad_token = 0 # tokenizer.pad_token_id 202 | 203 | # Pad the text inputs 204 | questions = [x['question'] for x in batch] 205 | input_ids = [x['input_ids'] for x in batch] 206 | max_len = max([len(x) for x in input_ids]) 207 | input_ids_padded = [] 208 | attn_masks = [] 209 | for i in range(len(input_ids)): 210 | ids_padded = input_ids[i] + [pad_token]*(max_len - len(input_ids[i])) 211 | attn_mask = [1]*len(input_ids[i]) + [0]*(max_len - len(input_ids[i])) 212 | 213 | input_ids_padded.append(ids_padded) 214 | attn_masks.append(attn_mask) 215 | input_ids = torch.tensor(input_ids_padded, dtype=torch.long) 216 | attn_mask = torch.tensor(attn_masks, dtype=torch.long) 217 | 218 | # Stack the target tensors 219 | batch_labels = [x['labels'] for x in batch] 220 | batch_scores = [x['target_scores'] for x in batch] 221 | batch_scores = torch.stack(batch_scores, dim=0) 222 | 223 | # Depending on the visual_input_type variable, process the images accordingly 224 | images = [x['image'] for x in batch] 225 | images = image_collate(images, visual_input_type) 226 | 227 | return {'raw_texts': questions, 228 | 'input_ids': input_ids, 229 | 'attn_mask': attn_mask, 230 | 'images': images, 231 | 'target_scores': batch_scores, 232 | 'labels': batch_labels} 233 | 234 | def build_vqa_dataloader(args, 235 | data_dir: str, 236 | images_dataset: MSCOCOImagesDataset, 237 | split: str, 238 | visual_input_type: str, 239 | **kwargs) -> torch.utils.data.DataLoader: 240 | 241 | """ 242 | Creates the VQA Dataloader, which gives batches of VQA inputs and outputs 243 | 244 | Args: 245 | data_dir : path containing VQA questions and annotations. 246 | images_dataset : instance of MSCOCOImagesDataset, that is used to retrieve the MS-COCO image for each question 247 | split: either train/val split 248 | visual_input_type: format of visual input to model 249 | 250 | Returns: 251 | DataLoader object 252 | """ 253 | 254 | batch_size = args.batch_size 255 | shuffle = True if split == 'train' else False 256 | 257 | logger.info("Creating VQAv2 {} dataloader with batch size of {}".format(split, batch_size)) 258 | 259 | dataset = VQADataset(data_dir, images_dataset, split, **kwargs) 260 | num_labels = dataset.num_labels 261 | dataloader = torch.utils.data.DataLoader( 262 | dataset, 263 | num_workers=args.num_workers, 264 | batch_size=batch_size, 265 | shuffle=shuffle, 266 | collate_fn=lambda x: vqa_batch_collate(x, visual_input_type)) 267 | return dataloader 268 | 269 | if __name__ == '__main__': 270 | data_dir = '/data/datasets/MCL/vqav2/' 271 | #dataset = VQADataset(data_dir, None, 'train', None) 272 | class Args: 273 | def __init__(self): 274 | self.batch_size = 4 275 | self.shuffle = True 276 | self.num_workers = 2 277 | self.visual_input_type = 'pil-image' 278 | args = Args() 279 | 280 | from transformers import BertTokenizer 281 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 282 | 283 | images_dataset = MSCOCOImagesDataset('/data/datasets/MCL/ms-coco/', args.visual_input_type) 284 | vqa_dataloader = build_vqa_dataloader(args, data_dir, images_dataset, 'val', args.visual_input_type, tokenizer=tokenizer) 285 | 286 | for batch in vqa_dataloader: 287 | pdb.set_trace() 288 | -------------------------------------------------------------------------------- /src/train/train_vision.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import datetime 4 | import json 5 | import logging 6 | import os 7 | import random 8 | import sys 9 | sys.path.insert(0, '.') 10 | import time 11 | import math 12 | import shutil 13 | import pickle as pkl 14 | from PIL import Image 15 | import copy 16 | import pdb 17 | from tqdm import tqdm 18 | 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | from sklearn.metrics import f1_score 23 | from torch.optim import AdamW 24 | from transformers import get_polynomial_decay_schedule_with_warmup 25 | from transformers import BertTokenizer 26 | 27 | from modeling import load_encoder_map 28 | from configs.model_configs import model_configs 29 | from configs.task_configs import task_configs 30 | from utils.seed_utils import set_seed 31 | 32 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 33 | 34 | logging.basicConfig() 35 | logger = logging.getLogger(__name__) 36 | logger.setLevel(logging.DEBUG) 37 | 38 | def train_vision(args, encoder, task_config, model_config, tokenizer, device): 39 | ''' 40 | Train the encoder on the vision-only downstream tasks 41 | 42 | args: arguments provided by user 43 | encoder: trained encoder on upstream tasks 44 | task_config: dictionary containing task-specific configuration parameters for the current task 45 | model_config: dictionary containing model-specific configuration parameters 46 | tokenizer: tokenzier of the backbone encoder 47 | device: cuda/cpu 48 | ''' 49 | 50 | # get upstream algo name for logging 51 | upstream_name = args.checkpoint_name.split('/')[-2] 52 | for short in ['adapter', 'ewc', 'replay', 'sequent', 'bottom9']: 53 | if short in args.checkpoint_name: 54 | upstream_name += f"_{short}" 55 | break 56 | logger.info(f"Upstream Task: {upstream_name}") 57 | 58 | # config 59 | task_name = task_config['task_name'] 60 | num_labels = task_config['num_labels'] 61 | data_dir = task_config['data_dir'] 62 | # coco-obj-cls (multi-label) uses percentage for low-shot learning; other single-label tasks: N-shot per class 63 | n_shot = args.num_shot if args.task_name == 'coco-cls' else int(args.num_shot) 64 | subsample_seed = args.subsample_seed 65 | output_dir = args.output_dir 66 | 67 | # Create model 68 | batch2inputs_converter = model_config['batch2inputs_converter'] 69 | encoder_dim = model_config['encoder_dim'] 70 | visual_input_type = model_config['visual_input_type'] 71 | classifier_class = model_config['classifier_class'] 72 | model = classifier_class(encoder=encoder, 73 | encoder_dim=encoder_dim, 74 | num_labels=num_labels) 75 | 76 | model.to(device) 77 | 78 | # Create dataloaders for training, validation, and test sets 79 | if args.task_name == 'imagenet': 80 | from data.vision_datasets.imagenet_dataset import get_data_loader 81 | elif args.task_name == 'places365': 82 | from data.vision_datasets.places365_dataset import get_data_loader 83 | elif args.task_name == 'inat2019': 84 | from data.vision_datasets.inat2019_dataset import get_data_loader 85 | elif args.task_name == 'coco-cls': 86 | from data.vision_datasets.coco_cls_dataset import get_data_loader 87 | else: 88 | raise NotImplementedError("get_data_loader not impelmented for this task!") 89 | 90 | eval_fn = eval_coco if args.task_name == 'coco-cls' else eval_acc 91 | 92 | train_dataloader = get_data_loader( 93 | args, 94 | data_dir, 95 | 'train', 96 | n_shot, 97 | subsample_seed 98 | ) 99 | val_dataloader = get_data_loader( 100 | args, 101 | data_dir, 102 | 'val', 103 | n_shot 104 | ) 105 | test_dataloader = get_data_loader( 106 | args, 107 | data_dir, 108 | 'test' 109 | ) 110 | 111 | # Training hyperparameters 112 | num_epochs = task_config['num_epochs'] 113 | lr = task_config['lr'] 114 | adam_epsilon = task_config['adam_epsilon'] 115 | weight_decay = task_config['weight_decay'] 116 | warmup_ratio = task_config['warmup_ratio'] 117 | 118 | # Create optimizer 119 | if args.task_name == 'coco-cls': 120 | loss_criterion = nn.BCEWithLogitsLoss() 121 | else: 122 | loss_criterion = nn.CrossEntropyLoss() 123 | 124 | no_decay = ['bias', 'LayerNorm.weight'] 125 | optimizer_grouped_parameters = [ 126 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, 127 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 128 | ] 129 | # https://github.com/dandelin/ViLT/blob/master/vilt/modules/vilt_utils.py#L236 130 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=adam_epsilon, betas=(0.9, 0.98)) 131 | # Create Scheduler 132 | # https://github.com/dandelin/ViLT/blob/master/vilt/modules/vilt_utils.py#L263 133 | max_steps = len(train_dataloader) * num_epochs 134 | scheduler = get_polynomial_decay_schedule_with_warmup( 135 | optimizer, 136 | num_warmup_steps=int(max_steps * warmup_ratio), 137 | num_training_steps=max_steps, 138 | lr_end=0, 139 | power=1, 140 | ) 141 | 142 | 143 | best_score = 0 144 | model.zero_grad() 145 | model.train() 146 | for epoch in range(1, num_epochs+1): 147 | for step, batch in enumerate(tqdm(train_dataloader, desc='Training epoch {}'.format(epoch))): 148 | labels = batch['labels'].to(device) 149 | inputs = batch2inputs_converter(batch) 150 | 151 | logits = model(**inputs) 152 | loss = loss_criterion(logits, labels) 153 | 154 | loss.backward() 155 | optimizer.step() 156 | scheduler.step() 157 | optimizer.zero_grad() 158 | 159 | if step % 50 == 0: 160 | print('loss:', loss.item()) 161 | 162 | # Eval on the val set and update the best model 163 | if epoch > 5 and epoch%2 == 0: 164 | eval_score = eval_fn(args, model, val_dataloader, device, batch2inputs_converter) 165 | logger.info("Evaluation after epoch {}: {:.2f}".format(epoch, eval_score)) 166 | 167 | if eval_score > best_score: 168 | logger.info("New best evaluation score: {:.2f}".format(eval_score)) 169 | best_score = eval_score 170 | best_epoch = epoch 171 | best_model = copy.deepcopy(model) 172 | 173 | # eval the best model on the test set & write the results 174 | test_score = eval_fn(args, best_model, test_dataloader, device, batch2inputs_converter) 175 | 176 | def eval_coco(args, model, eval_dataloader, device, batch2inputs_converter): 177 | ''' 178 | Evaluation for the MS-COCO object classification task, using F1-micro as the metric 179 | 180 | model: the trained model 181 | eval_dataloader: dataloader for evaluation 182 | device: cuda/cpu 183 | batch2inputs_converter: a model-specific fuction that converts inputs to the format the model takes 184 | ''' 185 | 186 | model.eval() 187 | act_fn = nn.Sigmoid() 188 | 189 | all_labels = torch.zeros((len(eval_dataloader.dataset), 80), dtype=torch.bool) 190 | all_preds = all_labels.clone() 191 | offset = 0 192 | for step, batch in enumerate(tqdm(eval_dataloader, desc='Evaluating...')): 193 | labels = batch['labels'] 194 | inputs = batch2inputs_converter(batch) 195 | with torch.no_grad(): 196 | logits = model(**inputs) 197 | preds = act_fn(logits) > 0.5 198 | 199 | all_labels[offset: offset+len(labels)] = labels.bool().cpu() 200 | all_preds[offset: offset+len(labels)] = preds.cpu() 201 | 202 | offset += len(labels) 203 | 204 | f1 = f1_score(all_labels, all_preds, average='micro')*100.0 205 | logger.info(f'Eval_F1: {f1:.3f}') 206 | 207 | model.train() 208 | return f1 209 | 210 | 211 | def eval_acc(args, model, eval_dataloader, device, batch2inputs_converter): 212 | ''' 213 | Evaluation on the dev set and test set, using accuracy as the metric 214 | 215 | model: the trained model 216 | eval_dataloader: dataloader for evaluation 217 | device: cuda/cpu 218 | batch2inputs_converter: a model-specific fuction that converts inputs to the format the model takes 219 | ''' 220 | 221 | model.eval() 222 | eval_score = 0 223 | for step, batch in enumerate(tqdm(eval_dataloader, desc='Evaluating...')): 224 | labels = batch['labels'] 225 | inputs = batch2inputs_converter(batch) 226 | with torch.no_grad(): 227 | logits = model(**inputs) 228 | 229 | batch_scores = (logits.argmax(-1).cpu() == labels) 230 | eval_score += batch_scores.sum().item() 231 | 232 | eval_score = eval_score/len(eval_dataloader.dataset)*100.0 233 | logger.info(f'Eval_acc: {eval_score:.3f}') 234 | 235 | model.train() 236 | return eval_score 237 | 238 | 239 | 240 | def main(): 241 | 242 | parser = argparse.ArgumentParser() 243 | 244 | ## Required parameters 245 | parser.add_argument("--task_name", default=None, type=str, required=True, 246 | help="The name of the vision-only task.") 247 | parser.add_argument("--encoder_name", default=None, type=str, required=True, choices=['vilt'], 248 | help="The name of the base pretrained encoder.") 249 | parser.add_argument("--model_catog", default='vilt-vl', type=str, 250 | help="The catogory for model class.") 251 | parser.add_argument("--checkpoint_name", default=None, type=str, required=True, 252 | help="Name of the checkpoint model load.") 253 | parser.add_argument("--pretrained_model_name", default="dandelin/vilt-b32-mlm", type=str, 254 | help="Name of the pretrained model") 255 | parser.add_argument("--output_dir", type=str, required=True, 256 | help="Name of output directory, where all experiment results and checkpoints are saved.") 257 | 258 | parser.add_argument("--batch_size", type=int, default=32, 259 | help="Batch size.") 260 | parser.add_argument("--num_workers", type=int, default=2, 261 | help="Number of workers for dataloader") 262 | parser.add_argument("--seed", type=int, default=42, 263 | help="Random seed.") 264 | 265 | # only used by few-shot downstream tasks 266 | parser.add_argument("--num_shot", type=float, 267 | help="Number of training data per class OR the ratio of the original training set") 268 | parser.add_argument("--subsample_seed", type=int, 269 | help="Random seed for few-shot sampling.") 270 | 271 | 272 | args = parser.parse_args() 273 | print(args) 274 | 275 | if not os.path.exists(args.output_dir): 276 | os.makedirs(args.output_dir) 277 | 278 | device = torch.device( 279 | "cuda" if torch.cuda.is_available() else "cpu") 280 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 281 | 282 | set_seed(args) 283 | 284 | 285 | # Load the Encoder model 286 | model_config = model_configs[args.model_catog] 287 | load_encoder_method = load_encoder_map[args.encoder_name] 288 | encoder = load_encoder_method(args.checkpoint_name, device, args.pretrained_model_name) 289 | 290 | results = [] 291 | logger.info("-"*100) 292 | logger.info("Training models on downstream vision-only tasks...") 293 | 294 | # Load the correct training method for current CL task, and call the training method 295 | task_config = task_configs[args.task_name] 296 | logger.info("-"*100) 297 | train_vision(args, encoder, task_config, model_config, tokenizer, device) 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /src/modeling/models/albef_model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from src.modeling.models.vit import VisionTransformer 3 | from src.modeling.models.xbert import BertConfig, BertModel, BertLMHeadModel 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | import numpy as np 10 | 11 | 12 | class ALBEF(nn.Module): 13 | def __init__(self, 14 | text_encoder=None, 15 | text_decoder=None, 16 | tokenizer=None, 17 | config=None, 18 | ): 19 | super().__init__() 20 | 21 | self.tokenizer = tokenizer 22 | self.distill = config["distill"] 23 | 24 | self.visual_encoder = VisionTransformer( 25 | img_size=config["image_res"], patch_size=16, embed_dim=768, depth=12, num_heads=12, 26 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 27 | adapter_config=config["adapter_config"] if 'adapter_config' in config.keys() else None 28 | ) 29 | 30 | config_encoder = BertConfig(**config["bert_config"]) 31 | config_decoder = BertConfig(**config["bert_config"]) 32 | config_decoder.fusion_layer = 0 33 | config_decoder.num_hidden_layers = 6 34 | 35 | if 'adapter_config' in config.keys(): 36 | config_encoder.adapter_config = config["adapter_config"] 37 | config_decoder.adapter_config = config["adapter_config"] 38 | 39 | # self.text_encoder = BertModel.from_pretrained(text_encoder,force_download=True, config=config_encoder, add_pooling_layer=False) 40 | BERT_LOCAL_PATH = "/home/stud/yyang/CARVEN/bert-base-uncased" 41 | self.text_encoder = BertModel.from_pretrained(BERT_LOCAL_PATH, local_files_only=True, config=config_encoder, add_pooling_layer=False) 42 | self.text_decoder = BertLMHeadModel.from_pretrained(BERT_LOCAL_PATH, local_files_only=True, config=config_decoder) 43 | 44 | if self.distill: 45 | self.visual_encoder_m = VisionTransformer( 46 | img_size=config["image_res"], patch_size=16, embed_dim=768, depth=12, num_heads=12, 47 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 48 | adapter_config=config["adapter_config"] if 'adapter_config' in config.keys() else None) 49 | ##online model changed to local model 50 | self.text_encoder_m = BertModel.from_pretrained(BERT_LOCAL_PATH, local_files_only=True, config=config_encoder, add_pooling_layer=False) 51 | self.text_decoder_m = BertLMHeadModel.from_pretrained(BERT_LOCAL_PATH, local_files_only=True, config=config_decoder) 52 | self.model_pairs = [[self.visual_encoder, self.visual_encoder_m], 53 | [self.text_encoder, self.text_encoder_m], 54 | [self.text_decoder, self.text_decoder_m], 55 | ] 56 | self.copy_params() 57 | self.momentum = 0.995 58 | 59 | def set_active_gating(self): 60 | for i in range(len(self.text_encoder.encoder.layer)): 61 | self.text_encoder.encoder.layer[i].output.adapter.gating_activated = True 62 | 63 | for i in range(len(self.text_decoder.bert.encoder.layer)): 64 | self.text_decoder.bert.encoder.layer[i].output.adapter.gating_activated = True 65 | 66 | for i in range(len(self.visual_encoder.blocks)): 67 | self.visual_encoder.blocks[i].adapter.gating_activated = True 68 | 69 | def forward(self, image, question, answer=None, alpha=0, k=None, weights=None, train=True, prev_f=None): 70 | 71 | image_embeds = self.visual_encoder(image) 72 | if prev_f is not None: 73 | image_embeds += self.pnn_layer_visual(prev_f[0]) 74 | 75 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) 76 | 77 | if train: 78 | """ 79 | k: number of answers for each question 80 | weights: weight for each answer 81 | """ 82 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) 83 | 84 | question_output = self.text_encoder(question.input_ids, 85 | attention_mask=question.attention_mask, 86 | encoder_hidden_states=image_embeds, 87 | encoder_attention_mask=image_atts, 88 | return_dict=True) 89 | if prev_f is not None: 90 | text_embeds = self.pnn_layer_text(prev_f[1]) 91 | 92 | question_states = [] 93 | question_atts = [] 94 | for b, n in enumerate(k): 95 | question_states += [question_output.last_hidden_state[b]] * n 96 | question_atts += [question.attention_mask[b]] * n 97 | question_states = torch.stack(question_states, 0) 98 | question_atts = torch.stack(question_atts, 0) 99 | 100 | if self.distill: 101 | with torch.no_grad(): 102 | # to do 103 | self._momentum_update() 104 | image_embeds_m = self.visual_encoder_m(image) 105 | question_output_m = self.text_encoder_m(question.input_ids, 106 | attention_mask=question.attention_mask, 107 | encoder_hidden_states=image_embeds_m, 108 | encoder_attention_mask=image_atts, 109 | return_dict=True) 110 | 111 | question_states_m = [] 112 | for b, n in enumerate(k): 113 | question_states_m += [question_output_m.last_hidden_state[b]] * n 114 | question_states_m = torch.stack(question_states_m, 0) 115 | 116 | logits_m = self.text_decoder_m(answer.input_ids, 117 | attention_mask=answer.attention_mask, 118 | encoder_hidden_states=question_states_m, 119 | encoder_attention_mask=question_atts, 120 | return_logits=True, 121 | ) 122 | 123 | answer_output = self.text_decoder(answer.input_ids, 124 | attention_mask=answer.attention_mask, 125 | encoder_hidden_states=question_states, 126 | encoder_attention_mask=question_atts, 127 | labels=answer_targets, 128 | return_dict=True, 129 | soft_labels=F.softmax(logits_m, dim=-1), 130 | alpha=alpha, 131 | reduction="none", 132 | ) 133 | else: 134 | answer_output = self.text_decoder(answer.input_ids, 135 | attention_mask=answer.attention_mask, 136 | encoder_hidden_states=question_states, 137 | encoder_attention_mask=question_atts, 138 | labels=answer_targets, 139 | return_dict=True, 140 | reduction="none", 141 | ) 142 | loss = weights * answer_output.loss 143 | loss = loss.sum() / image.size(0) 144 | 145 | return (loss, answer_output.logits[:, :-1, :].contiguous()) # logits: (batch, words, vocab_size(30522)) 146 | 147 | 148 | else: 149 | question_output = self.text_encoder(question.input_ids, # tokenized question 150 | attention_mask=question.attention_mask, 151 | encoder_hidden_states=image_embeds, 152 | encoder_attention_mask=image_atts, 153 | return_dict=True) # last_hidden_state: (batch, words, 768) 154 | topk_ids, topk_probs = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 155 | answer.input_ids, answer.attention_mask, k) # answer.input_ids: [num_answers, max_len]; k=128 156 | return topk_ids, topk_probs 157 | 158 | @torch.no_grad() 159 | def copy_params(self): 160 | for model_pair in self.model_pairs: 161 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 162 | param_m.data.copy_(param.data) # initialize 163 | param_m.requires_grad = False # not update by gradient 164 | 165 | @torch.no_grad() 166 | def _momentum_update(self): 167 | for model_pair in self.model_pairs: 168 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 169 | param_m.data = param_m.data * self.momentum + param.data * (1.0 - self.momentum) 170 | 171 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): 172 | # question_states: last_hidden_state of Multimodel Encoder; answer_ids: tokenized answers 173 | num_ques = question_states.size(0) 174 | start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token 175 | 176 | start_output = self.text_decoder(start_ids, 177 | encoder_hidden_states=question_states, 178 | encoder_attention_mask=question_atts, 179 | return_dict=True, 180 | reduction="none") # logits: (batch, word, 30522), word here is bos token 181 | logits = start_output.logits[:, 0, :] # first token's logit, (batch, 30522) 182 | 183 | # topk_probs: top-k probability 184 | # topk_ids: [num_question, k] 185 | answer_first_token = answer_ids[:, 1] # [num_answers,] 186 | prob_first_token = F.softmax(logits, dim=1).index_select(dim=1, index=answer_first_token) # (batch, num_answers) 187 | topk_probs, topk_ids = prob_first_token.topk(k, dim=1) 188 | 189 | # answer input: [num_question*k, answer_len] 190 | input_ids = [] 191 | input_atts = [] 192 | for b, topk_id in enumerate(topk_ids): 193 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 194 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 195 | input_ids = torch.cat(input_ids, dim=0) # (num_question*k, answer_len) 196 | input_atts = torch.cat(input_atts, dim=0) 197 | 198 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) 199 | 200 | # repeat encoder's output for top-k answers 201 | question_states = tile(question_states, 0, k) # (num_question*k, words, 768) 202 | question_atts = tile(question_atts, 0, k) 203 | 204 | output = self.text_decoder(input_ids, # logits(num_question*k, answer_len, 30522) 205 | attention_mask=input_atts, 206 | encoder_hidden_states=question_states, 207 | encoder_attention_mask=question_atts, 208 | labels=targets_ids, 209 | return_dict=True, 210 | reduction="none") 211 | 212 | answer_loss = output.loss 213 | answer_loss = answer_loss.view(input_ids.size(0), -1) # (num_question*k, 1) 214 | 215 | # topk_prob: first token probability 216 | topk_probs = topk_probs.view(-1, 1) # (num_question*k, 1) 217 | log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) # (num_question*k, 2) 218 | 219 | # re-calculate log probabilities for the answer sequences using chain rule 220 | log_probs_sum = log_probs.sum(1) 221 | log_probs_sum = log_probs_sum.view(num_ques, k) # (num_question, k) 222 | 223 | topk_probs = F.softmax(log_probs_sum, dim=-1) # (num_question, k) 224 | # get top-k after re-ranking 225 | topk_probs, rerank_id = topk_probs.topk(k, dim=1) # (num_question, k) 226 | topk_ids = torch.gather(topk_ids, 1, rerank_id) # (num_question, k) 227 | 228 | return topk_ids, topk_probs 229 | 230 | 231 | def tile(x, dim, n_tile): 232 | init_dim = x.size(dim) 233 | repeat_idx = [1] * x.dim() 234 | repeat_idx[dim] = n_tile 235 | x = x.repeat(*(repeat_idx)) 236 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 237 | return torch.index_select(x, dim, order_index.to(x.device)) 238 | -------------------------------------------------------------------------------- /src/train/train_lowshot_multimodal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import time 9 | import math 10 | import shutil 11 | import pickle as pkl 12 | import copy 13 | import yaml 14 | import pdb 15 | 16 | sys.path.insert(0, '.') 17 | 18 | import numpy as np 19 | import torch 20 | from tqdm import tqdm 21 | 22 | from transformers.adapters import AdapterConfig 23 | 24 | from modeling import load_encoder_map, create_continual_learner_map 25 | 26 | from cl_algorithms import ExperienceReplayMemory, EWC 27 | from cl_evaluation.evaluate_cl_algorithm import upstream_knowledge_transfer_eval, catastrophic_forgetting_eval 28 | from configs.model_configs import model_configs, ALLOWED_CL_ENCODERS 29 | from configs.task_configs import task_configs, SUPPORTED_VL_TASKS 30 | from configs.adapter_configs import ADAPTER_MAP 31 | from utils.seed_utils import set_seed 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | device = torch.device( 36 | "cuda" if torch.cuda.is_available() else "cpu") 37 | #device = torch.device("cpu") 38 | 39 | def train_low_shot(args, low_shot_model, low_shot_task_key, model_config, device): 40 | 41 | low_shot_task_name = task_configs[low_shot_task_key]['task_name'] 42 | low_shot_config = task_configs[low_shot_task_key]['low_shot_config'] 43 | 44 | # Create the Trainer method for the current CL task, and call the train method 45 | logger.info("Training {} model on low-shot task {}, low_shot_config={}".format(args.encoder_name, 46 | low_shot_task_name, 47 | low_shot_config)) 48 | task_trainer_class = low_shot_config['task_trainer'] 49 | task_trainer = task_trainer_class(args, task_configs, model_config, device, low_shot_config=low_shot_config) 50 | best_eval_score, best_model = task_trainer.train(low_shot_model) 51 | 52 | return best_eval_score, low_shot_config 53 | 54 | def main(): 55 | 56 | parser = argparse.ArgumentParser() 57 | 58 | ## Required parameters 59 | parser.add_argument("--encoder_name", default=None, type=str, required=True, choices=ALLOWED_CL_ENCODERS, 60 | help="The name of the base pretrained encoder.") 61 | parser.add_argument("--pretrained_model_name", default=None, type=str, required=True, 62 | help="Name of pretrained model weights to load.") 63 | parser.add_argument("--ordered_cl_tasks", type=str, required=True, 64 | help="Ordered list of VL task keys for continual learning, seprated by commas.") 65 | parser.add_argument("--cl_algorithm", type=str, required=True, choices=['singletask_ft', 66 | 'sequential_ft', 67 | 'experience_replay', 68 | 'ewc', 69 | 'adapter', 70 | 'freeze_encoder', 71 | 'freeze_bottom_k_layers'], 72 | help="Name of Continual Learning algorithm used.") 73 | parser.add_argument("--climb_data_dir", type=str, required=True, default='/data/datasets/MCL/', 74 | help="Directory where all the MCL data is stored") 75 | 76 | # Arguments specific to experience replay algorithm 77 | parser.add_argument("--memory_percentage", type=float, default=0.0, 78 | help="Percentage of tasks' training samples saved into memory.") 79 | parser.add_argument("--memory_sampling_strategy", type=str, choices=['random', 'random-balanced'], 80 | help="Strategy for sampling memory buffer samples.") 81 | parser.add_argument("--replay_frequency", type=int, 82 | help="Number of training steps after which to do a memory replay step.") 83 | 84 | # Arguments specific to Adapters algorithm 85 | parser.add_argument("--adapter_config", choices=list(ADAPTER_MAP.keys()), 86 | help="Type of Adapter architecture") 87 | parser.add_argument("--adapter_reduction_factor", type=int, default=0, 88 | help="Downsampling ratio for adapter layers") 89 | 90 | # Arguments specific to EWC algorithm 91 | parser.add_argument("--ewc_fisher_sample_percentage", type=float, default=0.0, 92 | help="Percentage of training samples for computing Fisher information matrix per task") 93 | parser.add_argument("--ewc_loss_weight", type=float, default=0.0, 94 | help="Factoring for scaling the EWC loss") 95 | 96 | # Arguments specific to frozen bottom-k layers algorithm 97 | parser.add_argument("--layers_to_freeze", type=int, default=0, 98 | help="Number of layers to freeze (if freezing bottom-k layers)") 99 | 100 | parser.add_argument("--output_dir", type=str, required=True, 101 | help="Name of output directory, where all experiment results and checkpoints are saved.") 102 | 103 | parser.add_argument("--batch_size", type=int, default=32, 104 | help="Batch size.") 105 | parser.add_argument("--num_workers", type=int, default=2, 106 | help="Number of workers for dataloader") 107 | parser.add_argument("--seed", type=int, default=42, 108 | help="Random seed.") 109 | 110 | args = parser.parse_args() 111 | args.ordered_cl_tasks = args.ordered_cl_tasks.split(',') 112 | 113 | # --------------------- Set up experiment directories 114 | experiment_name = '{}-{}'.format(args.encoder_name, args.cl_algorithm) 115 | if args.cl_algorithm == 'adapter': 116 | experiment_name = '{}_{}'.format(experiment_name, args.adapter_config) 117 | elif args.cl_algorithm == 'freeze_bottom_k_layers': 118 | experiment_name = experiment_name.replace('_k_layers', '{}layers'.format(args.layers_to_freeze)) 119 | for i, task_key in enumerate(args.ordered_cl_tasks): 120 | experiment_name = '{}-task{}_{}'.format(experiment_name, i, task_key) 121 | output_dir = os.path.join(args.output_dir, experiment_name) 122 | results_file = os.path.join(output_dir, 'lowshot_results.json') 123 | if not os.path.isdir(output_dir): 124 | os.makedirs(output_dir) 125 | 126 | set_seed(args) 127 | 128 | # --------------------- Ensure all the tasks for continual learning are supported VL tasks --------------------- 129 | for task_key in args.ordered_cl_tasks: 130 | assert task_key in SUPPORTED_VL_TASKS 131 | 132 | # --------------------- Load the correct ContinualLeaner model, based on encoder_name argument --------------------- 133 | model_config = model_configs[args.encoder_name] 134 | create_model_method = create_continual_learner_map[args.encoder_name] 135 | model = create_model_method(model_name_or_path=args.pretrained_model_name, 136 | ordered_cl_tasks=args.ordered_cl_tasks, 137 | model_config=model_config, 138 | task_configs=task_configs, 139 | device=device) 140 | args.visual_input_type = model_config['visual_input_type'] 141 | 142 | # ------------------------------------------ Print some model info ------------------------------------------ 143 | logger.info("Succesfully initialized {}-based Continual Learner".format(model_config['encoder_name'])) 144 | logger.info("{} task heads: {}".format(len(args.ordered_cl_tasks), ','.join(args.ordered_cl_tasks))) 145 | logger.info("CL Algorithm: {}".format(args.cl_algorithm)) 146 | total_params = sum(p.numel() for p in model.parameters()) 147 | logger.info('Total Parameters: {:.2f}M'.format(total_params*10**-6)) 148 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad == True) 149 | logger.info('Trainable Parameters: {:.2f}M ({:.2f}%)'.format(trainable_params*10**-6, (trainable_params/total_params*100))) 150 | logger.info('Model checkpoints loaded from {}'.format(output_dir)) 151 | logger.info("-"*100) 152 | 153 | results = [] 154 | if os.path.exists(results_file): 155 | results = json.load(open(results_file)) 156 | logger.info("-"*100) 157 | logger.info("Cached results:") 158 | for i, r in enumerate(results): 159 | task_key = r['task_key'] 160 | best_score = r['best_low_shot_score'] 161 | logger.info("Task #{}: {} - best score = {:.2f}".format(i+1, task_configs[task_key]['task_name'], best_score)) 162 | task_trainers = {} 163 | 164 | logger.info("-"*100) 165 | logger.info("Doing low-shot transfer to Vision-Language tasks...") 166 | 167 | if args.cl_algorithm == 'singletask_ft': 168 | # --------------------- Do low-shot training of pre-trained encoder on a single task --------------------- 169 | task_key = args.ordered_cl_tasks[0] 170 | low_shot_model = copy.deepcopy(model) 171 | low_shot_eval_score, low_shot_config = train_low_shot(args, low_shot_model, task_key, model_config, device) 172 | logger.info("Best {} evaluation score = {:.2f}".format(task_key, low_shot_eval_score)) 173 | 174 | # --------------------- Save low-shot results --------------------- 175 | config_copy = copy.deepcopy(low_shot_config) 176 | config_copy.pop('task_trainer', None) 177 | task_results = { 178 | 'task_key': task_key, 179 | 'best_low_shot_score': low_shot_eval_score, 180 | 'low_shot_config': config_copy, 181 | } 182 | results.append(task_results) 183 | json.dump(results, open(results_file, 'w')) 184 | logger.info("Saved low-shot transfer results so far!") 185 | 186 | else: 187 | # Iterate through each task, load its checkpoint and do low-shot training on all tasks after it 188 | for task_num, task_key in enumerate(args.ordered_cl_tasks): 189 | 190 | # --------------------- Find model checkpoint for this task, load the checkpoint and move onto next CL task --------------------- 191 | logger.info("-"*100) 192 | task_name = task_configs[task_key]['task_name'] 193 | task_output_dir = os.path.join(output_dir, 'checkpoints', 'task{}_{}'.format(task_num, task_key)) 194 | 195 | assert os.path.exists(os.path.join(task_output_dir, 'model')) 196 | logger.info("Found checkpoint for task {}!".format(task_name)) 197 | try: 198 | model.load_state_dict(torch.load(os.path.join(task_output_dir, 'model'))) 199 | except Exception as e: 200 | ckpt_state_dict = torch.load(os.path.join(task_output_dir, 'model')) 201 | initialized = {k: False for k in model.state_dict().keys()} 202 | for k in ckpt_state_dict.keys(): 203 | model.state_dict()[k].copy_(ckpt_state_dict[k]) 204 | initialized[k] = True 205 | logger.info("Uninitialized keys: {}".format(','.join([k for k in initialized.keys() if initialized[k] is False]))) 206 | torch.save(model.state_dict(), os.path.join(task_output_dir, 'model')) 207 | logger.info("Saved model with uninitialized keys as new checkpoint") 208 | logger.info("Loaded model checkpoint from task {}!".format(task_name)) 209 | 210 | # --------------------- Do low-shot training on all tasks after task_num --------------------- 211 | low_shot_tasks = args.ordered_cl_tasks[task_num+1:] 212 | logger.info("Doing low-shot transfer to tasks {} using checkpoint from {}".format(','.join(low_shot_tasks), task_name)) 213 | 214 | for low_shot_task_key in low_shot_tasks: 215 | low_shot_task_num = args.ordered_cl_tasks.index(low_shot_task_key) 216 | low_shot_model = copy.deepcopy(model) 217 | low_shot_eval_score, low_shot_config = train_low_shot(args, low_shot_model, low_shot_task_key, model_config, device) 218 | logger.info("Best {} evaluation score = {:.2f}".format(low_shot_task_key, low_shot_eval_score)) 219 | 220 | 221 | # --------------------- Save low-shot results so far --------------------- 222 | config_copy = copy.deepcopy(low_shot_config) 223 | config_copy.pop('task_trainer', None) 224 | task_results = { 225 | 'upstream_task_num': task_num, 226 | 'upstream_task_key': task_key, 227 | 'lowshot_task_num': low_shot_task_num, 228 | 'lowshot_task_key': low_shot_task_key, 229 | 'best_low_shot_score': low_shot_eval_score, 230 | 'low_shot_config': config_copy, 231 | } 232 | results.append(task_results) 233 | json.dump(results, open(results_file, 'w')) 234 | logger.info("Saved low-shot transfer results so far!") 235 | logger.info("-"*100) 236 | logger.info("-"*100) 237 | 238 | if __name__ == '__main__': 239 | main() 240 | -------------------------------------------------------------------------------- /src/modeling/albef.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import itertools 5 | import pdb 6 | import time 7 | from PIL import Image 8 | from typing import List, Dict 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from transformers import BertConfig, BertTokenizer, BertModel 16 | from src.modeling.models.albef_model import ALBEF 17 | from src.modeling.models.vit import interpolate_pos_embed 18 | from transformers import BertTokenizerFast 19 | from transformers import logging as transformers_logging 20 | 21 | from src.modeling.continual_learner import EncoderWrapper, ContinualLearner 22 | 23 | 24 | class ALBEFWrapper(EncoderWrapper): 25 | 26 | def __init__(self, albef: ALBEF, device: torch.device): 27 | """ 28 | Wrapper around ALBEF model from huggingface library 29 | this is the class that gets saved during checkpointing for continual learning 30 | args: 31 | albef - instance of ALBEF class 32 | device - gpu/cuda 33 | """ 34 | 35 | super().__init__() 36 | self.albef = albef 37 | self.device = device 38 | 39 | BERT_LOCAL_PATH = './models/bert-base-uncased' 40 | self.tokenizer = BertTokenizer.from_pretrained(BERT_LOCAL_PATH, local_files_only=True) 41 | 42 | def forward(self, batch) -> torch.FloatTensor: 43 | """ 44 | Does forward pass of input encodings through ALBEF 45 | 46 | Args: 47 | batch: List containing inputs ALBEF's forward pass 48 | 49 | Returns: 50 | loss 51 | """ 52 | images = batch["images"].to(self.device, non_blocking=True) 53 | 54 | if batch["train"]: 55 | weights = batch["weights"].to(self.device, non_blocking=True) 56 | question_input = self.tokenizer(batch["questions"], padding="longest", truncation=True, max_length=25, return_tensors="pt", ).to(self.device) 57 | answer_input = self.tokenizer(batch["answers"], padding="longest", return_tensors="pt").to(self.device) 58 | loss, logits = self.albef(image=images, question=question_input, answer=answer_input, train=True, 59 | alpha=batch["alpha"], k=batch["n"], weights=weights,) 60 | return [loss, logits] 61 | else: 62 | question_input = self.tokenizer(batch["questions"], padding="longest", return_tensors="pt").to(self.device) 63 | answer_list = [answer + "[SEP]" for answer in batch["answer_list"]] 64 | answer_input = self.tokenizer(answer_list, padding="longest", return_tensors="pt").to(self.device) 65 | # todo: double-check: ALBEF pads to longest answer in the answer list, but we pad to 'max_length' because logits used in CL methods need to be of the same shape 66 | # abstract: 8; med: 49; pvqa: 51; art: 5; toronto: 6 67 | # answer_input: dict{inpput_ids: tensor, attention_mask: tensor, token_type_ids: tensor}, all tensors are of shape (num_answers, words) 68 | # if the answers from different dataset need to be of the same length, use following line 69 | # answer_input = self.tokenizer(answer_list, padding='max_length', return_tensors='pt', max_length=51).to(self.device) 70 | 71 | topk_ids, topk_probs = self.albef(image=images, question=question_input, answer=answer_input, train=False, k=batch["k"], ) 72 | return [topk_ids, topk_probs] 73 | 74 | def freeze_all_weights(self): # todo 75 | """ 76 | Freeze all parameters in self.albef 77 | """ 78 | 79 | for p in self.albef.parameters(): 80 | p.requires_grad = False 81 | 82 | def freeze_bottom_k_layers(self, k: int): # todo 83 | """ 84 | Freeze embedding parameters and bottom K transformer layer parameters 85 | """ 86 | 87 | assert k < len(self.albef.encoder.layer) 88 | for p in self.albef.embeddings.parameters(): 89 | p.requires_grad = False 90 | for i in range(k): 91 | for p in self.albef.encoder.layer[i].parameters(): 92 | p.requires_grad = False 93 | 94 | def freeze_encoder(self): # todo 95 | raise NotImplementedError 96 | 97 | 98 | class ALBEFContinualLearner(ContinualLearner): 99 | # I should be ALBEF now 100 | def __init__(self, ordered_cl_tasks: List[str], albef_model: ALBEFWrapper, task_configs: Dict): 101 | """ 102 | The actual Continual Learning model 103 | 104 | arguments: 105 | ordered_cl_tasks - list of CL task keys that will be encountered by the ContinualLearner 106 | albef_model - instance of ALBEFEncoderWrapper class 107 | task_configs - dictionary which contains task-specific configurations/hparams for each task in ordered_cl_tasks, not used for now 108 | """ 109 | 110 | super().__init__() 111 | self.albef_model = albef_model 112 | # self.ordered_cl_tasks = ordered_cl_tasks 113 | # self.task_configs = task_configs 114 | # 115 | # self.task_layer_dict = {} # 116 | # for task_key in ordered_cl_tasks: 117 | # self.add_task_layer(task_key, task_configs[task_key]) 118 | # self.task_layer = nn.ModuleDict(self.task_layer_dict) 119 | 120 | def set_active_lora(self): 121 | import loralib as lora 122 | for i in range(len(self.albef_model.albef.text_encoder.encoder.layer)): 123 | in_f, out_f = self.albef_model.albef.text_encoder.encoder.layer[i].attention.self.query.weight.shape[:2] 124 | self.albef_model.albef.text_encoder.encoder.layer[i].attention.self.query = lora.Linear(in_f, out_f, r=16) 125 | self.albef_model.albef.text_encoder.encoder.layer[i].attention.self.value = lora.Linear(in_f, out_f, r=16) 126 | 127 | for i in range(len(self.albef_model.albef.text_decoder.bert.encoder.layer)): 128 | in_f, out_f = self.albef_model.albef.text_decoder.bert.encoder.layer[i].attention.self.query.weight.shape[:2] 129 | self.albef_model.albef.text_decoder.bert.encoder.layer[i].attention.self.query = lora.Linear(in_f, out_f, r=16) 130 | self.albef_model.albef.text_decoder.bert.encoder.layer[i].attention.self.value = lora.Linear(in_f, out_f, r=16) 131 | 132 | from src.modeling.adaptered_output import Attention_lorad 133 | for i in range(len(self.albef_model.albef.visual_encoder.blocks)): 134 | self.albef_model.albef.visual_encoder.blocks[i].attn = Attention_lorad( 135 | self.albef_model.albef.visual_encoder.blocks[i].attn, 136 | 768 137 | ) 138 | 139 | def set_active_adapter(self, name): 140 | for i in range(len(self.albef_model.albef.text_encoder.encoder.layer)): 141 | self.albef_model.albef.text_encoder.encoder.layer[i].output.adapter.set_active_adapter(name) 142 | 143 | for i in range(len(self.albef_model.albef.text_decoder.bert.encoder.layer)): 144 | self.albef_model.albef.text_decoder.bert.encoder.layer[i].output.adapter.set_active_adapter(name) 145 | 146 | for i in range(len(self.albef_model.albef.visual_encoder.blocks)): 147 | self.albef_model.albef.visual_encoder.blocks[i].adapter.set_active_adapter(name) 148 | 149 | def deactivate_gating(self): 150 | for i in range(len(self.albef_model.albef.text_encoder.encoder.layer)): 151 | self.albef_model.albef.text_encoder.encoder.layer[i].output.adapter.deactivate_gating() 152 | 153 | for i in range(len(self.albef_model.albef.text_decoder.bert.encoder.layer)): 154 | self.albef_model.albef.text_decoder.bert.encoder.layer[i].output.adapter.deactivate_gating() 155 | 156 | for i in range(len(self.albef_model.albef.visual_encoder.blocks)): 157 | self.albef_model.albef.visual_encoder.blocks[i].adapter.deactivate_gating() 158 | 159 | def activate_gating(self): 160 | for i in range(len(self.albef_model.albef.text_encoder.encoder.layer)): 161 | self.albef_model.albef.text_encoder.encoder.layer[i].output.adapter.activate_gating() 162 | 163 | for i in range(len(self.albef_model.albef.text_decoder.bert.encoder.layer)): 164 | self.albef_model.albef.text_decoder.bert.encoder.layer[i].output.adapter.activate_gating() 165 | 166 | for i in range(len(self.albef_model.albef.visual_encoder.blocks)): 167 | self.albef_model.albef.visual_encoder.blocks[i].adapter.activate_gating() 168 | 169 | def forward(self, task_key: str, batch: Dict): 170 | """ 171 | Does forward pass of image and text inputs through model, 172 | 173 | Args: 174 | task_key - string which indicates which task to do forward pass for 175 | 176 | Returns: 177 | https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/output#transformers.modeling_outputs.BaseModelOutputWithPooling 178 | """ 179 | 180 | # task_config = self.task_configs[task_key] 181 | output = self.albef_model(batch) 182 | return output 183 | 184 | 185 | def load_albef(logger, model_config, checkpoint_name: str, device: torch.device, pretrained_albef_name: str, ) -> ALBEFWrapper: 186 | """ 187 | Method to load ALBEFWrapper, around specified pre-trained albef 188 | 189 | args: 190 | checkpoint_name: name of ALBEF checkpoint to load encoder from 191 | device: torch.device 192 | pretrained_albef_name: pretrained albef name for processor/config 193 | 194 | returns: 195 | albef_model: ALBEFWrapper initialized with checkpoint 196 | """ 197 | logger.info("-" * 100) 198 | logger.info("Loading ALBEF model: {}".format(checkpoint_name)) 199 | 200 | BERT_LOCAL_PATH = './models/bert-base-uncased' 201 | logger.info("Loading tokenizer from: {}".format(BERT_LOCAL_PATH)) 202 | tokenizer = BertTokenizer.from_pretrained(BERT_LOCAL_PATH, local_files_only=True) 203 | 204 | if checkpoint_name == pretrained_albef_name: # load pretrained albef model 205 | logger.info("Loading pretrained ALBEF model: {}".format(pretrained_albef_name)) 206 | model = ALBEF(config=model_config, text_encoder=model_config["text_encoder"], text_decoder=model_config["text_decoder"], tokenizer=tokenizer, ) 207 | 208 | checkpoint = torch.load(pretrained_albef_name, map_location="cpu") 209 | state_dict = checkpoint["model"] 210 | # todo: if load pre-fine-tuned albef, comment out the following two lines 211 | pos_embed_reshaped = interpolate_pos_embed(state_dict["visual_encoder.pos_embed"], model.visual_encoder) 212 | state_dict["visual_encoder.pos_embed"] = pos_embed_reshaped 213 | 214 | # todo: if load pre-fine-tuned albef, comment out the following if 215 | if model_config["distill"]: 216 | m_pos_embed_reshaped = interpolate_pos_embed(state_dict["visual_encoder_m.pos_embed"], model.visual_encoder_m) 217 | state_dict["visual_encoder_m.pos_embed"] = m_pos_embed_reshaped 218 | 219 | for key in list(state_dict.keys()): 220 | if "bert" in key: 221 | encoder_key = key.replace("bert.", "") 222 | state_dict[encoder_key] = state_dict[key] 223 | # intialize text decoder as multimodal encoder (last 6 layers of model.text_encoder) 224 | if "text_encoder" in key: 225 | if "layer" in key: 226 | encoder_keys = key.split(".") 227 | layer_num = int(encoder_keys[4]) 228 | if layer_num < 6: 229 | del state_dict[key] 230 | continue 231 | else: 232 | decoder_layer_num = layer_num - 6 233 | encoder_keys[4] = str(decoder_layer_num) 234 | encoder_key = ".".join(encoder_keys) 235 | else: 236 | encoder_key = key 237 | decoder_key = encoder_key.replace("text_encoder", "text_decoder") 238 | state_dict[decoder_key] = state_dict[key] 239 | 240 | del state_dict[key] 241 | model.load_state_dict(state_dict, strict=False) 242 | 243 | albef_model = ALBEFWrapper(model, device) 244 | 245 | else: 246 | raise ValueError("Checkpoint name {} not supported".format(checkpoint_name)) 247 | 248 | logger.info("Successfully loaded pretrained ALBEF") 249 | return albef_model 250 | 251 | 252 | def create_albef_continual_learner_model(logger, model_name_or_path: str, ordered_cl_tasks: List[str], model_config: Dict, task_configs: Dict, device: torch.device, ): 253 | """ 254 | Creates an instance of ALBEFContinualLearner, with the encoder initialized from model_name_or_path 255 | 256 | Args: 257 | model_name_or_path: Name/path of model to load encoder checkpoint from 258 | ordered_cl_tasks: List of task_keys to do continual learning on 259 | model_config: Dictionary containing ALBEF model configuration 260 | task_configs: Dictionary containing task-specific configurations for the CL tasks 261 | device: cpu/cuda 262 | 263 | Returns: 264 | cl_model: instance of ALBEFContinualLearner 265 | """ 266 | 267 | albef_model = load_albef(logger, model_config=model_config, checkpoint_name=model_name_or_path, device=device, pretrained_albef_name=model_name_or_path, ) 268 | 269 | cl_model = ALBEFContinualLearner(ordered_cl_tasks=ordered_cl_tasks, albef_model=albef_model, task_configs=task_configs, ) 270 | logger.info("Successfully created and initialized ALBEF Continual Leaner model") 271 | 272 | return cl_model 273 | 274 | 275 | def convert_batch_to_albef_input_dict(batch: Dict): 276 | """ 277 | Convert inputs from batch_col 278 | late into format consumable by the ViltProcessor 279 | """ 280 | return { 281 | "images": batch[0], 282 | "questions": batch[1], 283 | "answers": batch[2], 284 | "weights": batch[3], 285 | "n": batch[4], 286 | "alpha": batch[5], } 287 | --------------------------------------------------------------------------------