├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── Grounding.py ├── LICENSE.txt ├── NLVR.py ├── Pretrain.py ├── Pretrain_nlvr.py ├── README.md ├── Retrieval.py ├── SECURITY.md ├── VE.py ├── VQA.py ├── cog.yaml ├── configs ├── Grounding.yaml ├── NLVR.yaml ├── NLVR_pretrain.yaml ├── Pretrain.yaml ├── Retrieval_coco.yaml ├── Retrieval_flickr.yaml ├── VE.yaml ├── VQA.yaml └── config_bert.json ├── dataset ├── __init__.py ├── caption_dataset.py ├── grounding_dataset.py ├── nlvr_dataset.py ├── randaugment.py ├── utils.py ├── ve_dataset.py └── vqa_dataset.py ├── examples ├── image0.jpg └── visualization.png ├── img.png ├── models ├── __init__.py ├── model_nlvr.py ├── model_pretrain.py ├── model_pretrain_nlvr.py ├── model_retrieval.py ├── model_ve.py ├── model_vqa.py ├── tokenization_bert.py ├── vit.py └── xbert.py ├── optim ├── __init__.py ├── adafactor.py ├── adahessian.py ├── adamp.py ├── adamw.py ├── lookahead.py ├── nadam.py ├── novograd.py ├── nvnovograd.py ├── optim_factory.py ├── radam.py ├── rmsprop_tf.py └── sgdp.py ├── predict.py ├── refTools ├── __pycache__ │ ├── refer_python3.cpython-36.pyc │ └── refer_python3.cpython-38.pyc ├── evaluation │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── refEvaluation.cpython-36.pyc │ │ └── refEvaluation.cpython-38.pyc │ ├── bleu │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── bleu.cpython-36.pyc │ │ │ ├── bleu.cpython-38.pyc │ │ │ ├── bleu_scorer.cpython-36.pyc │ │ │ └── bleu_scorer.cpython-38.pyc │ │ ├── bleu.py │ │ ├── bleu.pyc │ │ ├── bleu_scorer.py │ │ └── bleu_scorer.pyc │ ├── cider │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── cider.cpython-36.pyc │ │ │ ├── cider.cpython-38.pyc │ │ │ ├── cider_scorer.cpython-36.pyc │ │ │ └── cider_scorer.cpython-38.pyc │ │ ├── cider.py │ │ ├── cider.pyc │ │ ├── cider_scorer.py │ │ └── cider_scorer.pyc │ ├── meteor │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── meteor.cpython-36.pyc │ │ │ └── meteor.cpython-38.pyc │ │ ├── meteor-1.5.jar │ │ ├── meteor.py │ │ └── meteor.pyc │ ├── readme.txt │ ├── refEvaluation.py │ ├── refEvaluation.pyc │ ├── rouge │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── rouge.cpython-36.pyc │ │ │ └── rouge.cpython-38.pyc │ │ ├── rouge.py │ │ └── rouge.pyc │ └── tokenizer │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── ptbtokenizer.cpython-36.pyc │ │ └── ptbtokenizer.cpython-38.pyc │ │ ├── ptbtokenizer.py │ │ ├── ptbtokenizer.pyc │ │ ├── stanford-corenlp-3.4.1.jar │ │ ├── tmp37tp6xj8 │ │ ├── tmp82iqkuu0 │ │ └── tmpn19wmqte └── refer_python3.py ├── scheduler ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── cosine_lr.cpython-36.pyc │ ├── cosine_lr.cpython-38.pyc │ ├── plateau_lr.cpython-36.pyc │ ├── plateau_lr.cpython-38.pyc │ ├── scheduler.cpython-36.pyc │ ├── scheduler.cpython-38.pyc │ ├── scheduler_factory.cpython-36.pyc │ ├── scheduler_factory.cpython-38.pyc │ ├── step_lr.cpython-36.pyc │ ├── step_lr.cpython-38.pyc │ ├── tanh_lr.cpython-36.pyc │ └── tanh_lr.cpython-38.pyc ├── cosine_lr.py ├── plateau_lr.py ├── scheduler.py ├── scheduler_factory.py ├── step_lr.py └── tanh_lr.py ├── utils.py ├── visualization.ipynb └── vqaTools ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-38.pyc ├── vqa.cpython-36.pyc ├── vqa.cpython-38.pyc ├── vqaEval.cpython-36.pyc └── vqaEval.cpython-38.pyc ├── vqa.py └── vqaEval.py /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /Pretrain_nlvr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import ruamel_yaml as yaml 4 | import numpy as np 5 | import random 6 | import time 7 | import datetime 8 | import json 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.data import DataLoader 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | 18 | from models.model_pretrain_nlvr import ALBEF 19 | from models.vit import interpolate_pos_embed 20 | from models.tokenization_bert import BertTokenizer 21 | 22 | import utils 23 | from dataset import create_dataset, create_sampler, create_loader 24 | from scheduler import create_scheduler 25 | from optim import create_optimizer 26 | 27 | 28 | def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config): 29 | # train 30 | model.train() 31 | 32 | metric_logger = utils.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 34 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 35 | 36 | header = 'Train Epoch: [{}]'.format(epoch) 37 | print_freq = 50 38 | step_size = 100 39 | warmup_iterations = warmup_steps*step_size 40 | 41 | if args.distributed: 42 | data_loader.sampler.set_epoch(epoch) 43 | 44 | for i, (image, text) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | 46 | optimizer.zero_grad() 47 | 48 | image = image.to(device,non_blocking=True) 49 | text_input = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device) 50 | 51 | loss = model(image, text_input) 52 | loss.backward() 53 | 54 | optimizer.step() 55 | 56 | metric_logger.update(loss=loss.item()) 57 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 58 | 59 | if epoch==0 and i%step_size==0 and i<=warmup_iterations: 60 | scheduler.step(i//step_size) 61 | 62 | # gather the stats from all processes 63 | metric_logger.synchronize_between_processes() 64 | print("Averaged stats:", metric_logger.global_avg()) 65 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 66 | 67 | 68 | def main(args, config): 69 | utils.init_distributed_mode(args) 70 | 71 | device = torch.device(args.device) 72 | 73 | # fix the seed for reproducibility 74 | seed = args.seed + utils.get_rank() 75 | torch.manual_seed(seed) 76 | np.random.seed(seed) 77 | random.seed(seed) 78 | cudnn.benchmark = True 79 | 80 | start_epoch = 0 81 | max_epoch = config['schedular']['epochs'] 82 | warmup_steps = config['schedular']['warmup_epochs'] 83 | 84 | #### Dataset #### 85 | print("Creating dataset") 86 | datasets = [create_dataset('pretrain', config)] 87 | 88 | if args.distributed: 89 | num_tasks = utils.get_world_size() 90 | global_rank = utils.get_rank() 91 | samplers = create_sampler(datasets, [True], num_tasks, global_rank) 92 | else: 93 | samplers = [None] 94 | 95 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0] 96 | 97 | tokenizer = BertTokenizer.from_pretrained(args.text_encoder) 98 | 99 | #### Model #### 100 | print("Creating model") 101 | model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer) 102 | 103 | model = model.to(device) 104 | 105 | arg_opt = utils.AttrDict(config['optimizer']) 106 | optimizer = create_optimizer(arg_opt, model) 107 | arg_sche = utils.AttrDict(config['schedular']) 108 | lr_scheduler, _ = create_scheduler(arg_sche, optimizer) 109 | 110 | if args.checkpoint: 111 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 112 | state_dict = checkpoint['model'] 113 | pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 114 | state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped 115 | 116 | for key in list(state_dict.keys()): 117 | if 'bert' in key: 118 | new_key = key.replace('bert.','') 119 | 120 | if 'layer' in new_key: 121 | keys = new_key.split('.') 122 | layer_num = int(keys[3]) 123 | # replicate the multimodal encoder's blocks for two images 124 | if layer_num>=6: 125 | new_layer_num = (layer_num-6)*2+6 126 | keys[3] = str(new_layer_num) 127 | new_key_0 = '.'.join(keys) 128 | state_dict[new_key_0] = state_dict[key] 129 | keys[3] = str(new_layer_num+1) 130 | new_key_1 = '.'.join(keys) 131 | state_dict[new_key_1] = state_dict[key] 132 | else: 133 | state_dict[new_key] = state_dict[key] 134 | else: 135 | state_dict[new_key] = state_dict[key] 136 | del state_dict[key] 137 | 138 | msg = model.load_state_dict(state_dict,strict=False) 139 | print('load checkpoint from %s'%args.checkpoint) 140 | print(msg) 141 | 142 | model_without_ddp = model 143 | if args.distributed: 144 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 145 | model_without_ddp = model.module 146 | 147 | print("Start training") 148 | start_time = time.time() 149 | 150 | for epoch in range(start_epoch, max_epoch): 151 | 152 | if epoch>0: 153 | lr_scheduler.step(epoch+warmup_steps) 154 | 155 | train_stats = train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config) 156 | 157 | if utils.is_main_process(): 158 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 159 | 'epoch': epoch, 160 | } 161 | save_obj = { 162 | 'model': model_without_ddp.state_dict(), 163 | 'optimizer': optimizer.state_dict(), 164 | 'lr_scheduler': lr_scheduler.state_dict(), 165 | 'config': config, 166 | 'epoch': epoch, 167 | } 168 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 169 | 170 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 171 | f.write(json.dumps(log_stats) + "\n") 172 | 173 | dist.barrier() 174 | 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print('Training time {}'.format(total_time_str)) 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--config', default='./configs/NLVR_pretrain.yaml') 184 | parser.add_argument('--checkpoint', default='') 185 | parser.add_argument('--output_dir', default='output/NLVR_pretrain') 186 | parser.add_argument('--text_encoder', default='bert-base-uncased') 187 | parser.add_argument('--device', default='cuda') 188 | parser.add_argument('--seed', default=42, type=int) 189 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 190 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 191 | parser.add_argument('--distributed', default=True, type=bool) 192 | args = parser.parse_args() 193 | 194 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 195 | 196 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 197 | 198 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 199 | 200 | main(args, config) 201 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Align before Fuse: Vision and Language Representation Learning with Momentum Distillation, NeurIPS 2021 Spotlight (Salesforce Research). 2 | 3 | ## Announcement: ALBEF is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications! 4 | 5 | This is the official PyTorch implementation of the ALBEF paper [Blog]. 6 | This repository supports pre-training on custom datasets, as well as finetuning on VQA, SNLI-VE, NLVR2, Image-Text Retrieval on MSCOCO and Flickr30k, 7 | and visual grounding on RefCOCO+. Pre-trained and finetuned checkpoints are released. 8 | 9 | 10 | 11 | 12 | ### Requirements: 13 | * pytorch 1.8.0 14 | * transformers 4.8.1 15 | * timm 0.4.9 16 | 17 | ### Download: 18 | 19 | * Pre-trained checkpoint [[14M](https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth)] / [[4M](https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF_4M.pth)] 20 | * Dataset json files for downstream tasks 21 | * Dataset json files for pre-training (the image paths in each json file need to be changed to your own directory) 22 | * Finetuned checkpoint for retrieval on MSCOCO 23 | * Finetuned checkpoint for retrieval on Flickr30k 24 | * Finetuned checkpoint for VQA 25 | * Finetuned checkpoint for visual grounding on RefCOCO+ 26 | 27 | ### Visualization: 28 | We provide code in visualize.ipynb to visualize the important areas in an image for each word in a text. 29 | Here is an example visualization using the visual grounding checkpoint. 30 | 31 | Try the Replicate demo here [![Replicate](https://replicate.com/salesforce/albef/badge)](https://replicate.com/salesforce/albef). 32 | 33 | 34 | 35 | ### Pre-training on custom datasets: 36 | 1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}. 37 | 2. In configs/Pretrain.yaml, set the paths for the json files. 38 | 3. Pre-train the model using 8 A100 GPUs: 39 |
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain 
40 | 41 | ### Image-Text Retrieval: 42 | 43 | 1. Download MSCOCO or Flickr30k datasets from the original websites. 44 | 2. Download and extract the provided dataset json files. 45 | 3. In configs/Retrieval_coco.yaml or configs/Retrieval_flickr.yaml, set the paths for the json files and the image path. 46 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs: 47 |
python -m torch.distributed.launch --nproc_per_node=8 --use_env Retrieval.py \
 48 | --config ./configs/Retrieval_flickr.yaml \
 49 | --output_dir output/Retrieval_flickr \
 50 | --checkpoint [Pretrained checkpoint]
51 | 52 | ### VQA: 53 | 1. Download VQA v2 dataset and Visual Genome dataset from the original websites. 54 | 2. Download and extract the provided dataset json files. 55 | 3. In configs/VQA.yaml, set the paths for the json files and the image paths. 56 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs: 57 |
python -m torch.distributed.launch --nproc_per_node=8 --use_env VQA.py \
 58 | --config ./configs/VQA.yaml \
 59 | --output_dir output/vqa \
 60 | --checkpoint [Pretrained checkpoint]
61 | 5. Evaluate the result using the official evaluation server. 62 | 63 | ### Visual Entailment: 64 | 1. Download SNLI-VE dataset from the original website. 65 | 2. Download and extract the provided dataset json files. 66 | 3. In configs/VE.yaml, set the paths for the json files and the image path. 67 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs: 68 |
python -m torch.distributed.launch --nproc_per_node=8 --use_env VE.py \
 69 | --config ./configs/VE.yaml \
 70 | --output_dir output/VE \
 71 | --checkpoint [Pretrained checkpoint]
72 | 73 | ### Visual Grounding on RefCOCO+: 74 | 1. Download MSCOCO dataset from the original website. 75 | 2. Download and extract the provided dataset json files. 76 | 3. In configs/Grounding.yaml, set the paths for the json files and the image path. 77 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs: 78 |
python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py \
 79 | --config ./configs/Grounding.yaml \
 80 | --output_dir output/RefCOCO \
 81 | --gradcam_mode itm \ 
 82 | --block_num 8 \
 83 | --checkpoint [Pretrained checkpoint]
84 | 85 | ### NLVR2: 86 | NLVR2 requires an additional pre-training step with text-assignment (TA) to adapt the model for image-pair inputs. In order to perform TA, first set the paths for the json training files in configs/NLVR_pretrain.yaml, then run: 87 |
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain_nlvr.py \
 88 | --config ./configs/NLVR_pretrain.yaml \
 89 | --output_dir output/NLVR_pretrain \
 90 | --checkpoint [Pretrained checkpoint]
91 | 92 | We provide the checkpoint after TA pre-training, which can be fine-tuned with the following steps. 93 | 1. Download NLVR2 dataset from the original website. 94 | 2. Download and extract the provided dataset json files. 95 | 3. In configs/NLVR.yaml, set the paths for the json files and the image path. 96 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs: 97 |
python -m torch.distributed.launch --nproc_per_node=8 --use_env NLVR.py \
 98 | --config ./configs/NLVR.yaml \
 99 | --output_dir output/NLVR \
