├── .gitignore ├── README.md ├── advqa_eval.py ├── aokvqa_mc_eval.ipynb ├── artvqa_eval.py ├── cli.py ├── configs ├── aokvqa.yaml ├── aokvqg.yaml ├── artvqa.yaml ├── artvqg.yaml ├── bert_config.json ├── caption_coco.yaml ├── generate_questions_aokvqa.yaml ├── generate_questions_coco.yaml ├── generate_questions_coco_vastai.yaml ├── generate_questions_pathvqa.yaml ├── generate_questions_vg.yaml ├── generic.yaml ├── med_config.json ├── nlvr.yaml ├── nocaps.yaml ├── okvqa.yaml ├── okvqa_lavis.yaml ├── okvqg.yaml ├── pathvqa.yaml ├── pathvqg.yaml ├── pretrain.yaml ├── retrieval_coco.yaml ├── retrieval_flickr.yaml ├── retrieval_msrvtt.yaml ├── rsvqa_lr.yaml ├── vqa.yaml ├── vqa_ablations.yaml ├── vqa_ablations_vastai.yaml ├── vqa_rephrasings.yaml ├── vqg.yaml └── vqg_vastai.yaml ├── convert_advqa.py ├── convert_aokvqa.py ├── convert_aqua.py ├── convert_okvqa.py ├── convert_pathvqa.py ├── convert_rsvqa.py ├── convert_vqa_ce.py ├── convert_vqa_rephrasings.py ├── data ├── __init__.py ├── coco_karpathy_dataset.py ├── flickr30k_dataset.py ├── nlvr_dataset.py ├── nocaps_dataset.py ├── pretrain_dataset.py ├── utils.py ├── video_dataset.py ├── vqa_dataset.py └── vqg_dataset.py ├── environment.yaml ├── examples ├── evaluate.sh ├── generate_synthetic_data.sh ├── self_train_synthetic.sh └── train_teacher.sh ├── generate_questions.py ├── models ├── __init__.py ├── blip.py ├── blip_itm.py ├── blip_nlvr.py ├── blip_pretrain.py ├── blip_retrieval.py ├── blip_vqa.py ├── med.py ├── nlvr_encoder.py └── vit.py ├── okvqa_eval.py ├── pathvqa_eval.py ├── pytest.ini ├── requirements.txt ├── rsvqa_lr_eval.py ├── schemas.py ├── setup.md ├── tests ├── __init__.py ├── conftest.py ├── test_cli.py ├── test_datasets.py ├── test_generate_questions.py └── test_models.py ├── train_vqa.py ├── train_vqg.py ├── transform └── randaugment.py ├── utils.py ├── vqa_ce_eval.py ├── vqa_eval_tools ├── __init__.py ├── vqa.py └── vqa_eval.py ├── vqa_introspect_eval.py ├── vqa_rephrasings_eval.py └── vqav2_eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized files 2 | *.pyc 3 | __pycache__/ 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | dist/ 10 | build/ 11 | *.egg-info/ 12 | *.egg 13 | 14 | # Virtual environments 15 | venv/ 16 | env/ 17 | *.env 18 | 19 | # IDE-specific files 20 | .idea/ 21 | .vscode/ 22 | *.pydevproject 23 | 24 | # Miscellaneous 25 | *.swp 26 | *~ 27 | .DS_Store 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Conference](https://img.shields.io/badge/CVPR-2023-blue)](https://openaccess.thecvf.com/content/CVPR2023/html/Khan_Q_How_To_Specialize_Large_Vision-Language_Models_to_Data-Scarce_VQA_CVPR_2023_paper.html) 2 | [![Paper](http://img.shields.io/badge/paper-arxiv.2306.03932-B31B1B.svg)](https://arxiv.org/abs/2306.03932) 3 | 4 | # SelTDA 5 | This repository will hold the official code of SelTDA, the self-training framework introduced in our CVPR 2023 paper "Q: How to Specialize Large Vision-Language Models to Data-Scarce VQA Tasks? A: Self-Train on Unlabeled Images!". 6 | 7 | 8 | ![seltda_teaser](https://user-images.githubusercontent.com/4918041/225918833-7d744775-260a-4bc3-a642-7279531b5b07.png) 9 | 10 | ## Environment 11 | ```bash 12 | conda env create -f environment.yaml 13 | ``` 14 | 15 | ## Data 16 | ### Downloads and Preprocessing 17 | - [PathVQA](https://github.com/UCSD-AI4H/PathVQA) 18 | - then use `convert_pathvqa.py` 19 | - [RSVQA](https://rsvqa.sylvainlobry.com/) 20 | - then use `convert_rsvqa.py` 21 | - OK-VQA and A-OKVQA (use [LAVIS](https://github.com/salesforce/LAVIS)) 22 | - LAVIS should automatically put them in the correct format, but if not, you can use `convert_okvqa.py` 23 | - [VQA Counterexamples](https://github.com/cdancette/detect-shortcuts) 24 | - then use `convert_vqa_ce.py` 25 | - [AdVQA](https://adversarialvqa.org/download.html) 26 | - then use `convert_advqa.py` 27 | - [VQA Rephrasings](https://facebookresearch.github.io/VQA-Rephrasings/) 28 | - then use `convert_vqa_rephrasings.py` 29 | 30 | In general, the code expects that each VQA dataset is represented by a single JSON object that is a list of dictionaries. In `schemas.py`, we provide Pydantic models which you can use to define your own datasets or verify that the data is in the correct format. 31 | 32 | ## Experiments 33 | See the `examples/` directory to see examples of: 34 | - training the teacher 35 | - `examples/train_teacher.sh` 36 | - generating synthetic data with the teacher 37 | - `examples/generate_synthetic_data.sh` 38 | - self-training with the synthetic data 39 | - `examples/self_train_synthetic.sh` 40 | - evaluations 41 | - `examples/evaluate.sh` 42 | 43 | ## Citation 44 | ``` 45 | @InProceedings{Khan_2023_CVPR, 46 | author = {Khan, Zaid and BG, Vijay Kumar and Schulter, Samuel and Yu, Xiang and Fu, Yun and Chandraker, Manmohan}, 47 | title = {Q: How To Specialize Large Vision-Language Models to Data-Scarce VQA Tasks? A: Self-Train on Unlabeled Images!}, 48 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 49 | month = {June}, 50 | year = {2023}, 51 | pages = {15005-15015} 52 | } 53 | ``` 54 | 55 | 56 | ## Acknowledgements 57 | This code is heavily based on [salesforce/BLIP](https://github.com/salesforce/BLIP). -------------------------------------------------------------------------------- /advqa_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm.notebook import tqdm 3 | import json 4 | from pprint import PrettyPrinter 5 | from vqa_eval_tools import VQA, VQAEval 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | 9 | 10 | pp = PrettyPrinter() 11 | 12 | # The annotations are missing an "question_type" key, so we create a new annotation file which does 13 | # have the key. We just copy the "answer_type" key to "question_type", they are the same thing, I think. 14 | original_annotation_file = ( 15 | "/net/acadia10a/data/zkhan/advqa/v1_mscoco_val2017_advqa_annotations.json" 16 | ) 17 | question_file = ( 18 | "/net/acadia10a/data/zkhan/advqa/v1_OpenEnded_mscoco_val2017_advqa_questions.json" 19 | ) 20 | 21 | # This one doesn't have to exist, we create it from the original annotation file. 22 | modified_annotations_file = ( 23 | "/net/acadia10a/data/zkhan/advqa/nb017_val2017_annotations_w_qtype.json" 24 | ) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = ArgumentParser() 29 | parser.add_argument( 30 | "result_file", help="Path to a JSON result file generated by an evaluation." 31 | ) 32 | args = parser.parse_args() 33 | 34 | results_file = args.result_file 35 | # results_file = '/net/acadia4a/data/zkhan/mithril/advqa-0-shot-evals/35_blip_vqa_baseline/result/vqa_result.json' 36 | 37 | # The annotations are missing an "question_type" key, so we create a new annotation file which does 38 | # have the key. We just copy the "answer_type" key to "question_type", they are the same thing, I think. 39 | with open(original_annotation_file, "r") as f: 40 | annotations = json.load(f) 41 | 42 | for record in annotations["annotations"]: 43 | record["question_type"] = record["answer_type"] 44 | 45 | with open(modified_annotations_file, "w") as f: 46 | json.dump(annotations, f) 47 | 48 | advqa_obj = VQA( 49 | annotation_file=modified_annotations_file, question_file=question_file 50 | ) 51 | 52 | # We have to convert the question_id field to be an integer >.< 53 | with open(results_file, "r") as f: 54 | predicted = json.load(f) 55 | 56 | for element in predicted: 57 | element["question_id"] = int(element["question_id"]) 58 | 59 | with open(results_file, "w") as f: 60 | json.dump(predicted, f) 61 | 62 | result_obj = advqa_obj.loadRes( 63 | resFile=results_file, 64 | quesFile="/net/acadia10a/data/zkhan/advqa/v1_OpenEnded_mscoco_val2017_advqa_questions.json", 65 | ) 66 | 67 | advqa_eval = VQAEval(advqa_obj, result_obj, n=2) 68 | advqa_eval.evaluate() 69 | print(f"Completed evaluation of {results_file}") 70 | pp.pprint(advqa_eval.accuracy) 71 | with open(Path(results_file).parent / "advqa_eval.json", "w") as f: 72 | json.dump(advqa_eval.accuracy, f) 73 | -------------------------------------------------------------------------------- /artvqa_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | For AQUA (Art VQA), we don't use the VQAv2 evaluation code. 3 | That's because the VQAv2 evaluation code assumes there are 4 | multiple answers for each question, but in AQUA, there's only 5 | one answer for each question. We just do an exact match evaluation 6 | following the AQUA paper. 7 | """ 8 | import json 9 | from unittest import result 10 | from tqdm import tqdm 11 | import json 12 | from pprint import PrettyPrinter 13 | from vqa_eval_tools import VQA, VQAEval 14 | from argparse import ArgumentParser 15 | from pathlib import Path 16 | import schemas 17 | import pandas as pd 18 | 19 | 20 | pp = PrettyPrinter() 21 | 22 | annotation_file = "/net/acadia4a/data/zkhan/vqa-art/test_annotations.json" 23 | question_file = "/net/acadia4a/data/zkhan/vqa-art/test_questions.json" 24 | 25 | 26 | def exact_match_eval(annotation_file, question_file, result_file): 27 | with open(annotation_file, "r") as f: 28 | annotations = json.load(f)["annotations"] 29 | with open(question_file, "r") as f: 30 | questions = json.load(f)["questions"] 31 | with open(result_file, "r") as f: 32 | results = json.load(f) 33 | 34 | annotations = [schemas.VQAAnnotationRecord.parse_obj(a) for a in annotations] 35 | questions = [schemas.QuestionRecord.parse_obj(q) for q in questions] 36 | 37 | annotation_lookup_table = {a.question_id: a for a in annotations} 38 | evaluation_records = [] 39 | for answer_record in results: 40 | ground_truth = annotation_lookup_table[answer_record["question_id"]] 41 | # It's a list, but there's only one answer for each VQA art question. 42 | # So we just take the first one and do an exact match. 43 | true_answer = ground_truth.answers[0].answer 44 | is_correct = answer_record["answer"] == true_answer 45 | question_type = ground_truth.question_type 46 | evaluation_records.append( 47 | { 48 | "question_id": answer_record["question_id"], 49 | "answer": answer_record["answer"], 50 | "question_type": question_type, 51 | "is_correct": is_correct, 52 | "true_answer": true_answer, 53 | } 54 | ) 55 | 56 | frame = pd.DataFrame(evaluation_records) 57 | mask = frame["question_type"] == "external knowledge" 58 | accuracies = { 59 | "overall": frame["is_correct"].sum() / len(frame), 60 | "external knowledge": frame[mask]["is_correct"].sum() / len(frame[mask]), 61 | "no external knowledge": frame[~mask]["is_correct"].sum() / len(frame[~mask]), 62 | } 63 | accuracies = {k: round(v, 4) for k, v in accuracies.items()} 64 | return accuracies 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = ArgumentParser() 69 | parser.add_argument( 70 | "result_file", help="Path to a JSON result file generated by an evaluation." 71 | ) 72 | args = parser.parse_args() 73 | 74 | results_file = args.result_file 75 | 76 | accuracies = exact_match_eval(annotation_file, question_file, results_file) 77 | 78 | pp.pprint(accuracies) 79 | with open(Path(results_file).parent / "artvqa_eval.json", "w") as f: 80 | json.dump(accuracies, f) 81 | -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf, DictConfig 2 | import hydra 3 | from argparse import ArgumentParser, Namespace 4 | from pathlib import Path 5 | import torch 6 | import os 7 | 8 | # import ruamel.yaml as yaml 9 | from typing import Tuple 10 | 11 | 12 | def make_parser(default_config_path: str) -> ArgumentParser: 13 | parser = ArgumentParser() 14 | parser.add_argument("--config", default=default_config_path) 15 | parser.add_argument( 16 | "--output_dir", default="/net/acadia4a/data/zkhan/mithril/sandbox" 17 | ) 18 | parser.add_argument("--evaluate", action="store_true") 19 | parser.add_argument("--device", default="cuda") 20 | parser.add_argument("--seed", default=42, type=int) 21 | parser.add_argument( 22 | "--world_size", default=1, type=int, help="number of distributed processes" 23 | ) 24 | parser.add_argument( 25 | "--dist_url", default="env://", help="url used to set up distributed training" 26 | ) 27 | parser.add_argument("--distributed", default=True, type=bool) 28 | parser.add_argument("--overrides", nargs="+", default=[]) 29 | return parser 30 | 31 | 32 | def load_config(args: Namespace) -> DictConfig: 33 | config_path = Path(args.config) 34 | with hydra.initialize(config_path=str(config_path.parent), version_base=None): 35 | config = hydra.compose(config_name=config_path.stem, overrides=args.overrides) 36 | return config 37 | 38 | 39 | def parse_args(default_config_path: str) -> Tuple[Namespace, DictConfig]: 40 | """ 41 | Parse command line arguments and config.yaml file. 42 | 43 | Args: 44 | default_config_path (str): This config will be the default if the user 45 | doesn't provide a different config. 46 | Returns: 47 | args (Namespace): Parsed arguments from command line. 48 | config (DictConfig): The parsed config. 49 | """ 50 | parser = make_parser(default_config_path) 51 | args = parser.parse_args() 52 | config = load_config(args) 53 | return args, config 54 | 55 | 56 | def setup(args: Namespace, config: DictConfig) -> None: 57 | """Do housekeeping needed in general before training. 58 | 59 | Args: 60 | args (Namespace): Parsed arguments from command line. 61 | config (DictConfig): The config produced by hydra. 62 | """ 63 | if config.torch_home: 64 | torch.hub.set_dir(config.torch_home) 65 | args.result_dir = os.path.join(args.output_dir, "result") 66 | 67 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 68 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 69 | 70 | OmegaConf.save(config, os.path.join(args.output_dir, "config.yaml")) 71 | # yaml.dump(config, open(os.path.join(args.output_dir, "config.yaml"), "w")) 72 | -------------------------------------------------------------------------------- /configs/aokvqa.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/net/acadia10b/data/zkhan/coco2017' 2 | vg_root: null 3 | train_files: ['train'] 4 | ann_root: '/net/acadia4a/data/zkhan/aokvqa' 5 | dataset_name: aokvqa 6 | truncate_train_dataset_to: null 7 | 8 | # We don't stop you from setting all of these to True... 9 | # but you probably shouldn't. 10 | append_rationale_to_answer: false 11 | append_rationale_to_question: false 12 | use_rationale_as_answer: false 13 | 14 | # AOKVQA doesn't have a test-dev set like VQAv2 does, and 15 | # we can only score the actual test set once a week. The BLIP 16 | # code is not set up to use the validation set at all. So, we 17 | # add a flag that will force the model to use the validation set 18 | # in place of the test set, because the code currently only 19 | # recognizes training and test sets. 20 | use_validation_set_as_test_set: false 21 | 22 | # set pretrained as a file path or an url 23 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 24 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 25 | 26 | # size of vit model; base or large 27 | vit: 'base' 28 | batch_size_train: 16 29 | batch_size_test: 32 30 | vit_grad_ckpt: False 31 | vit_ckpt_layer: 0 32 | init_lr: 2e-5 33 | 34 | image_size: 480 35 | 36 | k_test: 128 37 | inference: 'rank' 38 | 39 | # optimizer 40 | weight_decay: 0.05 41 | min_lr: 0 42 | max_epoch: 10 43 | 44 | torch_home: /net/acadia10b/data/zkhan/torch_home 45 | wandb: true -------------------------------------------------------------------------------- /configs/aokvqg.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/net/acadia10a/data/zkhan/coco2017' 2 | vg_root: null 3 | train_files: ['train'] 4 | ann_root: '/net/acadia10a/data/zkhan/aokvqa' 5 | dataset_name: aokvqa 6 | truncate_train_dataset_to: null 7 | 8 | append_rationale_to_answer: False 9 | append_rationale_to_question: False 10 | 11 | # Increase this (~160) when generating use_rationale=true. 12 | tokenizer_max_length: 40 13 | use_rationale: false 14 | # Whether the model elarns to generate the rationale first, then the question, or the reverse. 15 | generate_rationale_first: false 16 | 17 | # set pretrained as a file path or an url 18 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 19 | 20 | use_validation_set_as_test_set: true 21 | 22 | # size of vit model; base or large 23 | vit: 'base' 24 | vit_grad_ckpt: False 25 | vit_ckpt_layer: 0 26 | batch_size: 32 27 | init_lr: 1e-5 28 | 29 | # vit: 'large' 30 | # vit_grad_ckpt: True 31 | # vit_ckpt_layer: 5 32 | # batch_size: 16 33 | # init_lr: 2e-6 34 | 35 | image_size: 384 36 | 37 | # generation configs 38 | max_length: 20 39 | min_length: 5 40 | num_beams: 3 41 | prompt: '' 42 | 43 | # optimizer 44 | weight_decay: 0.05 45 | min_lr: 0 46 | max_epoch: 5 47 | 48 | 49 | torch_home: /net/acadia10a/data/zkhan/torch_home 50 | wandb: true -------------------------------------------------------------------------------- /configs/artvqa.yaml: -------------------------------------------------------------------------------- 1 | ann_root: /net/acadia4a/data/zkhan/vqa-art 2 | vqa_root: /net/acadia4a/data/zkhan/SemArt/Images 3 | train_files: ['train'] 4 | dataset_name: artvqa 5 | truncate_train_dataset_to: null 6 | 7 | # set pretrained as a file path or an url 8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 9 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 10 | 11 | # size of vit model; base or large 12 | vit: 'base' 13 | batch_size_train: 16 14 | batch_size_test: 32 15 | vit_grad_ckpt: False 16 | vit_ckpt_layer: 0 17 | init_lr: 2e-5 18 | 19 | image_size: 480 20 | 21 | k_test: 128 22 | inference: 'rank' 23 | 24 | # optimizer 25 | weight_decay: 0.05 26 | min_lr: 0 27 | max_epoch: 10 28 | 29 | torch_home: /net/acadia10a/data/zkhan/torch_home 30 | wandb: true -------------------------------------------------------------------------------- /configs/artvqg.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: /net/acadia4a/data/zkhan/SemArt/Images 2 | vg_root: null 3 | train_files: ['train_no_external_knowledge_questions'] 4 | ann_root: /net/acadia4a/data/zkhan/vqa-art 5 | dataset_name: artvqa 6 | truncate_train_dataset_to: null 7 | 8 | # Keep these false for OK-VQA, there are no 9 | # rationales in this dataset. 10 | append_rationale_to_answer: False 11 | append_rationale_to_question: False 12 | 13 | # Increase this (~160) when generating use_rationale=true. 14 | tokenizer_max_length: 40 15 | use_rationale: false 16 | # Whether the model elarns to generate the rationale first, then the question, or the reverse. 17 | generate_rationale_first: false 18 | 19 | # set pretrained as a file path or an url 20 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 21 | 22 | use_validation_set_as_test_set: true 23 | 24 | # size of vit model; base or large 25 | vit: 'base' 26 | vit_grad_ckpt: False 27 | vit_ckpt_layer: 0 28 | batch_size: 32 29 | init_lr: 1e-5 30 | 31 | # vit: 'large' 32 | # vit_grad_ckpt: True 33 | # vit_ckpt_layer: 5 34 | # batch_size: 16 35 | # init_lr: 2e-6 36 | 37 | image_size: 384 38 | 39 | # generation configs 40 | max_length: 20 41 | min_length: 5 42 | num_beams: 3 43 | prompt: '' 44 | 45 | # optimizer 46 | weight_decay: 0.05 47 | min_lr: 0 48 | max_epoch: 5 49 | 50 | 51 | torch_home: /net/acadia10a/data/zkhan/torch_home 52 | wandb: true -------------------------------------------------------------------------------- /configs/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/caption_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | coco_gt_root: 'annotation/coco_gt' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 7 | 8 | # size of vit model; base or large 9 | vit: 'base' 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | batch_size: 32 13 | init_lr: 1e-5 14 | 15 | # vit: 'large' 16 | # vit_grad_ckpt: True 17 | # vit_ckpt_layer: 5 18 | # batch_size: 16 19 | # init_lr: 2e-6 20 | 21 | image_size: 384 22 | 23 | # generation configs 24 | max_length: 20 25 | min_length: 5 26 | num_beams: 3 27 | prompt: 'a picture of ' 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 5 33 | 34 | -------------------------------------------------------------------------------- /configs/generate_questions_aokvqa.yaml: -------------------------------------------------------------------------------- 1 | image_folder: /net/acadia10a/data/zkhan/coco2017/ 2 | output_folder: /net/acadia10a/data/zkhan/aokvqa/ 3 | annotations: /net/acadia10a/data/zkhan/aokvqa/train.json 4 | pretrained: /net/acadia10a/data/mithril/blip_aokvqg_no_vqav2/checkpoint_04.pth 5 | output_annotations_name: /net/acadia10a/data/zkhan/aokvqa/aokvqa_generated_qa.json 6 | image_size: 384 7 | max_length: 30 8 | min_length: 5 9 | num_beams: 3 10 | prompt: 'Question:' 11 | torch_home: /net/acadia10a/data/zkhan/torch_home 12 | vit: base 13 | questions_per_image: 2 14 | multimodal_encoder_decoder_config: /home/mai/zkhan/BLIP/configs/med_config.json 15 | batch_size: 64 16 | truncate_to: null 17 | num_workers: 4 18 | top_p: 0.9 19 | vqa_dataset_origin: 'vqa' 20 | dry_run: false 21 | shuffle: true 22 | 23 | parse_rationale: false -------------------------------------------------------------------------------- /configs/generate_questions_coco.yaml: -------------------------------------------------------------------------------- 1 | image_folder: /net/acadia10a/data/zkhan/coco2017/unlabeled2017 2 | output_folder: /net/acadia10a/data/zkhan/vqav2_annotations/ 3 | annotations: null 4 | pretrained: /net/acadia10a/data/zkhan/mithril/blip_vqg_2/checkpoint_04.pth 5 | output_annotations_name: coco_generated_qa.json 6 | image_size: 384 7 | max_length: 30 8 | min_length: 5 9 | num_beams: 3 10 | prompt: 'Question: ' 11 | torch_home: /net/acadia10b/data/zkhan/torch_home 12 | vit: base 13 | questions_per_image: 1 14 | multimodal_encoder_decoder_config: /home/mai/zkhan/BLIP/configs/med_config.json 15 | batch_size: 64 16 | truncate_to: null 17 | num_workers: 4 18 | top_p: 0.9 19 | vqa_dataset_origin: 'vqa' 20 | dry_run: false 21 | shuffle: false 22 | 23 | parse_rationale: false 24 | truncate_to_strict: null -------------------------------------------------------------------------------- /configs/generate_questions_coco_vastai.yaml: -------------------------------------------------------------------------------- 1 | image_folder: /home/zkhan/coco/unlabeled2017 2 | output_folder: /home/zkhan/vqav2_annotations 3 | annotations: null 4 | pretrained: null 5 | output_annotations_name: null 6 | image_size: 384 7 | max_length: 30 8 | min_length: 5 9 | num_beams: 3 10 | prompt: 'Question: ' 11 | torch_home: null 12 | vit: base 13 | questions_per_image: 1 14 | multimodal_encoder_decoder_config: /home/zkhan/blip/configs/med_config.json 15 | batch_size: 64 16 | truncate_to: null 17 | truncate_to_strict: null 18 | num_workers: 4 19 | top_p: 0.9 20 | vqa_dataset_origin: 'vqa' 21 | dry_run: false 22 | shuffle: false 23 | 24 | parse_rationale: false -------------------------------------------------------------------------------- /configs/generate_questions_pathvqa.yaml: -------------------------------------------------------------------------------- 1 | image_folder: /net/acadia4a/data/zkhan/pathvqa/images 2 | output_folder: /net/acadia4a/data/zkhan/pathvqa 3 | annotations: /net/acadia4a/data/zkhan/pathvqa/ 4 | pretrained: /net/acadia10a/data/zkhan/mithril/blip_vqg_2/checkpoint_04.pth 5 | output_annotations_name: coco_generated_qa.json 6 | image_size: 384 7 | max_length: 30 8 | min_length: 5 9 | num_beams: 3 10 | prompt: 'Question: ' 11 | torch_home: /net/acadia10a/data/zkhan/torch_home 12 | vit: base 13 | questions_per_image: 2 14 | multimodal_encoder_decoder_config: /home/mai/zkhan/BLIP/configs/med_config.json 15 | batch_size: 64 16 | truncate_to: null 17 | num_workers: 4 18 | top_p: 0.9 19 | vqa_dataset_origin: 'vqa' 20 | dry_run: false 21 | shuffle: false 22 | 23 | parse_rationale: false -------------------------------------------------------------------------------- /configs/generate_questions_vg.yaml: -------------------------------------------------------------------------------- 1 | image_folder: /net/acadia10a/data/zkhan/visual-genome-sandbox/image/ 2 | output_folder: /net/acadia10a/data/zkhan/vqav2_annotations/ 3 | annotations: null 4 | pretrained: /net/acadia10a/data/zkhan/mithril/blip_vqg_2/checkpoint_04.pth 5 | output_annotations_name: vg_generated_qa.json 6 | image_size: 384 7 | max_length: 30 8 | min_length: 5 9 | num_beams: 3 10 | prompt: 'Question:' 11 | torch_home: /net/acadia10a/data/zkhan/torch_home 12 | vit: base 13 | questions_per_image: 2 14 | multimodal_encoder_decoder_config: /home/mai/zkhan/BLIP/configs/med_config.json 15 | batch_size: 64 16 | truncate_to: null 17 | num_workers: 4 18 | top_p: 0.9 19 | vqa_dataset_origin: 'vg' 20 | dry_run: false 21 | shuffle: false 22 | 23 | parse_rationale: false 24 | -------------------------------------------------------------------------------- /configs/generic.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/net/acadia10a/data/zkhan/coco2014/' 2 | train_files: ['train'] 3 | val_file: val 4 | ann_root: '/net/acadia10a/data/zkhan/aokvqa' 5 | dataset_name: generic_vqa 6 | answer_list: answer_list 7 | truncate_train_dataset_to: null 8 | 9 | # set pretrained as a file path or an url 10 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 11 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 12 | 13 | # size of vit model; base or large 14 | vit: 'base' 15 | batch_size_train: 16 16 | batch_size_test: 32 17 | vit_grad_ckpt: False 18 | vit_ckpt_layer: 0 19 | init_lr: 2e-5 20 | 21 | image_size: 480 22 | 23 | k_test: 128 24 | inference: 'rank' 25 | 26 | # optimizer 27 | weight_decay: 0.05 28 | min_lr: 0 29 | max_epoch: 10 30 | 31 | torch_home: /net/acadia10a/data/zkhan/torch_home 32 | wandb: true -------------------------------------------------------------------------------- /configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/nlvr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/NLVR2/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth' 6 | 7 | #size of vit model; base or large 8 | vit: 'base' 9 | batch_size_train: 16 10 | batch_size_test: 64 11 | vit_grad_ckpt: False 12 | vit_ckpt_layer: 0 13 | max_epoch: 15 14 | 15 | image_size: 384 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-5 20 | min_lr: 0 21 | 22 | -------------------------------------------------------------------------------- /configs/nocaps.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/nocaps/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 6 | 7 | vit: 'base' 8 | batch_size: 32 9 | 10 | image_size: 384 11 | 12 | max_length: 20 13 | min_length: 5 14 | num_beams: 3 15 | prompt: 'a picture of ' -------------------------------------------------------------------------------- /configs/okvqa.yaml: -------------------------------------------------------------------------------- 1 | ann_root: /net/acadia10a/data/zkhan/ok-vqa 2 | vqa_root: '/net/acadia10a/data/zkhan/coco2014' 3 | train_files: ['train'] 4 | dataset_name: okvqa 5 | truncate_train_dataset_to: null 6 | 7 | # set pretrained as a file path or an url 8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 9 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 10 | 11 | # size of vit model; base or large 12 | vit: 'base' 13 | batch_size_train: 16 14 | batch_size_test: 32 15 | vit_grad_ckpt: False 16 | vit_ckpt_layer: 0 17 | init_lr: 2e-5 18 | 19 | image_size: 480 20 | 21 | k_test: 128 22 | inference: 'rank' 23 | 24 | # optimizer 25 | weight_decay: 0.05 26 | min_lr: 0 27 | max_epoch: 10 28 | 29 | torch_home: /net/acadia10a/data/zkhan/torch_home 30 | wandb: true -------------------------------------------------------------------------------- /configs/okvqa_lavis.yaml: -------------------------------------------------------------------------------- 1 | ann_root: /net/acadia10a/data/zkhan/ok-vqa 2 | vqa_root: '/net/acadia10a/data/zkhan/coco2014' 3 | train_files: ['train'] 4 | dataset_name: okvqa 5 | truncate_train_dataset_to: null 6 | 7 | # set pretrained as a file path or an url 8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 9 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 10 | 11 | # size of vit model; base or large 12 | vit: 'base' 13 | batch_size_train: 16 14 | batch_size_test: 32 15 | vit_grad_ckpt: False 16 | vit_ckpt_layer: 0 17 | init_lr: 3e-5 18 | 19 | image_size: 480 20 | 21 | k_test: 128 22 | inference: 'rank' 23 | 24 | # optimizer 25 | weight_decay: 0.02 26 | min_lr: 1e-5 27 | max_epoch: 7 28 | 29 | torch_home: /net/acadia10a/data/zkhan/torch_home 30 | wandb: true -------------------------------------------------------------------------------- /configs/okvqg.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/net/acadia10a/data/zkhan/coco2014' 2 | vg_root: null 3 | train_files: ['train'] 4 | ann_root: '/net/acadia10a/data/zkhan/ok-vqa' 5 | dataset_name: okvqa 6 | truncate_train_dataset_to: null 7 | 8 | # Keep these false for OK-VQA, there are no 9 | # rationales in this dataset. 10 | append_rationale_to_answer: False 11 | append_rationale_to_question: False 12 | 13 | # Increase this (~160) when generating use_rationale=true. 14 | tokenizer_max_length: 40 15 | use_rationale: false 16 | # Whether the model elarns to generate the rationale first, then the question, or the reverse. 17 | generate_rationale_first: false 18 | 19 | # set pretrained as a file path or an url 20 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 21 | 22 | use_validation_set_as_test_set: true 23 | 24 | # size of vit model; base or large 25 | vit: 'base' 26 | vit_grad_ckpt: False 27 | vit_ckpt_layer: 0 28 | batch_size: 32 29 | init_lr: 1e-5 30 | 31 | # vit: 'large' 32 | # vit_grad_ckpt: True 33 | # vit_ckpt_layer: 5 34 | # batch_size: 16 35 | # init_lr: 2e-6 36 | 37 | image_size: 384 38 | 39 | # generation configs 40 | max_length: 20 41 | min_length: 5 42 | num_beams: 3 43 | prompt: '' 44 | 45 | # optimizer 46 | weight_decay: 0.05 47 | min_lr: 0 48 | max_epoch: 5 49 | 50 | 51 | torch_home: /net/acadia10a/data/zkhan/torch_home 52 | wandb: true -------------------------------------------------------------------------------- /configs/pathvqa.yaml: -------------------------------------------------------------------------------- 1 | ann_root: /net/acadia4a/data/zkhan/pathvqa 2 | vqa_root: /net/acadia4a/data/zkhan/pathvqa/images 3 | train_files: ['train'] 4 | dataset_name: pathvqa 5 | truncate_train_dataset_to: null 6 | 7 | # set pretrained as a file path or an url 8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 9 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 10 | 11 | # size of vit model; base or large 12 | vit: 'base' 13 | batch_size_train: 16 14 | batch_size_test: 16 15 | vit_grad_ckpt: False 16 | vit_ckpt_layer: 0 17 | init_lr: 2e-5 18 | 19 | image_size: 480 20 | 21 | k_test: 128 22 | inference: 'rank' 23 | 24 | # optimizer 25 | weight_decay: 0.05 26 | min_lr: 0 27 | max_epoch: 10 28 | 29 | torch_home: /net/acadia10a/data/zkhan/torch_home 30 | wandb: true -------------------------------------------------------------------------------- /configs/pathvqg.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: /net/acadia4a/data/zkhan/pathvqa/images 2 | vg_root: null 3 | train_files: ['train'] 4 | ann_root: /net/acadia4a/data/zkhan/pathvqa 5 | dataset_name: pathvqa 6 | truncate_train_dataset_to: null 7 | 8 | # Keep these false for OK-VQA, there are no 9 | # rationales in this dataset. 10 | append_rationale_to_answer: False 11 | append_rationale_to_question: False 12 | 13 | # Increase this (~160) when generating use_rationale=true. 14 | tokenizer_max_length: 40 15 | use_rationale: false 16 | # Whether the model elarns to generate the rationale first, then the question, or the reverse. 17 | generate_rationale_first: false 18 | 19 | # set pretrained as a file path or an url 20 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 21 | 22 | use_validation_set_as_test_set: true 23 | 24 | # size of vit model; base or large 25 | vit: 'base' 26 | vit_grad_ckpt: False 27 | vit_ckpt_layer: 0 28 | batch_size: 32 29 | init_lr: 1e-5 30 | 31 | # vit: 'large' 32 | # vit_grad_ckpt: True 33 | # vit_ckpt_layer: 5 34 | # batch_size: 16 35 | # init_lr: 2e-6 36 | 37 | image_size: 384 38 | 39 | # generation configs 40 | max_length: 20 41 | min_length: 5 42 | num_beams: 3 43 | prompt: '' 44 | 45 | # optimizer 46 | weight_decay: 0.05 47 | min_lr: 0 48 | max_epoch: 5 49 | 50 | 51 | torch_home: /net/acadia10a/data/zkhan/torch_home 52 | wandb: true -------------------------------------------------------------------------------- /configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json', 2 | '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json', 3 | ] 4 | laion_path: '' 5 | 6 | # size of vit model; base or large 7 | vit: 'base' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 224 12 | batch_size: 75 13 | 14 | queue_size: 57600 15 | alpha: 0.4 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-4 20 | min_lr: 1e-6 21 | warmup_lr: 1e-6 22 | lr_decay_rate: 0.9 23 | max_epoch: 20 24 | warmup_steps: 3000 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /configs/retrieval_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | dataset: 'coco' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 12 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 256 28 | negative_all_rank: True 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /configs/retrieval_flickr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/flickr30k/' 2 | ann_root: 'annotation' 3 | dataset: 'flickr' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 10 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 128 28 | negative_all_rank: False 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /configs/retrieval_msrvtt.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 6 | 7 | # size of vit model; base or large 8 | vit: 'base' 9 | batch_size: 64 10 | k_test: 128 11 | image_size: 384 12 | num_frm_test: 8 -------------------------------------------------------------------------------- /configs/rsvqa_lr.yaml: -------------------------------------------------------------------------------- 1 | ann_root: /net/acadia4a/data/zkhan/rsvqa/low_resolution 2 | vqa_root: /net/acadia4a/data/zkhan/rsvqa/low_resolution/Images_LR 3 | train_files: ['train'] 4 | dataset_name: rsvqa 5 | truncate_train_dataset_to: null 6 | 7 | # set pretrained as a file path or an url 8 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 9 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 10 | 11 | # size of vit model; base or large 12 | vit: 'base' 13 | batch_size_train: 16 14 | batch_size_test: 16 15 | vit_grad_ckpt: False 16 | vit_ckpt_layer: 0 17 | init_lr: 2e-5 18 | 19 | image_size: 480 20 | 21 | k_test: 128 22 | inference: 'rank' 23 | 24 | # optimizer 25 | weight_decay: 0.05 26 | min_lr: 0 27 | max_epoch: 10 28 | 29 | torch_home: /net/acadia10a/data/zkhan/torch_home 30 | wandb: true -------------------------------------------------------------------------------- /configs/vqa.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/net/acadia10a/data/zkhan/coco2014/' #followed by train2014/ 2 | vg_root: '/net/acadia10a/data/zkhan/visual-genome-sandbox' #followed by image/ 3 | train_files: ['vqa_train','vqa_val', 'vg_qa'] 4 | ann_root: '/net/acadia10a/data/zkhan/vqav2_annotations' 5 | dataset_name: vqa 6 | truncate_train_dataset_to: null 7 | 8 | # We don't stop you from setting all of these to True... 9 | # but you probably shouldn't. Also, these aren't meaningful 10 | # for VQA, since there are no rationales in VQA. Yet :) 11 | append_rationale_to_answer: false 12 | append_rationale_to_question: false 13 | use_rationale_as_answer: false 14 | 15 | # set pretrained as a file path or an url 16 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 17 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 18 | 19 | # size of vit model; base or large 20 | vit: 'base' 21 | batch_size_train: 16 22 | batch_size_test: 32 23 | vit_grad_ckpt: False 24 | vit_ckpt_layer: 0 25 | init_lr: 2e-5 26 | 27 | image_size: 480 28 | 29 | k_test: 128 30 | inference: 'rank' 31 | 32 | # optimizer 33 | weight_decay: 0.05 34 | min_lr: 0 35 | max_epoch: 10 36 | 37 | torch_home: /net/acadia10a/data/zkhan/torch_home 38 | wandb: true -------------------------------------------------------------------------------- /configs/vqa_ablations.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/net/acadia10a/data/zkhan/coco2014/' #followed by train2014/ 2 | vg_root: '/net/acadia10a/data/zkhan/visual-genome-sandbox' #followed by image/ 3 | train_files: ['vqa_train'] 4 | ann_root: '/net/acadia10a/data/zkhan/vqav2_annotations' 5 | dataset_name: vqa 6 | truncate_train_dataset_to: null 7 | 8 | # We don't stop you from setting all of these to True... 9 | # but you probably shouldn't. Also, these aren't meaningful 10 | # for VQA, since there are no rationales in VQA. Yet :) 11 | append_rationale_to_answer: false 12 | append_rationale_to_question: false 13 | use_rationale_as_answer: false 14 | 15 | # set pretrained as a file path or an url 16 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 17 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 18 | 19 | # size of vit model; base or large 20 | vit: 'base' 21 | batch_size_train: 32 22 | batch_size_test: 32 23 | vit_grad_ckpt: False 24 | vit_ckpt_layer: 0 25 | init_lr: 2e-5 26 | 27 | image_size: 480 28 | 29 | k_test: 128 30 | inference: 'rank' 31 | 32 | # optimizer 33 | weight_decay: 0.05 34 | min_lr: 0 35 | max_epoch: 10 36 | 37 | torch_home: /net/acadia10a/data/zkhan/torch_home 38 | wandb: true 39 | save_last_only: true -------------------------------------------------------------------------------- /configs/vqa_ablations_vastai.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/home/zkhan/coco' 2 | vg_root: '' 3 | train_files: ['vqa_train'] 4 | ann_root: '/home/zkhan/vqav2_annotations' 5 | dataset_name: vqa 6 | truncate_train_dataset_to: null 7 | 8 | # We don't stop you from setting all of these to True... 9 | # but you probably shouldn't. Also, these aren't meaningful 10 | # for VQA, since there are no rationales in VQA. Yet :) 11 | append_rationale_to_answer: false 12 | append_rationale_to_question: false 13 | use_rationale_as_answer: false 14 | 15 | # set pretrained as a file path or an url 16 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 17 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 18 | 19 | # size of vit model; base or large 20 | vit: 'base' 21 | batch_size_train: 8 22 | batch_size_test: 16 23 | vit_grad_ckpt: False 24 | vit_ckpt_layer: 0 25 | init_lr: 2e-5 26 | 27 | image_size: 480 28 | 29 | k_test: 128 30 | inference: 'rank' 31 | 32 | # optimizer 33 | weight_decay: 0.05 34 | min_lr: 0 35 | max_epoch: 10 36 | 37 | torch_home: null 38 | wandb: true 39 | save_last_only: true -------------------------------------------------------------------------------- /configs/vqa_rephrasings.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/net/acadia10b/data/zkhan/coco2014/' 2 | train_files: ['train'] 3 | ann_root: '/net/acadia4a/data/zkhan/vqa-rephrasings' 4 | dataset_name: vqa_rephrasings 5 | truncate_train_dataset_to: null 6 | 7 | # We don't stop you from setting all of these to True... 8 | # but you probably shouldn't. Also, these aren't meaningful 9 | # for VQA, since there are no rationales in VQA. Yet :) 10 | append_rationale_to_answer: false 11 | append_rationale_to_question: false 12 | use_rationale_as_answer: false 13 | 14 | # set pretrained as a file path or an url 15 | # pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 16 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 17 | 18 | # size of vit model; base or large 19 | vit: 'base' 20 | batch_size_train: 16 21 | batch_size_test: 32 22 | vit_grad_ckpt: False 23 | vit_ckpt_layer: 0 24 | init_lr: 2e-5 25 | 26 | image_size: 480 27 | 28 | k_test: 128 29 | inference: 'rank' 30 | 31 | # optimizer 32 | weight_decay: 0.05 33 | min_lr: 0 34 | max_epoch: 10 35 | 36 | torch_home: /net/acadia10b/data/zkhan/torch_home 37 | wandb: true -------------------------------------------------------------------------------- /configs/vqg.yaml: -------------------------------------------------------------------------------- 1 | 2 | vqa_root: '/net/acadia10a/data/zkhan/coco2014/' #followed by train2014/ 3 | vg_root: '/net/acadia10a/data/zkhan/visual-genome-sandbox' #followed by image/ 4 | train_files: ['vqa_train'] 5 | ann_root: '/net/acadia10a/data/zkhan/vqav2_annotations' 6 | dataset_name: vqa 7 | truncate_train_dataset_to: null 8 | 9 | # set pretrained as a file path or an url 10 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 11 | 12 | tokenizer_max_length: 40 13 | 14 | # size of vit model; base or large 15 | vit: 'base' 16 | vit_grad_ckpt: False 17 | vit_ckpt_layer: 0 18 | batch_size: 32 19 | init_lr: 1e-5 20 | 21 | # vit: 'large' 22 | # vit_grad_ckpt: True 23 | # vit_ckpt_layer: 5 24 | # batch_size: 16 25 | # init_lr: 2e-6 26 | 27 | image_size: 384 28 | 29 | # generation configs 30 | max_length: 20 31 | min_length: 5 32 | num_beams: 3 33 | prompt: '' 34 | 35 | # optimizer 36 | weight_decay: 0.05 37 | min_lr: 0 38 | max_epoch: 5 39 | 40 | 41 | torch_home: /net/acadia10a/data/zkhan/torch_home 42 | wandb: true 43 | save_last_only: true -------------------------------------------------------------------------------- /configs/vqg_vastai.yaml: -------------------------------------------------------------------------------- 1 | 2 | vqa_root: /home/zkhan/coco 3 | vg_root: '' #followed by image/ 4 | train_files: ['vqa_train'] 5 | ann_root: '/home/zkhan/vqav2_annotations' 6 | dataset_name: vqa 7 | truncate_train_dataset_to: null 8 | 9 | # set pretrained as a file path or an url 10 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 11 | 12 | tokenizer_max_length: 40 13 | 14 | # size of vit model; base or large 15 | vit: 'base' 16 | vit_grad_ckpt: False 17 | vit_ckpt_layer: 0 18 | batch_size: 16 19 | init_lr: 1e-5 20 | 21 | # vit: 'large' 22 | # vit_grad_ckpt: True 23 | # vit_ckpt_layer: 5 24 | # batch_size: 16 25 | # init_lr: 2e-6 26 | 27 | image_size: 384 28 | 29 | # generation configs 30 | max_length: 20 31 | min_length: 5 32 | num_beams: 3 33 | prompt: '' 34 | 35 | # optimizer 36 | weight_decay: 0.05 37 | min_lr: 0 38 | max_epoch: 5 39 | 40 | 41 | torch_home: null 42 | wandb: true 43 | save_last_only: true -------------------------------------------------------------------------------- /convert_advqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts the AdVQA dataset into a format usable by our training code. 3 | The images in the AdVQA training dataset come from CC3M, Fakeddit and VCR, but the 4 | testing / validation images come only from COCO. 5 | 6 | AdVQA makes available a training, validation and testing split. Right now, I only use 7 | the validation split. 8 | """ 9 | 10 | import json 11 | from typing import List, Dict, Literal, Optional 12 | import cattrs 13 | from omegaconf import DictConfig 14 | from pathlib import Path 15 | import logging 16 | from pydantic import BaseModel 17 | import schemas 18 | from enum import Enum 19 | from tqdm import tqdm 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | handler = logging.StreamHandler() 24 | formatter = logging.Formatter( 25 | "%(asctime)s - %(name)s: %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 26 | ) 27 | handler.setFormatter(formatter) 28 | logger.addHandler(handler) 29 | logger.setLevel(logging.INFO) 30 | 31 | 32 | IMAGES_ROOT = Path("/net/acadia10a/data/zkhan/coco2014") 33 | ADVQA_ROOT = Path("/net/acadia10a/data/zkhan/advqa") 34 | 35 | 36 | def make_val_record(record: schemas.QuestionRecord) -> schemas.TestingRecord: 37 | return schemas.TestingRecord( 38 | question_id=record.question_id, 39 | question=record.question, 40 | image=record.image_id, 41 | dataset="advqa", 42 | ) 43 | 44 | 45 | def pad_coco_id(coco_id: int) -> str: 46 | return f"{coco_id:0>12}" 47 | 48 | 49 | def point_record_to_coco_image_file( 50 | record: schemas.TrainingRecord, split: str 51 | ) -> schemas.TrainingRecord: 52 | coco_image_id = pad_coco_id(record.image) 53 | relative_path_to_image = f"{split}2014/COCO_{split}2014_{coco_image_id}.jpg" 54 | record.image = relative_path_to_image 55 | return record 56 | 57 | 58 | if __name__ == "__main__": 59 | with open( 60 | "/net/acadia10a/data/zkhan/advqa/v1_OpenEnded_mscoco_val2017_advqa_questions.json", 61 | "r", 62 | ) as f: 63 | val_questions_raw = json.load(f) 64 | 65 | logger.info("Converting %d records", len(val_questions_raw["questions"])) 66 | val_questions = [ 67 | make_val_record(schemas.QuestionRecord.parse_obj(_)) 68 | for _ in val_questions_raw["questions"] 69 | ] 70 | 71 | logger.info("Verifying all image paths are correct.") 72 | val_records = [ 73 | point_record_to_coco_image_file(r, "val") for r in tqdm(val_questions) 74 | ] 75 | for record in tqdm(val_records): 76 | assert (IMAGES_ROOT / record.image).exists() 77 | 78 | with open(ADVQA_ROOT / "val.json", "w") as f: 79 | json.dump([r.dict() for r in val_records], f) 80 | logger.info( 81 | "Wrote %d validation records to %s", len(val_records), ADVQA_ROOT / "val.json" 82 | ) 83 | -------------------------------------------------------------------------------- /convert_aokvqa.py: -------------------------------------------------------------------------------- 1 | import cli 2 | import json 3 | import attrs 4 | from typing import List, Dict, Optional 5 | import cattrs 6 | from omegaconf import DictConfig 7 | from pathlib import Path 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | handler = logging.StreamHandler() 12 | formatter = logging.Formatter( 13 | "%(asctime)s - %(name)s: %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 14 | ) 15 | handler.setFormatter(formatter) 16 | logger.addHandler(handler) 17 | logger.setLevel(logging.INFO) 18 | 19 | SPLIT_TO_ORIGINAL_ANNOTATION_NAMES = { 20 | "test": "aokvqa_v1p0_test.json", 21 | "train": "aokvqa_v1p0_train.json", 22 | "val": "aokvqa_v1p0_val.json", 23 | } 24 | 25 | 26 | SPLIT_TO_NEW_ANNOTATION_NAMES = { 27 | "train": "train.json", 28 | "val": "val.json", 29 | "test": "test.json", 30 | } 31 | 32 | IMAGES_ROOT = Path("/net/acadia10a/data/zkhan/coco2017") 33 | 34 | 35 | @attrs.define 36 | class VQAV2Record: 37 | dataset: str 38 | image: str 39 | question: str 40 | question_id: str # Really an int, but I think we can use a string. 41 | answer: Optional[List[str]] = None 42 | rationales: Optional[List[str]] = None 43 | 44 | @classmethod 45 | def from_dict(cls, the_dict: Dict): 46 | return cattrs.structure(the_dict, cls) 47 | 48 | 49 | @attrs.define 50 | class AOKVQARecord: 51 | difficult_direct_answer: bool 52 | image_id: int 53 | question: str 54 | question_id: str 55 | split: str 56 | choices: Optional[List[str]] = None 57 | rationales: Optional[List[str]] = None 58 | correct_choice_idx: Optional[int] = None 59 | direct_answers: Optional[List[str]] = None 60 | 61 | @classmethod 62 | def from_dict(cls, the_dict: Dict): 63 | return cattrs.structure(the_dict, cls) 64 | 65 | 66 | def convert_aokvqa_to_vqav2(record: AOKVQARecord) -> VQAV2Record: 67 | return VQAV2Record( 68 | answer=record.direct_answers, 69 | dataset="aokvqa", 70 | image=record.image_id, 71 | question=record.question, 72 | question_id=record.question_id, 73 | rationales=record.rationales, 74 | ) 75 | 76 | 77 | def convert_coco_id_to_coco_name(coco_id: int, prefix="") -> str: 78 | return f"{prefix}{coco_id:012}.jpg" 79 | 80 | 81 | def point_record_to_coco_image_file( 82 | record: AOKVQARecord, coco_image_root: Path, split: str 83 | ) -> None: 84 | coco_filename = convert_coco_id_to_coco_name(record.image) 85 | record.image = f"coco-images/{coco_filename}" 86 | 87 | 88 | def load_split(split: str, aokvqa_root: Path) -> List[AOKVQARecord]: 89 | annotation_filename = SPLIT_TO_ORIGINAL_ANNOTATION_NAMES[split] 90 | with open(aokvqa_root / annotation_filename, "r") as f: 91 | annotations = json.load(f) 92 | 93 | records = [AOKVQARecord.from_dict(ann) for ann in annotations] 94 | logger.info("Loaded %d records from %s", len(records), annotation_filename) 95 | return records 96 | 97 | 98 | def serialize_records( 99 | records: List[AOKVQARecord], annotation_root: Path, split: str 100 | ) -> None: 101 | records_as_dict = [attrs.asdict(record) for record in records] 102 | output_filename = annotation_root / SPLIT_TO_NEW_ANNOTATION_NAMES[split] 103 | if output_filename.exists(): 104 | logger.info("Overwriting existing annotations %s", output_filename) 105 | with open(output_filename, "w") as f: 106 | json.dump(records_as_dict, f) 107 | logger.info("Wrote %d records to %s", len(records_as_dict), output_filename) 108 | 109 | 110 | def save_answer_list_as_json(annotation_root: Path): 111 | # For multiple choice answering, methods usually use a list 112 | # of common answers and then select the answer to a question 113 | # by ranking them. This is easier than directly generating the 114 | # the answer. AOKVQA provides this list as specialized_vocab_train.csv 115 | # but our code needs it to be JSON. It's a simple transformation. 116 | 117 | with open(annotation_root / "specialized_vocab_train.csv", "r") as f: 118 | words = [_.strip() for _ in f.readlines()] 119 | 120 | with open(annotation_root / "answer_list.json", "w") as f: 121 | json.dump(words, f) 122 | 123 | 124 | def main(config: DictConfig) -> None: 125 | for split in ("train", "val", "test"): 126 | logger.info("Processing split %s", split) 127 | records = load_split(split, Path(config.ann_root)) 128 | records = [convert_aokvqa_to_vqav2(_) for _ in records] 129 | logger.info("Converted %d records to VQAv2 format", len(records)) 130 | logger.info("Verifying images exist") 131 | for record in records: 132 | point_record_to_coco_image_file(record, Path(config.vqa_root), split) 133 | try: 134 | assert (IMAGES_ROOT / record.image).exists() 135 | except: 136 | import ipdb 137 | 138 | ipdb.set_trace() 139 | serialize_records(records, Path(config.ann_root), split) 140 | save_answer_list_as_json(Path(config.ann_root)) 141 | 142 | 143 | if __name__ == "__main__": 144 | args, config = cli.parse_args(default_config_path="./configs/aokvqa.yaml") 145 | main(config) 146 | -------------------------------------------------------------------------------- /convert_aqua.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from omegaconf import OmegaConf 4 | from dataclasses import dataclass 5 | import schemas 6 | from pydantic import BaseModel, validator 7 | from tqdm import tqdm 8 | import logging 9 | from typing import Tuple, List, Union 10 | 11 | AQUA_ROOT = Path("/net/acadia4a/data/zkhan/vqa-art") 12 | 13 | logger = logging.getLogger(__name__) 14 | handler = logging.StreamHandler() 15 | formatter = logging.Formatter( 16 | "%(asctime)s - %(name)s: %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 17 | ) 18 | handler.setFormatter(formatter) 19 | logger.addHandler(handler) 20 | logger.setLevel(logging.INFO) 21 | 22 | 23 | def load_json(_path: Path): 24 | with open(_path, "r") as f: 25 | return json.load(f) 26 | 27 | 28 | def write_json(_path: Path, data: dict): 29 | with open(_path, "w") as f: 30 | json.dump(data, f) 31 | 32 | 33 | @dataclass 34 | class RawAnnotations: 35 | train: Path = AQUA_ROOT / "raw_annotations" / "train.json" 36 | val: Path = AQUA_ROOT / "raw_annotations" / "val.json" 37 | test: Path = AQUA_ROOT / "raw_annotations" / "test.json" 38 | 39 | 40 | @dataclass 41 | class OutputAnnotations: 42 | train: Path = AQUA_ROOT / "train.json" 43 | train_no_external_knowledge_questions: Path = ( 44 | AQUA_ROOT / "train_no_external_knowledge_questions.json" 45 | ) 46 | val: Path = AQUA_ROOT / "val.json" 47 | test: Path = AQUA_ROOT / "test.json" 48 | val_annotations: Path = AQUA_ROOT / "val_annotations.json" 49 | val_questions: Path = AQUA_ROOT / "val_questions.json" 50 | test_annotations: Path = AQUA_ROOT / "test_annotations.json" 51 | test_questions: Path = AQUA_ROOT / "test_questions.json" 52 | answer_list: Path = AQUA_ROOT / "answer_list.json" 53 | 54 | 55 | @dataclass 56 | class Config: 57 | raw_annotations: RawAnnotations = RawAnnotations() 58 | output_annotations: OutputAnnotations = OutputAnnotations() 59 | semart_images_dir: Path = Path("/net/acadia4a/data/zkhan/SemArt/Images") 60 | 61 | 62 | class AquaRecord(BaseModel): 63 | image: str 64 | question: str 65 | answer: str 66 | need_external_knowledge: bool 67 | 68 | @validator("question") 69 | def normalize(cls, v): 70 | # The original question does not have a question mark at the end, which is different from 71 | # every other VQA dataset. We add the question mark here to make it consistent. 72 | return f"{v}?" 73 | 74 | 75 | def convert_aqua_record_to_train_record( 76 | record: AquaRecord, question_id: int 77 | ) -> schemas.TrainingRecord: 78 | return schemas.TrainingRecord( 79 | question=record.question, 80 | answer=[record.answer], 81 | image=record.image, 82 | dataset="aqua", 83 | question_id=question_id, 84 | ) 85 | 86 | 87 | def convert_aqua_record_to_test_record( 88 | record: AquaRecord, question_id: int 89 | ) -> schemas.TestingRecord: 90 | return schemas.TestingRecord( 91 | question=record.question, 92 | image=record.image, 93 | dataset="aqua", 94 | question_id=question_id, 95 | ) 96 | 97 | 98 | def convert_aqua_record_to_question_record( 99 | record: AquaRecord, question_id: int 100 | ) -> schemas.QuestionRecord: 101 | return schemas.QuestionRecord( 102 | question=record.question, 103 | image_id=record.image, 104 | question_id=question_id, 105 | ) 106 | 107 | 108 | def convert_aqua_record_to_annotation_record( 109 | record: AquaRecord, question_id: int 110 | ) -> schemas.VQAAnnotationRecord: 111 | return schemas.VQAAnnotationRecord( 112 | question_type="external knowledge" 113 | if record.need_external_knowledge 114 | else "no external knowledge", 115 | answers=[ 116 | schemas.VQAAnnotationSubRecord( 117 | answer=record.answer, answer_confidence="yes", answer_id=0 118 | ) 119 | ], 120 | image_id=record.image, 121 | answer_type="yes/no" if record.answer in ("yes", "no") else "other", 122 | question_id=question_id, 123 | ) 124 | 125 | 126 | def make_answer_list(records: List[AquaRecord]) -> List[str]: 127 | answers = set() 128 | for record in records: 129 | answers.add(record.answer) 130 | return list(answers) 131 | 132 | 133 | def verify_image_exists_for_record(image_name: str, semart_images_dir: Path) -> bool: 134 | return (semart_images_dir / image_name).exists() 135 | 136 | 137 | if __name__ == "__main__": 138 | conf: Config = OmegaConf.structured(Config) 139 | train_records_raw = [ 140 | AquaRecord.parse_obj(_) for _ in tqdm(load_json(conf.raw_annotations.train)) 141 | ] 142 | train_question_ids = [i for i in range(len(train_records_raw))] 143 | logger.info("Loaded %d raw train records", len(train_records_raw)) 144 | train_records = [ 145 | convert_aqua_record_to_train_record(_, i) 146 | for i, _ in enumerate(train_records_raw) 147 | ] 148 | logger.info( 149 | "Converted all %d raw train records to train_vqa format", len(train_records) 150 | ) 151 | 152 | train_records_no_external_knowledge_questions = [ 153 | convert_aqua_record_to_train_record(record, i) 154 | for i, record in zip(train_question_ids, train_records_raw) 155 | if not record.need_external_knowledge 156 | ] 157 | 158 | val_records_raw = [ 159 | AquaRecord.parse_obj(_) for _ in tqdm(load_json(conf.raw_annotations.val)) 160 | ] 161 | val_question_ids = [ 162 | i + train_question_ids[-1] + 1 for i in range(len(val_records_raw)) 163 | ] 164 | logger.info( 165 | "Loaded %d raw val records. Converting to train_vqa format", 166 | len(val_records_raw), 167 | ) 168 | val_records = [ 169 | convert_aqua_record_to_test_record(_, i) 170 | for i, _ in zip(val_question_ids, val_records_raw) 171 | ] 172 | 173 | test_records_raw = [ 174 | AquaRecord.parse_obj(_) for _ in tqdm(load_json(conf.raw_annotations.test)) 175 | ] 176 | logger.info( 177 | "Loaded %d raw test records. Converting to train_vqa format", 178 | len(test_records_raw), 179 | ) 180 | test_question_ids = [ 181 | i + val_question_ids[-1] + 1 for i in range(len(test_records_raw)) 182 | ] 183 | test_records = [ 184 | convert_aqua_record_to_test_record(_, i) 185 | for i, _ in zip(test_question_ids, test_records_raw) 186 | ] 187 | 188 | logger.info("Converting test/val records to vqa_eval_tools format") 189 | test_annotations = [ 190 | convert_aqua_record_to_annotation_record(_, i) 191 | for i, _ in tqdm(zip(test_question_ids, test_records_raw)) 192 | ] 193 | test_questions = [ 194 | convert_aqua_record_to_question_record(_, i) 195 | for i, _ in zip(test_question_ids, test_records_raw) 196 | ] 197 | val_annotations = [ 198 | convert_aqua_record_to_annotation_record(_, i) 199 | for i, _ in tqdm(zip(val_question_ids, val_records_raw)) 200 | ] 201 | val_questions = [ 202 | convert_aqua_record_to_question_record(_, i) 203 | for i, _ in zip(val_question_ids, val_records_raw) 204 | ] 205 | 206 | logger.info("Verifying that all images exist") 207 | for r in tqdm(set(_.image for _ in train_records + val_records + test_records)): 208 | assert verify_image_exists_for_record(r, conf.semart_images_dir) 209 | 210 | logger.info("Writing train/val/test records to disk") 211 | write_json(conf.output_annotations.train, [_.dict() for _ in train_records]) 212 | write_json(conf.output_annotations.val, [_.dict() for _ in val_records]) 213 | write_json(conf.output_annotations.test, [_.dict() for _ in test_records]) 214 | write_json( 215 | conf.output_annotations.train_no_external_knowledge_questions, 216 | [_.dict() for _ in train_records_no_external_knowledge_questions], 217 | ) 218 | 219 | logger.info("Writing test/val records to vqa_eval_tools format") 220 | write_json( 221 | conf.output_annotations.test_annotations, 222 | {"annotations": [_.dict() for _ in test_annotations]}, 223 | ) 224 | write_json( 225 | conf.output_annotations.test_questions, 226 | {"questions": [_.dict() for _ in test_questions]}, 227 | ) 228 | 229 | write_json( 230 | conf.output_annotations.val_annotations, 231 | {"annotations": [_.dict() for _ in val_annotations]}, 232 | ) 233 | write_json( 234 | conf.output_annotations.val_questions, 235 | {"questions": [_.dict() for _ in val_questions]}, 236 | ) 237 | 238 | logger.info("Making answer list") 239 | answer_list = make_answer_list(test_records_raw) 240 | logger.info("Made answer list with %d answers", len(answer_list)) 241 | write_json(conf.output_annotations.answer_list, answer_list) 242 | logger.info("Done") 243 | logger.info("Wrote all files to %s", AQUA_ROOT) 244 | -------------------------------------------------------------------------------- /convert_okvqa.py: -------------------------------------------------------------------------------- 1 | import cli 2 | import json 3 | from typing import List, Dict, Literal, Optional 4 | import cattrs 5 | from omegaconf import DictConfig 6 | from pathlib import Path 7 | import logging 8 | from pydantic import BaseModel 9 | import schemas 10 | from enum import Enum 11 | from tqdm import tqdm 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | handler = logging.StreamHandler() 16 | formatter = logging.Formatter( 17 | "%(asctime)s - %(name)s: %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 18 | ) 19 | handler.setFormatter(formatter) 20 | logger.addHandler(handler) 21 | logger.setLevel(logging.INFO) 22 | 23 | 24 | IMAGES_ROOT = Path("/net/acadia10a/data/zkhan/coco2014") 25 | OKVQA_ROOT = Path("/net/acadia10a/data/zkhan/ok-vqa") 26 | 27 | 28 | class SplitType(Enum): 29 | train = "train" 30 | val = "val" 31 | 32 | 33 | class OkvqaSplit(BaseModel): 34 | split: SplitType 35 | annotations: List[schemas.VQAAnnotationRecord] 36 | questions: List[schemas.QuestionRecord] 37 | 38 | @classmethod 39 | def build_from_path( 40 | cls, split_name: str, annotation_path: Path, question_path: Path 41 | ): 42 | logger.info( 43 | f"Loading {split_name} split from {annotation_path} and {question_path}" 44 | ) 45 | with open(annotation_path, "r") as f: 46 | annotations = json.load(f)["annotations"] 47 | with open(question_path, "r") as f: 48 | questions = json.load(f)["questions"] 49 | logger.info( 50 | "Structuring %d annotations and %d questions", 51 | len(annotations), 52 | len(questions), 53 | ) 54 | annotations = [ 55 | schemas.VQAAnnotationRecord.parse_obj(a) for a in tqdm(annotations) 56 | ] 57 | questions = [schemas.QuestionRecord.parse_obj(q) for q in tqdm(questions)] 58 | assert len(annotations) == len(questions) 59 | return cls( 60 | split=SplitType(split_name), annotations=annotations, questions=questions 61 | ) 62 | 63 | 64 | # OKVQA only has training and validation splits. Both of them have annotations and questions. 65 | # This is in contrast to VQAv2, which has a test split with only questions. 66 | # So while we could make a single function to convert both splits to `TrainingRecord`s, we 67 | # keep them separate to be consistent with VQAv2, where the testing split is structured 68 | # as `TestingRecord`s. 69 | 70 | 71 | def make_train_record( 72 | annotation: schemas.VQAAnnotationRecord, question: schemas.QuestionRecord 73 | ) -> schemas.TrainingRecord: 74 | assert annotation.question_id == question.question_id 75 | assert question.image_id == annotation.image_id 76 | return schemas.TrainingRecord( 77 | dataset="okvqa", 78 | image=annotation.image_id, 79 | question=question.question, 80 | question_id=question.question_id, 81 | answer=[_.answer for _ in annotation.answers], 82 | ) 83 | 84 | 85 | def make_val_record(record: schemas.QuestionRecord) -> schemas.TestingRecord: 86 | return schemas.TestingRecord( 87 | question_id=record.question_id, 88 | question=record.question, 89 | image=record.image_id, 90 | dataset="okvqa", 91 | ) 92 | 93 | 94 | def pad_coco_id(coco_id: int) -> str: 95 | return f"{coco_id:0>12}" 96 | 97 | 98 | def point_record_to_coco_image_file( 99 | record: schemas.TrainingRecord, split: SplitType 100 | ) -> schemas.TrainingRecord: 101 | coco_image_id = pad_coco_id(record.image) 102 | relative_path_to_image = ( 103 | f"{split.value}2014/COCO_{split.value}2014_{coco_image_id}.jpg" 104 | ) 105 | record.image = relative_path_to_image 106 | return record 107 | 108 | 109 | if __name__ == "__main__": 110 | train = OkvqaSplit.build_from_path( 111 | "train", 112 | annotation_path=OKVQA_ROOT / "mscoco_train2014_annotations.json", 113 | question_path=OKVQA_ROOT / "OpenEnded_mscoco_train2014_questions.json", 114 | ) 115 | 116 | val = OkvqaSplit.build_from_path( 117 | "val", 118 | annotation_path=OKVQA_ROOT / "mscoco_val2014_annotations.json", 119 | question_path=OKVQA_ROOT / "OpenEnded_mscoco_val2014_questions.json", 120 | ) 121 | logger.info( 122 | "Converting %d annotations and questions into training records", 123 | len(train.annotations), 124 | ) 125 | train_records = [ 126 | make_train_record(a, q) 127 | for a, q in tqdm(zip(train.annotations, train.questions)) 128 | ] 129 | 130 | logger.info("Verifying all %d images are present", len(train_records)) 131 | train_records = [ 132 | point_record_to_coco_image_file(r, train.split) for r in tqdm(train_records) 133 | ] 134 | for record in tqdm(train_records): 135 | assert (IMAGES_ROOT / record.image).exists() 136 | 137 | logger.info("Converting %d questions into validation records", len(val.questions)) 138 | val_records = [make_val_record(q) for q in tqdm(val.questions)] 139 | logger.info("Verifying all %d images are present", len(val_records)) 140 | val_records = [ 141 | point_record_to_coco_image_file(r, val.split) for r in tqdm(val_records) 142 | ] 143 | for record in tqdm(val_records): 144 | assert (IMAGES_ROOT / record.image).exists() 145 | 146 | # Use these for training and generating answers, but use the 147 | # original annotations for scoring the answers. 148 | with open(OKVQA_ROOT / "train.json", "w") as f: 149 | json.dump([r.dict() for r in train_records], f) 150 | logger.info( 151 | "Wrote %d training records to %s", len(train_records), OKVQA_ROOT / "train.json" 152 | ) 153 | 154 | with open(OKVQA_ROOT / "val.json", "w") as f: 155 | json.dump([r.dict() for r in val_records], f) 156 | logger.info( 157 | "Wrote %d validation records to %s", len(val_records), OKVQA_ROOT / "val.json" 158 | ) 159 | -------------------------------------------------------------------------------- /convert_pathvqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from omegaconf import OmegaConf 4 | from dataclasses import dataclass 5 | import schemas 6 | from pydantic import BaseModel, validator 7 | from tqdm import tqdm 8 | import logging 9 | from typing import Tuple, List, Union, Dict 10 | 11 | PATHVQA_ROOT = Path("/net/acadia4a/data/zkhan/pathvqa") 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | handler = logging.StreamHandler() 16 | formatter = logging.Formatter( 17 | "%(asctime)s - %(name)s: %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 18 | ) 19 | handler.setFormatter(formatter) 20 | logger.addHandler(handler) 21 | logger.setLevel(logging.INFO) 22 | 23 | 24 | def load_json(_path: Path): 25 | with open(_path, "r") as f: 26 | return json.load(f) 27 | 28 | 29 | def write_json(_path: Path, data: dict): 30 | with open(_path, "w") as f: 31 | json.dump(data, f) 32 | 33 | 34 | # For evaluting this code, use an exact match evaluation. 35 | # There's only one answer per question in PathVQA, so you 36 | # can't use the VQAv2 evaluation code. 37 | @dataclass 38 | class OutputAnnotations: 39 | train: Path = PATHVQA_ROOT / "train.json" 40 | val: Path = PATHVQA_ROOT / "val.json" 41 | test: Path = PATHVQA_ROOT / "test.json" 42 | answer_list: Path = PATHVQA_ROOT / "answer_list.json" 43 | # This can be easily used to generate questions on images unused during training. 44 | test_val_combined: Path = PATHVQA_ROOT / "test_val_combined.json" 45 | 46 | 47 | @dataclass 48 | class Config: 49 | raw_annotations: Path = PATHVQA_ROOT / "all_data.json" 50 | output_annotations: OutputAnnotations = OutputAnnotations() 51 | pathvqa_images_dir: Path = Path("/net/acadia4a/data/zkhan/pathvqa/images") 52 | 53 | 54 | class PathVQARecord_SuffixQA(BaseModel): 55 | image: str 56 | question: str 57 | answer: str 58 | 59 | 60 | class PathVQARecord_SuffixVQA(BaseModel): 61 | answer_type: str 62 | img_id: str 63 | label: Dict[str, int] 64 | question_id: int 65 | question_type: str 66 | sent: str 67 | 68 | 69 | class PathVQADump(BaseModel): 70 | test_qa: List[PathVQARecord_SuffixQA] 71 | test_vqa: List[PathVQARecord_SuffixVQA] 72 | train_qa: List[PathVQARecord_SuffixQA] 73 | train_vqa: List[PathVQARecord_SuffixVQA] 74 | val_qa: List[PathVQARecord_SuffixQA] 75 | val_vqa: List[PathVQARecord_SuffixVQA] 76 | 77 | 78 | def convert_pathvqa_record_to_train_record( 79 | record_suffixqa: PathVQARecord_SuffixQA, record_suffixvqa: PathVQARecord_SuffixVQA 80 | ) -> schemas.TrainingRecord: 81 | return schemas.TrainingRecord( 82 | question=record_suffixqa.question, 83 | answer=[record_suffixqa.answer], 84 | image=record_suffixqa.image, 85 | dataset="pathvqa", 86 | question_id=record_suffixvqa.question_id, 87 | ) 88 | 89 | 90 | def convert_pathvqa_record_to_evaluation_record( 91 | record_suffixqa: PathVQARecord_SuffixQA, record_suffixvqa: PathVQARecord_SuffixVQA 92 | ) -> schemas.MinimalEvaluationRecord: 93 | return schemas.MinimalEvaluationRecord( 94 | question=record_suffixqa.question, 95 | answer=record_suffixqa.answer, 96 | image=record_suffixqa.image, 97 | dataset="pathvqa", 98 | question_id=record_suffixvqa.question_id, 99 | question_type=record_suffixvqa.question_type, 100 | answer_type=record_suffixvqa.answer_type, 101 | ) 102 | 103 | 104 | def make_training_records(pathvqa_dump: PathVQADump) -> List[schemas.TrainingRecord]: 105 | training_records = [] 106 | for record_suffixqa, record_suffixvqa in zip( 107 | pathvqa_dump.train_qa, pathvqa_dump.train_vqa 108 | ): 109 | training_records.append( 110 | convert_pathvqa_record_to_train_record(record_suffixqa, record_suffixvqa) 111 | ) 112 | return training_records 113 | 114 | 115 | def make_validation_records( 116 | pathvqa_dump: PathVQADump, 117 | ) -> List[schemas.MinimalEvaluationRecord]: 118 | validation_records = [] 119 | for record_suffixqa, record_suffixvqa in zip( 120 | pathvqa_dump.val_qa, pathvqa_dump.val_vqa 121 | ): 122 | validation_records.append( 123 | convert_pathvqa_record_to_evaluation_record( 124 | record_suffixqa, record_suffixvqa 125 | ) 126 | ) 127 | return validation_records 128 | 129 | 130 | def make_testing_records( 131 | pathvqa_dump: PathVQADump, 132 | ) -> List[schemas.MinimalEvaluationRecord]: 133 | testing_records = [] 134 | for record_suffixqa, record_suffixvqa in zip( 135 | pathvqa_dump.test_qa, pathvqa_dump.test_vqa 136 | ): 137 | testing_records.append( 138 | convert_pathvqa_record_to_evaluation_record( 139 | record_suffixqa, record_suffixvqa 140 | ) 141 | ) 142 | return testing_records 143 | 144 | 145 | def redirect_image_and_verify( 146 | image_dir: Path, 147 | record: Union[schemas.TrainingRecord, schemas.MinimalEvaluationRecord], 148 | ) -> None: 149 | image_name = record.image 150 | split, *_ = image_name.split("_") 151 | 152 | path = f"{split}/{image_name}.jpg" 153 | 154 | record.image = path 155 | try: 156 | assert (image_dir / record.image).exists(), f"Image {path} does not exist" 157 | except AssertionError: 158 | import ipdb 159 | 160 | ipdb.set_trace() 161 | 162 | 163 | def make_answer_list(test_records: List[schemas.MinimalEvaluationRecord]) -> List[str]: 164 | answer_list = [] 165 | for record in test_records: 166 | answer_list.append(record.answer) 167 | return list(set(answer_list)) 168 | 169 | 170 | def filter_to_unique_images( 171 | records: List[schemas.TrainingRecord], 172 | ) -> List[schemas.TrainingRecord]: 173 | unique_images = set() 174 | filtered_records = [] 175 | for record in records: 176 | if record.image not in unique_images: 177 | unique_images.add(record.image) 178 | filtered_records.append(record) 179 | return filtered_records 180 | 181 | 182 | if __name__ == "__main__": 183 | conf: Config = OmegaConf.structured(Config) 184 | pathvqa_dump = PathVQADump.parse_obj(load_json(conf.raw_annotations)) 185 | training_records = make_training_records(pathvqa_dump) 186 | logger.info("Made %d training records", len(training_records)) 187 | validation_records = make_validation_records(pathvqa_dump) 188 | logger.info("Made %d validation records", len(validation_records)) 189 | testing_records = make_testing_records(pathvqa_dump) 190 | logger.info("Made %d testing records", len(testing_records)) 191 | logger.info("Verifying all records") 192 | for record in tqdm(training_records + validation_records + testing_records): 193 | redirect_image_and_verify(conf.pathvqa_images_dir, record) 194 | 195 | answer_list = make_answer_list(testing_records) 196 | logger.info("Made answer list with %d answers", len(answer_list)) 197 | 198 | write_json( 199 | conf.output_annotations.train, [record.dict() for record in training_records] 200 | ) 201 | write_json( 202 | conf.output_annotations.val, [record.dict() for record in validation_records] 203 | ) 204 | write_json( 205 | conf.output_annotations.test, [record.dict() for record in testing_records] 206 | ) 207 | write_json(conf.output_annotations.answer_list, answer_list) 208 | write_json( 209 | conf.output_annotations.test_val_combined, 210 | [ 211 | record.dict() 212 | for record in filter_to_unique_images(testing_records + validation_records) 213 | ], 214 | ) 215 | -------------------------------------------------------------------------------- /convert_rsvqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from omegaconf import OmegaConf 4 | from dataclasses import dataclass 5 | import schemas 6 | from pydantic import BaseModel, validator, ValidationError 7 | from tqdm import tqdm 8 | import logging 9 | from typing import Tuple, List, Union, Dict, Optional 10 | 11 | # Remote Sensing VQA is made up of two datasets. One is 12 | # low resolution, the other is high resolution. They have different 13 | # numbers of images, and probably different types and numbers of questions. 14 | # We handle each separately. 15 | RSVQA_LR_ROOT = Path("/net/acadia4a/data/zkhan/rsvqa/low_resolution") 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | handler = logging.StreamHandler() 20 | formatter = logging.Formatter( 21 | "%(asctime)s - %(name)s: %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 22 | ) 23 | handler.setFormatter(formatter) 24 | logger.addHandler(handler) 25 | logger.setLevel(logging.INFO) 26 | 27 | 28 | def load_json(_path: Path): 29 | with open(_path, "r") as f: 30 | return json.load(f) 31 | 32 | 33 | def write_json(_path: Path, data: dict): 34 | with open(_path, "w") as f: 35 | json.dump(data, f) 36 | 37 | 38 | @dataclass 39 | class LowResOutputAnnotations: 40 | train: Path = RSVQA_LR_ROOT / "train.json" 41 | val: Path = RSVQA_LR_ROOT / "val.json" 42 | test: Path = RSVQA_LR_ROOT / "test.json" 43 | # Write the MinimalEvaluationRecord (s) to this file for the scoring. 44 | test_annotations: Path = RSVQA_LR_ROOT / "test_annotations.json" 45 | answer_list: Path = RSVQA_LR_ROOT / "answer_list.json" 46 | 47 | 48 | @dataclass 49 | class LowResRawAnnotations: 50 | all_answers: Path = RSVQA_LR_ROOT / "all_answers.json" 51 | all_questions: Path = RSVQA_LR_ROOT / "all_questions.json" 52 | train_answers: Path = RSVQA_LR_ROOT / "LR_split_train_answers.json" 53 | train_questions: Path = RSVQA_LR_ROOT / "LR_split_train_questions.json" 54 | 55 | 56 | @dataclass 57 | class LowResConfig: 58 | raw_annotations: LowResRawAnnotations = LowResRawAnnotations() 59 | output_annotations: LowResOutputAnnotations = LowResOutputAnnotations() 60 | # The images are named like 0.tif, 1.tif, etc. 61 | low_res_images_dir: Path = Path( 62 | "/net/acadia4a/data/zkhan/rsvqa/low_resolution/Images_LR" 63 | ) 64 | 65 | 66 | class LRAnswerRecord(BaseModel): 67 | id: int 68 | date_added: float 69 | question_id: int 70 | people_id: int 71 | answer: str 72 | active: bool 73 | 74 | 75 | class LRQuestionRecord(BaseModel): 76 | id: int 77 | date_added: float 78 | img_id: int 79 | people_id: int 80 | type: str 81 | question: str 82 | answers_ids: List[int] 83 | active: bool 84 | 85 | 86 | class RsVqaLrDump(BaseModel): 87 | all_answers: List[LRAnswerRecord] 88 | all_questions: List[LRQuestionRecord] 89 | # The train answers and questions are the same length as all_questions and 90 | # all_answers, but have null fields for questions which are not in the train 91 | # set. That is how we identify what is in the training split and what is in the 92 | # test split. 93 | train_answers: List[Dict] 94 | train_questions: List[Dict] 95 | 96 | @classmethod 97 | def from_cfg(cls, cfg: LowResConfig): 98 | return cls( 99 | all_answers=load_json(cfg.raw_annotations.all_answers)["answers"], 100 | all_questions=load_json(cfg.raw_annotations.all_questions)["questions"], 101 | train_answers=load_json(cfg.raw_annotations.train_answers)["answers"], 102 | train_questions=load_json(cfg.raw_annotations.train_questions)["questions"], 103 | ) 104 | 105 | 106 | class RsVqaLrSplit(BaseModel): 107 | answers: List[LRAnswerRecord] 108 | questions: List[LRQuestionRecord] 109 | 110 | 111 | class StandardFormatSplit(BaseModel): 112 | training_records: List[schemas.TrainingRecord] 113 | testing_records: List[schemas.MinimalEvaluationRecord] 114 | 115 | 116 | def make_splits(rsvqa_lr_dump: RsVqaLrDump) -> Tuple[RsVqaLrSplit]: 117 | logger.info("Making splits...") 118 | indices = [i for i in range(len(rsvqa_lr_dump.all_answers))] 119 | is_test = [] 120 | for a in rsvqa_lr_dump.train_answers: 121 | is_test.append(a.get("answer") is None) 122 | 123 | train_indices = [i for i, test in zip(indices, is_test) if not test] 124 | test_indices = [i for i, test in zip(indices, is_test) if test] 125 | 126 | logger.info("Number of train questions: %d", len(train_indices)) 127 | logger.info("Number of test questions: %d", len(test_indices)) 128 | 129 | # Pull the records belonging to the train split from all_answers and all_questions. 130 | train_answers = [rsvqa_lr_dump.all_answers[i] for i in train_indices] 131 | train_questions = [rsvqa_lr_dump.all_questions[i] for i in train_indices] 132 | 133 | # Pull the records belonging to the test split from all_answers and all_questions. 134 | test_answers = [rsvqa_lr_dump.all_answers[i] for i in test_indices] 135 | test_questions = [rsvqa_lr_dump.all_questions[i] for i in test_indices] 136 | 137 | train_split = RsVqaLrSplit(answers=train_answers, questions=train_questions) 138 | test_split = RsVqaLrSplit(answers=test_answers, questions=test_questions) 139 | 140 | # Sanity check that each answer belongs to each question. 141 | logger.info("Sanity checking each answer and question are paired correctly") 142 | for a, q in zip(train_split.answers, train_split.questions): 143 | assert a.answer == rsvqa_lr_dump.all_answers[q.answers_ids[0]].answer 144 | 145 | for a, q in zip(test_split.answers, test_split.questions): 146 | assert a.answer == rsvqa_lr_dump.all_answers[q.answers_ids[0]].answer 147 | # Check to make sure each question only has one answer. In other VQA datasets, 148 | # questions can have multiple answers. From my brief perusal, RSVQA does not, but 149 | # let's double check that assumption. 150 | for q in rsvqa_lr_dump.all_questions: 151 | assert len(q.answers_ids) == 1 152 | logger.info("Sanity check passed") 153 | 154 | return train_split, test_split 155 | 156 | 157 | def convert_rsvqa_lr_record_tuple_to_vqa_format( 158 | question: LRQuestionRecord, answer: LRAnswerRecord 159 | ) -> Tuple[schemas.TrainingRecord, schemas.MinimalEvaluationRecord]: 160 | training_record = schemas.TrainingRecord( 161 | dataset="rsvqa_lr", 162 | image=question.img_id, 163 | question_id=question.id, 164 | question=question.question, 165 | answer=[answer.answer], 166 | ) 167 | 168 | try: 169 | evaluation_record = schemas.MinimalEvaluationRecord( 170 | question=question.question, 171 | dataset="rsvqa_lr", 172 | image=question.img_id, 173 | question_id=question.id, 174 | answer=answer.answer, 175 | question_type=question.type, 176 | answer_type="default", 177 | ) 178 | except ValidationError: 179 | import ipdb 180 | 181 | ipdb.set_trace() 182 | 183 | return training_record, evaluation_record 184 | 185 | 186 | def transform_split_into_std_format(split: RsVqaLrSplit) -> StandardFormatSplit: 187 | training_records = [] 188 | evaluation_records = [] 189 | for q, a in zip(split.questions, split.answers): 190 | ( 191 | training_record, 192 | evaluation_record, 193 | ) = convert_rsvqa_lr_record_tuple_to_vqa_format(q, a) 194 | training_records.append(training_record) 195 | evaluation_records.append(evaluation_record) 196 | 197 | return StandardFormatSplit( 198 | training_records=training_records, testing_records=evaluation_records 199 | ) 200 | 201 | 202 | def redirect_and_verify_image(record: schemas.TrainingRecord, cfg: LowResConfig): 203 | image_dir = cfg.low_res_images_dir 204 | 205 | record.image = f"{record.image}.tif" 206 | assert (image_dir / record.image).exists() 207 | 208 | 209 | def make_answer_list(records: List[schemas.TrainingRecord]) -> List[str]: 210 | answers = [] 211 | for record in records: 212 | answers.extend(record.answer) 213 | return list(set(answers)) 214 | 215 | 216 | if __name__ == "__main__": 217 | lr_config: LowResConfig = OmegaConf.structured(LowResConfig) 218 | lr_dump: RsVqaLrDump = RsVqaLrDump.from_cfg(lr_config) 219 | 220 | logger.info( 221 | "Loaded %d answers and %d questions", 222 | len(lr_dump.all_answers), 223 | len(lr_dump.all_questions), 224 | ) 225 | logger.info("Training questions: %d", len(lr_dump.train_questions)) 226 | logger.info("Training answers: %d", len(lr_dump.train_answers)) 227 | 228 | lr_train, lr_test = make_splits(lr_dump) 229 | 230 | logger.info("Converting splits into standard format") 231 | lr_train = transform_split_into_std_format(lr_train) 232 | lr_test = transform_split_into_std_format(lr_test) 233 | 234 | logger.info("Train split in standard format: %d", len(lr_train.training_records)) 235 | logger.info("Test split in standard format: %d", len(lr_test.testing_records)) 236 | 237 | logger.info("Verifying images in training split exist") 238 | for record in tqdm(lr_train.training_records): 239 | redirect_and_verify_image(record, lr_config) 240 | 241 | logger.info("Verifying images in test split exist") 242 | for record in tqdm(lr_test.training_records): 243 | redirect_and_verify_image(record, lr_config) 244 | for record in tqdm(lr_test.testing_records): 245 | redirect_and_verify_image(record, lr_config) 246 | 247 | logger.info("Saving splits to disk") 248 | write_json( 249 | lr_config.output_annotations.train, 250 | [_.dict() for _ in lr_train.training_records], 251 | ) 252 | write_json( 253 | lr_config.output_annotations.test, [_.dict() for _ in lr_test.training_records] 254 | ) 255 | write_json( 256 | lr_config.output_annotations.test_annotations, 257 | [_.dict() for _ in lr_test.testing_records], 258 | ) 259 | 260 | logger.info("Making answer list") 261 | answer_list = make_answer_list(lr_test.training_records) 262 | logger.info("Answer list has %d answers", len(answer_list)) 263 | write_json(lr_config.output_annotations.answer_list, answer_list) 264 | -------------------------------------------------------------------------------- /convert_vqa_ce.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script constructs the VQA-CE dataset from the VQAv2 dataset. 3 | The VQA-CE dataset is a slice of the VQAv2 validation set which was 4 | constructed so that models which have learned shortcuts to answer 5 | questions will perform poorly. The dataset does not require retraining 6 | a model, and is evaluation only. 7 | """ 8 | 9 | import json 10 | from pathlib import Path 11 | from typing import List 12 | from pydantic import BaseModel 13 | import schemas 14 | from tqdm import tqdm 15 | import logging 16 | import shutil 17 | 18 | logger = logging.getLogger(__name__) 19 | handler = logging.StreamHandler() 20 | formatter = logging.Formatter( 21 | "%(asctime)s - %(name)s: %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 22 | ) 23 | handler.setFormatter(formatter) 24 | logger.addHandler(handler) 25 | logger.setLevel(logging.INFO) 26 | 27 | 28 | def load_json(path): 29 | with open(path, "r") as f: 30 | return json.load(f) 31 | 32 | 33 | VQA_V2_ROOT = Path("/net/acadia10a/data/zkhan/vqav2_annotations") 34 | PATH_RAW_ANNOTATIONS = VQA_V2_ROOT / "raw_val_annotations_coco_val2014.json" 35 | PATH_RAW_QUESTIONS = VQA_V2_ROOT / "raw_val_questions_coco_val2014.json" 36 | PATH_VAL_RECORDS_WHOLE = VQA_V2_ROOT / "vqa_val.json" 37 | 38 | VQA_CE_ROOT = Path("/net/acadia4a/data/zkhan/vqa-counterexamples") 39 | HARD_SLICE_PATH = VQA_CE_ROOT / "hard.json" 40 | COUNTEREXAMPLE_SLICE_PATH = VQA_CE_ROOT / "counterexamples.json" 41 | 42 | 43 | class Slice(BaseModel): 44 | questions: List[schemas.QuestionRecord] 45 | annotations: List[schemas.VQAAnnotationRecord] 46 | testing_records: List[schemas.TestingRecord] 47 | 48 | @classmethod 49 | def build_from_whole_slice_given_question_ids( 50 | cls, qids: List[int], whole_slice: "Slice" 51 | ) -> "Slice": 52 | qids = set(qids) 53 | questions = [q for q in whole_slice.questions if q.question_id in qids] 54 | annotations = [a for a in whole_slice.annotations if a.question_id in qids] 55 | testing_records = [ 56 | t for t in whole_slice.testing_records if t.question_id in qids 57 | ] 58 | import ipdb 59 | 60 | ipdb.set_trace() 61 | # For some reason, one of the testing records is missing from the annotations. 62 | # Not going to track it down, will just drop it from the questions and annotations. 63 | try: 64 | assert len(questions) == len(annotations) == len(testing_records) 65 | except AssertionError: 66 | logger.warning( 67 | f"Number of questions, annotations, and testing records do not match." 68 | ) 69 | logger.warning(f"Questions: {len(questions)}") 70 | logger.warning(f"Annotations: {len(annotations)}") 71 | logger.warning(f"Testing records: {len(testing_records)}") 72 | qids_missing_from_testing_records = set( 73 | [q.question_id for q in questions] 74 | ) - set([t.question_id for t in testing_records]) 75 | logger.info( 76 | f"Dropping question IDs missing from testing records: {qids_missing_from_testing_records}" 77 | ) 78 | questions = [ 79 | q 80 | for q in questions 81 | if q.question_id not in qids_missing_from_testing_records 82 | ] 83 | annotations = [ 84 | a 85 | for a in annotations 86 | if a.question_id not in qids_missing_from_testing_records 87 | ] 88 | return cls( 89 | questions=questions, 90 | annotations=annotations, 91 | testing_records=testing_records, 92 | ) 93 | 94 | def serialize_slice(self, prefix: str): 95 | with open(VQA_CE_ROOT / f"{prefix}_questions.json", "w") as f: 96 | json.dump({"questions": [_.dict() for _ in tqdm(self.questions)]}, f) 97 | 98 | with open(VQA_CE_ROOT / f"{prefix}_annotations.json", "w") as f: 99 | json.dump({"annotations": [_.dict() for _ in tqdm(self.annotations)]}, f) 100 | 101 | with open(VQA_CE_ROOT / f"{prefix}_testing_records.json", "w") as f: 102 | json.dump([_.dict() for _ in tqdm(self.testing_records)], f) 103 | 104 | 105 | if __name__ == "__main__": 106 | logger.info("Loading raw annotations and questions.") 107 | raw_annotations = load_json(PATH_RAW_ANNOTATIONS) 108 | raw_questions = load_json(PATH_RAW_QUESTIONS) 109 | val_records_whole = load_json(PATH_VAL_RECORDS_WHOLE) 110 | logger.info( 111 | "Loaded %d raw annotations and %d raw questions.", 112 | len(raw_annotations), 113 | len(raw_questions), 114 | ) 115 | 116 | logger.info("Building easy slice.") 117 | hard_slice_qids = load_json(HARD_SLICE_PATH) 118 | counterexample_slice_qids = load_json(COUNTEREXAMPLE_SLICE_PATH) 119 | all_question_ids = set([q["question_id"] for q in raw_questions["questions"]]) 120 | easy_slice = all_question_ids - set(counterexample_slice_qids).union( 121 | set(hard_slice_qids) 122 | ) 123 | 124 | logger.info("Building whole slice.") 125 | whole_slice = Slice( 126 | questions=[ 127 | schemas.QuestionRecord.parse_obj(q) 128 | for q in tqdm(raw_questions["questions"]) 129 | ], 130 | annotations=[ 131 | schemas.VQAAnnotationRecord.parse_obj(a) 132 | for a in tqdm(raw_annotations["annotations"]) 133 | ], 134 | testing_records=[ 135 | schemas.TestingRecord.parse_obj(t) for t in tqdm(val_records_whole) 136 | ], 137 | ) 138 | 139 | hard_slice = Slice.build_from_whole_slice_given_question_ids( 140 | hard_slice_qids, whole_slice 141 | ) 142 | logger.info("Built hard slice with %d questions.", len(hard_slice.questions)) 143 | counterexample_slice = Slice.build_from_whole_slice_given_question_ids( 144 | counterexample_slice_qids, whole_slice 145 | ) 146 | logger.info( 147 | "Built counterexample slice with %d questions.", 148 | len(counterexample_slice.questions), 149 | ) 150 | easy_slice = Slice.build_from_whole_slice_given_question_ids( 151 | easy_slice, whole_slice 152 | ) 153 | logger.info("Built easy slice with %d questions.", len(easy_slice.questions)) 154 | 155 | logger.info("Serializing slices.") 156 | hard_slice.serialize_slice("hard") 157 | counterexample_slice.serialize_slice("counterexamples") 158 | easy_slice.serialize_slice("easy") 159 | 160 | logger.info("Copying answer list from VQAv2 to VQA-CE.") 161 | shutil.copyfile(VQA_V2_ROOT / "answer_list.json", VQA_CE_ROOT / "answer_list.json") 162 | 163 | # For ease of use, we copy counterexamples_testing_records.json to val.json, because 164 | # as of now, all the datasets are hardcoded to load the validation split from a file 165 | # called val.json. 166 | logger.info("Copying counterexamples_testing_records.json to val.json.") 167 | shutil.copyfile( 168 | VQA_CE_ROOT / "counterexamples_testing_records.json", VQA_CE_ROOT / "val.json" 169 | ) 170 | -------------------------------------------------------------------------------- /convert_vqa_rephrasings.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts the VQA-Rephrasings dataset into a format usable by `train_vqa.py`. 3 | VQA-Rephrasings is a test set only dataset, which supplies rephrasings of the questions 4 | in the VQAv2 validation set (the VQAv2 test set is not publically available.) The images 5 | come from the COCO14 validation set. 6 | """ 7 | 8 | 9 | import json 10 | from typing import List, Dict, Literal, Optional 11 | from omegaconf import DictConfig 12 | from pathlib import Path 13 | import logging 14 | import schemas 15 | from enum import Enum 16 | from tqdm import tqdm 17 | import shutil 18 | 19 | logger = logging.getLogger(__name__) 20 | handler = logging.StreamHandler() 21 | formatter = logging.Formatter( 22 | "%(asctime)s - %(name)s: %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 23 | ) 24 | handler.setFormatter(formatter) 25 | logger.addHandler(handler) 26 | logger.setLevel(logging.INFO) 27 | 28 | 29 | IMAGES_ROOT = Path("/net/acadia10a/data/zkhan/coco2014") 30 | VQA_REPHRASINGS_ROOT = Path("/net/acadia4a/data/zkhan/vqa-rephrasings") 31 | VQA_V2_ANSWER_LIST = "/net/acadia10a/data/zkhan/vqav2_annotations/answer_list.json" 32 | 33 | 34 | def load_json(path: Path) -> Dict: 35 | with open(path, "r") as f: 36 | return json.load(f) 37 | 38 | 39 | def make_val_record(record: schemas.QuestionRecord) -> schemas.TestingRecord: 40 | return schemas.TestingRecord( 41 | question_id=record.question_id, 42 | question=record.question, 43 | image=record.image_id, 44 | dataset="vqa-rephrasings", 45 | ) 46 | 47 | 48 | def pad_coco_id(coco_id: int) -> str: 49 | return f"{coco_id:0>12}" 50 | 51 | 52 | def point_record_to_coco_image_file( 53 | record: schemas.TrainingRecord, split: str 54 | ) -> schemas.TrainingRecord: 55 | coco_image_id = pad_coco_id(record.image) 56 | relative_path_to_image = f"{split}2014/COCO_{split}2014_{coco_image_id}.jpg" 57 | record.image = relative_path_to_image 58 | return record 59 | 60 | 61 | if __name__ == "__main__": 62 | val_questions_raw = load_json( 63 | VQA_REPHRASINGS_ROOT / "v2_OpenEnded_mscoco_valrep2014_humans_og_questions.json" 64 | ) 65 | 66 | logger.info("Converting %d records", len(val_questions_raw["questions"])) 67 | 68 | val_questions = [ 69 | make_val_record(schemas.QuestionRecord.parse_obj(_)) 70 | for _ in val_questions_raw["questions"] 71 | ] 72 | 73 | logger.info("Verifying all image paths are correct.") 74 | val_records = [ 75 | point_record_to_coco_image_file(r, "val") for r in tqdm(val_questions) 76 | ] 77 | for record in tqdm(val_records): 78 | assert (IMAGES_ROOT / record.image).exists() 79 | 80 | with open(VQA_REPHRASINGS_ROOT / "val.json", "w") as f: 81 | json.dump([r.dict() for r in val_records], f) 82 | logger.info( 83 | "Wrote %d validation records to %s", 84 | len(val_records), 85 | VQA_REPHRASINGS_ROOT / "val.json", 86 | ) 87 | shutil.copyfile(VQA_V2_ANSWER_LIST, VQA_REPHRASINGS_ROOT / "answer_list.json") 88 | 89 | # Make a fake training file just so `train_vqa.py` doesn't complain. 90 | fake_training_record = schemas.TrainingRecord( 91 | question_id=0, 92 | question="fake question", 93 | image="/not/a/real/path.jpg", 94 | answer=["fake answer"], 95 | dataset="vqa-rephrasings", 96 | ) 97 | with open(VQA_REPHRASINGS_ROOT / "train.json", "w") as f: 98 | json.dump([fake_training_record.dict()], f) 99 | -------------------------------------------------------------------------------- /data/coco_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | 12 | class coco_karpathy_train(Dataset): 13 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=""): 14 | """ 15 | image_root (string): Root directory of images (e.g. coco/images/) 16 | ann_root (string): directory to store the annotation file 17 | """ 18 | url = "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json" 19 | filename = "coco_karpathy_train.json" 20 | 21 | download_url(url, ann_root) 22 | 23 | self.annotation = json.load(open(os.path.join(ann_root, filename), "r")) 24 | self.transform = transform 25 | self.image_root = image_root 26 | self.max_words = max_words 27 | self.prompt = prompt 28 | 29 | self.img_ids = {} 30 | n = 0 31 | for ann in self.annotation: 32 | img_id = ann["image_id"] 33 | if img_id not in self.img_ids.keys(): 34 | self.img_ids[img_id] = n 35 | n += 1 36 | 37 | def __len__(self): 38 | return len(self.annotation) 39 | 40 | def __getitem__(self, index): 41 | 42 | ann = self.annotation[index] 43 | 44 | image_path = os.path.join(self.image_root, ann["image"]) 45 | image = Image.open(image_path).convert("RGB") 46 | image = self.transform(image) 47 | 48 | caption = self.prompt + pre_caption(ann["caption"], self.max_words) 49 | 50 | return image, caption, self.img_ids[ann["image_id"]] 51 | 52 | 53 | class coco_karpathy_caption_eval(Dataset): 54 | def __init__(self, transform, image_root, ann_root, split): 55 | """ 56 | image_root (string): Root directory of images (e.g. coco/images/) 57 | ann_root (string): directory to store the annotation file 58 | split (string): val or test 59 | """ 60 | urls = { 61 | "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json", 62 | "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json", 63 | } 64 | filenames = {"val": "coco_karpathy_val.json", "test": "coco_karpathy_test.json"} 65 | 66 | download_url(urls[split], ann_root) 67 | 68 | self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r")) 69 | self.transform = transform 70 | self.image_root = image_root 71 | 72 | def __len__(self): 73 | return len(self.annotation) 74 | 75 | def __getitem__(self, index): 76 | 77 | ann = self.annotation[index] 78 | 79 | image_path = os.path.join(self.image_root, ann["image"]) 80 | image = Image.open(image_path).convert("RGB") 81 | image = self.transform(image) 82 | 83 | img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1] 84 | 85 | return image, int(img_id) 86 | 87 | 88 | class coco_karpathy_retrieval_eval(Dataset): 89 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 90 | """ 91 | image_root (string): Root directory of images (e.g. coco/images/) 92 | ann_root (string): directory to store the annotation file 93 | split (string): val or test 94 | """ 95 | urls = { 96 | "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json", 97 | "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json", 98 | } 99 | filenames = {"val": "coco_karpathy_val.json", "test": "coco_karpathy_test.json"} 100 | 101 | download_url(urls[split], ann_root) 102 | 103 | self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r")) 104 | self.transform = transform 105 | self.image_root = image_root 106 | 107 | self.text = [] 108 | self.image = [] 109 | self.txt2img = {} 110 | self.img2txt = {} 111 | 112 | txt_id = 0 113 | for img_id, ann in enumerate(self.annotation): 114 | self.image.append(ann["image"]) 115 | self.img2txt[img_id] = [] 116 | for i, caption in enumerate(ann["caption"]): 117 | self.text.append(pre_caption(caption, max_words)) 118 | self.img2txt[img_id].append(txt_id) 119 | self.txt2img[txt_id] = img_id 120 | txt_id += 1 121 | 122 | def __len__(self): 123 | return len(self.annotation) 124 | 125 | def __getitem__(self, index): 126 | 127 | image_path = os.path.join(self.image_root, self.annotation[index]["image"]) 128 | image = Image.open(image_path).convert("RGB") 129 | image = self.transform(image) 130 | 131 | return image, index 132 | -------------------------------------------------------------------------------- /data/flickr30k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | 12 | class flickr30k_train(Dataset): 13 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=""): 14 | """ 15 | image_root (string): Root directory of images (e.g. flickr30k/) 16 | ann_root (string): directory to store the annotation file 17 | """ 18 | url = "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json" 19 | filename = "flickr30k_train.json" 20 | 21 | download_url(url, ann_root) 22 | 23 | self.annotation = json.load(open(os.path.join(ann_root, filename), "r")) 24 | self.transform = transform 25 | self.image_root = image_root 26 | self.max_words = max_words 27 | self.prompt = prompt 28 | 29 | self.img_ids = {} 30 | n = 0 31 | for ann in self.annotation: 32 | img_id = ann["image_id"] 33 | if img_id not in self.img_ids.keys(): 34 | self.img_ids[img_id] = n 35 | n += 1 36 | 37 | def __len__(self): 38 | return len(self.annotation) 39 | 40 | def __getitem__(self, index): 41 | 42 | ann = self.annotation[index] 43 | 44 | image_path = os.path.join(self.image_root, ann["image"]) 45 | image = Image.open(image_path).convert("RGB") 46 | image = self.transform(image) 47 | 48 | caption = self.prompt + pre_caption(ann["caption"], self.max_words) 49 | 50 | return image, caption, self.img_ids[ann["image_id"]] 51 | 52 | 53 | class flickr30k_retrieval_eval(Dataset): 54 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 55 | """ 56 | image_root (string): Root directory of images (e.g. flickr30k/) 57 | ann_root (string): directory to store the annotation file 58 | split (string): val or test 59 | """ 60 | urls = { 61 | "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json", 62 | "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json", 63 | } 64 | filenames = {"val": "flickr30k_val.json", "test": "flickr30k_test.json"} 65 | 66 | download_url(urls[split], ann_root) 67 | 68 | self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r")) 69 | self.transform = transform 70 | self.image_root = image_root 71 | 72 | self.text = [] 73 | self.image = [] 74 | self.txt2img = {} 75 | self.img2txt = {} 76 | 77 | txt_id = 0 78 | for img_id, ann in enumerate(self.annotation): 79 | self.image.append(ann["image"]) 80 | self.img2txt[img_id] = [] 81 | for i, caption in enumerate(ann["caption"]): 82 | self.text.append(pre_caption(caption, max_words)) 83 | self.img2txt[img_id].append(txt_id) 84 | self.txt2img[txt_id] = img_id 85 | txt_id += 1 86 | 87 | def __len__(self): 88 | return len(self.annotation) 89 | 90 | def __getitem__(self, index): 91 | 92 | image_path = os.path.join(self.image_root, self.annotation[index]["image"]) 93 | image = Image.open(image_path).convert("RGB") 94 | image = self.transform(image) 95 | 96 | return image, index 97 | -------------------------------------------------------------------------------- /data/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets.utils import download_url 7 | 8 | from PIL import Image 9 | 10 | from data.utils import pre_caption 11 | 12 | 13 | class nlvr_dataset(Dataset): 14 | def __init__(self, transform, image_root, ann_root, split): 15 | """ 16 | image_root (string): Root directory of images 17 | ann_root (string): directory to store the annotation file 18 | split (string): train, val or test 19 | """ 20 | urls = { 21 | "train": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json", 22 | "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json", 23 | "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json", 24 | } 25 | filenames = { 26 | "train": "nlvr_train.json", 27 | "val": "nlvr_dev.json", 28 | "test": "nlvr_test.json", 29 | } 30 | 31 | download_url(urls[split], ann_root) 32 | self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r")) 33 | 34 | self.transform = transform 35 | self.image_root = image_root 36 | 37 | def __len__(self): 38 | return len(self.annotation) 39 | 40 | def __getitem__(self, index): 41 | 42 | ann = self.annotation[index] 43 | 44 | image0_path = os.path.join(self.image_root, ann["images"][0]) 45 | image0 = Image.open(image0_path).convert("RGB") 46 | image0 = self.transform(image0) 47 | 48 | image1_path = os.path.join(self.image_root, ann["images"][1]) 49 | image1 = Image.open(image1_path).convert("RGB") 50 | image1 = self.transform(image1) 51 | 52 | sentence = pre_caption(ann["sentence"], 40) 53 | 54 | if ann["label"] == "True": 55 | label = 1 56 | else: 57 | label = 0 58 | 59 | words = sentence.split(" ") 60 | 61 | if "left" not in words and "right" not in words: 62 | if random.random() < 0.5: 63 | return image0, image1, sentence, label 64 | else: 65 | return image1, image0, sentence, label 66 | else: 67 | if random.random() < 0.5: 68 | return image0, image1, sentence, label 69 | else: 70 | new_words = [] 71 | for word in words: 72 | if word == "left": 73 | new_words.append("right") 74 | elif word == "right": 75 | new_words.append("left") 76 | else: 77 | new_words.append(word) 78 | 79 | sentence = " ".join(new_words) 80 | return image1, image0, sentence, label 81 | -------------------------------------------------------------------------------- /data/nocaps_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | 10 | class nocaps_eval(Dataset): 11 | def __init__(self, transform, image_root, ann_root, split): 12 | urls = { 13 | "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json", 14 | "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json", 15 | } 16 | filenames = {"val": "nocaps_val.json", "test": "nocaps_test.json"} 17 | 18 | download_url(urls[split], ann_root) 19 | 20 | self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r")) 21 | self.transform = transform 22 | self.image_root = image_root 23 | 24 | def __len__(self): 25 | return len(self.annotation) 26 | 27 | def __getitem__(self, index): 28 | 29 | ann = self.annotation[index] 30 | 31 | image_path = os.path.join(self.image_root, ann["image"]) 32 | image = Image.open(image_path).convert("RGB") 33 | image = self.transform(image) 34 | 35 | return image, int(ann["img_id"]) 36 | -------------------------------------------------------------------------------- /data/pretrain_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | from PIL import ImageFile 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | Image.MAX_IMAGE_PIXELS = None 12 | 13 | from data.utils import pre_caption 14 | import os, glob 15 | 16 | 17 | class pretrain_dataset(Dataset): 18 | def __init__(self, ann_file, laion_path, transform): 19 | 20 | self.ann_pretrain = [] 21 | for f in ann_file: 22 | print("loading " + f) 23 | ann = json.load(open(f, "r")) 24 | self.ann_pretrain += ann 25 | 26 | self.laion_path = laion_path 27 | if self.laion_path: 28 | self.laion_files = glob.glob(os.path.join(laion_path, "*.json")) 29 | 30 | print("loading " + self.laion_files[0]) 31 | with open(self.laion_files[0], "r") as f: 32 | self.ann_laion = json.load(f) 33 | 34 | self.annotation = self.ann_pretrain + self.ann_laion 35 | else: 36 | self.annotation = self.ann_pretrain 37 | 38 | self.transform = transform 39 | 40 | def reload_laion(self, epoch): 41 | n = epoch % len(self.laion_files) 42 | print("loading " + self.laion_files[n]) 43 | with open(self.laion_files[n], "r") as f: 44 | self.ann_laion = json.load(f) 45 | 46 | self.annotation = self.ann_pretrain + self.ann_laion 47 | 48 | def __len__(self): 49 | return len(self.annotation) 50 | 51 | def __getitem__(self, index): 52 | 53 | ann = self.annotation[index] 54 | 55 | image = Image.open(ann["image"]).convert("RGB") 56 | image = self.transform(image) 57 | caption = pre_caption(ann["caption"], 30) 58 | 59 | return image, caption 60 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import utils 9 | 10 | 11 | def pre_caption(caption, max_words=50): 12 | caption = re.sub( 13 | r"([.!\"()*#:;~])", 14 | " ", 15 | caption.lower(), 16 | ) 17 | caption = re.sub( 18 | r"\s{2,}", 19 | " ", 20 | caption, 21 | ) 22 | caption = caption.rstrip("\n") 23 | caption = caption.strip(" ") 24 | 25 | # truncate caption 26 | caption_words = caption.split(" ") 27 | if len(caption_words) > max_words: 28 | caption = " ".join(caption_words[:max_words]) 29 | 30 | return caption 31 | 32 | 33 | def pre_question(question, max_ques_words=50): 34 | question = re.sub( 35 | r"([.!\"()*#:;~])", 36 | "", 37 | question.lower(), 38 | ) 39 | question = question.rstrip(" ") 40 | 41 | # truncate question 42 | question_words = question.split(" ") 43 | if len(question_words) > max_ques_words: 44 | question = " ".join(question_words[:max_ques_words]) 45 | 46 | return question 47 | 48 | 49 | def save_result(result, result_dir, filename, remove_duplicate=""): 50 | result_file = os.path.join( 51 | result_dir, "%s_rank%d.json" % (filename, utils.get_rank()) 52 | ) 53 | final_result_file = os.path.join(result_dir, "%s.json" % filename) 54 | 55 | json.dump(result, open(result_file, "w")) 56 | 57 | dist.barrier() 58 | 59 | if utils.is_main_process(): 60 | # combine results from all processes 61 | result = [] 62 | 63 | for rank in range(utils.get_world_size()): 64 | result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank)) 65 | res = json.load(open(result_file, "r")) 66 | result += res 67 | 68 | if remove_duplicate: 69 | result_new = [] 70 | id_list = [] 71 | for res in result: 72 | if res[remove_duplicate] not in id_list: 73 | id_list.append(res[remove_duplicate]) 74 | result_new.append(res) 75 | result = result_new 76 | 77 | json.dump(result, open(final_result_file, "w")) 78 | print("result file saved to %s" % final_result_file) 79 | 80 | return final_result_file 81 | 82 | 83 | # from pycocotools.coco import COCO 84 | # from pycocoevalcap.eval import COCOEvalCap 85 | from torchvision.datasets.utils import download_url 86 | 87 | 88 | def coco_caption_eval(coco_gt_root, results_file, split): 89 | urls = { 90 | "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json", 91 | "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json", 92 | } 93 | filenames = { 94 | "val": "coco_karpathy_val_gt.json", 95 | "test": "coco_karpathy_test_gt.json", 96 | } 97 | 98 | download_url(urls[split], coco_gt_root) 99 | annotation_file = os.path.join(coco_gt_root, filenames[split]) 100 | 101 | # create coco object and coco_result object 102 | coco = COCO(annotation_file) 103 | coco_result = coco.loadRes(results_file) 104 | 105 | # create coco_eval object by taking coco and coco_result 106 | coco_eval = COCOEvalCap(coco, coco_result) 107 | 108 | # evaluate on a subset of images by setting 109 | # coco_eval.params['image_id'] = coco_result.getImgIds() 110 | # please remove this line when evaluating the full validation set 111 | # coco_eval.params['image_id'] = coco_result.getImgIds() 112 | 113 | # evaluate results 114 | # SPICE will take a few minutes the first time, but speeds up due to caching 115 | coco_eval.evaluate() 116 | 117 | # print output evaluation scores 118 | for metric, score in coco_eval.eval.items(): 119 | print(f"{metric}: {score:.3f}") 120 | 121 | return coco_eval 122 | -------------------------------------------------------------------------------- /data/video_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision.datasets.utils import download_url 3 | 4 | from PIL import Image 5 | import torch 6 | import numpy as np 7 | import random 8 | import decord 9 | from decord import VideoReader 10 | import json 11 | import os 12 | from data.utils import pre_caption 13 | 14 | decord.bridge.set_bridge("torch") 15 | 16 | 17 | class ImageNorm(object): 18 | """Apply Normalization to Image Pixels on GPU""" 19 | 20 | def __init__(self, mean, std): 21 | self.mean = torch.tensor(mean).view(1, 3, 1, 1) 22 | self.std = torch.tensor(std).view(1, 3, 1, 1) 23 | 24 | def __call__(self, img): 25 | 26 | if torch.max(img) > 1 and self.mean.max() <= 1: 27 | img.div_(255.0) 28 | return img.sub_(self.mean).div_(self.std) 29 | 30 | 31 | def load_jsonl(filename): 32 | with open(filename, "r") as f: 33 | return [json.loads(l.strip("\n")) for l in f.readlines()] 34 | 35 | 36 | class VideoDataset(Dataset): 37 | def __init__( 38 | self, 39 | video_root, 40 | ann_root, 41 | num_frm=4, 42 | frm_sampling_strategy="rand", 43 | max_img_size=384, 44 | video_fmt=".mp4", 45 | ): 46 | """ 47 | image_root (string): Root directory of video 48 | ann_root (string): directory to store the annotation file 49 | """ 50 | url = "https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl" 51 | filename = "msrvtt_test.jsonl" 52 | 53 | download_url(url, ann_root) 54 | self.annotation = load_jsonl(os.path.join(ann_root, filename)) 55 | 56 | self.num_frm = num_frm 57 | self.frm_sampling_strategy = frm_sampling_strategy 58 | self.max_img_size = max_img_size 59 | self.video_root = video_root 60 | self.video_fmt = video_fmt 61 | self.img_norm = ImageNorm( 62 | mean=(0.48145466, 0.4578275, 0.40821073), 63 | std=(0.26862954, 0.26130258, 0.27577711), 64 | ) 65 | 66 | self.text = [pre_caption(ann["caption"], 40) for ann in self.annotation] 67 | self.txt2video = [i for i in range(len(self.annotation))] 68 | self.video2txt = self.txt2video 69 | 70 | def __len__(self): 71 | return len(self.annotation) 72 | 73 | def __getitem__(self, index): 74 | 75 | ann = self.annotation[index] 76 | 77 | video_path = os.path.join(self.video_root, ann["clip_name"] + self.video_fmt) 78 | 79 | vid_frm_array = self._load_video_from_path_decord( 80 | video_path, height=self.max_img_size, width=self.max_img_size 81 | ) 82 | 83 | video = self.img_norm(vid_frm_array.float()) 84 | 85 | return video, ann["clip_name"] 86 | 87 | def _load_video_from_path_decord( 88 | self, 89 | video_path, 90 | height=None, 91 | width=None, 92 | start_time=None, 93 | end_time=None, 94 | fps=-1, 95 | ): 96 | try: 97 | if not height or not width: 98 | vr = VideoReader(video_path) 99 | else: 100 | vr = VideoReader(video_path, width=width, height=height) 101 | 102 | vlen = len(vr) 103 | 104 | if start_time or end_time: 105 | assert ( 106 | fps > 0 107 | ), "must provide video fps if specifying start and end time." 108 | 109 | start_idx = min(int(start_time * fps), vlen) 110 | end_idx = min(int(end_time * fps), vlen) 111 | else: 112 | start_idx, end_idx = 0, vlen 113 | 114 | if self.frm_sampling_strategy == "uniform": 115 | frame_indices = np.arange( 116 | start_idx, end_idx, vlen / self.num_frm, dtype=int 117 | ) 118 | elif self.frm_sampling_strategy == "rand": 119 | frame_indices = sorted(random.sample(range(vlen), self.num_frm)) 120 | elif self.frm_sampling_strategy == "headtail": 121 | frame_indices_head = sorted( 122 | random.sample(range(vlen // 2), self.num_frm // 2) 123 | ) 124 | frame_indices_tail = sorted( 125 | random.sample(range(vlen // 2, vlen), self.num_frm // 2) 126 | ) 127 | frame_indices = frame_indices_head + frame_indices_tail 128 | else: 129 | raise NotImplementedError( 130 | "Invalid sampling strategy {} ".format(self.frm_sampling_strategy) 131 | ) 132 | 133 | raw_sample_frms = vr.get_batch(frame_indices) 134 | except Exception as e: 135 | return None 136 | 137 | raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2) 138 | 139 | return raw_sample_frms 140 | -------------------------------------------------------------------------------- /data/vqg_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | from typing import Tuple 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | from data.utils import pre_question 10 | 11 | from torchvision.datasets.utils import download_url 12 | 13 | 14 | GENERATION_TEMPLATE = """Question: {question} Answer: {answer}""" 15 | GENERATION_TEMPLATE_WITH_RATIONALE = ( 16 | """Question: {question} Answer: {answer}. Rationale: {rationale}""" 17 | ) 18 | GENERATION_TEMPLATE_WITH_RATIONALE_FIRST = ( 19 | """Rationale: {rationale}. Question: {question}. Answer: {answer}""" 20 | ) 21 | 22 | 23 | class VqgDataset(Dataset): 24 | def __init__( 25 | self, 26 | transform, 27 | ann_root, 28 | vqa_root, 29 | vg_root, 30 | train_files=None, 31 | split="train", 32 | truncate_to=None, 33 | ): 34 | self.split = split 35 | 36 | self.transform = transform 37 | self.vqa_root = vqa_root 38 | self.vg_root = vg_root 39 | self.template = GENERATION_TEMPLATE 40 | self.truncate_to = truncate_to 41 | 42 | if split == "train": 43 | urls = { 44 | "vqa_train": "https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json", 45 | "vqa_val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json", 46 | "vg_qa": "https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json", 47 | } 48 | 49 | self.annotation = [] 50 | train_files = [] if train_files is None else train_files 51 | for f in train_files: 52 | download_url(urls[f], ann_root) 53 | self.annotation += json.load( 54 | open(os.path.join(ann_root, "%s.json" % f), "r") 55 | ) 56 | 57 | if truncate_to is not None: 58 | self.annotation = self.annotation[:truncate_to] 59 | 60 | else: 61 | download_url( 62 | "https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json", 63 | ann_root, 64 | ) 65 | self.annotation = json.load( 66 | open(os.path.join(ann_root, "vqa_test.json"), "r") 67 | ) 68 | 69 | download_url( 70 | "https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json", 71 | ann_root, 72 | ) 73 | self.answer_list = json.load( 74 | open(os.path.join(ann_root, "answer_list.json"), "r") 75 | ) 76 | 77 | def __len__(self): 78 | return len(self.annotation) 79 | 80 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, str, int]: 81 | 82 | ann = self.annotation[index] 83 | 84 | if ann["dataset"] == "vqa" or ann["dataset"] == "aokvqa": 85 | image_path = os.path.join(self.vqa_root, ann["image"]) 86 | elif ann["dataset"] == "vg": 87 | image_path = os.path.join(self.vg_root, ann["image"]) 88 | 89 | image = Image.open(image_path).convert("RGB") 90 | image = self.transform(image) 91 | 92 | if self.split == "test": 93 | question = pre_question(ann["question"]) 94 | question_id = ann["question_id"] 95 | return image, question, question_id 96 | 97 | elif self.split == "train": 98 | 99 | question = pre_question(ann["question"]) 100 | 101 | if ann["dataset"] == "vqa" or ann["dataset"] == "aokvqa": 102 | answer_weight = {} 103 | for answer in ann["answer"]: 104 | if answer in answer_weight.keys(): 105 | answer_weight[answer] += 1 / len(ann["answer"]) 106 | else: 107 | answer_weight[answer] = 1 / len(ann["answer"]) 108 | 109 | answers = list(answer_weight.keys()) 110 | weights = list(answer_weight.values()) 111 | 112 | elif ann["dataset"] == "vg": 113 | answers = [ann["answer"]] 114 | weights = [0.2] 115 | 116 | target = self.template.format(question=question, answer=",".join(answers)) 117 | 118 | return image, target, ann["question_id"] 119 | 120 | 121 | class AokVqgDataset(Dataset): 122 | def __init__( 123 | self, 124 | transform, 125 | ann_root, 126 | vqa_root, 127 | vg_root, 128 | train_files=None, 129 | split="train", 130 | truncate_to=None, 131 | use_rationale=False, 132 | generate_rationale_first=False, 133 | ): 134 | self.split = split 135 | self.transform = transform 136 | self.vqa_root = vqa_root 137 | self.vg_root = vg_root 138 | self.truncate_to = truncate_to 139 | self.use_rationale = use_rationale 140 | self.generate_rationale_first = generate_rationale_first 141 | 142 | if split == "train": 143 | self.annotation = [] 144 | for f in train_files: 145 | self.annotation += json.load( 146 | open(os.path.join(ann_root, "%s.json" % f), "r") 147 | ) 148 | if self.truncate_to: 149 | self.annotation = self.annotation[: self.truncate_to] 150 | elif split == "val": 151 | self.annotation = json.load(open(os.path.join(ann_root, "val.json"), "r")) 152 | self.answer_list = json.load( 153 | open(os.path.join(ann_root, "answer_list.json"), "r") 154 | ) 155 | else: 156 | self.annotation = json.load(open(os.path.join(ann_root, "test.json"), "r")) 157 | self.answer_list = json.load( 158 | open(os.path.join(ann_root, "answer_list.json"), "r") 159 | ) 160 | 161 | def __len__(self): 162 | return len(self.annotation) 163 | 164 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, str, int]: 165 | 166 | ann = self.annotation[index] 167 | 168 | image_path = os.path.join(self.vqa_root, ann["image"]) 169 | 170 | image = Image.open(image_path).convert("RGB") 171 | image = self.transform(image) 172 | 173 | if self.split == "test": 174 | question = pre_question(ann["question"]) 175 | question_id = ann["question_id"] 176 | return image, question, question_id 177 | 178 | elif self.split == "train": 179 | 180 | question = pre_question(ann["question"]) 181 | 182 | answers = ann["answer"] 183 | 184 | if self.use_rationale: 185 | if self.generate_rationale_first: 186 | target = GENERATION_TEMPLATE_WITH_RATIONALE_FIRST.format( 187 | question=question, 188 | answer=",".join(answers), 189 | rationale=" ".join(ann["rationales"]), 190 | ) 191 | else: 192 | target = GENERATION_TEMPLATE_WITH_RATIONALE.format( 193 | question=question, 194 | answer=",".join(answers), 195 | rationale=" ".join(ann["rationales"]), 196 | ) 197 | else: 198 | target = GENERATION_TEMPLATE.format( 199 | question=question, answer=",".join(answers) 200 | ) 201 | 202 | return image, target, ann["question_id"] 203 | 204 | 205 | def vqa_collate_fn(batch): 206 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] 207 | for image, question, answer, weights in batch: 208 | image_list.append(image) 209 | question_list.append(question) 210 | weight_list += weights 211 | answer_list += answer 212 | n.append(len(answer)) 213 | return ( 214 | torch.stack(image_list, dim=0), 215 | question_list, 216 | answer_list, 217 | torch.Tensor(weight_list), 218 | n, 219 | ) 220 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: blip 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - aws-c-cal=0.5.11=h95a6274_0 9 | - aws-c-common=0.6.2=h7f98852_0 10 | - aws-c-event-stream=0.2.7=h3541f99_13 11 | - aws-c-io=0.10.5=hfb6a706_0 12 | - aws-checksums=0.1.11=ha31a3da_7 13 | - aws-sdk-cpp=1.8.186=hb4091e7_3 14 | - c-ares=1.18.1=h7f98852_0 15 | - ca-certificates=2022.9.24=ha878542_0 16 | - certifi=2022.9.24=pyhd8ed1ab_0 17 | - keyutils=1.6.1=h166bdaf_0 18 | - krb5=1.19.3=h3790be6_0 19 | - ld_impl_linux-64=2.38=h1181459_1 20 | - libcurl=7.79.1=h2574ce0_1 21 | - libedit=3.1.20191231=he28a2e2_2 22 | - libev=4.33=h516909a_1 23 | - libffi=3.3=he6710b0_2 24 | - libgcc-ng=11.2.0=h1234567_1 25 | - libgomp=11.2.0=h1234567_1 26 | - libnghttp2=1.43.0=h812cca2_1 27 | - libssh2=1.10.0=ha56f1ee_2 28 | - libstdcxx-ng=11.2.0=h1234567_1 29 | - ncurses=6.3=h5eee18b_3 30 | - openssl=1.1.1o=h166bdaf_0 31 | - pip=22.1.2=py38h06a4308_0 32 | - python=3.8.13=h12debd9_0 33 | - readline=8.1.2=h7f8727e_1 34 | - s2n=1.0.10=h9b69904_0 35 | - setuptools=61.2.0=py38h06a4308_0 36 | - sqlite=3.38.5=hc218d9a_0 37 | - tk=8.6.12=h1ccaba5_0 38 | - wheel=0.37.1=pyhd3eb1b0_0 39 | - xz=5.2.5=h7f8727e_1 40 | - zlib=1.2.12=h7f8727e_2 41 | - pip: 42 | - absl-py==1.2.0 43 | - accelerate==0.15.0 44 | - aiohttp==3.8.1 45 | - aiosignal==1.2.0 46 | - antlr4-python3-runtime==4.9.3 47 | - argon2-cffi==21.3.0 48 | - argon2-cffi-bindings==21.2.0 49 | - asttokens==2.0.5 50 | - async-timeout==4.0.2 51 | - attrs==22.1.0 52 | - backcall==0.2.0 53 | - beartype==0.11.0 54 | - beautifulsoup4==4.11.1 55 | - black==22.6.0 56 | - bleach==5.0.1 57 | - blis==0.7.9 58 | - cachetools==5.2.0 59 | - catalogue==2.0.8 60 | - cattrs==22.1.0 61 | - cffi==1.15.1 62 | - charset-normalizer==2.1.0 63 | - click==8.1.3 64 | - clip==1.0 65 | - confection==0.0.3 66 | - cupy-cuda11x==11.4.0 67 | - cupy-wheel==11.4.0 68 | - cycler==0.11.0 69 | - cymem==2.0.7 70 | - datasets==2.6.1 71 | - debugpy==1.6.2 72 | - decorator==5.1.1 73 | - defusedxml==0.7.1 74 | - dill==0.3.5.1 75 | - docker-pycreds==0.4.0 76 | - en-core-web-trf==3.4.1 77 | - entrypoints==0.4 78 | - evaluate==0.3.0 79 | - exceptiongroup==1.0.0rc8 80 | - executing==0.9.1 81 | - fairscale==0.4.4 82 | - faiss-cpu==1.7.2 83 | - fastjsonschema==2.16.1 84 | - fastrlock==0.8.1 85 | - filelock==3.7.1 86 | - fire==0.4.0 87 | - fonttools==4.34.4 88 | - frozenlist==1.3.1 89 | - fsspec==2022.7.1 90 | - ftfy==6.1.1 91 | - gdown==4.6.0 92 | - gitdb==4.0.9 93 | - gitpython==3.1.27 94 | - google-auth==2.10.0 95 | - google-auth-oauthlib==0.4.6 96 | - grpcio==1.47.0 97 | - huggingface-hub==0.10.1 98 | - hydra-core==1.2.0 99 | - idna==3.3 100 | - importlib-metadata==4.12.0 101 | - importlib-resources==5.9.0 102 | - iniconfig==1.1.1 103 | - ipdb==0.13.9 104 | - ipykernel==6.15.1 105 | - ipython==8.4.0 106 | - ipython-genutils==0.2.0 107 | - ipywidgets==7.7.1 108 | - jedi==0.18.1 109 | - jinja2==3.1.2 110 | - joblib==1.1.0 111 | - jq==1.3.0 112 | - jsonlines==3.1.0 113 | - jsonschema==4.8.0 114 | - jupyter==1.0.0 115 | - jupyter-client==7.3.4 116 | - jupyter-console==6.4.4 117 | - jupyter-core==4.11.1 118 | - jupyterlab-pygments==0.2.2 119 | - jupyterlab-widgets==1.1.1 120 | - kaleido==0.2.1 121 | - kiwisolver==1.4.4 122 | - langcodes==3.3.0 123 | - lxml==4.9.1 124 | - markdown==3.4.1 125 | - markupsafe==2.1.1 126 | - matplotlib==3.5.2 127 | - matplotlib-inline==0.1.3 128 | - mistune==0.8.4 129 | - multidict==6.0.2 130 | - multiprocess==0.70.13 131 | - murmurhash==1.0.9 132 | - mypy==0.991 133 | - mypy-extensions==0.4.3 134 | - nbclient==0.6.6 135 | - nbconvert==6.5.0 136 | - nbformat==5.4.0 137 | - nest-asyncio==1.5.5 138 | - nltk==3.7 139 | - notebook==6.4.12 140 | - numpy==1.23.1 141 | - oauthlib==3.2.0 142 | - omegaconf==2.2.2 143 | - opencv-python==4.6.0.66 144 | - orjson==3.7.11 145 | - packaging==21.3 146 | - pandas==1.4.3 147 | - pandocfilters==1.5.0 148 | - parso==0.8.3 149 | - pathspec==0.9.0 150 | - pathtools==0.1.2 151 | - pathy==0.10.1 152 | - patsy==0.5.3 153 | - pexpect==4.8.0 154 | - pickleshare==0.7.5 155 | - pillow==9.2.0 156 | - platformdirs==2.5.2 157 | - plotly==5.11.0 158 | - pluggy==1.0.0 159 | - preshed==3.0.8 160 | - prometheus-client==0.14.1 161 | - promise==2.3 162 | - prompt-toolkit==3.0.30 163 | - protobuf==3.19.4 164 | - psutil==5.9.1 165 | - ptyprocess==0.7.0 166 | - pure-eval==0.2.2 167 | - py==1.11.0 168 | - pyarrow==9.0.0 169 | - pyasn1==0.4.8 170 | - pyasn1-modules==0.2.8 171 | - pycocoevalcap==1.2 172 | - pycocotools==2.0.4 173 | - pycparser==2.21 174 | - pydantic==1.10.2 175 | - pydeprecate==0.3.2 176 | - pygments==2.12.0 177 | - pyinstrument==4.2.0 178 | - pyparsing==3.0.9 179 | - pyrsistent==0.18.1 180 | - pysocks==1.7.1 181 | - pytest==7.1.2 182 | - python-dateutil==2.8.2 183 | - python-pptx==0.6.21 184 | - pytorch-lightning==1.7.1 185 | - pytz==2022.2 186 | - pyyaml==6.0 187 | - pyzmq==23.2.0 188 | - qtconsole==5.3.1 189 | - qtpy==2.1.0 190 | - regex==2022.7.25 191 | - requests==2.28.1 192 | - requests-oauthlib==1.3.1 193 | - responses==0.18.0 194 | - rouge-score==0.1.2 195 | - rsa==4.9 196 | - ruamel-yaml==0.17.21 197 | - ruamel-yaml-clib==0.2.6 198 | - sacremoses==0.0.53 199 | - scikit-learn==1.1.3 200 | - scipy==1.9.3 201 | - seaborn==0.12.0 202 | - send2trash==1.8.0 203 | - sentencepiece==0.1.97 204 | - sentry-sdk==1.9.3 205 | - setproctitle==1.3.1 206 | - shortuuid==1.0.9 207 | - six==1.16.0 208 | - smart-open==6.3.0 209 | - smmap==5.0.0 210 | - soupsieve==2.3.2.post1 211 | - spacy==3.4.4 212 | - spacy-alignments==0.8.6 213 | - spacy-legacy==3.0.10 214 | - spacy-loggers==1.0.4 215 | - spacy-transformers==1.1.8 216 | - srsly==2.4.5 217 | - stack-data==0.3.0 218 | - statsmodels==0.13.5 219 | - sunburst==1.0.0a2 220 | - tabulate==0.9.0 221 | - tenacity==8.1.0 222 | - tensorboard==2.10.0 223 | - tensorboard-data-server==0.6.1 224 | - tensorboard-plugin-wit==1.8.1 225 | - termcolor==1.1.0 226 | - terminado==0.15.0 227 | - thinc==8.1.5 228 | - threadpoolctl==3.1.0 229 | - timm==0.4.12 230 | - tinycss2==1.1.1 231 | - tokenizers==0.12.1 232 | - toml==0.10.2 233 | - tomli==2.0.1 234 | - torch==1.12.0+cu113 235 | - torchaudio==0.12.0+cu113 236 | - torchmetrics==0.9.3 237 | - torchvision==0.13.0+cu113 238 | - tornado==6.2 239 | - tqdm==4.64.0 240 | - traitlets==5.3.0 241 | - transformers==4.21.3 242 | - typeguard==2.13.3 243 | - typer==0.7.0 244 | - typing==3.7.4.3 245 | - typing-extensions==4.3.0 246 | - urllib3==1.26.11 247 | - wandb==0.13.7 248 | - wasabi==0.10.1 249 | - wcwidth==0.2.5 250 | - webencodings==0.5.1 251 | - werkzeug==2.2.2 252 | - widgetsnbextension==3.6.1 253 | - xlsxwriter==3.0.3 254 | - xxhash==3.1.0 255 | - yarl==1.8.1 256 | - zipp==3.8.1 257 | prefix: /home/mai/zkhan/miniconda3/envs/blip 258 | -------------------------------------------------------------------------------- /examples/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # NOTE: This script shows how you can evaluate a list of checkpoints on a dataset. 4 | # In this case, the dataset is ArtVQA. You can change this to work for _any_ dataset 5 | # by changing the config file and the evaluation script (artvqa_eval.py). 6 | 7 | parentdir_name() { 8 | echo "$(basename "$(dirname "$1")")" 9 | } 10 | 11 | # Control the number of GPUs and GPU usage by setting 12 | # CUDA_VISIBLE_DEVICES. 13 | 14 | # FYI, this will clobber results when two models have the same name. 15 | 16 | zero_shot_artvqa_generate_evaluate () { 17 | local CHECKPOINT=$1 18 | local OUTPUT_DIR=cache/artvqa_evals_$(parentdir_name $CHECKPOINT) 19 | echo -e "\e[1;34m\e[4mGenerating 0-shot ArtVQA results for $CHECKPOINT\e[0m" 20 | python -m torch.distributed.run --nproc_per_node=4 train_vqa.py \ 21 | --output_dir=$OUTPUT_DIR --evaluate \ 22 | --config configs/artvqa.yaml \ 23 | --overrides wandb=false \ 24 | pretrained=$CHECKPOINT \ 25 | batch_size_test=16 26 | echo -e "\e[1;32mResults are in $OUTPUT_DIR\e[0m" 27 | # We don't control the name of the result file, it is currently 28 | # harcoded to be result/vqa_result.json. 29 | python artvqa_eval.py $OUTPUT_DIR/result/vqa_result.json 30 | } 31 | 32 | 33 | MODELS_TO_EVAL=( 34 | # Compare synthetic vs baseline models with VQAV2 post-training and A-OKVQA finetuning. 35 | /net/acadia10a/data/zkhan/mithril/aokvqa_finetuned/blip_vqa_baseline/checkpoint_09.pth 36 | /net/acadia10a/data/zkhan/mithril/aokvqa_finetuned/33-finetune-on-aokvqa-synth_17k/checkpoint_09.pth 37 | /net/acadia4a/data/zkhan/mithril/aokvqa_finetuned/33-finetune-on-aokvqa-synth_34k/checkpoint_09.pth 38 | /net/acadia4a/data/zkhan/mithril/aokvqa_finetuned/33-finetune-on-aokvqa-synth_51k/checkpoint_09.pth 39 | /net/acadia4a/data/zkhan/mithril/aokvqa_finetuned/34-finetune-on-aokvqa-synth_4k/checkpoint_09.pth 40 | /net/acadia4a/data/zkhan/mithril/aokvqa_finetuned/34-finetune-on-aokvqa-synth_8k/checkpoint_09.pth 41 | 42 | # Compare synthetic vs baseline models from A-OKVQA only. 43 | /net/acadia10a/data/zkhan/mithril/aokvqa_finetuned/blip/checkpoint_09.pth 44 | /net/acadia10a/data/zkhan/mithril/aokvqa_finetuned/25-aokvqa_synth_rationale_17k/checkpoint_09.pth 45 | /net/acadia10a/data/zkhan/mithril/aokvqa_finetuned/25-aokvqa_synth_rationale_34k/checkpoint_09.pth 46 | /net/acadia10a/data/zkhan/mithril/aokvqa_finetuned/25-aokvqa_synth_rationale_51k/checkpoint_09.pth 47 | ) 48 | 49 | 50 | for PATH_TO_CHECKPOINT in ${MODELS_TO_EVAL[@]}; do 51 | zero_shot_artvqa_generate_evaluate $PATH_TO_CHECKPOINT 52 | done -------------------------------------------------------------------------------- /examples/generate_synthetic_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Store the synthetic data within the same folder you store the real data. 3 | # We have to do it this way because of the dataset reader code. 4 | OUTPUT_DIR= cache/aokvqa 5 | TEACHER_WEIGHTS=cache/teacher_weights/checkpoint_04.pth 6 | OUTPUT_ANNOTATIONS_NAME=synthetic_data.json 7 | python generate_questions.py --config=configs/generate_questions_coco.yaml --overrides \ 8 | max_length=40 \ 9 | output_folder=$OUTPUT_DIR \ 10 | pretrained=$TEACHER_WEIGHTS \ 11 | questions_per_image=2 \ 12 | output_annotations_name=$OUTPUT_ANNOTATIONS_NAME -------------------------------------------------------------------------------- /examples/self_train_synthetic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OUTPUT_DIR=cache/self_trained_weights 3 | # Just provide the file name without the extension. 4 | # The reader code will look at the config file (in this case, configs/aokvqa.yaml) 5 | # and read both the synthetic JSON and real JSON from the root folder of the dataset. 6 | TRAIN_FILES="[train,synthetic_data]" 7 | python -m torch.distributed.run --nproc_per_node=4 train_vqa.py \ 8 | --output_dir=$OUTPUT_DIR \ 9 | --config configs/aokvqa.yaml \ 10 | --overrides wandb=true \ 11 | train_files=$TRAIN_FILES \ 12 | truncate_train_dataset_to=34000 \ -------------------------------------------------------------------------------- /examples/train_teacher.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OUTPUT_DIR=cache/teacher_weights 4 | python -m torch.distributed.run --master_port=37770 --nproc_per_node=4 train_vqg.py \ 5 | --config=configs/aokvqg.yaml \ 6 | --output_dir=$OUTPUT_DIR \ 7 | --overrides batch_size=64 wandb=True -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codezakh/SelTDA/8ca7d53fb6ef1d8dec62e52bc9a46df4a194e06b/models/__init__.py -------------------------------------------------------------------------------- /models/blip_itm.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | 11 | class BLIP_ITM(nn.Module): 12 | def __init__( 13 | self, 14 | med_config="configs/med_config.json", 15 | image_size=384, 16 | vit="base", 17 | vit_grad_ckpt=False, 18 | vit_ckpt_layer=0, 19 | embed_dim=256, 20 | ): 21 | """ 22 | Args: 23 | med_config (str): path for the mixture of encoder-decoder model's configuration file 24 | image_size (int): input image size 25 | vit (str): model size of vision transformer 26 | """ 27 | super().__init__() 28 | 29 | self.visual_encoder, vision_width = create_vit( 30 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer 31 | ) 32 | self.tokenizer = init_tokenizer() 33 | med_config = BertConfig.from_json_file(med_config) 34 | med_config.encoder_width = vision_width 35 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 36 | 37 | text_width = self.text_encoder.config.hidden_size 38 | 39 | self.vision_proj = nn.Linear(vision_width, embed_dim) 40 | self.text_proj = nn.Linear(text_width, embed_dim) 41 | 42 | self.itm_head = nn.Linear(text_width, 2) 43 | 44 | def forward(self, image, caption, match_head="itm"): 45 | 46 | image_embeds = self.visual_encoder(image) 47 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 48 | image.device 49 | ) 50 | 51 | text = self.tokenizer( 52 | caption, 53 | padding="max_length", 54 | truncation=True, 55 | max_length=35, 56 | return_tensors="pt", 57 | ).to(image.device) 58 | 59 | if match_head == "itm": 60 | output = self.text_encoder( 61 | text.input_ids, 62 | attention_mask=text.attention_mask, 63 | encoder_hidden_states=image_embeds, 64 | encoder_attention_mask=image_atts, 65 | return_dict=True, 66 | ) 67 | itm_output = self.itm_head(output.last_hidden_state[:, 0, :]) 68 | return itm_output 69 | 70 | elif match_head == "itc": 71 | text_output = self.text_encoder( 72 | text.input_ids, 73 | attention_mask=text.attention_mask, 74 | return_dict=True, 75 | mode="text", 76 | ) 77 | image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) 78 | text_feat = F.normalize( 79 | self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 80 | ) 81 | 82 | sim = image_feat @ text_feat.t() 83 | return sim 84 | 85 | 86 | def blip_itm(pretrained="", **kwargs): 87 | model = BLIP_ITM(**kwargs) 88 | if pretrained: 89 | model, msg = load_checkpoint(model, pretrained) 90 | assert len(msg.missing_keys) == 0 91 | return model 92 | -------------------------------------------------------------------------------- /models/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig 2 | from models.nlvr_encoder import BertModel 3 | from models.vit import interpolate_pos_embed 4 | from models.blip import create_vit, init_tokenizer, is_url 5 | 6 | from timm.models.hub import download_cached_file 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from transformers import BertTokenizer 12 | import numpy as np 13 | 14 | 15 | class BLIP_NLVR(nn.Module): 16 | def __init__( 17 | self, 18 | med_config="configs/med_config.json", 19 | image_size=480, 20 | vit="base", 21 | vit_grad_ckpt=False, 22 | vit_ckpt_layer=0, 23 | ): 24 | """ 25 | Args: 26 | med_config (str): path for the mixture of encoder-decoder model's configuration file 27 | image_size (int): input image size 28 | vit (str): model size of vision transformer 29 | """ 30 | super().__init__() 31 | 32 | self.visual_encoder, vision_width = create_vit( 33 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1 34 | ) 35 | self.tokenizer = init_tokenizer() 36 | med_config = BertConfig.from_json_file(med_config) 37 | med_config.encoder_width = vision_width 38 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 39 | 40 | self.cls_head = nn.Sequential( 41 | nn.Linear( 42 | self.text_encoder.config.hidden_size, 43 | self.text_encoder.config.hidden_size, 44 | ), 45 | nn.ReLU(), 46 | nn.Linear(self.text_encoder.config.hidden_size, 2), 47 | ) 48 | 49 | def forward(self, image, text, targets, train=True): 50 | 51 | image_embeds = self.visual_encoder(image) 52 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 53 | image.device 54 | ) 55 | image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0)) 56 | 57 | text = self.tokenizer(text, padding="longest", return_tensors="pt").to( 58 | image.device 59 | ) 60 | text.input_ids[:, 0] = self.tokenizer.enc_token_id 61 | 62 | output = self.text_encoder( 63 | text.input_ids, 64 | attention_mask=text.attention_mask, 65 | encoder_hidden_states=[image0_embeds, image1_embeds], 66 | encoder_attention_mask=[ 67 | image_atts[: image0_embeds.size(0)], 68 | image_atts[image0_embeds.size(0) :], 69 | ], 70 | return_dict=True, 71 | ) 72 | hidden_state = output.last_hidden_state[:, 0, :] 73 | prediction = self.cls_head(hidden_state) 74 | 75 | if train: 76 | loss = F.cross_entropy(prediction, targets) 77 | return loss 78 | else: 79 | return prediction 80 | 81 | 82 | def blip_nlvr(pretrained="", **kwargs): 83 | model = BLIP_NLVR(**kwargs) 84 | if pretrained: 85 | model, msg = load_checkpoint(model, pretrained) 86 | print("missing keys:") 87 | print(msg.missing_keys) 88 | return model 89 | 90 | 91 | def load_checkpoint(model, url_or_filename): 92 | if is_url(url_or_filename): 93 | cached_file = download_cached_file( 94 | url_or_filename, check_hash=False, progress=True 95 | ) 96 | checkpoint = torch.load(cached_file, map_location="cpu") 97 | elif os.path.isfile(url_or_filename): 98 | checkpoint = torch.load(url_or_filename, map_location="cpu") 99 | else: 100 | raise RuntimeError("checkpoint url or path is invalid") 101 | state_dict = checkpoint["model"] 102 | 103 | state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( 104 | state_dict["visual_encoder.pos_embed"], model.visual_encoder 105 | ) 106 | 107 | for key in list(state_dict.keys()): 108 | if "crossattention.self." in key: 109 | new_key0 = key.replace("self", "self0") 110 | new_key1 = key.replace("self", "self1") 111 | state_dict[new_key0] = state_dict[key] 112 | state_dict[new_key1] = state_dict[key] 113 | elif "crossattention.output.dense." in key: 114 | new_key0 = key.replace("dense", "dense0") 115 | new_key1 = key.replace("dense", "dense1") 116 | state_dict[new_key0] = state_dict[key] 117 | state_dict[new_key1] = state_dict[key] 118 | 119 | msg = model.load_state_dict(state_dict, strict=False) 120 | print("load checkpoint from %s" % url_or_filename) 121 | return model, msg 122 | -------------------------------------------------------------------------------- /models/blip_vqa.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel, BertLMHeadModel 2 | from models.blip import create_vit, init_tokenizer, load_checkpoint 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers import BertTokenizer 8 | import numpy as np 9 | 10 | 11 | class BLIP_VQA(nn.Module): 12 | def __init__( 13 | self, 14 | med_config="configs/med_config.json", 15 | image_size=480, 16 | vit="base", 17 | vit_grad_ckpt=False, 18 | vit_ckpt_layer=0, 19 | ): 20 | """ 21 | Args: 22 | med_config (str): path for the mixture of encoder-decoder model's configuration file 23 | image_size (int): input image size 24 | vit (str): model size of vision transformer 25 | """ 26 | super().__init__() 27 | 28 | self.visual_encoder, vision_width = create_vit( 29 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1 30 | ) 31 | self.tokenizer = init_tokenizer() 32 | 33 | encoder_config = BertConfig.from_json_file(med_config) 34 | encoder_config.encoder_width = vision_width 35 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 36 | 37 | decoder_config = BertConfig.from_json_file(med_config) 38 | self.text_decoder = BertLMHeadModel(config=decoder_config) 39 | 40 | def forward( 41 | self, 42 | image, 43 | question, 44 | answer=None, 45 | n=None, 46 | weights=None, 47 | train=True, 48 | inference="rank", 49 | k_test=128, 50 | return_scores=False, 51 | ): 52 | 53 | image_embeds = self.visual_encoder(image) 54 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 55 | image.device 56 | ) 57 | 58 | question = self.tokenizer( 59 | question, 60 | padding="longest", 61 | truncation=True, 62 | max_length=35, 63 | return_tensors="pt", 64 | ).to(image.device) 65 | question.input_ids[:, 0] = self.tokenizer.enc_token_id 66 | 67 | if train: 68 | """ 69 | n: number of answers for each question 70 | weights: weight for each answer 71 | """ 72 | answer = self.tokenizer(answer, padding="longest", return_tensors="pt").to( 73 | image.device 74 | ) 75 | answer.input_ids[:, 0] = self.tokenizer.bos_token_id 76 | answer_targets = answer.input_ids.masked_fill( 77 | answer.input_ids == self.tokenizer.pad_token_id, -100 78 | ) 79 | 80 | question_output = self.text_encoder( 81 | question.input_ids, 82 | attention_mask=question.attention_mask, 83 | encoder_hidden_states=image_embeds, 84 | encoder_attention_mask=image_atts, 85 | return_dict=True, 86 | ) 87 | 88 | question_states = [] 89 | question_atts = [] 90 | for b, n in enumerate(n): 91 | question_states += [question_output.last_hidden_state[b]] * n 92 | question_atts += [question.attention_mask[b]] * n 93 | question_states = torch.stack(question_states, 0) 94 | question_atts = torch.stack(question_atts, 0) 95 | 96 | answer_output = self.text_decoder( 97 | answer.input_ids, 98 | attention_mask=answer.attention_mask, 99 | encoder_hidden_states=question_states, 100 | encoder_attention_mask=question_atts, 101 | labels=answer_targets, 102 | return_dict=True, 103 | reduction="none", 104 | ) 105 | 106 | loss = weights * answer_output.loss 107 | loss = loss.sum() / image.size(0) 108 | 109 | return loss 110 | 111 | else: 112 | question_output = self.text_encoder( 113 | question.input_ids, 114 | attention_mask=question.attention_mask, 115 | encoder_hidden_states=image_embeds, 116 | encoder_attention_mask=image_atts, 117 | return_dict=True, 118 | ) 119 | 120 | if inference == "generate": 121 | num_beams = 3 122 | question_states = question_output.last_hidden_state.repeat_interleave( 123 | num_beams, dim=0 124 | ) 125 | question_atts = torch.ones( 126 | question_states.size()[:-1], dtype=torch.long 127 | ).to(question_states.device) 128 | model_kwargs = { 129 | "encoder_hidden_states": question_states, 130 | "encoder_attention_mask": question_atts, 131 | } 132 | 133 | bos_ids = torch.full( 134 | (image.size(0), 1), 135 | fill_value=self.tokenizer.bos_token_id, 136 | device=image.device, 137 | ) 138 | 139 | outputs = self.text_decoder.generate( 140 | input_ids=bos_ids, 141 | max_length=10, 142 | min_length=1, 143 | num_beams=num_beams, 144 | eos_token_id=self.tokenizer.sep_token_id, 145 | pad_token_id=self.tokenizer.pad_token_id, 146 | **model_kwargs 147 | ) 148 | 149 | answers = [] 150 | for output in outputs: 151 | answer = self.tokenizer.decode(output, skip_special_tokens=True) 152 | answers.append(answer) 153 | return answers 154 | 155 | elif inference == "rank": 156 | max_ids, max_scores = self.rank_answer( 157 | question_output.last_hidden_state, 158 | question.attention_mask, 159 | answer.input_ids, 160 | answer.attention_mask, 161 | k_test, 162 | return_top_probability=True, 163 | ) 164 | if return_scores: 165 | return max_ids, max_scores 166 | return max_ids 167 | 168 | def rank_answer( 169 | self, 170 | question_states, 171 | question_atts, 172 | answer_ids, 173 | answer_atts, 174 | k, 175 | return_top_probability=False, 176 | ): 177 | 178 | num_ques = question_states.size(0) 179 | start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token 180 | 181 | start_output = self.text_decoder( 182 | start_ids, 183 | encoder_hidden_states=question_states, 184 | encoder_attention_mask=question_atts, 185 | return_dict=True, 186 | reduction="none", 187 | ) 188 | logits = start_output.logits[:, 0, :] # first token's logit 189 | 190 | # topk_probs: top-k probability 191 | # topk_ids: [num_question, k] 192 | answer_first_token = answer_ids[:, 1] 193 | prob_first_token = F.softmax(logits, dim=1).index_select( 194 | dim=1, index=answer_first_token 195 | ) 196 | topk_probs, topk_ids = prob_first_token.topk(k, dim=1) 197 | 198 | # answer input: [num_question*k, answer_len] 199 | input_ids = [] 200 | input_atts = [] 201 | for b, topk_id in enumerate(topk_ids): 202 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 203 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 204 | input_ids = torch.cat(input_ids, dim=0) 205 | input_atts = torch.cat(input_atts, dim=0) 206 | 207 | targets_ids = input_ids.masked_fill( 208 | input_ids == self.tokenizer.pad_token_id, -100 209 | ) 210 | 211 | # repeat encoder's output for top-k answers 212 | question_states = tile(question_states, 0, k) 213 | question_atts = tile(question_atts, 0, k) 214 | 215 | output = self.text_decoder( 216 | input_ids, 217 | attention_mask=input_atts, 218 | encoder_hidden_states=question_states, 219 | encoder_attention_mask=question_atts, 220 | labels=targets_ids, 221 | return_dict=True, 222 | reduction="none", 223 | ) 224 | 225 | log_probs_sum = -output.loss 226 | log_probs_sum = log_probs_sum.view(num_ques, k) 227 | 228 | max_topk_ids = log_probs_sum.argmax(dim=1) 229 | max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] 230 | 231 | if return_top_probability: 232 | return max_ids, log_probs_sum.max(dim=1).values 233 | return max_ids 234 | 235 | 236 | def blip_vqa(pretrained="", **kwargs): 237 | model = BLIP_VQA(**kwargs) 238 | if pretrained: 239 | model, msg = load_checkpoint(model, pretrained) 240 | # assert(len(msg.missing_keys)==0) 241 | return model 242 | 243 | 244 | def tile(x, dim, n_tile): 245 | init_dim = x.size(dim) 246 | repeat_idx = [1] * x.dim() 247 | repeat_idx[dim] = n_tile 248 | x = x.repeat(*(repeat_idx)) 249 | order_index = torch.LongTensor( 250 | np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) 251 | ) 252 | return torch.index_select(x, dim, order_index.to(x.device)) 253 | -------------------------------------------------------------------------------- /okvqa_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm.notebook import tqdm 3 | import json 4 | from pprint import PrettyPrinter 5 | from vqa_eval_tools import VQA, VQAEval 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | 9 | 10 | pp = PrettyPrinter() 11 | 12 | # The annotations are missing an "question_type" key, so we create a new annotation file which does 13 | # have the key. We just copy the "answer_type" key to "question_type", they are the same thing, I think. 14 | annotation_file = "/net/acadia10a/data/zkhan/ok-vqa/mscoco_val2014_annotations.json" 15 | question_file = ( 16 | "/net/acadia10a/data/zkhan/ok-vqa/OpenEnded_mscoco_val2014_questions.json" 17 | ) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = ArgumentParser() 22 | parser.add_argument( 23 | "result_file", help="Path to a JSON result file generated by an evaluation." 24 | ) 25 | args = parser.parse_args() 26 | 27 | results_file = args.result_file 28 | # results_file = '/net/acadia4a/data/zkhan/mithril/advqa-0-shot-evals/35_blip_vqa_baseline/result/vqa_result.json' 29 | 30 | advqa_obj = VQA(annotation_file=annotation_file, question_file=question_file) 31 | 32 | # We have to convert the question_id field to be an integer >.< 33 | with open(results_file, "r") as f: 34 | predicted = json.load(f) 35 | 36 | for element in predicted: 37 | element["question_id"] = int(element["question_id"]) 38 | 39 | with open(results_file, "w") as f: 40 | json.dump(predicted, f) 41 | 42 | result_obj = advqa_obj.loadRes(resFile=results_file, quesFile=question_file) 43 | 44 | advqa_eval = VQAEval(advqa_obj, result_obj, n=2) 45 | advqa_eval.evaluate() 46 | print(f"Completed evaluation of {results_file}") 47 | pp.pprint(advqa_eval.accuracy) 48 | with open(Path(results_file).parent / "okvqa_eval.json", "w") as f: 49 | json.dump(advqa_eval.accuracy, f) 50 | -------------------------------------------------------------------------------- /pathvqa_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | For AQUA (Art VQA), we don't use the VQAv2 evaluation code. 3 | That's because the VQAv2 evaluation code assumes there are 4 | multiple answers for each question, but in AQUA, there's only 5 | one answer for each question. We just do an exact match evaluation 6 | following the AQUA paper. 7 | """ 8 | import json 9 | from unittest import result 10 | from tqdm import tqdm 11 | import json 12 | from pprint import PrettyPrinter 13 | from vqa_eval_tools import VQA, VQAEval 14 | from argparse import ArgumentParser 15 | from pathlib import Path 16 | import schemas 17 | import pandas as pd 18 | 19 | 20 | pp = PrettyPrinter() 21 | 22 | annotation_file = "/net/acadia4a/data/zkhan/pathvqa/test.json" 23 | 24 | 25 | def exact_match_eval(annotation_file, result_file): 26 | with open(annotation_file, "r") as f: 27 | annotations = json.load(f) 28 | with open(result_file, "r") as f: 29 | results = json.load(f) 30 | 31 | annotations = [schemas.MinimalEvaluationRecord.parse_obj(a) for a in annotations] 32 | 33 | annotation_lookup_table = {a.question_id: a for a in annotations} 34 | evaluation_records = [] 35 | for answer_record in results: 36 | ground_truth = annotation_lookup_table[answer_record["question_id"]] 37 | # It's a list, but there's only one answer for each VQA art question. 38 | # So we just take the first one and do an exact match. 39 | true_answer = ground_truth.answer 40 | is_correct = answer_record["answer"] == true_answer 41 | question_type = ground_truth.question_type 42 | evaluation_records.append( 43 | { 44 | "question_id": answer_record["question_id"], 45 | "answer": answer_record["answer"], 46 | "question_type": question_type, 47 | "is_correct": is_correct, 48 | "true_answer": true_answer, 49 | "answer_type": ground_truth.answer_type, 50 | } 51 | ) 52 | 53 | frame = pd.DataFrame(evaluation_records) 54 | answertype_groupby = ( 55 | frame.groupby("answer_type") 56 | .apply(lambda s: s["is_correct"].sum() / len(s)) 57 | .to_frame() 58 | ) 59 | accuracies = { 60 | "overall": frame["is_correct"].sum() / len(frame), 61 | } 62 | accuracies = { 63 | **accuracies, 64 | **{_: float(answertype_groupby.loc[_]) for _ in answertype_groupby.index}, 65 | } 66 | accuracies = {k: round(v, 4) for k, v in accuracies.items()} 67 | return accuracies 68 | 69 | 70 | if __name__ == "__main__": 71 | parser = ArgumentParser() 72 | parser.add_argument( 73 | "result_file", help="Path to a JSON result file generated by an evaluation." 74 | ) 75 | args = parser.parse_args() 76 | 77 | results_file = args.result_file 78 | 79 | accuracies = exact_match_eval(annotation_file, results_file) 80 | 81 | pp.pprint(accuracies) 82 | with open(Path(results_file).parent / "pathvqa_eval.json", "w") as f: 83 | json.dump(accuracies, f) 84 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | slow: marks tests as slow (deselect with '-m "not slow"') 4 | serial -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | transformers==4.15.0 3 | fairscale==0.4.4 4 | pycocoevalcap 5 | -------------------------------------------------------------------------------- /rsvqa_lr_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | For AQUA (Art VQA), we don't use the VQAv2 evaluation code. 3 | That's because the VQAv2 evaluation code assumes there are 4 | multiple answers for each question, but in AQUA, there's only 5 | one answer for each question. We just do an exact match evaluation 6 | following the AQUA paper. 7 | """ 8 | import json 9 | from unittest import result 10 | from tqdm import tqdm 11 | import json 12 | from pprint import PrettyPrinter 13 | from vqa_eval_tools import VQA, VQAEval 14 | from argparse import ArgumentParser 15 | from pathlib import Path 16 | import schemas 17 | import pandas as pd 18 | 19 | 20 | pp = PrettyPrinter() 21 | 22 | annotation_file = "/net/acadia4a/data/zkhan/rsvqa/low_resolution/test_annotations.json" 23 | 24 | 25 | def exact_match_eval(annotation_file, result_file): 26 | with open(annotation_file, "r") as f: 27 | annotations = json.load(f) 28 | with open(result_file, "r") as f: 29 | results = json.load(f) 30 | 31 | annotations = [schemas.MinimalEvaluationRecord.parse_obj(a) for a in annotations] 32 | 33 | annotation_lookup_table = {a.question_id: a for a in annotations} 34 | evaluation_records = [] 35 | for answer_record in results: 36 | ground_truth = annotation_lookup_table[answer_record["question_id"]] 37 | # It's a list, but there's only one answer for each VQA art question. 38 | # So we just take the first one and do an exact match. 39 | true_answer = ground_truth.answer 40 | is_correct = answer_record["answer"] == true_answer 41 | question_type = ground_truth.question_type 42 | evaluation_records.append( 43 | { 44 | "question_id": answer_record["question_id"], 45 | "answer": answer_record["answer"], 46 | "question_type": question_type, 47 | "is_correct": is_correct, 48 | "true_answer": true_answer, 49 | # 'answer_type': ground_truth.answer_type, 50 | } 51 | ) 52 | 53 | frame = pd.DataFrame(evaluation_records) 54 | answertype_groupby = ( 55 | frame.groupby("question_type") 56 | .apply(lambda s: s["is_correct"].sum() / len(s)) 57 | .to_frame() 58 | ) 59 | accuracies = { 60 | "overall": frame["is_correct"].sum() / len(frame), 61 | } 62 | accuracies = { 63 | **accuracies, 64 | **{_: float(answertype_groupby.loc[_]) for _ in answertype_groupby.index}, 65 | } 66 | accuracies = {k: round(v, 4) for k, v in accuracies.items()} 67 | return accuracies 68 | 69 | 70 | if __name__ == "__main__": 71 | parser = ArgumentParser() 72 | parser.add_argument( 73 | "result_file", help="Path to a JSON result file generated by an evaluation." 74 | ) 75 | args = parser.parse_args() 76 | 77 | results_file = args.result_file 78 | 79 | accuracies = exact_match_eval(annotation_file, results_file) 80 | 81 | pp.pprint(accuracies) 82 | with open(Path(results_file).parent / "rsvqa_lr_eval.json", "w") as f: 83 | json.dump(accuracies, f) 84 | -------------------------------------------------------------------------------- /schemas.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Optional, List, Dict, Union, NewType 3 | 4 | SemArtImageName = NewType("SemArtImageName", str) 5 | 6 | 7 | class TrainingRecord(BaseModel): 8 | dataset: str 9 | image: str 10 | question: str 11 | question_id: int 12 | answer: Optional[List[str]] = None 13 | rationales: Optional[List[str]] = None 14 | 15 | 16 | class TestingRecord(BaseModel): 17 | question_id: int 18 | question: str 19 | image: str 20 | dataset: Optional[str] 21 | original_question_id: Optional[Union[int, str]] 22 | 23 | 24 | # AOKVQA has string type question ids. 25 | class AnswerRecord(BaseModel): 26 | question_id: Union[int, str] 27 | answer: str 28 | score: float 29 | 30 | 31 | class QuestionRecord(BaseModel): 32 | # For datasets which use COCO images, the image_id is the COCO image id. 33 | # SemArt doesn't have image ids but names, so we use the names. 34 | image_id: Union[int, SemArtImageName] 35 | question: str 36 | question_id: int 37 | 38 | 39 | class VQAAnnotationSubRecord(BaseModel): 40 | answer: str 41 | answer_confidence: str 42 | answer_id: int 43 | 44 | 45 | class VQAAnnotationRecord(BaseModel): 46 | question_type: str 47 | answers: List[VQAAnnotationSubRecord] 48 | image_id: Union[int, SemArtImageName] 49 | answer_type: str 50 | question_id: int 51 | multiple_choice_answer: Optional[str] = None 52 | 53 | 54 | # This can be used for VQA datasets that only have 55 | # one ground truth answer per question, and so can't 56 | # be used easily with the VQAv2 evaluation code. 57 | class MinimalEvaluationRecord(BaseModel): 58 | question_id: int 59 | answer: str 60 | question: str 61 | image: Union[str, int] 62 | question_type: Optional[str] = None 63 | answer_type: Optional[str] = None 64 | dataset: Optional[str] = None 65 | -------------------------------------------------------------------------------- /setup.md: -------------------------------------------------------------------------------- 1 | ``` 2 | conda create -n blip python=3.8 3 | pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 4 | pip3 install timm==0.4.12 5 | pip3 install transformers==4.15.0 6 | pip3 install fairscale==0.4.4 7 | pip3 install pycocoevalcap 8 | pip3 install jupyter 9 | pip3 install ipdb 10 | pip3 install hydra-core 11 | pip3 install ruamel.yaml 12 | pip3 install opencv-python 13 | ``` 14 | 15 | # For development 16 | ``` 17 | pip3 install pytest 18 | pip3 install pyinstrument 19 | ``` -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codezakh/SelTDA/8ca7d53fb6ef1d8dec62e52bc9a46df4a194e06b/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codezakh/SelTDA/8ca7d53fb6ef1d8dec62e52bc9a46df4a194e06b/tests/conftest.py -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | import cli 2 | import functools 3 | 4 | 5 | def test_parse_args(request, monkeypatch): 6 | with monkeypatch.context() as m: 7 | # Prevent the argparser from actually parsing the pytest args 8 | # and choking. We just override the parse_args method to always 9 | # parse the empty list. 10 | parse_args = cli.ArgumentParser.parse_args 11 | parse_empty_list = functools.partialmethod(parse_args, args=[]) 12 | m.setattr(cli.ArgumentParser, "parse_args", parse_empty_list) 13 | cli.parse_args(default_config_path="./configs/vqa.yaml") 14 | 15 | 16 | def test_cli_setup(request, tmp_path_factory, monkeypatch): 17 | with monkeypatch.context() as m: 18 | # Prevent the argparser from actually parsing the pytest args 19 | # and choking. We just override the parse_args method to always 20 | # parse the empty list. 21 | parse_args = cli.ArgumentParser.parse_args 22 | parse_empty_list = functools.partialmethod(parse_args, args=[]) 23 | m.setattr(cli.ArgumentParser, "parse_args", parse_empty_list) 24 | args, config = cli.parse_args(default_config_path="./configs/vqa.yaml") 25 | args.output_dir = str(tmp_path_factory.mktemp("test_cli_setup")) 26 | cli.setup(args, config) 27 | -------------------------------------------------------------------------------- /tests/test_generate_questions.py: -------------------------------------------------------------------------------- 1 | from generate_questions import VQARecord, AmbiguousBooleanAnswerError 2 | import pytest 3 | 4 | 5 | def test_parsing_rationale(): 6 | raw_model_output = ": what does this animal live in?. answer : trees, forest, savanna, park, forest, woods, grassland, zoo, zoo, zoo. rationale : the animals in this photo are found in the savannah. there is a small group of animals that appear to have some unique features. the animal on the left seems to be an elephant or tiger." 7 | record = VQARecord.build_from_raw_model_output( 8 | raw_model_output, "/fake/path/to/image.jpg", parse_rationale=True 9 | ) 10 | assert record.answer == [ 11 | _.strip() 12 | for _ in "trees, forest, savanna, park, forest, woods, grassland, zoo, zoo, zoo".split( 13 | "," 14 | ) 15 | ] 16 | assert record.question == "what does this animal live in?" 17 | assert ( 18 | record.rationale 19 | == "the animals in this photo are found in the savannah. there is a small group of animals that appear to have some unique features. the animal on the left seems to be an elephant or tiger." 20 | ) 21 | 22 | 23 | def test_parsing_output_with_no_rationale(): 24 | raw_model_output = ": what does this animal live in?. answer : trees, forest, savanna, park, forest, woods, grassland, zoo, zoo, zoo." 25 | record = VQARecord.build_from_raw_model_output( 26 | raw_model_output, "/fake/path/to/image.jpg", parse_rationale=False 27 | ) 28 | assert record.rationale is None 29 | 30 | 31 | def test_parsing_when_rationale_comes_first(): 32 | """Because we have models that generate the rationale first.""" 33 | raw_model_output = "rationale : the animals in this photo are found in the savannah. there is a small group of animals that appear to have some unique features. the animal on the left seems to be an elephant or tiger. question: what does this animal live in?. answer : trees, forest, savanna, park, forest, woods, grassland, zoo, zoo, zoo." 34 | record = VQARecord.build_from_raw_model_output( 35 | raw_model_output, "/fake/path/to/image.jpg", parse_rationale=True 36 | ) 37 | assert record.answer == [ 38 | _.strip() 39 | for _ in "trees, forest, savanna, park, forest, woods, grassland, zoo, zoo, zoo".split( 40 | "," 41 | ) 42 | ] 43 | assert record.question == "what does this animal live in?" 44 | assert ( 45 | record.rationale 46 | == "the animals in this photo are found in the savannah. there is a small group of animals that appear to have some unique features. the animal on the left seems to be an elephant or tiger." 47 | ) 48 | 49 | 50 | def test_parsing_question_when_no_colon_first(): 51 | # Models which are very good _do not_ generate an extraneous colon at the start of generated 52 | # questions. 53 | raw_model_output = " what might the person next to the suitcase be doing?. answer : waiting, waiting, waiting, waiting, waiting, waiting, waiting, waiting, waiting, waiting" 54 | record = VQARecord.build_from_raw_model_output( 55 | raw_model_output, "/fake/path/to/image.jpg", parse_rationale=False 56 | ) 57 | assert record.answer == [ 58 | _.strip() 59 | for _ in "waiting, waiting, waiting, waiting, waiting, waiting, waiting, waiting, waiting, waiting".split( 60 | "," 61 | ) 62 | ] 63 | assert record.question == "what might the person next to the suitcase be doing?" 64 | 65 | 66 | def test_parsing_question_when_ambiguous_boolean_answer(): 67 | raw_model_output = "is the person blue? answer: yes, no" 68 | with pytest.raises(AmbiguousBooleanAnswerError): 69 | record = VQARecord.build_from_raw_model_output( 70 | raw_model_output, "/fake/path/to/image.jpg", parse_rationale=False 71 | ) 72 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from omegaconf import OmegaConf 3 | import torch 4 | from models.blip import decoder_from_config 5 | import pytest 6 | 7 | 8 | @pytest.mark.slow 9 | def test_decoder_forward_pass(request): 10 | config = OmegaConf.load(Path(request.config.rootdir) / "configs" / "vqg.yaml") 11 | # Don't load a checkpoint, this just makes it faster. 12 | config.pretrained = None 13 | torch.hub.set_dir(config.torch_home) 14 | model = decoder_from_config(config) 15 | # Create a fake image. 16 | image = torch.rand(1, 3, config["image_size"], config["image_size"]) 17 | # Create a fake text. 18 | text = ["Do you like horses?"] 19 | model(image, text) 20 | 21 | 22 | @pytest.mark.slow 23 | def test_tokenization_length_respected(request): 24 | config = OmegaConf.load(Path(request.config.rootdir) / "configs" / "vqg.yaml") 25 | config.tokenizer_max_length = 10 26 | # Don't load a checkpoint, this just makes it faster. 27 | config.pretrained = None 28 | torch.hub.set_dir(config.torch_home) 29 | 30 | model = decoder_from_config(config) 31 | 32 | # Make a string much longer than the tokenizer maximum length 33 | # and check to make sure it gets chopped down by the tokenizer. 34 | text = "x " * (config.tokenizer_max_length**2) 35 | 36 | tokenized = model.tokenize(text) 37 | 38 | assert len(tokenized.input_ids.squeeze()) <= config.tokenizer_max_length 39 | -------------------------------------------------------------------------------- /train_vqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | """ 8 | import argparse 9 | from inspect import Attribute 10 | import os 11 | 12 | # import ruamel.yaml as yaml 13 | import numpy as np 14 | import random 15 | import time 16 | import datetime 17 | import json 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.utils.data import DataLoader 24 | import torch.backends.cudnn as cudnn 25 | import torch.distributed as dist 26 | import wandb 27 | from omegaconf import OmegaConf 28 | 29 | from models.blip_vqa import blip_vqa 30 | import utils 31 | from utils import cosine_lr_schedule 32 | from data import create_dataset, create_sampler, create_loader 33 | from data.vqa_dataset import vqa_collate_fn 34 | from data.utils import save_result 35 | import cli 36 | 37 | 38 | def train(model, data_loader, optimizer, epoch, device, wandb_logger=None): 39 | # train 40 | model.train() 41 | 42 | metric_logger = utils.MetricLogger(delimiter=" ") 43 | metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) 44 | metric_logger.add_meter( 45 | "loss", utils.SmoothedValue(window_size=1, fmt="{value:.4f}") 46 | ) 47 | 48 | header = "Train Epoch: [{}]".format(epoch) 49 | print_freq = 50 50 | 51 | for i, (image, question, answer, weights, n) in enumerate( 52 | metric_logger.log_every(data_loader, print_freq, header) 53 | ): 54 | image, weights = image.to(device, non_blocking=True), weights.to( 55 | device, non_blocking=True 56 | ) 57 | 58 | loss = model(image, question, answer, train=True, n=n, weights=weights) 59 | 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | 64 | metric_logger.update(loss=loss.item()) 65 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 66 | 67 | if i % print_freq == 0: 68 | if utils.is_main_process() and wandb_logger: 69 | wandb_logger.log( 70 | data={ 71 | "loss": loss.item(), 72 | "lr": optimizer.param_groups[0]["lr"], 73 | } 74 | ) 75 | 76 | # gather the stats from all processes 77 | metric_logger.synchronize_between_processes() 78 | print("Averaged stats:", metric_logger.global_avg()) 79 | return { 80 | k: "{:.3f}".format(meter.global_avg) 81 | for k, meter in metric_logger.meters.items() 82 | } 83 | 84 | 85 | @torch.no_grad() 86 | def evaluation(model, data_loader, device, config): 87 | # test 88 | model.eval() 89 | 90 | metric_logger = utils.MetricLogger(delimiter=" ") 91 | header = "Generate VQA test result:" 92 | print_freq = 50 93 | 94 | result = [] 95 | 96 | if config["inference"] == "rank": 97 | answer_list = data_loader.dataset.answer_list 98 | answer_candidates = model.tokenizer( 99 | answer_list, padding="longest", return_tensors="pt" 100 | ).to(device) 101 | answer_candidates.input_ids[:, 0] = model.tokenizer.bos_token_id 102 | 103 | for n, (image, question, question_id) in enumerate( 104 | metric_logger.log_every(data_loader, print_freq, header) 105 | ): 106 | image = image.to(device, non_blocking=True) 107 | 108 | # We'll only collect these when doing rank inference for now. 109 | 110 | if config["inference"] == "generate": 111 | answers = model(image, question, train=False, inference="generate") 112 | 113 | for answer, ques_id in zip(answers, question_id): 114 | # ques_id can be either a one-element Tensor[int] or a 115 | # string. We convert it to an int (not sure why), but 116 | # this means we have to handle each case separately. 117 | try: 118 | ques_id = int(ques_id.item()) 119 | except AttributeError: 120 | ques_id = int(ques_id) 121 | result.append({"question_id": ques_id, "answer": answer}) 122 | 123 | elif config["inference"] == "rank": 124 | answer_ids, answer_scores = model( 125 | image, 126 | question, 127 | answer_candidates, 128 | train=False, 129 | inference="rank", 130 | k_test=config["k_test"], 131 | return_scores=True, 132 | ) 133 | for ques_id, answer_id, answer_score in zip( 134 | question_id, answer_ids, answer_scores 135 | ): 136 | # The question id is a string in some datasets (VQAv2) 137 | # and an integer in other datasets (A-OKVQA). When it's a 138 | # an integer, it gets type casted to a tensor, which we 139 | # can't serialize to JSON without converting it back to an int. 140 | try: 141 | ques_id = int(ques_id.item()) 142 | except AttributeError: 143 | pass 144 | 145 | result.append( 146 | { 147 | "question_id": ques_id, 148 | "answer": answer_list[answer_id], 149 | "score": answer_score.item(), 150 | } 151 | ) 152 | 153 | return result 154 | 155 | 156 | def main(args, config): 157 | utils.init_distributed_mode(args) 158 | 159 | device = torch.device(args.device) 160 | 161 | # fix the seed for reproducibility 162 | seed = args.seed + utils.get_rank() 163 | torch.manual_seed(seed) 164 | np.random.seed(seed) 165 | random.seed(seed) 166 | cudnn.benchmark = True 167 | 168 | if utils.is_main_process() and config.wandb: 169 | print("Is main process, creating W&B logger.") 170 | wandb_logger = wandb.init( 171 | project="mithril-alice-valley", 172 | entity="zakh", 173 | config=OmegaConf.to_container(config), 174 | ) 175 | else: 176 | wandb_logger = None 177 | 178 | #### Dataset #### 179 | print("Creating vqa datasets") 180 | datasets = create_dataset(config.dataset_name, config) 181 | 182 | if args.distributed: 183 | num_tasks = utils.get_world_size() 184 | global_rank = utils.get_rank() 185 | samplers = create_sampler(datasets, [True, False], num_tasks, global_rank) 186 | else: 187 | samplers = [None, None] 188 | 189 | train_loader, test_loader = create_loader( 190 | datasets, 191 | samplers, 192 | batch_size=[config["batch_size_train"], config["batch_size_test"]], 193 | num_workers=[4, 4], 194 | is_trains=[True, False], 195 | collate_fns=[vqa_collate_fn, None], 196 | ) 197 | #### Model #### 198 | print("Creating model") 199 | model = blip_vqa( 200 | pretrained=config["pretrained"], 201 | image_size=config["image_size"], 202 | vit=config["vit"], 203 | vit_grad_ckpt=config["vit_grad_ckpt"], 204 | vit_ckpt_layer=config["vit_ckpt_layer"], 205 | ) 206 | 207 | model = model.to(device) 208 | 209 | model_without_ddp = model 210 | if args.distributed: 211 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 212 | model_without_ddp = model.module 213 | 214 | optimizer = torch.optim.AdamW( 215 | params=model.parameters(), 216 | lr=config["init_lr"], 217 | weight_decay=config["weight_decay"], 218 | ) 219 | 220 | best = 0 221 | best_epoch = 0 222 | 223 | print("Start training") 224 | start_time = time.time() 225 | epochs = list(range(0, config.max_epoch)) 226 | for epoch in epochs: 227 | if not args.evaluate: 228 | if args.distributed: 229 | train_loader.sampler.set_epoch(epoch) 230 | 231 | cosine_lr_schedule( 232 | optimizer, 233 | epoch, 234 | config["max_epoch"], 235 | config["init_lr"], 236 | config["min_lr"], 237 | ) 238 | 239 | train_stats = train( 240 | model, train_loader, optimizer, epoch, device, wandb_logger=wandb_logger 241 | ) 242 | 243 | else: 244 | break 245 | 246 | if utils.is_main_process(): 247 | log_stats = { 248 | **{f"train_{k}": v for k, v in train_stats.items()}, 249 | "epoch": epoch, 250 | } 251 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 252 | f.write(json.dumps(log_stats) + "\n") 253 | 254 | if config.save_last_only: 255 | should_save = epoch == epochs[-1] 256 | else: 257 | should_save = True 258 | 259 | if should_save: 260 | save_obj = { 261 | "model": model_without_ddp.state_dict(), 262 | "optimizer": optimizer.state_dict(), 263 | "config": config, 264 | "epoch": epoch, 265 | } 266 | 267 | torch.save( 268 | save_obj, 269 | os.path.join(args.output_dir, "checkpoint_%02d.pth" % epoch), 270 | ) 271 | 272 | dist.barrier() 273 | 274 | vqa_result = evaluation(model_without_ddp, test_loader, device, config) 275 | result_file = save_result(vqa_result, args.result_dir, "vqa_result") 276 | 277 | total_time = time.time() - start_time 278 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 279 | print("Training time {}".format(total_time_str)) 280 | 281 | 282 | if __name__ == "__main__": 283 | args, config = cli.parse_args(default_config_path="./configs/vqa.yaml") 284 | cli.setup(args, config) 285 | main(args, config) 286 | -------------------------------------------------------------------------------- /train_vqg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import random 5 | import time 6 | import datetime 7 | import json 8 | from pathlib import Path 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | from torch.utils.data import DataLoader 16 | import hydra 17 | from omegaconf import OmegaConf 18 | import wandb 19 | import data 20 | 21 | from models.blip import decoder_from_config 22 | import utils 23 | from utils import cosine_lr_schedule 24 | from data import create_dataset, create_sampler, create_loader 25 | from data.utils import save_result, coco_caption_eval 26 | import cli 27 | 28 | 29 | class Trainer: 30 | def __init__( 31 | self, model, data_loader, optimizer, device, wandb_logger=None, print_freq=50 32 | ) -> None: 33 | self.model = model 34 | self.data_loader = data_loader 35 | self.optimizer = optimizer 36 | self.device = device 37 | self.wandb_logger = wandb_logger 38 | self.print_freq = print_freq 39 | 40 | def train_one_epoch(self, epoch): 41 | self.model.train() 42 | 43 | self.metric_logger = utils.MetricLogger(delimiter=" ") 44 | self.metric_logger.add_meter( 45 | "lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}") 46 | ) 47 | self.metric_logger.add_meter( 48 | "loss", utils.SmoothedValue(window_size=1, fmt="{value:.4f}") 49 | ) 50 | header = "Train Caption Epoch: [{}]".format(epoch) 51 | print_freq = 50 52 | 53 | for i, (image, caption, _) in enumerate( 54 | self.metric_logger.log_every(self.data_loader, print_freq, header) 55 | ): 56 | self.train_step(image, caption, i) 57 | 58 | # gather the stats from all processes 59 | self.metric_logger.synchronize_between_processes() 60 | print("Averaged stats:", self.metric_logger.global_avg()) 61 | return { 62 | k: "{:.3f}".format(meter.global_avg) 63 | for k, meter in self.metric_logger.meters.items() 64 | } 65 | 66 | def train_step(self, image, caption, batch_idx): 67 | image = image.to(self.device) 68 | 69 | loss = self.model(image, caption) 70 | 71 | self.optimizer.zero_grad() 72 | loss.backward() 73 | self.optimizer.step() 74 | 75 | self.metric_logger.update(loss=loss.item()) 76 | self.metric_logger.update(lr=self.optimizer.param_groups[0]["lr"]) 77 | 78 | if batch_idx % self.print_freq == 0: 79 | if utils.is_main_process() and self.wandb_logger: 80 | self.wandb_logger.log( 81 | data={ 82 | "loss": loss.item(), 83 | "lr": self.optimizer.param_groups[0]["lr"], 84 | } 85 | ) 86 | 87 | 88 | @torch.no_grad() 89 | def evaluate(model, data_loader, device, config): 90 | # evaluate 91 | model.eval() 92 | 93 | metric_logger = utils.MetricLogger(delimiter=" ") 94 | header = "Caption generation:" 95 | print_freq = 10 96 | 97 | result = [] 98 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 99 | 100 | image = image.to(device) 101 | 102 | captions = model.generate( 103 | image, 104 | sample=False, 105 | num_beams=config["num_beams"], 106 | max_length=config["max_length"], 107 | min_length=config["min_length"], 108 | ) 109 | 110 | for caption, img_id in zip(captions, image_id): 111 | result.append({"image_id": img_id.item(), "caption": caption}) 112 | 113 | return result 114 | 115 | 116 | def main(args, config): 117 | utils.init_distributed_mode(args) 118 | 119 | device = torch.device(args.device) 120 | 121 | # fix the seed for reproducibility 122 | seed = args.seed + utils.get_rank() 123 | torch.manual_seed(seed) 124 | np.random.seed(seed) 125 | random.seed(seed) 126 | cudnn.benchmark = True 127 | 128 | #### Dataset #### 129 | print("Creating captioning dataset") 130 | train_dataset, test_dataset = create_dataset("vqg", config) 131 | 132 | if args.distributed: 133 | num_tasks = utils.get_world_size() 134 | global_rank = utils.get_rank() 135 | samplers = create_sampler( 136 | [train_dataset, test_dataset], 137 | [True, False], 138 | num_tasks, 139 | global_rank, 140 | ) 141 | else: 142 | samplers = [None, None] 143 | 144 | train_loader, test_loader = create_loader( 145 | [train_dataset, test_dataset], 146 | samplers, 147 | batch_size=[config["batch_size"]] * 3, 148 | num_workers=[4, 4, 4], 149 | is_trains=[True, False], 150 | collate_fns=[None, None], 151 | ) 152 | 153 | #### Model #### 154 | print("Creating model") 155 | model = decoder_from_config(config) 156 | 157 | if utils.is_main_process() and config.wandb: 158 | print("Is main process, creating W&B logger.") 159 | wandb_logger = wandb.init( 160 | project="mithril-alice-valley", 161 | entity="zakh", 162 | config=OmegaConf.to_container(config), 163 | ) 164 | else: 165 | wandb_logger = None 166 | 167 | model = model.to(device) 168 | 169 | model_without_ddp = model 170 | if args.distributed: 171 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 172 | model_without_ddp = model.module 173 | 174 | optimizer = torch.optim.AdamW( 175 | params=model.parameters(), 176 | lr=config["init_lr"], 177 | weight_decay=config["weight_decay"], 178 | ) 179 | 180 | best = 0 181 | best_epoch = 0 182 | 183 | trainer = Trainer( 184 | data_loader=train_loader, 185 | optimizer=optimizer, 186 | device=device, 187 | wandb_logger=wandb_logger, 188 | model=model, 189 | ) 190 | 191 | print("Start training") 192 | start_time = time.time() 193 | epochs = list(range(config.max_epoch)) 194 | for epoch in epochs: 195 | if not args.evaluate: 196 | if args.distributed: 197 | train_loader.sampler.set_epoch(epoch) 198 | 199 | cosine_lr_schedule( 200 | optimizer, 201 | epoch, 202 | config["max_epoch"], 203 | config["init_lr"], 204 | config["min_lr"], 205 | ) 206 | 207 | # train_stats = train(model, train_loader, optimizer, epoch, device, wandb_logger=wandb_logger) 208 | train_stats = trainer.train_one_epoch(epoch=epoch) 209 | 210 | if utils.is_main_process(): 211 | 212 | if args.evaluate: 213 | pass 214 | else: 215 | 216 | if config.save_last_only: 217 | should_save = epoch == epochs[-1] 218 | else: 219 | should_save = True 220 | 221 | if should_save: 222 | save_obj = { 223 | "model": model_without_ddp.state_dict(), 224 | "optimizer": optimizer.state_dict(), 225 | "config": config, 226 | "epoch": epoch, 227 | } 228 | torch.save( 229 | save_obj, 230 | os.path.join(args.output_dir, "checkpoint_%02d.pth" % epoch), 231 | ) 232 | 233 | log_stats = { 234 | **{f"train_{k}": v for k, v in train_stats.items()}, 235 | "epoch": epoch, 236 | "best_epoch": best_epoch, 237 | } 238 | 239 | with open(os.path.join(args.output_dir, "log.txt"), "a") as f: 240 | f.write(json.dumps(log_stats) + "\n") 241 | 242 | if args.evaluate: 243 | break 244 | dist.barrier() 245 | 246 | total_time = time.time() - start_time 247 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 248 | print("Training time {}".format(total_time_str)) 249 | 250 | 251 | if __name__ == "__main__": 252 | args, config = cli.parse_args(default_config_path="configs/vqg.yaml") 253 | cli.setup(args, config) 254 | main(args, config) 255 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 5 | """Decay the learning rate""" 6 | lr = (init_lr - min_lr) * 0.5 * ( 7 | 1.0 + math.cos(math.pi * epoch / max_epoch) 8 | ) + min_lr 9 | for param_group in optimizer.param_groups: 10 | param_group["lr"] = lr 11 | 12 | 13 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 14 | """Warmup the learning rate""" 15 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 16 | for param_group in optimizer.param_groups: 17 | param_group["lr"] = lr 18 | 19 | 20 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 21 | """Decay the learning rate""" 22 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 23 | for param_group in optimizer.param_groups: 24 | param_group["lr"] = lr 25 | 26 | 27 | import numpy as np 28 | import io 29 | import os 30 | import time 31 | from collections import defaultdict, deque 32 | import datetime 33 | 34 | import torch 35 | import torch.distributed as dist 36 | 37 | 38 | class SmoothedValue(object): 39 | """Track a series of values and provide access to smoothed values over a 40 | window or the global series average. 41 | """ 42 | 43 | def __init__(self, window_size=20, fmt=None): 44 | if fmt is None: 45 | fmt = "{median:.4f} ({global_avg:.4f})" 46 | self.deque = deque(maxlen=window_size) 47 | self.total = 0.0 48 | self.count = 0 49 | self.fmt = fmt 50 | 51 | def update(self, value, n=1): 52 | self.deque.append(value) 53 | self.count += n 54 | self.total += value * n 55 | 56 | def synchronize_between_processes(self): 57 | """ 58 | Warning: does not synchronize the deque! 59 | """ 60 | if not is_dist_avail_and_initialized(): 61 | return 62 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 63 | dist.barrier() 64 | dist.all_reduce(t) 65 | t = t.tolist() 66 | self.count = int(t[0]) 67 | self.total = t[1] 68 | 69 | @property 70 | def median(self): 71 | d = torch.tensor(list(self.deque)) 72 | return d.median().item() 73 | 74 | @property 75 | def avg(self): 76 | d = torch.tensor(list(self.deque), dtype=torch.float32) 77 | return d.mean().item() 78 | 79 | @property 80 | def global_avg(self): 81 | return self.total / self.count 82 | 83 | @property 84 | def max(self): 85 | return max(self.deque) 86 | 87 | @property 88 | def value(self): 89 | return self.deque[-1] 90 | 91 | def __str__(self): 92 | return self.fmt.format( 93 | median=self.median, 94 | avg=self.avg, 95 | global_avg=self.global_avg, 96 | max=self.max, 97 | value=self.value, 98 | ) 99 | 100 | 101 | class MetricLogger(object): 102 | def __init__(self, delimiter="\t"): 103 | self.meters = defaultdict(SmoothedValue) 104 | self.delimiter = delimiter 105 | 106 | def update(self, **kwargs): 107 | for k, v in kwargs.items(): 108 | if isinstance(v, torch.Tensor): 109 | v = v.item() 110 | assert isinstance(v, (float, int)) 111 | self.meters[k].update(v) 112 | 113 | def __getattr__(self, attr): 114 | if attr in self.meters: 115 | return self.meters[attr] 116 | if attr in self.__dict__: 117 | return self.__dict__[attr] 118 | raise AttributeError( 119 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 120 | ) 121 | 122 | def __str__(self): 123 | loss_str = [] 124 | for name, meter in self.meters.items(): 125 | loss_str.append("{}: {}".format(name, str(meter))) 126 | return self.delimiter.join(loss_str) 127 | 128 | def global_avg(self): 129 | loss_str = [] 130 | for name, meter in self.meters.items(): 131 | loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) 132 | return self.delimiter.join(loss_str) 133 | 134 | def synchronize_between_processes(self): 135 | for meter in self.meters.values(): 136 | meter.synchronize_between_processes() 137 | 138 | def add_meter(self, name, meter): 139 | self.meters[name] = meter 140 | 141 | def log_every(self, iterable, print_freq, header=None): 142 | i = 0 143 | if not header: 144 | header = "" 145 | start_time = time.time() 146 | end = time.time() 147 | iter_time = SmoothedValue(fmt="{avg:.4f}") 148 | data_time = SmoothedValue(fmt="{avg:.4f}") 149 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 150 | log_msg = [ 151 | header, 152 | "[{0" + space_fmt + "}/{1}]", 153 | "eta: {eta}", 154 | "{meters}", 155 | "time: {time}", 156 | "data: {data}", 157 | ] 158 | if torch.cuda.is_available(): 159 | log_msg.append("max mem: {memory:.0f}") 160 | log_msg = self.delimiter.join(log_msg) 161 | MB = 1024.0 * 1024.0 162 | for obj in iterable: 163 | data_time.update(time.time() - end) 164 | yield obj 165 | iter_time.update(time.time() - end) 166 | if i % print_freq == 0 or i == len(iterable) - 1: 167 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 168 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 169 | if torch.cuda.is_available(): 170 | print( 171 | log_msg.format( 172 | i, 173 | len(iterable), 174 | eta=eta_string, 175 | meters=str(self), 176 | time=str(iter_time), 177 | data=str(data_time), 178 | memory=torch.cuda.max_memory_allocated() / MB, 179 | ) 180 | ) 181 | else: 182 | print( 183 | log_msg.format( 184 | i, 185 | len(iterable), 186 | eta=eta_string, 187 | meters=str(self), 188 | time=str(iter_time), 189 | data=str(data_time), 190 | ) 191 | ) 192 | i += 1 193 | end = time.time() 194 | total_time = time.time() - start_time 195 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 196 | print( 197 | "{} Total time: {} ({:.4f} s / it)".format( 198 | header, total_time_str, total_time / len(iterable) 199 | ) 200 | ) 201 | 202 | 203 | class AttrDict(dict): 204 | def __init__(self, *args, **kwargs): 205 | super(AttrDict, self).__init__(*args, **kwargs) 206 | self.__dict__ = self 207 | 208 | 209 | def compute_acc(logits, label, reduction="mean"): 210 | ret = (torch.argmax(logits, dim=1) == label).float() 211 | if reduction == "none": 212 | return ret.detach() 213 | elif reduction == "mean": 214 | return ret.mean().item() 215 | 216 | 217 | def compute_n_params(model, return_str=True): 218 | tot = 0 219 | for p in model.parameters(): 220 | w = 1 221 | for x in p.shape: 222 | w *= x 223 | tot += w 224 | if return_str: 225 | if tot >= 1e6: 226 | return "{:.1f}M".format(tot / 1e6) 227 | else: 228 | return "{:.1f}K".format(tot / 1e3) 229 | else: 230 | return tot 231 | 232 | 233 | def setup_for_distributed(is_master): 234 | """ 235 | This function disables printing when not in master process 236 | """ 237 | import builtins as __builtin__ 238 | 239 | builtin_print = __builtin__.print 240 | 241 | def print(*args, **kwargs): 242 | force = kwargs.pop("force", False) 243 | if is_master or force: 244 | builtin_print(*args, **kwargs) 245 | 246 | __builtin__.print = print 247 | 248 | 249 | def is_dist_avail_and_initialized(): 250 | if not dist.is_available(): 251 | return False 252 | if not dist.is_initialized(): 253 | return False 254 | return True 255 | 256 | 257 | def get_world_size(): 258 | if not is_dist_avail_and_initialized(): 259 | return 1 260 | return dist.get_world_size() 261 | 262 | 263 | def get_rank(): 264 | if not is_dist_avail_and_initialized(): 265 | return 0 266 | return dist.get_rank() 267 | 268 | 269 | def is_main_process(): 270 | return get_rank() == 0 271 | 272 | 273 | def save_on_master(*args, **kwargs): 274 | if is_main_process(): 275 | torch.save(*args, **kwargs) 276 | 277 | 278 | def init_distributed_mode(args): 279 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 280 | args.rank = int(os.environ["RANK"]) 281 | args.world_size = int(os.environ["WORLD_SIZE"]) 282 | args.gpu = int(os.environ["LOCAL_RANK"]) 283 | elif "SLURM_PROCID" in os.environ: 284 | args.rank = int(os.environ["SLURM_PROCID"]) 285 | args.gpu = args.rank % torch.cuda.device_count() 286 | else: 287 | print("Not using distributed mode") 288 | args.distributed = False 289 | return 290 | 291 | args.distributed = True 292 | 293 | torch.cuda.set_device(args.gpu) 294 | args.dist_backend = "nccl" 295 | print( 296 | "| distributed init (rank {}, word {}): {}".format( 297 | args.rank, args.world_size, args.dist_url 298 | ), 299 | flush=True, 300 | ) 301 | torch.distributed.init_process_group( 302 | backend=args.dist_backend, 303 | init_method=args.dist_url, 304 | world_size=args.world_size, 305 | rank=args.rank, 306 | ) 307 | torch.distributed.barrier() 308 | setup_for_distributed(args.rank == 0) 309 | -------------------------------------------------------------------------------- /vqa_ce_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm.notebook import tqdm 3 | import json 4 | from pprint import PrettyPrinter 5 | from vqa_eval_tools import VQA, VQAEval 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | 9 | 10 | pp = PrettyPrinter() 11 | 12 | annotation_file = ( 13 | "/net/acadia4a/data/zkhan/vqa-counterexamples/counterexamples_annotations.json" 14 | ) 15 | question_file = ( 16 | "/net/acadia4a/data/zkhan/vqa-counterexamples/counterexamples_questions.json" 17 | ) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = ArgumentParser() 22 | parser.add_argument( 23 | "result_file", help="Path to a JSON result file generated by an evaluation." 24 | ) 25 | args = parser.parse_args() 26 | 27 | results_file = args.result_file 28 | 29 | with open(annotation_file, "r") as f: 30 | annotations = json.load(f) 31 | 32 | advqa_obj = VQA(annotation_file=annotation_file, question_file=question_file) 33 | 34 | # We have to convert the question_id field to be an integer >.< 35 | with open(results_file, "r") as f: 36 | predicted = json.load(f) 37 | 38 | for element in predicted: 39 | element["question_id"] = int(element["question_id"]) 40 | 41 | with open(results_file, "w") as f: 42 | json.dump(predicted, f) 43 | 44 | result_obj = advqa_obj.loadRes(resFile=results_file, quesFile=question_file) 45 | 46 | advqa_eval = VQAEval(advqa_obj, result_obj, n=2) 47 | advqa_eval.evaluate() 48 | print(f"Completed evaluation of {results_file}") 49 | with open(Path(results_file).parent / "vqa_ce-eval.json", "w") as f: 50 | json.dump(advqa_eval.accuracy, f) 51 | advqa_eval.accuracy.pop("perQuestionType") 52 | pp.pprint(advqa_eval.accuracy) 53 | -------------------------------------------------------------------------------- /vqa_eval_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .vqa import VQA 2 | from .vqa_eval import VQAEval 3 | -------------------------------------------------------------------------------- /vqa_eval_tools/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = "aagrawal" 2 | __version__ = "0.9" 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | 24 | class VQA: 25 | def __init__(self, annotation_file=None, question_file=None): 26 | """ 27 | Constructor of VQA helper class for reading and visualizing questions and answers. 28 | :param annotation_file (str): location of VQA annotation file 29 | :return: 30 | """ 31 | # load dataset 32 | self.dataset = {} 33 | self.questions = {} 34 | self.qa = {} 35 | self.qqa = {} 36 | self.imgToQA = {} 37 | if not annotation_file == None and not question_file == None: 38 | print("loading VQA annotations and questions into memory...") 39 | time_t = datetime.datetime.utcnow() 40 | dataset = json.load(open(annotation_file, "r")) 41 | questions = json.load(open(question_file, "r")) 42 | print(datetime.datetime.utcnow() - time_t) 43 | self.dataset = dataset 44 | self.questions = questions 45 | self.createIndex() 46 | 47 | def createIndex(self): 48 | # create index 49 | print("creating index...") 50 | imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]} 51 | qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} 52 | qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} 53 | for ann in self.dataset["annotations"]: 54 | imgToQA[ann["image_id"]] += [ann] 55 | qa[ann["question_id"]] = ann 56 | for ques in self.questions["questions"]: 57 | qqa[ques["question_id"]] = ques 58 | print("index created!") 59 | 60 | # create class members 61 | self.qa = qa 62 | self.qqa = qqa 63 | self.imgToQA = imgToQA 64 | 65 | def info(self): 66 | """ 67 | Print information about the VQA annotation file. 68 | :return: 69 | """ 70 | for key, value in list(self.datset["info"].items()): 71 | print("%s: %s" % (key, value)) 72 | 73 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 74 | """ 75 | Get question ids that satisfy given filter conditions. default skips that filter 76 | :param imgIds (int array) : get question ids for given imgs 77 | quesTypes (str array) : get question ids for given question types 78 | ansTypes (str array) : get question ids for given answer types 79 | :return: ids (int array) : integer array of question ids 80 | """ 81 | imgIds = imgIds if isinstance(imgIds, list) else [imgIds] 82 | quesTypes = quesTypes if isinstance(quesTypes, list) else [quesTypes] 83 | ansTypes = ansTypes if isinstance(ansTypes, list) else [ansTypes] 84 | 85 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 86 | anns = self.dataset["annotations"] 87 | else: 88 | if not len(imgIds) == 0: 89 | anns = sum( 90 | [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], 91 | [], 92 | ) 93 | else: 94 | anns = self.dataset["annotations"] 95 | anns = ( 96 | anns 97 | if len(quesTypes) == 0 98 | else [ann for ann in anns if ann["question_type"] in quesTypes] 99 | ) 100 | anns = ( 101 | anns 102 | if len(ansTypes) == 0 103 | else [ann for ann in anns if ann["answer_type"] in ansTypes] 104 | ) 105 | ids = [ann["question_id"] for ann in anns] 106 | return ids 107 | 108 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 109 | """ 110 | Get image ids that satisfy given filter conditions. default skips that filter 111 | :param quesIds (int array) : get image ids for given question ids 112 | quesTypes (str array) : get image ids for given question types 113 | ansTypes (str array) : get image ids for given answer types 114 | :return: ids (int array) : integer array of image ids 115 | """ 116 | quesIds = quesIds if isinstance(quesIds, list) else [quesIds] 117 | quesTypes = quesTypes if isinstance(quesTypes, list) else [quesTypes] 118 | ansTypes = ansTypes if isinstance(ansTypes, list) else [ansTypes] 119 | 120 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 121 | anns = self.dataset["annotations"] 122 | else: 123 | if not len(quesIds) == 0: 124 | anns = sum( 125 | [self.qa[quesId] for quesId in quesIds if quesId in self.qa], [] 126 | ) 127 | else: 128 | anns = self.dataset["annotations"] 129 | anns = ( 130 | anns 131 | if len(quesTypes) == 0 132 | else [ann for ann in anns if ann["question_type"] in quesTypes] 133 | ) 134 | anns = ( 135 | anns 136 | if len(ansTypes) == 0 137 | else [ann for ann in anns if ann["answer_type"] in ansTypes] 138 | ) 139 | ids = [ann["image_id"] for ann in anns] 140 | return ids 141 | 142 | def loadQA(self, ids=[]): 143 | """ 144 | Load questions and answers with the specified question ids. 145 | :param ids (int array) : integer ids specifying question ids 146 | :return: qa (object array) : loaded qa objects 147 | """ 148 | if isinstance(ids, list): 149 | return [self.qa[id] for id in ids] 150 | elif isinstance(ids, int): 151 | return [self.qa[ids]] 152 | 153 | def showQA(self, anns): 154 | """ 155 | Display the specified annotations. 156 | :param anns (array of object): annotations to display 157 | :return: None 158 | """ 159 | if len(anns) == 0: 160 | return 0 161 | for ann in anns: 162 | quesId = ann["question_id"] 163 | print("Question: %s" % (self.qqa[quesId]["question"])) 164 | for ans in ann["answers"]: 165 | print("Answer %d: %s" % (ans["answer_id"], ans["answer"])) 166 | 167 | def loadRes(self, resFile, quesFile): 168 | """ 169 | Load result file and return a result object. 170 | :param resFile (str) : file name of result file 171 | :return: res (obj) : result api object 172 | """ 173 | res = VQA() 174 | res.questions = json.load(open(quesFile)) 175 | # res.dataset["info"] = copy.deepcopy(self.questions["info"]) 176 | # res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"]) 177 | # res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"]) 178 | # res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"]) 179 | # res.dataset["license"] = copy.deepcopy(self.questions["license"]) 180 | 181 | print("Loading and preparing results... ") 182 | time_t = datetime.datetime.utcnow() 183 | anns = json.load(open(resFile)) 184 | assert isinstance(anns, list), "results is not an array of objects" 185 | annsQuesIds = [ann["question_id"] for ann in anns] 186 | assert set(annsQuesIds) == set( 187 | self.getQuesIds() 188 | ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file." 189 | for ann in anns: 190 | quesId = ann["question_id"] 191 | # We are only using this code for direct eval, so 192 | # we comment out this branch. 193 | # if res.dataset["task_type"] == "Multiple Choice": 194 | # assert ( 195 | # ann["answer"] in self.qqa[quesId]["multiple_choices"] 196 | # ), "predicted answer is not one of the multiple choices" 197 | qaAnn = self.qa[quesId] 198 | ann["image_id"] = qaAnn["image_id"] 199 | ann["question_type"] = qaAnn["question_type"] 200 | ann["answer_type"] = qaAnn["answer_type"] 201 | print( 202 | "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds()) 203 | ) 204 | 205 | res.dataset["annotations"] = anns 206 | res.createIndex() 207 | return res 208 | -------------------------------------------------------------------------------- /vqa_introspect_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm.notebook import tqdm 3 | import json 4 | from pprint import PrettyPrinter 5 | from vqa_eval_tools import VQA, VQAEval 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | 9 | 10 | pp = PrettyPrinter() 11 | 12 | annotation_file = "/net/acadia4a/data/zkhan/lavis_cache/vqa-introspect/annotations/scoring_format_annotations.json" 13 | question_file = "/net/acadia4a/data/zkhan/lavis_cache/vqa-introspect/annotations/scoring_format_questions.json" 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = ArgumentParser() 18 | parser.add_argument( 19 | "result_file", help="Path to a JSON result file generated by an evaluation." 20 | ) 21 | args = parser.parse_args() 22 | 23 | results_file = args.result_file 24 | # results_file = '/net/acadia4a/data/zkhan/mithril/advqa-0-shot-evals/35_blip_vqa_baseline/result/vqa_result.json' 25 | 26 | # The annotations are missing an "question_type" key, so we create a new annotation file which does 27 | # have the key. We just copy the "answer_type" key to "question_type", they are the same thing, I think. 28 | with open(annotation_file, "r") as f: 29 | annotations = json.load(f) 30 | 31 | vqa_obj = VQA(annotation_file=annotation_file, question_file=question_file) 32 | 33 | # We have to convert the question_id field to be an integer >.< 34 | with open(results_file, "r") as f: 35 | predicted = json.load(f) 36 | 37 | for element in predicted: 38 | element["question_id"] = int(element["question_id"]) 39 | 40 | with open(results_file, "w") as f: 41 | json.dump(predicted, f) 42 | 43 | result_obj = vqa_obj.loadRes(resFile=results_file, quesFile=question_file) 44 | 45 | vqa_eval = VQAEval(vqa_obj, result_obj, n=2) 46 | vqa_eval.evaluate() 47 | print(f"Completed evaluation of {results_file}") 48 | pp.pprint(vqa_eval.accuracy) 49 | with open(Path(results_file).parent / "vqa_rephrasings_eval.json", "w") as f: 50 | json.dump(vqa_eval.accuracy, f) 51 | -------------------------------------------------------------------------------- /vqa_rephrasings_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm.notebook import tqdm 3 | import json 4 | from pprint import PrettyPrinter 5 | from vqa_eval_tools import VQA, VQAEval 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | 9 | 10 | pp = PrettyPrinter() 11 | 12 | annotation_file = "/net/acadia4a/data/zkhan/vqa-rephrasings/v2_mscoco_valrep2014_humans_og_annotations.json" 13 | question_file = "/net/acadia4a/data/zkhan/vqa-rephrasings/v2_OpenEnded_mscoco_valrep2014_humans_og_questions.json" 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = ArgumentParser() 18 | parser.add_argument( 19 | "result_file", help="Path to a JSON result file generated by an evaluation." 20 | ) 21 | args = parser.parse_args() 22 | 23 | results_file = args.result_file 24 | # results_file = '/net/acadia4a/data/zkhan/mithril/advqa-0-shot-evals/35_blip_vqa_baseline/result/vqa_result.json' 25 | 26 | # The annotations are missing an "question_type" key, so we create a new annotation file which does 27 | # have the key. We just copy the "answer_type" key to "question_type", they are the same thing, I think. 28 | with open(annotation_file, "r") as f: 29 | annotations = json.load(f) 30 | 31 | vqa_obj = VQA(annotation_file=annotation_file, question_file=question_file) 32 | 33 | # We have to convert the question_id field to be an integer >.< 34 | with open(results_file, "r") as f: 35 | predicted = json.load(f) 36 | 37 | for element in predicted: 38 | element["question_id"] = int(element["question_id"]) 39 | 40 | with open(results_file, "w") as f: 41 | json.dump(predicted, f) 42 | 43 | result_obj = vqa_obj.loadRes(resFile=results_file, quesFile=question_file) 44 | 45 | vqa_eval = VQAEval(vqa_obj, result_obj, n=2) 46 | vqa_eval.evaluate() 47 | print(f"Completed evaluation of {results_file}") 48 | pp.pprint(vqa_eval.accuracy) 49 | with open(Path(results_file).parent / "vqa_rephrasings_eval.json", "w") as f: 50 | json.dump(vqa_eval.accuracy, f) 51 | -------------------------------------------------------------------------------- /vqav2_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm.notebook import tqdm 3 | import json 4 | from pprint import PrettyPrinter 5 | from vqa_eval_tools import VQA, VQAEval 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | 9 | 10 | pp = PrettyPrinter() 11 | 12 | # The annotations are missing an "question_type" key, so we create a new annotation file which does 13 | # have the key. We just copy the "answer_type" key to "question_type", they are the same thing, I think. 14 | annotation_file = "/net/acadia10a/data/zkhan/direct_answer_evaluations/v2_mscoco_val2014_annotations.json" 15 | question_file = "/net/acadia10a/data/zkhan/direct_answer_evaluations/v2_OpenEnded_mscoco_val2014_questions.json" 16 | 17 | # annotation_file = '/home/zkhan/v2_mscoco_val2014_annotations.json' 18 | # question_file = '/home/zkhan/v2_OpenEnded_mscoco_val2014_questions.json' 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = ArgumentParser() 23 | parser.add_argument( 24 | "result_file", help="Path to a JSON result file generated by an evaluation." 25 | ) 26 | args = parser.parse_args() 27 | 28 | results_file = args.result_file 29 | 30 | vqa_v2_obj = VQA(annotation_file=annotation_file, question_file=question_file) 31 | 32 | with open(results_file, "r") as f: 33 | predicted = json.load(f) 34 | 35 | for element in predicted: 36 | element["question_id"] = int(element["question_id"]) 37 | 38 | # The JSON I use for the VQAv2 validation set is missing 39 | # these two questions. It shouldn't make a big difference 40 | # in the evaluations, so we just predict a nonsense answer 41 | # for them. 42 | missing_qids = (196280004, 362391000) 43 | for missing_qid in missing_qids: 44 | predicted.append( 45 | {"question_id": missing_qid, "answer": "i forgor lol", "score": -5} 46 | ) 47 | 48 | with open(results_file, "w") as f: 49 | json.dump(predicted, f) 50 | 51 | result_obj = vqa_v2_obj.loadRes(resFile=results_file, quesFile=question_file) 52 | 53 | vqa_v2_eval = VQAEval(vqa_v2_obj, result_obj, n=2) 54 | vqa_v2_eval.evaluate() 55 | print(f"Completed evaluation of {results_file}") 56 | pp.pprint(vqa_v2_eval.accuracy) 57 | with open(Path(results_file).parent / "vqa_v2_eval.json", "w") as f: 58 | json.dump(vqa_v2_eval.accuracy, f) 59 | --------------------------------------------------------------------------------