├── Grounding_bbox.py ├── LICENSE ├── MARVL.py ├── NLVR.py ├── Pretrain.py ├── README.md ├── Retrieval.py ├── VQA.py ├── VQA_msrvtt.py ├── VQA_msvd.py ├── WIT.py ├── XGQA.py ├── XRetrieval.py ├── XVNLI.py ├── accelerators ├── __init__.py ├── accelerator.py └── apex_ddp_accelerator.py ├── configs ├── config_beit2_base.json ├── config_beit2_large.json ├── finetune │ ├── coco_captioning_large.yaml │ ├── refcoco_grounding_large.yaml │ ├── vqa2_base.yaml │ └── vqa2_large.yaml └── pretrain │ ├── multilingual_cclm_x2vlm_base.yaml │ ├── multilingual_cclm_x2vlm_large.yaml │ ├── x2vlm_base_1b.yaml │ ├── x2vlm_base_1b_stage2.yaml │ ├── x2vlm_base_4m.yaml │ ├── x2vlm_large_1b.yaml │ ├── x2vlm_large_1b_stage2.yaml │ └── x2vlm_large_4m.yaml ├── dataset ├── __init__.py ├── captioning_dataset.py ├── dist_dataset.py ├── grounding_dataset.py ├── nlvr_dataset.py ├── pretrain_dataset.py ├── pretrain_dataset_multilingual.py ├── randaugment.py ├── retrieval_dataset.py ├── tokenizers │ ├── __init__.py │ └── bert_tokenizer_with_dropout.py ├── utils.py ├── vqa_dataset.py ├── wit_dataset.py ├── xflickrco_dataset.py └── xvnli_dataset.py ├── models ├── __init__.py ├── beit2.py ├── box_ops.py ├── clip_vit.py ├── model_classification.py ├── model_grounding.py ├── model_pretrain.py ├── model_retrieval.py ├── resampler.py ├── swin_transformer.py ├── vit.py ├── xbert.py ├── xroberta.py └── xvlm.py ├── optim.py ├── refTools ├── evaluation │ ├── __init__.py │ ├── bleu │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── meteor │ │ ├── __init__.py │ │ ├── meteor-1.5.jar │ │ └── meteor.py │ ├── readme.txt │ ├── refEvaluation.py │ ├── rouge │ │ ├── __init__.py │ │ └── rouge.py │ └── tokenizer │ │ ├── __init__.py │ │ ├── ptbtokenizer.py │ │ ├── stanford-corenlp-3.4.1.jar │ │ ├── tmp37tp6xj8 │ │ ├── tmp82iqkuu0 │ │ └── tmpn19wmqte └── refer_python3.py ├── requirements.txt ├── run.py ├── scheduler.py ├── utils ├── __init__.py ├── bleu.py ├── checkpointer.py ├── cider │ └── pyciderevalcap │ │ ├── __init__.py │ │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ │ └── ciderD │ │ ├── __init__.py │ │ ├── ciderD.py │ │ └── ciderD_scorer.py ├── hdfs_io.py ├── marvl_preproc.py └── torch_io.py ├── vqaTools ├── __init__.py ├── vqa.py └── vqaEval.py ├── x2vlm_github.png └── xFlickrCO.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023, ByteDance Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /MARVL.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import math 5 | 6 | import ruamel.yaml as yaml 7 | import numpy as np 8 | import random 9 | import time 10 | import datetime 11 | import json 12 | from pathlib import Path 13 | import json 14 | import pickle 15 | 16 | import torch 17 | import torch.backends.cudnn as cudnn 18 | import torch.distributed as dist 19 | 20 | import utils 21 | from dataset import create_dataset, create_sampler, create_loader, build_tokenizer 22 | from scheduler import create_scheduler 23 | from optim import create_optimizer 24 | 25 | 26 | def train(model, data_loader, optimizer, tokenizer, epoch, device, scheduler): 27 | model.train() 28 | 29 | metric_logger = utils.MetricLogger(delimiter=" ") 30 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 31 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 32 | 33 | header = 'Train Epoch: [{}]'.format(epoch) 34 | print_freq = 50 35 | step_size = 100 36 | 37 | for i, (image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 38 | images = torch.cat([image0, image1], dim=0) 39 | images, targets = images.to(device), targets.to(device) 40 | 41 | text_inputs = tokenizer(text, padding='longest', return_tensors="pt").to(device) 42 | 43 | loss = model(images, text_inputs.input_ids, text_inputs.attention_mask, targets=targets, train=True) 44 | 45 | optimizer.zero_grad() 46 | loss.backward() 47 | optimizer.step() 48 | scheduler.step() 49 | 50 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 51 | metric_logger.update(loss=loss.item()) 52 | 53 | # gather the stats from all processes 54 | metric_logger.synchronize_between_processes() 55 | print("Averaged stats:", metric_logger.global_avg()) 56 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 57 | 58 | 59 | @torch.no_grad() 60 | def evaluate(model, data_loader, tokenizer, device): 61 | model.eval() 62 | 63 | metric_logger = utils.MetricLogger(delimiter=" ") 64 | 65 | header = 'Evaluation:' 66 | print_freq = 50 67 | 68 | for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header): 69 | images = torch.cat([image0, image1], dim=0) 70 | images, targets = images.to(device), targets.to(device) 71 | text_inputs = tokenizer(text, padding='longest', return_tensors="pt").to(device) 72 | 73 | prediction = model(images, text_inputs.input_ids, text_inputs.attention_mask, targets=targets, train=False) 74 | 75 | _, pred_class = prediction.max(1) 76 | accuracy = (targets == pred_class).sum() / targets.size(0) 77 | 78 | metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0)) 79 | 80 | # gather the stats from all processes 81 | metric_logger.synchronize_between_processes() 82 | print("Averaged stats:", metric_logger.global_avg()) 83 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 84 | 85 | 86 | def main(args, config): 87 | utils.init_distributed_mode(args) 88 | device = torch.device(args.device) 89 | 90 | world_size = utils.get_world_size() 91 | 92 | if args.epoch > 0: 93 | config['schedular']['epochs'] = args.epoch 94 | print(f"### set epochs to: {args.epoch}", flush=True) 95 | 96 | if args.bs > 0: 97 | config['batch_size'] = args.bs // world_size 98 | 99 | seed = args.seed + utils.get_rank() 100 | torch.manual_seed(seed) 101 | np.random.seed(seed) 102 | random.seed(seed) 103 | cudnn.benchmark = True 104 | 105 | print("Creating dataset") 106 | train_dataset, val_dataset, test_dataset_dict = create_dataset('marvl', config) 107 | datasets = [train_dataset, val_dataset] 108 | 109 | train_dataset_size = len(train_dataset) 110 | train_batch_size = config['batch_size'] 111 | world_size = utils.get_world_size() 112 | 113 | if utils.is_main_process(): 114 | print(f"### data {train_dataset_size}, batch size, {train_batch_size} x {world_size}") 115 | print(f"### Test: {[(k, len(dataset)) for k, dataset in test_dataset_dict.items()]}") 116 | 117 | if args.distributed: 118 | num_tasks = utils.get_world_size() 119 | global_rank = utils.get_rank() 120 | samplers = create_sampler(datasets, [True, False], num_tasks, global_rank) 121 | else: 122 | samplers = [None, None] 123 | 124 | train_loader, val_loader = create_loader(datasets, samplers, batch_size=[config['batch_size']] * 2, 125 | num_workers=[4, 4], is_trains=[True, False], 126 | collate_fns=[None, None]) 127 | 128 | test_loader_dict = {} 129 | for k, v in test_dataset_dict.items(): 130 | test_loader_dict[k] = create_loader([v], [None], batch_size=[config['batch_size']], 131 | num_workers=[4], is_trains=[False], collate_fns=[None])[0] 132 | 133 | print("Creating model") 134 | from models.model_classification import XVLMPlusForMARVL 135 | model = XVLMPlusForMARVL(config=config) 136 | model.load_pretrained(args.checkpoint, config, is_eval=args.evaluate) 137 | model = model.to(device) 138 | print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad)) 139 | 140 | model_without_ddp = model 141 | if args.distributed: 142 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 143 | model_without_ddp = model.module 144 | 145 | tokenizer = build_tokenizer(config['text_encoder']) 146 | 147 | print("### output_dir, ", args.output_dir, flush=True) 148 | start_time = time.time() 149 | 150 | if args.evaluate: 151 | print("Start evaluating") 152 | 153 | acc_mean = 0 154 | for language, test_loader in test_loader_dict.items(): 155 | test_stats = evaluate(model, test_loader, tokenizer, device) 156 | if utils.is_main_process(): 157 | print({f'test_{language}_{k}': v for k, v in test_stats.items()}, flush=True) 158 | acc_mean += (test_stats['acc'] / len(test_loader_dict)) 159 | 160 | dist.barrier() 161 | 162 | if utils.is_main_process(): 163 | print("Test average accuracy: ", acc_mean, flush=True) 164 | dist.barrier() 165 | 166 | else: 167 | print("Start training") 168 | arg_opt = utils.AttrDict(config['optimizer']) 169 | optimizer = create_optimizer(arg_opt, model) 170 | arg_sche = utils.AttrDict(config['schedular']) 171 | arg_sche['step_per_epoch'] = math.ceil(train_dataset_size / (train_batch_size * world_size)) 172 | lr_scheduler = create_scheduler(arg_sche, optimizer) 173 | 174 | max_epoch = config['schedular']['epochs'] 175 | 176 | best = 0 177 | best_epoch = 0 178 | if 'eval_interval' not in config: 179 | config['eval_interval'] = 1 180 | 181 | for epoch in range(0, max_epoch): 182 | if args.distributed: 183 | train_loader.sampler.set_epoch(epoch) 184 | train_stats = train(model, train_loader, optimizer, tokenizer, epoch, device, lr_scheduler) 185 | if epoch >= config['start_eval']: 186 | # val_stats = evaluate(model, val_loader, tokenizer, device) 187 | 188 | acc_mean = 0 189 | for language, test_loader in test_loader_dict.items(): 190 | test_stats = evaluate(model, test_loader, tokenizer, device) 191 | if utils.is_main_process(): 192 | print({f'test_{language}_{k}': v for k, v in test_stats.items()}, flush=True) 193 | acc_mean += (float(test_stats['acc']) / len(test_loader_dict)) 194 | dist.barrier() 195 | 196 | if utils.is_main_process(): 197 | if acc_mean > best: 198 | save_obj = { 199 | 'model': model_without_ddp.state_dict(), 200 | # 'optimizer': optimizer.state_dict(), 201 | # 'lr_scheduler': lr_scheduler.state_dict(), 202 | 'config': config, 203 | # 'epoch': epoch, 204 | } 205 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 206 | best = acc_mean 207 | best_epoch = epoch 208 | 209 | print("best epoch: {:}, best test acc_mean: {:.4f}".format(best_epoch, best), flush=True) 210 | 211 | dist.barrier() 212 | 213 | total_time = time.time() - start_time 214 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 215 | print('### Time {}'.format(total_time_str)) 216 | 217 | 218 | if __name__ == '__main__': 219 | parser = argparse.ArgumentParser() 220 | parser.add_argument('--checkpoint', type=str, required=True) 221 | parser.add_argument('--config', default='./configs/MARVL.yaml') 222 | parser.add_argument('--output_dir', default='output/nlvr') 223 | 224 | parser.add_argument('--device', default='cuda') 225 | parser.add_argument('--seed', default=42, type=int) 226 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 227 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 228 | parser.add_argument('--distributed', action='store_false') 229 | 230 | parser.add_argument('--load_nlvr_pretrain', action='store_true') 231 | parser.add_argument('--epoch', default=-1, type=int) 232 | parser.add_argument('--lr', default=0., type=float) 233 | parser.add_argument('--fewshot', default='', type=str) 234 | parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus") 235 | parser.add_argument('--evaluate', action='store_true') 236 | 237 | args = parser.parse_args() 238 | 239 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 240 | 241 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 242 | 243 | if args.lr != 0.: 244 | config['optimizer']['lr'] = args.lr 245 | config['schedular']['lr'] = args.lr 246 | if args.fewshot: 247 | config['train_file'][0] = config['train_file'][0].format(args.fewshot) 248 | config['val_file'][0] = config['val_file'][0].format(args.fewshot) 249 | 250 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 251 | 252 | main(args, config) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # X2-VLM: All-In-One Pre-trained Model For Vision-Language Tasks 2 | 3 |
4 | 5 |
6 | 7 | X2-VLM with a modular architecture performs the best on base and large scale for both image-text and video-text tasks, making a good trade-off between performance and model scale. We also show that the modular design of X2-VLM results in high transferability for X2-VLM to be utilized in any language or domain. For example, by simply replacing the text encoder with XLM-R, X-VLM outperforms state-of-the-art multilingual multi-modal pre-trained models without any multilingual pre-training. 8 | 9 | 10 | - Jun 2023: Release official PyTorch implementation and checkpoints 11 | - Nov 2022: Release preprint in arxiv. 12 | 13 | 14 | X2-VLM (large, 593M params): 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/x-2-vlm-all-in-one-pre-trained-model-for/cross-modal-retrieval-on-flickr30k)](https://paperswithcode.com/sota/cross-modal-retrieval-on-flickr30k?p=x-2-vlm-all-in-one-pre-trained-model-for) 16 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/x-2-vlm-all-in-one-pre-trained-model-for/cross-modal-retrieval-on-coco-2014)](https://paperswithcode.com/sota/cross-modal-retrieval-on-coco-2014?p=x-2-vlm-all-in-one-pre-trained-model-for) 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/x-2-vlm-all-in-one-pre-trained-model-for/visual-grounding-on-refcoco-testa)](https://paperswithcode.com/sota/visual-grounding-on-refcoco-testa?p=x-2-vlm-all-in-one-pre-trained-model-for) 18 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/x-2-vlm-all-in-one-pre-trained-model-for/visual-reasoning-on-nlvr2-test)](https://paperswithcode.com/sota/visual-reasoning-on-nlvr2-test?p=x-2-vlm-all-in-one-pre-trained-model-for) 19 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/x-2-vlm-all-in-one-pre-trained-model-for/visual-question-answering-on-vqa-v2-test-std)](https://paperswithcode.com/sota/visual-question-answering-on-vqa-v2-test-std?p=x-2-vlm-all-in-one-pre-trained-model-for) 20 | 21 | 22 | 23 | ## Features 24 | - Support several backbones 25 | - vision encoder: beit / clip-vit / swin-transformer 26 | - text encoder: bert / roberta 27 | - Support apex O1 / O2 for pre-training 28 | - Read from and write to HDFS 29 | - Distributed training across nodes for both pre-training and fine-tuning 30 | 31 | Please read the code for more details. 32 | 33 | 34 | ## Requirements 35 | - Install python3 environment 36 | ```angular2html 37 | pip3 install -r requirements.txt 38 | ``` 39 | - Download raw images from corresponding websites 40 | - Download the json files we provided, which contains image read paths and captions and/or bbox annotations 41 | - If running pre-training scripts: 42 | - install Apex 43 | - download pre-trained models for parameter initialization 44 | - image encoder: beit2 45 | - text encoder: bert 46 | 47 | 48 | ## Pretrain 49 | ```angular2html 50 | # X-VLM pretrain 51 | python3 run.py --task "pretrain_DIY" --dist "all" --config "configs/pretrain/x2vlm_base_4m.yaml" --output_dir "output/tmp" 52 | 53 | # CCLM multilingual multimodal pretrain 54 | python3 run.py --task "pretrain_DIY" --dist "all" --config "configs/pretrain/multilingual_cclm_x2vlm_base.yaml" --checkpoint "path/to/x2vlm_base_1b.th" --output_dir "output/tmp" 55 | ``` 56 | See run.py and configs/pretrain for more details. 57 | 58 | 59 | #### Data 60 | All datasets we utilized are public available. Please prepare the pre-training data by yourself. Read the code dataset/pretrain_dataset.py (more specifically ImageTextJsonDataset & RegionTextJsonDataset) to see what format is needed. 61 | 62 | The processed COCO & VG annotations can be downloaded [here](https://drive.google.com/drive/u/1/folders/1W4_wr53DDWLsvuSavNW1iDbSo9yXQQL1). 63 | 64 | 65 | #### Checkpoints 66 | Please make sure all parameters are loaded correctly. 67 | [X2VLM-base (4M)](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_base_4m.th) 68 | [X2VLM-large (4M)](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_large_4m.th) 69 | [X2VLM-base (1B)](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_base_1b.th) 70 | [CCLM-X2VLM-base](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/cclm_x2vlm_base.th) 71 | 72 | 73 | 74 | ## Finetune 75 | 76 | #### Data 77 | All datasets are publicly available. Some datasets can be downloaded [here](https://drive.google.com/file/d/1XFz1Vtz7MCBLn4_1QEojhFJ5Iw3eH3X4/view?usp=sharing). 78 | 79 | 80 | #### Checkpoints, Configs and Logs 81 | We have released all codes. However, now we only provide parts of fine-tuned ckpts (and training configs and logs). 82 | [vqa-base](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_base_1b_vqa.th) 83 | [vqa-large](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_large_1b_vqa.th) 84 | [refcoco-bbox-large](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_large_1b_grounding.tar) 85 | It takes time for us to retrieve our previous training logs. If you need more, please submit a Github issue and we will return to your request later. 86 | [coco-retrieval-base-rerun](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/xvlm_beit_1b_stage2_coco_rerun.th) 87 | [coco-retrieval-large-rerun](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/xvlm_beit_1b_large_stage2_coco_rerun.th) 88 | 89 | 90 | #### Examples 91 | ```angular2html 92 | # train 93 | 94 | python3 run.py --task "vqa" --dist "all" --config "configs/finetune/vqa2_large.yaml" --checkpoint "x2vlm_ckpts_2release/x2vlm_large_1b.th" --output_dir "output/tmp" 95 | 96 | python3 run.py --task "refcoco_bbox" --dist "all" --config "configs/finetune/refcoco_grounding_large.yaml" --checkpoint "x2vlm_ckpts_2release/x2vlm_large_1b.th" --output_dir "output/tmp" 97 | 98 | python3 run.py --task "coco_captioning_mlm" --dist "all" --config "configs/finetune/coco_captioning_large.yaml" --checkpoint "x2vlm_ckpts_2release/x2vlm_large_1b.th" --output_dir "output/tmp" 99 | ``` 100 | We release all training codes. Specify "--task" and "--config" to finetune on other tasks. See run.py for details. 101 | 102 | 103 | 104 | 105 | ## Citation 106 | If you find this repository useful, please considering giving ⭐ or citing: 107 | ``` 108 | @article{zeng2022x, 109 | title={X $\^{} 2$-VLM: All-In-One Pre-trained Model For Vision-Language Tasks}, 110 | author={Zeng, Yan and Zhang, Xinsong and Li, Hang and Wang, Jiawei and Zhang, Jipeng and Zhou, Wangchunshu}, 111 | journal={arXiv preprint arXiv:2211.12402}, 112 | year={2022} 113 | } 114 | 115 | @article{zeng2022cross, 116 | title={Cross-view language modeling: Towards unified cross-lingual cross-modal pre-training}, 117 | author={Zeng, Yan and Zhou, Wangchunshu and Luo, Ao and Zhang, Xinsong}, 118 | journal={arXiv preprint arXiv:2206.00621}, 119 | year={2022} 120 | } 121 | 122 | @article{xvlm, 123 | title={Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts}, 124 | author={Zeng, Yan and Zhang, Xinsong and Li, Hang}, 125 | journal={arXiv preprint arXiv:2111.08276}, 126 | year={2021} 127 | } 128 | ``` 129 | 130 | 131 | ### Contact 132 | For issues using this code, please submit a GitHub issue. -------------------------------------------------------------------------------- /accelerators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/accelerators/__init__.py -------------------------------------------------------------------------------- /accelerators/accelerator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 3 | # Github: https://github.com/zengyan-97/X-VLM 4 | # Copyright (c) 2022, ByteDance Inc. 5 | # All rights reserved. 6 | 7 | from logging import Logger 8 | 9 | import torch 10 | from torch.optim import Optimizer 11 | 12 | Net = torch.nn.Module 13 | 14 | 15 | class Accelerator: 16 | def __init__(self, cfg, logger) -> None: 17 | self.cfg = cfg 18 | self.logger = logger 19 | 20 | def set_up(self, model: Net): 21 | raise NotImplementedError("Set Up method not implement in Accelerator, please check! ") 22 | 23 | def broadcast(self): 24 | raise NotImplementedError("Broadcast method not implement in Accelerator, please check! ") 25 | 26 | def backward_step(self, loss: torch.Tensor): 27 | loss.backward() 28 | 29 | def optimizer_step(self, optimizer: Optimizer, model: Net, grad_norm: float) -> float: 30 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 31 | grad_norm) 32 | return float(total_norm) 33 | -------------------------------------------------------------------------------- /accelerators/apex_ddp_accelerator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 3 | # Github: https://github.com/zengyan-97/X-VLM 4 | # Copyright (c) 2022, ByteDance Inc. 5 | # All rights reserved. 6 | 7 | import os 8 | import random 9 | import sys 10 | from typing import Tuple, Union, Optional, Any 11 | import numpy as np 12 | 13 | import torch 14 | import torch.distributed as distributed 15 | from torch.optim import Optimizer 16 | from torch.optim.lr_scheduler import LambdaLR 17 | 18 | Net = torch.nn.Module 19 | 20 | from .accelerator import Accelerator 21 | 22 | try: 23 | from apex import amp 24 | from apex.parallel import DistributedDataParallel as Apex_DDP 25 | from apex.parallel import convert_syncbn_model 26 | except ImportError: 27 | print('no apex! Please install from https://www.github.com/nvidia/apex') 28 | 29 | 30 | class ApexDDPAccelerator(Accelerator): 31 | """ 32 | ApexDDPAccelerator, use apex DistributedDataParallel 33 | """ 34 | 35 | def __init__(self, cfg, logger): 36 | super().__init__(cfg, logger) 37 | self.accelerator_rng_seed = self.cfg.RNG_SEED 38 | self.accelerator_syncbn = self.cfg.SYNCBN 39 | self.accelerator_fp16_opt_level = self.cfg.FP16_OPT_LEVEL 40 | self.accelerator_fp16_loss_scale = self.cfg.FP16_LOSS_SCALE 41 | 42 | def set_up(self, model: Net, optimizer: Optimizer, lr_scheduler: LambdaLR, 43 | local_rank: int, world_size: int, rank: int) -> Tuple[Apex_DDP, Optimizer, LambdaLR]: 44 | """ 45 | set up ApexDDPAccelerator, including process_group and apex_ddp 46 | """ 47 | torch.backends.cudnn.benchmark = False 48 | random.seed(self.accelerator_rng_seed) 49 | np.random.seed(self.accelerator_rng_seed) 50 | torch.random.manual_seed(self.accelerator_rng_seed) 51 | torch.cuda.manual_seed_all(self.accelerator_rng_seed) 52 | master_address = os.environ.get('MASTER_ADDR', "127.0.0.1") 53 | master_port = int(os.environ.get('MASTER_PORT', 34171)) 54 | 55 | torch.cuda.set_device(local_rank) 56 | model = model.cuda() 57 | if not torch.distributed.is_initialized(): 58 | distributed.init_process_group( 59 | backend='nccl', 60 | init_method='tcp://{}:{}'.format(master_address, master_port), 61 | world_size=world_size, 62 | rank=rank, 63 | group_name='mtorch') 64 | print( 65 | f'ApexDDPAccelerator distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}') 66 | sys.stdout.flush() 67 | 68 | self.broadcast(model) 69 | apex_model, optimizer = self.configure_ddp(model, optimizer) 70 | 71 | if self.accelerator_syncbn: 72 | apex_model = self.configure_sync_batchnorm(apex_model) 73 | return apex_model, optimizer, lr_scheduler 74 | 75 | def broadcast(self, model: Net, src=0) -> None: 76 | for v in model.state_dict().values(): 77 | distributed.broadcast(v, src) 78 | 79 | def configure_ddp(self, model: Net, optimizer: Optimizer) -> Tuple[Apex_DDP, Optimizer]: 80 | model, optimizer = amp.initialize(model, optimizer, 81 | opt_level=self.accelerator_fp16_opt_level, 82 | keep_batchnorm_fp32=None, # from True to None 83 | loss_scale=self.accelerator_fp16_loss_scale, 84 | max_loss_scale=1024.0, 85 | min_loss_scale=1.0) 86 | 87 | apex_model = Apex_DDP(model, delay_allreduce=True) 88 | self.ddp_model = apex_model 89 | return apex_model, optimizer 90 | 91 | def configure_sync_batchnorm(self, model: Net) -> Net: 92 | model = convert_syncbn_model(model) 93 | return model 94 | 95 | def backward_step(self, loss: torch.Tensor, optimizer: Optimizer): 96 | with amp.scale_loss(loss, optimizer) as scaled_loss: 97 | scaled_loss.backward() 98 | 99 | def optimizer_step(self, optimizer: Optimizer, model: Net, grad_norm: float) -> float: 100 | total_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 101 | grad_norm) 102 | return float(total_norm) 103 | -------------------------------------------------------------------------------- /configs/config_beit2_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "data/beitv2_base_patch16_224_pt1k_ft21k.pth", 3 | "vision_width": 768, 4 | "patch_size": 16 5 | } 6 | -------------------------------------------------------------------------------- /configs/config_beit2_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "data/beitv2_large_patch16_224_pt1k_ft21k.pth", 3 | "vision_width": 1024, 4 | "patch_size": 16 5 | } 6 | -------------------------------------------------------------------------------- /configs/finetune/coco_captioning_large.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/finetune/coco_karpathy/coco_karpathy_train.json'] 2 | val_file: 'data/finetune/coco_karpathy/coco_karpathy_val.json' 3 | test_file: 'data/finetune/coco_karpathy/coco_karpathy_test.json' 4 | 5 | image_root: 'images/coco/' 6 | val_gt_file: 'data/finetune/coco_karpathy/coco_karpathy_val_gt.json' 7 | test_gt_file: 'data/finetune/coco_karpathy/coco_karpathy_test_gt.json' 8 | 9 | ## Vision Encoder 10 | use_beit_v2: True 11 | vision_config: 'configs/config_beit2_large.json' 12 | image_res: 384 13 | patch_size: 16 14 | 15 | 16 | ## Text Encoder (& Cross Encoder) 17 | text_encoder: 'data/bert-large-uncased' 18 | text_num_hidden_layers: 18 19 | text_fusion_start_at: 12 20 | 21 | 22 | ## Training 23 | apply_FG_free: True 24 | batch_size_train: 16 # xN A100s, i don't remember how many GPUs i used... (i guess either 8 or 16) 25 | batch_size_test: 20 26 | max_tokens: 40 27 | max_words: 40 28 | label_smoothing: 0.1 29 | mask_prob: 0.6 30 | max_masks: 18 31 | mask_whole_word: True 32 | skipgram_prb: 0.2 33 | skipgram_size: 3 34 | 35 | ## generation configs 36 | max_length: 50 37 | min_length: 5 38 | num_beams: 3 39 | length_penalty: 0 40 | prompt: 'a picture of ' 41 | 42 | 43 | optimizer: {opt: adamW, lr: 5e-6, weight_decay: 0.01, lr_mult: 2, vision_lr: 1e-5, text_lr: 5e-6} 44 | schedular: {sched: linear, epochs: 5, num_warmup_steps: 0.05} 45 | start_eval: 0 # epoch index 46 | -------------------------------------------------------------------------------- /configs/finetune/refcoco_grounding_large.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/finetune/refcoco+_train.json'] 2 | test_file: ['data/finetune/refcoco+_val.json','data/finetune/refcoco+_test.json'] 3 | 4 | refcoco_data: 'data/finetune/' 5 | det_file: 'data/finetune/refcoco+/dets.json' 6 | coco_file: 'data/finetune/refcoco+/cocos.json' 7 | 8 | image_root: 'images/coco/' 9 | 10 | careful_hflip: True # first check whether 'left' or 'right' in captions 11 | 12 | ## Vision Encoder 13 | use_beit_v2: True 14 | vision_config: 'configs/config_beit2_large.json' 15 | image_res: 384 16 | patch_size: 16 17 | 18 | 19 | ## Text Encoder (& Cross Encoder) 20 | text_encoder: 'data/bert-large-uncased' 21 | text_num_hidden_layers: 18 # 12 + 6 22 | text_fusion_start_at: 12 23 | 24 | text_drop_path_rate: 0.1 25 | cross_drop_path_rate: 0.1 26 | 27 | ## Training 28 | batch_size: 20 # xN A100s, i don't remember how many GPUs i used... (i guess either 8 or 16) 29 | max_tokens: 40 30 | 31 | 32 | ## Other Settings 33 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.01, lr_mult: 2} 34 | schedular: {sched: linear, epochs: 10, num_warmup_steps: 0.1} 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/finetune/vqa2_base.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/finetune/vqa_train.json', 2 | 'data/finetune/vqa_val.json', 3 | 'data/finetune/vg_qa.json'] 4 | 5 | test_file: ['data/finetune/vqa_test.json'] 6 | answer_list: 'data/finetune/answer_list.json' 7 | 8 | vqa_root: 'images/coco/' 9 | vg_root: 'images/visualgenome/' 10 | 11 | ## Vision Encoder 12 | use_beit_v2: True 13 | vision_config: 'configs/config_beit2_base.json' 14 | image_res: 768 15 | patch_size: 16 16 | 17 | 18 | ## Text Encoder (& Cross Encoder) 19 | text_encoder: 'data/bert-base-uncased' 20 | text_num_hidden_layers: 18 21 | text_fusion_start_at: 12 22 | 23 | ## Training 24 | num_dec_layers: 6 25 | large_lr_for_dec: True 26 | batch_size_train: 8 # x16 a100 27 | accumulate_steps: 1 28 | batch_size_test: 32 29 | max_tokens: 40 30 | k_test: 128 31 | 32 | 33 | ## Other Settings 34 | optimizer: {opt: adamW, lr: 4e-5, weight_decay: 0.01, lr_mult: 2, vision_lr: 2e-5, text_lr: 4e-5} 35 | schedular: {sched: linear, epochs: 5, num_warmup_steps: 0.1} 36 | start_eval: 2 # epoch index 37 | -------------------------------------------------------------------------------- /configs/finetune/vqa2_large.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/finetune/vqa_train.json', 2 | 'data/finetune/vqa_val.json', 3 | 'data/finetune/vg_qa.json'] 4 | 5 | test_file: ['data/finetune/vqa_test.json'] 6 | answer_list: 'data/finetune/answer_list.json' 7 | 8 | vqa_root: 'images/coco/' 9 | vg_root: 'images/visualgenome/' 10 | 11 | ## Vision Encoder 12 | use_beit_v2: True 13 | vision_config: 'configs/config_beit2_large.json' 14 | image_res: 768 15 | patch_size: 16 16 | 17 | 18 | ## Text Encoder (& Cross Encoder) 19 | text_encoder: 'data/bert-large-uncased' 20 | text_num_hidden_layers: 18 21 | text_fusion_start_at: 12 22 | 23 | ## Training 24 | num_dec_layers: 6 25 | large_lr_for_dec: True 26 | batch_size_train: 2 # x32 a100 27 | accumulate_steps: 2 28 | batch_size_test: 32 29 | max_tokens: 40 30 | k_test: 128 31 | 32 | 33 | ## Other Settings 34 | optimizer: {opt: adamW, lr: 4e-5, weight_decay: 0.01, lr_mult: 2, vision_lr: 2e-5, text_lr: 2e-5} 35 | schedular: {sched: linear, epochs: 5, num_warmup_steps: 0.05} 36 | start_eval: 2 # epoch index 37 | -------------------------------------------------------------------------------- /configs/pretrain/multilingual_cclm_x2vlm_base.yaml: -------------------------------------------------------------------------------- 1 | # Adapt X^2-VLM to multilingual by Cross-View Language Modeling 2 | 3 | ## Data 4 | train_file: [ 5 | "path/to/cc-3m-mm-uc2", 6 | "path/to/sbu-mm", 7 | "path/to/vg-mm", 8 | "path/to/coco-mm", 9 | ] # multilingual x multimodal 10 | 11 | train_dataset_size: 5004332 12 | 13 | images: {image_key: "binary", 14 | is_image_rpath: False, # read path or base64 encoding 15 | caption_key: "caption", 16 | tokenized: False, # whether texts have been tokenized 17 | batch_size: 60, # x8 gpus 18 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 19 | iter_perc: 1.0, 20 | } 21 | 22 | 23 | train_file_regions: [ 24 | 'path/to/coco_object-mm-google', 25 | 'path/to/vg_object-mm-google', 26 | 'path/to/vg_region-mm', 27 | ] # multilingual 28 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, code_switch: True, 29 | careful_hflip: True, 30 | batch_size: 60, max_images: 26, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4} 31 | 32 | 33 | train_file_mtext: [ 34 | "path/to/wikimatrix", 35 | "path/to/wikimatrix_en_bn", 36 | ] # multilingual parallel texts 37 | mtexts: {source_key: "source_text", 38 | target_key: "target_text", 39 | tokenized: False, # whether texts have been tokenized 40 | batch_size: 60, # x8 gpus 41 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 42 | iter_perc: 1.0, 43 | max_words: 64, 44 | max_tokens: 64, 45 | mask_prob: 0.4, 46 | max_masks: 16, 47 | } 48 | 49 | 50 | ## Vision Encoder 51 | use_beit_v2: True 52 | vision_config: 'configs/config_beit2_base.json' 53 | image_res: 224 54 | patch_size: 16 55 | 56 | 57 | 58 | ## Text Encoder (& Cross Encoder) 59 | model_type: 'CrossViewLM' 60 | text_encoder: 'data/xlm-roberta-base' 61 | text_num_hidden_layers: 12 62 | cross_encoder: 'data/bert-base-uncased' 63 | cross_num_hidden_layers: 6 64 | 65 | is_xvlm_ckpt: True # is of XVLMBase or XVLMPlusBase 66 | xvlm_ckpt_text_num_hidden_layers: 12 # if is_xvlm_ckpt 67 | replace_text_encoder: True 68 | 69 | 70 | ## Training 71 | mixed_in_batch: True 72 | calc_image_bbox_loss: False 73 | embed_dim: 256 74 | temp: 0.07 75 | 76 | max_words: 30 77 | max_tokens: 30 78 | mask_prob: 0.4 79 | max_masks: 10 80 | 81 | mask_whole_word: False # not implemented 82 | skipgram_prb: 0.2 83 | skipgram_size: 3 84 | 85 | 86 | ## Other Settings 87 | ckpt_frequent_step: 50000 88 | ckpt_frequent: 100000 # epoch 89 | optimizer: {opt: adamW, lr: 4e-5, weight_decay: 0.01, lr_mult: 2, vision_lr: 2e-5, text_lr: 8e-5, cross_lr: 4e-5} 90 | schedular: {sched: linear, epochs: 39, num_warmup_steps: 1000} # 400k steps 91 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O0, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} 92 | -------------------------------------------------------------------------------- /configs/pretrain/multilingual_cclm_x2vlm_large.yaml: -------------------------------------------------------------------------------- 1 | # Adapt X^2-VLM to multilingual by Cross-View Language Modeling 2 | 3 | ## Data 4 | train_file: [ 5 | "path/to/cc-3m-mm-uc2", 6 | "path/to/sbu-mm", 7 | "path/to/vg-mm", 8 | "path/to/coco-mm", 9 | ] # multilingual x multimodal 10 | 11 | train_dataset_size: 5004332 12 | 13 | images: {image_key: "binary", 14 | is_image_rpath: False, # read path or base64 encoding 15 | caption_key: "caption", 16 | tokenized: False, # whether texts have been tokenized 17 | batch_size: 30, # x16 gpus 18 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 19 | iter_perc: 1.0, 20 | } 21 | 22 | 23 | train_file_regions: [ 24 | 'path/to/coco_object-mm-google', 25 | 'path/to/vg_object-mm-google', 26 | 'path/to/vg_region-mm', 27 | ] # multilingual 28 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, code_switch: True, 29 | careful_hflip: True, 30 | batch_size: 30, max_images: 14, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4} 31 | 32 | 33 | train_file_mtext: [ 34 | "path/to/wikimatrix", 35 | "path/to/wikimatrix_en_bn", 36 | ] # multilingual parallel texts 37 | mtexts: {source_key: "source_text", 38 | target_key: "target_text", 39 | tokenized: False, # whether texts have been tokenized 40 | batch_size: 30, # x16 gpus 41 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 42 | iter_perc: 1.0, 43 | max_words: 64, 44 | max_tokens: 64, 45 | mask_prob: 0.4, 46 | max_masks: 16, 47 | } 48 | 49 | 50 | ## Vision Encoder 51 | use_beit_v2: True 52 | vision_config: 'configs/config_beit2_large.json' 53 | image_res: 224 54 | patch_size: 16 55 | 56 | 57 | ## Text Encoder (& Cross Encoder) 58 | model_type: 'CrossViewLM' 59 | text_encoder: 'data/xlm-roberta-large' 60 | text_num_hidden_layers: 24 61 | cross_encoder: 'data/bert-large-uncased' 62 | cross_num_hidden_layers: 6 63 | 64 | is_xvlm_ckpt: True # is of XVLMBase or XVLMPlusBase 65 | xvlm_ckpt_text_num_hidden_layers: 12 # if is_xvlm_ckpt 66 | replace_text_encoder: True 67 | 68 | 69 | 70 | ## Training 71 | mixed_in_batch: True 72 | calc_image_bbox_loss: False 73 | embed_dim: 256 74 | temp: 0.07 75 | 76 | max_words: 30 77 | max_tokens: 30 78 | mask_prob: 0.4 79 | max_masks: 10 80 | 81 | mask_whole_word: False # not implemented 82 | skipgram_prb: 0.2 83 | skipgram_size: 3 84 | 85 | 86 | ## Other Settings 87 | ckpt_frequent_step: 50000 88 | ckpt_frequent: 100000 # epoch 89 | optimizer: {opt: adamW, lr: 3e-5, weight_decay: 0.01, lr_mult: 2, vision_lr: 1e-5, text_lr: 6e-5, cross_lr: 3e-5} 90 | schedular: {sched: linear, epochs: 39, num_warmup_steps: 1000} # 400k steps 91 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O0, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} 92 | -------------------------------------------------------------------------------- /configs/pretrain/x2vlm_base_1b.yaml: -------------------------------------------------------------------------------- 1 | ## Data 2 | print_broken_data: False 3 | 4 | train_file: [ 5 | "path/to/laion_filtered", 6 | "path/to/laion2b_filtered", 7 | ] 8 | train_dataset_size: 1323042683 # for IterableDataset 9 | 10 | 11 | train_file_aux: [ 12 | "path/to/coco_testset_filtered", 13 | "path/to/vg_testset_filtered", 14 | "path/to/sbu_bs64", 15 | "path/to/cc3m_bs64", 16 | "path/to/cc12m_bs64", 17 | ] # cleaner data 18 | aux_iter_perc: 0.15 # aux_iter_perc% iterate on train_file_aux, (1-aux_iter_perc%) iterate on train_file 19 | images: {image_key: "binary", 20 | is_image_rpath: False, # read path or base64 encoding 21 | caption_key: "desc", 22 | aux_caption_key: "desc", 23 | tokenized: False, # whether texts have been tokenized 24 | batch_size: 128, # 128 x 24 = 3072 25 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 26 | } 27 | 28 | train_file_regions: [ 29 | 'path/to/coco2017_obj_rmtest_2207', 30 | 'path/to/vg_attr_obj_rmtest_2207', 31 | 'path/to/vg_region_rmtest_2207', 32 | "path/to/refcoco_region_2207", 33 | "path/to/gqa_obj_2207", 34 | "path/to/flickr_obj_2207", 35 | "path/to/openimages_v6_maxrez800_obj_region_2207", 36 | "path/to/object365_obj_2207", 37 | ] # objects & regions; 38 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, 39 | iter_perc: 0.5, batch_size: 64, max_images: 26, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4} 40 | 41 | 42 | ## Vision Encoder 43 | use_beit_v2: True 44 | vision_config: 'configs/config_beit2_base.json' 45 | image_res: 224 46 | patch_size: 16 47 | local_attn_depth: -1 48 | 49 | 50 | ## Text Encoder (& Cross Encoder) 51 | text_encoder: 'data/bert-base-uncased' 52 | text_num_hidden_layers: 18 # include cross 53 | text_fusion_start_at: 12 54 | 55 | 56 | ## Training 57 | mixed_in_batch: True 58 | calc_image_bbox_loss: False 59 | embed_dim: 256 60 | temp: 0.07 61 | 62 | max_words: 30 63 | max_tokens: 30 64 | mask_prob: 0.5 65 | max_masks: 12 66 | mask_whole_word: True 67 | skipgram_prb: 0.2 68 | skipgram_size: 3 69 | 70 | stop_calc_itm: 200000 # steps; matching loss calculates hard negatives causing nan loss 71 | 72 | 73 | ## Other Settings 74 | ckpt_frequent_step: 50000 75 | ckpt_frequent: 100000000 # inf 76 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2} 77 | schedular: {sched: linear, lr: 1e-4, epochs: 3, num_warmup_steps: 2500} 78 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/pretrain/x2vlm_base_1b_stage2.yaml: -------------------------------------------------------------------------------- 1 | ## Data 2 | train_file: [ 3 | "path/to/coco_testset_filtered", 4 | "path/to/vg_testset_filtered", 5 | "path/to/sbu_bs64", 6 | "path/to/cc3m_bs64", 7 | ] 8 | 9 | train_dataset_size: 5114489 10 | images: {image_key: "binary", 11 | is_image_rpath: False, # read path or base64 encoding 12 | caption_key: "desc", 13 | tokenized: False, # whether texts have been tokenized 14 | batch_size: 64, # 64 x 16 = 1024 15 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 16 | } 17 | 18 | 19 | train_file_regions: [ 20 | 'path/to/coco2017_obj_rmtest_2207', 21 | 'path/to/vg_attr_obj_rmtest_2207', 22 | 'path/to/vg_region_rmtest_2207', 23 | "path/to/refcoco_region_2207", 24 | "path/to/gqa_obj_2207", 25 | "path/to/flickr_obj_2207", 26 | "path/to/openimages_v6_maxrez800_obj_region_2207", 27 | "path/to/object365_obj_2207", 28 | ] # objects & regions; 29 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, 30 | iter_perc: 1.0, batch_size: 64, max_images: 26, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4} 31 | 32 | 33 | train_file_videos: [ 34 | "path/to/howto100m_filtered", # 1704857 35 | "path/to/ytt180m_filtered", # 5306576 36 | ] 37 | train_file_videos_aux: [ 38 | "path/to/webvid2_5m", # 2492317 39 | ] # cleaner data 40 | video_aux_iter_perc: 0.35 # aux_iter_perc% iterate on train_file_videos_aux, (1-aux_iter_perc%) iterate on train_file_videos 41 | 42 | 43 | videos: {image_key: "video_frames", 44 | is_image_rpath: False, # read path or base64 encoding 45 | caption_key: "text", 46 | tokenized: False, # whether texts have been tokenized 47 | frame_len: 3, # 5 -> 3, too slow 48 | use_random_sampling: True, 49 | combine_continuous_clips: True, 50 | mininum_frames_before_sampling: 8, # 10 -> 8 since webvid has 15+ frames per video on average 51 | batch_size: 40, # 40*16=640 64 -> 48, too slow 52 | iter_perc: 1.0, 53 | num_workers: 8, # better -> the total number of training files % (world_size * num_workers) == 0 54 | } 55 | 56 | 57 | ## Vision Encoder 58 | use_beit_v2: True 59 | vision_config: 'configs/config_beit2_base.json' 60 | image_res: 224 61 | patch_size: 16 62 | local_attn_depth: -1 63 | 64 | frame_len: 3 65 | add_frame_pos: True 66 | video_encoding: 'avgpool' 67 | 68 | 69 | ## Text Encoder (& Cross Encoder) 70 | text_encoder: 'data/bert-base-uncased' 71 | text_num_hidden_layers: 18 72 | text_fusion_start_at: 12 73 | 74 | 75 | ## Training 76 | mixed_in_batch: True 77 | calc_image_bbox_loss: False 78 | embed_dim: 256 79 | temp: 0.07 80 | 81 | max_words: 40 82 | max_tokens: 40 83 | mask_prob: 0.5 84 | max_masks: 12 85 | mask_whole_word: True 86 | skipgram_prb: 0.2 87 | skipgram_size: 3 88 | 89 | 90 | ## Other Settings 91 | ckpt_frequent_step: 50000 92 | ckpt_frequent: 1000000 # epoch 93 | optimizer: {opt: adamW, lr: 6e-5, weight_decay: 0.01, lr_mult: 2} 94 | schedular: {sched: linear, lr: 6e-5, epochs: 81, num_warmup_steps: 1000} # 400k steps, video -> 27 epochs 95 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} 96 | 97 | -------------------------------------------------------------------------------- /configs/pretrain/x2vlm_base_4m.yaml: -------------------------------------------------------------------------------- 1 | ## Data 2 | train_file: [ 3 | "path/to/coco_testset_filtered", 4 | "path/to/vg_testset_filtered", 5 | "path/to/sbu_bs64", 6 | "path/to/cc3m_bs64", 7 | ] 8 | 9 | train_dataset_size: 5114489 # for IterableDataset 10 | images: {image_key: "binary", 11 | is_image_rpath: False, # read path or base64 encoding 12 | caption_key: "desc", 13 | tokenized: False, # whether texts have been tokenized 14 | batch_size: 128, # 128 x 8 = 1024 15 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 16 | } 17 | 18 | 19 | train_file_regions: [ 20 | 'path/to/coco2017_obj_rmtest_2207', 21 | 'path/to/vg_attr_obj_rmtest_2207', 22 | 'path/to/vg_region_rmtest_2207', 23 | "path/to/refcoco_region_2207", 24 | "path/to/gqa_obj_2207", 25 | "path/to/flickr_obj_2207", 26 | ] 27 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, 28 | iter_perc: 1, batch_size: 128, max_images: 50, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4} 29 | 30 | 31 | ## Vision Encoder 32 | use_beit_v2: True 33 | vision_config: 'configs/config_beit2_base.json' 34 | image_res: 224 35 | patch_size: 16 36 | local_attn_depth: -1 37 | 38 | 39 | ## Text Encoder (& Cross Encoder) 40 | text_encoder: 'data/bert-base-uncased' 41 | text_num_hidden_layers: 18 # include cross 42 | text_fusion_start_at: 12 43 | 44 | 45 | ## Training 46 | mixed_in_batch: True 47 | calc_image_bbox_loss: False 48 | embed_dim: 256 49 | temp: 0.07 50 | 51 | max_words: 40 52 | max_tokens: 40 53 | mask_prob: 0.5 54 | max_masks: 12 55 | mask_whole_word: True 56 | skipgram_prb: 0.2 57 | skipgram_size: 3 58 | 59 | 60 | ## Other Settings 61 | ckpt_frequent_step: 50000 62 | ckpt_frequent: 1000000000 # epoch 63 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2} 64 | schedular: {sched: linear, lr: 1e-4, epochs: 101, num_warmup_steps: 2500} # 之前是跑 200k steps, 现在感觉要跑 500k steps 65 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /configs/pretrain/x2vlm_large_1b.yaml: -------------------------------------------------------------------------------- 1 | ## Data 2 | print_broken_data: False 3 | 4 | train_file: [ 5 | "path/to/laion_filtered", 6 | "path/to/laion2b_filtered", 7 | ] 8 | train_dataset_size: 1323042683 # for IterableDataset 9 | 10 | 11 | train_file_aux: [ 12 | "path/to/coco_testset_filtered", 13 | "path/to/vg_testset_filtered", 14 | "path/to/sbu_bs64", 15 | "path/to/cc3m_bs64", 16 | "path/to/cc12m_bs64", 17 | ] # cleaner data 18 | aux_iter_perc: 0.15 # aux_iter_perc% iterate on train_file_aux, (1-aux_iter_perc%) iterate on train_file 19 | images: {image_key: "binary", 20 | is_image_rpath: False, # read path or base64 encoding 21 | caption_key: "desc", 22 | aux_caption_key: "desc", 23 | tokenized: False, # whether texts have been tokenized 24 | batch_size: 64, # 128 x 24 = 3072 25 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 26 | } 27 | 28 | train_file_regions: [ 29 | 'path/to/coco2017_obj_rmtest_2207', 30 | 'path/to/vg_attr_obj_rmtest_2207', 31 | 'path/to/vg_region_rmtest_2207', 32 | "path/to/refcoco_region_2207", 33 | "path/to/gqa_obj_2207", 34 | "path/to/flickr_obj_2207", 35 | "path/to/openimages_v6_maxrez800_obj_region_2207", 36 | "path/to/object365_obj_2207", 37 | ] # objects & regions; 38 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, 39 | iter_perc: 0.5, batch_size: 32, max_images: 14, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4} 40 | 41 | 42 | ## Vision Encoder 43 | use_beit_v2: True 44 | vision_config: 'configs/config_beit2_large.json' 45 | image_res: 224 46 | patch_size: 16 47 | local_attn_depth: -1 48 | 49 | 50 | ## Text Encoder (& Cross Encoder) 51 | text_encoder: 'data/bert-large-uncased-12l' 52 | text_num_hidden_layers: 18 # 12 + 6 53 | text_fusion_start_at: 12 54 | 55 | 56 | ## Training 57 | mixed_in_batch: True 58 | calc_image_bbox_loss: False 59 | embed_dim: 256 60 | temp: 0.07 61 | 62 | max_words: 30 63 | max_tokens: 30 64 | mask_prob: 0.5 65 | max_masks: 12 66 | mask_whole_word: True 67 | skipgram_prb: 0.2 68 | skipgram_size: 3 69 | 70 | stop_calc_itm: 200000 71 | 72 | 73 | ## Other Settings 74 | ckpt_frequent_step: 50000 75 | ckpt_frequent: 100000000 # inf 76 | optimizer: {opt: adamW, lr: 5e-5, weight_decay: 0.01, lr_mult: 2} 77 | schedular: {sched: linear, lr: 5e-5, epochs: 3, num_warmup_steps: 2500} 78 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /configs/pretrain/x2vlm_large_1b_stage2.yaml: -------------------------------------------------------------------------------- 1 | ## Data 2 | train_file: [ 3 | "path/to/coco_testset_filtered", 4 | "path/to/vg_testset_filtered", 5 | "path/to/sbu_bs64", 6 | "path/to/cc3m_bs64", 7 | ] 8 | 9 | train_dataset_size: 5114489 # for IterableDataset 10 | images: {image_key: "binary", 11 | is_image_rpath: False, # read path or base64 encoding 12 | caption_key: "desc", 13 | tokenized: False, # whether texts have been tokenized 14 | batch_size: 32, # 32 x 32 = 1024 15 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 16 | } 17 | 18 | 19 | train_file_regions: [ 20 | 'path/to/coco2017_obj_rmtest_2207', 21 | 'path/to/vg_attr_obj_rmtest_2207', 22 | 'path/to/vg_region_rmtest_2207', 23 | "path/to/refcoco_region_2207", 24 | "path/to/gqa_obj_2207", 25 | "path/to/flickr_obj_2207", 26 | "path/to/openimages_v6_maxrez800_obj_region_2207", 27 | "path/to/object365_obj_2207", 28 | ] # objects & regions; 29 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, 30 | iter_perc: 1.0, batch_size: 32, max_images: 14, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4} 31 | 32 | 33 | train_file_videos: [ 34 | "path/to/howto100m_filtered", # 1704857 35 | "path/to/ytt180m_filtered", # 5306576 36 | ] 37 | train_file_videos_aux: [ 38 | "hdfs://haruna/home/byte_ailab_litg/user/wangjiawei.424/dataset/webvid2_5m", # 2492317 39 | ] # cleaner data 40 | video_aux_iter_perc: 0.35 # aux_iter_perc% iterate on train_file_videos_aux, (1-aux_iter_perc%) iterate on train_file_videos 41 | 42 | videos: {image_key: "video_frames", 43 | is_image_rpath: False, # read path or base64 encoding 44 | caption_key: "text", 45 | tokenized: False, # whether texts have been tokenized 46 | frame_len: 3, # 5 -> 3, too slow 47 | use_random_sampling: True, 48 | combine_continuous_clips: True, 49 | mininum_frames_before_sampling: 8, # 10 -> 8 since webvid has 15+ frames per video on average 50 | batch_size: 20, # 20*32=640 64 -> 48, too slow 51 | iter_perc: 1.0, 52 | num_workers: 8, # better -> the total number of training files % (world_size * num_workers) == 0 53 | } 54 | 55 | 56 | ## Vision Encoder 57 | use_beit_v2: True 58 | vision_config: 'configs/config_beit2_large.json' 59 | image_res: 224 60 | patch_size: 16 61 | local_attn_depth: -1 62 | 63 | frame_len: 3 64 | add_frame_pos: True 65 | video_encoding: 'avgpool' 66 | 67 | 68 | ## Text Encoder (& Cross Encoder) 69 | text_encoder: 'data/bert-large-uncased-12l' 70 | text_num_hidden_layers: 18 # 12 + 6 71 | text_fusion_start_at: 12 72 | 73 | 74 | ## Training 75 | mixed_in_batch: True 76 | calc_image_bbox_loss: False 77 | embed_dim: 256 78 | temp: 0.07 79 | 80 | max_words: 40 81 | max_tokens: 40 82 | mask_prob: 0.5 83 | max_masks: 12 84 | mask_whole_word: True 85 | skipgram_prb: 0.2 86 | skipgram_size: 3 87 | 88 | 89 | ## Other Settings 90 | ckpt_frequent_step: 50000 91 | ckpt_frequent: 1000000 # epoch 92 | optimizer: {opt: adamW, lr: 3e-5, weight_decay: 0.01, lr_mult: 2} 93 | schedular: {sched: linear, lr: 3e-5, epochs: 81, num_warmup_steps: 1000} # 400k steps, video -> 27 epochs 94 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /configs/pretrain/x2vlm_large_4m.yaml: -------------------------------------------------------------------------------- 1 | ## Data 2 | train_file: [ 3 | "path/to/coco_testset_filtered", 4 | "path/to/vg_testset_filtered", 5 | "path/to/sbu_bs64", 6 | "path/to/cc3m_bs64", 7 | ] 8 | 9 | train_dataset_size: 5114489 # for IterableDataset 10 | images: {image_key: "binary", 11 | is_image_rpath: False, # read path or base64 encoding 12 | caption_key: "desc", 13 | tokenized: False, # whether texts have been tokenized 14 | batch_size: 64, # 128 x 16 = 1024 15 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0 16 | } 17 | 18 | 19 | train_file_regions: [ 20 | 'path/to/coco2017_obj_rmtest_2207', 21 | 'path/to/vg_attr_obj_rmtest_2207', 22 | 'path/to/vg_region_rmtest_2207', 23 | "path/to/refcoco_region_2207", 24 | "path/to/gqa_obj_2207", 25 | "path/to/flickr_obj_2207", 26 | ] 27 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, 28 | iter_perc: 1, batch_size: 64, max_images: 25, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4} 29 | 30 | 31 | ## Vision Encoder 32 | use_beit_v2: True 33 | vision_config: 'configs/config_beit2_large.json' 34 | image_res: 224 35 | patch_size: 16 36 | local_attn_depth: -1 37 | 38 | 39 | ## Text Encoder (& Cross Encoder) 40 | text_encoder: 'data/bert-large-uncased-12l' 41 | text_num_hidden_layers: 18 # 12 + 6 42 | text_fusion_start_at: 12 43 | 44 | 45 | ## Training 46 | mixed_in_batch: True 47 | calc_image_bbox_loss: False 48 | embed_dim: 256 49 | temp: 0.07 50 | 51 | max_words: 40 52 | max_tokens: 40 53 | mask_prob: 0.5 54 | max_masks: 12 55 | mask_whole_word: True 56 | skipgram_prb: 0.2 57 | skipgram_size: 3 58 | 59 | 60 | ## Other Settings 61 | ckpt_frequent_step: 50000 62 | ckpt_frequent: 1000000000 # epoch 63 | optimizer: {opt: adamW, lr: 5e-5, weight_decay: 0.01, lr_mult: 2} 64 | schedular: {sched: linear, lr: 5e-5, epochs: 101, num_warmup_steps: 2500} # we use 250k steps 65 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0} 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /dataset/dist_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 4 | # Github: https://github.com/zengyan-97/X-VLM 5 | # Copyright (c) 2022, ByteDance Inc. 6 | # All rights reserved. 7 | 8 | import sys 9 | from typing import List, Any 10 | import warnings 11 | import random 12 | from itertools import cycle 13 | import torch 14 | from torch.utils.data import IterableDataset 15 | 16 | from utils.hdfs_io import hopen, hlist_files, hexists 17 | 18 | 19 | class DistLineReadingDataset(IterableDataset): # pylint: disable=W0223 20 | """ 21 | iterate a set of folders. 22 | """ 23 | def __init__(self, 24 | data_path, 25 | rank: int = 0, 26 | world_size: int = 1, 27 | shuffle: bool = False, 28 | repeat: bool = False): 29 | super().__init__() 30 | self.shuffle = shuffle 31 | self.rank = rank 32 | self.world_size = world_size 33 | 34 | if isinstance(data_path, str): 35 | data_path = data_path.split(',') 36 | elif isinstance(data_path, list): 37 | pass 38 | else: 39 | raise ValueError(data_path) 40 | 41 | for p in data_path: 42 | assert hexists(p), f"not exist {p}" 43 | 44 | self.files = hlist_files(data_path) 45 | self.files = [f for f in self.files if f.find('_SUCCESS') < 0] 46 | self.is_hdfs = data_path[0].startswith('hdfs') 47 | 48 | self.repeat = repeat 49 | print('[DATA]--all dataset containing {} files.'.format(len(self.files))) 50 | if len(self.files) % self.world_size != 0: 51 | print('[DATA]--Whole dataset file num %s cannot split to worldsize %s ' % 52 | (len(self.files), self.world_size)) 53 | sys.stdout.flush() 54 | 55 | def generate(self): 56 | if self.world_size == 1 or len(self.files) == 1: 57 | cur_dataloader_files = self.files 58 | else: 59 | cur_dataloader_files = split_shard( 60 | self.files, self.rank, self.world_size) 61 | 62 | while True: 63 | if self.shuffle: 64 | random.shuffle(cur_dataloader_files) 65 | worker_info = torch.utils.data.get_worker_info() 66 | 67 | if worker_info is not None: 68 | if len(cur_dataloader_files) % worker_info.num_workers != 0: 69 | print('[DATA]--current dataloader %s file num %s cannot split to worker_num %s ' % 70 | (self.rank, len(cur_dataloader_files), worker_info.num_workers)) 71 | cur_worker_files = split_shard( 72 | cur_dataloader_files, worker_info.id, worker_info.num_workers) 73 | if worker_info.id == 0: 74 | print("[DataLoader] --> Rank:{} Workers:[{} ~ {}][{}] Size of process file:{} ...".format( 75 | self.rank, 0, worker_info.num_workers - 1, worker_info.id, len(cur_dataloader_files))) 76 | else: 77 | cur_worker_files = cur_dataloader_files 78 | 79 | if self.shuffle: 80 | random.shuffle(cur_worker_files) 81 | for filepath in cur_worker_files: 82 | if self.is_hdfs: 83 | with hopen(filepath, 'r') as reader: 84 | for line in reader: 85 | yield line.decode() 86 | continue 87 | with open(filepath, 'r') as reader: 88 | for line in reader: 89 | yield line 90 | 91 | if not self.repeat: 92 | break 93 | 94 | def __iter__(self): 95 | return self.generate() 96 | 97 | 98 | def split_shard(data: List[Any], shard_idx: int, shard_size: int): 99 | num = len(data) 100 | if num < shard_size: 101 | raise RuntimeError("num:{} < shard size:{}".format(num, shard_size)) 102 | start_idx = (num * shard_idx) // shard_size 103 | end_idx = (num * (shard_idx + 1)) // shard_size 104 | return data[start_idx: end_idx] 105 | -------------------------------------------------------------------------------- /dataset/grounding_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import math 4 | import random 5 | from random import random as rand 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | from torchvision.transforms.functional import hflip, resize 11 | 12 | from PIL import Image 13 | from dataset.utils import pre_caption 14 | from refTools.refer_python3 import REFER 15 | 16 | 17 | class grounding_dataset(Dataset): 18 | def __init__(self, ann_file, transform, image_root, max_words=30, mode='train'): 19 | self.ann = [] 20 | for f in ann_file: 21 | self.ann += json.load(open(f, 'r')) 22 | self.transform = transform 23 | self.image_root = image_root 24 | self.max_words = max_words 25 | self.mode = mode 26 | 27 | if self.mode == 'train': 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.ann: 31 | img_id = ann['image'].split('/')[-1] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.ann) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.ann[index] 42 | 43 | image_path = os.path.join(self.image_root, ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = pre_caption(ann['text'], self.max_words) 48 | 49 | if self.mode == 'train': 50 | img_id = ann['image'].split('/')[-1] 51 | 52 | return image, caption, self.img_ids[img_id] 53 | else: 54 | return image, caption, ann['ref_id'] 55 | 56 | 57 | class grounding_dataset_bbox(Dataset): 58 | def __init__(self, ann_file, transform, image_root, max_words=30, mode='train', config=None): 59 | self.image_res = config['image_res'] 60 | self.careful_hflip = config['careful_hflip'] 61 | 62 | self.ann = [] 63 | for f in ann_file: 64 | self.ann += json.load(open(f, 'r')) 65 | self.transform = transform 66 | self.image_root = image_root 67 | self.max_words = max_words 68 | self.mode = mode 69 | 70 | if self.mode == 'train': 71 | self.refer = REFER(config['refcoco_data'], 'refcoco+', 'unc') 72 | self.img_ids = {} 73 | n = 0 74 | for ann in self.ann: 75 | img_id = ann['image'].split('/')[-1] 76 | if img_id not in self.img_ids.keys(): 77 | self.img_ids[img_id] = n 78 | n += 1 79 | 80 | def __len__(self): 81 | return len(self.ann) 82 | 83 | def left_or_right_in_caption(self, caption): 84 | if ('left' in caption) or ('right' in caption): 85 | return True 86 | 87 | return False 88 | 89 | def __getitem__(self, index): 90 | 91 | ann = self.ann[index] 92 | caption = pre_caption(ann['text'], self.max_words) 93 | 94 | image_path = os.path.join(self.image_root, ann['image']) 95 | image = Image.open(image_path).convert('RGB') 96 | W, H = image.size 97 | 98 | if self.mode == 'train': 99 | # random crop 100 | x, y, w, h = self.refer.refToAnn[ann['ref_id']]['bbox'] 101 | assert (x >= 0) and (y >= 0) and (x + w <= W) and (y + h <= H) and (w > 0) and ( 102 | h > 0), "elem invalid" 103 | 104 | x0, y0 = random.randint(0, math.floor(x)), random.randint(0, math.floor(y)) 105 | x1, y1 = random.randint(min(math.ceil(x + w), W), W), random.randint(min(math.ceil(y + h), H), 106 | H) # fix bug: max -> min 107 | w0, h0 = x1 - x0, y1 - y0 108 | assert (x0 >= 0) and (y0 >= 0) and (x0 + w0 <= W) and (y0 + h0 <= H) and (w0 > 0) and ( 109 | h0 > 0), "elem randomcrop, invalid" 110 | image = image.crop((x0, y0, x0 + w0, y0 + h0)) 111 | 112 | W, H = image.size 113 | 114 | do_hflip = False 115 | if rand() < 0.5: 116 | if self.careful_hflip and self.left_or_right_in_caption(caption): 117 | pass 118 | else: 119 | image = hflip(image) 120 | do_hflip = True 121 | 122 | image = resize(image, [self.image_res, self.image_res], interpolation=Image.BICUBIC) 123 | image = self.transform(image) 124 | 125 | # axis transform: for crop 126 | x = x - x0 127 | y = y - y0 128 | 129 | if do_hflip: # flipped applied 130 | x = (W - x) - w # W is w0 131 | 132 | # resize applied 133 | x = self.image_res / W * x 134 | w = self.image_res / W * w 135 | y = self.image_res / H * y 136 | h = self.image_res / H * h 137 | 138 | center_x = x + 1 / 2 * w 139 | center_y = y + 1 / 2 * h 140 | 141 | target_bbox = torch.tensor([center_x / self.image_res, center_y / self.image_res, 142 | w / self.image_res, h / self.image_res], dtype=torch.float) 143 | 144 | return image, caption, target_bbox 145 | 146 | else: 147 | image = self.transform(image) # test_transform 148 | return image, caption, ann['ref_id'] -------------------------------------------------------------------------------- /dataset/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from dataset.utils import pre_caption 6 | 7 | 8 | class nlvr_dataset(Dataset): 9 | def __init__(self, ann_file, transform, image_root=None): 10 | self.ann = [] 11 | 12 | if isinstance(ann_file, list): 13 | for f in ann_file: 14 | self.ann += json.load(open(f, 'r')) 15 | 16 | elif isinstance(ann_file, str): 17 | self.ann += json.load(open(ann_file, 'r')) 18 | 19 | else: 20 | raise ValueError(f"ann_file == {ann_file}") 21 | 22 | self.transform = transform 23 | self.image_root = image_root 24 | self.max_words = 30 25 | 26 | def __len__(self): 27 | return len(self.ann) 28 | 29 | def __getitem__(self, index): 30 | 31 | ann = self.ann[index] 32 | 33 | if self.image_root is None: 34 | image0_path = ann['images'][0] 35 | else: 36 | image0_path = os.path.join(self.image_root, ann['images'][0]) 37 | 38 | image0 = Image.open(image0_path).convert('RGB') 39 | image0 = self.transform(image0) 40 | 41 | if self.image_root is None: 42 | image1_path = ann['images'][1] 43 | else: 44 | image1_path = os.path.join(self.image_root, ann['images'][1]) 45 | 46 | image1 = Image.open(image1_path).convert('RGB') 47 | image1 = self.transform(image1) 48 | 49 | sentence = pre_caption(ann['sentence'], self.max_words) 50 | 51 | if (ann['label'] == 'True') or (ann['label'] is True): 52 | label = 1 53 | 54 | elif (ann['label'] == 'False') or (ann['label'] is False): 55 | label = 0 56 | 57 | else: 58 | raise ValueError(f"unsupported label: {ann['label']}") 59 | 60 | return image0, image1, sentence, label 61 | -------------------------------------------------------------------------------- /dataset/retrieval_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import io 3 | import os 4 | import torch 5 | from base64 import b64decode 6 | from torch.utils.data import Dataset 7 | import random 8 | import traceback 9 | 10 | from PIL import Image 11 | from PIL import ImageFile 12 | 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | Image.MAX_IMAGE_PIXELS = None 15 | 16 | from dataset.utils import pre_caption, sample_frame_ids 17 | 18 | 19 | class re_train_dataset(Dataset): 20 | def __init__(self, ann_file, transform, image_root='', max_words=30, 21 | index_key='image_id', vision_key='image', text_key='caption', 22 | is_video=False, frame_len=1): 23 | self.ann = [] 24 | for f in ann_file: 25 | self.ann += json.load(open(f, 'r')) 26 | self.transform = transform 27 | self.image_root = image_root 28 | self.max_words = max_words 29 | self.img_ids = {} 30 | 31 | self.index_key = index_key 32 | self.vision_key = vision_key 33 | self.text_key = text_key 34 | self.is_video = is_video 35 | self.frame_len = frame_len 36 | self.training = True 37 | 38 | n = 0 39 | for ann in self.ann: 40 | img_id = ann[self.index_key] 41 | if img_id not in self.img_ids.keys(): 42 | self.img_ids[img_id] = n 43 | n += 1 44 | 45 | def __len__(self): 46 | return len(self.ann) 47 | 48 | def __getitem__(self, index): 49 | 50 | ann = self.ann[index] 51 | assert isinstance(ann, dict) 52 | 53 | vision_rpath = os.path.join(self.image_root, ann[self.vision_key]) if len(self.image_root) else ann[self.vision_key] 54 | 55 | if self.is_video: 56 | frames_b64 = json.load(open(vision_rpath, 'r')) 57 | 58 | selected_indices = sample_frame_ids(len(frames_b64), self.frame_len, self.training) 59 | 60 | vision_input = [] 61 | for i in selected_indices: 62 | image = Image.open(io.BytesIO(b64decode(frames_b64[i]))).convert("RGB") 63 | image = self.transform(image) 64 | vision_input.append(image) 65 | 66 | else: 67 | image = Image.open(vision_rpath).convert('RGB') 68 | vision_input = self.transform(image) 69 | 70 | caption = pre_caption(ann[self.text_key], self.max_words) 71 | 72 | return vision_input, caption, self.img_ids[ann[self.index_key]] 73 | 74 | def collate_fn(self, batch): 75 | batch_tensors = [] 76 | for i, x in enumerate(zip(*batch)): 77 | if x[0] is None: 78 | batch_tensors.append(None) 79 | 80 | elif isinstance(x[0], torch.Tensor): 81 | batch_tensors.append(torch.stack(x)) 82 | 83 | elif isinstance(x[0], list): 84 | assert i == 0 # # frames !!! always first 85 | batch_size = len(x) 86 | frames = torch.stack(sum(x, [])) # flatten 87 | _, c, h, w = frames.shape 88 | frames = frames.reshape([batch_size, self.frame_len, c, h, w]) 89 | batch_tensors.append(frames) 90 | 91 | elif isinstance(x[0], str): # should be texts, put in tokenizer afterwards 92 | batch_tensors.append(x) 93 | 94 | else: 95 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 96 | 97 | return batch_tensors 98 | 99 | 100 | class re_eval_dataset(Dataset): 101 | def __init__(self, ann_file, transform, image_root, max_words=30, 102 | index_key='image_id', vision_key='image', text_key='caption', 103 | is_video=False, frame_len=1, ): 104 | self.ann = json.load(open(ann_file, 'r')) 105 | self.transform = transform 106 | self.image_root = image_root 107 | self.max_words = max_words 108 | 109 | self.text = [] 110 | self.image = [] 111 | self.txt2img = {} 112 | self.img2txt = {} 113 | 114 | self.index_key = index_key 115 | self.vision_key = vision_key 116 | self.text_key = text_key 117 | self.is_video = is_video 118 | self.frame_len = frame_len 119 | self.training = False 120 | 121 | txt_id = 0 122 | for img_id, ann in enumerate(self.ann): 123 | self.image.append(ann[self.vision_key]) 124 | self.img2txt[img_id] = [] 125 | 126 | assert isinstance(ann[self.text_key], list) 127 | 128 | for i, caption in enumerate(ann[self.text_key]): 129 | self.text.append(pre_caption(caption, self.max_words)) 130 | self.img2txt[img_id].append(txt_id) 131 | self.txt2img[txt_id] = img_id 132 | txt_id += 1 133 | 134 | def __len__(self): 135 | return len(self.image) 136 | 137 | def __getitem__(self, index): 138 | if len(self.image_root): 139 | image_path = os.path.join(self.image_root, self.ann[index][self.vision_key]) 140 | else: 141 | image_path = self.ann[index][self.vision_key] 142 | 143 | if self.is_video: 144 | frames_b64 = json.load(open(image_path, 'r')) 145 | selected_indices = sample_frame_ids(len(frames_b64), self.frame_len, self.training) 146 | 147 | frames = [] 148 | for i in selected_indices: 149 | image = Image.open(io.BytesIO(b64decode(frames_b64[i]))).convert("RGB") 150 | image = self.transform(image) 151 | frames.append(image) 152 | 153 | frames = torch.stack(frames, dim=0) # (frame_len, 3, 384, 384) 154 | 155 | return frames, index 156 | 157 | else: 158 | image = Image.open(image_path).convert('RGB') 159 | image = self.transform(image) 160 | 161 | return image, index -------------------------------------------------------------------------------- /dataset/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, RobertaTokenizer, XLMRobertaTokenizer, AutoTokenizer 2 | from dataset.tokenizers.bert_tokenizer_with_dropout import BertTokenizerWithDropout 3 | 4 | 5 | def build_tokenizer(text_encoder: str, dropout=0): 6 | if ('bert-base-uncased' in text_encoder) or ('bert-large-uncased' in text_encoder): 7 | if dropout > 0: 8 | tokenizer = BertTokenizerWithDropout.from_pretrained(text_encoder, dropout=dropout) 9 | else: 10 | tokenizer = BertTokenizer.from_pretrained(text_encoder) 11 | 12 | elif ('xlm-roberta-base' in text_encoder) or ('xlm-roberta-large' in text_encoder): 13 | tokenizer = XLMRobertaTokenizer.from_pretrained(text_encoder) 14 | 15 | elif ('roberta-base' in text_encoder) or ('roberta-large' in text_encoder): 16 | tokenizer = RobertaTokenizer.from_pretrained(text_encoder) 17 | 18 | else: 19 | raise NotImplementedError(f"tokenizer for {text_encoder}") 20 | 21 | # always use cls and sep 22 | tokenizer.add_special_tokens({'bos_token': tokenizer.cls_token}) 23 | tokenizer.add_special_tokens({'eos_token': tokenizer.sep_token}) 24 | 25 | return tokenizer -------------------------------------------------------------------------------- /dataset/tokenizers/bert_tokenizer_with_dropout.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | import random 3 | 4 | class BertTokenizerWithDropout(BertTokenizer): 5 | def __init__( 6 | self, 7 | vocab_file, 8 | do_lower_case=True, 9 | do_basic_tokenize=True, 10 | never_split=None, 11 | unk_token="[UNK]", 12 | sep_token="[SEP]", 13 | pad_token="[PAD]", 14 | cls_token="[CLS]", 15 | mask_token="[MASK]", 16 | tokenize_chinese_chars=True, 17 | **kwargs 18 | ): 19 | """Constructs a BertTokenizerWithDropout. 20 | Args: 21 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file 22 | **do_lower_case**: (`optional`) boolean (default True) 23 | Whether to lower case the input 24 | Only has an effect when do_basic_tokenize=True 25 | **do_basic_tokenize**: (`optional`) boolean (default True) 26 | Whether to do basic tokenization before wordpiece. 27 | **never_split**: (`optional`) list of string 28 | List of tokens which will never be split during tokenization. 29 | Only has an effect when do_basic_tokenize=True 30 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 31 | Whether to tokenize Chinese characters. 32 | This should likely be deactivated for Japanese: 33 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 34 | """ 35 | super().__init__( 36 | vocab_file, 37 | do_lower_case=do_lower_case, 38 | do_basic_tokenize=do_basic_tokenize, 39 | never_split=never_split, 40 | unk_token=unk_token, 41 | sep_token=sep_token, 42 | pad_token=pad_token, 43 | cls_token=cls_token, 44 | mask_token=mask_token, 45 | tokenize_chinese_chars=tokenize_chinese_chars, 46 | **kwargs 47 | ) 48 | dropout = kwargs.get("dropout", 0) 49 | assert 0<=dropout<=1 50 | self.wordpiece_tokenizer = WordpieceTokenizerWithDropout(vocab=self.vocab, unk_token=self.unk_token, dropout=dropout) 51 | 52 | 53 | class WordpieceTokenizerWithDropout(object): 54 | """Runs WordPiece tokenization with dropout.""" 55 | 56 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100, dropout=0): 57 | self.vocab = vocab 58 | self.unk_token = unk_token 59 | self.max_input_chars_per_word = max_input_chars_per_word 60 | self.dropout = dropout 61 | 62 | def tokenize(self, text, **kwargs): 63 | """Tokenizes a piece of text into its word pieces. 64 | This uses a greedy longest-match-first algorithm to perform tokenization 65 | using the given vocabulary. 66 | For example: 67 | input = "unaffable" 68 | output = ["un", "##aff", "##able"] 69 | Args: 70 | text: A single token or whitespace separated tokens. This should have 71 | already been passed through `BasicTokenizer`. 72 | Returns: 73 | A list of wordpiece tokens. 74 | """ 75 | 76 | output_tokens = [] 77 | for token in whitespace_tokenize(text): 78 | chars = list(token) 79 | if len(chars) > self.max_input_chars_per_word: 80 | output_tokens.append(self.unk_token) 81 | continue 82 | if self.dropout == 1: 83 | output_tokens.append(chars[0]) 84 | output_tokens.extend("##{}".format(char) for char in chars[1:]) 85 | continue 86 | is_bad = False 87 | start = 0 88 | sub_tokens = [] 89 | while start < len(chars): 90 | end = len(chars) 91 | cur_substr = None 92 | while start < end: 93 | substr = "".join(chars[start:end]) 94 | if start > 0: 95 | substr = "##" + substr 96 | if substr in self.vocab and random.random() >= self.dropout: 97 | cur_substr = substr 98 | break 99 | end -= 1 100 | if cur_substr is None: 101 | is_bad = True 102 | break 103 | sub_tokens.append(cur_substr) 104 | start = end 105 | 106 | if is_bad: 107 | output_tokens.append(self.unk_token) 108 | else: 109 | output_tokens.extend(sub_tokens) 110 | return output_tokens 111 | 112 | 113 | def whitespace_tokenize(text): 114 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 115 | text = text.strip() 116 | if not text: 117 | return [] 118 | tokens = text.split() 119 | return tokens 120 | -------------------------------------------------------------------------------- /dataset/wit_dataset.py: -------------------------------------------------------------------------------- 1 | # Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training (https://arxiv.org/abs/2206.00621) 2 | # Github: https://github.com/zengyan-97/CCLM 3 | # Copyright (c) 2022, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | import json 7 | import os 8 | from collections import OrderedDict 9 | 10 | from torch.utils.data import Dataset 11 | 12 | from PIL import Image 13 | from PIL import ImageFile 14 | import base64 15 | import io 16 | from tqdm import tqdm 17 | 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | Image.MAX_IMAGE_PIXELS = None 20 | 21 | from dataset.utils import pre_caption 22 | 23 | 24 | class wit_train_dataset(Dataset): 25 | def __init__(self, ann_file, transform, max_words=80): 26 | self.ann = [] 27 | for f in ann_file: 28 | for line in tqdm(open(f)): 29 | ann = json.loads(line) 30 | if not ann['caption_reference_description']: 31 | continue 32 | self.ann.append(ann) 33 | self.transform = transform 34 | self.max_words = max_words 35 | self.img_ids = {} 36 | 37 | n = 0 38 | for ann in self.ann: 39 | img_id = ann['image_url'] 40 | if img_id not in self.img_ids.keys(): 41 | self.img_ids[img_id] = n 42 | n += 1 43 | 44 | def __len__(self): 45 | return len(self.ann) 46 | 47 | def __getitem__(self, index): 48 | 49 | ann = self.ann[index] 50 | 51 | image_str = base64.b64decode(ann['image_content']) 52 | image = Image.open(io.BytesIO(image_str)).convert("RGB") 53 | image = self.transform(image) 54 | try: 55 | caption = pre_caption(ann['caption_reference_description'], self.max_words) 56 | except Exception: 57 | caption = ann['caption_reference_description'] 58 | 59 | return image, caption, self.img_ids[ann['image_url']] 60 | 61 | 62 | class wit_eval_dataset(Dataset): 63 | def __init__(self, ann_file, transform, max_words=80): 64 | self.ann = [] 65 | for line in open(ann_file, 'r'): 66 | ann = json.loads(line) 67 | if not ann['caption_reference_description']: 68 | continue 69 | self.ann.append(ann) 70 | self.transform = transform 71 | self.max_words = max_words 72 | self.text = [] 73 | self.image = OrderedDict() 74 | self.txt2img = {} 75 | self.img2txt = {} 76 | 77 | txt_id = 0 78 | img_id = 0 79 | for ann in self.ann: 80 | if ann['image_url'] in self.image: 81 | cur_img_id = self.image[ann['image_url']][0] 82 | self.img2txt[cur_img_id].append(txt_id) 83 | self.txt2img[txt_id] = cur_img_id 84 | else: 85 | self.img2txt[img_id] = [txt_id] 86 | self.image[ann['image_url']] = (img_id, ann['image_content']) 87 | self.txt2img[txt_id] = img_id 88 | img_id += 1 89 | if ann['caption_reference_description'] == '.': 90 | self.text.append(ann['caption_reference_description']) 91 | else: 92 | self.text.append(pre_caption(ann['caption_reference_description'], self.max_words)) 93 | txt_id += 1 94 | 95 | def __len__(self): 96 | return len(self.image) 97 | 98 | def __getitem__(self, index): 99 | image_str = base64.b64decode(list(self.image.values())[index][1]) 100 | image = Image.open(io.BytesIO(image_str)).convert("RGB") 101 | image = self.transform(image) 102 | 103 | return image, index 104 | -------------------------------------------------------------------------------- /dataset/xflickrco_dataset.py: -------------------------------------------------------------------------------- 1 | # Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training (https://arxiv.org/abs/2206.00621) 2 | # Github: https://github.com/zengyan-97/CCLM 3 | # Copyright (c) 2022, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | import json 7 | import os 8 | 9 | from torch.utils.data import Dataset 10 | 11 | from PIL import Image 12 | from PIL import ImageFile 13 | 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | Image.MAX_IMAGE_PIXELS = None 16 | 17 | from dataset.utils import pre_caption 18 | 19 | 20 | class xflickrco_train_dataset(Dataset): 21 | def __init__(self, ann_file, transform, image_root, max_words=80): 22 | self.ann = [] 23 | for f in ann_file: 24 | for line in open(f): 25 | ann = json.loads(line) 26 | for i in range(len(ann['sentences'])): 27 | self.ann.append({ 28 | 'caption': ann['sentences'][i], 29 | 'id': ann['id'], 30 | 'img_path': ann['img_path']}) 31 | self.transform = transform 32 | self.image_root = image_root 33 | self.max_words = max_words 34 | self.img_ids = {} 35 | 36 | n = 0 37 | for ann in self.ann: 38 | img_id = ann['id'] 39 | if img_id not in self.img_ids.keys(): 40 | self.img_ids[img_id] = n 41 | n += 1 42 | 43 | def __len__(self): 44 | return len(self.ann) 45 | 46 | def __getitem__(self, index): 47 | 48 | ann = self.ann[index] 49 | 50 | image_path = os.path.join(self.image_root, ann['img_path']) 51 | image = Image.open(image_path).convert('RGB') 52 | image = self.transform(image) 53 | 54 | caption = pre_caption(ann['caption'], self.max_words) 55 | 56 | return image, caption, self.img_ids[ann['id']] 57 | 58 | 59 | class xflickrco_eval_dataset(Dataset): 60 | def __init__(self, ann_file, transform, image_root, max_words=80): 61 | self.ann = [] 62 | for line in open(ann_file, 'r'): 63 | ann = json.loads(line) 64 | 65 | # judge if caption if empty 66 | empty = True 67 | for sent in ann['sentences']: 68 | if sent: 69 | empty = False 70 | break 71 | if empty: 72 | print(ann_file, 'has a empty caption') 73 | continue 74 | 75 | self.ann.append(ann) 76 | self.transform = transform 77 | self.image_root = image_root 78 | self.max_words = max_words 79 | 80 | self.text = [] 81 | self.image = [] 82 | self.txt2img = {} 83 | self.img2txt = {} 84 | 85 | txt_id = 0 86 | for img_id, ann in enumerate(self.ann): 87 | self.image.append(ann['img_path']) 88 | self.img2txt[img_id] = [] 89 | for i, caption in enumerate(ann['sentences']): 90 | self.text.append(pre_caption(caption, self.max_words)) 91 | self.img2txt[img_id].append(txt_id) 92 | self.txt2img[txt_id] = img_id 93 | txt_id += 1 94 | 95 | def __len__(self): 96 | return len(self.image) 97 | 98 | def __getitem__(self, index): 99 | if 'COCO' in self.image[index]: 100 | image_path = os.path.join(self.image_root['coco'], self.image[index]) 101 | else: 102 | image_path = os.path.join(self.image_root['flickr30k'], self.image[index]) 103 | image = Image.open(image_path).convert('RGB') 104 | image = self.transform(image) 105 | 106 | return image, index 107 | -------------------------------------------------------------------------------- /dataset/xvnli_dataset.py: -------------------------------------------------------------------------------- 1 | # Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training (https://arxiv.org/abs/2206.00621) 2 | # Github: https://github.com/zengyan-97/CCLM 3 | # Copyright (c) 2022, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | import json 7 | import os 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | from dataset.utils import pre_caption 11 | 12 | 13 | class xvnli_dataset(Dataset): 14 | def __init__(self, ann_file, transform, image_root, max_words=80): 15 | self.label_mapper = {"contradiction": 0, "entailment": 1, "neutral": 2} 16 | 17 | self.ann = [] 18 | 19 | if type(ann_file) == str: 20 | ann_file = [ann_file] 21 | 22 | invalid_cnt = 0 23 | for f in ann_file: 24 | for line in open(f, 'r'): 25 | ann = json.loads(line) 26 | if ann['gold_label'] not in self.label_mapper: 27 | invalid_cnt += 1 28 | continue 29 | self.ann.append(ann) 30 | 31 | if not self.ann: 32 | raise ValueError(f"ann_file == {ann_file}") 33 | 34 | print('data num: ', len(self.ann)) 35 | print('invalid num: ', invalid_cnt) 36 | 37 | self.transform = transform 38 | self.image_root = image_root 39 | self.max_words = max_words 40 | 41 | def __len__(self): 42 | return len(self.ann) 43 | 44 | def __getitem__(self, index): 45 | ann = self.ann[index] 46 | 47 | image_path = os.path.join(self.image_root, ann['Flikr30kID']+'.jpg') 48 | image = Image.open(image_path).convert('RGB') 49 | image = self.transform(image) 50 | 51 | sentence = pre_caption(ann['sentence2'], self.max_words) 52 | 53 | label = self.label_mapper[ann['gold_label']] 54 | 55 | return image, sentence, label 56 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.xvlm import XVLMBase 2 | from models.xvlm import build_mlp 3 | from models.xvlm import load_pretrained -------------------------------------------------------------------------------- /models/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): # 这个用了 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | iou, union = box_iou(boxes1, boxes2) 50 | 51 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 52 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 53 | 54 | wh = (rb - lt).clamp(min=0) # [N,M,2] 55 | area = wh[:, :, 0] * wh[:, :, 1] 56 | 57 | return iou - (area - union) / area 58 | 59 | 60 | -------------------------------------------------------------------------------- /models/model_classification.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import json 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.nn import MSELoss 9 | 10 | from einops import rearrange 11 | 12 | from models.xvlm import XVLMBase, XVLMPlusBase 13 | from models.xvlm import build_mlp 14 | 15 | 16 | class XVLMForClassification(XVLMBase): 17 | def __init__(self, config): 18 | super().__init__(config, load_vision_params=False, load_text_params=False, 19 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False) 20 | 21 | feature_dim = self.vision_width if config.get('task_name') == 'imagenet' else self.text_width 22 | self.cls_head = build_mlp(input_dim=feature_dim, output_dim=config['num_labels']) 23 | 24 | def forward(self, image, text_ids, text_atts, targets=None, train=True): 25 | if image is None: 26 | output_cls = self.get_text_embeds_12L(text_ids, text_atts)[:, 0, :] 27 | 28 | elif text_ids is None: 29 | image_embeds, _ = self.get_vision_embeds(image) 30 | output_cls = image_embeds[:, 0, :] 31 | 32 | else: 33 | image_embeds, image_atts = self.get_vision_embeds(image) 34 | 35 | output_cls = self.get_cross_embeds(image_embeds, image_atts, 36 | text_ids=text_ids, text_atts=text_atts)[:, 0, :] 37 | 38 | prediction = self.cls_head(output_cls) 39 | if prediction.shape[-1] == 1: 40 | # We are doing regression 41 | loss_fct = MSELoss() 42 | return loss_fct(prediction.view(-1), targets.view(-1)) if train else prediction 43 | 44 | return F.cross_entropy(prediction, targets, ignore_index=-100) if train else prediction 45 | 46 | 47 | class XVLMForVQAClassification(XVLMBase): 48 | def __init__(self, config): 49 | super().__init__(config, load_vision_params=False, load_text_params=False, 50 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False) 51 | 52 | self.cls_head = build_mlp(input_dim=self.text_width, output_dim=config['num_labels']) 53 | self.init_params = ['cls_head.' + n for n, _ in self.cls_head.named_parameters()] 54 | 55 | def forward(self, image, text_ids, text_atts, targets=None, k=None, weights=None, train=True, answer_pred=None, 56 | return_logits=False): 57 | 58 | image_embeds, image_atts = self.get_vision_embeds(image) 59 | output_cls = self.get_cross_embeds(image_embeds, image_atts, 60 | text_ids=text_ids, text_atts=text_atts)[:, 0, :] 61 | 62 | prediction = self.cls_head(output_cls) 63 | if train: 64 | if answer_pred is not None: 65 | self.criterion = nn.KLDivLoss(reduction='none') 66 | log_probs = F.log_softmax(prediction, -1) 67 | answer_label = F.softmax(answer_pred, dim=-1) 68 | loss = self.criterion(log_probs, answer_label) 69 | loss = loss.sum() / image.size(0) 70 | return loss 71 | 72 | p_states = [] 73 | for b, n in enumerate(k): 74 | p_states = p_states + [prediction[b]] * n 75 | 76 | p_states = torch.stack(p_states, 0) 77 | 78 | loss = F.cross_entropy(p_states, targets, ignore_index=-100, reduction='none') 79 | 80 | loss = weights * loss 81 | loss = loss.sum() / image.size(0) 82 | if return_logits: 83 | return loss, prediction 84 | return loss 85 | else: 86 | return prediction 87 | 88 | 89 | class XVLMForNLVR(XVLMBase): 90 | """ 91 | Follow VLMo 92 | """ 93 | def __init__(self, config): 94 | super().__init__(config, load_vision_params=False, load_text_params=False, 95 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False, 96 | config_text=None) 97 | 98 | self.cls_head = build_mlp(input_dim=self.text_width * 2, output_dim=2) 99 | self.init_params = ['cls_head.' + n for n, _ in self.cls_head.named_parameters()] 100 | 101 | def forward(self, image, text_ids, text_atts, targets, train=True): 102 | image_embeds, image_atts = self.get_vision_embeds(image) 103 | image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0)) 104 | 105 | output_cls_image1 = self.get_cross_embeds(image0_embeds, image_atts[:image0_embeds.size(0)], 106 | text_ids=text_ids, text_atts=text_atts)[:, 0, :] 107 | 108 | output_cls_image2 = self.get_cross_embeds(image1_embeds, image_atts[image0_embeds.size(0):], 109 | text_ids=text_ids, text_atts=text_atts)[:, 0, :] 110 | 111 | output_cls = torch.cat((output_cls_image1, output_cls_image2), dim=-1) 112 | 113 | assert output_cls.shape[-1] == self.text_width * 2 114 | 115 | prediction = self.cls_head(output_cls) 116 | 117 | return F.cross_entropy(prediction, targets) if train else prediction 118 | 119 | 120 | class XVLMPlus4XVNLI(XVLMPlusBase): 121 | def __init__(self, config): 122 | super().__init__(config, load_vision_params=False, load_text_params=False, 123 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False) 124 | 125 | self.cls_head = build_mlp(input_dim=self.text_width, output_dim=config['num_labels']) 126 | self.init_params = ['cls_head.' + n for n, _ in self.cls_head.named_parameters()] 127 | 128 | def forward(self, image, text_ids, text_atts, targets, train=True): 129 | image_embeds, image_atts = self.get_vision_embeds(image) 130 | text_embeds = self.get_cross_embeds(image_embeds, image_atts, text_ids, text_atts=text_atts) 131 | prediction = self.cls_head(text_embeds[:, 0, :]) 132 | 133 | return F.cross_entropy(prediction, targets) if train else prediction 134 | 135 | 136 | class XVLMPlusForMARVL(XVLMPlusBase): 137 | def __init__(self, config): 138 | super().__init__(config, load_vision_params=False, load_text_params=False, 139 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False) 140 | 141 | self.cls_head = build_mlp(input_dim=self.text_width * 2, output_dim=2) 142 | self.init_params = ['cls_head.' + n for n, _ in self.cls_head.named_parameters()] 143 | 144 | def forward(self, image, text_ids, text_atts, targets, train=True): 145 | image_embeds, image_atts = self.get_vision_embeds(image) 146 | image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0)) 147 | 148 | output_cls_image1 = self.get_cross_embeds(image0_embeds, image_atts[:image0_embeds.size(0)], 149 | text_ids=text_ids, text_atts=text_atts)[:, 0, :] 150 | 151 | output_cls_image2 = self.get_cross_embeds(image1_embeds, image_atts[image0_embeds.size(0):], 152 | text_ids=text_ids, text_atts=text_atts)[:, 0, :] 153 | 154 | output_cls = torch.cat((output_cls_image1, output_cls_image2), dim=-1) 155 | 156 | assert output_cls.shape[-1] == self.text_width * 2 157 | 158 | prediction = self.cls_head(output_cls) 159 | 160 | return F.cross_entropy(prediction, targets) if train else prediction 161 | 162 | -------------------------------------------------------------------------------- /models/model_grounding.py: -------------------------------------------------------------------------------- 1 | from models import XVLMBase, load_pretrained 2 | 3 | 4 | class XVLMForGrounding(XVLMBase): 5 | def __init__(self, config): 6 | super().__init__(config, load_vision_params=False, load_text_params=False, 7 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=True) 8 | self.init_params = [] 9 | 10 | def load_pretrained(self, ckpt_rpath, config, load_bbox_pretrain=False, is_eval=False): 11 | print("### load_bbox_pretrain, ", load_bbox_pretrain) 12 | state_dict = load_pretrained(self, ckpt_rpath, config, is_eval=is_eval, load_text=True) 13 | msg = self.load_state_dict(state_dict, strict=False) 14 | print('load checkpoint from %s' % ckpt_rpath) 15 | print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p]) 16 | print("unexpected_keys: ", msg.unexpected_keys) 17 | 18 | def forward(self, image, text_ids, text_atts, target_bbox=None): 19 | image_embeds, _ = self.get_vision_embeds(image) 20 | text_embeds = self.get_text_embeds(text_ids, text_atts) 21 | 22 | output_coord = self.predict_bbox(image_embeds, text_embeds, text_atts) 23 | # output_coord & target_bbox: 64, 4 24 | 25 | if target_bbox is None: 26 | return output_coord 27 | 28 | loss_bbox, loss_giou = self.get_bbox_loss(output_coord, target_bbox) 29 | 30 | return output_coord, loss_bbox, loss_giou 31 | 32 | -------------------------------------------------------------------------------- /models/model_pretrain.py: -------------------------------------------------------------------------------- 1 | # X^2-VLM: All-In-One Pre-trained Model For Vision-Language Tasks (https://arxiv.org/abs/2211.12402) 2 | # Github: https://github.com/zengyan-97/X2-VLM 3 | # Copyright (c) 2022, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | # Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training (https://arxiv.org/abs/2206.00621) 7 | # Github: https://github.com/zengyan-97/CCLM 8 | # Copyright (c) 2022, ByteDance Inc. 9 | # All rights reserved. 10 | 11 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 12 | # Github: https://github.com/zengyan-97/X-VLM 13 | # Copyright (c) 2022, ByteDance Inc. 14 | # All rights reserved. 15 | 16 | import os 17 | import json 18 | import torch 19 | from einops import rearrange 20 | 21 | from models.xvlm import XVLMBase, XVLMPlusBase, VanillaConfig 22 | 23 | 24 | class XVLM(XVLMBase): 25 | def __init__(self, config, load_vision_params=True, load_text_params=True, pretraining=True): 26 | super().__init__(config, load_vision_params=load_vision_params, load_text_params=load_text_params, 27 | use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=True, use_bbox_loss=True, 28 | config_text=None, pretraining=pretraining) 29 | 30 | def forward_multimodal(self, image, text_ids, text_atts, text_ids_masked=None, masked_pos=None, masked_ids=None, 31 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None, 32 | ret_bbox_loss=False, ret_match_loss=True): 33 | 34 | if ret_bbox_loss: 35 | image_embeds, image_atts, image_embeds_fullatts = \ 36 | self.get_vision_embeds(image, image_atts=image_atts, idx_to_group_img=idx_to_group_img) 37 | else: 38 | image_embeds, image_atts = self.get_vision_embeds(image) 39 | 40 | text_embeds = self.get_text_embeds(text_ids, text_atts) 41 | 42 | # with torch.no_grad(): # fix: i put it in batch iteration, so once a iteration 43 | # self.temp.clamp_(0.001, 0.5) 44 | 45 | image_feat, text_feat = self.get_features(image_embeds, text_embeds) 46 | 47 | loss_itc = self.get_contrastive_loss(image_feat, text_feat) 48 | 49 | if ret_match_loss: 50 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat) 51 | else: 52 | loss_itm = torch.tensor(0.0) 53 | 54 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids) 55 | 56 | loss = {'loss_itc': loss_itc, 'loss_itm': loss_itm, 'loss_mlm': loss_mlm} 57 | 58 | if ret_bbox_loss: 59 | output_coord = self.predict_bbox(image_embeds_fullatts, text_embeds, text_atts) 60 | loss_bbox, loss_giou = self.get_bbox_loss(output_coord, target_bbox, is_image=is_image) 61 | 62 | loss['loss_bbox'] = loss_bbox 63 | loss['loss_giou'] = loss_giou 64 | 65 | return loss 66 | 67 | def forward_text(self, text_ids=None, text_atts=None, 68 | text_ids_masked=None, masked_pos=None, masked_ids=None): 69 | 70 | loss = self.get_mlm_loss(text_ids_masked, text_atts, None, None, masked_pos, masked_ids) 71 | 72 | return {'loss_mlm': loss} 73 | 74 | def forward(self, image=None, text_ids=None, text_atts=None, 75 | text_ids_masked=None, masked_pos=None, masked_ids=None, 76 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None, 77 | ret_bbox_loss=False, ret_match_loss=True): 78 | 79 | if image is None: # text 80 | loss = self.forward_text(text_ids, text_atts, text_ids_masked, 81 | masked_pos, masked_ids) 82 | 83 | else: 84 | loss = self.forward_multimodal(image, text_ids, text_atts, text_ids_masked, masked_pos, masked_ids, 85 | image_atts, idx_to_group_img, target_bbox, is_image, ret_bbox_loss, 86 | ret_match_loss=ret_match_loss) 87 | 88 | return loss 89 | 90 | 91 | class XVLMPlus(XVLMPlusBase): 92 | def __init__(self, config, use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=True, use_bbox_loss=True, 93 | load_vision_params=True, load_text_params=True, load_cross_params=True, pretraining=True): 94 | super().__init__(config, use_contrastive_loss=use_contrastive_loss, use_matching_loss=use_matching_loss, 95 | use_mlm_loss=use_mlm_loss, use_bbox_loss=use_bbox_loss, 96 | load_vision_params=load_vision_params, load_text_params=load_text_params, load_cross_params=load_cross_params, 97 | pretraining=pretraining) 98 | 99 | def forward_multimodal(self, image, text_ids, text_atts, text_ids_masked=None, masked_pos=None, masked_ids=None, 100 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None, 101 | ret_bbox_loss=False, ret_match_loss=True): 102 | 103 | if ret_bbox_loss: 104 | image_embeds, image_atts, image_embeds_fullatts = \ 105 | self.get_vision_embeds(image, image_atts=image_atts, idx_to_group_img=idx_to_group_img) 106 | else: 107 | image_embeds, image_atts = self.get_vision_embeds(image) 108 | 109 | text_embeds = self.get_text_embeds(text_ids, text_atts) 110 | 111 | # with torch.no_grad(): # fix: i put it in batch iteration, so once a iteration 112 | # self.temp.clamp_(0.001, 0.5) 113 | 114 | image_feat, text_feat = self.get_features(image_embeds, text_embeds) 115 | 116 | loss_itc = self.get_contrastive_loss(image_feat, text_feat) 117 | 118 | if ret_match_loss: 119 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat) 120 | else: 121 | loss_itm = torch.tensor(0.0) 122 | 123 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids) 124 | 125 | loss = {'loss_itc': loss_itc, 'loss_itm': loss_itm, 'loss_mlm': loss_mlm} 126 | 127 | if ret_bbox_loss: 128 | output_coord = self.predict_bbox(image_embeds_fullatts, text_embeds, text_atts) 129 | loss_bbox, loss_giou = self.get_bbox_loss(output_coord, target_bbox, is_image=is_image) 130 | 131 | loss['loss_bbox'] = loss_bbox 132 | loss['loss_giou'] = loss_giou 133 | 134 | return loss 135 | 136 | def forward(self, image=None, text_ids=None, text_atts=None, 137 | text_ids_masked=None, masked_pos=None, masked_ids=None, 138 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None, 139 | ret_bbox_loss=False, ret_match_loss=True): 140 | 141 | loss = self.forward_multimodal(image, text_ids, text_atts, text_ids_masked, masked_pos, masked_ids, 142 | image_atts, idx_to_group_img, target_bbox, is_image, ret_bbox_loss, 143 | ret_match_loss=ret_match_loss) 144 | 145 | return loss 146 | 147 | 148 | class CrossViewLM(XVLMPlus): # Multilingual x Multimodal Pre-training 149 | """ 150 | Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training 151 | https://arxiv.org/abs/2206.00621 152 | """ 153 | def __init__(self, config, use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=True, use_bbox_loss=True, 154 | load_vision_params=True, load_text_params=True, load_cross_params=True, pretraining=True): 155 | super().__init__(config, use_contrastive_loss=use_contrastive_loss, use_matching_loss=use_matching_loss, 156 | use_mlm_loss=use_mlm_loss, use_bbox_loss=use_bbox_loss, 157 | load_vision_params=load_vision_params, load_text_params=load_text_params, load_cross_params=load_cross_params, 158 | pretraining=pretraining) 159 | 160 | def forward_para_text(self, text_ids=None, text_atts=None, 161 | text_ids_masked=None, text_atts_masked=None, masked_pos=None, masked_ids=None, 162 | text_ids_2=None, text_atts_2=None, text_ids_masked_2=None, masked_pos_2=None, masked_ids_2=None): 163 | 164 | text_embeds = self.get_text_embeds(text_ids, text_atts) 165 | text_embeds_2 = self.get_text_embeds(text_ids_2, text_atts_2) 166 | 167 | # with torch.no_grad(): 168 | # self.temp.clamp_(0.001, 0.5) 169 | 170 | text_feat = self.get_features(text_embeds=text_embeds) 171 | text_feat_2 = self.get_features(text_embeds=text_embeds_2) 172 | 173 | loss_ttc = self.get_contrastive_loss(text_feat, text_feat_2) 174 | loss_ttm = self.get_matching_loss(text_embeds, text_atts, text_feat, text_embeds_2, text_atts_2, text_feat_2) 175 | 176 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, text_embeds_2, text_atts_2, masked_pos, masked_ids) 177 | 178 | loss = {'loss_ttc': loss_ttc, 'loss_ttm': loss_ttm, 'loss_mlm': loss_mlm} 179 | 180 | return loss 181 | 182 | def forward(self, image=None, text_ids=None, text_atts=None, 183 | text_ids_masked=None, text_atts_masked=None, masked_pos=None, masked_ids=None, 184 | text_ids_2=None, text_atts_2=None, text_ids_masked_2=None, masked_pos_2=None, masked_ids_2=None, 185 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None, ret_bbox_loss=False, ret_match_loss=True): 186 | 187 | if image is None: # parallel text 188 | loss = self.forward_para_text(text_ids, text_atts, text_ids_masked, text_atts_masked, masked_pos, masked_ids, 189 | text_ids_2, text_atts_2, text_ids_masked_2, masked_pos_2, masked_ids_2) 190 | 191 | else: 192 | loss = self.forward_multimodal(image, text_ids, text_atts, text_ids_masked, masked_pos, masked_ids, 193 | image_atts, idx_to_group_img, target_bbox, is_image, ret_bbox_loss, 194 | ret_match_loss=ret_match_loss) 195 | 196 | return loss 197 | -------------------------------------------------------------------------------- /models/model_retrieval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models.xvlm import XVLMBase, XVLMPlusBase 4 | 5 | 6 | class XVLMForRetrieval(XVLMBase): 7 | def __init__(self, config): 8 | super().__init__(config, load_vision_params=False, load_text_params=False, 9 | use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=False, use_bbox_loss=False) 10 | 11 | self.num_attention_heads = self.text_encoder.config.num_attention_heads 12 | self.init_params = [] 13 | 14 | def forward(self, image, text_ids, text_atts, idx=None): 15 | image_embeds, image_atts = self.get_vision_embeds(image) 16 | text_embeds = self.get_text_embeds(text_ids, text_atts) 17 | 18 | with torch.no_grad(): 19 | self.temp.clamp_(0.001, 0.5) 20 | 21 | image_feat, text_feat = self.get_features(image_embeds, text_embeds) 22 | loss_itc = self.get_contrastive_loss(image_feat, text_feat, idx=idx) 23 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat, idx=idx) 24 | 25 | return loss_itc, loss_itm 26 | 27 | 28 | class XVLMPlusForRetrieval(XVLMPlusBase): 29 | def __init__(self, config): 30 | super().__init__(config, load_vision_params=False, load_text_params=False, load_cross_params=False, 31 | use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=False, use_bbox_loss=False) 32 | 33 | self.num_attention_heads = self.text_encoder.config.num_attention_heads 34 | self.init_params = [] 35 | 36 | def forward(self, image, text_ids, text_atts, idx=None): 37 | image_embeds, image_atts = self.get_vision_embeds(image) 38 | text_embeds = self.get_text_embeds(text_ids, text_atts) 39 | 40 | with torch.no_grad(): 41 | self.temp.clamp_(0.001, 0.5) 42 | 43 | image_feat, text_feat = self.get_features(image_embeds, text_embeds) 44 | loss_itc = self.get_contrastive_loss(image_feat, text_feat, idx=idx) 45 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat, idx=idx) 46 | 47 | return loss_itc, loss_itm 48 | -------------------------------------------------------------------------------- /models/resampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange, repeat 4 | from einops_exts import rearrange_many, repeat_many 5 | 6 | 7 | def FeedForward(dim, mult=4): 8 | inner_dim = int(dim * mult) 9 | return nn.Sequential( 10 | nn.LayerNorm(dim), 11 | nn.Linear(dim, inner_dim, bias=False), 12 | nn.GELU(), 13 | nn.Linear(inner_dim, dim, bias=False) 14 | ) 15 | 16 | 17 | class PerceiverAttention(nn.Module): 18 | def __init__( 19 | self, 20 | *, 21 | dim, 22 | dim_head=64, 23 | heads=8 24 | ): 25 | super().__init__() 26 | self.scale = dim_head ** -0.5 27 | self.heads = heads 28 | inner_dim = dim_head * heads 29 | 30 | self.norm_media = nn.LayerNorm(dim) 31 | self.norm_latents = nn.LayerNorm(dim) 32 | 33 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 34 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 35 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 36 | 37 | def forward(self, x, latents): 38 | """ 39 | einstein notation 40 | b - batch 41 | t - time 42 | n - sequence 43 | d - dimension 44 | """ 45 | x = self.norm_media(x) 46 | latents = self.norm_latents(latents) 47 | 48 | b, m, h = *x.shape[:2], self.heads 49 | 50 | q = self.to_q(latents) 51 | 52 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to 53 | kv_input = torch.cat((x, latents), dim=-2) 54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 55 | 56 | q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h=h) 57 | 58 | q = q * self.scale 59 | 60 | # attention 61 | 62 | sim = einsum('... i d, ... j d -> ... i j', q, k) 63 | 64 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 65 | attn = sim.softmax(dim=-1) 66 | 67 | out = einsum('... i j, ... j d -> ... i d', attn, v) 68 | out = rearrange(out, 'b h t n d -> b t n (h d)', h=h) 69 | return self.to_out(out) 70 | 71 | 72 | class PerceiverResampler(nn.Module): 73 | def __init__( 74 | self, 75 | *, 76 | dim, 77 | depth, 78 | dim_head=64, 79 | heads=8, 80 | num_latents=64, 81 | num_time_embeds=4, 82 | ff_mult=4, 83 | num_img_latents=-1, 84 | ): 85 | super().__init__() 86 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 87 | self.img_latents = nn.Parameter(torch.randn(num_img_latents, dim)) if num_img_latents > 0 else None 88 | 89 | # self.time_pos_emb = nn.Parameter(torch.randn(num_time_embeds, 1, dim)) # 我的代码里面会加好 90 | 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append(nn.ModuleList([ 94 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 95 | FeedForward(dim=dim, mult=ff_mult) 96 | ])) 97 | 98 | self.norm = nn.LayerNorm(dim) 99 | 100 | def forward(self, x, mode='video'): 101 | if x.ndim == 3: 102 | x = rearrange(x, 'b n d -> b 1 n d') 103 | 104 | # times = x.shape[1] 105 | # x = x + self.time_pos_emb[:times] 106 | 107 | if mode == 'video': 108 | latents = self.latents 109 | elif mode == 'image': 110 | latents = self.img_latents 111 | else: 112 | raise ValueError(f"mode == {mode}") 113 | 114 | latents = repeat(latents, 'n d -> b m n d', b=x.shape[0], m=x.shape[1]) 115 | 116 | for attn, ff in self.layers: 117 | latents = attn(x, latents) + latents 118 | latents = ff(latents) + latents 119 | 120 | return self.norm(latents) 121 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | from transformers.optimization import AdamW 2 | import torch 3 | 4 | 5 | def is_vision(args, param: str): 6 | if hasattr(args, 'vision_lr') and ('vision_encoder' in param): 7 | return True 8 | else: 9 | return False 10 | 11 | 12 | def is_text(args, param: str): 13 | if hasattr(args, 'text_lr') and ('text_encoder' in param): 14 | return True 15 | else: 16 | return False 17 | 18 | 19 | def is_cross(args, param: str): 20 | if hasattr(args, 'cross_lr') and ('cross_encoder' in param): # only supports XVLMPlusBase 21 | return True 22 | else: 23 | return False 24 | 25 | 26 | def create_optimizer(args, model): 27 | lr = args.lr 28 | wd = args.weight_decay 29 | lr_mult = getattr(args, 'lr_mult', 1) 30 | print(f"### lr: {args.lr}, lr_mult: {lr_mult}") 31 | 32 | optimizer_grouped_parameters = [ 33 | {"params": [], "weight_decay": wd, "lr": lr}, 34 | {"params": [], "weight_decay": 0.0, "lr": lr}, 35 | {"params": [], "weight_decay": wd, "lr": lr * lr_mult}, 36 | {"params": [], "weight_decay": 0.0, "lr": lr * lr_mult} 37 | ] 38 | 39 | if hasattr(args, 'vision_lr'): 40 | # 4 & 5 41 | print("### vision_lr: ", args.vision_lr) 42 | optimizer_grouped_parameters.append({"params": [], "weight_decay": wd, "lr": args.vision_lr}) 43 | optimizer_grouped_parameters.append({"params": [], "weight_decay": 0.0, "lr": args.vision_lr}) 44 | 45 | # 6 & 7 46 | assert hasattr(args, 'text_lr') 47 | print("### text_lr: ", args.text_lr) 48 | optimizer_grouped_parameters.append({"params": [], "weight_decay": wd, "lr": args.text_lr}) 49 | optimizer_grouped_parameters.append({"params": [], "weight_decay": 0.0, "lr": args.text_lr}) 50 | 51 | # 8 & 9 52 | if not hasattr(args, 'cross_lr'): 53 | args.cross_lr = args.text_lr 54 | 55 | print("### cross_lr: ", args.cross_lr, flush=True) 56 | optimizer_grouped_parameters.append({"params": [], "weight_decay": wd, "lr": args.cross_lr}) 57 | optimizer_grouped_parameters.append({"params": [], "weight_decay": 0.0, "lr": args.cross_lr}) 58 | 59 | no_decay = {"bias", 60 | "LayerNorm.bias", 61 | "LayerNorm.weight", 62 | "norm.bias", 63 | "norm.weight", 64 | "norm1.bias", 65 | "norm1.weight", 66 | "norm2.bias", 67 | "norm2.weight"} 68 | 69 | if hasattr(model, 'init_params'): 70 | large_lr = model.init_params 71 | print("### model has 'init_params', ", len(large_lr)) 72 | else: 73 | large_lr = {} 74 | 75 | for n, p in model.named_parameters(): 76 | if not p.requires_grad: 77 | continue # frozen weights 78 | 79 | if any(nd in n for nd in no_decay): 80 | if is_vision(args, n): 81 | optimizer_grouped_parameters[5]['params'].append(p) 82 | elif is_text(args, n): 83 | optimizer_grouped_parameters[7]['params'].append(p) 84 | elif is_cross(args, n): 85 | optimizer_grouped_parameters[9]['params'].append(p) 86 | elif n in large_lr: 87 | optimizer_grouped_parameters[3]['params'].append(p) 88 | else: 89 | optimizer_grouped_parameters[1]['params'].append(p) 90 | else: # decay 91 | if is_vision(args, n): 92 | optimizer_grouped_parameters[4]['params'].append(p) 93 | elif is_text(args, n): 94 | optimizer_grouped_parameters[6]['params'].append(p) 95 | elif is_cross(args, n): 96 | optimizer_grouped_parameters[8]['params'].append(p) 97 | elif n in large_lr: 98 | optimizer_grouped_parameters[2]['params'].append(p) 99 | else: 100 | optimizer_grouped_parameters[0]['params'].append(p) 101 | 102 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)) 103 | 104 | return optimizer 105 | 106 | 107 | class LARS(torch.optim.Optimizer): 108 | """ 109 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 110 | """ 111 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 112 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 113 | super().__init__(params, defaults) 114 | 115 | @torch.no_grad() 116 | def step(self): 117 | for g in self.param_groups: 118 | for p in g['params']: 119 | dp = p.grad 120 | 121 | if dp is None: 122 | continue 123 | 124 | if p.ndim > 1: # if not normalization gamma/beta or bias 125 | dp = dp.add(p, alpha=g['weight_decay']) 126 | param_norm = torch.norm(p) 127 | update_norm = torch.norm(dp) 128 | one = torch.ones_like(param_norm) 129 | q = torch.where(param_norm > 0., 130 | torch.where(update_norm > 0, 131 | (g['trust_coefficient'] * param_norm / update_norm), one), 132 | one) 133 | dp = dp.mul(q) 134 | 135 | param_state = self.state[p] 136 | if 'mu' not in param_state: 137 | param_state['mu'] = torch.zeros_like(p) 138 | mu = param_state['mu'] 139 | mu.mul_(g['momentum']).add_(dp) 140 | p.add_(mu, alpha=-g['lr']) 141 | -------------------------------------------------------------------------------- /refTools/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'licheng' 2 | 3 | 4 | -------------------------------------------------------------------------------- /refTools/evaluation/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from refTools.evaluation.bleu.bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /refTools/evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in range(1,n+1): 30 | for i in range(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.items(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, refparam, eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | 64 | reflen, refmaxcounts = refparam[0], refparam[1] 65 | testlen, counts = precook(test, n, True) 66 | 67 | result = {} 68 | 69 | # Calculate effective reference sentence length. 70 | 71 | if eff == "closest": 72 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 73 | else: ## i.e., "average" or "shortest" or None 74 | result["reflen"] = reflen 75 | 76 | result["testlen"] = testlen 77 | 78 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 79 | 80 | result['correct'] = [0]*n 81 | for (ngram, count) in counts.items(): 82 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 83 | 84 | return result 85 | 86 | class BleuScorer(object): 87 | """Bleu scorer. 88 | """ 89 | 90 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 91 | # special_reflen is used in oracle (proportional effective ref len for a node). 92 | 93 | def copy(self): 94 | ''' copy the refs.''' 95 | new = BleuScorer(n=self.n) 96 | new.ctest = copy.copy(self.ctest) 97 | new.crefs = copy.copy(self.crefs) 98 | new._score = None 99 | return new 100 | 101 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 102 | ''' singular instance ''' 103 | 104 | self.n = n 105 | self.crefs = [] 106 | self.ctest = [] 107 | self.cook_append(test, refs) 108 | self.special_reflen = special_reflen 109 | 110 | def cook_append(self, test, refs): 111 | '''called by constructor and __iadd__ to avoid creating new instances.''' 112 | 113 | if refs is not None: 114 | self.crefs.append(cook_refs(refs)) 115 | if test is not None: 116 | cooked_test = cook_test(test, self.crefs[-1]) 117 | self.ctest.append(cooked_test) ## N.B.: -1 118 | else: 119 | self.ctest.append(None) # lens of crefs and ctest have to match 120 | 121 | self._score = None ## need to recompute 122 | 123 | def ratio(self, option=None): 124 | self.compute_score(option=option) 125 | return self._ratio 126 | 127 | def score_ratio(self, option=None): 128 | '''return (bleu, len_ratio) pair''' 129 | return (self.fscore(option=option), self.ratio(option=option)) 130 | 131 | def score_ratio_str(self, option=None): 132 | return "%.4f (%.2f)" % self.score_ratio(option) 133 | 134 | def reflen(self, option=None): 135 | self.compute_score(option=option) 136 | return self._reflen 137 | 138 | def testlen(self, option=None): 139 | self.compute_score(option=option) 140 | return self._testlen 141 | 142 | def retest(self, new_test): 143 | if type(new_test) is str: 144 | new_test = [new_test] 145 | assert len(new_test) == len(self.crefs), new_test 146 | self.ctest = [] 147 | for t, rs in zip(new_test, self.crefs): 148 | self.ctest.append(cook_test(t, rs)) 149 | self._score = None 150 | 151 | return self 152 | 153 | def rescore(self, new_test): 154 | ''' replace test(s) with new test(s), and returns the new score.''' 155 | 156 | return self.retest(new_test).compute_score() 157 | 158 | def size(self): 159 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 160 | return len(self.crefs) 161 | 162 | def __iadd__(self, other): 163 | '''add an instance (e.g., from another sentence).''' 164 | 165 | if type(other) is tuple: 166 | ## avoid creating new BleuScorer instances 167 | self.cook_append(other[0], other[1]) 168 | else: 169 | assert self.compatible(other), "incompatible BLEUs." 170 | self.ctest.extend(other.ctest) 171 | self.crefs.extend(other.crefs) 172 | self._score = None ## need to recompute 173 | 174 | return self 175 | 176 | def compatible(self, other): 177 | return isinstance(other, BleuScorer) and self.n == other.n 178 | 179 | def single_reflen(self, option="average"): 180 | return self._single_reflen(self.crefs[0][0], option) 181 | 182 | def _single_reflen(self, reflens, option=None, testlen=None): 183 | 184 | if option == "shortest": 185 | reflen = min(reflens) 186 | elif option == "average": 187 | reflen = float(sum(reflens))/len(reflens) 188 | elif option == "closest": 189 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 190 | else: 191 | assert False, "unsupported reflen option %s" % option 192 | 193 | return reflen 194 | 195 | def recompute_score(self, option=None, verbose=0): 196 | self._score = None 197 | return self.compute_score(option, verbose) 198 | 199 | def compute_score(self, option=None, verbose=0): 200 | n = self.n 201 | small = 1e-9 202 | tiny = 1e-15 ## so that if guess is 0 still return 0 203 | bleu_list = [[] for _ in range(n)] 204 | 205 | if self._score is not None: 206 | return self._score 207 | 208 | if option is None: 209 | option = "average" if len(self.crefs) == 1 else "closest" 210 | 211 | self._testlen = 0 212 | self._reflen = 0 213 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 214 | 215 | # for each sentence 216 | for comps in self.ctest: 217 | testlen = comps['testlen'] 218 | self._testlen += testlen 219 | 220 | if self.special_reflen is None: ## need computation 221 | reflen = self._single_reflen(comps['reflen'], option, testlen) 222 | else: 223 | reflen = self.special_reflen 224 | 225 | self._reflen += reflen 226 | 227 | for key in ['guess','correct']: 228 | for k in range(n): 229 | totalcomps[key][k] += comps[key][k] 230 | 231 | # append per image bleu score 232 | bleu = 1. 233 | for k in range(n): 234 | bleu *= (float(comps['correct'][k]) + tiny) \ 235 | /(float(comps['guess'][k]) + small) 236 | bleu_list[k].append(bleu ** (1./(k+1))) 237 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 238 | if ratio < 1: 239 | for k in range(n): 240 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 241 | 242 | if verbose > 1: 243 | print(comps, reflen) 244 | 245 | totalcomps['reflen'] = self._reflen 246 | totalcomps['testlen'] = self._testlen 247 | 248 | bleus = [] 249 | bleu = 1. 250 | for k in range(n): 251 | bleu *= float(totalcomps['correct'][k] + tiny) \ 252 | / (totalcomps['guess'][k] + small) 253 | bleus.append(bleu ** (1./(k+1))) 254 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 255 | if ratio < 1: 256 | for k in range(n): 257 | bleus[k] *= math.exp(1 - 1/ratio) 258 | 259 | if verbose > 0: 260 | print(totalcomps) 261 | print("ratio:", ratio) 262 | 263 | self._score = bleus 264 | return self._score, bleu_list 265 | -------------------------------------------------------------------------------- /refTools/evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from refTools.evaluation.cider.cider_scorer import CiderScorer 11 | import pdb 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | assert(gts.keys() == res.keys()) 33 | imgIds = gts.keys() 34 | 35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 36 | 37 | for id in imgIds: 38 | hypo = res[id] 39 | ref = gts[id] 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | 51 | return score, scores 52 | 53 | def method(self): 54 | return "CIDEr" -------------------------------------------------------------------------------- /refTools/evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in xrange(1,n+1): 23 | for i in xrange(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.iteritems(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].iteritems(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /refTools/evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/refTools/evaluation/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /refTools/evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import threading 10 | 11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 12 | METEOR_JAR = 'meteor-1.5.jar' 13 | # print METEOR_JAR 14 | 15 | class Meteor: 16 | 17 | def __init__(self): 18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 19 | '-', '-', '-stdio', '-l', 'en', '-norm'] 20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 21 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 22 | stdin=subprocess.PIPE, \ 23 | stdout=subprocess.PIPE, \ 24 | stderr=subprocess.PIPE) 25 | # Used to guarantee thread safety 26 | self.lock = threading.Lock() 27 | 28 | def compute_score(self, gts, res): 29 | assert(gts.keys() == res.keys()) 30 | imgIds = gts.keys() 31 | scores = [] 32 | 33 | eval_line = 'EVAL' 34 | self.lock.acquire() 35 | for i in imgIds: 36 | assert(len(res[i]) == 1) 37 | stat = self._stat(res[i][0], gts[i]) 38 | eval_line += ' ||| {}'.format(stat) 39 | 40 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 41 | for i in range(0,len(imgIds)): 42 | scores.append(float(self.meteor_p.stdout.readline().strip())) 43 | score = float(self.meteor_p.stdout.readline().strip()) 44 | self.lock.release() 45 | 46 | return score, scores 47 | 48 | def method(self): 49 | return "METEOR" 50 | 51 | def _stat(self, hypothesis_str, reference_list): 52 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 53 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 54 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 55 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 56 | return self.meteor_p.stdout.readline().decode().strip() 57 | 58 | def _score(self, hypothesis_str, reference_list): 59 | self.lock.acquire() 60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 63 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 64 | stats = self.meteor_p.stdout.readline().strip() 65 | eval_line = 'EVAL ||| {}'.format(stats) 66 | # EVAL ||| stats 67 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 68 | score = float(self.meteor_p.stdout.readline().strip()) 69 | self.lock.release() 70 | return score 71 | 72 | def __exit__(self): 73 | self.lock.acquire() 74 | self.meteor_p.stdin.close() 75 | self.meteor_p.wait() 76 | self.lock.release() 77 | -------------------------------------------------------------------------------- /refTools/evaluation/readme.txt: -------------------------------------------------------------------------------- 1 | This folder contains modified coco-caption evaluation, which is downloaded from https://github.com/tylin/coco-caption.git 2 | and refEvaluation which is to be called by the refer algorithm. 3 | 4 | More specifically, this folder contains: 5 | 1. bleu/ 6 | 2. cider/ 7 | 3. meteor/ 8 | 4. rouge/ 9 | 5. tokenizer/ 10 | 6. __init__.py 11 | 7. refEvaluation.py 12 | -------------------------------------------------------------------------------- /refTools/evaluation/refEvaluation.py: -------------------------------------------------------------------------------- 1 | from refTools.evaluation.tokenizer.ptbtokenizer import PTBTokenizer 2 | from refTools.evaluation.bleu.bleu import Bleu 3 | from refTools.evaluation.meteor.meteor import Meteor 4 | from refTools.evaluation.rouge.rouge import Rouge 5 | from refTools.evaluation.cider.cider import Cider 6 | 7 | """ 8 | Input: refer and Res = [{ref_id, sent}] 9 | 10 | Things of interest 11 | evalRefs - list of ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR'] 12 | eval - dict of {metric: score} 13 | refToEval - dict of {ref_id: ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']} 14 | """ 15 | 16 | class RefEvaluation: 17 | def __init__ (self, refer, Res): 18 | """ 19 | :param refer: refer class of current dataset 20 | :param Res: [{'ref_id', 'sent'}] 21 | """ 22 | self.evalRefs = [] 23 | self.eval = {} 24 | self.refToEval = {} 25 | self.refer = refer 26 | self.Res = Res 27 | 28 | def evaluate(self): 29 | 30 | evalRefIds = [ann['ref_id'] for ann in self.Res] 31 | 32 | refToGts = {} 33 | for ref_id in evalRefIds: 34 | ref = self.refer.Refs[ref_id] 35 | gt_sents = [sent['sent'].encode('ascii', 'ignore').decode('ascii') for sent in ref['sentences']] # up to 3 expressions 36 | refToGts[ref_id] = gt_sents 37 | refToRes = {ann['ref_id']: [ann['sent']] for ann in self.Res} 38 | 39 | print('tokenization...') 40 | tokenizer = PTBTokenizer() 41 | self.refToRes = tokenizer.tokenize(refToRes) 42 | self.refToGts = tokenizer.tokenize(refToGts) 43 | 44 | # ================================================= 45 | # Set up scorers 46 | # ================================================= 47 | print('setting up scorers...') 48 | scorers = [ 49 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 50 | (Meteor(),"METEOR"), 51 | (Rouge(), "ROUGE_L"), 52 | (Cider(), "CIDEr") 53 | ] 54 | 55 | # ================================================= 56 | # Compute scores 57 | # ================================================= 58 | for scorer, method in scorers: 59 | print('computing %s score...'%(scorer.method())) 60 | score, scores = scorer.compute_score(self.refToGts, self.refToRes) 61 | if type(method) == list: 62 | for sc, scs, m in zip(score, scores, method): 63 | self.setEval(sc, m) 64 | self.setRefToEvalRefs(scs, self.refToGts.keys(), m) 65 | print("%s: %0.3f"%(m, sc)) 66 | else: 67 | self.setEval(score, method) 68 | self.setRefToEvalRefs(scores, self.refToGts.keys(), method) 69 | print("%s: %0.3f"%(method, score)) 70 | self.setEvalRefs() 71 | 72 | def setEval(self, score, method): 73 | self.eval[method] = score 74 | 75 | def setRefToEvalRefs(self, scores, refIds, method): 76 | for refId, score in zip(refIds, scores): 77 | if not refId in self.refToEval: 78 | self.refToEval[refId] = {} 79 | self.refToEval[refId]["ref_id"] = refId 80 | self.refToEval[refId][method] = score 81 | 82 | def setEvalRefs(self): 83 | self.evalRefs = [eval for refId, eval in self.refToEval.items()] 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | import os.path as osp 89 | import sys 90 | ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..')) 91 | sys.path.insert(0, osp.join(ROOT_DIR, 'lib', 'datasets')) 92 | from refer import REFER 93 | 94 | # load refer of dataset 95 | dataset = 'refcoco' 96 | refer = REFER(dataset, splitBy = 'google') 97 | 98 | # mimic some Res 99 | val_refIds = refer.getRefIds(split='test') 100 | ref_id = 49767 101 | print("GD: %s" % refer.Refs[ref_id]['sentences']) 102 | Res = [{'ref_id': ref_id, 'sent': 'left bottle'}] 103 | 104 | # evaluate some refer expressions 105 | refEval = RefEvaluation(refer, Res) 106 | refEval.evaluate() 107 | 108 | # print output evaluation scores 109 | for metric, score in refEval.eval.items(): 110 | print('%s: %.3f'%(metric, score)) 111 | 112 | # demo how to use evalImgs to retrieve low score result 113 | # evals = [eva for eva in refEval.evalRefs if eva['CIDEr']<30] 114 | # print 'ground truth sents' 115 | # refId = evals[0]['ref_id'] 116 | # print 'refId: %s' % refId 117 | # print [sent['sent'] for sent in refer.Refs[refId]['sentences']] 118 | # 119 | # print 'generated sent (CIDEr score %0.1f)' % (evals[0]['CIDEr']) 120 | 121 | # print refEval.refToEval[8] 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /refTools/evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import tempfile 15 | import itertools 16 | 17 | # path to the stanford corenlp jar 18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 19 | 20 | # punctuations to be removed from the sentences 21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 23 | 24 | class PTBTokenizer: 25 | """Python wrapper of Stanford PTBTokenizer""" 26 | 27 | def tokenize(self, captions_for_image): 28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 29 | 'edu.stanford.nlp.process.PTBTokenizer', \ 30 | '-preserveLines', '-lowerCase'] 31 | 32 | # ====================================================== 33 | # prepare data for PTB Tokenizer 34 | # ====================================================== 35 | final_tokenized_captions_for_image = {} 36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 38 | 39 | # ====================================================== 40 | # save sentences to temporary file 41 | # ====================================================== 42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 44 | tmp_file.write(sentences.encode()) 45 | tmp_file.close() 46 | 47 | # ====================================================== 48 | # tokenize sentence 49 | # ====================================================== 50 | cmd.append(os.path.basename(tmp_file.name)) 51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 52 | stdout=subprocess.PIPE) 53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 54 | token_lines = token_lines.decode() 55 | lines = token_lines.split('\n') 56 | # remove temp file 57 | os.remove(tmp_file.name) 58 | 59 | # ====================================================== 60 | # create dictionary for tokenized captions 61 | # ====================================================== 62 | for k, line in zip(image_id, lines): 63 | if not k in final_tokenized_captions_for_image: 64 | final_tokenized_captions_for_image[k] = [] 65 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 66 | if w not in PUNCTUATIONS]) 67 | final_tokenized_captions_for_image[k].append(tokenized_caption) 68 | 69 | return final_tokenized_captions_for_image 70 | -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/tmp82iqkuu0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/refTools/evaluation/tokenizer/tmp82iqkuu0 -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/tmpn19wmqte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/refTools/evaluation/tokenizer/tmpn19wmqte -------------------------------------------------------------------------------- /refTools/refer_python3.py: -------------------------------------------------------------------------------- 1 | __author__ = 'licheng' 2 | 3 | """ 4 | This interface provides access to four datasets: 5 | 1) refclef 6 | 2) refcoco 7 | 3) refcoco+ 8 | 4) refcocog 9 | split by unc and google 10 | 11 | The following API functions are defined: 12 | REFER - REFER api class 13 | getRefIds - get ref ids that satisfy given filter conditions. 14 | getAnnIds - get ann ids that satisfy given filter conditions. 15 | getImgIds - get image ids that satisfy given filter conditions. 16 | getCatIds - get category ids that satisfy given filter conditions. 17 | loadRefs - load refs with the specified ref ids. 18 | loadAnns - load anns with the specified ann ids. 19 | loadImgs - load images with the specified image ids. 20 | loadCats - load category names with the specified category ids. 21 | getRefBox - get ref's bounding box [x, y, w, h] given the ref_id 22 | """ 23 | 24 | import sys 25 | import os.path as osp 26 | import json 27 | import _pickle as pickle 28 | import time 29 | import itertools 30 | import skimage.io as io 31 | import matplotlib.pyplot as plt 32 | from matplotlib.collections import PatchCollection 33 | from matplotlib.patches import Polygon, Rectangle 34 | from pprint import pprint 35 | import numpy as np 36 | # import cv2 37 | # from skimage.measure import label, regionprops 38 | 39 | class REFER: 40 | 41 | def __init__(self, data_root, dataset='refcoco', splitBy='unc'): 42 | # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog 43 | # also provide dataset name and splitBy information 44 | # e.g., dataset = 'refcoco', splitBy = 'unc' 45 | print('loading dataset %s into memory...' % dataset) 46 | self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) 47 | self.DATA_DIR = osp.join(data_root, dataset) 48 | if dataset in ['refcoco', 'refcoco+', 'refcocog']: 49 | self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014') 50 | elif dataset == 'refclef': 51 | self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') 52 | else: 53 | print('No refer dataset is called [%s]' % dataset) 54 | sys.exit() 55 | 56 | # load refs from data/dataset/refs(dataset).json 57 | tic = time.time() 58 | ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p') 59 | self.data = {} 60 | self.data['dataset'] = dataset 61 | self.data['refs'] = pickle.load(open(ref_file, 'rb')) 62 | 63 | # load annotations from data/dataset/instances.json 64 | instances_file = osp.join(self.DATA_DIR, 'instances.json') 65 | instances = json.load(open(instances_file, 'r')) 66 | self.data['images'] = instances['images'] 67 | self.data['annotations'] = instances['annotations'] 68 | self.data['categories'] = instances['categories'] 69 | 70 | # create index 71 | self.createIndex() 72 | print('DONE (t=%.2fs)' % (time.time()-tic)) 73 | 74 | def createIndex(self): 75 | # create sets of mapping 76 | # 1) Refs: {ref_id: ref} 77 | # 2) Anns: {ann_id: ann} 78 | # 3) Imgs: {image_id: image} 79 | # 4) Cats: {category_id: category_name} 80 | # 5) Sents: {sent_id: sent} 81 | # 6) imgToRefs: {image_id: refs} 82 | # 7) imgToAnns: {image_id: anns} 83 | # 8) refToAnn: {ref_id: ann} 84 | # 9) annToRef: {ann_id: ref} 85 | # 10) catToRefs: {category_id: refs} 86 | # 11) sentToRef: {sent_id: ref} 87 | # 12) sentToTokens: {sent_id: tokens} 88 | print('creating index...') 89 | # fetch info from instances 90 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} 91 | for ann in self.data['annotations']: 92 | Anns[ann['id']] = ann 93 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann] 94 | for img in self.data['images']: 95 | Imgs[img['id']] = img 96 | for cat in self.data['categories']: 97 | Cats[cat['id']] = cat['name'] 98 | 99 | # fetch info from refs 100 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} 101 | Sents, sentToRef, sentToTokens = {}, {}, {} 102 | for ref in self.data['refs']: 103 | # ids 104 | ref_id = ref['ref_id'] 105 | ann_id = ref['ann_id'] 106 | category_id = ref['category_id'] 107 | image_id = ref['image_id'] 108 | 109 | # add mapping related to ref 110 | Refs[ref_id] = ref 111 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] 112 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] 113 | refToAnn[ref_id] = Anns[ann_id] 114 | annToRef[ann_id] = ref 115 | 116 | # add mapping of sent 117 | for sent in ref['sentences']: 118 | Sents[sent['sent_id']] = sent 119 | sentToRef[sent['sent_id']] = ref 120 | sentToTokens[sent['sent_id']] = sent['tokens'] 121 | 122 | # create class members 123 | self.Refs = Refs 124 | self.Anns = Anns 125 | self.Imgs = Imgs 126 | self.Cats = Cats 127 | self.Sents = Sents 128 | self.imgToRefs = imgToRefs 129 | self.imgToAnns = imgToAnns 130 | self.refToAnn = refToAnn 131 | self.annToRef = annToRef 132 | self.catToRefs = catToRefs 133 | self.sentToRef = sentToRef 134 | self.sentToTokens = sentToTokens 135 | print('index created.') 136 | 137 | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): 138 | image_ids = image_ids if type(image_ids) == list else [image_ids] 139 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 140 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 141 | 142 | if len(image_ids)==len(cat_ids)==len(ref_ids)==len(split)==0: 143 | refs = self.data['refs'] 144 | else: 145 | if not len(image_ids) == 0: 146 | refs = [self.imgToRefs[image_id] for image_id in image_ids] 147 | else: 148 | refs = self.data['refs'] 149 | if not len(cat_ids) == 0: 150 | refs = [ref for ref in refs if ref['category_id'] in cat_ids] 151 | if not len(ref_ids) == 0: 152 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids] 153 | if not len(split) == 0: 154 | if split in ['testA', 'testB', 'testC']: 155 | refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ... 156 | elif split in ['testAB', 'testBC', 'testAC']: 157 | refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess... 158 | elif split == 'test': 159 | refs = [ref for ref in refs if 'test' in ref['split']] 160 | elif split == 'train' or split == 'val': 161 | refs = [ref for ref in refs if ref['split'] == split] 162 | else: 163 | print('No such split [%s]' % split) 164 | sys.exit() 165 | ref_ids = [ref['ref_id'] for ref in refs] 166 | return ref_ids 167 | 168 | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): 169 | image_ids = image_ids if type(image_ids) == list else [image_ids] 170 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 171 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 172 | 173 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: 174 | ann_ids = [ann['id'] for ann in self.data['annotations']] 175 | else: 176 | if not len(image_ids) == 0: 177 | lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns] 178 | anns = list(itertools.chain.from_iterable(lists)) 179 | else: 180 | anns = self.data['annotations'] 181 | if not len(cat_ids) == 0: 182 | anns = [ann for ann in anns if ann['category_id'] in cat_ids] 183 | ann_ids = [ann['id'] for ann in anns] 184 | if not len(ref_ids) == 0: 185 | ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) 186 | return ann_ids 187 | 188 | def getImgIds(self, ref_ids=[]): 189 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 190 | 191 | if not len(ref_ids) == 0: 192 | image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) 193 | else: 194 | image_ids = self.Imgs.keys() 195 | return image_ids 196 | 197 | def getCatIds(self): 198 | return self.Cats.keys() 199 | 200 | def loadRefs(self, ref_ids=[]): 201 | if type(ref_ids) == list: 202 | return [self.Refs[ref_id] for ref_id in ref_ids] 203 | elif type(ref_ids) == int: 204 | return [self.Refs[ref_ids]] 205 | 206 | def loadAnns(self, ann_ids=[]): 207 | if type(ann_ids) == list: 208 | return [self.Anns[ann_id] for ann_id in ann_ids] 209 | elif type(ann_ids) == int or type(ann_ids) == unicode: 210 | return [self.Anns[ann_ids]] 211 | 212 | def loadImgs(self, image_ids=[]): 213 | if type(image_ids) == list: 214 | return [self.Imgs[image_id] for image_id in image_ids] 215 | elif type(image_ids) == int: 216 | return [self.Imgs[image_ids]] 217 | 218 | def loadCats(self, cat_ids=[]): 219 | if type(cat_ids) == list: 220 | return [self.Cats[cat_id] for cat_id in cat_ids] 221 | elif type(cat_ids) == int: 222 | return [self.Cats[cat_ids]] 223 | 224 | def getRefBox(self, ref_id): 225 | ref = self.Refs[ref_id] 226 | ann = self.refToAnn[ref_id] 227 | return ann['bbox'] # [x, y, w, h] 228 | 229 | 230 | 231 | if __name__ == '__main__': 232 | refer = REFER(dataset='refcocog', splitBy='google') 233 | ref_ids = refer.getRefIds() 234 | print(len(ref_ids)) 235 | 236 | print(len(refer.Imgs)) 237 | print(len(refer.imgToRefs)) 238 | 239 | ref_ids = refer.getRefIds(split='train') 240 | print('There are %s training referred objects.' % len(ref_ids)) 241 | 242 | for ref_id in ref_ids: 243 | ref = refer.loadRefs(ref_id)[0] 244 | if len(ref['sentences']) < 2: 245 | continue 246 | 247 | pprint(ref) 248 | print('The label is %s.' % refer.Cats[ref['category_id']]) 249 | plt.figure() 250 | refer.showRef(ref, seg_box='box') 251 | plt.show() 252 | 253 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.9 2 | transformers==4.12.5 3 | ruamel_yaml 4 | opencv-python 5 | scikit-image 6 | matplotlib 7 | pycocotools 8 | pycocoevalcap 9 | datasets 10 | sentencepiece 11 | accelerate 12 | scikit-learn 13 | einops 14 | einops_exts -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | 3 | 4 | def create_scheduler(args, optimizer): 5 | if 'num_training_steps' not in args: 6 | args['num_training_steps'] = args['epochs'] * args['step_per_epoch'] 7 | print("### num_training_steps, ", args['num_training_steps'], flush=True) 8 | 9 | if 'min_rate' not in args: 10 | args['min_rate'] = 0.0 11 | 12 | if isinstance(args['num_warmup_steps'], float): 13 | assert 0 <= args['num_warmup_steps'] < 1 14 | args['num_warmup_steps'] = int(args['num_training_steps'] * args['num_warmup_steps']) 15 | print("### num_warmup_steps, ", args['num_warmup_steps'], flush=True) 16 | 17 | if args.sched == 'linear': 18 | def lr_lambda(current_step: int): 19 | if current_step < args.num_warmup_steps: 20 | return float(current_step) / float(max(1, args.num_warmup_steps)) 21 | return max( 22 | args['min_rate'], float(args.num_training_steps - (1-args['min_rate'])*current_step) / float( 23 | max(1, args.num_training_steps - args.num_warmup_steps)) 24 | ) 25 | 26 | lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) 27 | 28 | else: 29 | raise NotImplementedError(f"args.sched == {args.sched}") 30 | 31 | return lr_scheduler 32 | -------------------------------------------------------------------------------- /utils/checkpointer.py: -------------------------------------------------------------------------------- 1 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 2 | # Github: https://github.com/zengyan-97/X-VLM 3 | # Copyright (c) 2022, ByteDance Inc. 4 | # All rights reserved. 5 | 6 | from typing import Union, Dict, List, Tuple, Any, Callable 7 | import logging 8 | import os 9 | import re 10 | import time 11 | 12 | import torch 13 | 14 | from utils.hdfs_io import hexists, hmkdir, hcopy 15 | from utils.torch_io import save as hdfs_torch_save 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class Checkpointer: 20 | def __init__(self, 21 | serialization_dir: str = ".output") -> None: 22 | self._serialization_dir = serialization_dir 23 | if not hexists(self._serialization_dir): 24 | hmkdir(self._serialization_dir) 25 | 26 | def save_checkpoint(self, 27 | epoch: Union[int, str], 28 | model_state: Dict[str, Any], 29 | training_states: Dict[str, Any], 30 | step: int = -1) -> None: 31 | """ 32 | Save ckpt to local or HDFS 33 | """ 34 | if step > 0: 35 | model_path = os.path.join( 36 | self._serialization_dir, "model_state_step_{}.th".format(step)) 37 | hdfs_torch_save(model_state, model_path) 38 | 39 | else: 40 | model_path = os.path.join( 41 | self._serialization_dir, "model_state_epoch_{}.th".format(epoch)) 42 | 43 | training_path = os.path.join(self._serialization_dir, 44 | "training_state_latest.th") 45 | hdfs_torch_save(model_state, model_path) 46 | hdfs_torch_save({**training_states, "epoch": epoch}, training_path) 47 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # 4 | # Description: Describes the class to compute the CIDEr 5 | # (Consensus-Based Image Description Evaluation) Metric 6 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 7 | # 8 | # Creation Date: Sun Feb 8 14:16:54 2015 9 | # 10 | # Authors: Ramakrishna Vedantam and 11 | # Tsung-Yi Lin 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | from .cider_scorer import CiderScorer 17 | 18 | 19 | class Cider: 20 | """ 21 | Main Class to compute the CIDEr metric 22 | 23 | """ 24 | def __init__(self, n=4, df="corpus"): 25 | """ 26 | Initialize the CIDEr scoring function 27 | : param n (int): n-gram size 28 | : param df (string): specifies where to get the IDF values from 29 | takes values 'corpus', 'coco-train' 30 | : return: None 31 | """ 32 | # set cider to sum over 1 to 4-grams 33 | self._n = n 34 | self._df = df 35 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) 36 | 37 | def compute_score(self, gts, res): 38 | """ 39 | Main function to compute CIDEr score 40 | : param gts (dict) : {image:tokenized reference sentence} 41 | : param res (dict) : {image:tokenized candidate sentence} 42 | : return: cider (float) : computed CIDEr score for the corpus 43 | """ 44 | 45 | # clear all the previous hypos and refs 46 | self.cider_scorer.clear() 47 | 48 | for res_id in res: 49 | 50 | hypo = res_id['caption'] 51 | ref = gts[res_id['image_id']] 52 | 53 | # Sanity check. 54 | assert(type(hypo) is list) 55 | assert(len(hypo) == 1) 56 | assert(type(ref) is list) 57 | assert(len(ref) > 0) 58 | self.cider_scorer += (hypo[0], ref) 59 | 60 | (score, scores) = self.cider_scorer.compute_score() 61 | 62 | return score, scores 63 | 64 | def method(self): 65 | return "CIDEr" 66 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import copy 9 | import six 10 | from six.moves import cPickle 11 | from collections import defaultdict 12 | import numpy as np 13 | import math 14 | import os 15 | 16 | def precook(s, n=4, out=False): 17 | """ 18 | Takes a string as input and returns an object that can be given to 19 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 20 | can take string arguments as well. 21 | :param s: string : sentence to be converted into ngrams 22 | :param n: int : number of ngrams for which representation is calculated 23 | :return: term frequency vector for occuring ngrams 24 | """ 25 | words = s.split() 26 | counts = defaultdict(int) 27 | for k in range(1,n+1): 28 | for i in range(len(words)-k+1): 29 | ngram = tuple(words[i:i+k]) 30 | counts[ngram] += 1 31 | return counts 32 | 33 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 34 | '''Takes a list of reference sentences for a single segment 35 | and returns an object that encapsulates everything that BLEU 36 | needs to know about them. 37 | :param refs: list of string : reference sentences for some image 38 | :param n: int : number of ngrams for which (ngram) representation is calculated 39 | :return: result (list of dict) 40 | ''' 41 | return [precook(ref, n) for ref in refs] 42 | 43 | def cook_test(test, n=4): 44 | '''Takes a test sentence and returns an object that 45 | encapsulates everything that BLEU needs to know about it. 46 | :param test: list of string : hypothesis sentence for some image 47 | :param n: int : number of ngrams for which (ngram) representation is calculated 48 | :return: result (dict) 49 | ''' 50 | return precook(test, n, True) 51 | 52 | class CiderScorer(object): 53 | """CIDEr scorer. 54 | """ 55 | 56 | def copy(self): 57 | ''' copy the refs.''' 58 | new = CiderScorer(n=self.n) 59 | new.ctest = copy.copy(self.ctest) 60 | new.crefs = copy.copy(self.crefs) 61 | return new 62 | 63 | def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0): 64 | ''' singular instance ''' 65 | self.n = n 66 | self.sigma = sigma 67 | self.crefs = [] 68 | self.ctest = [] 69 | self.df_mode = df_mode 70 | self.ref_len = None 71 | if self.df_mode != "corpus": 72 | pkl_file = cPickle.load(open(os.path.join('data', df_mode + '.p'),'rb'), **(dict(encoding='latin1') if six.PY3 else {})) 73 | self.ref_len = np.log(float(pkl_file['ref_len'])) 74 | self.document_frequency = pkl_file['document_frequency'] 75 | self.cook_append(test, refs) 76 | 77 | def clear(self): 78 | self.crefs = [] 79 | self.ctest = [] 80 | 81 | def cook_append(self, test, refs): 82 | '''called by constructor and __iadd__ to avoid creating new instances.''' 83 | 84 | if refs is not None: 85 | self.crefs.append(cook_refs(refs)) 86 | if test is not None: 87 | self.ctest.append(cook_test(test)) ## N.B.: -1 88 | else: 89 | self.ctest.append(None) # lens of crefs and ctest have to match 90 | 91 | def size(self): 92 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 93 | return len(self.crefs) 94 | 95 | def __iadd__(self, other): 96 | '''add an instance (e.g., from another sentence).''' 97 | 98 | if type(other) is tuple: 99 | ## avoid creating new CiderScorer instances 100 | self.cook_append(other[0], other[1]) 101 | else: 102 | self.ctest.extend(other.ctest) 103 | self.crefs.extend(other.crefs) 104 | 105 | return self 106 | def compute_doc_freq(self): 107 | ''' 108 | Compute term frequency for reference data. 109 | This will be used to compute idf (inverse document frequency later) 110 | The term frequency is stored in the object 111 | :return: None 112 | ''' 113 | for refs in self.crefs: 114 | # refs, k ref captions of one image 115 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 116 | self.document_frequency[ngram] += 1 117 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 118 | 119 | def compute_cider(self): 120 | def counts2vec(cnts): 121 | """ 122 | Function maps counts of ngram to vector of tfidf weights. 123 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 124 | The n-th entry of array denotes length of n-grams. 125 | :param cnts: 126 | :return: vec (array of dict), norm (array of float), length (int) 127 | """ 128 | vec = [defaultdict(float) for _ in range(self.n)] 129 | length = 0 130 | norm = [0.0 for _ in range(self.n)] 131 | for (ngram,term_freq) in cnts.items(): 132 | # give word count 1 if it doesn't appear in reference corpus 133 | df = np.log(max(1.0, self.document_frequency[ngram])) 134 | # ngram index 135 | n = len(ngram)-1 136 | # tf (term_freq) * idf (precomputed idf) for n-grams 137 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 138 | # compute norm for the vector. the norm will be used for 139 | # computing similarity 140 | norm[n] += pow(vec[n][ngram], 2) 141 | 142 | if n == 1: 143 | length += term_freq 144 | norm = [np.sqrt(n) for n in norm] 145 | return vec, norm, length 146 | 147 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 148 | ''' 149 | Compute the cosine similarity of two vectors. 150 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 151 | :param vec_ref: array of dictionary for vector corresponding to reference 152 | :param norm_hyp: array of float for vector corresponding to hypothesis 153 | :param norm_ref: array of float for vector corresponding to reference 154 | :param length_hyp: int containing length of hypothesis 155 | :param length_ref: int containing length of reference 156 | :return: array of score for each n-grams cosine similarity 157 | ''' 158 | delta = float(length_hyp - length_ref) 159 | # measure consine similarity 160 | val = np.array([0.0 for _ in range(self.n)]) 161 | for n in range(self.n): 162 | # ngram 163 | for (ngram,count) in vec_hyp[n].items(): 164 | val[n] += vec_hyp[n][ngram] * vec_ref[n][ngram] 165 | 166 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 167 | val[n] /= (norm_hyp[n]*norm_ref[n]) 168 | 169 | assert(not math.isnan(val[n])) 170 | return val 171 | 172 | # compute log reference length 173 | if self.df_mode == "corpus": 174 | self.ref_len = np.log(float(len(self.crefs))) 175 | 176 | scores = [] 177 | for test, refs in zip(self.ctest, self.crefs): 178 | # compute vector for test captions 179 | vec, norm, length = counts2vec(test) 180 | # compute vector for ref captions 181 | score = np.array([0.0 for _ in range(self.n)]) 182 | for ref in refs: 183 | vec_ref, norm_ref, length_ref = counts2vec(ref) 184 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 185 | # change by vrama91 - mean of ngram scores, instead of sum 186 | score_avg = np.mean(score) 187 | # divide by number of references 188 | score_avg /= len(refs) 189 | # multiply score by 10 190 | score_avg *= 10.0 191 | # append score of an image to the score list 192 | scores.append(score_avg) 193 | return scores 194 | 195 | def compute_score(self, option=None, verbose=0): 196 | # compute idf 197 | if self.df_mode == "corpus": 198 | self.document_frequency = defaultdict(float) 199 | self.compute_doc_freq() 200 | # assert to check document frequency 201 | assert(len(self.ctest) >= max(self.document_frequency.values())) 202 | # import json for now and write the corresponding files 203 | # compute cider score 204 | score = self.compute_cider() 205 | # debug 206 | # print score 207 | return np.mean(np.array(score)), np.array(score) 208 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/ciderD/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/ciderD/ciderD.py: -------------------------------------------------------------------------------- 1 | # Filename: ciderD.py 2 | # 3 | # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | from .ciderD_scorer import CiderScorer 14 | import pdb 15 | 16 | class CiderD: 17 | """ 18 | Main Class to compute the CIDEr metric 19 | 20 | """ 21 | def __init__(self, n=4, sigma=6.0, df="corpus"): 22 | # set cider to sum over 1 to 4-grams 23 | self._n = n 24 | # set the standard deviation parameter for gaussian penalty 25 | self._sigma = sigma 26 | # set which where to compute document frequencies from 27 | self._df = df 28 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) 29 | 30 | def compute_score(self, gts, res): 31 | """ 32 | Main function to compute CIDEr score 33 | :param hypo_for_image (dict) : dictionary with key and value 34 | ref_for_image (dict) : dictionary with key and value 35 | :return: cider (float) : computed CIDEr score for the corpus 36 | """ 37 | 38 | # clear all the previous hypos and refs 39 | tmp_cider_scorer = self.cider_scorer.copy_empty() 40 | tmp_cider_scorer.clear() 41 | for res_id in res: 42 | 43 | hypo = res_id['caption'] 44 | ref = gts[res_id['image_id']] 45 | 46 | # Sanity check. 47 | assert(type(hypo) is list) 48 | assert(len(hypo) == 1) 49 | assert(type(ref) is list) 50 | assert(len(ref) > 0) 51 | tmp_cider_scorer += (hypo[0], ref) 52 | 53 | (score, scores) = tmp_cider_scorer.compute_score() 54 | 55 | return score, scores 56 | 57 | def method(self): 58 | return "CIDEr-D" 59 | -------------------------------------------------------------------------------- /utils/cider/pyciderevalcap/ciderD/ciderD_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import copy 9 | from collections import defaultdict 10 | import numpy as np 11 | import pdb 12 | import math 13 | import six 14 | from six.moves import cPickle 15 | import os 16 | 17 | def precook(s, n=4, out=False): 18 | """ 19 | Takes a string as input and returns an object that can be given to 20 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 21 | can take string arguments as well. 22 | :param s: string : sentence to be converted into ngrams 23 | :param n: int : number of ngrams for which representation is calculated 24 | :return: term frequency vector for occuring ngrams 25 | """ 26 | words = s.split() 27 | counts = defaultdict(int) 28 | for k in range(1,n+1): 29 | for i in range(len(words)-k+1): 30 | ngram = tuple(words[i:i+k]) 31 | counts[ngram] += 1 32 | return counts 33 | 34 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 35 | '''Takes a list of reference sentences for a single segment 36 | and returns an object that encapsulates everything that BLEU 37 | needs to know about them. 38 | :param refs: list of string : reference sentences for some image 39 | :param n: int : number of ngrams for which (ngram) representation is calculated 40 | :return: result (list of dict) 41 | ''' 42 | return [precook(ref, n) for ref in refs] 43 | 44 | def cook_test(test, n=4): 45 | '''Takes a test sentence and returns an object that 46 | encapsulates everything that BLEU needs to know about it. 47 | :param test: list of string : hypothesis sentence for some image 48 | :param n: int : number of ngrams for which (ngram) representation is calculated 49 | :return: result (dict) 50 | ''' 51 | return precook(test, n, True) 52 | 53 | class CiderScorer(object): 54 | """CIDEr scorer. 55 | """ 56 | 57 | def copy(self): 58 | ''' copy the refs.''' 59 | new = CiderScorer(n=self.n) 60 | new.ctest = copy.copy(self.ctest) 61 | new.crefs = copy.copy(self.crefs) 62 | return new 63 | 64 | def copy_empty(self): 65 | new = CiderScorer(df_mode="corpus", n=self.n, sigma=self.sigma) 66 | new.df_mode = self.df_mode 67 | new.ref_len = self.ref_len 68 | new.document_frequency = self.document_frequency 69 | return new 70 | 71 | def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0): 72 | ''' singular instance ''' 73 | self.n = n 74 | self.sigma = sigma 75 | self.crefs = [] 76 | self.ctest = [] 77 | self.df_mode = df_mode 78 | self.ref_len = None 79 | if self.df_mode != "corpus": 80 | pkl_file = cPickle.load(open(df_mode,'rb'), **(dict(encoding='latin1') if six.PY3 else {})) 81 | self.ref_len = np.log(float(pkl_file['ref_len'])) 82 | self.document_frequency = pkl_file['document_frequency'] 83 | else: 84 | self.document_frequency = None 85 | self.cook_append(test, refs) 86 | 87 | def clear(self): 88 | self.crefs = [] 89 | self.ctest = [] 90 | 91 | def cook_append(self, test, refs): 92 | '''called by constructor and __iadd__ to avoid creating new instances.''' 93 | 94 | if refs is not None: 95 | self.crefs.append(cook_refs(refs)) 96 | if test is not None: 97 | self.ctest.append(cook_test(test)) ## N.B.: -1 98 | else: 99 | self.ctest.append(None) # lens of crefs and ctest have to match 100 | 101 | def size(self): 102 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 103 | return len(self.crefs) 104 | 105 | def __iadd__(self, other): 106 | '''add an instance (e.g., from another sentence).''' 107 | 108 | if type(other) is tuple: 109 | ## avoid creating new CiderScorer instances 110 | self.cook_append(other[0], other[1]) 111 | else: 112 | self.ctest.extend(other.ctest) 113 | self.crefs.extend(other.crefs) 114 | 115 | return self 116 | def compute_doc_freq(self): 117 | ''' 118 | Compute term frequency for reference data. 119 | This will be used to compute idf (inverse document frequency later) 120 | The term frequency is stored in the object 121 | :return: None 122 | ''' 123 | for refs in self.crefs: 124 | # refs, k ref captions of one image 125 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 126 | self.document_frequency[ngram] += 1 127 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 128 | 129 | def compute_cider(self): 130 | def counts2vec(cnts): 131 | """ 132 | Function maps counts of ngram to vector of tfidf weights. 133 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 134 | The n-th entry of array denotes length of n-grams. 135 | :param cnts: 136 | :return: vec (array of dict), norm (array of float), length (int) 137 | """ 138 | vec = [defaultdict(float) for _ in range(self.n)] 139 | length = 0 140 | norm = [0.0 for _ in range(self.n)] 141 | for (ngram,term_freq) in cnts.items(): 142 | # give word count 1 if it doesn't appear in reference corpus 143 | df = np.log(max(1.0, self.document_frequency[ngram])) 144 | # ngram index 145 | n = len(ngram)-1 146 | # tf (term_freq) * idf (precomputed idf) for n-grams 147 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 148 | # compute norm for the vector. the norm will be used for computing similarity 149 | norm[n] += pow(vec[n][ngram], 2) 150 | 151 | if n == 1: 152 | length += term_freq 153 | norm = [np.sqrt(n) for n in norm] 154 | return vec, norm, length 155 | 156 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 157 | ''' 158 | Compute the cosine similarity of two vectors. 159 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 160 | :param vec_ref: array of dictionary for vector corresponding to reference 161 | :param norm_hyp: array of float for vector corresponding to hypothesis 162 | :param norm_ref: array of float for vector corresponding to reference 163 | :param length_hyp: int containing length of hypothesis 164 | :param length_ref: int containing length of reference 165 | :return: array of score for each n-grams cosine similarity 166 | ''' 167 | delta = float(length_hyp - length_ref) 168 | # measure consine similarity 169 | val = np.array([0.0 for _ in range(self.n)]) 170 | for n in range(self.n): 171 | # ngram 172 | for (ngram,count) in vec_hyp[n].items(): 173 | # vrama91 : added clipping 174 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 175 | 176 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 177 | val[n] /= (norm_hyp[n]*norm_ref[n]) 178 | 179 | assert(not math.isnan(val[n])) 180 | # vrama91: added a length based gaussian penalty 181 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 182 | return val 183 | 184 | # compute log reference length 185 | if self.df_mode == "corpus": 186 | self.ref_len = np.log(float(len(self.crefs))) 187 | #elif self.df_mode == "coco-val-df": 188 | # if coco option selected, use length of coco-val set 189 | # self.ref_len = np.log(float(40504)) 190 | 191 | scores = [] 192 | for test, refs in zip(self.ctest, self.crefs): 193 | # compute vector for test captions 194 | vec, norm, length = counts2vec(test) 195 | # compute vector for ref captions 196 | score = np.array([0.0 for _ in range(self.n)]) 197 | for ref in refs: 198 | vec_ref, norm_ref, length_ref = counts2vec(ref) 199 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 200 | # change by vrama91 - mean of ngram scores, instead of sum 201 | score_avg = np.mean(score) 202 | # divide by number of references 203 | score_avg /= len(refs) 204 | # multiply score by 10 205 | score_avg *= 10.0 206 | # append score of an image to the score list 207 | scores.append(score_avg) 208 | return scores 209 | 210 | def compute_score(self, option=None, verbose=0): 211 | # compute idf 212 | if self.df_mode == "corpus": 213 | self.document_frequency = defaultdict(float) 214 | self.compute_doc_freq() 215 | # assert to check document frequency 216 | assert(len(self.ctest) >= max(self.document_frequency.values())) 217 | # import json for now and write the corresponding files 218 | # compute cider score 219 | score = self.compute_cider() 220 | # debug 221 | # print score 222 | return np.mean(np.array(score)), np.array(score) 223 | -------------------------------------------------------------------------------- /utils/hdfs_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 4 | # Github: https://github.com/zengyan-97/X-VLM 5 | # Copyright (c) 2022, ByteDance Inc. 6 | # All rights reserved. 7 | 8 | import sys 9 | from typing import IO, Any, List 10 | 11 | import shutil 12 | import subprocess 13 | from contextlib import contextmanager 14 | import os 15 | import glob 16 | import threading 17 | 18 | HADOOP_BIN = 'HADOOP_ROOT_LOGGER=ERROR,console /SET/PATH/TO/hadoop/bin/hdfs' 19 | 20 | __all__ = ['hlist_files', 'hopen', 'hexists', 'hmkdir'] 21 | 22 | 23 | @contextmanager # type: ignore 24 | def hopen(hdfs_path: str, mode: str = "r") -> IO[Any]: 25 | """ 26 | open a file on hdfs with contextmanager. 27 | 28 | Args: 29 | mode (str): supports ["r", "w", "wa"] 30 | """ 31 | pipe = None 32 | if mode.startswith("r"): 33 | pipe = subprocess.Popen( 34 | "{} dfs -text {}".format(HADOOP_BIN, hdfs_path), shell=True, stdout=subprocess.PIPE) 35 | yield pipe.stdout 36 | pipe.stdout.close() # type: ignore 37 | pipe.wait() 38 | return 39 | if mode == "wa" or mode == "a": 40 | pipe = subprocess.Popen( 41 | "{} dfs -appendToFile - {}".format(HADOOP_BIN, hdfs_path), shell=True, stdin=subprocess.PIPE) 42 | yield pipe.stdin 43 | pipe.stdin.close() # type: ignore 44 | pipe.wait() 45 | return 46 | if mode.startswith("w"): 47 | pipe = subprocess.Popen( 48 | "{} dfs -put -f - {}".format(HADOOP_BIN, hdfs_path), shell=True, stdin=subprocess.PIPE) 49 | yield pipe.stdin 50 | pipe.stdin.close() # type: ignore 51 | pipe.wait() 52 | return 53 | raise RuntimeError("unsupported io mode: {}".format(mode)) 54 | 55 | 56 | def hlist_files(folders: List[str]) -> List[str]: 57 | files = [] 58 | for folder in folders: 59 | if folder.startswith('hdfs'): 60 | pipe = subprocess.Popen("{} dfs -ls {}".format(HADOOP_BIN, folder), shell=True, 61 | stdout=subprocess.PIPE) 62 | # output, _ = pipe.communicate() 63 | for line in pipe.stdout: # type: ignore 64 | line = line.strip() 65 | # drwxr-xr-x - user group 4 file 66 | if len(line.split()) < 5: 67 | continue 68 | files.append(line.split()[-1].decode("utf8")) 69 | pipe.stdout.close() # type: ignore 70 | pipe.wait() 71 | else: 72 | if os.path.isdir(folder): 73 | files.extend([os.path.join(folder, d) for d in os.listdir(folder)]) 74 | elif os.path.isfile(folder): 75 | files.append(folder) 76 | else: 77 | print('Path {} is invalid'.format(folder)) 78 | sys.stdout.flush() 79 | 80 | return files 81 | 82 | 83 | def hexists(file_path: str) -> bool: 84 | """ hdfs capable to check whether a file_path is exists """ 85 | if file_path.startswith('hdfs'): 86 | return os.system("{} dfs -test -e {}".format(HADOOP_BIN, file_path)) == 0 87 | return os.path.exists(file_path) 88 | 89 | 90 | def hmkdir(file_path: str) -> bool: 91 | """ hdfs mkdir """ 92 | if file_path.startswith('hdfs'): 93 | os.system("{} dfs -mkdir -p {}".format(HADOOP_BIN, file_path)) # exist ok 94 | else: 95 | if not os.path.exists(file_path): 96 | os.makedirs(file_path, exist_ok=True) 97 | return True 98 | 99 | 100 | def hcopy(from_path: str, to_path: str) -> bool: 101 | """ hdfs copy """ 102 | if to_path.startswith("hdfs"): 103 | if from_path.startswith("hdfs"): 104 | os.system("{} dfs -cp -f {} {}".format(HADOOP_BIN, from_path, to_path)) 105 | else: 106 | os.system("{} dfs -copyFromLocal -f {} {}".format(HADOOP_BIN, from_path, to_path)) 107 | else: 108 | if from_path.startswith("hdfs"): 109 | os.system("{} dfs -text {} > {}".format(HADOOP_BIN, from_path, to_path)) 110 | else: 111 | shutil.copy(from_path, to_path) 112 | return True 113 | 114 | 115 | def hcountline(path): 116 | ''' 117 | count line in file 118 | ''' 119 | count = 0 120 | if path.startswith('hdfs'): 121 | with hopen(path, 'r') as f: 122 | for line in f: 123 | count += 1 124 | else: 125 | with open(path, 'r') as f: 126 | for line in f: 127 | count += 1 128 | return count 129 | -------------------------------------------------------------------------------- /utils/marvl_preproc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | def marvl_preproc(ipath, opath): 6 | if not os.path.exists(opath): 7 | os.makedirs(opath) 8 | # test 9 | root = os.path.join(ipath, 'zero_shot/annotations') 10 | for fp in os.listdir(root): 11 | with open(os.path.join(root, fp)) as f, open(os.path.join(opath, fp[:-1]), 'w') as wf: 12 | data = [] 13 | for l in f: 14 | d = json.loads(l) 15 | data.append({ 16 | 'sentence': d['caption'], 17 | 'label': d['label'], 18 | 'images': ['images/marvl_official/{}/images/{}/{}'.format(d['language'], d['left_img'].split('-')[0], d['left_img']), 19 | 'images/marvl_official/{}/images/{}/{}'.format(d['language'], d['right_img'].split('-')[0], d['right_img'])] 20 | }) 21 | json.dump(data, wf) 22 | # few shot 23 | root = os.path.join(ipath, 'few_shot/annotations') 24 | for fp in os.listdir(root): 25 | with open(os.path.join(root, fp)) as f, open(os.path.join(opath, fp[:-1]), 'w') as wf: 26 | data = [] 27 | for l in f: 28 | d = json.loads(l) 29 | data.append({ 30 | 'sentence': d['caption'], 31 | 'label': d['label'], 32 | 'images': ['images/marvl_fewshot/{}/all/{}'.format(d['language'], d['left_img'].split('/')[-1]), 33 | 'images/marvl_fewshot/{}/all/{}'.format(d['language'], d['right_img'].split('/')[-1])] 34 | }) 35 | json.dump(data, wf) 36 | -------------------------------------------------------------------------------- /utils/torch_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 4 | # Github: https://github.com/zengyan-97/X-VLM 5 | # Copyright (c) 2022, ByteDance Inc. 6 | # All rights reserved. 7 | 8 | import io 9 | import torch 10 | 11 | from .hdfs_io import hopen 12 | 13 | 14 | def load(filepath: str, **kwargs): 15 | """ load model """ 16 | if not filepath.startswith("hdfs://"): 17 | return torch.load(filepath, **kwargs) 18 | with hopen(filepath, "rb") as reader: 19 | accessor = io.BytesIO(reader.read()) 20 | state_dict = torch.load(accessor, **kwargs) 21 | del accessor 22 | return state_dict 23 | 24 | 25 | def save(obj, filepath: str, **kwargs): 26 | """ save model """ 27 | if filepath.startswith("hdfs://"): 28 | with hopen(filepath, "wb") as writer: 29 | torch.save(obj, writer, **kwargs) 30 | else: 31 | torch.save(obj, filepath, **kwargs) 32 | -------------------------------------------------------------------------------- /vqaTools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | -------------------------------------------------------------------------------- /vqaTools/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | __version__ = '0.9' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | class VQA: 24 | def __init__(self, annotation_file=None, question_file=None): 25 | """ 26 | Constructor of VQA helper class for reading and visualizing questions and answers. 27 | :param annotation_file (str): location of VQA annotation file 28 | :return: 29 | """ 30 | # load dataset 31 | self.dataset = {} 32 | self.questions = {} 33 | self.qa = {} 34 | self.qqa = {} 35 | self.imgToQA = {} 36 | if not annotation_file == None and not question_file == None: 37 | print('loading VQA annotations and questions into memory...') 38 | time_t = datetime.datetime.utcnow() 39 | dataset = json.load(open(annotation_file, 'r')) 40 | questions = json.load(open(question_file, 'r')) 41 | self.dataset = dataset 42 | self.questions = questions 43 | self.createIndex() 44 | 45 | def createIndex(self): 46 | # create index 47 | print('creating index...') 48 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} 49 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']} 50 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} 51 | for ann in self.dataset['annotations']: 52 | imgToQA[ann['image_id']] += [ann] 53 | qa[ann['question_id']] = ann 54 | for ques in self.questions['questions']: 55 | qqa[ques['question_id']] = ques 56 | print('index created!') 57 | 58 | # create class members 59 | self.qa = qa 60 | self.qqa = qqa 61 | self.imgToQA = imgToQA 62 | 63 | def info(self): 64 | """ 65 | Print information about the VQA annotation file. 66 | :return: 67 | """ 68 | for key, value in self.datset['info'].items(): 69 | print('%s: %s'%(key, value)) 70 | 71 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 72 | """ 73 | Get question ids that satisfy given filter conditions. default skips that filter 74 | :param imgIds (int array) : get question ids for given imgs 75 | quesTypes (str array) : get question ids for given question types 76 | ansTypes (str array) : get question ids for given answer types 77 | :return: ids (int array) : integer array of question ids 78 | """ 79 | imgIds = imgIds if type(imgIds) == list else [imgIds] 80 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 81 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 82 | 83 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 84 | anns = self.dataset['annotations'] 85 | else: 86 | if not len(imgIds) == 0: 87 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[]) 88 | else: 89 | anns = self.dataset['annotations'] 90 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 91 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 92 | ids = [ann['question_id'] for ann in anns] 93 | return ids 94 | 95 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 96 | """ 97 | Get image ids that satisfy given filter conditions. default skips that filter 98 | :param quesIds (int array) : get image ids for given question ids 99 | quesTypes (str array) : get image ids for given question types 100 | ansTypes (str array) : get image ids for given answer types 101 | :return: ids (int array) : integer array of image ids 102 | """ 103 | quesIds = quesIds if type(quesIds) == list else [quesIds] 104 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 105 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 106 | 107 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 108 | anns = self.dataset['annotations'] 109 | else: 110 | if not len(quesIds) == 0: 111 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[]) 112 | else: 113 | anns = self.dataset['annotations'] 114 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 115 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 116 | ids = [ann['image_id'] for ann in anns] 117 | return ids 118 | 119 | def loadQA(self, ids=[]): 120 | """ 121 | Load questions and answers with the specified question ids. 122 | :param ids (int array) : integer ids specifying question ids 123 | :return: qa (object array) : loaded qa objects 124 | """ 125 | if type(ids) == list: 126 | return [self.qa[id] for id in ids] 127 | elif type(ids) == int: 128 | return [self.qa[ids]] 129 | 130 | def showQA(self, anns): 131 | """ 132 | Display the specified annotations. 133 | :param anns (array of object): annotations to display 134 | :return: None 135 | """ 136 | if len(anns) == 0: 137 | return 0 138 | for ann in anns: 139 | quesId = ann['question_id'] 140 | print("Question: %s" %(self.qqa[quesId]['question'])) 141 | for ans in ann['answers']: 142 | print("Answer %d: %s" %(ans['answer_id'], ans['answer'])) 143 | 144 | def loadRes(self, resFile, quesFile): 145 | """ 146 | Load result file and return a result object. 147 | :param resFile (str) : file name of result file 148 | :return: res (obj) : result api object 149 | """ 150 | res = VQA() 151 | res.questions = json.load(open(quesFile)) 152 | res.dataset['info'] = copy.deepcopy(self.questions['info']) 153 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) 154 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) 155 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype']) 156 | res.dataset['license'] = copy.deepcopy(self.questions['license']) 157 | 158 | print('Loading and preparing results... ') 159 | time_t = datetime.datetime.utcnow() 160 | anns = json.load(open(resFile)) 161 | assert type(anns) == list, 'results is not an array of objects' 162 | annsQuesIds = [ann['question_id'] for ann in anns] 163 | assert set(annsQuesIds) == set(self.getQuesIds()), \ 164 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' 165 | for ann in anns: 166 | quesId = ann['question_id'] 167 | if res.dataset['task_type'] == 'Multiple Choice': 168 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices' 169 | qaAnn = self.qa[quesId] 170 | ann['image_id'] = qaAnn['image_id'] 171 | ann['question_type'] = qaAnn['question_type'] 172 | ann['answer_type'] = qaAnn['answer_type'] 173 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())) 174 | 175 | res.dataset['annotations'] = anns 176 | res.createIndex() 177 | return res 178 | -------------------------------------------------------------------------------- /vqaTools/vqaEval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | __author__='aagrawal' 4 | 5 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 6 | # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). 7 | import sys 8 | import re 9 | 10 | class VQAEval: 11 | def __init__(self, vqa, vqaRes, n=2): 12 | self.n = n 13 | self.accuracy = {} 14 | self.evalQA = {} 15 | self.evalQuesType = {} 16 | self.evalAnsType = {} 17 | self.vqa = vqa 18 | self.vqaRes = vqaRes 19 | self.params = {'question_id': vqa.getQuesIds()} 20 | self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", 21 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", 22 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", 23 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", 24 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 25 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 26 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 27 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", 28 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", 29 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", 30 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", 31 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", 32 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", 33 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", 34 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", 35 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", 36 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", 37 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 38 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 39 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 40 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", 41 | "youll": "you'll", "youre": "you're", "youve": "you've"} 42 | self.manualMap = { 'none': '0', 43 | 'zero': '0', 44 | 'one': '1', 45 | 'two': '2', 46 | 'three': '3', 47 | 'four': '4', 48 | 'five': '5', 49 | 'six': '6', 50 | 'seven': '7', 51 | 'eight': '8', 52 | 'nine': '9', 53 | 'ten': '10' 54 | } 55 | self.articles = ['a', 56 | 'an', 57 | 'the' 58 | ] 59 | 60 | 61 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 62 | self.commaStrip = re.compile("(\d)(,)(\d)") 63 | self.punct = [';', r"/", '[', ']', '"', '{', '}', 64 | '(', ')', '=', '+', '\\', '_', '-', 65 | '>', '<', '@', '`', ',', '?', '!'] 66 | 67 | 68 | def evaluate(self, quesIds=None): 69 | if quesIds == None: 70 | quesIds = [quesId for quesId in self.params['question_id']] 71 | gts = {} 72 | res = {} 73 | for quesId in quesIds: 74 | gts[quesId] = self.vqa.qa[quesId] 75 | res[quesId] = self.vqaRes.qa[quesId] 76 | 77 | # ================================================= 78 | # Compute accuracy 79 | # ================================================= 80 | accQA = [] 81 | accQuesType = {} 82 | accAnsType = {} 83 | print ("computing accuracy") 84 | step = 0 85 | for quesId in quesIds: 86 | resAns = res[quesId]['answer'] 87 | resAns = resAns.replace('\n', ' ') 88 | resAns = resAns.replace('\t', ' ') 89 | resAns = resAns.strip() 90 | resAns = self.processPunctuation(resAns) 91 | resAns = self.processDigitArticle(resAns) 92 | gtAcc = [] 93 | gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] 94 | if len(set(gtAnswers)) > 1: 95 | for ansDic in gts[quesId]['answers']: 96 | ansDic['answer'] = self.processPunctuation(ansDic['answer']) 97 | for gtAnsDatum in gts[quesId]['answers']: 98 | otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum] 99 | matchingAns = [item for item in otherGTAns if item['answer']==resAns] 100 | acc = min(1, float(len(matchingAns))/3) 101 | gtAcc.append(acc) 102 | quesType = gts[quesId]['question_type'] 103 | ansType = gts[quesId]['answer_type'] 104 | avgGTAcc = float(sum(gtAcc))/len(gtAcc) 105 | accQA.append(avgGTAcc) 106 | if quesType not in accQuesType: 107 | accQuesType[quesType] = [] 108 | accQuesType[quesType].append(avgGTAcc) 109 | if ansType not in accAnsType: 110 | accAnsType[ansType] = [] 111 | accAnsType[ansType].append(avgGTAcc) 112 | self.setEvalQA(quesId, avgGTAcc) 113 | self.setEvalQuesType(quesId, quesType, avgGTAcc) 114 | self.setEvalAnsType(quesId, ansType, avgGTAcc) 115 | if step%100 == 0: 116 | self.updateProgress(step/float(len(quesIds))) 117 | step = step + 1 118 | 119 | self.setAccuracy(accQA, accQuesType, accAnsType) 120 | print ("Done computing accuracy") 121 | 122 | def processPunctuation(self, inText): 123 | outText = inText 124 | for p in self.punct: 125 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): 126 | outText = outText.replace(p, '') 127 | else: 128 | outText = outText.replace(p, ' ') 129 | outText = self.periodStrip.sub("", 130 | outText, 131 | re.UNICODE) 132 | return outText 133 | 134 | def processDigitArticle(self, inText): 135 | outText = [] 136 | tempText = inText.lower().split() 137 | for word in tempText: 138 | word = self.manualMap.setdefault(word, word) 139 | if word not in self.articles: 140 | outText.append(word) 141 | else: 142 | pass 143 | for wordId, word in enumerate(outText): 144 | if word in self.contractions: 145 | outText[wordId] = self.contractions[word] 146 | outText = ' '.join(outText) 147 | return outText 148 | 149 | def setAccuracy(self, accQA, accQuesType, accAnsType): 150 | self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n) 151 | self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType} 152 | self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType} 153 | 154 | def setEvalQA(self, quesId, acc): 155 | self.evalQA[quesId] = round(100*acc, self.n) 156 | 157 | def setEvalQuesType(self, quesId, quesType, acc): 158 | if quesType not in self.evalQuesType: 159 | self.evalQuesType[quesType] = {} 160 | self.evalQuesType[quesType][quesId] = round(100*acc, self.n) 161 | 162 | def setEvalAnsType(self, quesId, ansType, acc): 163 | if ansType not in self.evalAnsType: 164 | self.evalAnsType[ansType] = {} 165 | self.evalAnsType[ansType][quesId] = round(100*acc, self.n) 166 | 167 | def updateProgress(self, progress): 168 | barLength = 20 169 | status = "" 170 | if isinstance(progress, int): 171 | progress = float(progress) 172 | if not isinstance(progress, float): 173 | progress = 0 174 | status = "error: progress var must be float\r\n" 175 | if progress < 0: 176 | progress = 0 177 | status = "Halt...\r\n" 178 | if progress >= 1: 179 | progress = 1 180 | status = "Done...\r\n" 181 | block = int(round(barLength*progress)) 182 | text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status) 183 | sys.stdout.write(text) 184 | sys.stdout.flush() -------------------------------------------------------------------------------- /x2vlm_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/x2vlm_github.png --------------------------------------------------------------------------------