100 | --checkpoint [TA pretrained checkpoint]
101 | 102 | ### Citation 103 | If you find this code to be useful for your research, please consider citing. 104 |
105 | @inproceedings{ALBEF,
106 |       title={Align before Fuse: Vision and Language Representation Learning with Momentum Distillation}, 
107 |       author={Junnan Li and Ramprasaath R. Selvaraju and Akhilesh Deepak Gotmare and Shafiq Joty and Caiming Xiong and Steven Hoi},
108 |       year={2021},
109 |       booktitle={NeurIPS},
110 | }
111 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | cuda: "11.1" 4 | python_version: "3.8" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "ipython==7.30.1" 10 | - "torchvision==0.11.1" 11 | - "torch==1.10.0" 12 | - "timm==0.4.12" 13 | - "transformers==4.8.1" 14 | - "Pillow==8.3.2" 15 | - "numpy==1.21.1" 16 | - "opencv-python==4.5.5.62" 17 | - "scipy==1.8.0" 18 | - "scikit_image==0.19.2" 19 | - "matplotlib==3.4.3" 20 | 21 | predict: "predict.py:Predictor" 22 | -------------------------------------------------------------------------------- /configs/Grounding.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/refcoco+_train.json'] 2 | test_file: ['data/refcoco+_val.json','data/refcoco+_test.json'] 3 | 4 | refcoco_data: 'data' 5 | det_file: 'data/refcoco+/dets.json' 6 | coco_file: 'data/refcoco+/cocos.json' 7 | 8 | image_root: '/export/share/datasets/vision/coco/images/' 9 | 10 | bert_config: 'configs/config_bert.json' 11 | 12 | image_res: 384 13 | batch_size: 32 14 | 15 | queue_size: 65536 16 | momentum: 0.995 17 | vision_width: 768 18 | embed_dim: 256 19 | temp: 0.07 20 | 21 | alpha: 0.4 22 | distill: True 23 | warm_up: True 24 | 25 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02} 26 | schedular: {sched: cosine, lr: 1e-5, epochs: 5, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /configs/NLVR.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/nlvr_train.json'] 2 | val_file: ['data/nlvr_dev.json'] 3 | test_file: ['data/nlvr_test.json'] 4 | 5 | image_root: '/export/share/datasets/vision/NLVR2/' 6 | 7 | image_res: 384 8 | batch_size: 16 9 | 10 | bert_config: 'configs/config_bert.json' 11 | 12 | alpha: 0.4 13 | distill: True 14 | warm_up: True 15 | eval_ema: False 16 | 17 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02} 18 | schedular: {sched: cosine, lr: 2e-5, epochs: 10, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /configs/NLVR_pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/coco.json', 2 | 'data/vg.json', 3 | 'data/cc3m_train.json', 4 | 'data/cc3m_val.json', 5 | 'data/sbu.json' 6 | ] 7 | 8 | # each train_file (json) contains a python list where each item is {'image': img_path, 'caption': text or list_of_text } 9 | 10 | bert_config: 'configs/config_bert.json' 11 | 12 | image_res: 256 13 | vision_width: 768 14 | embed_dim: 256 15 | batch_size: 64 16 | 17 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02} 18 | schedular: {sched: cosine, lr: 2e-5, epochs: 1, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /configs/Pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/coco.json', 2 | 'data/vg.json', 3 | 'data/cc12m.json', 4 | 'data/cc3m_train.json', 5 | 'data/cc3m_val.json', 6 | 'data/sbu.json' 7 | ] 8 | # each train_file (json) contains a python list where each item is {'image': img_path, 'caption': text or list_of_text } 9 | bert_config: 'configs/config_bert.json' 10 | 11 | image_res: 256 12 | vision_width: 768 13 | embed_dim: 256 14 | batch_size: 64 15 | temp: 0.07 16 | mlm_probability: 0.15 17 | queue_size: 65536 18 | momentum: 0.995 19 | alpha: 0.4 20 | 21 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.02} 22 | schedular: {sched: cosine, lr: 1e-4, epochs: 30, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 20, cooldown_epochs: 0} 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /configs/Retrieval_coco.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/coco_train.json'] 2 | val_file: 'data/coco_val.json' 3 | test_file: 'data/coco_test.json' 4 | image_root: '/export/share/datasets/vision/coco/images/' 5 | 6 | bert_config: 'configs/config_bert.json' 7 | 8 | image_res: 384 9 | batch_size_train: 32 10 | batch_size_test: 64 11 | 12 | queue_size: 65536 13 | momentum: 0.995 14 | vision_width: 768 15 | embed_dim: 256 16 | temp: 0.07 17 | k_test: 256 18 | 19 | alpha: 0.4 20 | distill: True 21 | warm_up: True 22 | 23 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02} 24 | schedular: {sched: cosine, lr: 1e-5, epochs: 5, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /configs/Retrieval_flickr.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/flickr30k_train.json'] 2 | val_file: 'data/flickr30k_val.json' 3 | test_file: 'data/flickr30k_test.json' 4 | image_root: '/export/share/datasets/vision/flickr30k/' #flickr30k-images/ 5 | 6 | bert_config: 'configs/config_bert.json' 7 | 8 | image_res: 384 9 | batch_size_train: 32 10 | batch_size_test: 64 11 | 12 | queue_size: 65536 13 | momentum: 0.995 14 | vision_width: 768 15 | embed_dim: 256 16 | temp: 0.07 17 | k_test: 128 18 | 19 | alpha: 0.4 20 | distill: True 21 | warm_up: True 22 | 23 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02} 24 | schedular: {sched: cosine, lr: 1e-5, epochs: 10, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /configs/VE.yaml: -------------------------------------------------------------------------------- 1 | train_file: 'data/ve_train.json' 2 | val_file: 'data/ve_dev.json' 3 | test_file: 'data/ve_test.json' 4 | 5 | image_root: '/export/home/project/SNLI-VE/data/images' 6 | 7 | image_res: 384 8 | batch_size_train: 32 9 | batch_size_test: 64 10 | 11 | alpha: 0.4 12 | distill: True 13 | warm_up: False 14 | 15 | bert_config: 'configs/config_bert.json' 16 | 17 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02} 18 | schedular: {sched: cosine, lr: 2e-5, epochs: 5, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /configs/VQA.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/vqa_train.json', 2 | 'data/vqa_val.json', 3 | 'data/vg_qa.json'] 4 | 5 | test_file: ['data/vqa_test.json'] 6 | answer_list: 'data/answer_list.json' 7 | 8 | vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #train2014/ 9 | vg_root: '/export/share/datasets/vision/visual-genome/' #image/ 10 | 11 | image_res: 384 12 | batch_size_train: 32 13 | batch_size_test: 16 14 | k_test: 128 15 | 16 | alpha: 0.4 17 | distill: True 18 | warm_up: True 19 | 20 | eos: '[SEP]' 21 | 22 | bert_config: 'configs/config_bert.json' 23 | 24 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02} 25 | schedular: {sched: cosine, lr: 2e-5, epochs: 8, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 4, cooldown_epochs: 0} 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /configs/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 6, 20 | "encoder_width": 768 21 | } 22 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from PIL import Image 5 | 6 | from dataset.caption_dataset import re_train_dataset, re_eval_dataset, pretrain_dataset 7 | from dataset.nlvr_dataset import nlvr_dataset 8 | from dataset.ve_dataset import ve_dataset 9 | from dataset.vqa_dataset import vqa_dataset 10 | from dataset.grounding_dataset import grounding_dataset 11 | 12 | from dataset.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config): 15 | 16 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | pretrain_transform = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_res'],scale=(0.2, 1.0), interpolation=Image.BICUBIC), 20 | transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | train_transform = transforms.Compose([ 27 | transforms.RandomResizedCrop(config['image_res'],scale=(0.5, 1.0), interpolation=Image.BICUBIC), 28 | transforms.RandomHorizontalFlip(), 29 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 30 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 31 | transforms.ToTensor(), 32 | normalize, 33 | ]) 34 | test_transform = transforms.Compose([ 35 | transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC), 36 | transforms.ToTensor(), 37 | normalize, 38 | ]) 39 | 40 | if dataset=='pretrain': 41 | dataset = pretrain_dataset(config['train_file'], pretrain_transform) 42 | return dataset 43 | 44 | elif dataset=='re': 45 | train_dataset = re_train_dataset(config['train_file'], train_transform, config['image_root']) 46 | val_dataset = re_eval_dataset(config['val_file'], test_transform, config['image_root']) 47 | test_dataset = re_eval_dataset(config['test_file'], test_transform, config['image_root']) 48 | return train_dataset, val_dataset, test_dataset 49 | 50 | elif dataset=='vqa': 51 | train_dataset = vqa_dataset(config['train_file'], train_transform, config['vqa_root'], config['vg_root'], split='train') 52 | vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], split='test', answer_list=config['answer_list']) 53 | return train_dataset, vqa_test_dataset 54 | 55 | elif dataset=='nlvr': 56 | train_dataset = nlvr_dataset(config['train_file'], train_transform, config['image_root']) 57 | val_dataset = nlvr_dataset(config['val_file'], test_transform, config['image_root']) 58 | test_dataset = nlvr_dataset(config['test_file'], test_transform, config['image_root']) 59 | return train_dataset, val_dataset, test_dataset 60 | 61 | elif dataset=='ve': 62 | train_dataset = ve_dataset(config['train_file'], train_transform, config['image_root']) 63 | val_dataset = ve_dataset(config['val_file'], test_transform, config['image_root']) 64 | test_dataset = ve_dataset(config['test_file'], test_transform, config['image_root']) 65 | return train_dataset, val_dataset, test_dataset 66 | 67 | elif dataset=='grounding': 68 | train_transform = transforms.Compose([ 69 | transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC), 70 | transforms.RandomHorizontalFlip(), 71 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 72 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 73 | transforms.ToTensor(), 74 | normalize, 75 | ]) 76 | train_dataset = grounding_dataset(config['train_file'], train_transform, config['image_root'], mode='train') 77 | test_dataset = grounding_dataset(config['test_file'], test_transform, config['image_root'], mode='test') 78 | return train_dataset, test_dataset 79 | 80 | 81 | def vqa_collate_fn(batch): 82 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] 83 | for image, question, answer, weights in batch: 84 | image_list.append(image) 85 | question_list.append(question) 86 | weight_list += weights 87 | answer_list += answer 88 | n.append(len(answer)) 89 | return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n 90 | 91 | 92 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 93 | samplers = [] 94 | for dataset,shuffle in zip(datasets,shuffles): 95 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 96 | samplers.append(sampler) 97 | return samplers 98 | 99 | 100 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 101 | loaders = [] 102 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 103 | if is_train: 104 | shuffle = (sampler is None) 105 | drop_last = True 106 | else: 107 | shuffle = False 108 | drop_last = False 109 | loader = DataLoader( 110 | dataset, 111 | batch_size=bs, 112 | num_workers=n_worker, 113 | pin_memory=True, 114 | sampler=sampler, 115 | shuffle=shuffle, 116 | collate_fn=collate_fn, 117 | drop_last=drop_last, 118 | ) 119 | loaders.append(loader) 120 | return loaders -------------------------------------------------------------------------------- /dataset/caption_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | Image.MAX_IMAGE_PIXELS = None 11 | 12 | from dataset.utils import pre_caption 13 | 14 | 15 | class re_train_dataset(Dataset): 16 | def __init__(self, ann_file, transform, image_root, max_words=30): 17 | self.ann = [] 18 | for f in ann_file: 19 | self.ann += json.load(open(f,'r')) 20 | self.transform = transform 21 | self.image_root = image_root 22 | self.max_words = max_words 23 | self.img_ids = {} 24 | 25 | n = 0 26 | for ann in self.ann: 27 | img_id = ann['image_id'] 28 | if img_id not in self.img_ids.keys(): 29 | self.img_ids[img_id] = n 30 | n += 1 31 | 32 | def __len__(self): 33 | return len(self.ann) 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.ann[index] 38 | 39 | image_path = os.path.join(self.image_root,ann['image']) 40 | image = Image.open(image_path).convert('RGB') 41 | image = self.transform(image) 42 | 43 | caption = pre_caption(ann['caption'], self.max_words) 44 | 45 | return image, caption, self.img_ids[ann['image_id']] 46 | 47 | 48 | 49 | class re_eval_dataset(Dataset): 50 | def __init__(self, ann_file, transform, image_root, max_words=30): 51 | self.ann = json.load(open(ann_file,'r')) 52 | self.transform = transform 53 | self.image_root = image_root 54 | self.max_words = max_words 55 | 56 | self.text = [] 57 | self.image = [] 58 | self.txt2img = {} 59 | self.img2txt = {} 60 | 61 | txt_id = 0 62 | for img_id, ann in enumerate(self.ann): 63 | self.image.append(ann['image']) 64 | self.img2txt[img_id] = [] 65 | for i, caption in enumerate(ann['caption']): 66 | self.text.append(pre_caption(caption,self.max_words)) 67 | self.img2txt[img_id].append(txt_id) 68 | self.txt2img[txt_id] = img_id 69 | txt_id += 1 70 | 71 | def __len__(self): 72 | return len(self.image) 73 | 74 | def __getitem__(self, index): 75 | 76 | image_path = os.path.join(self.image_root, self.ann[index]['image']) 77 | image = Image.open(image_path).convert('RGB') 78 | image = self.transform(image) 79 | 80 | return image, index 81 | 82 | 83 | 84 | class pretrain_dataset(Dataset): 85 | def __init__(self, ann_file, transform, max_words=30): 86 | self.ann = [] 87 | for f in ann_file: 88 | self.ann += json.load(open(f,'r')) 89 | self.transform = transform 90 | self.max_words = max_words 91 | 92 | 93 | def __len__(self): 94 | return len(self.ann) 95 | 96 | 97 | def __getitem__(self, index): 98 | 99 | ann = self.ann[index] 100 | 101 | if type(ann['caption']) == list: 102 | caption = pre_caption(random.choice(ann['caption']), self.max_words) 103 | else: 104 | caption = pre_caption(ann['caption'], self.max_words) 105 | 106 | image = Image.open(ann['image']).convert('RGB') 107 | image = self.transform(image) 108 | 109 | return image, caption 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /dataset/grounding_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from dataset.utils import pre_caption 6 | 7 | class grounding_dataset(Dataset): 8 | def __init__(self, ann_file, transform, image_root, max_words=30, mode='train'): 9 | self.ann = [] 10 | for f in ann_file: 11 | self.ann += json.load(open(f,'r')) 12 | self.transform = transform 13 | self.image_root = image_root 14 | self.max_words = max_words 15 | self.mode = mode 16 | 17 | if self.mode == 'train': 18 | self.img_ids = {} 19 | n = 0 20 | for ann in self.ann: 21 | img_id = ann['image'].split('/')[-1] 22 | if img_id not in self.img_ids.keys(): 23 | self.img_ids[img_id] = n 24 | n += 1 25 | 26 | 27 | def __len__(self): 28 | return len(self.ann) 29 | 30 | def __getitem__(self, index): 31 | 32 | ann = self.ann[index] 33 | 34 | image_path = os.path.join(self.image_root,ann['image']) 35 | image = Image.open(image_path).convert('RGB') 36 | image = self.transform(image) 37 | 38 | caption = pre_caption(ann['text'], self.max_words) 39 | 40 | if self.mode=='train': 41 | img_id = ann['image'].split('/')[-1] 42 | 43 | return image, caption, self.img_ids[img_id] 44 | else: 45 | return image, caption, ann['ref_id'] -------------------------------------------------------------------------------- /dataset/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from dataset.utils import pre_caption 6 | 7 | 8 | class nlvr_dataset(Dataset): 9 | def __init__(self, ann_file, transform, image_root): 10 | self.ann = [] 11 | for f in ann_file: 12 | self.ann += json.load(open(f,'r')) 13 | self.transform = transform 14 | self.image_root = image_root 15 | self.max_words = 30 16 | 17 | def __len__(self): 18 | return len(self.ann) 19 | 20 | 21 | def __getitem__(self, index): 22 | 23 | ann = self.ann[index] 24 | 25 | image0_path = os.path.join(self.image_root,ann['images'][0]) 26 | image0 = Image.open(image0_path).convert('RGB') 27 | image0 = self.transform(image0) 28 | 29 | image1_path = os.path.join(self.image_root,ann['images'][1]) 30 | image1 = Image.open(image1_path).convert('RGB') 31 | image1 = self.transform(image1) 32 | 33 | sentence = pre_caption(ann['sentence'], self.max_words) 34 | 35 | if ann['label']=='True': 36 | label = 1 37 | else: 38 | label = 0 39 | 40 | return image0, image1, sentence, label -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def pre_question(question,max_ques_words): 4 | question = re.sub( 5 | r"([,.'!?\"()*#:;~])", 6 | '', 7 | question.lower(), 8 | ).replace('-', ' ').replace('/', ' ') 9 | question = question.rstrip(' ') 10 | 11 | #truncate question 12 | question_words = question.split(' ') 13 | if len(question_words)>max_ques_words: 14 | question = ' '.join(question_words[:max_ques_words]) 15 | 16 | return question 17 | 18 | 19 | def pre_caption(caption,max_words): 20 | caption = re.sub( 21 | r"([,.'!?\"()*#:;~])", 22 | '', 23 | caption.lower(), 24 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 25 | 26 | caption = re.sub( 27 | r"\s{2,}", 28 | ' ', 29 | caption, 30 | ) 31 | caption = caption.rstrip('\n') 32 | caption = caption.strip(' ') 33 | 34 | #truncate caption 35 | caption_words = caption.split(' ') 36 | if len(caption_words)>max_words: 37 | caption = ' '.join(caption_words[:max_words]) 38 | 39 | return caption 40 | 41 | 42 | from vqaTools.vqaEval import VQAEval 43 | from refTools.evaluation.refEvaluation import RefEvaluation 44 | 45 | import json 46 | import os 47 | import numpy as np 48 | import torch 49 | import torch.distributed as dist 50 | import torch.nn.functional as F 51 | 52 | import utils 53 | from tqdm import tqdm 54 | 55 | 56 | def vqa_eval(vqa, result_file, test_ques_path): 57 | vqaRes = vqa.loadRes(result_file, test_ques_path) 58 | # create vqaEval object by taking vqa and vqaRes 59 | vqaEval = VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 60 | # evaluate results 61 | vqaEval.evaluate() 62 | 63 | # print accuracies 64 | print("\n") 65 | print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) 66 | print("Per Answer Type Accuracy is the following:") 67 | for ansType in vqaEval.accuracy['perAnswerType']: 68 | print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) 69 | print("\n") 70 | 71 | return vqaEval 72 | 73 | 74 | 75 | def collect_result(result, result_dir, filename, is_json=True, is_list=True): 76 | if is_json: 77 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 78 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 79 | json.dump(result,open(result_file,'w')) 80 | else: 81 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,utils.get_rank())) 82 | final_result_file = os.path.join(result_dir, '%s.pth'%filename) 83 | torch.save(result,result_file) 84 | 85 | dist.barrier() 86 | 87 | result = None 88 | if utils.is_main_process(): 89 | # combine results from all processes 90 | if is_list: 91 | result = [] 92 | else: 93 | result = {} 94 | for rank in range(utils.get_world_size()): 95 | if is_json: 96 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 97 | res = json.load(open(result_file,'r')) 98 | else: 99 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,rank)) 100 | res = torch.load(result_file) 101 | if is_list: 102 | result += res 103 | else: 104 | result.update(res) 105 | 106 | return result 107 | 108 | 109 | def save_result(result, result_dir, filename, is_json=True, is_list=True): 110 | if is_json: 111 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 112 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 113 | json.dump(result,open(result_file,'w')) 114 | else: 115 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,utils.get_rank())) 116 | final_result_file = os.path.join(result_dir, '%s.pth'%filename) 117 | torch.save(result,result_file) 118 | 119 | dist.barrier() 120 | 121 | if utils.is_main_process(): 122 | # combine results from all processes 123 | if is_list: 124 | result = [] 125 | else: 126 | result = {} 127 | for rank in range(utils.get_world_size()): 128 | if is_json: 129 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 130 | res = json.load(open(result_file,'r')) 131 | else: 132 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,rank)) 133 | res = torch.load(result_file) 134 | if is_list: 135 | result += res 136 | else: 137 | result.update(res) 138 | if is_json: 139 | json.dump(result,open(final_result_file,'w')) 140 | else: 141 | torch.save(result,final_result_file) 142 | 143 | print('result file saved to %s'%final_result_file) 144 | dist.barrier() 145 | return final_result_file 146 | 147 | 148 | 149 | def grounding_eval(results,dets,cocos,refer,alpha,mask_size=24): 150 | 151 | correct_A_d, correct_B_d, correct_val_d = 0, 0, 0 152 | correct_A, correct_B, correct_val = 0, 0, 0 153 | num_A,num_B,num_val = 0,0,0 154 | 155 | for res in tqdm(results): 156 | 157 | ref_id = res['ref_id'] 158 | ref = refer.Refs[ref_id] 159 | ref_box = refer.refToAnn[ref_id]['bbox'] 160 | image = refer.Imgs[ref['image_id']] 161 | 162 | mask = res['pred'].cuda().view(1,1,mask_size,mask_size) 163 | mask = F.interpolate(mask,size = (image['height'],image['width']), mode='bicubic').squeeze() 164 | 165 | # rank detection boxes 166 | max_score = 0 167 | for det in dets[str(ref['image_id'])]: 168 | score = mask[int(det[1]):int(det[1]+det[3]),int(det[0]):int(det[0]+det[2])] 169 | area = det[2]*det[3] 170 | score = score.sum() / area**alpha 171 | if score>max_score: 172 | pred_box = det[:4] 173 | max_score = score 174 | 175 | IoU_det = computeIoU(ref_box, pred_box) 176 | 177 | if ref['split']=='testA': 178 | num_A += 1 179 | if IoU_det >= 0.5: 180 | correct_A_d += 1 181 | elif ref['split']=='testB': 182 | num_B += 1 183 | if IoU_det >= 0.5: 184 | correct_B_d += 1 185 | elif ref['split']=='val': 186 | num_val += 1 187 | if IoU_det >= 0.5: 188 | correct_val_d += 1 189 | 190 | eval_result = {'val_d':correct_val_d/num_val,'testA_d':correct_A_d/num_A,'testB_d':correct_B_d/num_B} 191 | 192 | for metric, acc in eval_result.items(): 193 | print(f'{metric}: {acc:.3f}') 194 | 195 | return eval_result 196 | 197 | 198 | 199 | # IoU function 200 | def computeIoU(box1, box2): 201 | # each box is of [x1, y1, w, h] 202 | inter_x1 = max(box1[0], box2[0]) 203 | inter_y1 = max(box1[1], box2[1]) 204 | inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1) 205 | inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1) 206 | 207 | if inter_x1 < inter_x2 and inter_y1 < inter_y2: 208 | inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) 209 | else: 210 | inter = 0 211 | union = box1[2]*box1[3] + box2[2]*box2[3] - inter 212 | return float(inter)/union 213 | 214 | 215 | -------------------------------------------------------------------------------- /dataset/ve_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from dataset.utils import pre_caption 6 | 7 | 8 | class ve_dataset(Dataset): 9 | def __init__(self, ann_file, transform, image_root, max_words=30): 10 | self.ann = json.load(open(ann_file,'r')) 11 | self.transform = transform 12 | self.image_root = image_root 13 | self.max_words = max_words 14 | self.labels = {'entailment':2,'neutral':1,'contradiction':0} 15 | 16 | def __len__(self): 17 | return len(self.ann) 18 | 19 | 20 | def __getitem__(self, index): 21 | 22 | ann = self.ann[index] 23 | 24 | image_path = os.path.join(self.image_root,'%s.jpg'%ann['image']) 25 | image = Image.open(image_path).convert('RGB') 26 | image = self.transform(image) 27 | 28 | sentence = pre_caption(ann['sentence'], self.max_words) 29 | 30 | return image, sentence, self.labels[ann['label']] 31 | -------------------------------------------------------------------------------- /dataset/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from dataset.utils import pre_question 7 | 8 | 9 | class vqa_dataset(Dataset): 10 | def __init__(self, ann_file, transform, vqa_root, vg_root, eos='[SEP]', split="train", max_ques_words=30, answer_list=''): 11 | self.split = split 12 | self.ann = [] 13 | for f in ann_file: 14 | self.ann += json.load(open(f,'r')) 15 | 16 | self.transform = transform 17 | self.vqa_root = vqa_root 18 | self.vg_root = vg_root 19 | self.max_ques_words = max_ques_words 20 | self.eos = eos 21 | 22 | if split=='test': 23 | self.max_ques_words = 50 # do not limit question length during test 24 | self.answer_list = json.load(open(answer_list,'r')) 25 | 26 | 27 | def __len__(self): 28 | return len(self.ann) 29 | 30 | def __getitem__(self, index): 31 | 32 | ann = self.ann[index] 33 | 34 | if ann['dataset']=='vqa': 35 | image_path = os.path.join(self.vqa_root,ann['image']) 36 | elif ann['dataset']=='vg': 37 | image_path = os.path.join(self.vg_root,ann['image']) 38 | 39 | image = Image.open(image_path).convert('RGB') 40 | image = self.transform(image) 41 | 42 | if self.split == 'test': 43 | question = pre_question(ann['question'],self.max_ques_words) 44 | question_id = ann['question_id'] 45 | return image, question, question_id 46 | 47 | 48 | elif self.split=='train': 49 | 50 | question = pre_question(ann['question'],self.max_ques_words) 51 | 52 | if ann['dataset']=='vqa': 53 | 54 | answer_weight = {} 55 | for answer in ann['answer']: 56 | if answer in answer_weight.keys(): 57 | answer_weight[answer] += 1/len(ann['answer']) 58 | else: 59 | answer_weight[answer] = 1/len(ann['answer']) 60 | 61 | answers = list(answer_weight.keys()) 62 | weights = list(answer_weight.values()) 63 | 64 | elif ann['dataset']=='vg': 65 | answers = [ann['answer']] 66 | weights = [0.5] 67 | 68 | answers = [answer+self.eos for answer in answers] 69 | 70 | return image, question, answers, weights -------------------------------------------------------------------------------- /examples/image0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/examples/image0.jpg -------------------------------------------------------------------------------- /examples/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/examples/visualization.png -------------------------------------------------------------------------------- /img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/img.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/models/__init__.py -------------------------------------------------------------------------------- /models/model_nlvr.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from models.vit import VisionTransformer 3 | from models.xbert import BertConfig, BertModel 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | class ALBEF(nn.Module): 10 | def __init__(self, 11 | text_encoder = None, 12 | tokenizer = None, 13 | config = None, 14 | ): 15 | super().__init__() 16 | 17 | self.tokenizer = tokenizer 18 | self.distill = config['distill'] 19 | 20 | self.visual_encoder = VisionTransformer( 21 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 22 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 23 | 24 | bert_config = BertConfig.from_json_file(config['bert_config']) 25 | bert_config.num_hidden_layers = 18 26 | 27 | self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) 28 | self.cls_head = nn.Sequential( 29 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 30 | nn.ReLU(), 31 | nn.Linear(self.text_encoder.config.hidden_size, 2) 32 | ) 33 | 34 | self.share_cross_attention(self.text_encoder.encoder) 35 | 36 | if self.distill: 37 | self.visual_encoder_m = VisionTransformer( 38 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 39 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 40 | self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) 41 | self.share_cross_attention(self.text_encoder_m.encoder) 42 | 43 | self.cls_head_m = nn.Sequential( 44 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 45 | nn.ReLU(), 46 | nn.Linear(self.text_encoder.config.hidden_size, 2) 47 | ) 48 | 49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 50 | [self.text_encoder,self.text_encoder_m], 51 | [self.cls_head,self.cls_head_m], 52 | ] 53 | self.copy_params() 54 | self.momentum = 0.995 55 | 56 | 57 | def forward(self, image, text, targets, alpha=0, train=True): 58 | 59 | image_embeds = self.visual_encoder(image) 60 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 61 | 62 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) 63 | 64 | output = self.text_encoder(text.input_ids, 65 | attention_mask = text.attention_mask, 66 | encoder_hidden_states = [image0_embeds,image1_embeds], 67 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 68 | image_atts[image0_embeds.size(0):]], 69 | return_dict = True, 70 | ) 71 | hidden_state = output.last_hidden_state[:,0,:] 72 | prediction = self.cls_head(hidden_state) 73 | 74 | if train: 75 | if self.distill: 76 | with torch.no_grad(): 77 | self._momentum_update() 78 | image_embeds_m = self.visual_encoder_m(image) 79 | image0_embeds_m, image1_embeds_m = torch.split(image_embeds_m,targets.size(0)) 80 | output_m = self.text_encoder_m(text.input_ids, 81 | attention_mask = text.attention_mask, 82 | encoder_hidden_states = [image0_embeds_m,image1_embeds_m], 83 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 84 | image_atts[image0_embeds.size(0):]], 85 | return_dict = True, 86 | ) 87 | prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:]) 88 | 89 | loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum( 90 | F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean() 91 | else: 92 | loss = F.cross_entropy(prediction, targets) 93 | return loss 94 | else: 95 | return prediction 96 | 97 | 98 | 99 | @torch.no_grad() 100 | def copy_params(self): 101 | for model_pair in self.model_pairs: 102 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 103 | param_m.data.copy_(param.data) # initialize 104 | param_m.requires_grad = False # not update by gradient 105 | 106 | 107 | @torch.no_grad() 108 | def _momentum_update(self): 109 | for model_pair in self.model_pairs: 110 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 111 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 112 | 113 | 114 | def share_cross_attention(self, model): 115 | 116 | for i in range(6): 117 | layer_num = 6+i*2 118 | modules_0 = model.layer[layer_num].crossattention.self._modules 119 | modules_1 = model.layer[layer_num+1].crossattention.self._modules 120 | 121 | for name in modules_0.keys(): 122 | if 'key' in name or 'value' in name: 123 | module_0 = modules_0[name] 124 | module_1 = modules_1[name] 125 | if hasattr(module_0, "weight"): 126 | module_0.weight = module_1.weight 127 | if hasattr(module_0, "bias"): 128 | module_0.bias = module_1.bias -------------------------------------------------------------------------------- /models/model_pretrain_nlvr.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from models.vit import VisionTransformer 3 | from models.xbert import BertConfig, BertModel 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | class ALBEF(nn.Module): 10 | def __init__(self, 11 | text_encoder = None, 12 | tokenizer = None, 13 | config = None, 14 | ): 15 | super().__init__() 16 | 17 | self.tokenizer = tokenizer 18 | vision_width = config['vision_width'] 19 | embed_dim = config['embed_dim'] 20 | 21 | self.visual_encoder = VisionTransformer( 22 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 23 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 24 | 25 | bert_config = BertConfig.from_json_file(config['bert_config']) 26 | bert_config.num_hidden_layers = 18 27 | self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) 28 | 29 | #share the cross-attention layers for two images 30 | self.share_cross_attention(self.text_encoder.encoder) 31 | 32 | text_width = self.text_encoder.config.hidden_size 33 | self.vision_proj = nn.Linear(vision_width, embed_dim) 34 | self.text_proj = nn.Linear(text_width, embed_dim) 35 | self.temp = nn.Parameter(torch.ones([]) * 0.07) 36 | self.ta_head = nn.Linear(self.text_encoder.config.hidden_size, 3) 37 | 38 | 39 | def forward(self, image, text): 40 | image_embeds = self.visual_encoder(image) 41 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 42 | with torch.no_grad(): 43 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 44 | sim = image_feat @ image_feat.t() / 0.07 45 | weights = F.softmax(sim,dim=1) 46 | weights.fill_diagonal_(0) 47 | 48 | image_inputs = [[],[]] 49 | labels = [] 50 | for b in range(image.size(0)): 51 | if torch.rand(1)>1/3: 52 | idx = torch.multinomial(weights[b], 1).item() 53 | if torch.rand(1)>0.5: 54 | image_inputs[0].append(image_embeds[b]) 55 | image_inputs[1].append(image_embeds[idx]) 56 | labels.append(0) 57 | else: 58 | image_inputs[1].append(image_embeds[b]) 59 | image_inputs[0].append(image_embeds[idx]) 60 | labels.append(1) 61 | else: 62 | idx = torch.multinomial(weights[b], 2) 63 | image_inputs[0].append(image_embeds[idx[0]]) 64 | image_inputs[1].append(image_embeds[idx[1]]) 65 | labels.append(2) 66 | 67 | image_inputs[0] = torch.stack(image_inputs[0],dim=0) 68 | image_inputs[1] = torch.stack(image_inputs[1],dim=0) 69 | labels = torch.LongTensor(labels).to(image.device) 70 | 71 | output = self.text_encoder(text.input_ids, 72 | attention_mask = text.attention_mask, 73 | encoder_hidden_states = image_inputs, 74 | encoder_attention_mask = [image_atts,image_atts], 75 | return_dict = True, 76 | ) 77 | 78 | pred = self.ta_head(output.last_hidden_state[:,0,:]) 79 | loss = F.cross_entropy(pred, labels) 80 | 81 | return loss 82 | 83 | 84 | 85 | def share_cross_attention(self, model): 86 | 87 | for i in range(6): 88 | layer_num = 6+i*2 89 | modules_0 = model.layer[layer_num].crossattention.self._modules 90 | modules_1 = model.layer[layer_num+1].crossattention.self._modules 91 | 92 | for name in modules_0.keys(): 93 | if 'key' in name or 'value' in name: 94 | module_0 = modules_0[name] 95 | module_1 = modules_1[name] 96 | if hasattr(module_0, "weight"): 97 | module_0.weight = module_1.weight 98 | if hasattr(module_0, "bias"): 99 | module_0.bias = module_1.bias -------------------------------------------------------------------------------- /models/model_ve.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from models.vit import VisionTransformer 3 | from models.xbert import BertConfig, BertModel 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | class ALBEF(nn.Module): 10 | def __init__(self, 11 | text_encoder = None, 12 | tokenizer = None, 13 | config = None, 14 | ): 15 | super().__init__() 16 | 17 | self.tokenizer = tokenizer 18 | self.distill = config['distill'] 19 | 20 | self.visual_encoder = VisionTransformer( 21 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 22 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 23 | 24 | bert_config = BertConfig.from_json_file(config['bert_config']) 25 | 26 | self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) 27 | 28 | self.cls_head = nn.Sequential( 29 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 30 | nn.ReLU(), 31 | nn.Linear(self.text_encoder.config.hidden_size, 3) 32 | ) 33 | 34 | if self.distill: 35 | self.visual_encoder_m = VisionTransformer( 36 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, 37 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 38 | self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) 39 | self.cls_head_m = nn.Sequential( 40 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 41 | nn.ReLU(), 42 | nn.Linear(self.text_encoder.config.hidden_size, 3) 43 | ) 44 | 45 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 46 | [self.text_encoder,self.text_encoder_m], 47 | [self.cls_head,self.cls_head_m], 48 | ] 49 | self.copy_params() 50 | self.momentum = 0.995 51 | 52 | 53 | def forward(self, image, text, targets, alpha=0, train=True): 54 | 55 | image_embeds = self.visual_encoder(image) 56 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 57 | 58 | if train: 59 | output = self.text_encoder(text.input_ids, 60 | attention_mask = text.attention_mask, 61 | encoder_hidden_states = image_embeds, 62 | encoder_attention_mask = image_atts, 63 | return_dict = True 64 | ) 65 | prediction = self.cls_head(output.last_hidden_state[:,0,:]) 66 | if self.distill: 67 | with torch.no_grad(): 68 | self._momentum_update() 69 | image_embeds_m = self.visual_encoder_m(image) 70 | output_m = self.text_encoder_m(text.input_ids, 71 | attention_mask = text.attention_mask, 72 | encoder_hidden_states = image_embeds_m, 73 | encoder_attention_mask = image_atts, 74 | return_dict = True 75 | ) 76 | prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:]) 77 | 78 | loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum( 79 | F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean() 80 | else: 81 | loss = F.cross_entropy(prediction, targets) 82 | return loss 83 | 84 | else: 85 | output = self.text_encoder(text.input_ids, 86 | attention_mask = text.attention_mask, 87 | encoder_hidden_states = image_embeds, 88 | encoder_attention_mask = image_atts, 89 | return_dict = True 90 | ) 91 | prediction = self.cls_head(output.last_hidden_state[:,0,:]) 92 | return prediction 93 | 94 | 95 | 96 | @torch.no_grad() 97 | def copy_params(self): 98 | for model_pair in self.model_pairs: 99 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 100 | param_m.data.copy_(param.data) # initialize 101 | param_m.requires_grad = False # not update by gradient 102 | 103 | 104 | @torch.no_grad() 105 | def _momentum_update(self): 106 | for model_pair in self.model_pairs: 107 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 108 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 109 | 110 | 111 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adamp import AdamP 2 | from .adamw import AdamW 3 | from .adafactor import Adafactor 4 | from .adahessian import Adahessian 5 | from .lookahead import Lookahead 6 | from .nadam import Nadam 7 | from .novograd import NovoGrad 8 | from .nvnovograd import NvNovoGrad 9 | from .radam import RAdam 10 | from .rmsprop_tf import RMSpropTF 11 | from .sgdp import SGDP 12 | 13 | from .optim_factory import create_optimizer -------------------------------------------------------------------------------- /optim/adahessian.py: -------------------------------------------------------------------------------- 1 | """ AdaHessian Optimizer 2 | 3 | Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py 4 | Originally licensed MIT, Copyright 2020, David Samuel 5 | """ 6 | import torch 7 | 8 | 9 | class Adahessian(torch.optim.Optimizer): 10 | """ 11 | Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" 12 | 13 | Arguments: 14 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups 15 | lr (float, optional): learning rate (default: 0.1) 16 | betas ((float, float), optional): coefficients used for computing running averages of gradient and the 17 | squared hessian trace (default: (0.9, 0.999)) 18 | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) 20 | hessian_power (float, optional): exponent of the hessian trace (default: 1.0) 21 | update_each (int, optional): compute the hessian trace approximation only after *this* number of steps 22 | (to save time) (default: 1) 23 | n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) 24 | """ 25 | 26 | def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, 27 | hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): 28 | if not 0.0 <= lr: 29 | raise ValueError(f"Invalid learning rate: {lr}") 30 | if not 0.0 <= eps: 31 | raise ValueError(f"Invalid epsilon value: {eps}") 32 | if not 0.0 <= betas[0] < 1.0: 33 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 34 | if not 0.0 <= betas[1] < 1.0: 35 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 36 | if not 0.0 <= hessian_power <= 1.0: 37 | raise ValueError(f"Invalid Hessian power value: {hessian_power}") 38 | 39 | self.n_samples = n_samples 40 | self.update_each = update_each 41 | self.avg_conv_kernel = avg_conv_kernel 42 | 43 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training 44 | self.seed = 2147483647 45 | self.generator = torch.Generator().manual_seed(self.seed) 46 | 47 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) 48 | super(Adahessian, self).__init__(params, defaults) 49 | 50 | for p in self.get_params(): 51 | p.hess = 0.0 52 | self.state[p]["hessian step"] = 0 53 | 54 | @property 55 | def is_second_order(self): 56 | return True 57 | 58 | def get_params(self): 59 | """ 60 | Gets all parameters in all param_groups with gradients 61 | """ 62 | 63 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad) 64 | 65 | def zero_hessian(self): 66 | """ 67 | Zeros out the accumalated hessian traces. 68 | """ 69 | 70 | for p in self.get_params(): 71 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: 72 | p.hess.zero_() 73 | 74 | @torch.no_grad() 75 | def set_hessian(self): 76 | """ 77 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. 78 | """ 79 | 80 | params = [] 81 | for p in filter(lambda p: p.grad is not None, self.get_params()): 82 | if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step 83 | params.append(p) 84 | self.state[p]["hessian step"] += 1 85 | 86 | if len(params) == 0: 87 | return 88 | 89 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device 90 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed) 91 | 92 | grads = [p.grad for p in params] 93 | 94 | for i in range(self.n_samples): 95 | # Rademacher distribution {-1.0, 1.0} 96 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] 97 | h_zs = torch.autograd.grad( 98 | grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) 99 | for h_z, z, p in zip(h_zs, zs, params): 100 | p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) 101 | 102 | @torch.no_grad() 103 | def step(self, closure=None): 104 | """ 105 | Performs a single optimization step. 106 | Arguments: 107 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) 108 | """ 109 | 110 | loss = None 111 | if closure is not None: 112 | loss = closure() 113 | 114 | self.zero_hessian() 115 | self.set_hessian() 116 | 117 | for group in self.param_groups: 118 | for p in group['params']: 119 | if p.grad is None or p.hess is None: 120 | continue 121 | 122 | if self.avg_conv_kernel and p.dim() == 4: 123 | p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() 124 | 125 | # Perform correct stepweight decay as in AdamW 126 | p.mul_(1 - group['lr'] * group['weight_decay']) 127 | 128 | state = self.state[p] 129 | 130 | # State initialization 131 | if len(state) == 1: 132 | state['step'] = 0 133 | # Exponential moving average of gradient values 134 | state['exp_avg'] = torch.zeros_like(p) 135 | # Exponential moving average of Hessian diagonal square values 136 | state['exp_hessian_diag_sq'] = torch.zeros_like(p) 137 | 138 | exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] 139 | beta1, beta2 = group['betas'] 140 | state['step'] += 1 141 | 142 | # Decay the first and second moment running average coefficient 143 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) 144 | exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) 145 | 146 | bias_correction1 = 1 - beta1 ** state['step'] 147 | bias_correction2 = 1 - beta2 ** state['step'] 148 | 149 | k = group['hessian_power'] 150 | denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) 151 | 152 | # make update 153 | step_size = group['lr'] / bias_correction1 154 | p.addcdiv_(exp_avg, denom, value=-step_size) 155 | 156 | return loss 157 | -------------------------------------------------------------------------------- /optim/adamp.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class AdamP(Optimizer): 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): 19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) 21 | super(AdamP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | 65 | grad = p.grad.data 66 | beta1, beta2 = group['betas'] 67 | nesterov = group['nesterov'] 68 | 69 | state = self.state[p] 70 | 71 | # State initialization 72 | if len(state) == 0: 73 | state['step'] = 0 74 | state['exp_avg'] = torch.zeros_like(p.data) 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | 77 | # Adam 78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 79 | 80 | state['step'] += 1 81 | bias_correction1 = 1 - beta1 ** state['step'] 82 | bias_correction2 = 1 - beta2 ** state['step'] 83 | 84 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 86 | 87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 88 | step_size = group['lr'] / bias_correction1 89 | 90 | if nesterov: 91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 92 | else: 93 | perturb = exp_avg / denom 94 | 95 | # Projection 96 | wd_ratio = 1 97 | if len(p.shape) > 1: 98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) 99 | 100 | # Weight decay 101 | if group['weight_decay'] > 0: 102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) 103 | 104 | # Step 105 | p.data.add_(-step_size, perturb) 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ AdamW Optimizer 2 | Impl copied from PyTorch master 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdamW(Optimizer): 10 | r"""Implements AdamW algorithm. 11 | 12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | 28 | .. _Adam\: A Method for Stochastic Optimization: 29 | https://arxiv.org/abs/1412.6980 30 | .. _Decoupled Weight Decay Regularization: 31 | https://arxiv.org/abs/1711.05101 32 | .. _On the Convergence of Adam and Beyond: 33 | https://openreview.net/forum?id=ryQu7f-RZ 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 37 | weight_decay=1e-2, amsgrad=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | defaults = dict(lr=lr, betas=betas, eps=eps, 47 | weight_decay=weight_decay, amsgrad=amsgrad) 48 | super(AdamW, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super(AdamW, self).__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault('amsgrad', False) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | 71 | # Perform stepweight decay 72 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 73 | 74 | # Perform optimization step 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | state['step'] += 1 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | 102 | # Decay the first and second moment running average coefficient 103 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 105 | if amsgrad: 106 | # Maintains the maximum of all 2nd moment running avg. till now 107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 108 | # Use the max. for normalizing running avg. of gradient 109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 110 | else: 111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 112 | 113 | step_size = group['lr'] / bias_correction1 114 | 115 | p.data.addcdiv_(-step_size, exp_avg, denom) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | if not 0.0 <= alpha <= 1.0: 15 | raise ValueError(f'Invalid slow update rate: {alpha}') 16 | if not 1 <= k: 17 | raise ValueError(f'Invalid lookahead steps: {k}') 18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 19 | self.base_optimizer = base_optimizer 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults = base_optimizer.defaults 22 | self.defaults.update(defaults) 23 | self.state = defaultdict(dict) 24 | # manually add our defaults to the param groups 25 | for name, default in defaults.items(): 26 | for group in self.param_groups: 27 | group.setdefault(name, default) 28 | 29 | def update_slow(self, group): 30 | for fast_p in group["params"]: 31 | if fast_p.grad is None: 32 | continue 33 | param_state = self.state[fast_p] 34 | if 'slow_buffer' not in param_state: 35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 36 | param_state['slow_buffer'].copy_(fast_p.data) 37 | slow = param_state['slow_buffer'] 38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 39 | fast_p.data.copy_(slow) 40 | 41 | def sync_lookahead(self): 42 | for group in self.param_groups: 43 | self.update_slow(group) 44 | 45 | def step(self, closure=None): 46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 47 | loss = self.base_optimizer.step(closure) 48 | for group in self.param_groups: 49 | group['lookahead_step'] += 1 50 | if group['lookahead_step'] % group['lookahead_k'] == 0: 51 | self.update_slow(group) 52 | return loss 53 | 54 | def state_dict(self): 55 | fast_state_dict = self.base_optimizer.state_dict() 56 | slow_state = { 57 | (id(k) if isinstance(k, torch.Tensor) else k): v 58 | for k, v in self.state.items() 59 | } 60 | fast_state = fast_state_dict['state'] 61 | param_groups = fast_state_dict['param_groups'] 62 | return { 63 | 'state': fast_state, 64 | 'slow_state': slow_state, 65 | 'param_groups': param_groups, 66 | } 67 | 68 | def load_state_dict(self, state_dict): 69 | fast_state_dict = { 70 | 'state': state_dict['state'], 71 | 'param_groups': state_dict['param_groups'], 72 | } 73 | self.base_optimizer.load_state_dict(fast_state_dict) 74 | 75 | # We want to restore the slow state, but share param_groups reference 76 | # with base_optimizer. This is a bit redundant but least code 77 | slow_state_new = False 78 | if 'slow_state' not in state_dict: 79 | print('Loading state_dict from optimizer without Lookahead applied.') 80 | state_dict['slow_state'] = defaultdict(dict) 81 | slow_state_new = True 82 | slow_state_dict = { 83 | 'state': state_dict['slow_state'], 84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 85 | } 86 | super(Lookahead, self).load_state_dict(slow_state_dict) 87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 88 | if slow_state_new: 89 | # reapply defaults to catch missing lookahead specific ones 90 | for name, default in self.defaults.items(): 91 | for group in self.param_groups: 92 | group.setdefault(name, default) 93 | -------------------------------------------------------------------------------- /optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | """ 27 | 28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 29 | weight_decay=0, schedule_decay=4e-3): 30 | defaults = dict(lr=lr, betas=betas, eps=eps, 31 | weight_decay=weight_decay, schedule_decay=schedule_decay) 32 | super(Nadam, self).__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['m_schedule'] = 1. 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 58 | 59 | # Warming momentum schedule 60 | m_schedule = state['m_schedule'] 61 | schedule_decay = group['schedule_decay'] 62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 63 | beta1, beta2 = group['betas'] 64 | eps = group['eps'] 65 | state['step'] += 1 66 | t = state['step'] 67 | 68 | if group['weight_decay'] != 0: 69 | grad = grad.add(group['weight_decay'], p.data) 70 | 71 | momentum_cache_t = beta1 * \ 72 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 73 | momentum_cache_t_1 = beta1 * \ 74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 75 | m_schedule_new = m_schedule * momentum_cache_t 76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 77 | state['m_schedule'] = m_schedule_new 78 | 79 | # Decay the first and second moment running average coefficient 80 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 83 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 84 | 85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /optim/novograd.py: -------------------------------------------------------------------------------- 1 | """NovoGrad Optimizer. 2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd 3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 4 | - https://arxiv.org/abs/1905.11286 5 | """ 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | class NovoGrad(Optimizer): 13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | super(NovoGrad, self).__init__(params, defaults) 16 | self._lr = lr 17 | self._beta1 = betas[0] 18 | self._beta2 = betas[1] 19 | self._eps = eps 20 | self._wd = weight_decay 21 | self._grad_averaging = grad_averaging 22 | 23 | self._momentum_initialized = False 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | if not self._momentum_initialized: 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | state = self.state[p] 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError('NovoGrad does not support sparse gradients') 39 | 40 | v = torch.norm(grad)**2 41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data 42 | state['step'] = 0 43 | state['v'] = v 44 | state['m'] = m 45 | state['grad_ema'] = None 46 | self._momentum_initialized = True 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | state = self.state[p] 53 | state['step'] += 1 54 | 55 | step, v, m = state['step'], state['v'], state['m'] 56 | grad_ema = state['grad_ema'] 57 | 58 | grad = p.grad.data 59 | g2 = torch.norm(grad)**2 60 | grad_ema = g2 if grad_ema is None else grad_ema * \ 61 | self._beta2 + g2 * (1. - self._beta2) 62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) 63 | 64 | if self._grad_averaging: 65 | grad *= (1. - self._beta1) 66 | 67 | g2 = torch.norm(grad)**2 68 | v = self._beta2*v + (1. - self._beta2)*g2 69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) 70 | bias_correction1 = 1 - self._beta1 ** step 71 | bias_correction2 = 1 - self._beta2 ** step 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | 74 | state['v'], state['m'] = v, m 75 | state['grad_ema'] = grad_ema 76 | p.data.add_(-step_size, m) 77 | return loss 78 | -------------------------------------------------------------------------------- /optim/nvnovograd.py: -------------------------------------------------------------------------------- 1 | """ Nvidia NovoGrad Optimizer. 2 | Original impl by Nvidia from Jasper example: 3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper 4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 5 | - https://arxiv.org/abs/1905.11286 6 | """ 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer 10 | import math 11 | 12 | 13 | class NvNovoGrad(Optimizer): 14 | """ 15 | Implements Novograd algorithm. 16 | 17 | Args: 18 | params (iterable): iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr (float, optional): learning rate (default: 1e-3) 21 | betas (Tuple[float, float], optional): coefficients used for computing 22 | running averages of gradient and its square (default: (0.95, 0.98)) 23 | eps (float, optional): term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 26 | grad_averaging: gradient averaging 27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 28 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 29 | (default: False) 30 | """ 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, 33 | weight_decay=0, grad_averaging=False, amsgrad=False): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {}".format(eps)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, 44 | grad_averaging=grad_averaging, 45 | amsgrad=amsgrad) 46 | 47 | super(NvNovoGrad, self).__init__(params, defaults) 48 | 49 | def __setstate__(self, state): 50 | super(NvNovoGrad, self).__setstate__(state) 51 | for group in self.param_groups: 52 | group.setdefault('amsgrad', False) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | if grad.is_sparse: 71 | raise RuntimeError('Sparse gradients are not supported.') 72 | amsgrad = group['amsgrad'] 73 | 74 | state = self.state[p] 75 | 76 | # State initialization 77 | if len(state) == 0: 78 | state['step'] = 0 79 | # Exponential moving average of gradient values 80 | state['exp_avg'] = torch.zeros_like(p.data) 81 | # Exponential moving average of squared gradient values 82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 83 | if amsgrad: 84 | # Maintains max of all exp. moving avg. of sq. grad. values 85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | if amsgrad: 89 | max_exp_avg_sq = state['max_exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | norm = torch.sum(torch.pow(grad, 2)) 95 | 96 | if exp_avg_sq == 0: 97 | exp_avg_sq.copy_(norm) 98 | else: 99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) 100 | 101 | if amsgrad: 102 | # Maintains the maximum of all 2nd moment running avg. till now 103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 104 | # Use the max. for normalizing running avg. of gradient 105 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 106 | else: 107 | denom = exp_avg_sq.sqrt().add_(group['eps']) 108 | 109 | grad.div_(denom) 110 | if group['weight_decay'] != 0: 111 | grad.add_(group['weight_decay'], p.data) 112 | if group['grad_averaging']: 113 | grad.mul_(1 - beta1) 114 | exp_avg.mul_(beta1).add_(grad) 115 | 116 | p.data.add_(-group['lr'], exp_avg) 117 | 118 | return loss 119 | -------------------------------------------------------------------------------- /optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | """ Optimizer Factory w/ Custom Weight Decay 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch 5 | from torch import optim as optim 6 | 7 | from .adafactor import Adafactor 8 | from .adahessian import Adahessian 9 | from .adamp import AdamP 10 | from .lookahead import Lookahead 11 | from .nadam import Nadam 12 | from .novograd import NovoGrad 13 | from .nvnovograd import NvNovoGrad 14 | from .radam import RAdam 15 | from .rmsprop_tf import RMSpropTF 16 | from .sgdp import SGDP 17 | 18 | try: 19 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 20 | has_apex = True 21 | except ImportError: 22 | has_apex = False 23 | 24 | 25 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 26 | decay = [] 27 | no_decay = [] 28 | for name, param in model.named_parameters(): 29 | if not param.requires_grad: 30 | continue # frozen weights 31 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 32 | no_decay.append(param) 33 | else: 34 | decay.append(param) 35 | return [ 36 | {'params': no_decay, 'weight_decay': 0.}, 37 | {'params': decay, 'weight_decay': weight_decay}] 38 | 39 | 40 | def create_optimizer(args, model, filter_bias_and_bn=True): 41 | opt_lower = args.opt.lower() 42 | weight_decay = args.weight_decay 43 | if weight_decay and filter_bias_and_bn: 44 | skip = {} 45 | if hasattr(model, 'no_weight_decay'): 46 | skip = model.no_weight_decay() 47 | parameters = add_weight_decay(model, weight_decay, skip) 48 | weight_decay = 0. 49 | else: 50 | parameters = model.parameters() 51 | 52 | if 'fused' in opt_lower: 53 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 54 | 55 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 56 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 57 | opt_args['eps'] = args.opt_eps 58 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 59 | opt_args['betas'] = args.opt_betas 60 | if hasattr(args, 'opt_args') and args.opt_args is not None: 61 | opt_args.update(args.opt_args) 62 | 63 | opt_split = opt_lower.split('_') 64 | opt_lower = opt_split[-1] 65 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 66 | opt_args.pop('eps', None) 67 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 68 | elif opt_lower == 'momentum': 69 | opt_args.pop('eps', None) 70 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 71 | elif opt_lower == 'adam': 72 | optimizer = optim.Adam(parameters, **opt_args) 73 | elif opt_lower == 'adamw': 74 | optimizer = optim.AdamW(parameters, **opt_args) 75 | elif opt_lower == 'nadam': 76 | optimizer = Nadam(parameters, **opt_args) 77 | elif opt_lower == 'radam': 78 | optimizer = RAdam(parameters, **opt_args) 79 | elif opt_lower == 'adamp': 80 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 81 | elif opt_lower == 'sgdp': 82 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 83 | elif opt_lower == 'adadelta': 84 | optimizer = optim.Adadelta(parameters, **opt_args) 85 | elif opt_lower == 'adafactor': 86 | if not args.lr: 87 | opt_args['lr'] = None 88 | optimizer = Adafactor(parameters, **opt_args) 89 | elif opt_lower == 'adahessian': 90 | optimizer = Adahessian(parameters, **opt_args) 91 | elif opt_lower == 'rmsprop': 92 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 93 | elif opt_lower == 'rmsproptf': 94 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 95 | elif opt_lower == 'novograd': 96 | optimizer = NovoGrad(parameters, **opt_args) 97 | elif opt_lower == 'nvnovograd': 98 | optimizer = NvNovoGrad(parameters, **opt_args) 99 | elif opt_lower == 'fusedsgd': 100 | opt_args.pop('eps', None) 101 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 102 | elif opt_lower == 'fusedmomentum': 103 | opt_args.pop('eps', None) 104 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 105 | elif opt_lower == 'fusedadam': 106 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 107 | elif opt_lower == 'fusedadamw': 108 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 109 | elif opt_lower == 'fusedlamb': 110 | optimizer = FusedLAMB(parameters, **opt_args) 111 | elif opt_lower == 'fusednovograd': 112 | opt_args.setdefault('betas', (0.95, 0.98)) 113 | optimizer = FusedNovoGrad(parameters, **opt_args) 114 | else: 115 | assert False and "Invalid optimizer" 116 | raise ValueError 117 | 118 | if len(opt_split) > 1: 119 | if opt_split[0] == 'lookahead': 120 | optimizer = Lookahead(optimizer) 121 | 122 | return optimizer 123 | -------------------------------------------------------------------------------- /optim/radam.py: -------------------------------------------------------------------------------- 1 | """RAdam Optimizer. 2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam 3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer, required 8 | 9 | 10 | class RAdam(Optimizer): 11 | 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 14 | self.buffer = [[None, None, None] for ind in range(10)] 15 | super(RAdam, self).__init__(params, defaults) 16 | 17 | def __setstate__(self, state): 18 | super(RAdam, self).__setstate__(state) 19 | 20 | def step(self, closure=None): 21 | 22 | loss = None 23 | if closure is not None: 24 | loss = closure() 25 | 26 | for group in self.param_groups: 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad.data.float() 32 | if grad.is_sparse: 33 | raise RuntimeError('RAdam does not support sparse gradients') 34 | 35 | p_data_fp32 = p.data.float() 36 | 37 | state = self.state[p] 38 | 39 | if len(state) == 0: 40 | state['step'] = 0 41 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 43 | else: 44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 46 | 47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 48 | beta1, beta2 = group['betas'] 49 | 50 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 51 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 52 | 53 | state['step'] += 1 54 | buffered = self.buffer[int(state['step'] % 10)] 55 | if state['step'] == buffered[0]: 56 | N_sma, step_size = buffered[1], buffered[2] 57 | else: 58 | buffered[0] = state['step'] 59 | beta2_t = beta2 ** state['step'] 60 | N_sma_max = 2 / (1 - beta2) - 1 61 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 62 | buffered[1] = N_sma 63 | 64 | # more conservative since it's an approximated value 65 | if N_sma >= 5: 66 | step_size = group['lr'] * math.sqrt( 67 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 68 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 69 | else: 70 | step_size = group['lr'] / (1 - beta1 ** state['step']) 71 | buffered[2] = step_size 72 | 73 | if group['weight_decay'] != 0: 74 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 75 | 76 | # more conservative since it's an approximated value 77 | if N_sma >= 5: 78 | denom = exp_avg_sq.sqrt().add_(group['eps']) 79 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 80 | else: 81 | p_data_fp32.add_(-step_size, exp_avg) 82 | 83 | p.data.copy_(p_data_fp32) 84 | 85 | return loss 86 | 87 | 88 | class PlainRAdam(Optimizer): 89 | 90 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 91 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 92 | 93 | super(PlainRAdam, self).__init__(params, defaults) 94 | 95 | def __setstate__(self, state): 96 | super(PlainRAdam, self).__setstate__(state) 97 | 98 | def step(self, closure=None): 99 | 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group in self.param_groups: 105 | 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data.float() 110 | if grad.is_sparse: 111 | raise RuntimeError('RAdam does not support sparse gradients') 112 | 113 | p_data_fp32 = p.data.float() 114 | 115 | state = self.state[p] 116 | 117 | if len(state) == 0: 118 | state['step'] = 0 119 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 121 | else: 122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 124 | 125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 126 | beta1, beta2 = group['betas'] 127 | 128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 129 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 130 | 131 | state['step'] += 1 132 | beta2_t = beta2 ** state['step'] 133 | N_sma_max = 2 / (1 - beta2) - 1 134 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 135 | 136 | if group['weight_decay'] != 0: 137 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 138 | 139 | # more conservative since it's an approximated value 140 | if N_sma >= 5: 141 | step_size = group['lr'] * math.sqrt( 142 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 143 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 144 | denom = exp_avg_sq.sqrt().add_(group['eps']) 145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 146 | else: 147 | step_size = group['lr'] / (1 - beta1 ** state['step']) 148 | p_data_fp32.add_(-step_size, exp_avg) 149 | 150 | p.data.copy_(p_data_fp32) 151 | 152 | return loss 153 | -------------------------------------------------------------------------------- /optim/rmsprop_tf.py: -------------------------------------------------------------------------------- 1 | """ RMSProp modified to behave like Tensorflow impl 2 | 3 | Originally cut & paste from PyTorch RMSProp 4 | https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py 5 | Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE 6 | 7 | Modifications Copyright 2020 Ross Wightman 8 | """ 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | 14 | class RMSpropTF(Optimizer): 15 | """Implements RMSprop algorithm (TensorFlow style epsilon) 16 | 17 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt 18 | and a few other modifications to closer match Tensorflow for matching hyper-params. 19 | 20 | Noteworthy changes include: 21 | 1. Epsilon applied inside square-root 22 | 2. square_avg initialized to ones 23 | 3. LR scaling of update accumulated in momentum buffer 24 | 25 | Proposed by G. Hinton in his 26 | `course `_. 27 | 28 | The centered version first appears in `Generating Sequences 29 | With Recurrent Neural Networks `_. 30 | 31 | Arguments: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-2) 35 | momentum (float, optional): momentum factor (default: 0) 36 | alpha (float, optional): smoothing (decay) constant (default: 0.9) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-10) 39 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 40 | the gradient is normalized by an estimation of its variance 41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 42 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 43 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer 44 | update as per defaults in Tensorflow 45 | 46 | """ 47 | 48 | def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, 49 | decoupled_decay=False, lr_in_momentum=True): 50 | if not 0.0 <= lr: 51 | raise ValueError("Invalid learning rate: {}".format(lr)) 52 | if not 0.0 <= eps: 53 | raise ValueError("Invalid epsilon value: {}".format(eps)) 54 | if not 0.0 <= momentum: 55 | raise ValueError("Invalid momentum value: {}".format(momentum)) 56 | if not 0.0 <= weight_decay: 57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 58 | if not 0.0 <= alpha: 59 | raise ValueError("Invalid alpha value: {}".format(alpha)) 60 | 61 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, 62 | decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) 63 | super(RMSpropTF, self).__init__(params, defaults) 64 | 65 | def __setstate__(self, state): 66 | super(RMSpropTF, self).__setstate__(state) 67 | for group in self.param_groups: 68 | group.setdefault('momentum', 0) 69 | group.setdefault('centered', False) 70 | 71 | def step(self, closure=None): 72 | """Performs a single optimization step. 73 | 74 | Arguments: 75 | closure (callable, optional): A closure that reevaluates the model 76 | and returns the loss. 77 | """ 78 | loss = None 79 | if closure is not None: 80 | loss = closure() 81 | 82 | for group in self.param_groups: 83 | for p in group['params']: 84 | if p.grad is None: 85 | continue 86 | grad = p.grad.data 87 | if grad.is_sparse: 88 | raise RuntimeError('RMSprop does not support sparse gradients') 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero 95 | if group['momentum'] > 0: 96 | state['momentum_buffer'] = torch.zeros_like(p.data) 97 | if group['centered']: 98 | state['grad_avg'] = torch.zeros_like(p.data) 99 | 100 | square_avg = state['square_avg'] 101 | one_minus_alpha = 1. - group['alpha'] 102 | 103 | state['step'] += 1 104 | 105 | if group['weight_decay'] != 0: 106 | if 'decoupled_decay' in group and group['decoupled_decay']: 107 | p.data.add_(-group['weight_decay'], p.data) 108 | else: 109 | grad = grad.add(group['weight_decay'], p.data) 110 | 111 | # Tensorflow order of ops for updating squared avg 112 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) 113 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original 114 | 115 | if group['centered']: 116 | grad_avg = state['grad_avg'] 117 | grad_avg.add_(one_minus_alpha, grad - grad_avg) 118 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original 119 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt 120 | else: 121 | avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt 122 | 123 | if group['momentum'] > 0: 124 | buf = state['momentum_buffer'] 125 | # Tensorflow accumulates the LR scaling in the momentum buffer 126 | if 'lr_in_momentum' in group and group['lr_in_momentum']: 127 | buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) 128 | p.data.add_(-buf) 129 | else: 130 | # PyTorch scales the param update by LR 131 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 132 | p.data.add_(-group['lr'], buf) 133 | else: 134 | p.data.addcdiv_(-group['lr'], grad, avg) 135 | 136 | return loss 137 | -------------------------------------------------------------------------------- /optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class SGDP(Optimizer): 17 | def __init__(self, params, lr=required, momentum=0, dampening=0, 18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 21 | super(SGDP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | weight_decay = group['weight_decay'] 62 | momentum = group['momentum'] 63 | dampening = group['dampening'] 64 | nesterov = group['nesterov'] 65 | 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['momentum'] = torch.zeros_like(p.data) 75 | 76 | # SGD 77 | buf = state['momentum'] 78 | buf.mul_(momentum).add_(1 - dampening, grad) 79 | if nesterov: 80 | d_p = grad + momentum * buf 81 | else: 82 | d_p = buf 83 | 84 | # Projection 85 | wd_ratio = 1 86 | if len(p.shape) > 1: 87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 88 | 89 | # Weight decay 90 | if weight_decay != 0: 91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 92 | 93 | # Step 94 | p.data.add_(-group['lr'], d_p) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import re 2 | import tempfile 3 | from functools import partial 4 | import cv2 5 | from PIL import Image 6 | import numpy as np 7 | from cog import BasePredictor, Path, Input 8 | 9 | from skimage import transform as skimage_transform 10 | from scipy.ndimage import filters 11 | from matplotlib import pyplot as plt 12 | 13 | import torch 14 | from torch import nn 15 | from torchvision import transforms 16 | 17 | from models.vit import VisionTransformer 18 | from models.xbert import BertConfig, BertModel 19 | from models.tokenization_bert import BertTokenizer 20 | 21 | 22 | class Predictor(BasePredictor): 23 | def setup(self): 24 | normalize = transforms.Normalize( 25 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 26 | ) 27 | 28 | self.transform = transforms.Compose( 29 | [ 30 | transforms.Resize((384, 384), interpolation=Image.BICUBIC), 31 | transforms.ToTensor(), 32 | normalize, 33 | ] 34 | ) 35 | 36 | self.tokenizer = BertTokenizer.from_pretrained("bert/bert-base-uncased") 37 | 38 | bert_config_path = "configs/config_bert.json" 39 | self.model = VL_Transformer_ITM( 40 | text_encoder="bert/bert-base-uncased", config_bert=bert_config_path 41 | ) 42 | 43 | checkpoint = torch.load("refcoco.pth", map_location="cpu") 44 | msg = self.model.load_state_dict(checkpoint, strict=False) 45 | self.model.eval() 46 | 47 | self.block_num = 8 48 | self.model.text_encoder.base_model.base_model.encoder.layer[ 49 | self.block_num 50 | ].crossattention.self.save_attention = True 51 | 52 | self.model.cuda() 53 | 54 | def predict( 55 | self, 56 | image: Path = Input(description="Input image."), 57 | caption: str = Input( 58 | description="Caption for the image. Grad-CAM visualization will be generated " 59 | "for each word in the cation." 60 | ), 61 | ) -> Path: 62 | 63 | image_pil = Image.open(str(image)).convert("RGB") 64 | img = self.transform(image_pil).unsqueeze(0) 65 | 66 | text = pre_caption(caption) 67 | text_input = self.tokenizer(text, return_tensors="pt") 68 | 69 | img = img.cuda() 70 | text_input = text_input.to(img.device) 71 | 72 | # Compute GradCAM 73 | output = self.model(img, text_input) 74 | loss = output[:, 1].sum() 75 | 76 | self.model.zero_grad() 77 | loss.backward() 78 | 79 | with torch.no_grad(): 80 | mask = text_input.attention_mask.view( 81 | text_input.attention_mask.size(0), 1, -1, 1, 1 82 | ) 83 | 84 | grads = self.model.text_encoder.base_model.base_model.encoder.layer[ 85 | self.block_num 86 | ].crossattention.self.get_attn_gradients() 87 | cams = self.model.text_encoder.base_model.base_model.encoder.layer[ 88 | self.block_num 89 | ].crossattention.self.get_attention_map() 90 | 91 | cams = cams[:, :, :, 1:].reshape(img.size(0), 12, -1, 24, 24) * mask 92 | grads = ( 93 | grads[:, :, :, 1:].clamp(0).reshape(img.size(0), 12, -1, 24, 24) * mask 94 | ) 95 | 96 | gradcam = cams * grads 97 | gradcam = gradcam[0].mean(0).cpu().detach() 98 | 99 | num_image = len(text_input.input_ids[0]) 100 | fig, ax = plt.subplots(num_image, 1, figsize=(20, 8 * num_image)) 101 | 102 | rgb_image = cv2.imread(str(image))[:, :, ::-1] 103 | rgb_image = np.float32(rgb_image) / 255 104 | 105 | ax[0].imshow(rgb_image) 106 | ax[0].set_yticks([]) 107 | ax[0].set_xticks([]) 108 | ax[0].set_xlabel("Image") 109 | 110 | for i, token_id in enumerate(text_input.input_ids[0][1:]): 111 | word = self.tokenizer.decode([token_id]) 112 | gradcam_image = getAttMap(rgb_image, gradcam[i + 1]) 113 | ax[i + 1].imshow(gradcam_image) 114 | ax[i + 1].set_yticks([]) 115 | ax[i + 1].set_xticks([]) 116 | ax[i + 1].set_xlabel(word) 117 | 118 | out_path = Path(tempfile.mkdtemp()) / "output.png" 119 | fig.savefig(str(out_path)) 120 | return out_path 121 | 122 | 123 | class VL_Transformer_ITM(nn.Module): 124 | def __init__(self, text_encoder=None, config_bert=""): 125 | super().__init__() 126 | 127 | bert_config = BertConfig.from_json_file(config_bert) 128 | 129 | self.visual_encoder = VisionTransformer( 130 | img_size=384, 131 | patch_size=16, 132 | embed_dim=768, 133 | depth=12, 134 | num_heads=12, 135 | mlp_ratio=4, 136 | qkv_bias=True, 137 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 138 | ) 139 | 140 | self.text_encoder = BertModel.from_pretrained( 141 | text_encoder, config=bert_config, add_pooling_layer=False 142 | ) 143 | 144 | self.itm_head = nn.Linear(768, 2) 145 | 146 | def forward(self, image, text): 147 | image_embeds = self.visual_encoder(image) 148 | 149 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 150 | image.device 151 | ) 152 | 153 | output = self.text_encoder( 154 | text.input_ids, 155 | attention_mask=text.attention_mask, 156 | encoder_hidden_states=image_embeds, 157 | encoder_attention_mask=image_atts, 158 | return_dict=True, 159 | ) 160 | 161 | vl_embeddings = output.last_hidden_state[:, 0, :] 162 | vl_output = self.itm_head(vl_embeddings) 163 | return vl_output 164 | 165 | 166 | def pre_caption(caption, max_words=30): 167 | caption = ( 168 | re.sub( 169 | r"([,.'!?\"()*#:;~])", 170 | "", 171 | caption.lower(), 172 | ) 173 | .replace("-", " ") 174 | .replace("/", " ") 175 | ) 176 | 177 | caption = re.sub( 178 | r"\s{2,}", 179 | " ", 180 | caption, 181 | ) 182 | caption = caption.rstrip("\n") 183 | caption = caption.strip(" ") 184 | 185 | # truncate caption 186 | caption_words = caption.split(" ") 187 | if len(caption_words) > max_words: 188 | caption = " ".join(caption_words[:max_words]) 189 | return caption 190 | 191 | 192 | def getAttMap(img, attMap, blur=True, overlap=True): 193 | attMap -= attMap.min() 194 | if attMap.max() > 0: 195 | attMap /= attMap.max() 196 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 197 | if blur: 198 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 199 | attMap -= attMap.min() 200 | attMap /= attMap.max() 201 | cmap = plt.get_cmap("jet") 202 | attMapV = cmap(attMap) 203 | attMapV = np.delete(attMapV, 3, 2) 204 | if overlap: 205 | attMap = ( 206 | 1 * (1 - attMap ** 0.7).reshape(attMap.shape + (1,)) * img 207 | + (attMap ** 0.7).reshape(attMap.shape + (1,)) * attMapV 208 | ) 209 | return attMap 210 | -------------------------------------------------------------------------------- /refTools/__pycache__/refer_python3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/__pycache__/refer_python3.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/__pycache__/refer_python3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/__pycache__/refer_python3.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'licheng' 2 | 3 | 4 | -------------------------------------------------------------------------------- /refTools/evaluation/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/__pycache__/refEvaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/__pycache__/refEvaluation.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/__pycache__/refEvaluation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/__pycache__/refEvaluation.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__init__.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__pycache__/bleu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/bleu.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__pycache__/bleu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/bleu.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from refTools.evaluation.bleu.bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /refTools/evaluation/bleu/bleu.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/bleu.pyc -------------------------------------------------------------------------------- /refTools/evaluation/bleu/bleu_scorer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/bleu_scorer.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/cider/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__init__.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/__pycache__/cider.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/cider.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/__pycache__/cider.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/cider.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/__pycache__/cider_scorer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/__pycache__/cider_scorer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from refTools.evaluation.cider.cider_scorer import CiderScorer 11 | import pdb 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | assert(gts.keys() == res.keys()) 33 | imgIds = gts.keys() 34 | 35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 36 | 37 | for id in imgIds: 38 | hypo = res[id] 39 | ref = gts[id] 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | 51 | return score, scores 52 | 53 | def method(self): 54 | return "CIDEr" -------------------------------------------------------------------------------- /refTools/evaluation/cider/cider.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/cider.pyc -------------------------------------------------------------------------------- /refTools/evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in xrange(1,n+1): 23 | for i in xrange(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.iteritems(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].iteritems(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /refTools/evaluation/cider/cider_scorer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/cider_scorer.pyc -------------------------------------------------------------------------------- /refTools/evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/meteor/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__init__.pyc -------------------------------------------------------------------------------- /refTools/evaluation/meteor/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/meteor/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/meteor/__pycache__/meteor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__pycache__/meteor.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/meteor/__pycache__/meteor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__pycache__/meteor.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /refTools/evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import threading 10 | 11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 12 | METEOR_JAR = 'meteor-1.5.jar' 13 | # print METEOR_JAR 14 | 15 | class Meteor: 16 | 17 | def __init__(self): 18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 19 | '-', '-', '-stdio', '-l', 'en', '-norm'] 20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 21 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 22 | stdin=subprocess.PIPE, \ 23 | stdout=subprocess.PIPE, \ 24 | stderr=subprocess.PIPE) 25 | # Used to guarantee thread safety 26 | self.lock = threading.Lock() 27 | 28 | def compute_score(self, gts, res): 29 | assert(gts.keys() == res.keys()) 30 | imgIds = gts.keys() 31 | scores = [] 32 | 33 | eval_line = 'EVAL' 34 | self.lock.acquire() 35 | for i in imgIds: 36 | assert(len(res[i]) == 1) 37 | stat = self._stat(res[i][0], gts[i]) 38 | eval_line += ' ||| {}'.format(stat) 39 | 40 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 41 | for i in range(0,len(imgIds)): 42 | scores.append(float(self.meteor_p.stdout.readline().strip())) 43 | score = float(self.meteor_p.stdout.readline().strip()) 44 | self.lock.release() 45 | 46 | return score, scores 47 | 48 | def method(self): 49 | return "METEOR" 50 | 51 | def _stat(self, hypothesis_str, reference_list): 52 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 53 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 54 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 55 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 56 | return self.meteor_p.stdout.readline().decode().strip() 57 | 58 | def _score(self, hypothesis_str, reference_list): 59 | self.lock.acquire() 60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 63 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 64 | stats = self.meteor_p.stdout.readline().strip() 65 | eval_line = 'EVAL ||| {}'.format(stats) 66 | # EVAL ||| stats 67 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 68 | score = float(self.meteor_p.stdout.readline().strip()) 69 | self.lock.release() 70 | return score 71 | 72 | def __exit__(self): 73 | self.lock.acquire() 74 | self.meteor_p.stdin.close() 75 | self.meteor_p.wait() 76 | self.lock.release() 77 | -------------------------------------------------------------------------------- /refTools/evaluation/meteor/meteor.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/meteor.pyc -------------------------------------------------------------------------------- /refTools/evaluation/readme.txt: -------------------------------------------------------------------------------- 1 | This folder contains modified coco-caption evaluation, which is downloaded from https://github.com/tylin/coco-caption.git 2 | and refEvaluation which is to be called by the refer algorithm. 3 | 4 | More specifically, this folder contains: 5 | 1. bleu/ 6 | 2. cider/ 7 | 3. meteor/ 8 | 4. rouge/ 9 | 5. tokenizer/ 10 | 6. __init__.py 11 | 7. refEvaluation.py 12 | -------------------------------------------------------------------------------- /refTools/evaluation/refEvaluation.py: -------------------------------------------------------------------------------- 1 | from refTools.evaluation.tokenizer.ptbtokenizer import PTBTokenizer 2 | from refTools.evaluation.bleu.bleu import Bleu 3 | from refTools.evaluation.meteor.meteor import Meteor 4 | from refTools.evaluation.rouge.rouge import Rouge 5 | from refTools.evaluation.cider.cider import Cider 6 | 7 | """ 8 | Input: refer and Res = [{ref_id, sent}] 9 | 10 | Things of interest 11 | evalRefs - list of ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR'] 12 | eval - dict of {metric: score} 13 | refToEval - dict of {ref_id: ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']} 14 | """ 15 | 16 | class RefEvaluation: 17 | def __init__ (self, refer, Res): 18 | """ 19 | :param refer: refer class of current dataset 20 | :param Res: [{'ref_id', 'sent'}] 21 | """ 22 | self.evalRefs = [] 23 | self.eval = {} 24 | self.refToEval = {} 25 | self.refer = refer 26 | self.Res = Res 27 | 28 | def evaluate(self): 29 | 30 | evalRefIds = [ann['ref_id'] for ann in self.Res] 31 | 32 | refToGts = {} 33 | for ref_id in evalRefIds: 34 | ref = self.refer.Refs[ref_id] 35 | gt_sents = [sent['sent'].encode('ascii', 'ignore').decode('ascii') for sent in ref['sentences']] # up to 3 expressions 36 | refToGts[ref_id] = gt_sents 37 | refToRes = {ann['ref_id']: [ann['sent']] for ann in self.Res} 38 | 39 | print('tokenization...') 40 | tokenizer = PTBTokenizer() 41 | self.refToRes = tokenizer.tokenize(refToRes) 42 | self.refToGts = tokenizer.tokenize(refToGts) 43 | 44 | # ================================================= 45 | # Set up scorers 46 | # ================================================= 47 | print('setting up scorers...') 48 | scorers = [ 49 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 50 | (Meteor(),"METEOR"), 51 | (Rouge(), "ROUGE_L"), 52 | (Cider(), "CIDEr") 53 | ] 54 | 55 | # ================================================= 56 | # Compute scores 57 | # ================================================= 58 | for scorer, method in scorers: 59 | print('computing %s score...'%(scorer.method())) 60 | score, scores = scorer.compute_score(self.refToGts, self.refToRes) 61 | if type(method) == list: 62 | for sc, scs, m in zip(score, scores, method): 63 | self.setEval(sc, m) 64 | self.setRefToEvalRefs(scs, self.refToGts.keys(), m) 65 | print("%s: %0.3f"%(m, sc)) 66 | else: 67 | self.setEval(score, method) 68 | self.setRefToEvalRefs(scores, self.refToGts.keys(), method) 69 | print("%s: %0.3f"%(method, score)) 70 | self.setEvalRefs() 71 | 72 | def setEval(self, score, method): 73 | self.eval[method] = score 74 | 75 | def setRefToEvalRefs(self, scores, refIds, method): 76 | for refId, score in zip(refIds, scores): 77 | if not refId in self.refToEval: 78 | self.refToEval[refId] = {} 79 | self.refToEval[refId]["ref_id"] = refId 80 | self.refToEval[refId][method] = score 81 | 82 | def setEvalRefs(self): 83 | self.evalRefs = [eval for refId, eval in self.refToEval.items()] 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | import os.path as osp 89 | import sys 90 | ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..')) 91 | sys.path.insert(0, osp.join(ROOT_DIR, 'lib', 'datasets')) 92 | from refer import REFER 93 | 94 | # load refer of dataset 95 | dataset = 'refcoco' 96 | refer = REFER(dataset, splitBy = 'google') 97 | 98 | # mimic some Res 99 | val_refIds = refer.getRefIds(split='test') 100 | ref_id = 49767 101 | print("GD: %s" % refer.Refs[ref_id]['sentences']) 102 | Res = [{'ref_id': ref_id, 'sent': 'left bottle'}] 103 | 104 | # evaluate some refer expressions 105 | refEval = RefEvaluation(refer, Res) 106 | refEval.evaluate() 107 | 108 | # print output evaluation scores 109 | for metric, score in refEval.eval.items(): 110 | print('%s: %.3f'%(metric, score)) 111 | 112 | # demo how to use evalImgs to retrieve low score result 113 | # evals = [eva for eva in refEval.evalRefs if eva['CIDEr']<30] 114 | # print 'ground truth sents' 115 | # refId = evals[0]['ref_id'] 116 | # print 'refId: %s' % refId 117 | # print [sent['sent'] for sent in refer.Refs[refId]['sentences']] 118 | # 119 | # print 'generated sent (CIDEr score %0.1f)' % (evals[0]['CIDEr']) 120 | 121 | # print refEval.refToEval[8] 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /refTools/evaluation/refEvaluation.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/refEvaluation.pyc -------------------------------------------------------------------------------- /refTools/evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/rouge/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__init__.pyc -------------------------------------------------------------------------------- /refTools/evaluation/rouge/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/rouge/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/rouge/__pycache__/rouge.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__pycache__/rouge.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/rouge/__pycache__/rouge.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__pycache__/rouge.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /refTools/evaluation/rouge/rouge.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/rouge.pyc -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__init__.pyc -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-36.pyc -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import tempfile 15 | import itertools 16 | 17 | # path to the stanford corenlp jar 18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 19 | 20 | # punctuations to be removed from the sentences 21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 23 | 24 | class PTBTokenizer: 25 | """Python wrapper of Stanford PTBTokenizer""" 26 | 27 | def tokenize(self, captions_for_image): 28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 29 | 'edu.stanford.nlp.process.PTBTokenizer', \ 30 | '-preserveLines', '-lowerCase'] 31 | 32 | # ====================================================== 33 | # prepare data for PTB Tokenizer 34 | # ====================================================== 35 | final_tokenized_captions_for_image = {} 36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 38 | 39 | # ====================================================== 40 | # save sentences to temporary file 41 | # ====================================================== 42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 44 | tmp_file.write(sentences.encode()) 45 | tmp_file.close() 46 | 47 | # ====================================================== 48 | # tokenize sentence 49 | # ====================================================== 50 | cmd.append(os.path.basename(tmp_file.name)) 51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 52 | stdout=subprocess.PIPE) 53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 54 | token_lines = token_lines.decode() 55 | lines = token_lines.split('\n') 56 | # remove temp file 57 | os.remove(tmp_file.name) 58 | 59 | # ====================================================== 60 | # create dictionary for tokenized captions 61 | # ====================================================== 62 | for k, line in zip(image_id, lines): 63 | if not k in final_tokenized_captions_for_image: 64 | final_tokenized_captions_for_image[k] = [] 65 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 66 | if w not in PUNCTUATIONS]) 67 | final_tokenized_captions_for_image[k].append(tokenized_caption) 68 | 69 | return final_tokenized_captions_for_image 70 | -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/ptbtokenizer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/ptbtokenizer.pyc -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/tmp82iqkuu0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/tmp82iqkuu0 -------------------------------------------------------------------------------- /refTools/evaluation/tokenizer/tmpn19wmqte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/tmpn19wmqte -------------------------------------------------------------------------------- /scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .plateau_lr import PlateauLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .tanh_lr import TanhLRScheduler 5 | from .scheduler_factory import create_scheduler 6 | -------------------------------------------------------------------------------- /scheduler/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/cosine_lr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/cosine_lr.cpython-36.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/cosine_lr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/cosine_lr.cpython-38.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/plateau_lr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/plateau_lr.cpython-36.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/plateau_lr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/plateau_lr.cpython-38.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/scheduler_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/scheduler_factory.cpython-36.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/scheduler_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/scheduler_factory.cpython-38.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/step_lr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/step_lr.cpython-36.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/step_lr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/step_lr.cpython-38.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/tanh_lr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/tanh_lr.cpython-36.pyc -------------------------------------------------------------------------------- /scheduler/__pycache__/tanh_lr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/tanh_lr.cpython-38.pyc -------------------------------------------------------------------------------- /scheduler/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | from pdb import set_trace as breakpoint 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class CosineLRScheduler(Scheduler): 20 | """ 21 | Cosine decay with restarts. 22 | This is described in the paper https://arxiv.org/abs/1608.03983. 23 | 24 | Inspiration from 25 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 26 | """ 27 | 28 | def __init__(self, 29 | optimizer: torch.optim.Optimizer, 30 | t_initial: int, 31 | t_mul: float = 1., 32 | lr_min: float = 0., 33 | decay_rate: float = 1., 34 | warmup_t=0, 35 | warmup_lr_init=0, 36 | warmup_prefix=True, 37 | cycle_limit=0, 38 | t_in_epochs=True, 39 | noise_range_t=None, 40 | noise_pct=0.67, 41 | noise_std=1.0, 42 | noise_seed=42, 43 | initialize=True) -> None: 44 | super().__init__( 45 | optimizer, param_group_field="lr", 46 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 47 | initialize=initialize) 48 | 49 | assert t_initial > 0 50 | assert lr_min >= 0 51 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 52 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 53 | "rate since t_initial = t_mul = eta_mul = 1.") 54 | self.t_initial = t_initial 55 | self.t_mul = t_mul 56 | self.lr_min = lr_min 57 | self.decay_rate = decay_rate 58 | self.cycle_limit = cycle_limit 59 | self.warmup_t = warmup_t 60 | self.warmup_lr_init = warmup_lr_init 61 | self.warmup_prefix = warmup_prefix 62 | self.t_in_epochs = t_in_epochs 63 | if self.warmup_t: 64 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 65 | super().update_groups(self.warmup_lr_init) 66 | else: 67 | self.warmup_steps = [1 for _ in self.base_values] 68 | 69 | def _get_lr(self, t): 70 | if t < self.warmup_t: 71 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 72 | else: 73 | if self.warmup_prefix: 74 | t = t - self.warmup_t 75 | 76 | if self.t_mul != 1: 77 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 78 | t_i = self.t_mul ** i * self.t_initial 79 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 80 | else: 81 | i = t // self.t_initial 82 | t_i = self.t_initial 83 | t_curr = t - (self.t_initial * i) 84 | 85 | gamma = self.decay_rate ** i 86 | lr_min = self.lr_min * gamma 87 | lr_max_values = [v * gamma for v in self.base_values] 88 | 89 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 90 | lrs = [ 91 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 92 | ] 93 | else: 94 | lrs = [self.lr_min for _ in self.base_values] 95 | 96 | return lrs 97 | 98 | def get_epoch_values(self, epoch: int): 99 | if self.t_in_epochs: 100 | return self._get_lr(epoch) 101 | else: 102 | return None 103 | 104 | def get_update_values(self, num_updates: int): 105 | if not self.t_in_epochs: 106 | return self._get_lr(num_updates) 107 | else: 108 | return None 109 | 110 | def get_cycle_length(self, cycles=0): 111 | if not cycles: 112 | cycles = self.cycle_limit 113 | cycles = max(1, cycles) 114 | if self.t_mul == 1.0: 115 | return self.t_initial * cycles 116 | else: 117 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 118 | -------------------------------------------------------------------------------- /scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | """ Plateau Scheduler 2 | 3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | 9 | from .scheduler import Scheduler 10 | 11 | 12 | class PlateauLRScheduler(Scheduler): 13 | """Decay the LR by a factor every time the validation loss plateaus.""" 14 | 15 | def __init__(self, 16 | optimizer, 17 | decay_rate=0.1, 18 | patience_t=10, 19 | verbose=True, 20 | threshold=1e-4, 21 | cooldown_t=0, 22 | warmup_t=0, 23 | warmup_lr_init=0, 24 | lr_min=0, 25 | mode='max', 26 | noise_range_t=None, 27 | noise_type='normal', 28 | noise_pct=0.67, 29 | noise_std=1.0, 30 | noise_seed=None, 31 | initialize=True, 32 | ): 33 | super().__init__(optimizer, 'lr', initialize=initialize) 34 | 35 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 36 | self.optimizer, 37 | patience=patience_t, 38 | factor=decay_rate, 39 | verbose=verbose, 40 | threshold=threshold, 41 | cooldown=cooldown_t, 42 | mode=mode, 43 | min_lr=lr_min 44 | ) 45 | 46 | self.noise_range = noise_range_t 47 | self.noise_pct = noise_pct 48 | self.noise_type = noise_type 49 | self.noise_std = noise_std 50 | self.noise_seed = noise_seed if noise_seed is not None else 42 51 | self.warmup_t = warmup_t 52 | self.warmup_lr_init = warmup_lr_init 53 | if self.warmup_t: 54 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 55 | super().update_groups(self.warmup_lr_init) 56 | else: 57 | self.warmup_steps = [1 for _ in self.base_values] 58 | self.restore_lr = None 59 | 60 | def state_dict(self): 61 | return { 62 | 'best': self.lr_scheduler.best, 63 | 'last_epoch': self.lr_scheduler.last_epoch, 64 | } 65 | 66 | def load_state_dict(self, state_dict): 67 | self.lr_scheduler.best = state_dict['best'] 68 | if 'last_epoch' in state_dict: 69 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 70 | 71 | # override the base class step fn completely 72 | def step(self, epoch, metric=None): 73 | if epoch <= self.warmup_t: 74 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 75 | super().update_groups(lrs) 76 | else: 77 | if self.restore_lr is not None: 78 | # restore actual LR from before our last noise perturbation before stepping base 79 | for i, param_group in enumerate(self.optimizer.param_groups): 80 | param_group['lr'] = self.restore_lr[i] 81 | self.restore_lr = None 82 | 83 | self.lr_scheduler.step(metric, epoch) # step the base scheduler 84 | 85 | if self.noise_range is not None: 86 | if isinstance(self.noise_range, (list, tuple)): 87 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] 88 | else: 89 | apply_noise = epoch >= self.noise_range 90 | if apply_noise: 91 | self._apply_noise(epoch) 92 | 93 | def _apply_noise(self, epoch): 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + epoch) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | 105 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 106 | # stepping of base scheduler 107 | restore_lr = [] 108 | for i, param_group in enumerate(self.optimizer.param_groups): 109 | old_lr = float(param_group['lr']) 110 | restore_lr.append(old_lr) 111 | new_lr = old_lr + old_lr * noise 112 | param_group['lr'] = new_lr 113 | self.restore_lr = restore_lr 114 | -------------------------------------------------------------------------------- /scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /scheduler/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | from .tanh_lr import TanhLRScheduler 6 | from .step_lr import StepLRScheduler 7 | from .plateau_lr import PlateauLRScheduler 8 | 9 | 10 | def create_scheduler(args, optimizer): 11 | num_epochs = args.epochs 12 | 13 | if getattr(args, 'lr_noise', None) is not None: 14 | lr_noise = getattr(args, 'lr_noise') 15 | if isinstance(lr_noise, (list, tuple)): 16 | noise_range = [n * num_epochs for n in lr_noise] 17 | if len(noise_range) == 1: 18 | noise_range = noise_range[0] 19 | else: 20 | noise_range = lr_noise * num_epochs 21 | else: 22 | noise_range = None 23 | 24 | lr_scheduler = None 25 | if args.sched == 'cosine': 26 | lr_scheduler = CosineLRScheduler( 27 | optimizer, 28 | t_initial=num_epochs, 29 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 30 | lr_min=args.min_lr, 31 | decay_rate=args.decay_rate, 32 | warmup_lr_init=args.warmup_lr, 33 | warmup_t=args.warmup_epochs, 34 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 35 | t_in_epochs=True, 36 | noise_range_t=noise_range, 37 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 38 | noise_std=getattr(args, 'lr_noise_std', 1.), 39 | noise_seed=getattr(args, 'seed', 42), 40 | ) 41 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 42 | elif args.sched == 'tanh': 43 | lr_scheduler = TanhLRScheduler( 44 | optimizer, 45 | t_initial=num_epochs, 46 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 47 | lr_min=args.min_lr, 48 | warmup_lr_init=args.warmup_lr, 49 | warmup_t=args.warmup_epochs, 50 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 51 | t_in_epochs=True, 52 | noise_range_t=noise_range, 53 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 54 | noise_std=getattr(args, 'lr_noise_std', 1.), 55 | noise_seed=getattr(args, 'seed', 42), 56 | ) 57 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 58 | elif args.sched == 'step': 59 | lr_scheduler = StepLRScheduler( 60 | optimizer, 61 | decay_t=args.decay_epochs, 62 | decay_rate=args.decay_rate, 63 | warmup_lr_init=args.warmup_lr, 64 | warmup_t=args.warmup_epochs, 65 | noise_range_t=noise_range, 66 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 67 | noise_std=getattr(args, 'lr_noise_std', 1.), 68 | noise_seed=getattr(args, 'seed', 42), 69 | ) 70 | elif args.sched == 'plateau': 71 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' 72 | lr_scheduler = PlateauLRScheduler( 73 | optimizer, 74 | decay_rate=args.decay_rate, 75 | patience_t=args.patience_epochs, 76 | lr_min=args.min_lr, 77 | mode=mode, 78 | warmup_lr_init=args.warmup_lr, 79 | warmup_t=args.warmup_epochs, 80 | cooldown_t=0, 81 | noise_range_t=noise_range, 82 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 83 | noise_std=getattr(args, 'lr_noise_std', 1.), 84 | noise_seed=getattr(args, 'seed', 42), 85 | ) 86 | 87 | return lr_scheduler, num_epochs 88 | -------------------------------------------------------------------------------- /scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | """ TanH Scheduler 2 | 3 | TanH schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class TanhLRScheduler(Scheduler): 19 | """ 20 | Hyberbolic-Tangent decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1806.01593 22 | """ 23 | 24 | def __init__(self, 25 | optimizer: torch.optim.Optimizer, 26 | t_initial: int, 27 | lb: float = -6., 28 | ub: float = 4., 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | assert lb < ub 50 | assert cycle_limit >= 0 51 | assert warmup_t >= 0 52 | assert warmup_lr_init >= 0 53 | self.lb = lb 54 | self.ub = ub 55 | self.t_initial = t_initial 56 | self.t_mul = t_mul 57 | self.lr_min = lr_min 58 | self.decay_rate = decay_rate 59 | self.cycle_limit = cycle_limit 60 | self.warmup_t = warmup_t 61 | self.warmup_lr_init = warmup_lr_init 62 | self.warmup_prefix = warmup_prefix 63 | self.t_in_epochs = t_in_epochs 64 | if self.warmup_t: 65 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 66 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 67 | super().update_groups(self.warmup_lr_init) 68 | else: 69 | self.warmup_steps = [1 for _ in self.base_values] 70 | 71 | def _get_lr(self, t): 72 | if t < self.warmup_t: 73 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 74 | else: 75 | if self.warmup_prefix: 76 | t = t - self.warmup_t 77 | 78 | if self.t_mul != 1: 79 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 80 | t_i = self.t_mul ** i * self.t_initial 81 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 82 | else: 83 | i = t // self.t_initial 84 | t_i = self.t_initial 85 | t_curr = t - (self.t_initial * i) 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | gamma = self.decay_rate ** i 89 | lr_min = self.lr_min * gamma 90 | lr_max_values = [v * gamma for v in self.base_values] 91 | 92 | tr = t_curr / t_i 93 | lrs = [ 94 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 95 | for lr_max in lr_max_values 96 | ] 97 | else: 98 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] 99 | return lrs 100 | 101 | def get_epoch_values(self, epoch: int): 102 | if self.t_in_epochs: 103 | return self._get_lr(epoch) 104 | else: 105 | return None 106 | 107 | def get_update_values(self, num_updates: int): 108 | if not self.t_in_epochs: 109 | return self._get_lr(num_updates) 110 | else: 111 | return None 112 | 113 | def get_cycle_length(self, cycles=0): 114 | if not cycles: 115 | cycles = self.cycle_limit 116 | cycles = max(1, cycles) 117 | if self.t_mul == 1.0: 118 | return self.t_initial * cycles 119 | else: 120 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 121 | -------------------------------------------------------------------------------- /vqaTools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | -------------------------------------------------------------------------------- /vqaTools/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /vqaTools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /vqaTools/__pycache__/vqa.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/vqa.cpython-36.pyc -------------------------------------------------------------------------------- /vqaTools/__pycache__/vqa.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/vqa.cpython-38.pyc -------------------------------------------------------------------------------- /vqaTools/__pycache__/vqaEval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/vqaEval.cpython-36.pyc -------------------------------------------------------------------------------- /vqaTools/__pycache__/vqaEval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/vqaEval.cpython-38.pyc -------------------------------------------------------------------------------- /vqaTools/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | __version__ = '0.9' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | class VQA: 24 | def __init__(self, annotation_file=None, question_file=None): 25 | """ 26 | Constructor of VQA helper class for reading and visualizing questions and answers. 27 | :param annotation_file (str): location of VQA annotation file 28 | :return: 29 | """ 30 | # load dataset 31 | self.dataset = {} 32 | self.questions = {} 33 | self.qa = {} 34 | self.qqa = {} 35 | self.imgToQA = {} 36 | if not annotation_file == None and not question_file == None: 37 | print('loading VQA annotations and questions into memory...') 38 | time_t = datetime.datetime.utcnow() 39 | dataset = json.load(open(annotation_file, 'r')) 40 | questions = json.load(open(question_file, 'r')) 41 | self.dataset = dataset 42 | self.questions = questions 43 | self.createIndex() 44 | 45 | def createIndex(self): 46 | # create index 47 | print('creating index...') 48 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} 49 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']} 50 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} 51 | for ann in self.dataset['annotations']: 52 | imgToQA[ann['image_id']] += [ann] 53 | qa[ann['question_id']] = ann 54 | for ques in self.questions['questions']: 55 | qqa[ques['question_id']] = ques 56 | print('index created!') 57 | 58 | # create class members 59 | self.qa = qa 60 | self.qqa = qqa 61 | self.imgToQA = imgToQA 62 | 63 | def info(self): 64 | """ 65 | Print information about the VQA annotation file. 66 | :return: 67 | """ 68 | for key, value in self.datset['info'].items(): 69 | print('%s: %s'%(key, value)) 70 | 71 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 72 | """ 73 | Get question ids that satisfy given filter conditions. default skips that filter 74 | :param imgIds (int array) : get question ids for given imgs 75 | quesTypes (str array) : get question ids for given question types 76 | ansTypes (str array) : get question ids for given answer types 77 | :return: ids (int array) : integer array of question ids 78 | """ 79 | imgIds = imgIds if type(imgIds) == list else [imgIds] 80 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 81 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 82 | 83 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 84 | anns = self.dataset['annotations'] 85 | else: 86 | if not len(imgIds) == 0: 87 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[]) 88 | else: 89 | anns = self.dataset['annotations'] 90 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 91 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 92 | ids = [ann['question_id'] for ann in anns] 93 | return ids 94 | 95 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 96 | """ 97 | Get image ids that satisfy given filter conditions. default skips that filter 98 | :param quesIds (int array) : get image ids for given question ids 99 | quesTypes (str array) : get image ids for given question types 100 | ansTypes (str array) : get image ids for given answer types 101 | :return: ids (int array) : integer array of image ids 102 | """ 103 | quesIds = quesIds if type(quesIds) == list else [quesIds] 104 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 105 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 106 | 107 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 108 | anns = self.dataset['annotations'] 109 | else: 110 | if not len(quesIds) == 0: 111 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[]) 112 | else: 113 | anns = self.dataset['annotations'] 114 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 115 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 116 | ids = [ann['image_id'] for ann in anns] 117 | return ids 118 | 119 | def loadQA(self, ids=[]): 120 | """ 121 | Load questions and answers with the specified question ids. 122 | :param ids (int array) : integer ids specifying question ids 123 | :return: qa (object array) : loaded qa objects 124 | """ 125 | if type(ids) == list: 126 | return [self.qa[id] for id in ids] 127 | elif type(ids) == int: 128 | return [self.qa[ids]] 129 | 130 | def showQA(self, anns): 131 | """ 132 | Display the specified annotations. 133 | :param anns (array of object): annotations to display 134 | :return: None 135 | """ 136 | if len(anns) == 0: 137 | return 0 138 | for ann in anns: 139 | quesId = ann['question_id'] 140 | print("Question: %s" %(self.qqa[quesId]['question'])) 141 | for ans in ann['answers']: 142 | print("Answer %d: %s" %(ans['answer_id'], ans['answer'])) 143 | 144 | def loadRes(self, resFile, quesFile): 145 | """ 146 | Load result file and return a result object. 147 | :param resFile (str) : file name of result file 148 | :return: res (obj) : result api object 149 | """ 150 | res = VQA() 151 | res.questions = json.load(open(quesFile)) 152 | res.dataset['info'] = copy.deepcopy(self.questions['info']) 153 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) 154 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) 155 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype']) 156 | res.dataset['license'] = copy.deepcopy(self.questions['license']) 157 | 158 | print('Loading and preparing results... ') 159 | time_t = datetime.datetime.utcnow() 160 | anns = json.load(open(resFile)) 161 | assert type(anns) == list, 'results is not an array of objects' 162 | annsQuesIds = [ann['question_id'] for ann in anns] 163 | assert set(annsQuesIds) == set(self.getQuesIds()), \ 164 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' 165 | for ann in anns: 166 | quesId = ann['question_id'] 167 | if res.dataset['task_type'] == 'Multiple Choice': 168 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices' 169 | qaAnn = self.qa[quesId] 170 | ann['image_id'] = qaAnn['image_id'] 171 | ann['question_type'] = qaAnn['question_type'] 172 | ann['answer_type'] = qaAnn['answer_type'] 173 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())) 174 | 175 | res.dataset['annotations'] = anns 176 | res.createIndex() 177 | return res 178 | --------------------------------------------------------------------------------