├── Grounding_bbox.py
├── LICENSE
├── MARVL.py
├── NLVR.py
├── Pretrain.py
├── README.md
├── Retrieval.py
├── VQA.py
├── VQA_msrvtt.py
├── VQA_msvd.py
├── WIT.py
├── XGQA.py
├── XRetrieval.py
├── XVNLI.py
├── accelerators
├── __init__.py
├── accelerator.py
└── apex_ddp_accelerator.py
├── configs
├── config_beit2_base.json
├── config_beit2_large.json
├── finetune
│ ├── coco_captioning_large.yaml
│ ├── refcoco_grounding_large.yaml
│ ├── vqa2_base.yaml
│ └── vqa2_large.yaml
└── pretrain
│ ├── multilingual_cclm_x2vlm_base.yaml
│ ├── multilingual_cclm_x2vlm_large.yaml
│ ├── x2vlm_base_1b.yaml
│ ├── x2vlm_base_1b_stage2.yaml
│ ├── x2vlm_base_4m.yaml
│ ├── x2vlm_large_1b.yaml
│ ├── x2vlm_large_1b_stage2.yaml
│ └── x2vlm_large_4m.yaml
├── dataset
├── __init__.py
├── captioning_dataset.py
├── dist_dataset.py
├── grounding_dataset.py
├── nlvr_dataset.py
├── pretrain_dataset.py
├── pretrain_dataset_multilingual.py
├── randaugment.py
├── retrieval_dataset.py
├── tokenizers
│ ├── __init__.py
│ └── bert_tokenizer_with_dropout.py
├── utils.py
├── vqa_dataset.py
├── wit_dataset.py
├── xflickrco_dataset.py
└── xvnli_dataset.py
├── models
├── __init__.py
├── beit2.py
├── box_ops.py
├── clip_vit.py
├── model_classification.py
├── model_grounding.py
├── model_pretrain.py
├── model_retrieval.py
├── resampler.py
├── swin_transformer.py
├── vit.py
├── xbert.py
├── xroberta.py
└── xvlm.py
├── optim.py
├── refTools
├── evaluation
│ ├── __init__.py
│ ├── bleu
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ ├── bleu.py
│ │ └── bleu_scorer.py
│ ├── cider
│ │ ├── __init__.py
│ │ ├── cider.py
│ │ └── cider_scorer.py
│ ├── meteor
│ │ ├── __init__.py
│ │ ├── meteor-1.5.jar
│ │ └── meteor.py
│ ├── readme.txt
│ ├── refEvaluation.py
│ ├── rouge
│ │ ├── __init__.py
│ │ └── rouge.py
│ └── tokenizer
│ │ ├── __init__.py
│ │ ├── ptbtokenizer.py
│ │ ├── stanford-corenlp-3.4.1.jar
│ │ ├── tmp37tp6xj8
│ │ ├── tmp82iqkuu0
│ │ └── tmpn19wmqte
└── refer_python3.py
├── requirements.txt
├── run.py
├── scheduler.py
├── utils
├── __init__.py
├── bleu.py
├── checkpointer.py
├── cider
│ └── pyciderevalcap
│ │ ├── __init__.py
│ │ ├── cider
│ │ ├── __init__.py
│ │ ├── cider.py
│ │ └── cider_scorer.py
│ │ └── ciderD
│ │ ├── __init__.py
│ │ ├── ciderD.py
│ │ └── ciderD_scorer.py
├── hdfs_io.py
├── marvl_preproc.py
└── torch_io.py
├── vqaTools
├── __init__.py
├── vqa.py
└── vqaEval.py
├── x2vlm_github.png
└── xFlickrCO.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023, ByteDance Inc.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13 |
--------------------------------------------------------------------------------
/MARVL.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import math
5 |
6 | import ruamel.yaml as yaml
7 | import numpy as np
8 | import random
9 | import time
10 | import datetime
11 | import json
12 | from pathlib import Path
13 | import json
14 | import pickle
15 |
16 | import torch
17 | import torch.backends.cudnn as cudnn
18 | import torch.distributed as dist
19 |
20 | import utils
21 | from dataset import create_dataset, create_sampler, create_loader, build_tokenizer
22 | from scheduler import create_scheduler
23 | from optim import create_optimizer
24 |
25 |
26 | def train(model, data_loader, optimizer, tokenizer, epoch, device, scheduler):
27 | model.train()
28 |
29 | metric_logger = utils.MetricLogger(delimiter=" ")
30 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
31 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
32 |
33 | header = 'Train Epoch: [{}]'.format(epoch)
34 | print_freq = 50
35 | step_size = 100
36 |
37 | for i, (image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
38 | images = torch.cat([image0, image1], dim=0)
39 | images, targets = images.to(device), targets.to(device)
40 |
41 | text_inputs = tokenizer(text, padding='longest', return_tensors="pt").to(device)
42 |
43 | loss = model(images, text_inputs.input_ids, text_inputs.attention_mask, targets=targets, train=True)
44 |
45 | optimizer.zero_grad()
46 | loss.backward()
47 | optimizer.step()
48 | scheduler.step()
49 |
50 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
51 | metric_logger.update(loss=loss.item())
52 |
53 | # gather the stats from all processes
54 | metric_logger.synchronize_between_processes()
55 | print("Averaged stats:", metric_logger.global_avg())
56 | return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
57 |
58 |
59 | @torch.no_grad()
60 | def evaluate(model, data_loader, tokenizer, device):
61 | model.eval()
62 |
63 | metric_logger = utils.MetricLogger(delimiter=" ")
64 |
65 | header = 'Evaluation:'
66 | print_freq = 50
67 |
68 | for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
69 | images = torch.cat([image0, image1], dim=0)
70 | images, targets = images.to(device), targets.to(device)
71 | text_inputs = tokenizer(text, padding='longest', return_tensors="pt").to(device)
72 |
73 | prediction = model(images, text_inputs.input_ids, text_inputs.attention_mask, targets=targets, train=False)
74 |
75 | _, pred_class = prediction.max(1)
76 | accuracy = (targets == pred_class).sum() / targets.size(0)
77 |
78 | metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
79 |
80 | # gather the stats from all processes
81 | metric_logger.synchronize_between_processes()
82 | print("Averaged stats:", metric_logger.global_avg())
83 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
84 |
85 |
86 | def main(args, config):
87 | utils.init_distributed_mode(args)
88 | device = torch.device(args.device)
89 |
90 | world_size = utils.get_world_size()
91 |
92 | if args.epoch > 0:
93 | config['schedular']['epochs'] = args.epoch
94 | print(f"### set epochs to: {args.epoch}", flush=True)
95 |
96 | if args.bs > 0:
97 | config['batch_size'] = args.bs // world_size
98 |
99 | seed = args.seed + utils.get_rank()
100 | torch.manual_seed(seed)
101 | np.random.seed(seed)
102 | random.seed(seed)
103 | cudnn.benchmark = True
104 |
105 | print("Creating dataset")
106 | train_dataset, val_dataset, test_dataset_dict = create_dataset('marvl', config)
107 | datasets = [train_dataset, val_dataset]
108 |
109 | train_dataset_size = len(train_dataset)
110 | train_batch_size = config['batch_size']
111 | world_size = utils.get_world_size()
112 |
113 | if utils.is_main_process():
114 | print(f"### data {train_dataset_size}, batch size, {train_batch_size} x {world_size}")
115 | print(f"### Test: {[(k, len(dataset)) for k, dataset in test_dataset_dict.items()]}")
116 |
117 | if args.distributed:
118 | num_tasks = utils.get_world_size()
119 | global_rank = utils.get_rank()
120 | samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
121 | else:
122 | samplers = [None, None]
123 |
124 | train_loader, val_loader = create_loader(datasets, samplers, batch_size=[config['batch_size']] * 2,
125 | num_workers=[4, 4], is_trains=[True, False],
126 | collate_fns=[None, None])
127 |
128 | test_loader_dict = {}
129 | for k, v in test_dataset_dict.items():
130 | test_loader_dict[k] = create_loader([v], [None], batch_size=[config['batch_size']],
131 | num_workers=[4], is_trains=[False], collate_fns=[None])[0]
132 |
133 | print("Creating model")
134 | from models.model_classification import XVLMPlusForMARVL
135 | model = XVLMPlusForMARVL(config=config)
136 | model.load_pretrained(args.checkpoint, config, is_eval=args.evaluate)
137 | model = model.to(device)
138 | print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
139 |
140 | model_without_ddp = model
141 | if args.distributed:
142 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
143 | model_without_ddp = model.module
144 |
145 | tokenizer = build_tokenizer(config['text_encoder'])
146 |
147 | print("### output_dir, ", args.output_dir, flush=True)
148 | start_time = time.time()
149 |
150 | if args.evaluate:
151 | print("Start evaluating")
152 |
153 | acc_mean = 0
154 | for language, test_loader in test_loader_dict.items():
155 | test_stats = evaluate(model, test_loader, tokenizer, device)
156 | if utils.is_main_process():
157 | print({f'test_{language}_{k}': v for k, v in test_stats.items()}, flush=True)
158 | acc_mean += (test_stats['acc'] / len(test_loader_dict))
159 |
160 | dist.barrier()
161 |
162 | if utils.is_main_process():
163 | print("Test average accuracy: ", acc_mean, flush=True)
164 | dist.barrier()
165 |
166 | else:
167 | print("Start training")
168 | arg_opt = utils.AttrDict(config['optimizer'])
169 | optimizer = create_optimizer(arg_opt, model)
170 | arg_sche = utils.AttrDict(config['schedular'])
171 | arg_sche['step_per_epoch'] = math.ceil(train_dataset_size / (train_batch_size * world_size))
172 | lr_scheduler = create_scheduler(arg_sche, optimizer)
173 |
174 | max_epoch = config['schedular']['epochs']
175 |
176 | best = 0
177 | best_epoch = 0
178 | if 'eval_interval' not in config:
179 | config['eval_interval'] = 1
180 |
181 | for epoch in range(0, max_epoch):
182 | if args.distributed:
183 | train_loader.sampler.set_epoch(epoch)
184 | train_stats = train(model, train_loader, optimizer, tokenizer, epoch, device, lr_scheduler)
185 | if epoch >= config['start_eval']:
186 | # val_stats = evaluate(model, val_loader, tokenizer, device)
187 |
188 | acc_mean = 0
189 | for language, test_loader in test_loader_dict.items():
190 | test_stats = evaluate(model, test_loader, tokenizer, device)
191 | if utils.is_main_process():
192 | print({f'test_{language}_{k}': v for k, v in test_stats.items()}, flush=True)
193 | acc_mean += (float(test_stats['acc']) / len(test_loader_dict))
194 | dist.barrier()
195 |
196 | if utils.is_main_process():
197 | if acc_mean > best:
198 | save_obj = {
199 | 'model': model_without_ddp.state_dict(),
200 | # 'optimizer': optimizer.state_dict(),
201 | # 'lr_scheduler': lr_scheduler.state_dict(),
202 | 'config': config,
203 | # 'epoch': epoch,
204 | }
205 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
206 | best = acc_mean
207 | best_epoch = epoch
208 |
209 | print("best epoch: {:}, best test acc_mean: {:.4f}".format(best_epoch, best), flush=True)
210 |
211 | dist.barrier()
212 |
213 | total_time = time.time() - start_time
214 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
215 | print('### Time {}'.format(total_time_str))
216 |
217 |
218 | if __name__ == '__main__':
219 | parser = argparse.ArgumentParser()
220 | parser.add_argument('--checkpoint', type=str, required=True)
221 | parser.add_argument('--config', default='./configs/MARVL.yaml')
222 | parser.add_argument('--output_dir', default='output/nlvr')
223 |
224 | parser.add_argument('--device', default='cuda')
225 | parser.add_argument('--seed', default=42, type=int)
226 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
227 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
228 | parser.add_argument('--distributed', action='store_false')
229 |
230 | parser.add_argument('--load_nlvr_pretrain', action='store_true')
231 | parser.add_argument('--epoch', default=-1, type=int)
232 | parser.add_argument('--lr', default=0., type=float)
233 | parser.add_argument('--fewshot', default='', type=str)
234 | parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus")
235 | parser.add_argument('--evaluate', action='store_true')
236 |
237 | args = parser.parse_args()
238 |
239 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
240 |
241 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
242 |
243 | if args.lr != 0.:
244 | config['optimizer']['lr'] = args.lr
245 | config['schedular']['lr'] = args.lr
246 | if args.fewshot:
247 | config['train_file'][0] = config['train_file'][0].format(args.fewshot)
248 | config['val_file'][0] = config['val_file'][0].format(args.fewshot)
249 |
250 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
251 |
252 | main(args, config)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # X2-VLM: All-In-One Pre-trained Model For Vision-Language Tasks
2 |
3 |
4 |

5 |
6 |
7 | X2-VLM with a modular architecture performs the best on base and large scale for both image-text and video-text tasks, making a good trade-off between performance and model scale. We also show that the modular design of X2-VLM results in high transferability for X2-VLM to be utilized in any language or domain. For example, by simply replacing the text encoder with XLM-R, X-VLM outperforms state-of-the-art multilingual multi-modal pre-trained models without any multilingual pre-training.
8 |
9 |
10 | - Jun 2023: Release official PyTorch implementation and checkpoints
11 | - Nov 2022: Release preprint in arxiv.
12 |
13 |
14 | X2-VLM (large, 593M params):
15 | [](https://paperswithcode.com/sota/cross-modal-retrieval-on-flickr30k?p=x-2-vlm-all-in-one-pre-trained-model-for)
16 | [](https://paperswithcode.com/sota/cross-modal-retrieval-on-coco-2014?p=x-2-vlm-all-in-one-pre-trained-model-for)
17 | [](https://paperswithcode.com/sota/visual-grounding-on-refcoco-testa?p=x-2-vlm-all-in-one-pre-trained-model-for)
18 | [](https://paperswithcode.com/sota/visual-reasoning-on-nlvr2-test?p=x-2-vlm-all-in-one-pre-trained-model-for)
19 | [](https://paperswithcode.com/sota/visual-question-answering-on-vqa-v2-test-std?p=x-2-vlm-all-in-one-pre-trained-model-for)
20 |
21 |
22 |
23 | ## Features
24 | - Support several backbones
25 | - vision encoder: beit / clip-vit / swin-transformer
26 | - text encoder: bert / roberta
27 | - Support apex O1 / O2 for pre-training
28 | - Read from and write to HDFS
29 | - Distributed training across nodes for both pre-training and fine-tuning
30 |
31 | Please read the code for more details.
32 |
33 |
34 | ## Requirements
35 | - Install python3 environment
36 | ```angular2html
37 | pip3 install -r requirements.txt
38 | ```
39 | - Download raw images from corresponding websites
40 | - Download the json files we provided, which contains image read paths and captions and/or bbox annotations
41 | - If running pre-training scripts:
42 | - install Apex
43 | - download pre-trained models for parameter initialization
44 | - image encoder: beit2
45 | - text encoder: bert
46 |
47 |
48 | ## Pretrain
49 | ```angular2html
50 | # X-VLM pretrain
51 | python3 run.py --task "pretrain_DIY" --dist "all" --config "configs/pretrain/x2vlm_base_4m.yaml" --output_dir "output/tmp"
52 |
53 | # CCLM multilingual multimodal pretrain
54 | python3 run.py --task "pretrain_DIY" --dist "all" --config "configs/pretrain/multilingual_cclm_x2vlm_base.yaml" --checkpoint "path/to/x2vlm_base_1b.th" --output_dir "output/tmp"
55 | ```
56 | See run.py and configs/pretrain for more details.
57 |
58 |
59 | #### Data
60 | All datasets we utilized are public available. Please prepare the pre-training data by yourself. Read the code dataset/pretrain_dataset.py (more specifically ImageTextJsonDataset & RegionTextJsonDataset) to see what format is needed.
61 |
62 | The processed COCO & VG annotations can be downloaded [here](https://drive.google.com/drive/u/1/folders/1W4_wr53DDWLsvuSavNW1iDbSo9yXQQL1).
63 |
64 |
65 | #### Checkpoints
66 | Please make sure all parameters are loaded correctly.
67 | [X2VLM-base (4M)](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_base_4m.th)
68 | [X2VLM-large (4M)](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_large_4m.th)
69 | [X2VLM-base (1B)](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_base_1b.th)
70 | [CCLM-X2VLM-base](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/cclm_x2vlm_base.th)
71 |
72 |
73 |
74 | ## Finetune
75 |
76 | #### Data
77 | All datasets are publicly available. Some datasets can be downloaded [here](https://drive.google.com/file/d/1XFz1Vtz7MCBLn4_1QEojhFJ5Iw3eH3X4/view?usp=sharing).
78 |
79 |
80 | #### Checkpoints, Configs and Logs
81 | We have released all codes. However, now we only provide parts of fine-tuned ckpts (and training configs and logs).
82 | [vqa-base](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_base_1b_vqa.th)
83 | [vqa-large](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_large_1b_vqa.th)
84 | [refcoco-bbox-large](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/x2vlm_large_1b_grounding.tar)
85 | It takes time for us to retrieve our previous training logs. If you need more, please submit a Github issue and we will return to your request later.
86 | [coco-retrieval-base-rerun](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/xvlm_beit_1b_stage2_coco_rerun.th)
87 | [coco-retrieval-large-rerun](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/x2vlm_ckpts_2release/xvlm_beit_1b_large_stage2_coco_rerun.th)
88 |
89 |
90 | #### Examples
91 | ```angular2html
92 | # train
93 |
94 | python3 run.py --task "vqa" --dist "all" --config "configs/finetune/vqa2_large.yaml" --checkpoint "x2vlm_ckpts_2release/x2vlm_large_1b.th" --output_dir "output/tmp"
95 |
96 | python3 run.py --task "refcoco_bbox" --dist "all" --config "configs/finetune/refcoco_grounding_large.yaml" --checkpoint "x2vlm_ckpts_2release/x2vlm_large_1b.th" --output_dir "output/tmp"
97 |
98 | python3 run.py --task "coco_captioning_mlm" --dist "all" --config "configs/finetune/coco_captioning_large.yaml" --checkpoint "x2vlm_ckpts_2release/x2vlm_large_1b.th" --output_dir "output/tmp"
99 | ```
100 | We release all training codes. Specify "--task" and "--config" to finetune on other tasks. See run.py for details.
101 |
102 |
103 |
104 |
105 | ## Citation
106 | If you find this repository useful, please considering giving ⭐ or citing:
107 | ```
108 | @article{zeng2022x,
109 | title={X $\^{} 2$-VLM: All-In-One Pre-trained Model For Vision-Language Tasks},
110 | author={Zeng, Yan and Zhang, Xinsong and Li, Hang and Wang, Jiawei and Zhang, Jipeng and Zhou, Wangchunshu},
111 | journal={arXiv preprint arXiv:2211.12402},
112 | year={2022}
113 | }
114 |
115 | @article{zeng2022cross,
116 | title={Cross-view language modeling: Towards unified cross-lingual cross-modal pre-training},
117 | author={Zeng, Yan and Zhou, Wangchunshu and Luo, Ao and Zhang, Xinsong},
118 | journal={arXiv preprint arXiv:2206.00621},
119 | year={2022}
120 | }
121 |
122 | @article{xvlm,
123 | title={Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts},
124 | author={Zeng, Yan and Zhang, Xinsong and Li, Hang},
125 | journal={arXiv preprint arXiv:2111.08276},
126 | year={2021}
127 | }
128 | ```
129 |
130 |
131 | ### Contact
132 | For issues using this code, please submit a GitHub issue.
--------------------------------------------------------------------------------
/accelerators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/accelerators/__init__.py
--------------------------------------------------------------------------------
/accelerators/accelerator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276)
3 | # Github: https://github.com/zengyan-97/X-VLM
4 | # Copyright (c) 2022, ByteDance Inc.
5 | # All rights reserved.
6 |
7 | from logging import Logger
8 |
9 | import torch
10 | from torch.optim import Optimizer
11 |
12 | Net = torch.nn.Module
13 |
14 |
15 | class Accelerator:
16 | def __init__(self, cfg, logger) -> None:
17 | self.cfg = cfg
18 | self.logger = logger
19 |
20 | def set_up(self, model: Net):
21 | raise NotImplementedError("Set Up method not implement in Accelerator, please check! ")
22 |
23 | def broadcast(self):
24 | raise NotImplementedError("Broadcast method not implement in Accelerator, please check! ")
25 |
26 | def backward_step(self, loss: torch.Tensor):
27 | loss.backward()
28 |
29 | def optimizer_step(self, optimizer: Optimizer, model: Net, grad_norm: float) -> float:
30 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
31 | grad_norm)
32 | return float(total_norm)
33 |
--------------------------------------------------------------------------------
/accelerators/apex_ddp_accelerator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276)
3 | # Github: https://github.com/zengyan-97/X-VLM
4 | # Copyright (c) 2022, ByteDance Inc.
5 | # All rights reserved.
6 |
7 | import os
8 | import random
9 | import sys
10 | from typing import Tuple, Union, Optional, Any
11 | import numpy as np
12 |
13 | import torch
14 | import torch.distributed as distributed
15 | from torch.optim import Optimizer
16 | from torch.optim.lr_scheduler import LambdaLR
17 |
18 | Net = torch.nn.Module
19 |
20 | from .accelerator import Accelerator
21 |
22 | try:
23 | from apex import amp
24 | from apex.parallel import DistributedDataParallel as Apex_DDP
25 | from apex.parallel import convert_syncbn_model
26 | except ImportError:
27 | print('no apex! Please install from https://www.github.com/nvidia/apex')
28 |
29 |
30 | class ApexDDPAccelerator(Accelerator):
31 | """
32 | ApexDDPAccelerator, use apex DistributedDataParallel
33 | """
34 |
35 | def __init__(self, cfg, logger):
36 | super().__init__(cfg, logger)
37 | self.accelerator_rng_seed = self.cfg.RNG_SEED
38 | self.accelerator_syncbn = self.cfg.SYNCBN
39 | self.accelerator_fp16_opt_level = self.cfg.FP16_OPT_LEVEL
40 | self.accelerator_fp16_loss_scale = self.cfg.FP16_LOSS_SCALE
41 |
42 | def set_up(self, model: Net, optimizer: Optimizer, lr_scheduler: LambdaLR,
43 | local_rank: int, world_size: int, rank: int) -> Tuple[Apex_DDP, Optimizer, LambdaLR]:
44 | """
45 | set up ApexDDPAccelerator, including process_group and apex_ddp
46 | """
47 | torch.backends.cudnn.benchmark = False
48 | random.seed(self.accelerator_rng_seed)
49 | np.random.seed(self.accelerator_rng_seed)
50 | torch.random.manual_seed(self.accelerator_rng_seed)
51 | torch.cuda.manual_seed_all(self.accelerator_rng_seed)
52 | master_address = os.environ.get('MASTER_ADDR', "127.0.0.1")
53 | master_port = int(os.environ.get('MASTER_PORT', 34171))
54 |
55 | torch.cuda.set_device(local_rank)
56 | model = model.cuda()
57 | if not torch.distributed.is_initialized():
58 | distributed.init_process_group(
59 | backend='nccl',
60 | init_method='tcp://{}:{}'.format(master_address, master_port),
61 | world_size=world_size,
62 | rank=rank,
63 | group_name='mtorch')
64 | print(
65 | f'ApexDDPAccelerator distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}')
66 | sys.stdout.flush()
67 |
68 | self.broadcast(model)
69 | apex_model, optimizer = self.configure_ddp(model, optimizer)
70 |
71 | if self.accelerator_syncbn:
72 | apex_model = self.configure_sync_batchnorm(apex_model)
73 | return apex_model, optimizer, lr_scheduler
74 |
75 | def broadcast(self, model: Net, src=0) -> None:
76 | for v in model.state_dict().values():
77 | distributed.broadcast(v, src)
78 |
79 | def configure_ddp(self, model: Net, optimizer: Optimizer) -> Tuple[Apex_DDP, Optimizer]:
80 | model, optimizer = amp.initialize(model, optimizer,
81 | opt_level=self.accelerator_fp16_opt_level,
82 | keep_batchnorm_fp32=None, # from True to None
83 | loss_scale=self.accelerator_fp16_loss_scale,
84 | max_loss_scale=1024.0,
85 | min_loss_scale=1.0)
86 |
87 | apex_model = Apex_DDP(model, delay_allreduce=True)
88 | self.ddp_model = apex_model
89 | return apex_model, optimizer
90 |
91 | def configure_sync_batchnorm(self, model: Net) -> Net:
92 | model = convert_syncbn_model(model)
93 | return model
94 |
95 | def backward_step(self, loss: torch.Tensor, optimizer: Optimizer):
96 | with amp.scale_loss(loss, optimizer) as scaled_loss:
97 | scaled_loss.backward()
98 |
99 | def optimizer_step(self, optimizer: Optimizer, model: Net, grad_norm: float) -> float:
100 | total_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
101 | grad_norm)
102 | return float(total_norm)
103 |
--------------------------------------------------------------------------------
/configs/config_beit2_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "ckpt": "data/beitv2_base_patch16_224_pt1k_ft21k.pth",
3 | "vision_width": 768,
4 | "patch_size": 16
5 | }
6 |
--------------------------------------------------------------------------------
/configs/config_beit2_large.json:
--------------------------------------------------------------------------------
1 | {
2 | "ckpt": "data/beitv2_large_patch16_224_pt1k_ft21k.pth",
3 | "vision_width": 1024,
4 | "patch_size": 16
5 | }
6 |
--------------------------------------------------------------------------------
/configs/finetune/coco_captioning_large.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/finetune/coco_karpathy/coco_karpathy_train.json']
2 | val_file: 'data/finetune/coco_karpathy/coco_karpathy_val.json'
3 | test_file: 'data/finetune/coco_karpathy/coco_karpathy_test.json'
4 |
5 | image_root: 'images/coco/'
6 | val_gt_file: 'data/finetune/coco_karpathy/coco_karpathy_val_gt.json'
7 | test_gt_file: 'data/finetune/coco_karpathy/coco_karpathy_test_gt.json'
8 |
9 | ## Vision Encoder
10 | use_beit_v2: True
11 | vision_config: 'configs/config_beit2_large.json'
12 | image_res: 384
13 | patch_size: 16
14 |
15 |
16 | ## Text Encoder (& Cross Encoder)
17 | text_encoder: 'data/bert-large-uncased'
18 | text_num_hidden_layers: 18
19 | text_fusion_start_at: 12
20 |
21 |
22 | ## Training
23 | apply_FG_free: True
24 | batch_size_train: 16 # xN A100s, i don't remember how many GPUs i used... (i guess either 8 or 16)
25 | batch_size_test: 20
26 | max_tokens: 40
27 | max_words: 40
28 | label_smoothing: 0.1
29 | mask_prob: 0.6
30 | max_masks: 18
31 | mask_whole_word: True
32 | skipgram_prb: 0.2
33 | skipgram_size: 3
34 |
35 | ## generation configs
36 | max_length: 50
37 | min_length: 5
38 | num_beams: 3
39 | length_penalty: 0
40 | prompt: 'a picture of '
41 |
42 |
43 | optimizer: {opt: adamW, lr: 5e-6, weight_decay: 0.01, lr_mult: 2, vision_lr: 1e-5, text_lr: 5e-6}
44 | schedular: {sched: linear, epochs: 5, num_warmup_steps: 0.05}
45 | start_eval: 0 # epoch index
46 |
--------------------------------------------------------------------------------
/configs/finetune/refcoco_grounding_large.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/finetune/refcoco+_train.json']
2 | test_file: ['data/finetune/refcoco+_val.json','data/finetune/refcoco+_test.json']
3 |
4 | refcoco_data: 'data/finetune/'
5 | det_file: 'data/finetune/refcoco+/dets.json'
6 | coco_file: 'data/finetune/refcoco+/cocos.json'
7 |
8 | image_root: 'images/coco/'
9 |
10 | careful_hflip: True # first check whether 'left' or 'right' in captions
11 |
12 | ## Vision Encoder
13 | use_beit_v2: True
14 | vision_config: 'configs/config_beit2_large.json'
15 | image_res: 384
16 | patch_size: 16
17 |
18 |
19 | ## Text Encoder (& Cross Encoder)
20 | text_encoder: 'data/bert-large-uncased'
21 | text_num_hidden_layers: 18 # 12 + 6
22 | text_fusion_start_at: 12
23 |
24 | text_drop_path_rate: 0.1
25 | cross_drop_path_rate: 0.1
26 |
27 | ## Training
28 | batch_size: 20 # xN A100s, i don't remember how many GPUs i used... (i guess either 8 or 16)
29 | max_tokens: 40
30 |
31 |
32 | ## Other Settings
33 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.01, lr_mult: 2}
34 | schedular: {sched: linear, epochs: 10, num_warmup_steps: 0.1}
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/finetune/vqa2_base.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/finetune/vqa_train.json',
2 | 'data/finetune/vqa_val.json',
3 | 'data/finetune/vg_qa.json']
4 |
5 | test_file: ['data/finetune/vqa_test.json']
6 | answer_list: 'data/finetune/answer_list.json'
7 |
8 | vqa_root: 'images/coco/'
9 | vg_root: 'images/visualgenome/'
10 |
11 | ## Vision Encoder
12 | use_beit_v2: True
13 | vision_config: 'configs/config_beit2_base.json'
14 | image_res: 768
15 | patch_size: 16
16 |
17 |
18 | ## Text Encoder (& Cross Encoder)
19 | text_encoder: 'data/bert-base-uncased'
20 | text_num_hidden_layers: 18
21 | text_fusion_start_at: 12
22 |
23 | ## Training
24 | num_dec_layers: 6
25 | large_lr_for_dec: True
26 | batch_size_train: 8 # x16 a100
27 | accumulate_steps: 1
28 | batch_size_test: 32
29 | max_tokens: 40
30 | k_test: 128
31 |
32 |
33 | ## Other Settings
34 | optimizer: {opt: adamW, lr: 4e-5, weight_decay: 0.01, lr_mult: 2, vision_lr: 2e-5, text_lr: 4e-5}
35 | schedular: {sched: linear, epochs: 5, num_warmup_steps: 0.1}
36 | start_eval: 2 # epoch index
37 |
--------------------------------------------------------------------------------
/configs/finetune/vqa2_large.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/finetune/vqa_train.json',
2 | 'data/finetune/vqa_val.json',
3 | 'data/finetune/vg_qa.json']
4 |
5 | test_file: ['data/finetune/vqa_test.json']
6 | answer_list: 'data/finetune/answer_list.json'
7 |
8 | vqa_root: 'images/coco/'
9 | vg_root: 'images/visualgenome/'
10 |
11 | ## Vision Encoder
12 | use_beit_v2: True
13 | vision_config: 'configs/config_beit2_large.json'
14 | image_res: 768
15 | patch_size: 16
16 |
17 |
18 | ## Text Encoder (& Cross Encoder)
19 | text_encoder: 'data/bert-large-uncased'
20 | text_num_hidden_layers: 18
21 | text_fusion_start_at: 12
22 |
23 | ## Training
24 | num_dec_layers: 6
25 | large_lr_for_dec: True
26 | batch_size_train: 2 # x32 a100
27 | accumulate_steps: 2
28 | batch_size_test: 32
29 | max_tokens: 40
30 | k_test: 128
31 |
32 |
33 | ## Other Settings
34 | optimizer: {opt: adamW, lr: 4e-5, weight_decay: 0.01, lr_mult: 2, vision_lr: 2e-5, text_lr: 2e-5}
35 | schedular: {sched: linear, epochs: 5, num_warmup_steps: 0.05}
36 | start_eval: 2 # epoch index
37 |
--------------------------------------------------------------------------------
/configs/pretrain/multilingual_cclm_x2vlm_base.yaml:
--------------------------------------------------------------------------------
1 | # Adapt X^2-VLM to multilingual by Cross-View Language Modeling
2 |
3 | ## Data
4 | train_file: [
5 | "path/to/cc-3m-mm-uc2",
6 | "path/to/sbu-mm",
7 | "path/to/vg-mm",
8 | "path/to/coco-mm",
9 | ] # multilingual x multimodal
10 |
11 | train_dataset_size: 5004332
12 |
13 | images: {image_key: "binary",
14 | is_image_rpath: False, # read path or base64 encoding
15 | caption_key: "caption",
16 | tokenized: False, # whether texts have been tokenized
17 | batch_size: 60, # x8 gpus
18 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
19 | iter_perc: 1.0,
20 | }
21 |
22 |
23 | train_file_regions: [
24 | 'path/to/coco_object-mm-google',
25 | 'path/to/vg_object-mm-google',
26 | 'path/to/vg_region-mm',
27 | ] # multilingual
28 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, code_switch: True,
29 | careful_hflip: True,
30 | batch_size: 60, max_images: 26, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4}
31 |
32 |
33 | train_file_mtext: [
34 | "path/to/wikimatrix",
35 | "path/to/wikimatrix_en_bn",
36 | ] # multilingual parallel texts
37 | mtexts: {source_key: "source_text",
38 | target_key: "target_text",
39 | tokenized: False, # whether texts have been tokenized
40 | batch_size: 60, # x8 gpus
41 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
42 | iter_perc: 1.0,
43 | max_words: 64,
44 | max_tokens: 64,
45 | mask_prob: 0.4,
46 | max_masks: 16,
47 | }
48 |
49 |
50 | ## Vision Encoder
51 | use_beit_v2: True
52 | vision_config: 'configs/config_beit2_base.json'
53 | image_res: 224
54 | patch_size: 16
55 |
56 |
57 |
58 | ## Text Encoder (& Cross Encoder)
59 | model_type: 'CrossViewLM'
60 | text_encoder: 'data/xlm-roberta-base'
61 | text_num_hidden_layers: 12
62 | cross_encoder: 'data/bert-base-uncased'
63 | cross_num_hidden_layers: 6
64 |
65 | is_xvlm_ckpt: True # is of XVLMBase or XVLMPlusBase
66 | xvlm_ckpt_text_num_hidden_layers: 12 # if is_xvlm_ckpt
67 | replace_text_encoder: True
68 |
69 |
70 | ## Training
71 | mixed_in_batch: True
72 | calc_image_bbox_loss: False
73 | embed_dim: 256
74 | temp: 0.07
75 |
76 | max_words: 30
77 | max_tokens: 30
78 | mask_prob: 0.4
79 | max_masks: 10
80 |
81 | mask_whole_word: False # not implemented
82 | skipgram_prb: 0.2
83 | skipgram_size: 3
84 |
85 |
86 | ## Other Settings
87 | ckpt_frequent_step: 50000
88 | ckpt_frequent: 100000 # epoch
89 | optimizer: {opt: adamW, lr: 4e-5, weight_decay: 0.01, lr_mult: 2, vision_lr: 2e-5, text_lr: 8e-5, cross_lr: 4e-5}
90 | schedular: {sched: linear, epochs: 39, num_warmup_steps: 1000} # 400k steps
91 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O0, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0}
92 |
--------------------------------------------------------------------------------
/configs/pretrain/multilingual_cclm_x2vlm_large.yaml:
--------------------------------------------------------------------------------
1 | # Adapt X^2-VLM to multilingual by Cross-View Language Modeling
2 |
3 | ## Data
4 | train_file: [
5 | "path/to/cc-3m-mm-uc2",
6 | "path/to/sbu-mm",
7 | "path/to/vg-mm",
8 | "path/to/coco-mm",
9 | ] # multilingual x multimodal
10 |
11 | train_dataset_size: 5004332
12 |
13 | images: {image_key: "binary",
14 | is_image_rpath: False, # read path or base64 encoding
15 | caption_key: "caption",
16 | tokenized: False, # whether texts have been tokenized
17 | batch_size: 30, # x16 gpus
18 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
19 | iter_perc: 1.0,
20 | }
21 |
22 |
23 | train_file_regions: [
24 | 'path/to/coco_object-mm-google',
25 | 'path/to/vg_object-mm-google',
26 | 'path/to/vg_region-mm',
27 | ] # multilingual
28 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False, code_switch: True,
29 | careful_hflip: True,
30 | batch_size: 30, max_images: 14, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4}
31 |
32 |
33 | train_file_mtext: [
34 | "path/to/wikimatrix",
35 | "path/to/wikimatrix_en_bn",
36 | ] # multilingual parallel texts
37 | mtexts: {source_key: "source_text",
38 | target_key: "target_text",
39 | tokenized: False, # whether texts have been tokenized
40 | batch_size: 30, # x16 gpus
41 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
42 | iter_perc: 1.0,
43 | max_words: 64,
44 | max_tokens: 64,
45 | mask_prob: 0.4,
46 | max_masks: 16,
47 | }
48 |
49 |
50 | ## Vision Encoder
51 | use_beit_v2: True
52 | vision_config: 'configs/config_beit2_large.json'
53 | image_res: 224
54 | patch_size: 16
55 |
56 |
57 | ## Text Encoder (& Cross Encoder)
58 | model_type: 'CrossViewLM'
59 | text_encoder: 'data/xlm-roberta-large'
60 | text_num_hidden_layers: 24
61 | cross_encoder: 'data/bert-large-uncased'
62 | cross_num_hidden_layers: 6
63 |
64 | is_xvlm_ckpt: True # is of XVLMBase or XVLMPlusBase
65 | xvlm_ckpt_text_num_hidden_layers: 12 # if is_xvlm_ckpt
66 | replace_text_encoder: True
67 |
68 |
69 |
70 | ## Training
71 | mixed_in_batch: True
72 | calc_image_bbox_loss: False
73 | embed_dim: 256
74 | temp: 0.07
75 |
76 | max_words: 30
77 | max_tokens: 30
78 | mask_prob: 0.4
79 | max_masks: 10
80 |
81 | mask_whole_word: False # not implemented
82 | skipgram_prb: 0.2
83 | skipgram_size: 3
84 |
85 |
86 | ## Other Settings
87 | ckpt_frequent_step: 50000
88 | ckpt_frequent: 100000 # epoch
89 | optimizer: {opt: adamW, lr: 3e-5, weight_decay: 0.01, lr_mult: 2, vision_lr: 1e-5, text_lr: 6e-5, cross_lr: 3e-5}
90 | schedular: {sched: linear, epochs: 39, num_warmup_steps: 1000} # 400k steps
91 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O0, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0}
92 |
--------------------------------------------------------------------------------
/configs/pretrain/x2vlm_base_1b.yaml:
--------------------------------------------------------------------------------
1 | ## Data
2 | print_broken_data: False
3 |
4 | train_file: [
5 | "path/to/laion_filtered",
6 | "path/to/laion2b_filtered",
7 | ]
8 | train_dataset_size: 1323042683 # for IterableDataset
9 |
10 |
11 | train_file_aux: [
12 | "path/to/coco_testset_filtered",
13 | "path/to/vg_testset_filtered",
14 | "path/to/sbu_bs64",
15 | "path/to/cc3m_bs64",
16 | "path/to/cc12m_bs64",
17 | ] # cleaner data
18 | aux_iter_perc: 0.15 # aux_iter_perc% iterate on train_file_aux, (1-aux_iter_perc%) iterate on train_file
19 | images: {image_key: "binary",
20 | is_image_rpath: False, # read path or base64 encoding
21 | caption_key: "desc",
22 | aux_caption_key: "desc",
23 | tokenized: False, # whether texts have been tokenized
24 | batch_size: 128, # 128 x 24 = 3072
25 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
26 | }
27 |
28 | train_file_regions: [
29 | 'path/to/coco2017_obj_rmtest_2207',
30 | 'path/to/vg_attr_obj_rmtest_2207',
31 | 'path/to/vg_region_rmtest_2207',
32 | "path/to/refcoco_region_2207",
33 | "path/to/gqa_obj_2207",
34 | "path/to/flickr_obj_2207",
35 | "path/to/openimages_v6_maxrez800_obj_region_2207",
36 | "path/to/object365_obj_2207",
37 | ] # objects & regions;
38 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False,
39 | iter_perc: 0.5, batch_size: 64, max_images: 26, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4}
40 |
41 |
42 | ## Vision Encoder
43 | use_beit_v2: True
44 | vision_config: 'configs/config_beit2_base.json'
45 | image_res: 224
46 | patch_size: 16
47 | local_attn_depth: -1
48 |
49 |
50 | ## Text Encoder (& Cross Encoder)
51 | text_encoder: 'data/bert-base-uncased'
52 | text_num_hidden_layers: 18 # include cross
53 | text_fusion_start_at: 12
54 |
55 |
56 | ## Training
57 | mixed_in_batch: True
58 | calc_image_bbox_loss: False
59 | embed_dim: 256
60 | temp: 0.07
61 |
62 | max_words: 30
63 | max_tokens: 30
64 | mask_prob: 0.5
65 | max_masks: 12
66 | mask_whole_word: True
67 | skipgram_prb: 0.2
68 | skipgram_size: 3
69 |
70 | stop_calc_itm: 200000 # steps; matching loss calculates hard negatives causing nan loss
71 |
72 |
73 | ## Other Settings
74 | ckpt_frequent_step: 50000
75 | ckpt_frequent: 100000000 # inf
76 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2}
77 | schedular: {sched: linear, lr: 1e-4, epochs: 3, num_warmup_steps: 2500}
78 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0}
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/configs/pretrain/x2vlm_base_1b_stage2.yaml:
--------------------------------------------------------------------------------
1 | ## Data
2 | train_file: [
3 | "path/to/coco_testset_filtered",
4 | "path/to/vg_testset_filtered",
5 | "path/to/sbu_bs64",
6 | "path/to/cc3m_bs64",
7 | ]
8 |
9 | train_dataset_size: 5114489
10 | images: {image_key: "binary",
11 | is_image_rpath: False, # read path or base64 encoding
12 | caption_key: "desc",
13 | tokenized: False, # whether texts have been tokenized
14 | batch_size: 64, # 64 x 16 = 1024
15 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
16 | }
17 |
18 |
19 | train_file_regions: [
20 | 'path/to/coco2017_obj_rmtest_2207',
21 | 'path/to/vg_attr_obj_rmtest_2207',
22 | 'path/to/vg_region_rmtest_2207',
23 | "path/to/refcoco_region_2207",
24 | "path/to/gqa_obj_2207",
25 | "path/to/flickr_obj_2207",
26 | "path/to/openimages_v6_maxrez800_obj_region_2207",
27 | "path/to/object365_obj_2207",
28 | ] # objects & regions;
29 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False,
30 | iter_perc: 1.0, batch_size: 64, max_images: 26, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4}
31 |
32 |
33 | train_file_videos: [
34 | "path/to/howto100m_filtered", # 1704857
35 | "path/to/ytt180m_filtered", # 5306576
36 | ]
37 | train_file_videos_aux: [
38 | "path/to/webvid2_5m", # 2492317
39 | ] # cleaner data
40 | video_aux_iter_perc: 0.35 # aux_iter_perc% iterate on train_file_videos_aux, (1-aux_iter_perc%) iterate on train_file_videos
41 |
42 |
43 | videos: {image_key: "video_frames",
44 | is_image_rpath: False, # read path or base64 encoding
45 | caption_key: "text",
46 | tokenized: False, # whether texts have been tokenized
47 | frame_len: 3, # 5 -> 3, too slow
48 | use_random_sampling: True,
49 | combine_continuous_clips: True,
50 | mininum_frames_before_sampling: 8, # 10 -> 8 since webvid has 15+ frames per video on average
51 | batch_size: 40, # 40*16=640 64 -> 48, too slow
52 | iter_perc: 1.0,
53 | num_workers: 8, # better -> the total number of training files % (world_size * num_workers) == 0
54 | }
55 |
56 |
57 | ## Vision Encoder
58 | use_beit_v2: True
59 | vision_config: 'configs/config_beit2_base.json'
60 | image_res: 224
61 | patch_size: 16
62 | local_attn_depth: -1
63 |
64 | frame_len: 3
65 | add_frame_pos: True
66 | video_encoding: 'avgpool'
67 |
68 |
69 | ## Text Encoder (& Cross Encoder)
70 | text_encoder: 'data/bert-base-uncased'
71 | text_num_hidden_layers: 18
72 | text_fusion_start_at: 12
73 |
74 |
75 | ## Training
76 | mixed_in_batch: True
77 | calc_image_bbox_loss: False
78 | embed_dim: 256
79 | temp: 0.07
80 |
81 | max_words: 40
82 | max_tokens: 40
83 | mask_prob: 0.5
84 | max_masks: 12
85 | mask_whole_word: True
86 | skipgram_prb: 0.2
87 | skipgram_size: 3
88 |
89 |
90 | ## Other Settings
91 | ckpt_frequent_step: 50000
92 | ckpt_frequent: 1000000 # epoch
93 | optimizer: {opt: adamW, lr: 6e-5, weight_decay: 0.01, lr_mult: 2}
94 | schedular: {sched: linear, lr: 6e-5, epochs: 81, num_warmup_steps: 1000} # 400k steps, video -> 27 epochs
95 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0}
96 |
97 |
--------------------------------------------------------------------------------
/configs/pretrain/x2vlm_base_4m.yaml:
--------------------------------------------------------------------------------
1 | ## Data
2 | train_file: [
3 | "path/to/coco_testset_filtered",
4 | "path/to/vg_testset_filtered",
5 | "path/to/sbu_bs64",
6 | "path/to/cc3m_bs64",
7 | ]
8 |
9 | train_dataset_size: 5114489 # for IterableDataset
10 | images: {image_key: "binary",
11 | is_image_rpath: False, # read path or base64 encoding
12 | caption_key: "desc",
13 | tokenized: False, # whether texts have been tokenized
14 | batch_size: 128, # 128 x 8 = 1024
15 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
16 | }
17 |
18 |
19 | train_file_regions: [
20 | 'path/to/coco2017_obj_rmtest_2207',
21 | 'path/to/vg_attr_obj_rmtest_2207',
22 | 'path/to/vg_region_rmtest_2207',
23 | "path/to/refcoco_region_2207",
24 | "path/to/gqa_obj_2207",
25 | "path/to/flickr_obj_2207",
26 | ]
27 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False,
28 | iter_perc: 1, batch_size: 128, max_images: 50, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4}
29 |
30 |
31 | ## Vision Encoder
32 | use_beit_v2: True
33 | vision_config: 'configs/config_beit2_base.json'
34 | image_res: 224
35 | patch_size: 16
36 | local_attn_depth: -1
37 |
38 |
39 | ## Text Encoder (& Cross Encoder)
40 | text_encoder: 'data/bert-base-uncased'
41 | text_num_hidden_layers: 18 # include cross
42 | text_fusion_start_at: 12
43 |
44 |
45 | ## Training
46 | mixed_in_batch: True
47 | calc_image_bbox_loss: False
48 | embed_dim: 256
49 | temp: 0.07
50 |
51 | max_words: 40
52 | max_tokens: 40
53 | mask_prob: 0.5
54 | max_masks: 12
55 | mask_whole_word: True
56 | skipgram_prb: 0.2
57 | skipgram_size: 3
58 |
59 |
60 | ## Other Settings
61 | ckpt_frequent_step: 50000
62 | ckpt_frequent: 1000000000 # epoch
63 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.01, lr_mult: 2}
64 | schedular: {sched: linear, lr: 1e-4, epochs: 101, num_warmup_steps: 2500} # 之前是跑 200k steps, 现在感觉要跑 500k steps
65 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0}
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/configs/pretrain/x2vlm_large_1b.yaml:
--------------------------------------------------------------------------------
1 | ## Data
2 | print_broken_data: False
3 |
4 | train_file: [
5 | "path/to/laion_filtered",
6 | "path/to/laion2b_filtered",
7 | ]
8 | train_dataset_size: 1323042683 # for IterableDataset
9 |
10 |
11 | train_file_aux: [
12 | "path/to/coco_testset_filtered",
13 | "path/to/vg_testset_filtered",
14 | "path/to/sbu_bs64",
15 | "path/to/cc3m_bs64",
16 | "path/to/cc12m_bs64",
17 | ] # cleaner data
18 | aux_iter_perc: 0.15 # aux_iter_perc% iterate on train_file_aux, (1-aux_iter_perc%) iterate on train_file
19 | images: {image_key: "binary",
20 | is_image_rpath: False, # read path or base64 encoding
21 | caption_key: "desc",
22 | aux_caption_key: "desc",
23 | tokenized: False, # whether texts have been tokenized
24 | batch_size: 64, # 128 x 24 = 3072
25 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
26 | }
27 |
28 | train_file_regions: [
29 | 'path/to/coco2017_obj_rmtest_2207',
30 | 'path/to/vg_attr_obj_rmtest_2207',
31 | 'path/to/vg_region_rmtest_2207',
32 | "path/to/refcoco_region_2207",
33 | "path/to/gqa_obj_2207",
34 | "path/to/flickr_obj_2207",
35 | "path/to/openimages_v6_maxrez800_obj_region_2207",
36 | "path/to/object365_obj_2207",
37 | ] # objects & regions;
38 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False,
39 | iter_perc: 0.5, batch_size: 32, max_images: 14, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4}
40 |
41 |
42 | ## Vision Encoder
43 | use_beit_v2: True
44 | vision_config: 'configs/config_beit2_large.json'
45 | image_res: 224
46 | patch_size: 16
47 | local_attn_depth: -1
48 |
49 |
50 | ## Text Encoder (& Cross Encoder)
51 | text_encoder: 'data/bert-large-uncased-12l'
52 | text_num_hidden_layers: 18 # 12 + 6
53 | text_fusion_start_at: 12
54 |
55 |
56 | ## Training
57 | mixed_in_batch: True
58 | calc_image_bbox_loss: False
59 | embed_dim: 256
60 | temp: 0.07
61 |
62 | max_words: 30
63 | max_tokens: 30
64 | mask_prob: 0.5
65 | max_masks: 12
66 | mask_whole_word: True
67 | skipgram_prb: 0.2
68 | skipgram_size: 3
69 |
70 | stop_calc_itm: 200000
71 |
72 |
73 | ## Other Settings
74 | ckpt_frequent_step: 50000
75 | ckpt_frequent: 100000000 # inf
76 | optimizer: {opt: adamW, lr: 5e-5, weight_decay: 0.01, lr_mult: 2}
77 | schedular: {sched: linear, lr: 5e-5, epochs: 3, num_warmup_steps: 2500}
78 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0}
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/configs/pretrain/x2vlm_large_1b_stage2.yaml:
--------------------------------------------------------------------------------
1 | ## Data
2 | train_file: [
3 | "path/to/coco_testset_filtered",
4 | "path/to/vg_testset_filtered",
5 | "path/to/sbu_bs64",
6 | "path/to/cc3m_bs64",
7 | ]
8 |
9 | train_dataset_size: 5114489 # for IterableDataset
10 | images: {image_key: "binary",
11 | is_image_rpath: False, # read path or base64 encoding
12 | caption_key: "desc",
13 | tokenized: False, # whether texts have been tokenized
14 | batch_size: 32, # 32 x 32 = 1024
15 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
16 | }
17 |
18 |
19 | train_file_regions: [
20 | 'path/to/coco2017_obj_rmtest_2207',
21 | 'path/to/vg_attr_obj_rmtest_2207',
22 | 'path/to/vg_region_rmtest_2207',
23 | "path/to/refcoco_region_2207",
24 | "path/to/gqa_obj_2207",
25 | "path/to/flickr_obj_2207",
26 | "path/to/openimages_v6_maxrez800_obj_region_2207",
27 | "path/to/object365_obj_2207",
28 | ] # objects & regions;
29 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False,
30 | iter_perc: 1.0, batch_size: 32, max_images: 14, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4}
31 |
32 |
33 | train_file_videos: [
34 | "path/to/howto100m_filtered", # 1704857
35 | "path/to/ytt180m_filtered", # 5306576
36 | ]
37 | train_file_videos_aux: [
38 | "hdfs://haruna/home/byte_ailab_litg/user/wangjiawei.424/dataset/webvid2_5m", # 2492317
39 | ] # cleaner data
40 | video_aux_iter_perc: 0.35 # aux_iter_perc% iterate on train_file_videos_aux, (1-aux_iter_perc%) iterate on train_file_videos
41 |
42 | videos: {image_key: "video_frames",
43 | is_image_rpath: False, # read path or base64 encoding
44 | caption_key: "text",
45 | tokenized: False, # whether texts have been tokenized
46 | frame_len: 3, # 5 -> 3, too slow
47 | use_random_sampling: True,
48 | combine_continuous_clips: True,
49 | mininum_frames_before_sampling: 8, # 10 -> 8 since webvid has 15+ frames per video on average
50 | batch_size: 20, # 20*32=640 64 -> 48, too slow
51 | iter_perc: 1.0,
52 | num_workers: 8, # better -> the total number of training files % (world_size * num_workers) == 0
53 | }
54 |
55 |
56 | ## Vision Encoder
57 | use_beit_v2: True
58 | vision_config: 'configs/config_beit2_large.json'
59 | image_res: 224
60 | patch_size: 16
61 | local_attn_depth: -1
62 |
63 | frame_len: 3
64 | add_frame_pos: True
65 | video_encoding: 'avgpool'
66 |
67 |
68 | ## Text Encoder (& Cross Encoder)
69 | text_encoder: 'data/bert-large-uncased-12l'
70 | text_num_hidden_layers: 18 # 12 + 6
71 | text_fusion_start_at: 12
72 |
73 |
74 | ## Training
75 | mixed_in_batch: True
76 | calc_image_bbox_loss: False
77 | embed_dim: 256
78 | temp: 0.07
79 |
80 | max_words: 40
81 | max_tokens: 40
82 | mask_prob: 0.5
83 | max_masks: 12
84 | mask_whole_word: True
85 | skipgram_prb: 0.2
86 | skipgram_size: 3
87 |
88 |
89 | ## Other Settings
90 | ckpt_frequent_step: 50000
91 | ckpt_frequent: 1000000 # epoch
92 | optimizer: {opt: adamW, lr: 3e-5, weight_decay: 0.01, lr_mult: 2}
93 | schedular: {sched: linear, lr: 3e-5, epochs: 81, num_warmup_steps: 1000} # 400k steps, video -> 27 epochs
94 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0}
95 |
96 |
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/configs/pretrain/x2vlm_large_4m.yaml:
--------------------------------------------------------------------------------
1 | ## Data
2 | train_file: [
3 | "path/to/coco_testset_filtered",
4 | "path/to/vg_testset_filtered",
5 | "path/to/sbu_bs64",
6 | "path/to/cc3m_bs64",
7 | ]
8 |
9 | train_dataset_size: 5114489 # for IterableDataset
10 | images: {image_key: "binary",
11 | is_image_rpath: False, # read path or base64 encoding
12 | caption_key: "desc",
13 | tokenized: False, # whether texts have been tokenized
14 | batch_size: 64, # 128 x 16 = 1024
15 | num_workers: 4, # better -> the total number of training files % (world_size * num_workers) == 0
16 | }
17 |
18 |
19 | train_file_regions: [
20 | 'path/to/coco2017_obj_rmtest_2207',
21 | 'path/to/vg_attr_obj_rmtest_2207',
22 | 'path/to/vg_region_rmtest_2207',
23 | "path/to/refcoco_region_2207",
24 | "path/to/gqa_obj_2207",
25 | "path/to/flickr_obj_2207",
26 | ]
27 | regions: {image_key: "binary", is_image_rpath: False, caption_key: "caption", tokenized: False,
28 | iter_perc: 1, batch_size: 64, max_images: 25, max_regions: 5, min_perc_in_image: 0.5, num_workers: 4}
29 |
30 |
31 | ## Vision Encoder
32 | use_beit_v2: True
33 | vision_config: 'configs/config_beit2_large.json'
34 | image_res: 224
35 | patch_size: 16
36 | local_attn_depth: -1
37 |
38 |
39 | ## Text Encoder (& Cross Encoder)
40 | text_encoder: 'data/bert-large-uncased-12l'
41 | text_num_hidden_layers: 18 # 12 + 6
42 | text_fusion_start_at: 12
43 |
44 |
45 | ## Training
46 | mixed_in_batch: True
47 | calc_image_bbox_loss: False
48 | embed_dim: 256
49 | temp: 0.07
50 |
51 | max_words: 40
52 | max_tokens: 40
53 | mask_prob: 0.5
54 | max_masks: 12
55 | mask_whole_word: True
56 | skipgram_prb: 0.2
57 | skipgram_size: 3
58 |
59 |
60 | ## Other Settings
61 | ckpt_frequent_step: 50000
62 | ckpt_frequent: 1000000000 # epoch
63 | optimizer: {opt: adamW, lr: 5e-5, weight_decay: 0.01, lr_mult: 2}
64 | schedular: {sched: linear, lr: 5e-5, epochs: 101, num_warmup_steps: 2500} # we use 250k steps
65 | accelerator: {SYNCBN: false, FP16_OPT_LEVEL: O1, FP16_LOSS_SCALE: dynamic, RNG_SEED: 42, GRAD_ACCUMULATE_STEPS: 1, CLIP_GRAD_NORM: 1.0}
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/dataset/dist_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276)
4 | # Github: https://github.com/zengyan-97/X-VLM
5 | # Copyright (c) 2022, ByteDance Inc.
6 | # All rights reserved.
7 |
8 | import sys
9 | from typing import List, Any
10 | import warnings
11 | import random
12 | from itertools import cycle
13 | import torch
14 | from torch.utils.data import IterableDataset
15 |
16 | from utils.hdfs_io import hopen, hlist_files, hexists
17 |
18 |
19 | class DistLineReadingDataset(IterableDataset): # pylint: disable=W0223
20 | """
21 | iterate a set of folders.
22 | """
23 | def __init__(self,
24 | data_path,
25 | rank: int = 0,
26 | world_size: int = 1,
27 | shuffle: bool = False,
28 | repeat: bool = False):
29 | super().__init__()
30 | self.shuffle = shuffle
31 | self.rank = rank
32 | self.world_size = world_size
33 |
34 | if isinstance(data_path, str):
35 | data_path = data_path.split(',')
36 | elif isinstance(data_path, list):
37 | pass
38 | else:
39 | raise ValueError(data_path)
40 |
41 | for p in data_path:
42 | assert hexists(p), f"not exist {p}"
43 |
44 | self.files = hlist_files(data_path)
45 | self.files = [f for f in self.files if f.find('_SUCCESS') < 0]
46 | self.is_hdfs = data_path[0].startswith('hdfs')
47 |
48 | self.repeat = repeat
49 | print('[DATA]--all dataset containing {} files.'.format(len(self.files)))
50 | if len(self.files) % self.world_size != 0:
51 | print('[DATA]--Whole dataset file num %s cannot split to worldsize %s ' %
52 | (len(self.files), self.world_size))
53 | sys.stdout.flush()
54 |
55 | def generate(self):
56 | if self.world_size == 1 or len(self.files) == 1:
57 | cur_dataloader_files = self.files
58 | else:
59 | cur_dataloader_files = split_shard(
60 | self.files, self.rank, self.world_size)
61 |
62 | while True:
63 | if self.shuffle:
64 | random.shuffle(cur_dataloader_files)
65 | worker_info = torch.utils.data.get_worker_info()
66 |
67 | if worker_info is not None:
68 | if len(cur_dataloader_files) % worker_info.num_workers != 0:
69 | print('[DATA]--current dataloader %s file num %s cannot split to worker_num %s ' %
70 | (self.rank, len(cur_dataloader_files), worker_info.num_workers))
71 | cur_worker_files = split_shard(
72 | cur_dataloader_files, worker_info.id, worker_info.num_workers)
73 | if worker_info.id == 0:
74 | print("[DataLoader] --> Rank:{} Workers:[{} ~ {}][{}] Size of process file:{} ...".format(
75 | self.rank, 0, worker_info.num_workers - 1, worker_info.id, len(cur_dataloader_files)))
76 | else:
77 | cur_worker_files = cur_dataloader_files
78 |
79 | if self.shuffle:
80 | random.shuffle(cur_worker_files)
81 | for filepath in cur_worker_files:
82 | if self.is_hdfs:
83 | with hopen(filepath, 'r') as reader:
84 | for line in reader:
85 | yield line.decode()
86 | continue
87 | with open(filepath, 'r') as reader:
88 | for line in reader:
89 | yield line
90 |
91 | if not self.repeat:
92 | break
93 |
94 | def __iter__(self):
95 | return self.generate()
96 |
97 |
98 | def split_shard(data: List[Any], shard_idx: int, shard_size: int):
99 | num = len(data)
100 | if num < shard_size:
101 | raise RuntimeError("num:{} < shard size:{}".format(num, shard_size))
102 | start_idx = (num * shard_idx) // shard_size
103 | end_idx = (num * (shard_idx + 1)) // shard_size
104 | return data[start_idx: end_idx]
105 |
--------------------------------------------------------------------------------
/dataset/grounding_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import math
4 | import random
5 | from random import random as rand
6 |
7 | import torch
8 | from torch.utils.data import Dataset
9 |
10 | from torchvision.transforms.functional import hflip, resize
11 |
12 | from PIL import Image
13 | from dataset.utils import pre_caption
14 | from refTools.refer_python3 import REFER
15 |
16 |
17 | class grounding_dataset(Dataset):
18 | def __init__(self, ann_file, transform, image_root, max_words=30, mode='train'):
19 | self.ann = []
20 | for f in ann_file:
21 | self.ann += json.load(open(f, 'r'))
22 | self.transform = transform
23 | self.image_root = image_root
24 | self.max_words = max_words
25 | self.mode = mode
26 |
27 | if self.mode == 'train':
28 | self.img_ids = {}
29 | n = 0
30 | for ann in self.ann:
31 | img_id = ann['image'].split('/')[-1]
32 | if img_id not in self.img_ids.keys():
33 | self.img_ids[img_id] = n
34 | n += 1
35 |
36 | def __len__(self):
37 | return len(self.ann)
38 |
39 | def __getitem__(self, index):
40 |
41 | ann = self.ann[index]
42 |
43 | image_path = os.path.join(self.image_root, ann['image'])
44 | image = Image.open(image_path).convert('RGB')
45 | image = self.transform(image)
46 |
47 | caption = pre_caption(ann['text'], self.max_words)
48 |
49 | if self.mode == 'train':
50 | img_id = ann['image'].split('/')[-1]
51 |
52 | return image, caption, self.img_ids[img_id]
53 | else:
54 | return image, caption, ann['ref_id']
55 |
56 |
57 | class grounding_dataset_bbox(Dataset):
58 | def __init__(self, ann_file, transform, image_root, max_words=30, mode='train', config=None):
59 | self.image_res = config['image_res']
60 | self.careful_hflip = config['careful_hflip']
61 |
62 | self.ann = []
63 | for f in ann_file:
64 | self.ann += json.load(open(f, 'r'))
65 | self.transform = transform
66 | self.image_root = image_root
67 | self.max_words = max_words
68 | self.mode = mode
69 |
70 | if self.mode == 'train':
71 | self.refer = REFER(config['refcoco_data'], 'refcoco+', 'unc')
72 | self.img_ids = {}
73 | n = 0
74 | for ann in self.ann:
75 | img_id = ann['image'].split('/')[-1]
76 | if img_id not in self.img_ids.keys():
77 | self.img_ids[img_id] = n
78 | n += 1
79 |
80 | def __len__(self):
81 | return len(self.ann)
82 |
83 | def left_or_right_in_caption(self, caption):
84 | if ('left' in caption) or ('right' in caption):
85 | return True
86 |
87 | return False
88 |
89 | def __getitem__(self, index):
90 |
91 | ann = self.ann[index]
92 | caption = pre_caption(ann['text'], self.max_words)
93 |
94 | image_path = os.path.join(self.image_root, ann['image'])
95 | image = Image.open(image_path).convert('RGB')
96 | W, H = image.size
97 |
98 | if self.mode == 'train':
99 | # random crop
100 | x, y, w, h = self.refer.refToAnn[ann['ref_id']]['bbox']
101 | assert (x >= 0) and (y >= 0) and (x + w <= W) and (y + h <= H) and (w > 0) and (
102 | h > 0), "elem invalid"
103 |
104 | x0, y0 = random.randint(0, math.floor(x)), random.randint(0, math.floor(y))
105 | x1, y1 = random.randint(min(math.ceil(x + w), W), W), random.randint(min(math.ceil(y + h), H),
106 | H) # fix bug: max -> min
107 | w0, h0 = x1 - x0, y1 - y0
108 | assert (x0 >= 0) and (y0 >= 0) and (x0 + w0 <= W) and (y0 + h0 <= H) and (w0 > 0) and (
109 | h0 > 0), "elem randomcrop, invalid"
110 | image = image.crop((x0, y0, x0 + w0, y0 + h0))
111 |
112 | W, H = image.size
113 |
114 | do_hflip = False
115 | if rand() < 0.5:
116 | if self.careful_hflip and self.left_or_right_in_caption(caption):
117 | pass
118 | else:
119 | image = hflip(image)
120 | do_hflip = True
121 |
122 | image = resize(image, [self.image_res, self.image_res], interpolation=Image.BICUBIC)
123 | image = self.transform(image)
124 |
125 | # axis transform: for crop
126 | x = x - x0
127 | y = y - y0
128 |
129 | if do_hflip: # flipped applied
130 | x = (W - x) - w # W is w0
131 |
132 | # resize applied
133 | x = self.image_res / W * x
134 | w = self.image_res / W * w
135 | y = self.image_res / H * y
136 | h = self.image_res / H * h
137 |
138 | center_x = x + 1 / 2 * w
139 | center_y = y + 1 / 2 * h
140 |
141 | target_bbox = torch.tensor([center_x / self.image_res, center_y / self.image_res,
142 | w / self.image_res, h / self.image_res], dtype=torch.float)
143 |
144 | return image, caption, target_bbox
145 |
146 | else:
147 | image = self.transform(image) # test_transform
148 | return image, caption, ann['ref_id']
--------------------------------------------------------------------------------
/dataset/nlvr_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 | from dataset.utils import pre_caption
6 |
7 |
8 | class nlvr_dataset(Dataset):
9 | def __init__(self, ann_file, transform, image_root=None):
10 | self.ann = []
11 |
12 | if isinstance(ann_file, list):
13 | for f in ann_file:
14 | self.ann += json.load(open(f, 'r'))
15 |
16 | elif isinstance(ann_file, str):
17 | self.ann += json.load(open(ann_file, 'r'))
18 |
19 | else:
20 | raise ValueError(f"ann_file == {ann_file}")
21 |
22 | self.transform = transform
23 | self.image_root = image_root
24 | self.max_words = 30
25 |
26 | def __len__(self):
27 | return len(self.ann)
28 |
29 | def __getitem__(self, index):
30 |
31 | ann = self.ann[index]
32 |
33 | if self.image_root is None:
34 | image0_path = ann['images'][0]
35 | else:
36 | image0_path = os.path.join(self.image_root, ann['images'][0])
37 |
38 | image0 = Image.open(image0_path).convert('RGB')
39 | image0 = self.transform(image0)
40 |
41 | if self.image_root is None:
42 | image1_path = ann['images'][1]
43 | else:
44 | image1_path = os.path.join(self.image_root, ann['images'][1])
45 |
46 | image1 = Image.open(image1_path).convert('RGB')
47 | image1 = self.transform(image1)
48 |
49 | sentence = pre_caption(ann['sentence'], self.max_words)
50 |
51 | if (ann['label'] == 'True') or (ann['label'] is True):
52 | label = 1
53 |
54 | elif (ann['label'] == 'False') or (ann['label'] is False):
55 | label = 0
56 |
57 | else:
58 | raise ValueError(f"unsupported label: {ann['label']}")
59 |
60 | return image0, image1, sentence, label
61 |
--------------------------------------------------------------------------------
/dataset/retrieval_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import io
3 | import os
4 | import torch
5 | from base64 import b64decode
6 | from torch.utils.data import Dataset
7 | import random
8 | import traceback
9 |
10 | from PIL import Image
11 | from PIL import ImageFile
12 |
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | Image.MAX_IMAGE_PIXELS = None
15 |
16 | from dataset.utils import pre_caption, sample_frame_ids
17 |
18 |
19 | class re_train_dataset(Dataset):
20 | def __init__(self, ann_file, transform, image_root='', max_words=30,
21 | index_key='image_id', vision_key='image', text_key='caption',
22 | is_video=False, frame_len=1):
23 | self.ann = []
24 | for f in ann_file:
25 | self.ann += json.load(open(f, 'r'))
26 | self.transform = transform
27 | self.image_root = image_root
28 | self.max_words = max_words
29 | self.img_ids = {}
30 |
31 | self.index_key = index_key
32 | self.vision_key = vision_key
33 | self.text_key = text_key
34 | self.is_video = is_video
35 | self.frame_len = frame_len
36 | self.training = True
37 |
38 | n = 0
39 | for ann in self.ann:
40 | img_id = ann[self.index_key]
41 | if img_id not in self.img_ids.keys():
42 | self.img_ids[img_id] = n
43 | n += 1
44 |
45 | def __len__(self):
46 | return len(self.ann)
47 |
48 | def __getitem__(self, index):
49 |
50 | ann = self.ann[index]
51 | assert isinstance(ann, dict)
52 |
53 | vision_rpath = os.path.join(self.image_root, ann[self.vision_key]) if len(self.image_root) else ann[self.vision_key]
54 |
55 | if self.is_video:
56 | frames_b64 = json.load(open(vision_rpath, 'r'))
57 |
58 | selected_indices = sample_frame_ids(len(frames_b64), self.frame_len, self.training)
59 |
60 | vision_input = []
61 | for i in selected_indices:
62 | image = Image.open(io.BytesIO(b64decode(frames_b64[i]))).convert("RGB")
63 | image = self.transform(image)
64 | vision_input.append(image)
65 |
66 | else:
67 | image = Image.open(vision_rpath).convert('RGB')
68 | vision_input = self.transform(image)
69 |
70 | caption = pre_caption(ann[self.text_key], self.max_words)
71 |
72 | return vision_input, caption, self.img_ids[ann[self.index_key]]
73 |
74 | def collate_fn(self, batch):
75 | batch_tensors = []
76 | for i, x in enumerate(zip(*batch)):
77 | if x[0] is None:
78 | batch_tensors.append(None)
79 |
80 | elif isinstance(x[0], torch.Tensor):
81 | batch_tensors.append(torch.stack(x))
82 |
83 | elif isinstance(x[0], list):
84 | assert i == 0 # # frames !!! always first
85 | batch_size = len(x)
86 | frames = torch.stack(sum(x, [])) # flatten
87 | _, c, h, w = frames.shape
88 | frames = frames.reshape([batch_size, self.frame_len, c, h, w])
89 | batch_tensors.append(frames)
90 |
91 | elif isinstance(x[0], str): # should be texts, put in tokenizer afterwards
92 | batch_tensors.append(x)
93 |
94 | else:
95 | batch_tensors.append(torch.tensor(x, dtype=torch.long))
96 |
97 | return batch_tensors
98 |
99 |
100 | class re_eval_dataset(Dataset):
101 | def __init__(self, ann_file, transform, image_root, max_words=30,
102 | index_key='image_id', vision_key='image', text_key='caption',
103 | is_video=False, frame_len=1, ):
104 | self.ann = json.load(open(ann_file, 'r'))
105 | self.transform = transform
106 | self.image_root = image_root
107 | self.max_words = max_words
108 |
109 | self.text = []
110 | self.image = []
111 | self.txt2img = {}
112 | self.img2txt = {}
113 |
114 | self.index_key = index_key
115 | self.vision_key = vision_key
116 | self.text_key = text_key
117 | self.is_video = is_video
118 | self.frame_len = frame_len
119 | self.training = False
120 |
121 | txt_id = 0
122 | for img_id, ann in enumerate(self.ann):
123 | self.image.append(ann[self.vision_key])
124 | self.img2txt[img_id] = []
125 |
126 | assert isinstance(ann[self.text_key], list)
127 |
128 | for i, caption in enumerate(ann[self.text_key]):
129 | self.text.append(pre_caption(caption, self.max_words))
130 | self.img2txt[img_id].append(txt_id)
131 | self.txt2img[txt_id] = img_id
132 | txt_id += 1
133 |
134 | def __len__(self):
135 | return len(self.image)
136 |
137 | def __getitem__(self, index):
138 | if len(self.image_root):
139 | image_path = os.path.join(self.image_root, self.ann[index][self.vision_key])
140 | else:
141 | image_path = self.ann[index][self.vision_key]
142 |
143 | if self.is_video:
144 | frames_b64 = json.load(open(image_path, 'r'))
145 | selected_indices = sample_frame_ids(len(frames_b64), self.frame_len, self.training)
146 |
147 | frames = []
148 | for i in selected_indices:
149 | image = Image.open(io.BytesIO(b64decode(frames_b64[i]))).convert("RGB")
150 | image = self.transform(image)
151 | frames.append(image)
152 |
153 | frames = torch.stack(frames, dim=0) # (frame_len, 3, 384, 384)
154 |
155 | return frames, index
156 |
157 | else:
158 | image = Image.open(image_path).convert('RGB')
159 | image = self.transform(image)
160 |
161 | return image, index
--------------------------------------------------------------------------------
/dataset/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import BertTokenizer, RobertaTokenizer, XLMRobertaTokenizer, AutoTokenizer
2 | from dataset.tokenizers.bert_tokenizer_with_dropout import BertTokenizerWithDropout
3 |
4 |
5 | def build_tokenizer(text_encoder: str, dropout=0):
6 | if ('bert-base-uncased' in text_encoder) or ('bert-large-uncased' in text_encoder):
7 | if dropout > 0:
8 | tokenizer = BertTokenizerWithDropout.from_pretrained(text_encoder, dropout=dropout)
9 | else:
10 | tokenizer = BertTokenizer.from_pretrained(text_encoder)
11 |
12 | elif ('xlm-roberta-base' in text_encoder) or ('xlm-roberta-large' in text_encoder):
13 | tokenizer = XLMRobertaTokenizer.from_pretrained(text_encoder)
14 |
15 | elif ('roberta-base' in text_encoder) or ('roberta-large' in text_encoder):
16 | tokenizer = RobertaTokenizer.from_pretrained(text_encoder)
17 |
18 | else:
19 | raise NotImplementedError(f"tokenizer for {text_encoder}")
20 |
21 | # always use cls and sep
22 | tokenizer.add_special_tokens({'bos_token': tokenizer.cls_token})
23 | tokenizer.add_special_tokens({'eos_token': tokenizer.sep_token})
24 |
25 | return tokenizer
--------------------------------------------------------------------------------
/dataset/tokenizers/bert_tokenizer_with_dropout.py:
--------------------------------------------------------------------------------
1 | from transformers import BertTokenizer
2 | import random
3 |
4 | class BertTokenizerWithDropout(BertTokenizer):
5 | def __init__(
6 | self,
7 | vocab_file,
8 | do_lower_case=True,
9 | do_basic_tokenize=True,
10 | never_split=None,
11 | unk_token="[UNK]",
12 | sep_token="[SEP]",
13 | pad_token="[PAD]",
14 | cls_token="[CLS]",
15 | mask_token="[MASK]",
16 | tokenize_chinese_chars=True,
17 | **kwargs
18 | ):
19 | """Constructs a BertTokenizerWithDropout.
20 | Args:
21 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file
22 | **do_lower_case**: (`optional`) boolean (default True)
23 | Whether to lower case the input
24 | Only has an effect when do_basic_tokenize=True
25 | **do_basic_tokenize**: (`optional`) boolean (default True)
26 | Whether to do basic tokenization before wordpiece.
27 | **never_split**: (`optional`) list of string
28 | List of tokens which will never be split during tokenization.
29 | Only has an effect when do_basic_tokenize=True
30 | **tokenize_chinese_chars**: (`optional`) boolean (default True)
31 | Whether to tokenize Chinese characters.
32 | This should likely be deactivated for Japanese:
33 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
34 | """
35 | super().__init__(
36 | vocab_file,
37 | do_lower_case=do_lower_case,
38 | do_basic_tokenize=do_basic_tokenize,
39 | never_split=never_split,
40 | unk_token=unk_token,
41 | sep_token=sep_token,
42 | pad_token=pad_token,
43 | cls_token=cls_token,
44 | mask_token=mask_token,
45 | tokenize_chinese_chars=tokenize_chinese_chars,
46 | **kwargs
47 | )
48 | dropout = kwargs.get("dropout", 0)
49 | assert 0<=dropout<=1
50 | self.wordpiece_tokenizer = WordpieceTokenizerWithDropout(vocab=self.vocab, unk_token=self.unk_token, dropout=dropout)
51 |
52 |
53 | class WordpieceTokenizerWithDropout(object):
54 | """Runs WordPiece tokenization with dropout."""
55 |
56 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100, dropout=0):
57 | self.vocab = vocab
58 | self.unk_token = unk_token
59 | self.max_input_chars_per_word = max_input_chars_per_word
60 | self.dropout = dropout
61 |
62 | def tokenize(self, text, **kwargs):
63 | """Tokenizes a piece of text into its word pieces.
64 | This uses a greedy longest-match-first algorithm to perform tokenization
65 | using the given vocabulary.
66 | For example:
67 | input = "unaffable"
68 | output = ["un", "##aff", "##able"]
69 | Args:
70 | text: A single token or whitespace separated tokens. This should have
71 | already been passed through `BasicTokenizer`.
72 | Returns:
73 | A list of wordpiece tokens.
74 | """
75 |
76 | output_tokens = []
77 | for token in whitespace_tokenize(text):
78 | chars = list(token)
79 | if len(chars) > self.max_input_chars_per_word:
80 | output_tokens.append(self.unk_token)
81 | continue
82 | if self.dropout == 1:
83 | output_tokens.append(chars[0])
84 | output_tokens.extend("##{}".format(char) for char in chars[1:])
85 | continue
86 | is_bad = False
87 | start = 0
88 | sub_tokens = []
89 | while start < len(chars):
90 | end = len(chars)
91 | cur_substr = None
92 | while start < end:
93 | substr = "".join(chars[start:end])
94 | if start > 0:
95 | substr = "##" + substr
96 | if substr in self.vocab and random.random() >= self.dropout:
97 | cur_substr = substr
98 | break
99 | end -= 1
100 | if cur_substr is None:
101 | is_bad = True
102 | break
103 | sub_tokens.append(cur_substr)
104 | start = end
105 |
106 | if is_bad:
107 | output_tokens.append(self.unk_token)
108 | else:
109 | output_tokens.extend(sub_tokens)
110 | return output_tokens
111 |
112 |
113 | def whitespace_tokenize(text):
114 | """Runs basic whitespace cleaning and splitting on a piece of text."""
115 | text = text.strip()
116 | if not text:
117 | return []
118 | tokens = text.split()
119 | return tokens
120 |
--------------------------------------------------------------------------------
/dataset/wit_dataset.py:
--------------------------------------------------------------------------------
1 | # Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training (https://arxiv.org/abs/2206.00621)
2 | # Github: https://github.com/zengyan-97/CCLM
3 | # Copyright (c) 2022, ByteDance Inc.
4 | # All rights reserved.
5 |
6 | import json
7 | import os
8 | from collections import OrderedDict
9 |
10 | from torch.utils.data import Dataset
11 |
12 | from PIL import Image
13 | from PIL import ImageFile
14 | import base64
15 | import io
16 | from tqdm import tqdm
17 |
18 | ImageFile.LOAD_TRUNCATED_IMAGES = True
19 | Image.MAX_IMAGE_PIXELS = None
20 |
21 | from dataset.utils import pre_caption
22 |
23 |
24 | class wit_train_dataset(Dataset):
25 | def __init__(self, ann_file, transform, max_words=80):
26 | self.ann = []
27 | for f in ann_file:
28 | for line in tqdm(open(f)):
29 | ann = json.loads(line)
30 | if not ann['caption_reference_description']:
31 | continue
32 | self.ann.append(ann)
33 | self.transform = transform
34 | self.max_words = max_words
35 | self.img_ids = {}
36 |
37 | n = 0
38 | for ann in self.ann:
39 | img_id = ann['image_url']
40 | if img_id not in self.img_ids.keys():
41 | self.img_ids[img_id] = n
42 | n += 1
43 |
44 | def __len__(self):
45 | return len(self.ann)
46 |
47 | def __getitem__(self, index):
48 |
49 | ann = self.ann[index]
50 |
51 | image_str = base64.b64decode(ann['image_content'])
52 | image = Image.open(io.BytesIO(image_str)).convert("RGB")
53 | image = self.transform(image)
54 | try:
55 | caption = pre_caption(ann['caption_reference_description'], self.max_words)
56 | except Exception:
57 | caption = ann['caption_reference_description']
58 |
59 | return image, caption, self.img_ids[ann['image_url']]
60 |
61 |
62 | class wit_eval_dataset(Dataset):
63 | def __init__(self, ann_file, transform, max_words=80):
64 | self.ann = []
65 | for line in open(ann_file, 'r'):
66 | ann = json.loads(line)
67 | if not ann['caption_reference_description']:
68 | continue
69 | self.ann.append(ann)
70 | self.transform = transform
71 | self.max_words = max_words
72 | self.text = []
73 | self.image = OrderedDict()
74 | self.txt2img = {}
75 | self.img2txt = {}
76 |
77 | txt_id = 0
78 | img_id = 0
79 | for ann in self.ann:
80 | if ann['image_url'] in self.image:
81 | cur_img_id = self.image[ann['image_url']][0]
82 | self.img2txt[cur_img_id].append(txt_id)
83 | self.txt2img[txt_id] = cur_img_id
84 | else:
85 | self.img2txt[img_id] = [txt_id]
86 | self.image[ann['image_url']] = (img_id, ann['image_content'])
87 | self.txt2img[txt_id] = img_id
88 | img_id += 1
89 | if ann['caption_reference_description'] == '.':
90 | self.text.append(ann['caption_reference_description'])
91 | else:
92 | self.text.append(pre_caption(ann['caption_reference_description'], self.max_words))
93 | txt_id += 1
94 |
95 | def __len__(self):
96 | return len(self.image)
97 |
98 | def __getitem__(self, index):
99 | image_str = base64.b64decode(list(self.image.values())[index][1])
100 | image = Image.open(io.BytesIO(image_str)).convert("RGB")
101 | image = self.transform(image)
102 |
103 | return image, index
104 |
--------------------------------------------------------------------------------
/dataset/xflickrco_dataset.py:
--------------------------------------------------------------------------------
1 | # Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training (https://arxiv.org/abs/2206.00621)
2 | # Github: https://github.com/zengyan-97/CCLM
3 | # Copyright (c) 2022, ByteDance Inc.
4 | # All rights reserved.
5 |
6 | import json
7 | import os
8 |
9 | from torch.utils.data import Dataset
10 |
11 | from PIL import Image
12 | from PIL import ImageFile
13 |
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 | Image.MAX_IMAGE_PIXELS = None
16 |
17 | from dataset.utils import pre_caption
18 |
19 |
20 | class xflickrco_train_dataset(Dataset):
21 | def __init__(self, ann_file, transform, image_root, max_words=80):
22 | self.ann = []
23 | for f in ann_file:
24 | for line in open(f):
25 | ann = json.loads(line)
26 | for i in range(len(ann['sentences'])):
27 | self.ann.append({
28 | 'caption': ann['sentences'][i],
29 | 'id': ann['id'],
30 | 'img_path': ann['img_path']})
31 | self.transform = transform
32 | self.image_root = image_root
33 | self.max_words = max_words
34 | self.img_ids = {}
35 |
36 | n = 0
37 | for ann in self.ann:
38 | img_id = ann['id']
39 | if img_id not in self.img_ids.keys():
40 | self.img_ids[img_id] = n
41 | n += 1
42 |
43 | def __len__(self):
44 | return len(self.ann)
45 |
46 | def __getitem__(self, index):
47 |
48 | ann = self.ann[index]
49 |
50 | image_path = os.path.join(self.image_root, ann['img_path'])
51 | image = Image.open(image_path).convert('RGB')
52 | image = self.transform(image)
53 |
54 | caption = pre_caption(ann['caption'], self.max_words)
55 |
56 | return image, caption, self.img_ids[ann['id']]
57 |
58 |
59 | class xflickrco_eval_dataset(Dataset):
60 | def __init__(self, ann_file, transform, image_root, max_words=80):
61 | self.ann = []
62 | for line in open(ann_file, 'r'):
63 | ann = json.loads(line)
64 |
65 | # judge if caption if empty
66 | empty = True
67 | for sent in ann['sentences']:
68 | if sent:
69 | empty = False
70 | break
71 | if empty:
72 | print(ann_file, 'has a empty caption')
73 | continue
74 |
75 | self.ann.append(ann)
76 | self.transform = transform
77 | self.image_root = image_root
78 | self.max_words = max_words
79 |
80 | self.text = []
81 | self.image = []
82 | self.txt2img = {}
83 | self.img2txt = {}
84 |
85 | txt_id = 0
86 | for img_id, ann in enumerate(self.ann):
87 | self.image.append(ann['img_path'])
88 | self.img2txt[img_id] = []
89 | for i, caption in enumerate(ann['sentences']):
90 | self.text.append(pre_caption(caption, self.max_words))
91 | self.img2txt[img_id].append(txt_id)
92 | self.txt2img[txt_id] = img_id
93 | txt_id += 1
94 |
95 | def __len__(self):
96 | return len(self.image)
97 |
98 | def __getitem__(self, index):
99 | if 'COCO' in self.image[index]:
100 | image_path = os.path.join(self.image_root['coco'], self.image[index])
101 | else:
102 | image_path = os.path.join(self.image_root['flickr30k'], self.image[index])
103 | image = Image.open(image_path).convert('RGB')
104 | image = self.transform(image)
105 |
106 | return image, index
107 |
--------------------------------------------------------------------------------
/dataset/xvnli_dataset.py:
--------------------------------------------------------------------------------
1 | # Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training (https://arxiv.org/abs/2206.00621)
2 | # Github: https://github.com/zengyan-97/CCLM
3 | # Copyright (c) 2022, ByteDance Inc.
4 | # All rights reserved.
5 |
6 | import json
7 | import os
8 | from torch.utils.data import Dataset
9 | from PIL import Image
10 | from dataset.utils import pre_caption
11 |
12 |
13 | class xvnli_dataset(Dataset):
14 | def __init__(self, ann_file, transform, image_root, max_words=80):
15 | self.label_mapper = {"contradiction": 0, "entailment": 1, "neutral": 2}
16 |
17 | self.ann = []
18 |
19 | if type(ann_file) == str:
20 | ann_file = [ann_file]
21 |
22 | invalid_cnt = 0
23 | for f in ann_file:
24 | for line in open(f, 'r'):
25 | ann = json.loads(line)
26 | if ann['gold_label'] not in self.label_mapper:
27 | invalid_cnt += 1
28 | continue
29 | self.ann.append(ann)
30 |
31 | if not self.ann:
32 | raise ValueError(f"ann_file == {ann_file}")
33 |
34 | print('data num: ', len(self.ann))
35 | print('invalid num: ', invalid_cnt)
36 |
37 | self.transform = transform
38 | self.image_root = image_root
39 | self.max_words = max_words
40 |
41 | def __len__(self):
42 | return len(self.ann)
43 |
44 | def __getitem__(self, index):
45 | ann = self.ann[index]
46 |
47 | image_path = os.path.join(self.image_root, ann['Flikr30kID']+'.jpg')
48 | image = Image.open(image_path).convert('RGB')
49 | image = self.transform(image)
50 |
51 | sentence = pre_caption(ann['sentence2'], self.max_words)
52 |
53 | label = self.label_mapper[ann['gold_label']]
54 |
55 | return image, sentence, label
56 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.xvlm import XVLMBase
2 | from models.xvlm import build_mlp
3 | from models.xvlm import load_pretrained
--------------------------------------------------------------------------------
/models/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x): # 这个用了
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 | return torch.stack(b, dim=-1)
14 |
15 |
16 | def box_xyxy_to_cxcywh(x):
17 | x0, y0, x1, y1 = x.unbind(-1)
18 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 | (x1 - x0), (y1 - y0)]
20 | return torch.stack(b, dim=-1)
21 |
22 |
23 | # modified from torchvision to also return the union
24 | def box_iou(boxes1, boxes2):
25 | area1 = box_area(boxes1)
26 | area2 = box_area(boxes2)
27 |
28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
30 |
31 | wh = (rb - lt).clamp(min=0) # [N,M,2]
32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
33 |
34 | union = area1[:, None] + area2 - inter
35 |
36 | iou = inter / union
37 | return iou, union
38 |
39 |
40 | def generalized_box_iou(boxes1, boxes2):
41 | """
42 | Generalized IoU from https://giou.stanford.edu/
43 |
44 | The boxes should be in [x0, y0, x1, y1] format
45 |
46 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
47 | and M = len(boxes2)
48 | """
49 | iou, union = box_iou(boxes1, boxes2)
50 |
51 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
52 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
53 |
54 | wh = (rb - lt).clamp(min=0) # [N,M,2]
55 | area = wh[:, :, 0] * wh[:, :, 1]
56 |
57 | return iou - (area - union) / area
58 |
59 |
60 |
--------------------------------------------------------------------------------
/models/model_classification.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import json
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torch.nn import MSELoss
9 |
10 | from einops import rearrange
11 |
12 | from models.xvlm import XVLMBase, XVLMPlusBase
13 | from models.xvlm import build_mlp
14 |
15 |
16 | class XVLMForClassification(XVLMBase):
17 | def __init__(self, config):
18 | super().__init__(config, load_vision_params=False, load_text_params=False,
19 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False)
20 |
21 | feature_dim = self.vision_width if config.get('task_name') == 'imagenet' else self.text_width
22 | self.cls_head = build_mlp(input_dim=feature_dim, output_dim=config['num_labels'])
23 |
24 | def forward(self, image, text_ids, text_atts, targets=None, train=True):
25 | if image is None:
26 | output_cls = self.get_text_embeds_12L(text_ids, text_atts)[:, 0, :]
27 |
28 | elif text_ids is None:
29 | image_embeds, _ = self.get_vision_embeds(image)
30 | output_cls = image_embeds[:, 0, :]
31 |
32 | else:
33 | image_embeds, image_atts = self.get_vision_embeds(image)
34 |
35 | output_cls = self.get_cross_embeds(image_embeds, image_atts,
36 | text_ids=text_ids, text_atts=text_atts)[:, 0, :]
37 |
38 | prediction = self.cls_head(output_cls)
39 | if prediction.shape[-1] == 1:
40 | # We are doing regression
41 | loss_fct = MSELoss()
42 | return loss_fct(prediction.view(-1), targets.view(-1)) if train else prediction
43 |
44 | return F.cross_entropy(prediction, targets, ignore_index=-100) if train else prediction
45 |
46 |
47 | class XVLMForVQAClassification(XVLMBase):
48 | def __init__(self, config):
49 | super().__init__(config, load_vision_params=False, load_text_params=False,
50 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False)
51 |
52 | self.cls_head = build_mlp(input_dim=self.text_width, output_dim=config['num_labels'])
53 | self.init_params = ['cls_head.' + n for n, _ in self.cls_head.named_parameters()]
54 |
55 | def forward(self, image, text_ids, text_atts, targets=None, k=None, weights=None, train=True, answer_pred=None,
56 | return_logits=False):
57 |
58 | image_embeds, image_atts = self.get_vision_embeds(image)
59 | output_cls = self.get_cross_embeds(image_embeds, image_atts,
60 | text_ids=text_ids, text_atts=text_atts)[:, 0, :]
61 |
62 | prediction = self.cls_head(output_cls)
63 | if train:
64 | if answer_pred is not None:
65 | self.criterion = nn.KLDivLoss(reduction='none')
66 | log_probs = F.log_softmax(prediction, -1)
67 | answer_label = F.softmax(answer_pred, dim=-1)
68 | loss = self.criterion(log_probs, answer_label)
69 | loss = loss.sum() / image.size(0)
70 | return loss
71 |
72 | p_states = []
73 | for b, n in enumerate(k):
74 | p_states = p_states + [prediction[b]] * n
75 |
76 | p_states = torch.stack(p_states, 0)
77 |
78 | loss = F.cross_entropy(p_states, targets, ignore_index=-100, reduction='none')
79 |
80 | loss = weights * loss
81 | loss = loss.sum() / image.size(0)
82 | if return_logits:
83 | return loss, prediction
84 | return loss
85 | else:
86 | return prediction
87 |
88 |
89 | class XVLMForNLVR(XVLMBase):
90 | """
91 | Follow VLMo
92 | """
93 | def __init__(self, config):
94 | super().__init__(config, load_vision_params=False, load_text_params=False,
95 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False,
96 | config_text=None)
97 |
98 | self.cls_head = build_mlp(input_dim=self.text_width * 2, output_dim=2)
99 | self.init_params = ['cls_head.' + n for n, _ in self.cls_head.named_parameters()]
100 |
101 | def forward(self, image, text_ids, text_atts, targets, train=True):
102 | image_embeds, image_atts = self.get_vision_embeds(image)
103 | image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0))
104 |
105 | output_cls_image1 = self.get_cross_embeds(image0_embeds, image_atts[:image0_embeds.size(0)],
106 | text_ids=text_ids, text_atts=text_atts)[:, 0, :]
107 |
108 | output_cls_image2 = self.get_cross_embeds(image1_embeds, image_atts[image0_embeds.size(0):],
109 | text_ids=text_ids, text_atts=text_atts)[:, 0, :]
110 |
111 | output_cls = torch.cat((output_cls_image1, output_cls_image2), dim=-1)
112 |
113 | assert output_cls.shape[-1] == self.text_width * 2
114 |
115 | prediction = self.cls_head(output_cls)
116 |
117 | return F.cross_entropy(prediction, targets) if train else prediction
118 |
119 |
120 | class XVLMPlus4XVNLI(XVLMPlusBase):
121 | def __init__(self, config):
122 | super().__init__(config, load_vision_params=False, load_text_params=False,
123 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False)
124 |
125 | self.cls_head = build_mlp(input_dim=self.text_width, output_dim=config['num_labels'])
126 | self.init_params = ['cls_head.' + n for n, _ in self.cls_head.named_parameters()]
127 |
128 | def forward(self, image, text_ids, text_atts, targets, train=True):
129 | image_embeds, image_atts = self.get_vision_embeds(image)
130 | text_embeds = self.get_cross_embeds(image_embeds, image_atts, text_ids, text_atts=text_atts)
131 | prediction = self.cls_head(text_embeds[:, 0, :])
132 |
133 | return F.cross_entropy(prediction, targets) if train else prediction
134 |
135 |
136 | class XVLMPlusForMARVL(XVLMPlusBase):
137 | def __init__(self, config):
138 | super().__init__(config, load_vision_params=False, load_text_params=False,
139 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False)
140 |
141 | self.cls_head = build_mlp(input_dim=self.text_width * 2, output_dim=2)
142 | self.init_params = ['cls_head.' + n for n, _ in self.cls_head.named_parameters()]
143 |
144 | def forward(self, image, text_ids, text_atts, targets, train=True):
145 | image_embeds, image_atts = self.get_vision_embeds(image)
146 | image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0))
147 |
148 | output_cls_image1 = self.get_cross_embeds(image0_embeds, image_atts[:image0_embeds.size(0)],
149 | text_ids=text_ids, text_atts=text_atts)[:, 0, :]
150 |
151 | output_cls_image2 = self.get_cross_embeds(image1_embeds, image_atts[image0_embeds.size(0):],
152 | text_ids=text_ids, text_atts=text_atts)[:, 0, :]
153 |
154 | output_cls = torch.cat((output_cls_image1, output_cls_image2), dim=-1)
155 |
156 | assert output_cls.shape[-1] == self.text_width * 2
157 |
158 | prediction = self.cls_head(output_cls)
159 |
160 | return F.cross_entropy(prediction, targets) if train else prediction
161 |
162 |
--------------------------------------------------------------------------------
/models/model_grounding.py:
--------------------------------------------------------------------------------
1 | from models import XVLMBase, load_pretrained
2 |
3 |
4 | class XVLMForGrounding(XVLMBase):
5 | def __init__(self, config):
6 | super().__init__(config, load_vision_params=False, load_text_params=False,
7 | use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=True)
8 | self.init_params = []
9 |
10 | def load_pretrained(self, ckpt_rpath, config, load_bbox_pretrain=False, is_eval=False):
11 | print("### load_bbox_pretrain, ", load_bbox_pretrain)
12 | state_dict = load_pretrained(self, ckpt_rpath, config, is_eval=is_eval, load_text=True)
13 | msg = self.load_state_dict(state_dict, strict=False)
14 | print('load checkpoint from %s' % ckpt_rpath)
15 | print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p])
16 | print("unexpected_keys: ", msg.unexpected_keys)
17 |
18 | def forward(self, image, text_ids, text_atts, target_bbox=None):
19 | image_embeds, _ = self.get_vision_embeds(image)
20 | text_embeds = self.get_text_embeds(text_ids, text_atts)
21 |
22 | output_coord = self.predict_bbox(image_embeds, text_embeds, text_atts)
23 | # output_coord & target_bbox: 64, 4
24 |
25 | if target_bbox is None:
26 | return output_coord
27 |
28 | loss_bbox, loss_giou = self.get_bbox_loss(output_coord, target_bbox)
29 |
30 | return output_coord, loss_bbox, loss_giou
31 |
32 |
--------------------------------------------------------------------------------
/models/model_pretrain.py:
--------------------------------------------------------------------------------
1 | # X^2-VLM: All-In-One Pre-trained Model For Vision-Language Tasks (https://arxiv.org/abs/2211.12402)
2 | # Github: https://github.com/zengyan-97/X2-VLM
3 | # Copyright (c) 2022, ByteDance Inc.
4 | # All rights reserved.
5 |
6 | # Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training (https://arxiv.org/abs/2206.00621)
7 | # Github: https://github.com/zengyan-97/CCLM
8 | # Copyright (c) 2022, ByteDance Inc.
9 | # All rights reserved.
10 |
11 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276)
12 | # Github: https://github.com/zengyan-97/X-VLM
13 | # Copyright (c) 2022, ByteDance Inc.
14 | # All rights reserved.
15 |
16 | import os
17 | import json
18 | import torch
19 | from einops import rearrange
20 |
21 | from models.xvlm import XVLMBase, XVLMPlusBase, VanillaConfig
22 |
23 |
24 | class XVLM(XVLMBase):
25 | def __init__(self, config, load_vision_params=True, load_text_params=True, pretraining=True):
26 | super().__init__(config, load_vision_params=load_vision_params, load_text_params=load_text_params,
27 | use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=True, use_bbox_loss=True,
28 | config_text=None, pretraining=pretraining)
29 |
30 | def forward_multimodal(self, image, text_ids, text_atts, text_ids_masked=None, masked_pos=None, masked_ids=None,
31 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None,
32 | ret_bbox_loss=False, ret_match_loss=True):
33 |
34 | if ret_bbox_loss:
35 | image_embeds, image_atts, image_embeds_fullatts = \
36 | self.get_vision_embeds(image, image_atts=image_atts, idx_to_group_img=idx_to_group_img)
37 | else:
38 | image_embeds, image_atts = self.get_vision_embeds(image)
39 |
40 | text_embeds = self.get_text_embeds(text_ids, text_atts)
41 |
42 | # with torch.no_grad(): # fix: i put it in batch iteration, so once a iteration
43 | # self.temp.clamp_(0.001, 0.5)
44 |
45 | image_feat, text_feat = self.get_features(image_embeds, text_embeds)
46 |
47 | loss_itc = self.get_contrastive_loss(image_feat, text_feat)
48 |
49 | if ret_match_loss:
50 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat)
51 | else:
52 | loss_itm = torch.tensor(0.0)
53 |
54 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids)
55 |
56 | loss = {'loss_itc': loss_itc, 'loss_itm': loss_itm, 'loss_mlm': loss_mlm}
57 |
58 | if ret_bbox_loss:
59 | output_coord = self.predict_bbox(image_embeds_fullatts, text_embeds, text_atts)
60 | loss_bbox, loss_giou = self.get_bbox_loss(output_coord, target_bbox, is_image=is_image)
61 |
62 | loss['loss_bbox'] = loss_bbox
63 | loss['loss_giou'] = loss_giou
64 |
65 | return loss
66 |
67 | def forward_text(self, text_ids=None, text_atts=None,
68 | text_ids_masked=None, masked_pos=None, masked_ids=None):
69 |
70 | loss = self.get_mlm_loss(text_ids_masked, text_atts, None, None, masked_pos, masked_ids)
71 |
72 | return {'loss_mlm': loss}
73 |
74 | def forward(self, image=None, text_ids=None, text_atts=None,
75 | text_ids_masked=None, masked_pos=None, masked_ids=None,
76 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None,
77 | ret_bbox_loss=False, ret_match_loss=True):
78 |
79 | if image is None: # text
80 | loss = self.forward_text(text_ids, text_atts, text_ids_masked,
81 | masked_pos, masked_ids)
82 |
83 | else:
84 | loss = self.forward_multimodal(image, text_ids, text_atts, text_ids_masked, masked_pos, masked_ids,
85 | image_atts, idx_to_group_img, target_bbox, is_image, ret_bbox_loss,
86 | ret_match_loss=ret_match_loss)
87 |
88 | return loss
89 |
90 |
91 | class XVLMPlus(XVLMPlusBase):
92 | def __init__(self, config, use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=True, use_bbox_loss=True,
93 | load_vision_params=True, load_text_params=True, load_cross_params=True, pretraining=True):
94 | super().__init__(config, use_contrastive_loss=use_contrastive_loss, use_matching_loss=use_matching_loss,
95 | use_mlm_loss=use_mlm_loss, use_bbox_loss=use_bbox_loss,
96 | load_vision_params=load_vision_params, load_text_params=load_text_params, load_cross_params=load_cross_params,
97 | pretraining=pretraining)
98 |
99 | def forward_multimodal(self, image, text_ids, text_atts, text_ids_masked=None, masked_pos=None, masked_ids=None,
100 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None,
101 | ret_bbox_loss=False, ret_match_loss=True):
102 |
103 | if ret_bbox_loss:
104 | image_embeds, image_atts, image_embeds_fullatts = \
105 | self.get_vision_embeds(image, image_atts=image_atts, idx_to_group_img=idx_to_group_img)
106 | else:
107 | image_embeds, image_atts = self.get_vision_embeds(image)
108 |
109 | text_embeds = self.get_text_embeds(text_ids, text_atts)
110 |
111 | # with torch.no_grad(): # fix: i put it in batch iteration, so once a iteration
112 | # self.temp.clamp_(0.001, 0.5)
113 |
114 | image_feat, text_feat = self.get_features(image_embeds, text_embeds)
115 |
116 | loss_itc = self.get_contrastive_loss(image_feat, text_feat)
117 |
118 | if ret_match_loss:
119 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat)
120 | else:
121 | loss_itm = torch.tensor(0.0)
122 |
123 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids)
124 |
125 | loss = {'loss_itc': loss_itc, 'loss_itm': loss_itm, 'loss_mlm': loss_mlm}
126 |
127 | if ret_bbox_loss:
128 | output_coord = self.predict_bbox(image_embeds_fullatts, text_embeds, text_atts)
129 | loss_bbox, loss_giou = self.get_bbox_loss(output_coord, target_bbox, is_image=is_image)
130 |
131 | loss['loss_bbox'] = loss_bbox
132 | loss['loss_giou'] = loss_giou
133 |
134 | return loss
135 |
136 | def forward(self, image=None, text_ids=None, text_atts=None,
137 | text_ids_masked=None, masked_pos=None, masked_ids=None,
138 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None,
139 | ret_bbox_loss=False, ret_match_loss=True):
140 |
141 | loss = self.forward_multimodal(image, text_ids, text_atts, text_ids_masked, masked_pos, masked_ids,
142 | image_atts, idx_to_group_img, target_bbox, is_image, ret_bbox_loss,
143 | ret_match_loss=ret_match_loss)
144 |
145 | return loss
146 |
147 |
148 | class CrossViewLM(XVLMPlus): # Multilingual x Multimodal Pre-training
149 | """
150 | Cross-View Language Modeling: Towards Unified Cross-Lingual Cross-Modal Pre-training
151 | https://arxiv.org/abs/2206.00621
152 | """
153 | def __init__(self, config, use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=True, use_bbox_loss=True,
154 | load_vision_params=True, load_text_params=True, load_cross_params=True, pretraining=True):
155 | super().__init__(config, use_contrastive_loss=use_contrastive_loss, use_matching_loss=use_matching_loss,
156 | use_mlm_loss=use_mlm_loss, use_bbox_loss=use_bbox_loss,
157 | load_vision_params=load_vision_params, load_text_params=load_text_params, load_cross_params=load_cross_params,
158 | pretraining=pretraining)
159 |
160 | def forward_para_text(self, text_ids=None, text_atts=None,
161 | text_ids_masked=None, text_atts_masked=None, masked_pos=None, masked_ids=None,
162 | text_ids_2=None, text_atts_2=None, text_ids_masked_2=None, masked_pos_2=None, masked_ids_2=None):
163 |
164 | text_embeds = self.get_text_embeds(text_ids, text_atts)
165 | text_embeds_2 = self.get_text_embeds(text_ids_2, text_atts_2)
166 |
167 | # with torch.no_grad():
168 | # self.temp.clamp_(0.001, 0.5)
169 |
170 | text_feat = self.get_features(text_embeds=text_embeds)
171 | text_feat_2 = self.get_features(text_embeds=text_embeds_2)
172 |
173 | loss_ttc = self.get_contrastive_loss(text_feat, text_feat_2)
174 | loss_ttm = self.get_matching_loss(text_embeds, text_atts, text_feat, text_embeds_2, text_atts_2, text_feat_2)
175 |
176 | loss_mlm = self.get_mlm_loss(text_ids_masked, text_atts, text_embeds_2, text_atts_2, masked_pos, masked_ids)
177 |
178 | loss = {'loss_ttc': loss_ttc, 'loss_ttm': loss_ttm, 'loss_mlm': loss_mlm}
179 |
180 | return loss
181 |
182 | def forward(self, image=None, text_ids=None, text_atts=None,
183 | text_ids_masked=None, text_atts_masked=None, masked_pos=None, masked_ids=None,
184 | text_ids_2=None, text_atts_2=None, text_ids_masked_2=None, masked_pos_2=None, masked_ids_2=None,
185 | image_atts=None, idx_to_group_img=None, target_bbox=None, is_image=None, ret_bbox_loss=False, ret_match_loss=True):
186 |
187 | if image is None: # parallel text
188 | loss = self.forward_para_text(text_ids, text_atts, text_ids_masked, text_atts_masked, masked_pos, masked_ids,
189 | text_ids_2, text_atts_2, text_ids_masked_2, masked_pos_2, masked_ids_2)
190 |
191 | else:
192 | loss = self.forward_multimodal(image, text_ids, text_atts, text_ids_masked, masked_pos, masked_ids,
193 | image_atts, idx_to_group_img, target_bbox, is_image, ret_bbox_loss,
194 | ret_match_loss=ret_match_loss)
195 |
196 | return loss
197 |
--------------------------------------------------------------------------------
/models/model_retrieval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from models.xvlm import XVLMBase, XVLMPlusBase
4 |
5 |
6 | class XVLMForRetrieval(XVLMBase):
7 | def __init__(self, config):
8 | super().__init__(config, load_vision_params=False, load_text_params=False,
9 | use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=False, use_bbox_loss=False)
10 |
11 | self.num_attention_heads = self.text_encoder.config.num_attention_heads
12 | self.init_params = []
13 |
14 | def forward(self, image, text_ids, text_atts, idx=None):
15 | image_embeds, image_atts = self.get_vision_embeds(image)
16 | text_embeds = self.get_text_embeds(text_ids, text_atts)
17 |
18 | with torch.no_grad():
19 | self.temp.clamp_(0.001, 0.5)
20 |
21 | image_feat, text_feat = self.get_features(image_embeds, text_embeds)
22 | loss_itc = self.get_contrastive_loss(image_feat, text_feat, idx=idx)
23 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat, idx=idx)
24 |
25 | return loss_itc, loss_itm
26 |
27 |
28 | class XVLMPlusForRetrieval(XVLMPlusBase):
29 | def __init__(self, config):
30 | super().__init__(config, load_vision_params=False, load_text_params=False, load_cross_params=False,
31 | use_contrastive_loss=True, use_matching_loss=True, use_mlm_loss=False, use_bbox_loss=False)
32 |
33 | self.num_attention_heads = self.text_encoder.config.num_attention_heads
34 | self.init_params = []
35 |
36 | def forward(self, image, text_ids, text_atts, idx=None):
37 | image_embeds, image_atts = self.get_vision_embeds(image)
38 | text_embeds = self.get_text_embeds(text_ids, text_atts)
39 |
40 | with torch.no_grad():
41 | self.temp.clamp_(0.001, 0.5)
42 |
43 | image_feat, text_feat = self.get_features(image_embeds, text_embeds)
44 | loss_itc = self.get_contrastive_loss(image_feat, text_feat, idx=idx)
45 | loss_itm = self.get_matching_loss(image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat, idx=idx)
46 |
47 | return loss_itc, loss_itm
48 |
--------------------------------------------------------------------------------
/models/resampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange, repeat
4 | from einops_exts import rearrange_many, repeat_many
5 |
6 |
7 | def FeedForward(dim, mult=4):
8 | inner_dim = int(dim * mult)
9 | return nn.Sequential(
10 | nn.LayerNorm(dim),
11 | nn.Linear(dim, inner_dim, bias=False),
12 | nn.GELU(),
13 | nn.Linear(inner_dim, dim, bias=False)
14 | )
15 |
16 |
17 | class PerceiverAttention(nn.Module):
18 | def __init__(
19 | self,
20 | *,
21 | dim,
22 | dim_head=64,
23 | heads=8
24 | ):
25 | super().__init__()
26 | self.scale = dim_head ** -0.5
27 | self.heads = heads
28 | inner_dim = dim_head * heads
29 |
30 | self.norm_media = nn.LayerNorm(dim)
31 | self.norm_latents = nn.LayerNorm(dim)
32 |
33 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
34 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
35 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
36 |
37 | def forward(self, x, latents):
38 | """
39 | einstein notation
40 | b - batch
41 | t - time
42 | n - sequence
43 | d - dimension
44 | """
45 | x = self.norm_media(x)
46 | latents = self.norm_latents(latents)
47 |
48 | b, m, h = *x.shape[:2], self.heads
49 |
50 | q = self.to_q(latents)
51 |
52 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
53 | kv_input = torch.cat((x, latents), dim=-2)
54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
55 |
56 | q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h=h)
57 |
58 | q = q * self.scale
59 |
60 | # attention
61 |
62 | sim = einsum('... i d, ... j d -> ... i j', q, k)
63 |
64 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
65 | attn = sim.softmax(dim=-1)
66 |
67 | out = einsum('... i j, ... j d -> ... i d', attn, v)
68 | out = rearrange(out, 'b h t n d -> b t n (h d)', h=h)
69 | return self.to_out(out)
70 |
71 |
72 | class PerceiverResampler(nn.Module):
73 | def __init__(
74 | self,
75 | *,
76 | dim,
77 | depth,
78 | dim_head=64,
79 | heads=8,
80 | num_latents=64,
81 | num_time_embeds=4,
82 | ff_mult=4,
83 | num_img_latents=-1,
84 | ):
85 | super().__init__()
86 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
87 | self.img_latents = nn.Parameter(torch.randn(num_img_latents, dim)) if num_img_latents > 0 else None
88 |
89 | # self.time_pos_emb = nn.Parameter(torch.randn(num_time_embeds, 1, dim)) # 我的代码里面会加好
90 |
91 | self.layers = nn.ModuleList([])
92 | for _ in range(depth):
93 | self.layers.append(nn.ModuleList([
94 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
95 | FeedForward(dim=dim, mult=ff_mult)
96 | ]))
97 |
98 | self.norm = nn.LayerNorm(dim)
99 |
100 | def forward(self, x, mode='video'):
101 | if x.ndim == 3:
102 | x = rearrange(x, 'b n d -> b 1 n d')
103 |
104 | # times = x.shape[1]
105 | # x = x + self.time_pos_emb[:times]
106 |
107 | if mode == 'video':
108 | latents = self.latents
109 | elif mode == 'image':
110 | latents = self.img_latents
111 | else:
112 | raise ValueError(f"mode == {mode}")
113 |
114 | latents = repeat(latents, 'n d -> b m n d', b=x.shape[0], m=x.shape[1])
115 |
116 | for attn, ff in self.layers:
117 | latents = attn(x, latents) + latents
118 | latents = ff(latents) + latents
119 |
120 | return self.norm(latents)
121 |
--------------------------------------------------------------------------------
/optim.py:
--------------------------------------------------------------------------------
1 | from transformers.optimization import AdamW
2 | import torch
3 |
4 |
5 | def is_vision(args, param: str):
6 | if hasattr(args, 'vision_lr') and ('vision_encoder' in param):
7 | return True
8 | else:
9 | return False
10 |
11 |
12 | def is_text(args, param: str):
13 | if hasattr(args, 'text_lr') and ('text_encoder' in param):
14 | return True
15 | else:
16 | return False
17 |
18 |
19 | def is_cross(args, param: str):
20 | if hasattr(args, 'cross_lr') and ('cross_encoder' in param): # only supports XVLMPlusBase
21 | return True
22 | else:
23 | return False
24 |
25 |
26 | def create_optimizer(args, model):
27 | lr = args.lr
28 | wd = args.weight_decay
29 | lr_mult = getattr(args, 'lr_mult', 1)
30 | print(f"### lr: {args.lr}, lr_mult: {lr_mult}")
31 |
32 | optimizer_grouped_parameters = [
33 | {"params": [], "weight_decay": wd, "lr": lr},
34 | {"params": [], "weight_decay": 0.0, "lr": lr},
35 | {"params": [], "weight_decay": wd, "lr": lr * lr_mult},
36 | {"params": [], "weight_decay": 0.0, "lr": lr * lr_mult}
37 | ]
38 |
39 | if hasattr(args, 'vision_lr'):
40 | # 4 & 5
41 | print("### vision_lr: ", args.vision_lr)
42 | optimizer_grouped_parameters.append({"params": [], "weight_decay": wd, "lr": args.vision_lr})
43 | optimizer_grouped_parameters.append({"params": [], "weight_decay": 0.0, "lr": args.vision_lr})
44 |
45 | # 6 & 7
46 | assert hasattr(args, 'text_lr')
47 | print("### text_lr: ", args.text_lr)
48 | optimizer_grouped_parameters.append({"params": [], "weight_decay": wd, "lr": args.text_lr})
49 | optimizer_grouped_parameters.append({"params": [], "weight_decay": 0.0, "lr": args.text_lr})
50 |
51 | # 8 & 9
52 | if not hasattr(args, 'cross_lr'):
53 | args.cross_lr = args.text_lr
54 |
55 | print("### cross_lr: ", args.cross_lr, flush=True)
56 | optimizer_grouped_parameters.append({"params": [], "weight_decay": wd, "lr": args.cross_lr})
57 | optimizer_grouped_parameters.append({"params": [], "weight_decay": 0.0, "lr": args.cross_lr})
58 |
59 | no_decay = {"bias",
60 | "LayerNorm.bias",
61 | "LayerNorm.weight",
62 | "norm.bias",
63 | "norm.weight",
64 | "norm1.bias",
65 | "norm1.weight",
66 | "norm2.bias",
67 | "norm2.weight"}
68 |
69 | if hasattr(model, 'init_params'):
70 | large_lr = model.init_params
71 | print("### model has 'init_params', ", len(large_lr))
72 | else:
73 | large_lr = {}
74 |
75 | for n, p in model.named_parameters():
76 | if not p.requires_grad:
77 | continue # frozen weights
78 |
79 | if any(nd in n for nd in no_decay):
80 | if is_vision(args, n):
81 | optimizer_grouped_parameters[5]['params'].append(p)
82 | elif is_text(args, n):
83 | optimizer_grouped_parameters[7]['params'].append(p)
84 | elif is_cross(args, n):
85 | optimizer_grouped_parameters[9]['params'].append(p)
86 | elif n in large_lr:
87 | optimizer_grouped_parameters[3]['params'].append(p)
88 | else:
89 | optimizer_grouped_parameters[1]['params'].append(p)
90 | else: # decay
91 | if is_vision(args, n):
92 | optimizer_grouped_parameters[4]['params'].append(p)
93 | elif is_text(args, n):
94 | optimizer_grouped_parameters[6]['params'].append(p)
95 | elif is_cross(args, n):
96 | optimizer_grouped_parameters[8]['params'].append(p)
97 | elif n in large_lr:
98 | optimizer_grouped_parameters[2]['params'].append(p)
99 | else:
100 | optimizer_grouped_parameters[0]['params'].append(p)
101 |
102 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98))
103 |
104 | return optimizer
105 |
106 |
107 | class LARS(torch.optim.Optimizer):
108 | """
109 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
110 | """
111 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
112 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
113 | super().__init__(params, defaults)
114 |
115 | @torch.no_grad()
116 | def step(self):
117 | for g in self.param_groups:
118 | for p in g['params']:
119 | dp = p.grad
120 |
121 | if dp is None:
122 | continue
123 |
124 | if p.ndim > 1: # if not normalization gamma/beta or bias
125 | dp = dp.add(p, alpha=g['weight_decay'])
126 | param_norm = torch.norm(p)
127 | update_norm = torch.norm(dp)
128 | one = torch.ones_like(param_norm)
129 | q = torch.where(param_norm > 0.,
130 | torch.where(update_norm > 0,
131 | (g['trust_coefficient'] * param_norm / update_norm), one),
132 | one)
133 | dp = dp.mul(q)
134 |
135 | param_state = self.state[p]
136 | if 'mu' not in param_state:
137 | param_state['mu'] = torch.zeros_like(p)
138 | mu = param_state['mu']
139 | mu.mul_(g['momentum']).add_(dp)
140 | p.add_(mu, alpha=-g['lr'])
141 |
--------------------------------------------------------------------------------
/refTools/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'licheng'
2 |
3 |
4 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from refTools.evaluation.bleu.bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 |
23 | assert(gts.keys() == res.keys())
24 | imgIds = gts.keys()
25 |
26 | bleu_scorer = BleuScorer(n=self._n)
27 | for id in imgIds:
28 | hypo = res[id]
29 | ref = gts[id]
30 |
31 | # Sanity check.
32 | assert(type(hypo) is list)
33 | assert(len(hypo) == 1)
34 | assert(type(ref) is list)
35 | assert(len(ref) >= 1)
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | #score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
42 |
43 | # return (bleu, bleu_info)
44 | return score, scores
45 |
46 | def method(self):
47 | return "Bleu"
48 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/bleu_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | '''Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 | def precook(s, n=4, out=False):
24 | """Takes a string as input and returns an object that can be given to
25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
26 | can take string arguments as well."""
27 | words = s.split()
28 | counts = defaultdict(int)
29 | for k in range(1,n+1):
30 | for i in range(len(words)-k+1):
31 | ngram = tuple(words[i:i+k])
32 | counts[ngram] += 1
33 | return (len(words), counts)
34 |
35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36 | '''Takes a list of reference sentences for a single segment
37 | and returns an object that encapsulates everything that BLEU
38 | needs to know about them.'''
39 |
40 | reflen = []
41 | maxcounts = {}
42 | for ref in refs:
43 | rl, counts = precook(ref, n)
44 | reflen.append(rl)
45 | for (ngram,count) in counts.items():
46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47 |
48 | # Calculate effective reference sentence length.
49 | if eff == "shortest":
50 | reflen = min(reflen)
51 | elif eff == "average":
52 | reflen = float(sum(reflen))/len(reflen)
53 |
54 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
55 |
56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57 |
58 | return (reflen, maxcounts)
59 |
60 | def cook_test(test, refparam, eff=None, n=4):
61 | '''Takes a test sentence and returns an object that
62 | encapsulates everything that BLEU needs to know about it.'''
63 |
64 | reflen, refmaxcounts = refparam[0], refparam[1]
65 | testlen, counts = precook(test, n, True)
66 |
67 | result = {}
68 |
69 | # Calculate effective reference sentence length.
70 |
71 | if eff == "closest":
72 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
73 | else: ## i.e., "average" or "shortest" or None
74 | result["reflen"] = reflen
75 |
76 | result["testlen"] = testlen
77 |
78 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
79 |
80 | result['correct'] = [0]*n
81 | for (ngram, count) in counts.items():
82 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
83 |
84 | return result
85 |
86 | class BleuScorer(object):
87 | """Bleu scorer.
88 | """
89 |
90 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
91 | # special_reflen is used in oracle (proportional effective ref len for a node).
92 |
93 | def copy(self):
94 | ''' copy the refs.'''
95 | new = BleuScorer(n=self.n)
96 | new.ctest = copy.copy(self.ctest)
97 | new.crefs = copy.copy(self.crefs)
98 | new._score = None
99 | return new
100 |
101 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
102 | ''' singular instance '''
103 |
104 | self.n = n
105 | self.crefs = []
106 | self.ctest = []
107 | self.cook_append(test, refs)
108 | self.special_reflen = special_reflen
109 |
110 | def cook_append(self, test, refs):
111 | '''called by constructor and __iadd__ to avoid creating new instances.'''
112 |
113 | if refs is not None:
114 | self.crefs.append(cook_refs(refs))
115 | if test is not None:
116 | cooked_test = cook_test(test, self.crefs[-1])
117 | self.ctest.append(cooked_test) ## N.B.: -1
118 | else:
119 | self.ctest.append(None) # lens of crefs and ctest have to match
120 |
121 | self._score = None ## need to recompute
122 |
123 | def ratio(self, option=None):
124 | self.compute_score(option=option)
125 | return self._ratio
126 |
127 | def score_ratio(self, option=None):
128 | '''return (bleu, len_ratio) pair'''
129 | return (self.fscore(option=option), self.ratio(option=option))
130 |
131 | def score_ratio_str(self, option=None):
132 | return "%.4f (%.2f)" % self.score_ratio(option)
133 |
134 | def reflen(self, option=None):
135 | self.compute_score(option=option)
136 | return self._reflen
137 |
138 | def testlen(self, option=None):
139 | self.compute_score(option=option)
140 | return self._testlen
141 |
142 | def retest(self, new_test):
143 | if type(new_test) is str:
144 | new_test = [new_test]
145 | assert len(new_test) == len(self.crefs), new_test
146 | self.ctest = []
147 | for t, rs in zip(new_test, self.crefs):
148 | self.ctest.append(cook_test(t, rs))
149 | self._score = None
150 |
151 | return self
152 |
153 | def rescore(self, new_test):
154 | ''' replace test(s) with new test(s), and returns the new score.'''
155 |
156 | return self.retest(new_test).compute_score()
157 |
158 | def size(self):
159 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
160 | return len(self.crefs)
161 |
162 | def __iadd__(self, other):
163 | '''add an instance (e.g., from another sentence).'''
164 |
165 | if type(other) is tuple:
166 | ## avoid creating new BleuScorer instances
167 | self.cook_append(other[0], other[1])
168 | else:
169 | assert self.compatible(other), "incompatible BLEUs."
170 | self.ctest.extend(other.ctest)
171 | self.crefs.extend(other.crefs)
172 | self._score = None ## need to recompute
173 |
174 | return self
175 |
176 | def compatible(self, other):
177 | return isinstance(other, BleuScorer) and self.n == other.n
178 |
179 | def single_reflen(self, option="average"):
180 | return self._single_reflen(self.crefs[0][0], option)
181 |
182 | def _single_reflen(self, reflens, option=None, testlen=None):
183 |
184 | if option == "shortest":
185 | reflen = min(reflens)
186 | elif option == "average":
187 | reflen = float(sum(reflens))/len(reflens)
188 | elif option == "closest":
189 | reflen = min((abs(l-testlen), l) for l in reflens)[1]
190 | else:
191 | assert False, "unsupported reflen option %s" % option
192 |
193 | return reflen
194 |
195 | def recompute_score(self, option=None, verbose=0):
196 | self._score = None
197 | return self.compute_score(option, verbose)
198 |
199 | def compute_score(self, option=None, verbose=0):
200 | n = self.n
201 | small = 1e-9
202 | tiny = 1e-15 ## so that if guess is 0 still return 0
203 | bleu_list = [[] for _ in range(n)]
204 |
205 | if self._score is not None:
206 | return self._score
207 |
208 | if option is None:
209 | option = "average" if len(self.crefs) == 1 else "closest"
210 |
211 | self._testlen = 0
212 | self._reflen = 0
213 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
214 |
215 | # for each sentence
216 | for comps in self.ctest:
217 | testlen = comps['testlen']
218 | self._testlen += testlen
219 |
220 | if self.special_reflen is None: ## need computation
221 | reflen = self._single_reflen(comps['reflen'], option, testlen)
222 | else:
223 | reflen = self.special_reflen
224 |
225 | self._reflen += reflen
226 |
227 | for key in ['guess','correct']:
228 | for k in range(n):
229 | totalcomps[key][k] += comps[key][k]
230 |
231 | # append per image bleu score
232 | bleu = 1.
233 | for k in range(n):
234 | bleu *= (float(comps['correct'][k]) + tiny) \
235 | /(float(comps['guess'][k]) + small)
236 | bleu_list[k].append(bleu ** (1./(k+1)))
237 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
238 | if ratio < 1:
239 | for k in range(n):
240 | bleu_list[k][-1] *= math.exp(1 - 1/ratio)
241 |
242 | if verbose > 1:
243 | print(comps, reflen)
244 |
245 | totalcomps['reflen'] = self._reflen
246 | totalcomps['testlen'] = self._testlen
247 |
248 | bleus = []
249 | bleu = 1.
250 | for k in range(n):
251 | bleu *= float(totalcomps['correct'][k] + tiny) \
252 | / (totalcomps['guess'][k] + small)
253 | bleus.append(bleu ** (1./(k+1)))
254 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
255 | if ratio < 1:
256 | for k in range(n):
257 | bleus[k] *= math.exp(1 - 1/ratio)
258 |
259 | if verbose > 0:
260 | print(totalcomps)
261 | print("ratio:", ratio)
262 |
263 | self._score = bleus
264 | return self._score, bleu_list
265 |
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from refTools.evaluation.cider.cider_scorer import CiderScorer
11 | import pdb
12 |
13 | class Cider:
14 | """
15 | Main Class to compute the CIDEr metric
16 |
17 | """
18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
19 | # set cider to sum over 1 to 4-grams
20 | self._n = n
21 | # set the standard deviation parameter for gaussian penalty
22 | self._sigma = sigma
23 |
24 | def compute_score(self, gts, res):
25 | """
26 | Main function to compute CIDEr score
27 | :param hypo_for_image (dict) : dictionary with key and value
28 | ref_for_image (dict) : dictionary with key and value
29 | :return: cider (float) : computed CIDEr score for the corpus
30 | """
31 |
32 | assert(gts.keys() == res.keys())
33 | imgIds = gts.keys()
34 |
35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
36 |
37 | for id in imgIds:
38 | hypo = res[id]
39 | ref = gts[id]
40 |
41 | # Sanity check.
42 | assert(type(hypo) is list)
43 | assert(len(hypo) == 1)
44 | assert(type(ref) is list)
45 | assert(len(ref) > 0)
46 |
47 | cider_scorer += (hypo[0], ref)
48 |
49 | (score, scores) = cider_scorer.compute_score()
50 |
51 | return score, scores
52 |
53 | def method(self):
54 | return "CIDEr"
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import pdb
9 | import math
10 |
11 | def precook(s, n=4, out=False):
12 | """
13 | Takes a string as input and returns an object that can be given to
14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
15 | can take string arguments as well.
16 | :param s: string : sentence to be converted into ngrams
17 | :param n: int : number of ngrams for which representation is calculated
18 | :return: term frequency vector for occuring ngrams
19 | """
20 | words = s.split()
21 | counts = defaultdict(int)
22 | for k in xrange(1,n+1):
23 | for i in xrange(len(words)-k+1):
24 | ngram = tuple(words[i:i+k])
25 | counts[ngram] += 1
26 | return counts
27 |
28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29 | '''Takes a list of reference sentences for a single segment
30 | and returns an object that encapsulates everything that BLEU
31 | needs to know about them.
32 | :param refs: list of string : reference sentences for some image
33 | :param n: int : number of ngrams for which (ngram) representation is calculated
34 | :return: result (list of dict)
35 | '''
36 | return [precook(ref, n) for ref in refs]
37 |
38 | def cook_test(test, n=4):
39 | '''Takes a test sentence and returns an object that
40 | encapsulates everything that BLEU needs to know about it.
41 | :param test: list of string : hypothesis sentence for some image
42 | :param n: int : number of ngrams for which (ngram) representation is calculated
43 | :return: result (dict)
44 | '''
45 | return precook(test, n, True)
46 |
47 | class CiderScorer(object):
48 | """CIDEr scorer.
49 | """
50 |
51 | def copy(self):
52 | ''' copy the refs.'''
53 | new = CiderScorer(n=self.n)
54 | new.ctest = copy.copy(self.ctest)
55 | new.crefs = copy.copy(self.crefs)
56 | return new
57 |
58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59 | ''' singular instance '''
60 | self.n = n
61 | self.sigma = sigma
62 | self.crefs = []
63 | self.ctest = []
64 | self.document_frequency = defaultdict(float)
65 | self.cook_append(test, refs)
66 | self.ref_len = None
67 |
68 | def cook_append(self, test, refs):
69 | '''called by constructor and __iadd__ to avoid creating new instances.'''
70 |
71 | if refs is not None:
72 | self.crefs.append(cook_refs(refs))
73 | if test is not None:
74 | self.ctest.append(cook_test(test)) ## N.B.: -1
75 | else:
76 | self.ctest.append(None) # lens of crefs and ctest have to match
77 |
78 | def size(self):
79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80 | return len(self.crefs)
81 |
82 | def __iadd__(self, other):
83 | '''add an instance (e.g., from another sentence).'''
84 |
85 | if type(other) is tuple:
86 | ## avoid creating new CiderScorer instances
87 | self.cook_append(other[0], other[1])
88 | else:
89 | self.ctest.extend(other.ctest)
90 | self.crefs.extend(other.crefs)
91 |
92 | return self
93 | def compute_doc_freq(self):
94 | '''
95 | Compute term frequency for reference data.
96 | This will be used to compute idf (inverse document frequency later)
97 | The term frequency is stored in the object
98 | :return: None
99 | '''
100 | for refs in self.crefs:
101 | # refs, k ref captions of one image
102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
103 | self.document_frequency[ngram] += 1
104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105 |
106 | def compute_cider(self):
107 | def counts2vec(cnts):
108 | """
109 | Function maps counts of ngram to vector of tfidf weights.
110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111 | The n-th entry of array denotes length of n-grams.
112 | :param cnts:
113 | :return: vec (array of dict), norm (array of float), length (int)
114 | """
115 | vec = [defaultdict(float) for _ in range(self.n)]
116 | length = 0
117 | norm = [0.0 for _ in range(self.n)]
118 | for (ngram,term_freq) in cnts.iteritems():
119 | # give word count 1 if it doesn't appear in reference corpus
120 | df = np.log(max(1.0, self.document_frequency[ngram]))
121 | # ngram index
122 | n = len(ngram)-1
123 | # tf (term_freq) * idf (precomputed idf) for n-grams
124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125 | # compute norm for the vector. the norm will be used for computing similarity
126 | norm[n] += pow(vec[n][ngram], 2)
127 |
128 | if n == 1:
129 | length += term_freq
130 | norm = [np.sqrt(n) for n in norm]
131 | return vec, norm, length
132 |
133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134 | '''
135 | Compute the cosine similarity of two vectors.
136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137 | :param vec_ref: array of dictionary for vector corresponding to reference
138 | :param norm_hyp: array of float for vector corresponding to hypothesis
139 | :param norm_ref: array of float for vector corresponding to reference
140 | :param length_hyp: int containing length of hypothesis
141 | :param length_ref: int containing length of reference
142 | :return: array of score for each n-grams cosine similarity
143 | '''
144 | delta = float(length_hyp - length_ref)
145 | # measure consine similarity
146 | val = np.array([0.0 for _ in range(self.n)])
147 | for n in range(self.n):
148 | # ngram
149 | for (ngram,count) in vec_hyp[n].iteritems():
150 | # vrama91 : added clipping
151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152 |
153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154 | val[n] /= (norm_hyp[n]*norm_ref[n])
155 |
156 | assert(not math.isnan(val[n]))
157 | # vrama91: added a length based gaussian penalty
158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159 | return val
160 |
161 | # compute log reference length
162 | self.ref_len = np.log(float(len(self.crefs)))
163 |
164 | scores = []
165 | for test, refs in zip(self.ctest, self.crefs):
166 | # compute vector for test captions
167 | vec, norm, length = counts2vec(test)
168 | # compute vector for ref captions
169 | score = np.array([0.0 for _ in range(self.n)])
170 | for ref in refs:
171 | vec_ref, norm_ref, length_ref = counts2vec(ref)
172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
173 | # change by vrama91 - mean of ngram scores, instead of sum
174 | score_avg = np.mean(score)
175 | # divide by number of references
176 | score_avg /= len(refs)
177 | # multiply score by 10
178 | score_avg *= 10.0
179 | # append score of an image to the score list
180 | scores.append(score_avg)
181 | return scores
182 |
183 | def compute_score(self, option=None, verbose=0):
184 | # compute idf
185 | self.compute_doc_freq()
186 | # assert to check document frequency
187 | assert(len(self.ctest) >= max(self.document_frequency.values()))
188 | # compute cider score
189 | score = self.compute_cider()
190 | # debug
191 | # print score
192 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/meteor-1.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/refTools/evaluation/meteor/meteor-1.5.jar
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 | import os
7 | import sys
8 | import subprocess
9 | import threading
10 |
11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
12 | METEOR_JAR = 'meteor-1.5.jar'
13 | # print METEOR_JAR
14 |
15 | class Meteor:
16 |
17 | def __init__(self):
18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
19 | '-', '-', '-stdio', '-l', 'en', '-norm']
20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
21 | cwd=os.path.dirname(os.path.abspath(__file__)), \
22 | stdin=subprocess.PIPE, \
23 | stdout=subprocess.PIPE, \
24 | stderr=subprocess.PIPE)
25 | # Used to guarantee thread safety
26 | self.lock = threading.Lock()
27 |
28 | def compute_score(self, gts, res):
29 | assert(gts.keys() == res.keys())
30 | imgIds = gts.keys()
31 | scores = []
32 |
33 | eval_line = 'EVAL'
34 | self.lock.acquire()
35 | for i in imgIds:
36 | assert(len(res[i]) == 1)
37 | stat = self._stat(res[i][0], gts[i])
38 | eval_line += ' ||| {}'.format(stat)
39 |
40 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
41 | for i in range(0,len(imgIds)):
42 | scores.append(float(self.meteor_p.stdout.readline().strip()))
43 | score = float(self.meteor_p.stdout.readline().strip())
44 | self.lock.release()
45 |
46 | return score, scores
47 |
48 | def method(self):
49 | return "METEOR"
50 |
51 | def _stat(self, hypothesis_str, reference_list):
52 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
53 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
54 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
55 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode())
56 | return self.meteor_p.stdout.readline().decode().strip()
57 |
58 | def _score(self, hypothesis_str, reference_list):
59 | self.lock.acquire()
60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
63 | self.meteor_p.stdin.write('{}\n'.format(score_line))
64 | stats = self.meteor_p.stdout.readline().strip()
65 | eval_line = 'EVAL ||| {}'.format(stats)
66 | # EVAL ||| stats
67 | self.meteor_p.stdin.write('{}\n'.format(eval_line))
68 | score = float(self.meteor_p.stdout.readline().strip())
69 | self.lock.release()
70 | return score
71 |
72 | def __exit__(self):
73 | self.lock.acquire()
74 | self.meteor_p.stdin.close()
75 | self.meteor_p.wait()
76 | self.lock.release()
77 |
--------------------------------------------------------------------------------
/refTools/evaluation/readme.txt:
--------------------------------------------------------------------------------
1 | This folder contains modified coco-caption evaluation, which is downloaded from https://github.com/tylin/coco-caption.git
2 | and refEvaluation which is to be called by the refer algorithm.
3 |
4 | More specifically, this folder contains:
5 | 1. bleu/
6 | 2. cider/
7 | 3. meteor/
8 | 4. rouge/
9 | 5. tokenizer/
10 | 6. __init__.py
11 | 7. refEvaluation.py
12 |
--------------------------------------------------------------------------------
/refTools/evaluation/refEvaluation.py:
--------------------------------------------------------------------------------
1 | from refTools.evaluation.tokenizer.ptbtokenizer import PTBTokenizer
2 | from refTools.evaluation.bleu.bleu import Bleu
3 | from refTools.evaluation.meteor.meteor import Meteor
4 | from refTools.evaluation.rouge.rouge import Rouge
5 | from refTools.evaluation.cider.cider import Cider
6 |
7 | """
8 | Input: refer and Res = [{ref_id, sent}]
9 |
10 | Things of interest
11 | evalRefs - list of ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']
12 | eval - dict of {metric: score}
13 | refToEval - dict of {ref_id: ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']}
14 | """
15 |
16 | class RefEvaluation:
17 | def __init__ (self, refer, Res):
18 | """
19 | :param refer: refer class of current dataset
20 | :param Res: [{'ref_id', 'sent'}]
21 | """
22 | self.evalRefs = []
23 | self.eval = {}
24 | self.refToEval = {}
25 | self.refer = refer
26 | self.Res = Res
27 |
28 | def evaluate(self):
29 |
30 | evalRefIds = [ann['ref_id'] for ann in self.Res]
31 |
32 | refToGts = {}
33 | for ref_id in evalRefIds:
34 | ref = self.refer.Refs[ref_id]
35 | gt_sents = [sent['sent'].encode('ascii', 'ignore').decode('ascii') for sent in ref['sentences']] # up to 3 expressions
36 | refToGts[ref_id] = gt_sents
37 | refToRes = {ann['ref_id']: [ann['sent']] for ann in self.Res}
38 |
39 | print('tokenization...')
40 | tokenizer = PTBTokenizer()
41 | self.refToRes = tokenizer.tokenize(refToRes)
42 | self.refToGts = tokenizer.tokenize(refToGts)
43 |
44 | # =================================================
45 | # Set up scorers
46 | # =================================================
47 | print('setting up scorers...')
48 | scorers = [
49 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
50 | (Meteor(),"METEOR"),
51 | (Rouge(), "ROUGE_L"),
52 | (Cider(), "CIDEr")
53 | ]
54 |
55 | # =================================================
56 | # Compute scores
57 | # =================================================
58 | for scorer, method in scorers:
59 | print('computing %s score...'%(scorer.method()))
60 | score, scores = scorer.compute_score(self.refToGts, self.refToRes)
61 | if type(method) == list:
62 | for sc, scs, m in zip(score, scores, method):
63 | self.setEval(sc, m)
64 | self.setRefToEvalRefs(scs, self.refToGts.keys(), m)
65 | print("%s: %0.3f"%(m, sc))
66 | else:
67 | self.setEval(score, method)
68 | self.setRefToEvalRefs(scores, self.refToGts.keys(), method)
69 | print("%s: %0.3f"%(method, score))
70 | self.setEvalRefs()
71 |
72 | def setEval(self, score, method):
73 | self.eval[method] = score
74 |
75 | def setRefToEvalRefs(self, scores, refIds, method):
76 | for refId, score in zip(refIds, scores):
77 | if not refId in self.refToEval:
78 | self.refToEval[refId] = {}
79 | self.refToEval[refId]["ref_id"] = refId
80 | self.refToEval[refId][method] = score
81 |
82 | def setEvalRefs(self):
83 | self.evalRefs = [eval for refId, eval in self.refToEval.items()]
84 |
85 |
86 | if __name__ == '__main__':
87 |
88 | import os.path as osp
89 | import sys
90 | ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))
91 | sys.path.insert(0, osp.join(ROOT_DIR, 'lib', 'datasets'))
92 | from refer import REFER
93 |
94 | # load refer of dataset
95 | dataset = 'refcoco'
96 | refer = REFER(dataset, splitBy = 'google')
97 |
98 | # mimic some Res
99 | val_refIds = refer.getRefIds(split='test')
100 | ref_id = 49767
101 | print("GD: %s" % refer.Refs[ref_id]['sentences'])
102 | Res = [{'ref_id': ref_id, 'sent': 'left bottle'}]
103 |
104 | # evaluate some refer expressions
105 | refEval = RefEvaluation(refer, Res)
106 | refEval.evaluate()
107 |
108 | # print output evaluation scores
109 | for metric, score in refEval.eval.items():
110 | print('%s: %.3f'%(metric, score))
111 |
112 | # demo how to use evalImgs to retrieve low score result
113 | # evals = [eva for eva in refEval.evalRefs if eva['CIDEr']<30]
114 | # print 'ground truth sents'
115 | # refId = evals[0]['ref_id']
116 | # print 'refId: %s' % refId
117 | # print [sent['sent'] for sent in refer.Refs[refId]['sentences']]
118 | #
119 | # print 'generated sent (CIDEr score %0.1f)' % (evals[0]['CIDEr'])
120 |
121 | # print refEval.refToEval[8]
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'vrama91'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 |
20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21 | """
22 | if(len(string)< len(sub)):
23 | sub, string = string, sub
24 |
25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
26 |
27 | for j in range(1,len(sub)+1):
28 | for i in range(1,len(string)+1):
29 | if(string[i-1] == sub[j-1]):
30 | lengths[i][j] = lengths[i-1][j-1] + 1
31 | else:
32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
33 |
34 | return lengths[len(string)][len(sub)]
35 |
36 | class Rouge():
37 | '''
38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
39 |
40 | '''
41 | def __init__(self):
42 | # vrama91: updated the value below based on discussion with Hovey
43 | self.beta = 1.2
44 |
45 | def calc_score(self, candidate, refs):
46 | """
47 | Compute ROUGE-L score given one candidate and references for an image
48 | :param candidate: str : candidate sentence to be evaluated
49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
50 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
51 | """
52 | assert(len(candidate)==1)
53 | assert(len(refs)>0)
54 | prec = []
55 | rec = []
56 |
57 | # split into tokens
58 | token_c = candidate[0].split(" ")
59 |
60 | for reference in refs:
61 | # split into tokens
62 | token_r = reference.split(" ")
63 | # compute the longest common subsequence
64 | lcs = my_lcs(token_r, token_c)
65 | prec.append(lcs/float(len(token_c)))
66 | rec.append(lcs/float(len(token_r)))
67 |
68 | prec_max = max(prec)
69 | rec_max = max(rec)
70 |
71 | if(prec_max!=0 and rec_max !=0):
72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
73 | else:
74 | score = 0.0
75 | return score
76 |
77 | def compute_score(self, gts, res):
78 | """
79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
80 | Invoked by evaluate_captions.py
81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
84 | """
85 | assert(gts.keys() == res.keys())
86 | imgIds = gts.keys()
87 |
88 | score = []
89 | for id in imgIds:
90 | hypo = res[id]
91 | ref = gts[id]
92 |
93 | score.append(self.calc_score(hypo, ref))
94 |
95 | # Sanity check.
96 | assert(type(hypo) is list)
97 | assert(len(hypo) == 1)
98 | assert(type(ref) is list)
99 | assert(len(ref) > 0)
100 |
101 | average_score = np.mean(np.array(score))
102 | return average_score, np.array(score)
103 |
104 | def method(self):
105 | return "Rouge"
106 |
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'hfang'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/ptbtokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import sys
13 | import subprocess
14 | import tempfile
15 | import itertools
16 |
17 | # path to the stanford corenlp jar
18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
19 |
20 | # punctuations to be removed from the sentences
21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
23 |
24 | class PTBTokenizer:
25 | """Python wrapper of Stanford PTBTokenizer"""
26 |
27 | def tokenize(self, captions_for_image):
28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
29 | 'edu.stanford.nlp.process.PTBTokenizer', \
30 | '-preserveLines', '-lowerCase']
31 |
32 | # ======================================================
33 | # prepare data for PTB Tokenizer
34 | # ======================================================
35 | final_tokenized_captions_for_image = {}
36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
38 |
39 | # ======================================================
40 | # save sentences to temporary file
41 | # ======================================================
42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
44 | tmp_file.write(sentences.encode())
45 | tmp_file.close()
46 |
47 | # ======================================================
48 | # tokenize sentence
49 | # ======================================================
50 | cmd.append(os.path.basename(tmp_file.name))
51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
52 | stdout=subprocess.PIPE)
53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
54 | token_lines = token_lines.decode()
55 | lines = token_lines.split('\n')
56 | # remove temp file
57 | os.remove(tmp_file.name)
58 |
59 | # ======================================================
60 | # create dictionary for tokenized captions
61 | # ======================================================
62 | for k, line in zip(image_id, lines):
63 | if not k in final_tokenized_captions_for_image:
64 | final_tokenized_captions_for_image[k] = []
65 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
66 | if w not in PUNCTUATIONS])
67 | final_tokenized_captions_for_image[k].append(tokenized_caption)
68 |
69 | return final_tokenized_captions_for_image
70 |
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/tmp82iqkuu0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/refTools/evaluation/tokenizer/tmp82iqkuu0
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/tmpn19wmqte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/refTools/evaluation/tokenizer/tmpn19wmqte
--------------------------------------------------------------------------------
/refTools/refer_python3.py:
--------------------------------------------------------------------------------
1 | __author__ = 'licheng'
2 |
3 | """
4 | This interface provides access to four datasets:
5 | 1) refclef
6 | 2) refcoco
7 | 3) refcoco+
8 | 4) refcocog
9 | split by unc and google
10 |
11 | The following API functions are defined:
12 | REFER - REFER api class
13 | getRefIds - get ref ids that satisfy given filter conditions.
14 | getAnnIds - get ann ids that satisfy given filter conditions.
15 | getImgIds - get image ids that satisfy given filter conditions.
16 | getCatIds - get category ids that satisfy given filter conditions.
17 | loadRefs - load refs with the specified ref ids.
18 | loadAnns - load anns with the specified ann ids.
19 | loadImgs - load images with the specified image ids.
20 | loadCats - load category names with the specified category ids.
21 | getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
22 | """
23 |
24 | import sys
25 | import os.path as osp
26 | import json
27 | import _pickle as pickle
28 | import time
29 | import itertools
30 | import skimage.io as io
31 | import matplotlib.pyplot as plt
32 | from matplotlib.collections import PatchCollection
33 | from matplotlib.patches import Polygon, Rectangle
34 | from pprint import pprint
35 | import numpy as np
36 | # import cv2
37 | # from skimage.measure import label, regionprops
38 |
39 | class REFER:
40 |
41 | def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
42 | # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
43 | # also provide dataset name and splitBy information
44 | # e.g., dataset = 'refcoco', splitBy = 'unc'
45 | print('loading dataset %s into memory...' % dataset)
46 | self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
47 | self.DATA_DIR = osp.join(data_root, dataset)
48 | if dataset in ['refcoco', 'refcoco+', 'refcocog']:
49 | self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
50 | elif dataset == 'refclef':
51 | self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
52 | else:
53 | print('No refer dataset is called [%s]' % dataset)
54 | sys.exit()
55 |
56 | # load refs from data/dataset/refs(dataset).json
57 | tic = time.time()
58 | ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
59 | self.data = {}
60 | self.data['dataset'] = dataset
61 | self.data['refs'] = pickle.load(open(ref_file, 'rb'))
62 |
63 | # load annotations from data/dataset/instances.json
64 | instances_file = osp.join(self.DATA_DIR, 'instances.json')
65 | instances = json.load(open(instances_file, 'r'))
66 | self.data['images'] = instances['images']
67 | self.data['annotations'] = instances['annotations']
68 | self.data['categories'] = instances['categories']
69 |
70 | # create index
71 | self.createIndex()
72 | print('DONE (t=%.2fs)' % (time.time()-tic))
73 |
74 | def createIndex(self):
75 | # create sets of mapping
76 | # 1) Refs: {ref_id: ref}
77 | # 2) Anns: {ann_id: ann}
78 | # 3) Imgs: {image_id: image}
79 | # 4) Cats: {category_id: category_name}
80 | # 5) Sents: {sent_id: sent}
81 | # 6) imgToRefs: {image_id: refs}
82 | # 7) imgToAnns: {image_id: anns}
83 | # 8) refToAnn: {ref_id: ann}
84 | # 9) annToRef: {ann_id: ref}
85 | # 10) catToRefs: {category_id: refs}
86 | # 11) sentToRef: {sent_id: ref}
87 | # 12) sentToTokens: {sent_id: tokens}
88 | print('creating index...')
89 | # fetch info from instances
90 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
91 | for ann in self.data['annotations']:
92 | Anns[ann['id']] = ann
93 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
94 | for img in self.data['images']:
95 | Imgs[img['id']] = img
96 | for cat in self.data['categories']:
97 | Cats[cat['id']] = cat['name']
98 |
99 | # fetch info from refs
100 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
101 | Sents, sentToRef, sentToTokens = {}, {}, {}
102 | for ref in self.data['refs']:
103 | # ids
104 | ref_id = ref['ref_id']
105 | ann_id = ref['ann_id']
106 | category_id = ref['category_id']
107 | image_id = ref['image_id']
108 |
109 | # add mapping related to ref
110 | Refs[ref_id] = ref
111 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
112 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
113 | refToAnn[ref_id] = Anns[ann_id]
114 | annToRef[ann_id] = ref
115 |
116 | # add mapping of sent
117 | for sent in ref['sentences']:
118 | Sents[sent['sent_id']] = sent
119 | sentToRef[sent['sent_id']] = ref
120 | sentToTokens[sent['sent_id']] = sent['tokens']
121 |
122 | # create class members
123 | self.Refs = Refs
124 | self.Anns = Anns
125 | self.Imgs = Imgs
126 | self.Cats = Cats
127 | self.Sents = Sents
128 | self.imgToRefs = imgToRefs
129 | self.imgToAnns = imgToAnns
130 | self.refToAnn = refToAnn
131 | self.annToRef = annToRef
132 | self.catToRefs = catToRefs
133 | self.sentToRef = sentToRef
134 | self.sentToTokens = sentToTokens
135 | print('index created.')
136 |
137 | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
138 | image_ids = image_ids if type(image_ids) == list else [image_ids]
139 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
140 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
141 |
142 | if len(image_ids)==len(cat_ids)==len(ref_ids)==len(split)==0:
143 | refs = self.data['refs']
144 | else:
145 | if not len(image_ids) == 0:
146 | refs = [self.imgToRefs[image_id] for image_id in image_ids]
147 | else:
148 | refs = self.data['refs']
149 | if not len(cat_ids) == 0:
150 | refs = [ref for ref in refs if ref['category_id'] in cat_ids]
151 | if not len(ref_ids) == 0:
152 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
153 | if not len(split) == 0:
154 | if split in ['testA', 'testB', 'testC']:
155 | refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ...
156 | elif split in ['testAB', 'testBC', 'testAC']:
157 | refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess...
158 | elif split == 'test':
159 | refs = [ref for ref in refs if 'test' in ref['split']]
160 | elif split == 'train' or split == 'val':
161 | refs = [ref for ref in refs if ref['split'] == split]
162 | else:
163 | print('No such split [%s]' % split)
164 | sys.exit()
165 | ref_ids = [ref['ref_id'] for ref in refs]
166 | return ref_ids
167 |
168 | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
169 | image_ids = image_ids if type(image_ids) == list else [image_ids]
170 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
171 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
172 |
173 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
174 | ann_ids = [ann['id'] for ann in self.data['annotations']]
175 | else:
176 | if not len(image_ids) == 0:
177 | lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
178 | anns = list(itertools.chain.from_iterable(lists))
179 | else:
180 | anns = self.data['annotations']
181 | if not len(cat_ids) == 0:
182 | anns = [ann for ann in anns if ann['category_id'] in cat_ids]
183 | ann_ids = [ann['id'] for ann in anns]
184 | if not len(ref_ids) == 0:
185 | ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
186 | return ann_ids
187 |
188 | def getImgIds(self, ref_ids=[]):
189 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
190 |
191 | if not len(ref_ids) == 0:
192 | image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
193 | else:
194 | image_ids = self.Imgs.keys()
195 | return image_ids
196 |
197 | def getCatIds(self):
198 | return self.Cats.keys()
199 |
200 | def loadRefs(self, ref_ids=[]):
201 | if type(ref_ids) == list:
202 | return [self.Refs[ref_id] for ref_id in ref_ids]
203 | elif type(ref_ids) == int:
204 | return [self.Refs[ref_ids]]
205 |
206 | def loadAnns(self, ann_ids=[]):
207 | if type(ann_ids) == list:
208 | return [self.Anns[ann_id] for ann_id in ann_ids]
209 | elif type(ann_ids) == int or type(ann_ids) == unicode:
210 | return [self.Anns[ann_ids]]
211 |
212 | def loadImgs(self, image_ids=[]):
213 | if type(image_ids) == list:
214 | return [self.Imgs[image_id] for image_id in image_ids]
215 | elif type(image_ids) == int:
216 | return [self.Imgs[image_ids]]
217 |
218 | def loadCats(self, cat_ids=[]):
219 | if type(cat_ids) == list:
220 | return [self.Cats[cat_id] for cat_id in cat_ids]
221 | elif type(cat_ids) == int:
222 | return [self.Cats[cat_ids]]
223 |
224 | def getRefBox(self, ref_id):
225 | ref = self.Refs[ref_id]
226 | ann = self.refToAnn[ref_id]
227 | return ann['bbox'] # [x, y, w, h]
228 |
229 |
230 |
231 | if __name__ == '__main__':
232 | refer = REFER(dataset='refcocog', splitBy='google')
233 | ref_ids = refer.getRefIds()
234 | print(len(ref_ids))
235 |
236 | print(len(refer.Imgs))
237 | print(len(refer.imgToRefs))
238 |
239 | ref_ids = refer.getRefIds(split='train')
240 | print('There are %s training referred objects.' % len(ref_ids))
241 |
242 | for ref_id in ref_ids:
243 | ref = refer.loadRefs(ref_id)[0]
244 | if len(ref['sentences']) < 2:
245 | continue
246 |
247 | pprint(ref)
248 | print('The label is %s.' % refer.Cats[ref['category_id']])
249 | plt.figure()
250 | refer.showRef(ref, seg_box='box')
251 | plt.show()
252 |
253 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.4.9
2 | transformers==4.12.5
3 | ruamel_yaml
4 | opencv-python
5 | scikit-image
6 | matplotlib
7 | pycocotools
8 | pycocoevalcap
9 | datasets
10 | sentencepiece
11 | accelerate
12 | scikit-learn
13 | einops
14 | einops_exts
--------------------------------------------------------------------------------
/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import LambdaLR
2 |
3 |
4 | def create_scheduler(args, optimizer):
5 | if 'num_training_steps' not in args:
6 | args['num_training_steps'] = args['epochs'] * args['step_per_epoch']
7 | print("### num_training_steps, ", args['num_training_steps'], flush=True)
8 |
9 | if 'min_rate' not in args:
10 | args['min_rate'] = 0.0
11 |
12 | if isinstance(args['num_warmup_steps'], float):
13 | assert 0 <= args['num_warmup_steps'] < 1
14 | args['num_warmup_steps'] = int(args['num_training_steps'] * args['num_warmup_steps'])
15 | print("### num_warmup_steps, ", args['num_warmup_steps'], flush=True)
16 |
17 | if args.sched == 'linear':
18 | def lr_lambda(current_step: int):
19 | if current_step < args.num_warmup_steps:
20 | return float(current_step) / float(max(1, args.num_warmup_steps))
21 | return max(
22 | args['min_rate'], float(args.num_training_steps - (1-args['min_rate'])*current_step) / float(
23 | max(1, args.num_training_steps - args.num_warmup_steps))
24 | )
25 |
26 | lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
27 |
28 | else:
29 | raise NotImplementedError(f"args.sched == {args.sched}")
30 |
31 | return lr_scheduler
32 |
--------------------------------------------------------------------------------
/utils/checkpointer.py:
--------------------------------------------------------------------------------
1 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276)
2 | # Github: https://github.com/zengyan-97/X-VLM
3 | # Copyright (c) 2022, ByteDance Inc.
4 | # All rights reserved.
5 |
6 | from typing import Union, Dict, List, Tuple, Any, Callable
7 | import logging
8 | import os
9 | import re
10 | import time
11 |
12 | import torch
13 |
14 | from utils.hdfs_io import hexists, hmkdir, hcopy
15 | from utils.torch_io import save as hdfs_torch_save
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | class Checkpointer:
20 | def __init__(self,
21 | serialization_dir: str = ".output") -> None:
22 | self._serialization_dir = serialization_dir
23 | if not hexists(self._serialization_dir):
24 | hmkdir(self._serialization_dir)
25 |
26 | def save_checkpoint(self,
27 | epoch: Union[int, str],
28 | model_state: Dict[str, Any],
29 | training_states: Dict[str, Any],
30 | step: int = -1) -> None:
31 | """
32 | Save ckpt to local or HDFS
33 | """
34 | if step > 0:
35 | model_path = os.path.join(
36 | self._serialization_dir, "model_state_step_{}.th".format(step))
37 | hdfs_torch_save(model_state, model_path)
38 |
39 | else:
40 | model_path = os.path.join(
41 | self._serialization_dir, "model_state_epoch_{}.th".format(epoch))
42 |
43 | training_path = os.path.join(self._serialization_dir,
44 | "training_state_latest.th")
45 | hdfs_torch_save(model_state, model_path)
46 | hdfs_torch_save({**training_states, "epoch": epoch}, training_path)
47 |
--------------------------------------------------------------------------------
/utils/cider/pyciderevalcap/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/utils/cider/pyciderevalcap/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/utils/cider/pyciderevalcap/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | #
4 | # Description: Describes the class to compute the CIDEr
5 | # (Consensus-Based Image Description Evaluation) Metric
6 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
7 | #
8 | # Creation Date: Sun Feb 8 14:16:54 2015
9 | #
10 | # Authors: Ramakrishna Vedantam and
11 | # Tsung-Yi Lin
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | from .cider_scorer import CiderScorer
17 |
18 |
19 | class Cider:
20 | """
21 | Main Class to compute the CIDEr metric
22 |
23 | """
24 | def __init__(self, n=4, df="corpus"):
25 | """
26 | Initialize the CIDEr scoring function
27 | : param n (int): n-gram size
28 | : param df (string): specifies where to get the IDF values from
29 | takes values 'corpus', 'coco-train'
30 | : return: None
31 | """
32 | # set cider to sum over 1 to 4-grams
33 | self._n = n
34 | self._df = df
35 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
36 |
37 | def compute_score(self, gts, res):
38 | """
39 | Main function to compute CIDEr score
40 | : param gts (dict) : {image:tokenized reference sentence}
41 | : param res (dict) : {image:tokenized candidate sentence}
42 | : return: cider (float) : computed CIDEr score for the corpus
43 | """
44 |
45 | # clear all the previous hypos and refs
46 | self.cider_scorer.clear()
47 |
48 | for res_id in res:
49 |
50 | hypo = res_id['caption']
51 | ref = gts[res_id['image_id']]
52 |
53 | # Sanity check.
54 | assert(type(hypo) is list)
55 | assert(len(hypo) == 1)
56 | assert(type(ref) is list)
57 | assert(len(ref) > 0)
58 | self.cider_scorer += (hypo[0], ref)
59 |
60 | (score, scores) = self.cider_scorer.compute_score()
61 |
62 | return score, scores
63 |
64 | def method(self):
65 | return "CIDEr"
66 |
--------------------------------------------------------------------------------
/utils/cider/pyciderevalcap/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import copy
9 | import six
10 | from six.moves import cPickle
11 | from collections import defaultdict
12 | import numpy as np
13 | import math
14 | import os
15 |
16 | def precook(s, n=4, out=False):
17 | """
18 | Takes a string as input and returns an object that can be given to
19 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
20 | can take string arguments as well.
21 | :param s: string : sentence to be converted into ngrams
22 | :param n: int : number of ngrams for which representation is calculated
23 | :return: term frequency vector for occuring ngrams
24 | """
25 | words = s.split()
26 | counts = defaultdict(int)
27 | for k in range(1,n+1):
28 | for i in range(len(words)-k+1):
29 | ngram = tuple(words[i:i+k])
30 | counts[ngram] += 1
31 | return counts
32 |
33 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
34 | '''Takes a list of reference sentences for a single segment
35 | and returns an object that encapsulates everything that BLEU
36 | needs to know about them.
37 | :param refs: list of string : reference sentences for some image
38 | :param n: int : number of ngrams for which (ngram) representation is calculated
39 | :return: result (list of dict)
40 | '''
41 | return [precook(ref, n) for ref in refs]
42 |
43 | def cook_test(test, n=4):
44 | '''Takes a test sentence and returns an object that
45 | encapsulates everything that BLEU needs to know about it.
46 | :param test: list of string : hypothesis sentence for some image
47 | :param n: int : number of ngrams for which (ngram) representation is calculated
48 | :return: result (dict)
49 | '''
50 | return precook(test, n, True)
51 |
52 | class CiderScorer(object):
53 | """CIDEr scorer.
54 | """
55 |
56 | def copy(self):
57 | ''' copy the refs.'''
58 | new = CiderScorer(n=self.n)
59 | new.ctest = copy.copy(self.ctest)
60 | new.crefs = copy.copy(self.crefs)
61 | return new
62 |
63 | def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
64 | ''' singular instance '''
65 | self.n = n
66 | self.sigma = sigma
67 | self.crefs = []
68 | self.ctest = []
69 | self.df_mode = df_mode
70 | self.ref_len = None
71 | if self.df_mode != "corpus":
72 | pkl_file = cPickle.load(open(os.path.join('data', df_mode + '.p'),'rb'), **(dict(encoding='latin1') if six.PY3 else {}))
73 | self.ref_len = np.log(float(pkl_file['ref_len']))
74 | self.document_frequency = pkl_file['document_frequency']
75 | self.cook_append(test, refs)
76 |
77 | def clear(self):
78 | self.crefs = []
79 | self.ctest = []
80 |
81 | def cook_append(self, test, refs):
82 | '''called by constructor and __iadd__ to avoid creating new instances.'''
83 |
84 | if refs is not None:
85 | self.crefs.append(cook_refs(refs))
86 | if test is not None:
87 | self.ctest.append(cook_test(test)) ## N.B.: -1
88 | else:
89 | self.ctest.append(None) # lens of crefs and ctest have to match
90 |
91 | def size(self):
92 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
93 | return len(self.crefs)
94 |
95 | def __iadd__(self, other):
96 | '''add an instance (e.g., from another sentence).'''
97 |
98 | if type(other) is tuple:
99 | ## avoid creating new CiderScorer instances
100 | self.cook_append(other[0], other[1])
101 | else:
102 | self.ctest.extend(other.ctest)
103 | self.crefs.extend(other.crefs)
104 |
105 | return self
106 | def compute_doc_freq(self):
107 | '''
108 | Compute term frequency for reference data.
109 | This will be used to compute idf (inverse document frequency later)
110 | The term frequency is stored in the object
111 | :return: None
112 | '''
113 | for refs in self.crefs:
114 | # refs, k ref captions of one image
115 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
116 | self.document_frequency[ngram] += 1
117 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
118 |
119 | def compute_cider(self):
120 | def counts2vec(cnts):
121 | """
122 | Function maps counts of ngram to vector of tfidf weights.
123 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
124 | The n-th entry of array denotes length of n-grams.
125 | :param cnts:
126 | :return: vec (array of dict), norm (array of float), length (int)
127 | """
128 | vec = [defaultdict(float) for _ in range(self.n)]
129 | length = 0
130 | norm = [0.0 for _ in range(self.n)]
131 | for (ngram,term_freq) in cnts.items():
132 | # give word count 1 if it doesn't appear in reference corpus
133 | df = np.log(max(1.0, self.document_frequency[ngram]))
134 | # ngram index
135 | n = len(ngram)-1
136 | # tf (term_freq) * idf (precomputed idf) for n-grams
137 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
138 | # compute norm for the vector. the norm will be used for
139 | # computing similarity
140 | norm[n] += pow(vec[n][ngram], 2)
141 |
142 | if n == 1:
143 | length += term_freq
144 | norm = [np.sqrt(n) for n in norm]
145 | return vec, norm, length
146 |
147 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
148 | '''
149 | Compute the cosine similarity of two vectors.
150 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
151 | :param vec_ref: array of dictionary for vector corresponding to reference
152 | :param norm_hyp: array of float for vector corresponding to hypothesis
153 | :param norm_ref: array of float for vector corresponding to reference
154 | :param length_hyp: int containing length of hypothesis
155 | :param length_ref: int containing length of reference
156 | :return: array of score for each n-grams cosine similarity
157 | '''
158 | delta = float(length_hyp - length_ref)
159 | # measure consine similarity
160 | val = np.array([0.0 for _ in range(self.n)])
161 | for n in range(self.n):
162 | # ngram
163 | for (ngram,count) in vec_hyp[n].items():
164 | val[n] += vec_hyp[n][ngram] * vec_ref[n][ngram]
165 |
166 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
167 | val[n] /= (norm_hyp[n]*norm_ref[n])
168 |
169 | assert(not math.isnan(val[n]))
170 | return val
171 |
172 | # compute log reference length
173 | if self.df_mode == "corpus":
174 | self.ref_len = np.log(float(len(self.crefs)))
175 |
176 | scores = []
177 | for test, refs in zip(self.ctest, self.crefs):
178 | # compute vector for test captions
179 | vec, norm, length = counts2vec(test)
180 | # compute vector for ref captions
181 | score = np.array([0.0 for _ in range(self.n)])
182 | for ref in refs:
183 | vec_ref, norm_ref, length_ref = counts2vec(ref)
184 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
185 | # change by vrama91 - mean of ngram scores, instead of sum
186 | score_avg = np.mean(score)
187 | # divide by number of references
188 | score_avg /= len(refs)
189 | # multiply score by 10
190 | score_avg *= 10.0
191 | # append score of an image to the score list
192 | scores.append(score_avg)
193 | return scores
194 |
195 | def compute_score(self, option=None, verbose=0):
196 | # compute idf
197 | if self.df_mode == "corpus":
198 | self.document_frequency = defaultdict(float)
199 | self.compute_doc_freq()
200 | # assert to check document frequency
201 | assert(len(self.ctest) >= max(self.document_frequency.values()))
202 | # import json for now and write the corresponding files
203 | # compute cider score
204 | score = self.compute_cider()
205 | # debug
206 | # print score
207 | return np.mean(np.array(score)), np.array(score)
208 |
--------------------------------------------------------------------------------
/utils/cider/pyciderevalcap/ciderD/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/utils/cider/pyciderevalcap/ciderD/ciderD.py:
--------------------------------------------------------------------------------
1 | # Filename: ciderD.py
2 | #
3 | # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | from .ciderD_scorer import CiderScorer
14 | import pdb
15 |
16 | class CiderD:
17 | """
18 | Main Class to compute the CIDEr metric
19 |
20 | """
21 | def __init__(self, n=4, sigma=6.0, df="corpus"):
22 | # set cider to sum over 1 to 4-grams
23 | self._n = n
24 | # set the standard deviation parameter for gaussian penalty
25 | self._sigma = sigma
26 | # set which where to compute document frequencies from
27 | self._df = df
28 | self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
29 |
30 | def compute_score(self, gts, res):
31 | """
32 | Main function to compute CIDEr score
33 | :param hypo_for_image (dict) : dictionary with key and value
34 | ref_for_image (dict) : dictionary with key and value
35 | :return: cider (float) : computed CIDEr score for the corpus
36 | """
37 |
38 | # clear all the previous hypos and refs
39 | tmp_cider_scorer = self.cider_scorer.copy_empty()
40 | tmp_cider_scorer.clear()
41 | for res_id in res:
42 |
43 | hypo = res_id['caption']
44 | ref = gts[res_id['image_id']]
45 |
46 | # Sanity check.
47 | assert(type(hypo) is list)
48 | assert(len(hypo) == 1)
49 | assert(type(ref) is list)
50 | assert(len(ref) > 0)
51 | tmp_cider_scorer += (hypo[0], ref)
52 |
53 | (score, scores) = tmp_cider_scorer.compute_score()
54 |
55 | return score, scores
56 |
57 | def method(self):
58 | return "CIDEr-D"
59 |
--------------------------------------------------------------------------------
/utils/cider/pyciderevalcap/ciderD/ciderD_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import copy
9 | from collections import defaultdict
10 | import numpy as np
11 | import pdb
12 | import math
13 | import six
14 | from six.moves import cPickle
15 | import os
16 |
17 | def precook(s, n=4, out=False):
18 | """
19 | Takes a string as input and returns an object that can be given to
20 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
21 | can take string arguments as well.
22 | :param s: string : sentence to be converted into ngrams
23 | :param n: int : number of ngrams for which representation is calculated
24 | :return: term frequency vector for occuring ngrams
25 | """
26 | words = s.split()
27 | counts = defaultdict(int)
28 | for k in range(1,n+1):
29 | for i in range(len(words)-k+1):
30 | ngram = tuple(words[i:i+k])
31 | counts[ngram] += 1
32 | return counts
33 |
34 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
35 | '''Takes a list of reference sentences for a single segment
36 | and returns an object that encapsulates everything that BLEU
37 | needs to know about them.
38 | :param refs: list of string : reference sentences for some image
39 | :param n: int : number of ngrams for which (ngram) representation is calculated
40 | :return: result (list of dict)
41 | '''
42 | return [precook(ref, n) for ref in refs]
43 |
44 | def cook_test(test, n=4):
45 | '''Takes a test sentence and returns an object that
46 | encapsulates everything that BLEU needs to know about it.
47 | :param test: list of string : hypothesis sentence for some image
48 | :param n: int : number of ngrams for which (ngram) representation is calculated
49 | :return: result (dict)
50 | '''
51 | return precook(test, n, True)
52 |
53 | class CiderScorer(object):
54 | """CIDEr scorer.
55 | """
56 |
57 | def copy(self):
58 | ''' copy the refs.'''
59 | new = CiderScorer(n=self.n)
60 | new.ctest = copy.copy(self.ctest)
61 | new.crefs = copy.copy(self.crefs)
62 | return new
63 |
64 | def copy_empty(self):
65 | new = CiderScorer(df_mode="corpus", n=self.n, sigma=self.sigma)
66 | new.df_mode = self.df_mode
67 | new.ref_len = self.ref_len
68 | new.document_frequency = self.document_frequency
69 | return new
70 |
71 | def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
72 | ''' singular instance '''
73 | self.n = n
74 | self.sigma = sigma
75 | self.crefs = []
76 | self.ctest = []
77 | self.df_mode = df_mode
78 | self.ref_len = None
79 | if self.df_mode != "corpus":
80 | pkl_file = cPickle.load(open(df_mode,'rb'), **(dict(encoding='latin1') if six.PY3 else {}))
81 | self.ref_len = np.log(float(pkl_file['ref_len']))
82 | self.document_frequency = pkl_file['document_frequency']
83 | else:
84 | self.document_frequency = None
85 | self.cook_append(test, refs)
86 |
87 | def clear(self):
88 | self.crefs = []
89 | self.ctest = []
90 |
91 | def cook_append(self, test, refs):
92 | '''called by constructor and __iadd__ to avoid creating new instances.'''
93 |
94 | if refs is not None:
95 | self.crefs.append(cook_refs(refs))
96 | if test is not None:
97 | self.ctest.append(cook_test(test)) ## N.B.: -1
98 | else:
99 | self.ctest.append(None) # lens of crefs and ctest have to match
100 |
101 | def size(self):
102 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
103 | return len(self.crefs)
104 |
105 | def __iadd__(self, other):
106 | '''add an instance (e.g., from another sentence).'''
107 |
108 | if type(other) is tuple:
109 | ## avoid creating new CiderScorer instances
110 | self.cook_append(other[0], other[1])
111 | else:
112 | self.ctest.extend(other.ctest)
113 | self.crefs.extend(other.crefs)
114 |
115 | return self
116 | def compute_doc_freq(self):
117 | '''
118 | Compute term frequency for reference data.
119 | This will be used to compute idf (inverse document frequency later)
120 | The term frequency is stored in the object
121 | :return: None
122 | '''
123 | for refs in self.crefs:
124 | # refs, k ref captions of one image
125 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
126 | self.document_frequency[ngram] += 1
127 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
128 |
129 | def compute_cider(self):
130 | def counts2vec(cnts):
131 | """
132 | Function maps counts of ngram to vector of tfidf weights.
133 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
134 | The n-th entry of array denotes length of n-grams.
135 | :param cnts:
136 | :return: vec (array of dict), norm (array of float), length (int)
137 | """
138 | vec = [defaultdict(float) for _ in range(self.n)]
139 | length = 0
140 | norm = [0.0 for _ in range(self.n)]
141 | for (ngram,term_freq) in cnts.items():
142 | # give word count 1 if it doesn't appear in reference corpus
143 | df = np.log(max(1.0, self.document_frequency[ngram]))
144 | # ngram index
145 | n = len(ngram)-1
146 | # tf (term_freq) * idf (precomputed idf) for n-grams
147 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
148 | # compute norm for the vector. the norm will be used for computing similarity
149 | norm[n] += pow(vec[n][ngram], 2)
150 |
151 | if n == 1:
152 | length += term_freq
153 | norm = [np.sqrt(n) for n in norm]
154 | return vec, norm, length
155 |
156 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
157 | '''
158 | Compute the cosine similarity of two vectors.
159 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
160 | :param vec_ref: array of dictionary for vector corresponding to reference
161 | :param norm_hyp: array of float for vector corresponding to hypothesis
162 | :param norm_ref: array of float for vector corresponding to reference
163 | :param length_hyp: int containing length of hypothesis
164 | :param length_ref: int containing length of reference
165 | :return: array of score for each n-grams cosine similarity
166 | '''
167 | delta = float(length_hyp - length_ref)
168 | # measure consine similarity
169 | val = np.array([0.0 for _ in range(self.n)])
170 | for n in range(self.n):
171 | # ngram
172 | for (ngram,count) in vec_hyp[n].items():
173 | # vrama91 : added clipping
174 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
175 |
176 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
177 | val[n] /= (norm_hyp[n]*norm_ref[n])
178 |
179 | assert(not math.isnan(val[n]))
180 | # vrama91: added a length based gaussian penalty
181 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
182 | return val
183 |
184 | # compute log reference length
185 | if self.df_mode == "corpus":
186 | self.ref_len = np.log(float(len(self.crefs)))
187 | #elif self.df_mode == "coco-val-df":
188 | # if coco option selected, use length of coco-val set
189 | # self.ref_len = np.log(float(40504))
190 |
191 | scores = []
192 | for test, refs in zip(self.ctest, self.crefs):
193 | # compute vector for test captions
194 | vec, norm, length = counts2vec(test)
195 | # compute vector for ref captions
196 | score = np.array([0.0 for _ in range(self.n)])
197 | for ref in refs:
198 | vec_ref, norm_ref, length_ref = counts2vec(ref)
199 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
200 | # change by vrama91 - mean of ngram scores, instead of sum
201 | score_avg = np.mean(score)
202 | # divide by number of references
203 | score_avg /= len(refs)
204 | # multiply score by 10
205 | score_avg *= 10.0
206 | # append score of an image to the score list
207 | scores.append(score_avg)
208 | return scores
209 |
210 | def compute_score(self, option=None, verbose=0):
211 | # compute idf
212 | if self.df_mode == "corpus":
213 | self.document_frequency = defaultdict(float)
214 | self.compute_doc_freq()
215 | # assert to check document frequency
216 | assert(len(self.ctest) >= max(self.document_frequency.values()))
217 | # import json for now and write the corresponding files
218 | # compute cider score
219 | score = self.compute_cider()
220 | # debug
221 | # print score
222 | return np.mean(np.array(score)), np.array(score)
223 |
--------------------------------------------------------------------------------
/utils/hdfs_io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276)
4 | # Github: https://github.com/zengyan-97/X-VLM
5 | # Copyright (c) 2022, ByteDance Inc.
6 | # All rights reserved.
7 |
8 | import sys
9 | from typing import IO, Any, List
10 |
11 | import shutil
12 | import subprocess
13 | from contextlib import contextmanager
14 | import os
15 | import glob
16 | import threading
17 |
18 | HADOOP_BIN = 'HADOOP_ROOT_LOGGER=ERROR,console /SET/PATH/TO/hadoop/bin/hdfs'
19 |
20 | __all__ = ['hlist_files', 'hopen', 'hexists', 'hmkdir']
21 |
22 |
23 | @contextmanager # type: ignore
24 | def hopen(hdfs_path: str, mode: str = "r") -> IO[Any]:
25 | """
26 | open a file on hdfs with contextmanager.
27 |
28 | Args:
29 | mode (str): supports ["r", "w", "wa"]
30 | """
31 | pipe = None
32 | if mode.startswith("r"):
33 | pipe = subprocess.Popen(
34 | "{} dfs -text {}".format(HADOOP_BIN, hdfs_path), shell=True, stdout=subprocess.PIPE)
35 | yield pipe.stdout
36 | pipe.stdout.close() # type: ignore
37 | pipe.wait()
38 | return
39 | if mode == "wa" or mode == "a":
40 | pipe = subprocess.Popen(
41 | "{} dfs -appendToFile - {}".format(HADOOP_BIN, hdfs_path), shell=True, stdin=subprocess.PIPE)
42 | yield pipe.stdin
43 | pipe.stdin.close() # type: ignore
44 | pipe.wait()
45 | return
46 | if mode.startswith("w"):
47 | pipe = subprocess.Popen(
48 | "{} dfs -put -f - {}".format(HADOOP_BIN, hdfs_path), shell=True, stdin=subprocess.PIPE)
49 | yield pipe.stdin
50 | pipe.stdin.close() # type: ignore
51 | pipe.wait()
52 | return
53 | raise RuntimeError("unsupported io mode: {}".format(mode))
54 |
55 |
56 | def hlist_files(folders: List[str]) -> List[str]:
57 | files = []
58 | for folder in folders:
59 | if folder.startswith('hdfs'):
60 | pipe = subprocess.Popen("{} dfs -ls {}".format(HADOOP_BIN, folder), shell=True,
61 | stdout=subprocess.PIPE)
62 | # output, _ = pipe.communicate()
63 | for line in pipe.stdout: # type: ignore
64 | line = line.strip()
65 | # drwxr-xr-x - user group 4 file
66 | if len(line.split()) < 5:
67 | continue
68 | files.append(line.split()[-1].decode("utf8"))
69 | pipe.stdout.close() # type: ignore
70 | pipe.wait()
71 | else:
72 | if os.path.isdir(folder):
73 | files.extend([os.path.join(folder, d) for d in os.listdir(folder)])
74 | elif os.path.isfile(folder):
75 | files.append(folder)
76 | else:
77 | print('Path {} is invalid'.format(folder))
78 | sys.stdout.flush()
79 |
80 | return files
81 |
82 |
83 | def hexists(file_path: str) -> bool:
84 | """ hdfs capable to check whether a file_path is exists """
85 | if file_path.startswith('hdfs'):
86 | return os.system("{} dfs -test -e {}".format(HADOOP_BIN, file_path)) == 0
87 | return os.path.exists(file_path)
88 |
89 |
90 | def hmkdir(file_path: str) -> bool:
91 | """ hdfs mkdir """
92 | if file_path.startswith('hdfs'):
93 | os.system("{} dfs -mkdir -p {}".format(HADOOP_BIN, file_path)) # exist ok
94 | else:
95 | if not os.path.exists(file_path):
96 | os.makedirs(file_path, exist_ok=True)
97 | return True
98 |
99 |
100 | def hcopy(from_path: str, to_path: str) -> bool:
101 | """ hdfs copy """
102 | if to_path.startswith("hdfs"):
103 | if from_path.startswith("hdfs"):
104 | os.system("{} dfs -cp -f {} {}".format(HADOOP_BIN, from_path, to_path))
105 | else:
106 | os.system("{} dfs -copyFromLocal -f {} {}".format(HADOOP_BIN, from_path, to_path))
107 | else:
108 | if from_path.startswith("hdfs"):
109 | os.system("{} dfs -text {} > {}".format(HADOOP_BIN, from_path, to_path))
110 | else:
111 | shutil.copy(from_path, to_path)
112 | return True
113 |
114 |
115 | def hcountline(path):
116 | '''
117 | count line in file
118 | '''
119 | count = 0
120 | if path.startswith('hdfs'):
121 | with hopen(path, 'r') as f:
122 | for line in f:
123 | count += 1
124 | else:
125 | with open(path, 'r') as f:
126 | for line in f:
127 | count += 1
128 | return count
129 |
--------------------------------------------------------------------------------
/utils/marvl_preproc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 |
5 | def marvl_preproc(ipath, opath):
6 | if not os.path.exists(opath):
7 | os.makedirs(opath)
8 | # test
9 | root = os.path.join(ipath, 'zero_shot/annotations')
10 | for fp in os.listdir(root):
11 | with open(os.path.join(root, fp)) as f, open(os.path.join(opath, fp[:-1]), 'w') as wf:
12 | data = []
13 | for l in f:
14 | d = json.loads(l)
15 | data.append({
16 | 'sentence': d['caption'],
17 | 'label': d['label'],
18 | 'images': ['images/marvl_official/{}/images/{}/{}'.format(d['language'], d['left_img'].split('-')[0], d['left_img']),
19 | 'images/marvl_official/{}/images/{}/{}'.format(d['language'], d['right_img'].split('-')[0], d['right_img'])]
20 | })
21 | json.dump(data, wf)
22 | # few shot
23 | root = os.path.join(ipath, 'few_shot/annotations')
24 | for fp in os.listdir(root):
25 | with open(os.path.join(root, fp)) as f, open(os.path.join(opath, fp[:-1]), 'w') as wf:
26 | data = []
27 | for l in f:
28 | d = json.loads(l)
29 | data.append({
30 | 'sentence': d['caption'],
31 | 'label': d['label'],
32 | 'images': ['images/marvl_fewshot/{}/all/{}'.format(d['language'], d['left_img'].split('/')[-1]),
33 | 'images/marvl_fewshot/{}/all/{}'.format(d['language'], d['right_img'].split('/')[-1])]
34 | })
35 | json.dump(data, wf)
36 |
--------------------------------------------------------------------------------
/utils/torch_io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276)
4 | # Github: https://github.com/zengyan-97/X-VLM
5 | # Copyright (c) 2022, ByteDance Inc.
6 | # All rights reserved.
7 |
8 | import io
9 | import torch
10 |
11 | from .hdfs_io import hopen
12 |
13 |
14 | def load(filepath: str, **kwargs):
15 | """ load model """
16 | if not filepath.startswith("hdfs://"):
17 | return torch.load(filepath, **kwargs)
18 | with hopen(filepath, "rb") as reader:
19 | accessor = io.BytesIO(reader.read())
20 | state_dict = torch.load(accessor, **kwargs)
21 | del accessor
22 | return state_dict
23 |
24 |
25 | def save(obj, filepath: str, **kwargs):
26 | """ save model """
27 | if filepath.startswith("hdfs://"):
28 | with hopen(filepath, "wb") as writer:
29 | torch.save(obj, writer, **kwargs)
30 | else:
31 | torch.save(obj, filepath, **kwargs)
32 |
--------------------------------------------------------------------------------
/vqaTools/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 |
--------------------------------------------------------------------------------
/vqaTools/vqa.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 | __version__ = '0.9'
3 |
4 | # Interface for accessing the VQA dataset.
5 |
6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
8 |
9 | # The following functions are defined:
10 | # VQA - VQA class that loads VQA annotation file and prepares data structures.
11 | # getQuesIds - Get question ids that satisfy given filter conditions.
12 | # getImgIds - Get image ids that satisfy given filter conditions.
13 | # loadQA - Load questions and answers with the specified question ids.
14 | # showQA - Display the specified questions and answers.
15 | # loadRes - Load result file and create result object.
16 |
17 | # Help on each function can be accessed by: "help(COCO.function)"
18 |
19 | import json
20 | import datetime
21 | import copy
22 |
23 | class VQA:
24 | def __init__(self, annotation_file=None, question_file=None):
25 | """
26 | Constructor of VQA helper class for reading and visualizing questions and answers.
27 | :param annotation_file (str): location of VQA annotation file
28 | :return:
29 | """
30 | # load dataset
31 | self.dataset = {}
32 | self.questions = {}
33 | self.qa = {}
34 | self.qqa = {}
35 | self.imgToQA = {}
36 | if not annotation_file == None and not question_file == None:
37 | print('loading VQA annotations and questions into memory...')
38 | time_t = datetime.datetime.utcnow()
39 | dataset = json.load(open(annotation_file, 'r'))
40 | questions = json.load(open(question_file, 'r'))
41 | self.dataset = dataset
42 | self.questions = questions
43 | self.createIndex()
44 |
45 | def createIndex(self):
46 | # create index
47 | print('creating index...')
48 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
49 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
50 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
51 | for ann in self.dataset['annotations']:
52 | imgToQA[ann['image_id']] += [ann]
53 | qa[ann['question_id']] = ann
54 | for ques in self.questions['questions']:
55 | qqa[ques['question_id']] = ques
56 | print('index created!')
57 |
58 | # create class members
59 | self.qa = qa
60 | self.qqa = qqa
61 | self.imgToQA = imgToQA
62 |
63 | def info(self):
64 | """
65 | Print information about the VQA annotation file.
66 | :return:
67 | """
68 | for key, value in self.datset['info'].items():
69 | print('%s: %s'%(key, value))
70 |
71 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
72 | """
73 | Get question ids that satisfy given filter conditions. default skips that filter
74 | :param imgIds (int array) : get question ids for given imgs
75 | quesTypes (str array) : get question ids for given question types
76 | ansTypes (str array) : get question ids for given answer types
77 | :return: ids (int array) : integer array of question ids
78 | """
79 | imgIds = imgIds if type(imgIds) == list else [imgIds]
80 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
81 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
82 |
83 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
84 | anns = self.dataset['annotations']
85 | else:
86 | if not len(imgIds) == 0:
87 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[])
88 | else:
89 | anns = self.dataset['annotations']
90 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
91 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
92 | ids = [ann['question_id'] for ann in anns]
93 | return ids
94 |
95 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
96 | """
97 | Get image ids that satisfy given filter conditions. default skips that filter
98 | :param quesIds (int array) : get image ids for given question ids
99 | quesTypes (str array) : get image ids for given question types
100 | ansTypes (str array) : get image ids for given answer types
101 | :return: ids (int array) : integer array of image ids
102 | """
103 | quesIds = quesIds if type(quesIds) == list else [quesIds]
104 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
105 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
106 |
107 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
108 | anns = self.dataset['annotations']
109 | else:
110 | if not len(quesIds) == 0:
111 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[])
112 | else:
113 | anns = self.dataset['annotations']
114 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
115 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
116 | ids = [ann['image_id'] for ann in anns]
117 | return ids
118 |
119 | def loadQA(self, ids=[]):
120 | """
121 | Load questions and answers with the specified question ids.
122 | :param ids (int array) : integer ids specifying question ids
123 | :return: qa (object array) : loaded qa objects
124 | """
125 | if type(ids) == list:
126 | return [self.qa[id] for id in ids]
127 | elif type(ids) == int:
128 | return [self.qa[ids]]
129 |
130 | def showQA(self, anns):
131 | """
132 | Display the specified annotations.
133 | :param anns (array of object): annotations to display
134 | :return: None
135 | """
136 | if len(anns) == 0:
137 | return 0
138 | for ann in anns:
139 | quesId = ann['question_id']
140 | print("Question: %s" %(self.qqa[quesId]['question']))
141 | for ans in ann['answers']:
142 | print("Answer %d: %s" %(ans['answer_id'], ans['answer']))
143 |
144 | def loadRes(self, resFile, quesFile):
145 | """
146 | Load result file and return a result object.
147 | :param resFile (str) : file name of result file
148 | :return: res (obj) : result api object
149 | """
150 | res = VQA()
151 | res.questions = json.load(open(quesFile))
152 | res.dataset['info'] = copy.deepcopy(self.questions['info'])
153 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
154 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
155 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
156 | res.dataset['license'] = copy.deepcopy(self.questions['license'])
157 |
158 | print('Loading and preparing results... ')
159 | time_t = datetime.datetime.utcnow()
160 | anns = json.load(open(resFile))
161 | assert type(anns) == list, 'results is not an array of objects'
162 | annsQuesIds = [ann['question_id'] for ann in anns]
163 | assert set(annsQuesIds) == set(self.getQuesIds()), \
164 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
165 | for ann in anns:
166 | quesId = ann['question_id']
167 | if res.dataset['task_type'] == 'Multiple Choice':
168 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices'
169 | qaAnn = self.qa[quesId]
170 | ann['image_id'] = qaAnn['image_id']
171 | ann['question_type'] = qaAnn['question_type']
172 | ann['answer_type'] = qaAnn['answer_type']
173 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))
174 |
175 | res.dataset['annotations'] = anns
176 | res.createIndex()
177 | return res
178 |
--------------------------------------------------------------------------------
/vqaTools/vqaEval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | __author__='aagrawal'
4 |
5 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
6 | # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
7 | import sys
8 | import re
9 |
10 | class VQAEval:
11 | def __init__(self, vqa, vqaRes, n=2):
12 | self.n = n
13 | self.accuracy = {}
14 | self.evalQA = {}
15 | self.evalQuesType = {}
16 | self.evalAnsType = {}
17 | self.vqa = vqa
18 | self.vqaRes = vqaRes
19 | self.params = {'question_id': vqa.getQuesIds()}
20 | self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't",
21 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't",
22 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've",
23 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've",
24 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
25 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
26 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
27 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've",
28 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've",
29 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll",
30 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've",
31 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've",
32 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've",
33 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've",
34 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't",
35 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're",
36 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've",
37 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
38 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
39 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
40 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've",
41 | "youll": "you'll", "youre": "you're", "youve": "you've"}
42 | self.manualMap = { 'none': '0',
43 | 'zero': '0',
44 | 'one': '1',
45 | 'two': '2',
46 | 'three': '3',
47 | 'four': '4',
48 | 'five': '5',
49 | 'six': '6',
50 | 'seven': '7',
51 | 'eight': '8',
52 | 'nine': '9',
53 | 'ten': '10'
54 | }
55 | self.articles = ['a',
56 | 'an',
57 | 'the'
58 | ]
59 |
60 |
61 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
62 | self.commaStrip = re.compile("(\d)(,)(\d)")
63 | self.punct = [';', r"/", '[', ']', '"', '{', '}',
64 | '(', ')', '=', '+', '\\', '_', '-',
65 | '>', '<', '@', '`', ',', '?', '!']
66 |
67 |
68 | def evaluate(self, quesIds=None):
69 | if quesIds == None:
70 | quesIds = [quesId for quesId in self.params['question_id']]
71 | gts = {}
72 | res = {}
73 | for quesId in quesIds:
74 | gts[quesId] = self.vqa.qa[quesId]
75 | res[quesId] = self.vqaRes.qa[quesId]
76 |
77 | # =================================================
78 | # Compute accuracy
79 | # =================================================
80 | accQA = []
81 | accQuesType = {}
82 | accAnsType = {}
83 | print ("computing accuracy")
84 | step = 0
85 | for quesId in quesIds:
86 | resAns = res[quesId]['answer']
87 | resAns = resAns.replace('\n', ' ')
88 | resAns = resAns.replace('\t', ' ')
89 | resAns = resAns.strip()
90 | resAns = self.processPunctuation(resAns)
91 | resAns = self.processDigitArticle(resAns)
92 | gtAcc = []
93 | gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
94 | if len(set(gtAnswers)) > 1:
95 | for ansDic in gts[quesId]['answers']:
96 | ansDic['answer'] = self.processPunctuation(ansDic['answer'])
97 | for gtAnsDatum in gts[quesId]['answers']:
98 | otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
99 | matchingAns = [item for item in otherGTAns if item['answer']==resAns]
100 | acc = min(1, float(len(matchingAns))/3)
101 | gtAcc.append(acc)
102 | quesType = gts[quesId]['question_type']
103 | ansType = gts[quesId]['answer_type']
104 | avgGTAcc = float(sum(gtAcc))/len(gtAcc)
105 | accQA.append(avgGTAcc)
106 | if quesType not in accQuesType:
107 | accQuesType[quesType] = []
108 | accQuesType[quesType].append(avgGTAcc)
109 | if ansType not in accAnsType:
110 | accAnsType[ansType] = []
111 | accAnsType[ansType].append(avgGTAcc)
112 | self.setEvalQA(quesId, avgGTAcc)
113 | self.setEvalQuesType(quesId, quesType, avgGTAcc)
114 | self.setEvalAnsType(quesId, ansType, avgGTAcc)
115 | if step%100 == 0:
116 | self.updateProgress(step/float(len(quesIds)))
117 | step = step + 1
118 |
119 | self.setAccuracy(accQA, accQuesType, accAnsType)
120 | print ("Done computing accuracy")
121 |
122 | def processPunctuation(self, inText):
123 | outText = inText
124 | for p in self.punct:
125 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
126 | outText = outText.replace(p, '')
127 | else:
128 | outText = outText.replace(p, ' ')
129 | outText = self.periodStrip.sub("",
130 | outText,
131 | re.UNICODE)
132 | return outText
133 |
134 | def processDigitArticle(self, inText):
135 | outText = []
136 | tempText = inText.lower().split()
137 | for word in tempText:
138 | word = self.manualMap.setdefault(word, word)
139 | if word not in self.articles:
140 | outText.append(word)
141 | else:
142 | pass
143 | for wordId, word in enumerate(outText):
144 | if word in self.contractions:
145 | outText[wordId] = self.contractions[word]
146 | outText = ' '.join(outText)
147 | return outText
148 |
149 | def setAccuracy(self, accQA, accQuesType, accAnsType):
150 | self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
151 | self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
152 | self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
153 |
154 | def setEvalQA(self, quesId, acc):
155 | self.evalQA[quesId] = round(100*acc, self.n)
156 |
157 | def setEvalQuesType(self, quesId, quesType, acc):
158 | if quesType not in self.evalQuesType:
159 | self.evalQuesType[quesType] = {}
160 | self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
161 |
162 | def setEvalAnsType(self, quesId, ansType, acc):
163 | if ansType not in self.evalAnsType:
164 | self.evalAnsType[ansType] = {}
165 | self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
166 |
167 | def updateProgress(self, progress):
168 | barLength = 20
169 | status = ""
170 | if isinstance(progress, int):
171 | progress = float(progress)
172 | if not isinstance(progress, float):
173 | progress = 0
174 | status = "error: progress var must be float\r\n"
175 | if progress < 0:
176 | progress = 0
177 | status = "Halt...\r\n"
178 | if progress >= 1:
179 | progress = 1
180 | status = "Done...\r\n"
181 | block = int(round(barLength*progress))
182 | text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
183 | sys.stdout.write(text)
184 | sys.stdout.flush()
--------------------------------------------------------------------------------
/x2vlm_github.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengyan-97/X2-VLM/ac040c831b74088c7989aa06f03114479e522293/x2vlm_github.png
--------------------------------------------------------------------------------