├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── Grounding.py
├── LICENSE.txt
├── NLVR.py
├── Pretrain.py
├── Pretrain_nlvr.py
├── README.md
├── Retrieval.py
├── SECURITY.md
├── VE.py
├── VQA.py
├── cog.yaml
├── configs
├── Grounding.yaml
├── NLVR.yaml
├── NLVR_pretrain.yaml
├── Pretrain.yaml
├── Retrieval_coco.yaml
├── Retrieval_flickr.yaml
├── VE.yaml
├── VQA.yaml
└── config_bert.json
├── dataset
├── __init__.py
├── caption_dataset.py
├── grounding_dataset.py
├── nlvr_dataset.py
├── randaugment.py
├── utils.py
├── ve_dataset.py
└── vqa_dataset.py
├── examples
├── image0.jpg
└── visualization.png
├── img.png
├── models
├── __init__.py
├── model_nlvr.py
├── model_pretrain.py
├── model_pretrain_nlvr.py
├── model_retrieval.py
├── model_ve.py
├── model_vqa.py
├── tokenization_bert.py
├── vit.py
└── xbert.py
├── optim
├── __init__.py
├── adafactor.py
├── adahessian.py
├── adamp.py
├── adamw.py
├── lookahead.py
├── nadam.py
├── novograd.py
├── nvnovograd.py
├── optim_factory.py
├── radam.py
├── rmsprop_tf.py
└── sgdp.py
├── predict.py
├── refTools
├── __pycache__
│ ├── refer_python3.cpython-36.pyc
│ └── refer_python3.cpython-38.pyc
├── evaluation
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── refEvaluation.cpython-36.pyc
│ │ └── refEvaluation.cpython-38.pyc
│ ├── bleu
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── bleu.cpython-36.pyc
│ │ │ ├── bleu.cpython-38.pyc
│ │ │ ├── bleu_scorer.cpython-36.pyc
│ │ │ └── bleu_scorer.cpython-38.pyc
│ │ ├── bleu.py
│ │ ├── bleu.pyc
│ │ ├── bleu_scorer.py
│ │ └── bleu_scorer.pyc
│ ├── cider
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── cider.cpython-36.pyc
│ │ │ ├── cider.cpython-38.pyc
│ │ │ ├── cider_scorer.cpython-36.pyc
│ │ │ └── cider_scorer.cpython-38.pyc
│ │ ├── cider.py
│ │ ├── cider.pyc
│ │ ├── cider_scorer.py
│ │ └── cider_scorer.pyc
│ ├── meteor
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── meteor.cpython-36.pyc
│ │ │ └── meteor.cpython-38.pyc
│ │ ├── meteor-1.5.jar
│ │ ├── meteor.py
│ │ └── meteor.pyc
│ ├── readme.txt
│ ├── refEvaluation.py
│ ├── refEvaluation.pyc
│ ├── rouge
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── rouge.cpython-36.pyc
│ │ │ └── rouge.cpython-38.pyc
│ │ ├── rouge.py
│ │ └── rouge.pyc
│ └── tokenizer
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── ptbtokenizer.cpython-36.pyc
│ │ └── ptbtokenizer.cpython-38.pyc
│ │ ├── ptbtokenizer.py
│ │ ├── ptbtokenizer.pyc
│ │ ├── stanford-corenlp-3.4.1.jar
│ │ ├── tmp37tp6xj8
│ │ ├── tmp82iqkuu0
│ │ └── tmpn19wmqte
└── refer_python3.py
├── scheduler
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-38.pyc
│ ├── cosine_lr.cpython-36.pyc
│ ├── cosine_lr.cpython-38.pyc
│ ├── plateau_lr.cpython-36.pyc
│ ├── plateau_lr.cpython-38.pyc
│ ├── scheduler.cpython-36.pyc
│ ├── scheduler.cpython-38.pyc
│ ├── scheduler_factory.cpython-36.pyc
│ ├── scheduler_factory.cpython-38.pyc
│ ├── step_lr.cpython-36.pyc
│ ├── step_lr.cpython-38.pyc
│ ├── tanh_lr.cpython-36.pyc
│ └── tanh_lr.cpython-38.pyc
├── cosine_lr.py
├── plateau_lr.py
├── scheduler.py
├── scheduler_factory.py
├── step_lr.py
└── tanh_lr.py
├── utils.py
├── visualization.ipynb
└── vqaTools
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── __init__.cpython-38.pyc
├── vqa.cpython-36.pyc
├── vqa.cpython-38.pyc
├── vqaEval.cpython-36.pyc
└── vqaEval.cpython-38.pyc
├── vqa.py
└── vqaEval.py
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2 | #ECCN:Open Source
3 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Salesforce Open Source Community Code of Conduct
2 |
3 | ## About the Code of Conduct
4 |
5 | Equality is a core value at Salesforce. We believe a diverse and inclusive
6 | community fosters innovation and creativity, and are committed to building a
7 | culture where everyone feels included.
8 |
9 | Salesforce open-source projects are committed to providing a friendly, safe, and
10 | welcoming environment for all, regardless of gender identity and expression,
11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
12 | race, age, religion, level of experience, education, socioeconomic status, or
13 | other similar personal characteristics.
14 |
15 | The goal of this code of conduct is to specify a baseline standard of behavior so
16 | that people with different social values and communication styles can work
17 | together effectively, productively, and respectfully in our open source community.
18 | It also establishes a mechanism for reporting issues and resolving conflicts.
19 |
20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior
21 | in a Salesforce open-source project may be reported by contacting the Salesforce
22 | Open Source Conduct Committee at ossconduct@salesforce.com.
23 |
24 | ## Our Pledge
25 |
26 | In the interest of fostering an open and welcoming environment, we as
27 | contributors and maintainers pledge to making participation in our project and
28 | our community a harassment-free experience for everyone, regardless of gender
29 | identity and expression, sexual orientation, disability, physical appearance,
30 | body size, ethnicity, nationality, race, age, religion, level of experience, education,
31 | socioeconomic status, or other similar personal characteristics.
32 |
33 | ## Our Standards
34 |
35 | Examples of behavior that contributes to creating a positive environment
36 | include:
37 |
38 | * Using welcoming and inclusive language
39 | * Being respectful of differing viewpoints and experiences
40 | * Gracefully accepting constructive criticism
41 | * Focusing on what is best for the community
42 | * Showing empathy toward other community members
43 |
44 | Examples of unacceptable behavior by participants include:
45 |
46 | * The use of sexualized language or imagery and unwelcome sexual attention or
47 | advances
48 | * Personal attacks, insulting/derogatory comments, or trolling
49 | * Public or private harassment
50 | * Publishing, or threatening to publish, others' private information—such as
51 | a physical or electronic address—without explicit permission
52 | * Other conduct which could reasonably be considered inappropriate in a
53 | professional setting
54 | * Advocating for or encouraging any of the above behaviors
55 |
56 | ## Our Responsibilities
57 |
58 | Project maintainers are responsible for clarifying the standards of acceptable
59 | behavior and are expected to take appropriate and fair corrective action in
60 | response to any instances of unacceptable behavior.
61 |
62 | Project maintainers have the right and responsibility to remove, edit, or
63 | reject comments, commits, code, wiki edits, issues, and other contributions
64 | that are not aligned with this Code of Conduct, or to ban temporarily or
65 | permanently any contributor for other behaviors that they deem inappropriate,
66 | threatening, offensive, or harmful.
67 |
68 | ## Scope
69 |
70 | This Code of Conduct applies both within project spaces and in public spaces
71 | when an individual is representing the project or its community. Examples of
72 | representing a project or community include using an official project email
73 | address, posting via an official social media account, or acting as an appointed
74 | representative at an online or offline event. Representation of a project may be
75 | further defined and clarified by project maintainers.
76 |
77 | ## Enforcement
78 |
79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
80 | reported by contacting the Salesforce Open Source Conduct Committee
81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated
82 | and will result in a response that is deemed necessary and appropriate to the
83 | circumstances. The committee is obligated to maintain confidentiality with
84 | regard to the reporter of an incident. Further details of specific enforcement
85 | policies may be posted separately.
86 |
87 | Project maintainers who do not follow or enforce the Code of Conduct in good
88 | faith may face temporary or permanent repercussions as determined by other
89 | members of the project's leadership and the Salesforce Open Source Conduct
90 | Committee.
91 |
92 | ## Attribution
93 |
94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
98 |
99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
100 |
101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
102 | [golang-coc]: https://golang.org/conduct
103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
106 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, Salesforce.com, Inc.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13 |
--------------------------------------------------------------------------------
/Pretrain_nlvr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import ruamel_yaml as yaml
4 | import numpy as np
5 | import random
6 | import time
7 | import datetime
8 | import json
9 | from pathlib import Path
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.utils.data import DataLoader
15 | import torch.backends.cudnn as cudnn
16 | import torch.distributed as dist
17 |
18 | from models.model_pretrain_nlvr import ALBEF
19 | from models.vit import interpolate_pos_embed
20 | from models.tokenization_bert import BertTokenizer
21 |
22 | import utils
23 | from dataset import create_dataset, create_sampler, create_loader
24 | from scheduler import create_scheduler
25 | from optim import create_optimizer
26 |
27 |
28 | def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config):
29 | # train
30 | model.train()
31 |
32 | metric_logger = utils.MetricLogger(delimiter=" ")
33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
34 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
35 |
36 | header = 'Train Epoch: [{}]'.format(epoch)
37 | print_freq = 50
38 | step_size = 100
39 | warmup_iterations = warmup_steps*step_size
40 |
41 | if args.distributed:
42 | data_loader.sampler.set_epoch(epoch)
43 |
44 | for i, (image, text) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
45 |
46 | optimizer.zero_grad()
47 |
48 | image = image.to(device,non_blocking=True)
49 | text_input = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
50 |
51 | loss = model(image, text_input)
52 | loss.backward()
53 |
54 | optimizer.step()
55 |
56 | metric_logger.update(loss=loss.item())
57 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
58 |
59 | if epoch==0 and i%step_size==0 and i<=warmup_iterations:
60 | scheduler.step(i//step_size)
61 |
62 | # gather the stats from all processes
63 | metric_logger.synchronize_between_processes()
64 | print("Averaged stats:", metric_logger.global_avg())
65 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
66 |
67 |
68 | def main(args, config):
69 | utils.init_distributed_mode(args)
70 |
71 | device = torch.device(args.device)
72 |
73 | # fix the seed for reproducibility
74 | seed = args.seed + utils.get_rank()
75 | torch.manual_seed(seed)
76 | np.random.seed(seed)
77 | random.seed(seed)
78 | cudnn.benchmark = True
79 |
80 | start_epoch = 0
81 | max_epoch = config['schedular']['epochs']
82 | warmup_steps = config['schedular']['warmup_epochs']
83 |
84 | #### Dataset ####
85 | print("Creating dataset")
86 | datasets = [create_dataset('pretrain', config)]
87 |
88 | if args.distributed:
89 | num_tasks = utils.get_world_size()
90 | global_rank = utils.get_rank()
91 | samplers = create_sampler(datasets, [True], num_tasks, global_rank)
92 | else:
93 | samplers = [None]
94 |
95 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
96 |
97 | tokenizer = BertTokenizer.from_pretrained(args.text_encoder)
98 |
99 | #### Model ####
100 | print("Creating model")
101 | model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer)
102 |
103 | model = model.to(device)
104 |
105 | arg_opt = utils.AttrDict(config['optimizer'])
106 | optimizer = create_optimizer(arg_opt, model)
107 | arg_sche = utils.AttrDict(config['schedular'])
108 | lr_scheduler, _ = create_scheduler(arg_sche, optimizer)
109 |
110 | if args.checkpoint:
111 | checkpoint = torch.load(args.checkpoint, map_location='cpu')
112 | state_dict = checkpoint['model']
113 | pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
114 | state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
115 |
116 | for key in list(state_dict.keys()):
117 | if 'bert' in key:
118 | new_key = key.replace('bert.','')
119 |
120 | if 'layer' in new_key:
121 | keys = new_key.split('.')
122 | layer_num = int(keys[3])
123 | # replicate the multimodal encoder's blocks for two images
124 | if layer_num>=6:
125 | new_layer_num = (layer_num-6)*2+6
126 | keys[3] = str(new_layer_num)
127 | new_key_0 = '.'.join(keys)
128 | state_dict[new_key_0] = state_dict[key]
129 | keys[3] = str(new_layer_num+1)
130 | new_key_1 = '.'.join(keys)
131 | state_dict[new_key_1] = state_dict[key]
132 | else:
133 | state_dict[new_key] = state_dict[key]
134 | else:
135 | state_dict[new_key] = state_dict[key]
136 | del state_dict[key]
137 |
138 | msg = model.load_state_dict(state_dict,strict=False)
139 | print('load checkpoint from %s'%args.checkpoint)
140 | print(msg)
141 |
142 | model_without_ddp = model
143 | if args.distributed:
144 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
145 | model_without_ddp = model.module
146 |
147 | print("Start training")
148 | start_time = time.time()
149 |
150 | for epoch in range(start_epoch, max_epoch):
151 |
152 | if epoch>0:
153 | lr_scheduler.step(epoch+warmup_steps)
154 |
155 | train_stats = train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config)
156 |
157 | if utils.is_main_process():
158 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
159 | 'epoch': epoch,
160 | }
161 | save_obj = {
162 | 'model': model_without_ddp.state_dict(),
163 | 'optimizer': optimizer.state_dict(),
164 | 'lr_scheduler': lr_scheduler.state_dict(),
165 | 'config': config,
166 | 'epoch': epoch,
167 | }
168 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
169 |
170 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
171 | f.write(json.dumps(log_stats) + "\n")
172 |
173 | dist.barrier()
174 |
175 | total_time = time.time() - start_time
176 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177 | print('Training time {}'.format(total_time_str))
178 |
179 |
180 |
181 | if __name__ == '__main__':
182 | parser = argparse.ArgumentParser()
183 | parser.add_argument('--config', default='./configs/NLVR_pretrain.yaml')
184 | parser.add_argument('--checkpoint', default='')
185 | parser.add_argument('--output_dir', default='output/NLVR_pretrain')
186 | parser.add_argument('--text_encoder', default='bert-base-uncased')
187 | parser.add_argument('--device', default='cuda')
188 | parser.add_argument('--seed', default=42, type=int)
189 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
190 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
191 | parser.add_argument('--distributed', default=True, type=bool)
192 | args = parser.parse_args()
193 |
194 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
195 |
196 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
197 |
198 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
199 |
200 | main(args, config)
201 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Align before Fuse: Vision and Language Representation Learning with Momentum Distillation, NeurIPS 2021 Spotlight (Salesforce Research).
2 |
3 | ## Announcement: ALBEF is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications!
4 |
5 | This is the official PyTorch implementation of the ALBEF paper [Blog].
6 | This repository supports pre-training on custom datasets, as well as finetuning on VQA, SNLI-VE, NLVR2, Image-Text Retrieval on MSCOCO and Flickr30k,
7 | and visual grounding on RefCOCO+. Pre-trained and finetuned checkpoints are released.
8 |
9 |
10 |
11 |
12 | ### Requirements:
13 | * pytorch 1.8.0
14 | * transformers 4.8.1
15 | * timm 0.4.9
16 |
17 | ### Download:
18 |
19 | * Pre-trained checkpoint [[14M](https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth)] / [[4M](https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF_4M.pth)]
20 | * Dataset json files for downstream tasks
21 | * Dataset json files for pre-training (the image paths in each json file need to be changed to your own directory)
22 | * Finetuned checkpoint for retrieval on MSCOCO
23 | * Finetuned checkpoint for retrieval on Flickr30k
24 | * Finetuned checkpoint for VQA
25 | * Finetuned checkpoint for visual grounding on RefCOCO+
26 |
27 | ### Visualization:
28 | We provide code in visualize.ipynb to visualize the important areas in an image for each word in a text.
29 | Here is an example visualization using the visual grounding checkpoint.
30 |
31 | Try the Replicate demo here [](https://replicate.com/salesforce/albef).
32 |
33 |
34 |
35 | ### Pre-training on custom datasets:
36 | 1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
37 | 2. In configs/Pretrain.yaml, set the paths for the json files.
38 | 3. Pre-train the model using 8 A100 GPUs:
39 |
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain
40 |
41 | ### Image-Text Retrieval:
42 |
43 | 1. Download MSCOCO or Flickr30k datasets from the original websites.
44 | 2. Download and extract the provided dataset json files.
45 | 3. In configs/Retrieval_coco.yaml or configs/Retrieval_flickr.yaml, set the paths for the json files and the image path.
46 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
47 | python -m torch.distributed.launch --nproc_per_node=8 --use_env Retrieval.py \
48 | --config ./configs/Retrieval_flickr.yaml \
49 | --output_dir output/Retrieval_flickr \
50 | --checkpoint [Pretrained checkpoint]
51 |
52 | ### VQA:
53 | 1. Download VQA v2 dataset and Visual Genome dataset from the original websites.
54 | 2. Download and extract the provided dataset json files.
55 | 3. In configs/VQA.yaml, set the paths for the json files and the image paths.
56 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
57 | python -m torch.distributed.launch --nproc_per_node=8 --use_env VQA.py \
58 | --config ./configs/VQA.yaml \
59 | --output_dir output/vqa \
60 | --checkpoint [Pretrained checkpoint]
61 | 5. Evaluate the result using the official evaluation server.
62 |
63 | ### Visual Entailment:
64 | 1. Download SNLI-VE dataset from the original website.
65 | 2. Download and extract the provided dataset json files.
66 | 3. In configs/VE.yaml, set the paths for the json files and the image path.
67 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
68 | python -m torch.distributed.launch --nproc_per_node=8 --use_env VE.py \
69 | --config ./configs/VE.yaml \
70 | --output_dir output/VE \
71 | --checkpoint [Pretrained checkpoint]
72 |
73 | ### Visual Grounding on RefCOCO+:
74 | 1. Download MSCOCO dataset from the original website.
75 | 2. Download and extract the provided dataset json files.
76 | 3. In configs/Grounding.yaml, set the paths for the json files and the image path.
77 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
78 | python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py \
79 | --config ./configs/Grounding.yaml \
80 | --output_dir output/RefCOCO \
81 | --gradcam_mode itm \
82 | --block_num 8 \
83 | --checkpoint [Pretrained checkpoint]
84 |
85 | ### NLVR2:
86 | NLVR2 requires an additional pre-training step with text-assignment (TA) to adapt the model for image-pair inputs. In order to perform TA, first set the paths for the json training files in configs/NLVR_pretrain.yaml, then run:
87 | python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain_nlvr.py \
88 | --config ./configs/NLVR_pretrain.yaml \
89 | --output_dir output/NLVR_pretrain \
90 | --checkpoint [Pretrained checkpoint]
91 |
92 | We provide the checkpoint after TA pre-training, which can be fine-tuned with the following steps.
93 | 1. Download NLVR2 dataset from the original website.
94 | 2. Download and extract the provided dataset json files.
95 | 3. In configs/NLVR.yaml, set the paths for the json files and the image path.
96 | 4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
97 | python -m torch.distributed.launch --nproc_per_node=8 --use_env NLVR.py \
98 | --config ./configs/NLVR.yaml \
99 | --output_dir output/NLVR \
100 | --checkpoint [TA pretrained checkpoint]
101 |
102 | ### Citation
103 | If you find this code to be useful for your research, please consider citing.
104 |
105 | @inproceedings{ALBEF,
106 | title={Align before Fuse: Vision and Language Representation Learning with Momentum Distillation},
107 | author={Junnan Li and Ramprasaath R. Selvaraju and Akhilesh Deepak Gotmare and Shafiq Joty and Caiming Xiong and Steven Hoi},
108 | year={2021},
109 | booktitle={NeurIPS},
110 | }
111 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | ## Security
2 |
3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
4 | as soon as it is discovered. This library limits its runtime dependencies in
5 | order to reduce the total cost of ownership as much as can be, but all consumers
6 | should remain vigilant and have their security stakeholders review all third-party
7 | products (3PP) like this one and their dependencies.
8 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | gpu: true
3 | cuda: "11.1"
4 | python_version: "3.8"
5 | system_packages:
6 | - "libgl1-mesa-glx"
7 | - "libglib2.0-0"
8 | python_packages:
9 | - "ipython==7.30.1"
10 | - "torchvision==0.11.1"
11 | - "torch==1.10.0"
12 | - "timm==0.4.12"
13 | - "transformers==4.8.1"
14 | - "Pillow==8.3.2"
15 | - "numpy==1.21.1"
16 | - "opencv-python==4.5.5.62"
17 | - "scipy==1.8.0"
18 | - "scikit_image==0.19.2"
19 | - "matplotlib==3.4.3"
20 |
21 | predict: "predict.py:Predictor"
22 |
--------------------------------------------------------------------------------
/configs/Grounding.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/refcoco+_train.json']
2 | test_file: ['data/refcoco+_val.json','data/refcoco+_test.json']
3 |
4 | refcoco_data: 'data'
5 | det_file: 'data/refcoco+/dets.json'
6 | coco_file: 'data/refcoco+/cocos.json'
7 |
8 | image_root: '/export/share/datasets/vision/coco/images/'
9 |
10 | bert_config: 'configs/config_bert.json'
11 |
12 | image_res: 384
13 | batch_size: 32
14 |
15 | queue_size: 65536
16 | momentum: 0.995
17 | vision_width: 768
18 | embed_dim: 256
19 | temp: 0.07
20 |
21 | alpha: 0.4
22 | distill: True
23 | warm_up: True
24 |
25 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02}
26 | schedular: {sched: cosine, lr: 1e-5, epochs: 5, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/configs/NLVR.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/nlvr_train.json']
2 | val_file: ['data/nlvr_dev.json']
3 | test_file: ['data/nlvr_test.json']
4 |
5 | image_root: '/export/share/datasets/vision/NLVR2/'
6 |
7 | image_res: 384
8 | batch_size: 16
9 |
10 | bert_config: 'configs/config_bert.json'
11 |
12 | alpha: 0.4
13 | distill: True
14 | warm_up: True
15 | eval_ema: False
16 |
17 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02}
18 | schedular: {sched: cosine, lr: 2e-5, epochs: 10, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/configs/NLVR_pretrain.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/coco.json',
2 | 'data/vg.json',
3 | 'data/cc3m_train.json',
4 | 'data/cc3m_val.json',
5 | 'data/sbu.json'
6 | ]
7 |
8 | # each train_file (json) contains a python list where each item is {'image': img_path, 'caption': text or list_of_text }
9 |
10 | bert_config: 'configs/config_bert.json'
11 |
12 | image_res: 256
13 | vision_width: 768
14 | embed_dim: 256
15 | batch_size: 64
16 |
17 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02}
18 | schedular: {sched: cosine, lr: 2e-5, epochs: 1, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/configs/Pretrain.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/coco.json',
2 | 'data/vg.json',
3 | 'data/cc12m.json',
4 | 'data/cc3m_train.json',
5 | 'data/cc3m_val.json',
6 | 'data/sbu.json'
7 | ]
8 | # each train_file (json) contains a python list where each item is {'image': img_path, 'caption': text or list_of_text }
9 | bert_config: 'configs/config_bert.json'
10 |
11 | image_res: 256
12 | vision_width: 768
13 | embed_dim: 256
14 | batch_size: 64
15 | temp: 0.07
16 | mlm_probability: 0.15
17 | queue_size: 65536
18 | momentum: 0.995
19 | alpha: 0.4
20 |
21 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.02}
22 | schedular: {sched: cosine, lr: 1e-4, epochs: 30, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 20, cooldown_epochs: 0}
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/configs/Retrieval_coco.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/coco_train.json']
2 | val_file: 'data/coco_val.json'
3 | test_file: 'data/coco_test.json'
4 | image_root: '/export/share/datasets/vision/coco/images/'
5 |
6 | bert_config: 'configs/config_bert.json'
7 |
8 | image_res: 384
9 | batch_size_train: 32
10 | batch_size_test: 64
11 |
12 | queue_size: 65536
13 | momentum: 0.995
14 | vision_width: 768
15 | embed_dim: 256
16 | temp: 0.07
17 | k_test: 256
18 |
19 | alpha: 0.4
20 | distill: True
21 | warm_up: True
22 |
23 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02}
24 | schedular: {sched: cosine, lr: 1e-5, epochs: 5, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/configs/Retrieval_flickr.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/flickr30k_train.json']
2 | val_file: 'data/flickr30k_val.json'
3 | test_file: 'data/flickr30k_test.json'
4 | image_root: '/export/share/datasets/vision/flickr30k/' #flickr30k-images/
5 |
6 | bert_config: 'configs/config_bert.json'
7 |
8 | image_res: 384
9 | batch_size_train: 32
10 | batch_size_test: 64
11 |
12 | queue_size: 65536
13 | momentum: 0.995
14 | vision_width: 768
15 | embed_dim: 256
16 | temp: 0.07
17 | k_test: 128
18 |
19 | alpha: 0.4
20 | distill: True
21 | warm_up: True
22 |
23 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02}
24 | schedular: {sched: cosine, lr: 1e-5, epochs: 10, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/configs/VE.yaml:
--------------------------------------------------------------------------------
1 | train_file: 'data/ve_train.json'
2 | val_file: 'data/ve_dev.json'
3 | test_file: 'data/ve_test.json'
4 |
5 | image_root: '/export/home/project/SNLI-VE/data/images'
6 |
7 | image_res: 384
8 | batch_size_train: 32
9 | batch_size_test: 64
10 |
11 | alpha: 0.4
12 | distill: True
13 | warm_up: False
14 |
15 | bert_config: 'configs/config_bert.json'
16 |
17 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02}
18 | schedular: {sched: cosine, lr: 2e-5, epochs: 5, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/configs/VQA.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['data/vqa_train.json',
2 | 'data/vqa_val.json',
3 | 'data/vg_qa.json']
4 |
5 | test_file: ['data/vqa_test.json']
6 | answer_list: 'data/answer_list.json'
7 |
8 | vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #train2014/
9 | vg_root: '/export/share/datasets/vision/visual-genome/' #image/
10 |
11 | image_res: 384
12 | batch_size_train: 32
13 | batch_size_test: 16
14 | k_test: 128
15 |
16 | alpha: 0.4
17 | distill: True
18 | warm_up: True
19 |
20 | eos: '[SEP]'
21 |
22 | bert_config: 'configs/config_bert.json'
23 |
24 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02}
25 | schedular: {sched: cosine, lr: 2e-5, epochs: 8, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 4, cooldown_epochs: 0}
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/configs/config_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522,
19 | "fusion_layer": 6,
20 | "encoder_width": 768
21 | }
22 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torchvision import transforms
4 | from PIL import Image
5 |
6 | from dataset.caption_dataset import re_train_dataset, re_eval_dataset, pretrain_dataset
7 | from dataset.nlvr_dataset import nlvr_dataset
8 | from dataset.ve_dataset import ve_dataset
9 | from dataset.vqa_dataset import vqa_dataset
10 | from dataset.grounding_dataset import grounding_dataset
11 |
12 | from dataset.randaugment import RandomAugment
13 |
14 | def create_dataset(dataset, config):
15 |
16 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
17 |
18 | pretrain_transform = transforms.Compose([
19 | transforms.RandomResizedCrop(config['image_res'],scale=(0.2, 1.0), interpolation=Image.BICUBIC),
20 | transforms.RandomHorizontalFlip(),
21 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
23 | transforms.ToTensor(),
24 | normalize,
25 | ])
26 | train_transform = transforms.Compose([
27 | transforms.RandomResizedCrop(config['image_res'],scale=(0.5, 1.0), interpolation=Image.BICUBIC),
28 | transforms.RandomHorizontalFlip(),
29 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
30 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
31 | transforms.ToTensor(),
32 | normalize,
33 | ])
34 | test_transform = transforms.Compose([
35 | transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
36 | transforms.ToTensor(),
37 | normalize,
38 | ])
39 |
40 | if dataset=='pretrain':
41 | dataset = pretrain_dataset(config['train_file'], pretrain_transform)
42 | return dataset
43 |
44 | elif dataset=='re':
45 | train_dataset = re_train_dataset(config['train_file'], train_transform, config['image_root'])
46 | val_dataset = re_eval_dataset(config['val_file'], test_transform, config['image_root'])
47 | test_dataset = re_eval_dataset(config['test_file'], test_transform, config['image_root'])
48 | return train_dataset, val_dataset, test_dataset
49 |
50 | elif dataset=='vqa':
51 | train_dataset = vqa_dataset(config['train_file'], train_transform, config['vqa_root'], config['vg_root'], split='train')
52 | vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], split='test', answer_list=config['answer_list'])
53 | return train_dataset, vqa_test_dataset
54 |
55 | elif dataset=='nlvr':
56 | train_dataset = nlvr_dataset(config['train_file'], train_transform, config['image_root'])
57 | val_dataset = nlvr_dataset(config['val_file'], test_transform, config['image_root'])
58 | test_dataset = nlvr_dataset(config['test_file'], test_transform, config['image_root'])
59 | return train_dataset, val_dataset, test_dataset
60 |
61 | elif dataset=='ve':
62 | train_dataset = ve_dataset(config['train_file'], train_transform, config['image_root'])
63 | val_dataset = ve_dataset(config['val_file'], test_transform, config['image_root'])
64 | test_dataset = ve_dataset(config['test_file'], test_transform, config['image_root'])
65 | return train_dataset, val_dataset, test_dataset
66 |
67 | elif dataset=='grounding':
68 | train_transform = transforms.Compose([
69 | transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
70 | transforms.RandomHorizontalFlip(),
71 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
72 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
73 | transforms.ToTensor(),
74 | normalize,
75 | ])
76 | train_dataset = grounding_dataset(config['train_file'], train_transform, config['image_root'], mode='train')
77 | test_dataset = grounding_dataset(config['test_file'], test_transform, config['image_root'], mode='test')
78 | return train_dataset, test_dataset
79 |
80 |
81 | def vqa_collate_fn(batch):
82 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
83 | for image, question, answer, weights in batch:
84 | image_list.append(image)
85 | question_list.append(question)
86 | weight_list += weights
87 | answer_list += answer
88 | n.append(len(answer))
89 | return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
90 |
91 |
92 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
93 | samplers = []
94 | for dataset,shuffle in zip(datasets,shuffles):
95 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
96 | samplers.append(sampler)
97 | return samplers
98 |
99 |
100 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
101 | loaders = []
102 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
103 | if is_train:
104 | shuffle = (sampler is None)
105 | drop_last = True
106 | else:
107 | shuffle = False
108 | drop_last = False
109 | loader = DataLoader(
110 | dataset,
111 | batch_size=bs,
112 | num_workers=n_worker,
113 | pin_memory=True,
114 | sampler=sampler,
115 | shuffle=shuffle,
116 | collate_fn=collate_fn,
117 | drop_last=drop_last,
118 | )
119 | loaders.append(loader)
120 | return loaders
--------------------------------------------------------------------------------
/dataset/caption_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | from torch.utils.data import Dataset
6 |
7 | from PIL import Image
8 | from PIL import ImageFile
9 | ImageFile.LOAD_TRUNCATED_IMAGES = True
10 | Image.MAX_IMAGE_PIXELS = None
11 |
12 | from dataset.utils import pre_caption
13 |
14 |
15 | class re_train_dataset(Dataset):
16 | def __init__(self, ann_file, transform, image_root, max_words=30):
17 | self.ann = []
18 | for f in ann_file:
19 | self.ann += json.load(open(f,'r'))
20 | self.transform = transform
21 | self.image_root = image_root
22 | self.max_words = max_words
23 | self.img_ids = {}
24 |
25 | n = 0
26 | for ann in self.ann:
27 | img_id = ann['image_id']
28 | if img_id not in self.img_ids.keys():
29 | self.img_ids[img_id] = n
30 | n += 1
31 |
32 | def __len__(self):
33 | return len(self.ann)
34 |
35 | def __getitem__(self, index):
36 |
37 | ann = self.ann[index]
38 |
39 | image_path = os.path.join(self.image_root,ann['image'])
40 | image = Image.open(image_path).convert('RGB')
41 | image = self.transform(image)
42 |
43 | caption = pre_caption(ann['caption'], self.max_words)
44 |
45 | return image, caption, self.img_ids[ann['image_id']]
46 |
47 |
48 |
49 | class re_eval_dataset(Dataset):
50 | def __init__(self, ann_file, transform, image_root, max_words=30):
51 | self.ann = json.load(open(ann_file,'r'))
52 | self.transform = transform
53 | self.image_root = image_root
54 | self.max_words = max_words
55 |
56 | self.text = []
57 | self.image = []
58 | self.txt2img = {}
59 | self.img2txt = {}
60 |
61 | txt_id = 0
62 | for img_id, ann in enumerate(self.ann):
63 | self.image.append(ann['image'])
64 | self.img2txt[img_id] = []
65 | for i, caption in enumerate(ann['caption']):
66 | self.text.append(pre_caption(caption,self.max_words))
67 | self.img2txt[img_id].append(txt_id)
68 | self.txt2img[txt_id] = img_id
69 | txt_id += 1
70 |
71 | def __len__(self):
72 | return len(self.image)
73 |
74 | def __getitem__(self, index):
75 |
76 | image_path = os.path.join(self.image_root, self.ann[index]['image'])
77 | image = Image.open(image_path).convert('RGB')
78 | image = self.transform(image)
79 |
80 | return image, index
81 |
82 |
83 |
84 | class pretrain_dataset(Dataset):
85 | def __init__(self, ann_file, transform, max_words=30):
86 | self.ann = []
87 | for f in ann_file:
88 | self.ann += json.load(open(f,'r'))
89 | self.transform = transform
90 | self.max_words = max_words
91 |
92 |
93 | def __len__(self):
94 | return len(self.ann)
95 |
96 |
97 | def __getitem__(self, index):
98 |
99 | ann = self.ann[index]
100 |
101 | if type(ann['caption']) == list:
102 | caption = pre_caption(random.choice(ann['caption']), self.max_words)
103 | else:
104 | caption = pre_caption(ann['caption'], self.max_words)
105 |
106 | image = Image.open(ann['image']).convert('RGB')
107 | image = self.transform(image)
108 |
109 | return image, caption
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/dataset/grounding_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 | from dataset.utils import pre_caption
6 |
7 | class grounding_dataset(Dataset):
8 | def __init__(self, ann_file, transform, image_root, max_words=30, mode='train'):
9 | self.ann = []
10 | for f in ann_file:
11 | self.ann += json.load(open(f,'r'))
12 | self.transform = transform
13 | self.image_root = image_root
14 | self.max_words = max_words
15 | self.mode = mode
16 |
17 | if self.mode == 'train':
18 | self.img_ids = {}
19 | n = 0
20 | for ann in self.ann:
21 | img_id = ann['image'].split('/')[-1]
22 | if img_id not in self.img_ids.keys():
23 | self.img_ids[img_id] = n
24 | n += 1
25 |
26 |
27 | def __len__(self):
28 | return len(self.ann)
29 |
30 | def __getitem__(self, index):
31 |
32 | ann = self.ann[index]
33 |
34 | image_path = os.path.join(self.image_root,ann['image'])
35 | image = Image.open(image_path).convert('RGB')
36 | image = self.transform(image)
37 |
38 | caption = pre_caption(ann['text'], self.max_words)
39 |
40 | if self.mode=='train':
41 | img_id = ann['image'].split('/')[-1]
42 |
43 | return image, caption, self.img_ids[img_id]
44 | else:
45 | return image, caption, ann['ref_id']
--------------------------------------------------------------------------------
/dataset/nlvr_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 | from dataset.utils import pre_caption
6 |
7 |
8 | class nlvr_dataset(Dataset):
9 | def __init__(self, ann_file, transform, image_root):
10 | self.ann = []
11 | for f in ann_file:
12 | self.ann += json.load(open(f,'r'))
13 | self.transform = transform
14 | self.image_root = image_root
15 | self.max_words = 30
16 |
17 | def __len__(self):
18 | return len(self.ann)
19 |
20 |
21 | def __getitem__(self, index):
22 |
23 | ann = self.ann[index]
24 |
25 | image0_path = os.path.join(self.image_root,ann['images'][0])
26 | image0 = Image.open(image0_path).convert('RGB')
27 | image0 = self.transform(image0)
28 |
29 | image1_path = os.path.join(self.image_root,ann['images'][1])
30 | image1 = Image.open(image1_path).convert('RGB')
31 | image1 = self.transform(image1)
32 |
33 | sentence = pre_caption(ann['sentence'], self.max_words)
34 |
35 | if ann['label']=='True':
36 | label = 1
37 | else:
38 | label = 0
39 |
40 | return image0, image1, sentence, label
--------------------------------------------------------------------------------
/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def pre_question(question,max_ques_words):
4 | question = re.sub(
5 | r"([,.'!?\"()*#:;~])",
6 | '',
7 | question.lower(),
8 | ).replace('-', ' ').replace('/', ' ')
9 | question = question.rstrip(' ')
10 |
11 | #truncate question
12 | question_words = question.split(' ')
13 | if len(question_words)>max_ques_words:
14 | question = ' '.join(question_words[:max_ques_words])
15 |
16 | return question
17 |
18 |
19 | def pre_caption(caption,max_words):
20 | caption = re.sub(
21 | r"([,.'!?\"()*#:;~])",
22 | '',
23 | caption.lower(),
24 | ).replace('-', ' ').replace('/', ' ').replace('', 'person')
25 |
26 | caption = re.sub(
27 | r"\s{2,}",
28 | ' ',
29 | caption,
30 | )
31 | caption = caption.rstrip('\n')
32 | caption = caption.strip(' ')
33 |
34 | #truncate caption
35 | caption_words = caption.split(' ')
36 | if len(caption_words)>max_words:
37 | caption = ' '.join(caption_words[:max_words])
38 |
39 | return caption
40 |
41 |
42 | from vqaTools.vqaEval import VQAEval
43 | from refTools.evaluation.refEvaluation import RefEvaluation
44 |
45 | import json
46 | import os
47 | import numpy as np
48 | import torch
49 | import torch.distributed as dist
50 | import torch.nn.functional as F
51 |
52 | import utils
53 | from tqdm import tqdm
54 |
55 |
56 | def vqa_eval(vqa, result_file, test_ques_path):
57 | vqaRes = vqa.loadRes(result_file, test_ques_path)
58 | # create vqaEval object by taking vqa and vqaRes
59 | vqaEval = VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2
60 | # evaluate results
61 | vqaEval.evaluate()
62 |
63 | # print accuracies
64 | print("\n")
65 | print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall']))
66 | print("Per Answer Type Accuracy is the following:")
67 | for ansType in vqaEval.accuracy['perAnswerType']:
68 | print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType]))
69 | print("\n")
70 |
71 | return vqaEval
72 |
73 |
74 |
75 | def collect_result(result, result_dir, filename, is_json=True, is_list=True):
76 | if is_json:
77 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
78 | final_result_file = os.path.join(result_dir, '%s.json'%filename)
79 | json.dump(result,open(result_file,'w'))
80 | else:
81 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,utils.get_rank()))
82 | final_result_file = os.path.join(result_dir, '%s.pth'%filename)
83 | torch.save(result,result_file)
84 |
85 | dist.barrier()
86 |
87 | result = None
88 | if utils.is_main_process():
89 | # combine results from all processes
90 | if is_list:
91 | result = []
92 | else:
93 | result = {}
94 | for rank in range(utils.get_world_size()):
95 | if is_json:
96 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
97 | res = json.load(open(result_file,'r'))
98 | else:
99 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,rank))
100 | res = torch.load(result_file)
101 | if is_list:
102 | result += res
103 | else:
104 | result.update(res)
105 |
106 | return result
107 |
108 |
109 | def save_result(result, result_dir, filename, is_json=True, is_list=True):
110 | if is_json:
111 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
112 | final_result_file = os.path.join(result_dir, '%s.json'%filename)
113 | json.dump(result,open(result_file,'w'))
114 | else:
115 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,utils.get_rank()))
116 | final_result_file = os.path.join(result_dir, '%s.pth'%filename)
117 | torch.save(result,result_file)
118 |
119 | dist.barrier()
120 |
121 | if utils.is_main_process():
122 | # combine results from all processes
123 | if is_list:
124 | result = []
125 | else:
126 | result = {}
127 | for rank in range(utils.get_world_size()):
128 | if is_json:
129 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
130 | res = json.load(open(result_file,'r'))
131 | else:
132 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,rank))
133 | res = torch.load(result_file)
134 | if is_list:
135 | result += res
136 | else:
137 | result.update(res)
138 | if is_json:
139 | json.dump(result,open(final_result_file,'w'))
140 | else:
141 | torch.save(result,final_result_file)
142 |
143 | print('result file saved to %s'%final_result_file)
144 | dist.barrier()
145 | return final_result_file
146 |
147 |
148 |
149 | def grounding_eval(results,dets,cocos,refer,alpha,mask_size=24):
150 |
151 | correct_A_d, correct_B_d, correct_val_d = 0, 0, 0
152 | correct_A, correct_B, correct_val = 0, 0, 0
153 | num_A,num_B,num_val = 0,0,0
154 |
155 | for res in tqdm(results):
156 |
157 | ref_id = res['ref_id']
158 | ref = refer.Refs[ref_id]
159 | ref_box = refer.refToAnn[ref_id]['bbox']
160 | image = refer.Imgs[ref['image_id']]
161 |
162 | mask = res['pred'].cuda().view(1,1,mask_size,mask_size)
163 | mask = F.interpolate(mask,size = (image['height'],image['width']), mode='bicubic').squeeze()
164 |
165 | # rank detection boxes
166 | max_score = 0
167 | for det in dets[str(ref['image_id'])]:
168 | score = mask[int(det[1]):int(det[1]+det[3]),int(det[0]):int(det[0]+det[2])]
169 | area = det[2]*det[3]
170 | score = score.sum() / area**alpha
171 | if score>max_score:
172 | pred_box = det[:4]
173 | max_score = score
174 |
175 | IoU_det = computeIoU(ref_box, pred_box)
176 |
177 | if ref['split']=='testA':
178 | num_A += 1
179 | if IoU_det >= 0.5:
180 | correct_A_d += 1
181 | elif ref['split']=='testB':
182 | num_B += 1
183 | if IoU_det >= 0.5:
184 | correct_B_d += 1
185 | elif ref['split']=='val':
186 | num_val += 1
187 | if IoU_det >= 0.5:
188 | correct_val_d += 1
189 |
190 | eval_result = {'val_d':correct_val_d/num_val,'testA_d':correct_A_d/num_A,'testB_d':correct_B_d/num_B}
191 |
192 | for metric, acc in eval_result.items():
193 | print(f'{metric}: {acc:.3f}')
194 |
195 | return eval_result
196 |
197 |
198 |
199 | # IoU function
200 | def computeIoU(box1, box2):
201 | # each box is of [x1, y1, w, h]
202 | inter_x1 = max(box1[0], box2[0])
203 | inter_y1 = max(box1[1], box2[1])
204 | inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1)
205 | inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1)
206 |
207 | if inter_x1 < inter_x2 and inter_y1 < inter_y2:
208 | inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
209 | else:
210 | inter = 0
211 | union = box1[2]*box1[3] + box2[2]*box2[3] - inter
212 | return float(inter)/union
213 |
214 |
215 |
--------------------------------------------------------------------------------
/dataset/ve_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 | from dataset.utils import pre_caption
6 |
7 |
8 | class ve_dataset(Dataset):
9 | def __init__(self, ann_file, transform, image_root, max_words=30):
10 | self.ann = json.load(open(ann_file,'r'))
11 | self.transform = transform
12 | self.image_root = image_root
13 | self.max_words = max_words
14 | self.labels = {'entailment':2,'neutral':1,'contradiction':0}
15 |
16 | def __len__(self):
17 | return len(self.ann)
18 |
19 |
20 | def __getitem__(self, index):
21 |
22 | ann = self.ann[index]
23 |
24 | image_path = os.path.join(self.image_root,'%s.jpg'%ann['image'])
25 | image = Image.open(image_path).convert('RGB')
26 | image = self.transform(image)
27 |
28 | sentence = pre_caption(ann['sentence'], self.max_words)
29 |
30 | return image, sentence, self.labels[ann['label']]
31 |
--------------------------------------------------------------------------------
/dataset/vqa_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from dataset.utils import pre_question
7 |
8 |
9 | class vqa_dataset(Dataset):
10 | def __init__(self, ann_file, transform, vqa_root, vg_root, eos='[SEP]', split="train", max_ques_words=30, answer_list=''):
11 | self.split = split
12 | self.ann = []
13 | for f in ann_file:
14 | self.ann += json.load(open(f,'r'))
15 |
16 | self.transform = transform
17 | self.vqa_root = vqa_root
18 | self.vg_root = vg_root
19 | self.max_ques_words = max_ques_words
20 | self.eos = eos
21 |
22 | if split=='test':
23 | self.max_ques_words = 50 # do not limit question length during test
24 | self.answer_list = json.load(open(answer_list,'r'))
25 |
26 |
27 | def __len__(self):
28 | return len(self.ann)
29 |
30 | def __getitem__(self, index):
31 |
32 | ann = self.ann[index]
33 |
34 | if ann['dataset']=='vqa':
35 | image_path = os.path.join(self.vqa_root,ann['image'])
36 | elif ann['dataset']=='vg':
37 | image_path = os.path.join(self.vg_root,ann['image'])
38 |
39 | image = Image.open(image_path).convert('RGB')
40 | image = self.transform(image)
41 |
42 | if self.split == 'test':
43 | question = pre_question(ann['question'],self.max_ques_words)
44 | question_id = ann['question_id']
45 | return image, question, question_id
46 |
47 |
48 | elif self.split=='train':
49 |
50 | question = pre_question(ann['question'],self.max_ques_words)
51 |
52 | if ann['dataset']=='vqa':
53 |
54 | answer_weight = {}
55 | for answer in ann['answer']:
56 | if answer in answer_weight.keys():
57 | answer_weight[answer] += 1/len(ann['answer'])
58 | else:
59 | answer_weight[answer] = 1/len(ann['answer'])
60 |
61 | answers = list(answer_weight.keys())
62 | weights = list(answer_weight.values())
63 |
64 | elif ann['dataset']=='vg':
65 | answers = [ann['answer']]
66 | weights = [0.5]
67 |
68 | answers = [answer+self.eos for answer in answers]
69 |
70 | return image, question, answers, weights
--------------------------------------------------------------------------------
/examples/image0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/examples/image0.jpg
--------------------------------------------------------------------------------
/examples/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/examples/visualization.png
--------------------------------------------------------------------------------
/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/img.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/models/__init__.py
--------------------------------------------------------------------------------
/models/model_nlvr.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from models.vit import VisionTransformer
3 | from models.xbert import BertConfig, BertModel
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | class ALBEF(nn.Module):
10 | def __init__(self,
11 | text_encoder = None,
12 | tokenizer = None,
13 | config = None,
14 | ):
15 | super().__init__()
16 |
17 | self.tokenizer = tokenizer
18 | self.distill = config['distill']
19 |
20 | self.visual_encoder = VisionTransformer(
21 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
22 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
23 |
24 | bert_config = BertConfig.from_json_file(config['bert_config'])
25 | bert_config.num_hidden_layers = 18
26 |
27 | self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
28 | self.cls_head = nn.Sequential(
29 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
30 | nn.ReLU(),
31 | nn.Linear(self.text_encoder.config.hidden_size, 2)
32 | )
33 |
34 | self.share_cross_attention(self.text_encoder.encoder)
35 |
36 | if self.distill:
37 | self.visual_encoder_m = VisionTransformer(
38 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
39 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
40 | self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
41 | self.share_cross_attention(self.text_encoder_m.encoder)
42 |
43 | self.cls_head_m = nn.Sequential(
44 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
45 | nn.ReLU(),
46 | nn.Linear(self.text_encoder.config.hidden_size, 2)
47 | )
48 |
49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
50 | [self.text_encoder,self.text_encoder_m],
51 | [self.cls_head,self.cls_head_m],
52 | ]
53 | self.copy_params()
54 | self.momentum = 0.995
55 |
56 |
57 | def forward(self, image, text, targets, alpha=0, train=True):
58 |
59 | image_embeds = self.visual_encoder(image)
60 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
61 |
62 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
63 |
64 | output = self.text_encoder(text.input_ids,
65 | attention_mask = text.attention_mask,
66 | encoder_hidden_states = [image0_embeds,image1_embeds],
67 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
68 | image_atts[image0_embeds.size(0):]],
69 | return_dict = True,
70 | )
71 | hidden_state = output.last_hidden_state[:,0,:]
72 | prediction = self.cls_head(hidden_state)
73 |
74 | if train:
75 | if self.distill:
76 | with torch.no_grad():
77 | self._momentum_update()
78 | image_embeds_m = self.visual_encoder_m(image)
79 | image0_embeds_m, image1_embeds_m = torch.split(image_embeds_m,targets.size(0))
80 | output_m = self.text_encoder_m(text.input_ids,
81 | attention_mask = text.attention_mask,
82 | encoder_hidden_states = [image0_embeds_m,image1_embeds_m],
83 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
84 | image_atts[image0_embeds.size(0):]],
85 | return_dict = True,
86 | )
87 | prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:])
88 |
89 | loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum(
90 | F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean()
91 | else:
92 | loss = F.cross_entropy(prediction, targets)
93 | return loss
94 | else:
95 | return prediction
96 |
97 |
98 |
99 | @torch.no_grad()
100 | def copy_params(self):
101 | for model_pair in self.model_pairs:
102 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
103 | param_m.data.copy_(param.data) # initialize
104 | param_m.requires_grad = False # not update by gradient
105 |
106 |
107 | @torch.no_grad()
108 | def _momentum_update(self):
109 | for model_pair in self.model_pairs:
110 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
111 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
112 |
113 |
114 | def share_cross_attention(self, model):
115 |
116 | for i in range(6):
117 | layer_num = 6+i*2
118 | modules_0 = model.layer[layer_num].crossattention.self._modules
119 | modules_1 = model.layer[layer_num+1].crossattention.self._modules
120 |
121 | for name in modules_0.keys():
122 | if 'key' in name or 'value' in name:
123 | module_0 = modules_0[name]
124 | module_1 = modules_1[name]
125 | if hasattr(module_0, "weight"):
126 | module_0.weight = module_1.weight
127 | if hasattr(module_0, "bias"):
128 | module_0.bias = module_1.bias
--------------------------------------------------------------------------------
/models/model_pretrain_nlvr.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from models.vit import VisionTransformer
3 | from models.xbert import BertConfig, BertModel
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | class ALBEF(nn.Module):
10 | def __init__(self,
11 | text_encoder = None,
12 | tokenizer = None,
13 | config = None,
14 | ):
15 | super().__init__()
16 |
17 | self.tokenizer = tokenizer
18 | vision_width = config['vision_width']
19 | embed_dim = config['embed_dim']
20 |
21 | self.visual_encoder = VisionTransformer(
22 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
23 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
24 |
25 | bert_config = BertConfig.from_json_file(config['bert_config'])
26 | bert_config.num_hidden_layers = 18
27 | self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
28 |
29 | #share the cross-attention layers for two images
30 | self.share_cross_attention(self.text_encoder.encoder)
31 |
32 | text_width = self.text_encoder.config.hidden_size
33 | self.vision_proj = nn.Linear(vision_width, embed_dim)
34 | self.text_proj = nn.Linear(text_width, embed_dim)
35 | self.temp = nn.Parameter(torch.ones([]) * 0.07)
36 | self.ta_head = nn.Linear(self.text_encoder.config.hidden_size, 3)
37 |
38 |
39 | def forward(self, image, text):
40 | image_embeds = self.visual_encoder(image)
41 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
42 | with torch.no_grad():
43 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
44 | sim = image_feat @ image_feat.t() / 0.07
45 | weights = F.softmax(sim,dim=1)
46 | weights.fill_diagonal_(0)
47 |
48 | image_inputs = [[],[]]
49 | labels = []
50 | for b in range(image.size(0)):
51 | if torch.rand(1)>1/3:
52 | idx = torch.multinomial(weights[b], 1).item()
53 | if torch.rand(1)>0.5:
54 | image_inputs[0].append(image_embeds[b])
55 | image_inputs[1].append(image_embeds[idx])
56 | labels.append(0)
57 | else:
58 | image_inputs[1].append(image_embeds[b])
59 | image_inputs[0].append(image_embeds[idx])
60 | labels.append(1)
61 | else:
62 | idx = torch.multinomial(weights[b], 2)
63 | image_inputs[0].append(image_embeds[idx[0]])
64 | image_inputs[1].append(image_embeds[idx[1]])
65 | labels.append(2)
66 |
67 | image_inputs[0] = torch.stack(image_inputs[0],dim=0)
68 | image_inputs[1] = torch.stack(image_inputs[1],dim=0)
69 | labels = torch.LongTensor(labels).to(image.device)
70 |
71 | output = self.text_encoder(text.input_ids,
72 | attention_mask = text.attention_mask,
73 | encoder_hidden_states = image_inputs,
74 | encoder_attention_mask = [image_atts,image_atts],
75 | return_dict = True,
76 | )
77 |
78 | pred = self.ta_head(output.last_hidden_state[:,0,:])
79 | loss = F.cross_entropy(pred, labels)
80 |
81 | return loss
82 |
83 |
84 |
85 | def share_cross_attention(self, model):
86 |
87 | for i in range(6):
88 | layer_num = 6+i*2
89 | modules_0 = model.layer[layer_num].crossattention.self._modules
90 | modules_1 = model.layer[layer_num+1].crossattention.self._modules
91 |
92 | for name in modules_0.keys():
93 | if 'key' in name or 'value' in name:
94 | module_0 = modules_0[name]
95 | module_1 = modules_1[name]
96 | if hasattr(module_0, "weight"):
97 | module_0.weight = module_1.weight
98 | if hasattr(module_0, "bias"):
99 | module_0.bias = module_1.bias
--------------------------------------------------------------------------------
/models/model_ve.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from models.vit import VisionTransformer
3 | from models.xbert import BertConfig, BertModel
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | class ALBEF(nn.Module):
10 | def __init__(self,
11 | text_encoder = None,
12 | tokenizer = None,
13 | config = None,
14 | ):
15 | super().__init__()
16 |
17 | self.tokenizer = tokenizer
18 | self.distill = config['distill']
19 |
20 | self.visual_encoder = VisionTransformer(
21 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
22 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
23 |
24 | bert_config = BertConfig.from_json_file(config['bert_config'])
25 |
26 | self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
27 |
28 | self.cls_head = nn.Sequential(
29 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
30 | nn.ReLU(),
31 | nn.Linear(self.text_encoder.config.hidden_size, 3)
32 | )
33 |
34 | if self.distill:
35 | self.visual_encoder_m = VisionTransformer(
36 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
37 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
38 | self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
39 | self.cls_head_m = nn.Sequential(
40 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
41 | nn.ReLU(),
42 | nn.Linear(self.text_encoder.config.hidden_size, 3)
43 | )
44 |
45 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
46 | [self.text_encoder,self.text_encoder_m],
47 | [self.cls_head,self.cls_head_m],
48 | ]
49 | self.copy_params()
50 | self.momentum = 0.995
51 |
52 |
53 | def forward(self, image, text, targets, alpha=0, train=True):
54 |
55 | image_embeds = self.visual_encoder(image)
56 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
57 |
58 | if train:
59 | output = self.text_encoder(text.input_ids,
60 | attention_mask = text.attention_mask,
61 | encoder_hidden_states = image_embeds,
62 | encoder_attention_mask = image_atts,
63 | return_dict = True
64 | )
65 | prediction = self.cls_head(output.last_hidden_state[:,0,:])
66 | if self.distill:
67 | with torch.no_grad():
68 | self._momentum_update()
69 | image_embeds_m = self.visual_encoder_m(image)
70 | output_m = self.text_encoder_m(text.input_ids,
71 | attention_mask = text.attention_mask,
72 | encoder_hidden_states = image_embeds_m,
73 | encoder_attention_mask = image_atts,
74 | return_dict = True
75 | )
76 | prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:])
77 |
78 | loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum(
79 | F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean()
80 | else:
81 | loss = F.cross_entropy(prediction, targets)
82 | return loss
83 |
84 | else:
85 | output = self.text_encoder(text.input_ids,
86 | attention_mask = text.attention_mask,
87 | encoder_hidden_states = image_embeds,
88 | encoder_attention_mask = image_atts,
89 | return_dict = True
90 | )
91 | prediction = self.cls_head(output.last_hidden_state[:,0,:])
92 | return prediction
93 |
94 |
95 |
96 | @torch.no_grad()
97 | def copy_params(self):
98 | for model_pair in self.model_pairs:
99 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
100 | param_m.data.copy_(param.data) # initialize
101 | param_m.requires_grad = False # not update by gradient
102 |
103 |
104 | @torch.no_grad()
105 | def _momentum_update(self):
106 | for model_pair in self.model_pairs:
107 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
108 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
109 |
110 |
111 |
--------------------------------------------------------------------------------
/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .adamp import AdamP
2 | from .adamw import AdamW
3 | from .adafactor import Adafactor
4 | from .adahessian import Adahessian
5 | from .lookahead import Lookahead
6 | from .nadam import Nadam
7 | from .novograd import NovoGrad
8 | from .nvnovograd import NvNovoGrad
9 | from .radam import RAdam
10 | from .rmsprop_tf import RMSpropTF
11 | from .sgdp import SGDP
12 |
13 | from .optim_factory import create_optimizer
--------------------------------------------------------------------------------
/optim/adahessian.py:
--------------------------------------------------------------------------------
1 | """ AdaHessian Optimizer
2 |
3 | Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
4 | Originally licensed MIT, Copyright 2020, David Samuel
5 | """
6 | import torch
7 |
8 |
9 | class Adahessian(torch.optim.Optimizer):
10 | """
11 | Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
12 |
13 | Arguments:
14 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups
15 | lr (float, optional): learning rate (default: 0.1)
16 | betas ((float, float), optional): coefficients used for computing running averages of gradient and the
17 | squared hessian trace (default: (0.9, 0.999))
18 | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
20 | hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
21 | update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
22 | (to save time) (default: 1)
23 | n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
24 | """
25 |
26 | def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
27 | hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
28 | if not 0.0 <= lr:
29 | raise ValueError(f"Invalid learning rate: {lr}")
30 | if not 0.0 <= eps:
31 | raise ValueError(f"Invalid epsilon value: {eps}")
32 | if not 0.0 <= betas[0] < 1.0:
33 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
34 | if not 0.0 <= betas[1] < 1.0:
35 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
36 | if not 0.0 <= hessian_power <= 1.0:
37 | raise ValueError(f"Invalid Hessian power value: {hessian_power}")
38 |
39 | self.n_samples = n_samples
40 | self.update_each = update_each
41 | self.avg_conv_kernel = avg_conv_kernel
42 |
43 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
44 | self.seed = 2147483647
45 | self.generator = torch.Generator().manual_seed(self.seed)
46 |
47 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
48 | super(Adahessian, self).__init__(params, defaults)
49 |
50 | for p in self.get_params():
51 | p.hess = 0.0
52 | self.state[p]["hessian step"] = 0
53 |
54 | @property
55 | def is_second_order(self):
56 | return True
57 |
58 | def get_params(self):
59 | """
60 | Gets all parameters in all param_groups with gradients
61 | """
62 |
63 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
64 |
65 | def zero_hessian(self):
66 | """
67 | Zeros out the accumalated hessian traces.
68 | """
69 |
70 | for p in self.get_params():
71 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
72 | p.hess.zero_()
73 |
74 | @torch.no_grad()
75 | def set_hessian(self):
76 | """
77 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
78 | """
79 |
80 | params = []
81 | for p in filter(lambda p: p.grad is not None, self.get_params()):
82 | if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
83 | params.append(p)
84 | self.state[p]["hessian step"] += 1
85 |
86 | if len(params) == 0:
87 | return
88 |
89 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
90 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
91 |
92 | grads = [p.grad for p in params]
93 |
94 | for i in range(self.n_samples):
95 | # Rademacher distribution {-1.0, 1.0}
96 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
97 | h_zs = torch.autograd.grad(
98 | grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
99 | for h_z, z, p in zip(h_zs, zs, params):
100 | p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
101 |
102 | @torch.no_grad()
103 | def step(self, closure=None):
104 | """
105 | Performs a single optimization step.
106 | Arguments:
107 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
108 | """
109 |
110 | loss = None
111 | if closure is not None:
112 | loss = closure()
113 |
114 | self.zero_hessian()
115 | self.set_hessian()
116 |
117 | for group in self.param_groups:
118 | for p in group['params']:
119 | if p.grad is None or p.hess is None:
120 | continue
121 |
122 | if self.avg_conv_kernel and p.dim() == 4:
123 | p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
124 |
125 | # Perform correct stepweight decay as in AdamW
126 | p.mul_(1 - group['lr'] * group['weight_decay'])
127 |
128 | state = self.state[p]
129 |
130 | # State initialization
131 | if len(state) == 1:
132 | state['step'] = 0
133 | # Exponential moving average of gradient values
134 | state['exp_avg'] = torch.zeros_like(p)
135 | # Exponential moving average of Hessian diagonal square values
136 | state['exp_hessian_diag_sq'] = torch.zeros_like(p)
137 |
138 | exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
139 | beta1, beta2 = group['betas']
140 | state['step'] += 1
141 |
142 | # Decay the first and second moment running average coefficient
143 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
144 | exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
145 |
146 | bias_correction1 = 1 - beta1 ** state['step']
147 | bias_correction2 = 1 - beta2 ** state['step']
148 |
149 | k = group['hessian_power']
150 | denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
151 |
152 | # make update
153 | step_size = group['lr'] / bias_correction1
154 | p.addcdiv_(exp_avg, denom, value=-step_size)
155 |
156 | return loss
157 |
--------------------------------------------------------------------------------
/optim/adamp.py:
--------------------------------------------------------------------------------
1 | """
2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
3 |
4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5 | Code: https://github.com/clovaai/AdamP
6 |
7 | Copyright (c) 2020-present NAVER Corp.
8 | MIT license
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.optim.optimizer import Optimizer, required
14 | import math
15 |
16 | class AdamP(Optimizer):
17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
21 | super(AdamP, self).__init__(params, defaults)
22 |
23 | def _channel_view(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | def _layer_view(self, x):
27 | return x.view(1, -1)
28 |
29 | def _cosine_similarity(self, x, y, eps, view_func):
30 | x = view_func(x)
31 | y = view_func(y)
32 |
33 | x_norm = x.norm(dim=1).add_(eps)
34 | y_norm = y.norm(dim=1).add_(eps)
35 | dot = (x * y).sum(dim=1)
36 |
37 | return dot.abs() / x_norm / y_norm
38 |
39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40 | wd = 1
41 | expand_size = [-1] + [1] * (len(p.shape) - 1)
42 | for view_func in [self._channel_view, self._layer_view]:
43 |
44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45 |
46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49 | wd = wd_ratio
50 |
51 | return perturb, wd
52 |
53 | return perturb, wd
54 |
55 | def step(self, closure=None):
56 | loss = None
57 | if closure is not None:
58 | loss = closure()
59 |
60 | for group in self.param_groups:
61 | for p in group['params']:
62 | if p.grad is None:
63 | continue
64 |
65 | grad = p.grad.data
66 | beta1, beta2 = group['betas']
67 | nesterov = group['nesterov']
68 |
69 | state = self.state[p]
70 |
71 | # State initialization
72 | if len(state) == 0:
73 | state['step'] = 0
74 | state['exp_avg'] = torch.zeros_like(p.data)
75 | state['exp_avg_sq'] = torch.zeros_like(p.data)
76 |
77 | # Adam
78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
79 |
80 | state['step'] += 1
81 | bias_correction1 = 1 - beta1 ** state['step']
82 | bias_correction2 = 1 - beta2 ** state['step']
83 |
84 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
86 |
87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
88 | step_size = group['lr'] / bias_correction1
89 |
90 | if nesterov:
91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
92 | else:
93 | perturb = exp_avg / denom
94 |
95 | # Projection
96 | wd_ratio = 1
97 | if len(p.shape) > 1:
98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
99 |
100 | # Weight decay
101 | if group['weight_decay'] > 0:
102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)
103 |
104 | # Step
105 | p.data.add_(-step_size, perturb)
106 |
107 | return loss
108 |
--------------------------------------------------------------------------------
/optim/adamw.py:
--------------------------------------------------------------------------------
1 | """ AdamW Optimizer
2 | Impl copied from PyTorch master
3 | """
4 | import math
5 | import torch
6 | from torch.optim.optimizer import Optimizer
7 |
8 |
9 | class AdamW(Optimizer):
10 | r"""Implements AdamW algorithm.
11 |
12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
14 |
15 | Arguments:
16 | params (iterable): iterable of parameters to optimize or dicts defining
17 | parameter groups
18 | lr (float, optional): learning rate (default: 1e-3)
19 | betas (Tuple[float, float], optional): coefficients used for computing
20 | running averages of gradient and its square (default: (0.9, 0.999))
21 | eps (float, optional): term added to the denominator to improve
22 | numerical stability (default: 1e-8)
23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2)
24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25 | algorithm from the paper `On the Convergence of Adam and Beyond`_
26 | (default: False)
27 |
28 | .. _Adam\: A Method for Stochastic Optimization:
29 | https://arxiv.org/abs/1412.6980
30 | .. _Decoupled Weight Decay Regularization:
31 | https://arxiv.org/abs/1711.05101
32 | .. _On the Convergence of Adam and Beyond:
33 | https://openreview.net/forum?id=ryQu7f-RZ
34 | """
35 |
36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
37 | weight_decay=1e-2, amsgrad=False):
38 | if not 0.0 <= lr:
39 | raise ValueError("Invalid learning rate: {}".format(lr))
40 | if not 0.0 <= eps:
41 | raise ValueError("Invalid epsilon value: {}".format(eps))
42 | if not 0.0 <= betas[0] < 1.0:
43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
44 | if not 0.0 <= betas[1] < 1.0:
45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
46 | defaults = dict(lr=lr, betas=betas, eps=eps,
47 | weight_decay=weight_decay, amsgrad=amsgrad)
48 | super(AdamW, self).__init__(params, defaults)
49 |
50 | def __setstate__(self, state):
51 | super(AdamW, self).__setstate__(state)
52 | for group in self.param_groups:
53 | group.setdefault('amsgrad', False)
54 |
55 | def step(self, closure=None):
56 | """Performs a single optimization step.
57 |
58 | Arguments:
59 | closure (callable, optional): A closure that reevaluates the model
60 | and returns the loss.
61 | """
62 | loss = None
63 | if closure is not None:
64 | loss = closure()
65 |
66 | for group in self.param_groups:
67 | for p in group['params']:
68 | if p.grad is None:
69 | continue
70 |
71 | # Perform stepweight decay
72 | p.data.mul_(1 - group['lr'] * group['weight_decay'])
73 |
74 | # Perform optimization step
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
78 | amsgrad = group['amsgrad']
79 |
80 | state = self.state[p]
81 |
82 | # State initialization
83 | if len(state) == 0:
84 | state['step'] = 0
85 | # Exponential moving average of gradient values
86 | state['exp_avg'] = torch.zeros_like(p.data)
87 | # Exponential moving average of squared gradient values
88 | state['exp_avg_sq'] = torch.zeros_like(p.data)
89 | if amsgrad:
90 | # Maintains max of all exp. moving avg. of sq. grad. values
91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
92 |
93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
94 | if amsgrad:
95 | max_exp_avg_sq = state['max_exp_avg_sq']
96 | beta1, beta2 = group['betas']
97 |
98 | state['step'] += 1
99 | bias_correction1 = 1 - beta1 ** state['step']
100 | bias_correction2 = 1 - beta2 ** state['step']
101 |
102 | # Decay the first and second moment running average coefficient
103 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
105 | if amsgrad:
106 | # Maintains the maximum of all 2nd moment running avg. till now
107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
108 | # Use the max. for normalizing running avg. of gradient
109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
110 | else:
111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
112 |
113 | step_size = group['lr'] / bias_correction1
114 |
115 | p.data.addcdiv_(-step_size, exp_avg, denom)
116 |
117 | return loss
118 |
--------------------------------------------------------------------------------
/optim/lookahead.py:
--------------------------------------------------------------------------------
1 | """ Lookahead Optimizer Wrapper.
2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import torch
8 | from torch.optim.optimizer import Optimizer
9 | from collections import defaultdict
10 |
11 |
12 | class Lookahead(Optimizer):
13 | def __init__(self, base_optimizer, alpha=0.5, k=6):
14 | if not 0.0 <= alpha <= 1.0:
15 | raise ValueError(f'Invalid slow update rate: {alpha}')
16 | if not 1 <= k:
17 | raise ValueError(f'Invalid lookahead steps: {k}')
18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
19 | self.base_optimizer = base_optimizer
20 | self.param_groups = self.base_optimizer.param_groups
21 | self.defaults = base_optimizer.defaults
22 | self.defaults.update(defaults)
23 | self.state = defaultdict(dict)
24 | # manually add our defaults to the param groups
25 | for name, default in defaults.items():
26 | for group in self.param_groups:
27 | group.setdefault(name, default)
28 |
29 | def update_slow(self, group):
30 | for fast_p in group["params"]:
31 | if fast_p.grad is None:
32 | continue
33 | param_state = self.state[fast_p]
34 | if 'slow_buffer' not in param_state:
35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data)
36 | param_state['slow_buffer'].copy_(fast_p.data)
37 | slow = param_state['slow_buffer']
38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow)
39 | fast_p.data.copy_(slow)
40 |
41 | def sync_lookahead(self):
42 | for group in self.param_groups:
43 | self.update_slow(group)
44 |
45 | def step(self, closure=None):
46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups)
47 | loss = self.base_optimizer.step(closure)
48 | for group in self.param_groups:
49 | group['lookahead_step'] += 1
50 | if group['lookahead_step'] % group['lookahead_k'] == 0:
51 | self.update_slow(group)
52 | return loss
53 |
54 | def state_dict(self):
55 | fast_state_dict = self.base_optimizer.state_dict()
56 | slow_state = {
57 | (id(k) if isinstance(k, torch.Tensor) else k): v
58 | for k, v in self.state.items()
59 | }
60 | fast_state = fast_state_dict['state']
61 | param_groups = fast_state_dict['param_groups']
62 | return {
63 | 'state': fast_state,
64 | 'slow_state': slow_state,
65 | 'param_groups': param_groups,
66 | }
67 |
68 | def load_state_dict(self, state_dict):
69 | fast_state_dict = {
70 | 'state': state_dict['state'],
71 | 'param_groups': state_dict['param_groups'],
72 | }
73 | self.base_optimizer.load_state_dict(fast_state_dict)
74 |
75 | # We want to restore the slow state, but share param_groups reference
76 | # with base_optimizer. This is a bit redundant but least code
77 | slow_state_new = False
78 | if 'slow_state' not in state_dict:
79 | print('Loading state_dict from optimizer without Lookahead applied.')
80 | state_dict['slow_state'] = defaultdict(dict)
81 | slow_state_new = True
82 | slow_state_dict = {
83 | 'state': state_dict['slow_state'],
84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code
85 | }
86 | super(Lookahead, self).load_state_dict(slow_state_dict)
87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container
88 | if slow_state_new:
89 | # reapply defaults to catch missing lookahead specific ones
90 | for name, default in self.defaults.items():
91 | for group in self.param_groups:
92 | group.setdefault(name, default)
93 |
--------------------------------------------------------------------------------
/optim/nadam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import Optimizer
3 |
4 |
5 | class Nadam(Optimizer):
6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
7 |
8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
9 |
10 | Arguments:
11 | params (iterable): iterable of parameters to optimize or dicts defining
12 | parameter groups
13 | lr (float, optional): learning rate (default: 2e-3)
14 | betas (Tuple[float, float], optional): coefficients used for computing
15 | running averages of gradient and its square
16 | eps (float, optional): term added to the denominator to improve
17 | numerical stability (default: 1e-8)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
20 |
21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf
22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
23 |
24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408
25 | NOTE: Has potential issues but does work well on some problems.
26 | """
27 |
28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
29 | weight_decay=0, schedule_decay=4e-3):
30 | defaults = dict(lr=lr, betas=betas, eps=eps,
31 | weight_decay=weight_decay, schedule_decay=schedule_decay)
32 | super(Nadam, self).__init__(params, defaults)
33 |
34 | def step(self, closure=None):
35 | """Performs a single optimization step.
36 |
37 | Arguments:
38 | closure (callable, optional): A closure that reevaluates the model
39 | and returns the loss.
40 | """
41 | loss = None
42 | if closure is not None:
43 | loss = closure()
44 |
45 | for group in self.param_groups:
46 | for p in group['params']:
47 | if p.grad is None:
48 | continue
49 | grad = p.grad.data
50 | state = self.state[p]
51 |
52 | # State initialization
53 | if len(state) == 0:
54 | state['step'] = 0
55 | state['m_schedule'] = 1.
56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_()
57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
58 |
59 | # Warming momentum schedule
60 | m_schedule = state['m_schedule']
61 | schedule_decay = group['schedule_decay']
62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
63 | beta1, beta2 = group['betas']
64 | eps = group['eps']
65 | state['step'] += 1
66 | t = state['step']
67 |
68 | if group['weight_decay'] != 0:
69 | grad = grad.add(group['weight_decay'], p.data)
70 |
71 | momentum_cache_t = beta1 * \
72 | (1. - 0.5 * (0.96 ** (t * schedule_decay)))
73 | momentum_cache_t_1 = beta1 * \
74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
75 | m_schedule_new = m_schedule * momentum_cache_t
76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
77 | state['m_schedule'] = m_schedule_new
78 |
79 | # Decay the first and second moment running average coefficient
80 | exp_avg.mul_(beta1).add_(1. - beta1, grad)
81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad)
82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
83 | denom = exp_avg_sq_prime.sqrt_().add_(eps)
84 |
85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom)
86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom)
87 |
88 | return loss
89 |
--------------------------------------------------------------------------------
/optim/novograd.py:
--------------------------------------------------------------------------------
1 | """NovoGrad Optimizer.
2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
4 | - https://arxiv.org/abs/1905.11286
5 | """
6 |
7 | import torch
8 | from torch.optim.optimizer import Optimizer
9 | import math
10 |
11 |
12 | class NovoGrad(Optimizer):
13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
15 | super(NovoGrad, self).__init__(params, defaults)
16 | self._lr = lr
17 | self._beta1 = betas[0]
18 | self._beta2 = betas[1]
19 | self._eps = eps
20 | self._wd = weight_decay
21 | self._grad_averaging = grad_averaging
22 |
23 | self._momentum_initialized = False
24 |
25 | def step(self, closure=None):
26 | loss = None
27 | if closure is not None:
28 | loss = closure()
29 |
30 | if not self._momentum_initialized:
31 | for group in self.param_groups:
32 | for p in group['params']:
33 | if p.grad is None:
34 | continue
35 | state = self.state[p]
36 | grad = p.grad.data
37 | if grad.is_sparse:
38 | raise RuntimeError('NovoGrad does not support sparse gradients')
39 |
40 | v = torch.norm(grad)**2
41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
42 | state['step'] = 0
43 | state['v'] = v
44 | state['m'] = m
45 | state['grad_ema'] = None
46 | self._momentum_initialized = True
47 |
48 | for group in self.param_groups:
49 | for p in group['params']:
50 | if p.grad is None:
51 | continue
52 | state = self.state[p]
53 | state['step'] += 1
54 |
55 | step, v, m = state['step'], state['v'], state['m']
56 | grad_ema = state['grad_ema']
57 |
58 | grad = p.grad.data
59 | g2 = torch.norm(grad)**2
60 | grad_ema = g2 if grad_ema is None else grad_ema * \
61 | self._beta2 + g2 * (1. - self._beta2)
62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
63 |
64 | if self._grad_averaging:
65 | grad *= (1. - self._beta1)
66 |
67 | g2 = torch.norm(grad)**2
68 | v = self._beta2*v + (1. - self._beta2)*g2
69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data)
70 | bias_correction1 = 1 - self._beta1 ** step
71 | bias_correction2 = 1 - self._beta2 ** step
72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
73 |
74 | state['v'], state['m'] = v, m
75 | state['grad_ema'] = grad_ema
76 | p.data.add_(-step_size, m)
77 | return loss
78 |
--------------------------------------------------------------------------------
/optim/nvnovograd.py:
--------------------------------------------------------------------------------
1 | """ Nvidia NovoGrad Optimizer.
2 | Original impl by Nvidia from Jasper example:
3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
5 | - https://arxiv.org/abs/1905.11286
6 | """
7 |
8 | import torch
9 | from torch.optim.optimizer import Optimizer
10 | import math
11 |
12 |
13 | class NvNovoGrad(Optimizer):
14 | """
15 | Implements Novograd algorithm.
16 |
17 | Args:
18 | params (iterable): iterable of parameters to optimize or dicts defining
19 | parameter groups
20 | lr (float, optional): learning rate (default: 1e-3)
21 | betas (Tuple[float, float], optional): coefficients used for computing
22 | running averages of gradient and its square (default: (0.95, 0.98))
23 | eps (float, optional): term added to the denominator to improve
24 | numerical stability (default: 1e-8)
25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
26 | grad_averaging: gradient averaging
27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28 | algorithm from the paper `On the Convergence of Adam and Beyond`_
29 | (default: False)
30 | """
31 |
32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
33 | weight_decay=0, grad_averaging=False, amsgrad=False):
34 | if not 0.0 <= lr:
35 | raise ValueError("Invalid learning rate: {}".format(lr))
36 | if not 0.0 <= eps:
37 | raise ValueError("Invalid epsilon value: {}".format(eps))
38 | if not 0.0 <= betas[0] < 1.0:
39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
40 | if not 0.0 <= betas[1] < 1.0:
41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
42 | defaults = dict(lr=lr, betas=betas, eps=eps,
43 | weight_decay=weight_decay,
44 | grad_averaging=grad_averaging,
45 | amsgrad=amsgrad)
46 |
47 | super(NvNovoGrad, self).__init__(params, defaults)
48 |
49 | def __setstate__(self, state):
50 | super(NvNovoGrad, self).__setstate__(state)
51 | for group in self.param_groups:
52 | group.setdefault('amsgrad', False)
53 |
54 | def step(self, closure=None):
55 | """Performs a single optimization step.
56 |
57 | Arguments:
58 | closure (callable, optional): A closure that reevaluates the model
59 | and returns the loss.
60 | """
61 | loss = None
62 | if closure is not None:
63 | loss = closure()
64 |
65 | for group in self.param_groups:
66 | for p in group['params']:
67 | if p.grad is None:
68 | continue
69 | grad = p.grad.data
70 | if grad.is_sparse:
71 | raise RuntimeError('Sparse gradients are not supported.')
72 | amsgrad = group['amsgrad']
73 |
74 | state = self.state[p]
75 |
76 | # State initialization
77 | if len(state) == 0:
78 | state['step'] = 0
79 | # Exponential moving average of gradient values
80 | state['exp_avg'] = torch.zeros_like(p.data)
81 | # Exponential moving average of squared gradient values
82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
83 | if amsgrad:
84 | # Maintains max of all exp. moving avg. of sq. grad. values
85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
86 |
87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
88 | if amsgrad:
89 | max_exp_avg_sq = state['max_exp_avg_sq']
90 | beta1, beta2 = group['betas']
91 |
92 | state['step'] += 1
93 |
94 | norm = torch.sum(torch.pow(grad, 2))
95 |
96 | if exp_avg_sq == 0:
97 | exp_avg_sq.copy_(norm)
98 | else:
99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
100 |
101 | if amsgrad:
102 | # Maintains the maximum of all 2nd moment running avg. till now
103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
104 | # Use the max. for normalizing running avg. of gradient
105 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
106 | else:
107 | denom = exp_avg_sq.sqrt().add_(group['eps'])
108 |
109 | grad.div_(denom)
110 | if group['weight_decay'] != 0:
111 | grad.add_(group['weight_decay'], p.data)
112 | if group['grad_averaging']:
113 | grad.mul_(1 - beta1)
114 | exp_avg.mul_(beta1).add_(grad)
115 |
116 | p.data.add_(-group['lr'], exp_avg)
117 |
118 | return loss
119 |
--------------------------------------------------------------------------------
/optim/optim_factory.py:
--------------------------------------------------------------------------------
1 | """ Optimizer Factory w/ Custom Weight Decay
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import torch
5 | from torch import optim as optim
6 |
7 | from .adafactor import Adafactor
8 | from .adahessian import Adahessian
9 | from .adamp import AdamP
10 | from .lookahead import Lookahead
11 | from .nadam import Nadam
12 | from .novograd import NovoGrad
13 | from .nvnovograd import NvNovoGrad
14 | from .radam import RAdam
15 | from .rmsprop_tf import RMSpropTF
16 | from .sgdp import SGDP
17 |
18 | try:
19 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
20 | has_apex = True
21 | except ImportError:
22 | has_apex = False
23 |
24 |
25 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
26 | decay = []
27 | no_decay = []
28 | for name, param in model.named_parameters():
29 | if not param.requires_grad:
30 | continue # frozen weights
31 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
32 | no_decay.append(param)
33 | else:
34 | decay.append(param)
35 | return [
36 | {'params': no_decay, 'weight_decay': 0.},
37 | {'params': decay, 'weight_decay': weight_decay}]
38 |
39 |
40 | def create_optimizer(args, model, filter_bias_and_bn=True):
41 | opt_lower = args.opt.lower()
42 | weight_decay = args.weight_decay
43 | if weight_decay and filter_bias_and_bn:
44 | skip = {}
45 | if hasattr(model, 'no_weight_decay'):
46 | skip = model.no_weight_decay()
47 | parameters = add_weight_decay(model, weight_decay, skip)
48 | weight_decay = 0.
49 | else:
50 | parameters = model.parameters()
51 |
52 | if 'fused' in opt_lower:
53 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
54 |
55 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
56 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
57 | opt_args['eps'] = args.opt_eps
58 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
59 | opt_args['betas'] = args.opt_betas
60 | if hasattr(args, 'opt_args') and args.opt_args is not None:
61 | opt_args.update(args.opt_args)
62 |
63 | opt_split = opt_lower.split('_')
64 | opt_lower = opt_split[-1]
65 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
66 | opt_args.pop('eps', None)
67 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
68 | elif opt_lower == 'momentum':
69 | opt_args.pop('eps', None)
70 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
71 | elif opt_lower == 'adam':
72 | optimizer = optim.Adam(parameters, **opt_args)
73 | elif opt_lower == 'adamw':
74 | optimizer = optim.AdamW(parameters, **opt_args)
75 | elif opt_lower == 'nadam':
76 | optimizer = Nadam(parameters, **opt_args)
77 | elif opt_lower == 'radam':
78 | optimizer = RAdam(parameters, **opt_args)
79 | elif opt_lower == 'adamp':
80 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
81 | elif opt_lower == 'sgdp':
82 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
83 | elif opt_lower == 'adadelta':
84 | optimizer = optim.Adadelta(parameters, **opt_args)
85 | elif opt_lower == 'adafactor':
86 | if not args.lr:
87 | opt_args['lr'] = None
88 | optimizer = Adafactor(parameters, **opt_args)
89 | elif opt_lower == 'adahessian':
90 | optimizer = Adahessian(parameters, **opt_args)
91 | elif opt_lower == 'rmsprop':
92 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
93 | elif opt_lower == 'rmsproptf':
94 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
95 | elif opt_lower == 'novograd':
96 | optimizer = NovoGrad(parameters, **opt_args)
97 | elif opt_lower == 'nvnovograd':
98 | optimizer = NvNovoGrad(parameters, **opt_args)
99 | elif opt_lower == 'fusedsgd':
100 | opt_args.pop('eps', None)
101 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
102 | elif opt_lower == 'fusedmomentum':
103 | opt_args.pop('eps', None)
104 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
105 | elif opt_lower == 'fusedadam':
106 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
107 | elif opt_lower == 'fusedadamw':
108 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
109 | elif opt_lower == 'fusedlamb':
110 | optimizer = FusedLAMB(parameters, **opt_args)
111 | elif opt_lower == 'fusednovograd':
112 | opt_args.setdefault('betas', (0.95, 0.98))
113 | optimizer = FusedNovoGrad(parameters, **opt_args)
114 | else:
115 | assert False and "Invalid optimizer"
116 | raise ValueError
117 |
118 | if len(opt_split) > 1:
119 | if opt_split[0] == 'lookahead':
120 | optimizer = Lookahead(optimizer)
121 |
122 | return optimizer
123 |
--------------------------------------------------------------------------------
/optim/radam.py:
--------------------------------------------------------------------------------
1 | """RAdam Optimizer.
2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
4 | """
5 | import math
6 | import torch
7 | from torch.optim.optimizer import Optimizer, required
8 |
9 |
10 | class RAdam(Optimizer):
11 |
12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
14 | self.buffer = [[None, None, None] for ind in range(10)]
15 | super(RAdam, self).__init__(params, defaults)
16 |
17 | def __setstate__(self, state):
18 | super(RAdam, self).__setstate__(state)
19 |
20 | def step(self, closure=None):
21 |
22 | loss = None
23 | if closure is not None:
24 | loss = closure()
25 |
26 | for group in self.param_groups:
27 |
28 | for p in group['params']:
29 | if p.grad is None:
30 | continue
31 | grad = p.grad.data.float()
32 | if grad.is_sparse:
33 | raise RuntimeError('RAdam does not support sparse gradients')
34 |
35 | p_data_fp32 = p.data.float()
36 |
37 | state = self.state[p]
38 |
39 | if len(state) == 0:
40 | state['step'] = 0
41 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
43 | else:
44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
46 |
47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
48 | beta1, beta2 = group['betas']
49 |
50 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
51 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
52 |
53 | state['step'] += 1
54 | buffered = self.buffer[int(state['step'] % 10)]
55 | if state['step'] == buffered[0]:
56 | N_sma, step_size = buffered[1], buffered[2]
57 | else:
58 | buffered[0] = state['step']
59 | beta2_t = beta2 ** state['step']
60 | N_sma_max = 2 / (1 - beta2) - 1
61 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
62 | buffered[1] = N_sma
63 |
64 | # more conservative since it's an approximated value
65 | if N_sma >= 5:
66 | step_size = group['lr'] * math.sqrt(
67 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
68 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
69 | else:
70 | step_size = group['lr'] / (1 - beta1 ** state['step'])
71 | buffered[2] = step_size
72 |
73 | if group['weight_decay'] != 0:
74 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
75 |
76 | # more conservative since it's an approximated value
77 | if N_sma >= 5:
78 | denom = exp_avg_sq.sqrt().add_(group['eps'])
79 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
80 | else:
81 | p_data_fp32.add_(-step_size, exp_avg)
82 |
83 | p.data.copy_(p_data_fp32)
84 |
85 | return loss
86 |
87 |
88 | class PlainRAdam(Optimizer):
89 |
90 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
91 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
92 |
93 | super(PlainRAdam, self).__init__(params, defaults)
94 |
95 | def __setstate__(self, state):
96 | super(PlainRAdam, self).__setstate__(state)
97 |
98 | def step(self, closure=None):
99 |
100 | loss = None
101 | if closure is not None:
102 | loss = closure()
103 |
104 | for group in self.param_groups:
105 |
106 | for p in group['params']:
107 | if p.grad is None:
108 | continue
109 | grad = p.grad.data.float()
110 | if grad.is_sparse:
111 | raise RuntimeError('RAdam does not support sparse gradients')
112 |
113 | p_data_fp32 = p.data.float()
114 |
115 | state = self.state[p]
116 |
117 | if len(state) == 0:
118 | state['step'] = 0
119 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
121 | else:
122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
124 |
125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
126 | beta1, beta2 = group['betas']
127 |
128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
129 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
130 |
131 | state['step'] += 1
132 | beta2_t = beta2 ** state['step']
133 | N_sma_max = 2 / (1 - beta2) - 1
134 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
135 |
136 | if group['weight_decay'] != 0:
137 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
138 |
139 | # more conservative since it's an approximated value
140 | if N_sma >= 5:
141 | step_size = group['lr'] * math.sqrt(
142 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
143 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
144 | denom = exp_avg_sq.sqrt().add_(group['eps'])
145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
146 | else:
147 | step_size = group['lr'] / (1 - beta1 ** state['step'])
148 | p_data_fp32.add_(-step_size, exp_avg)
149 |
150 | p.data.copy_(p_data_fp32)
151 |
152 | return loss
153 |
--------------------------------------------------------------------------------
/optim/rmsprop_tf.py:
--------------------------------------------------------------------------------
1 | """ RMSProp modified to behave like Tensorflow impl
2 |
3 | Originally cut & paste from PyTorch RMSProp
4 | https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
5 | Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
6 |
7 | Modifications Copyright 2020 Ross Wightman
8 | """
9 |
10 | import torch
11 | from torch.optim import Optimizer
12 |
13 |
14 | class RMSpropTF(Optimizer):
15 | """Implements RMSprop algorithm (TensorFlow style epsilon)
16 |
17 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
18 | and a few other modifications to closer match Tensorflow for matching hyper-params.
19 |
20 | Noteworthy changes include:
21 | 1. Epsilon applied inside square-root
22 | 2. square_avg initialized to ones
23 | 3. LR scaling of update accumulated in momentum buffer
24 |
25 | Proposed by G. Hinton in his
26 | `course `_.
27 |
28 | The centered version first appears in `Generating Sequences
29 | With Recurrent Neural Networks `_.
30 |
31 | Arguments:
32 | params (iterable): iterable of parameters to optimize or dicts defining
33 | parameter groups
34 | lr (float, optional): learning rate (default: 1e-2)
35 | momentum (float, optional): momentum factor (default: 0)
36 | alpha (float, optional): smoothing (decay) constant (default: 0.9)
37 | eps (float, optional): term added to the denominator to improve
38 | numerical stability (default: 1e-10)
39 | centered (bool, optional) : if ``True``, compute the centered RMSProp,
40 | the gradient is normalized by an estimation of its variance
41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
43 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
44 | update as per defaults in Tensorflow
45 |
46 | """
47 |
48 | def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
49 | decoupled_decay=False, lr_in_momentum=True):
50 | if not 0.0 <= lr:
51 | raise ValueError("Invalid learning rate: {}".format(lr))
52 | if not 0.0 <= eps:
53 | raise ValueError("Invalid epsilon value: {}".format(eps))
54 | if not 0.0 <= momentum:
55 | raise ValueError("Invalid momentum value: {}".format(momentum))
56 | if not 0.0 <= weight_decay:
57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
58 | if not 0.0 <= alpha:
59 | raise ValueError("Invalid alpha value: {}".format(alpha))
60 |
61 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
62 | decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
63 | super(RMSpropTF, self).__init__(params, defaults)
64 |
65 | def __setstate__(self, state):
66 | super(RMSpropTF, self).__setstate__(state)
67 | for group in self.param_groups:
68 | group.setdefault('momentum', 0)
69 | group.setdefault('centered', False)
70 |
71 | def step(self, closure=None):
72 | """Performs a single optimization step.
73 |
74 | Arguments:
75 | closure (callable, optional): A closure that reevaluates the model
76 | and returns the loss.
77 | """
78 | loss = None
79 | if closure is not None:
80 | loss = closure()
81 |
82 | for group in self.param_groups:
83 | for p in group['params']:
84 | if p.grad is None:
85 | continue
86 | grad = p.grad.data
87 | if grad.is_sparse:
88 | raise RuntimeError('RMSprop does not support sparse gradients')
89 | state = self.state[p]
90 |
91 | # State initialization
92 | if len(state) == 0:
93 | state['step'] = 0
94 | state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero
95 | if group['momentum'] > 0:
96 | state['momentum_buffer'] = torch.zeros_like(p.data)
97 | if group['centered']:
98 | state['grad_avg'] = torch.zeros_like(p.data)
99 |
100 | square_avg = state['square_avg']
101 | one_minus_alpha = 1. - group['alpha']
102 |
103 | state['step'] += 1
104 |
105 | if group['weight_decay'] != 0:
106 | if 'decoupled_decay' in group and group['decoupled_decay']:
107 | p.data.add_(-group['weight_decay'], p.data)
108 | else:
109 | grad = grad.add(group['weight_decay'], p.data)
110 |
111 | # Tensorflow order of ops for updating squared avg
112 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
113 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
114 |
115 | if group['centered']:
116 | grad_avg = state['grad_avg']
117 | grad_avg.add_(one_minus_alpha, grad - grad_avg)
118 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
119 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt
120 | else:
121 | avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
122 |
123 | if group['momentum'] > 0:
124 | buf = state['momentum_buffer']
125 | # Tensorflow accumulates the LR scaling in the momentum buffer
126 | if 'lr_in_momentum' in group and group['lr_in_momentum']:
127 | buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
128 | p.data.add_(-buf)
129 | else:
130 | # PyTorch scales the param update by LR
131 | buf.mul_(group['momentum']).addcdiv_(grad, avg)
132 | p.data.add_(-group['lr'], buf)
133 | else:
134 | p.data.addcdiv_(-group['lr'], grad, avg)
135 |
136 | return loss
137 |
--------------------------------------------------------------------------------
/optim/sgdp.py:
--------------------------------------------------------------------------------
1 | """
2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
3 |
4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5 | Code: https://github.com/clovaai/AdamP
6 |
7 | Copyright (c) 2020-present NAVER Corp.
8 | MIT license
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.optim.optimizer import Optimizer, required
14 | import math
15 |
16 | class SGDP(Optimizer):
17 | def __init__(self, params, lr=required, momentum=0, dampening=0,
18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
21 | super(SGDP, self).__init__(params, defaults)
22 |
23 | def _channel_view(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | def _layer_view(self, x):
27 | return x.view(1, -1)
28 |
29 | def _cosine_similarity(self, x, y, eps, view_func):
30 | x = view_func(x)
31 | y = view_func(y)
32 |
33 | x_norm = x.norm(dim=1).add_(eps)
34 | y_norm = y.norm(dim=1).add_(eps)
35 | dot = (x * y).sum(dim=1)
36 |
37 | return dot.abs() / x_norm / y_norm
38 |
39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40 | wd = 1
41 | expand_size = [-1] + [1] * (len(p.shape) - 1)
42 | for view_func in [self._channel_view, self._layer_view]:
43 |
44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45 |
46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49 | wd = wd_ratio
50 |
51 | return perturb, wd
52 |
53 | return perturb, wd
54 |
55 | def step(self, closure=None):
56 | loss = None
57 | if closure is not None:
58 | loss = closure()
59 |
60 | for group in self.param_groups:
61 | weight_decay = group['weight_decay']
62 | momentum = group['momentum']
63 | dampening = group['dampening']
64 | nesterov = group['nesterov']
65 |
66 | for p in group['params']:
67 | if p.grad is None:
68 | continue
69 | grad = p.grad.data
70 | state = self.state[p]
71 |
72 | # State initialization
73 | if len(state) == 0:
74 | state['momentum'] = torch.zeros_like(p.data)
75 |
76 | # SGD
77 | buf = state['momentum']
78 | buf.mul_(momentum).add_(1 - dampening, grad)
79 | if nesterov:
80 | d_p = grad + momentum * buf
81 | else:
82 | d_p = buf
83 |
84 | # Projection
85 | wd_ratio = 1
86 | if len(p.shape) > 1:
87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
88 |
89 | # Weight decay
90 | if weight_decay != 0:
91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
92 |
93 | # Step
94 | p.data.add_(-group['lr'], d_p)
95 |
96 | return loss
97 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import re
2 | import tempfile
3 | from functools import partial
4 | import cv2
5 | from PIL import Image
6 | import numpy as np
7 | from cog import BasePredictor, Path, Input
8 |
9 | from skimage import transform as skimage_transform
10 | from scipy.ndimage import filters
11 | from matplotlib import pyplot as plt
12 |
13 | import torch
14 | from torch import nn
15 | from torchvision import transforms
16 |
17 | from models.vit import VisionTransformer
18 | from models.xbert import BertConfig, BertModel
19 | from models.tokenization_bert import BertTokenizer
20 |
21 |
22 | class Predictor(BasePredictor):
23 | def setup(self):
24 | normalize = transforms.Normalize(
25 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
26 | )
27 |
28 | self.transform = transforms.Compose(
29 | [
30 | transforms.Resize((384, 384), interpolation=Image.BICUBIC),
31 | transforms.ToTensor(),
32 | normalize,
33 | ]
34 | )
35 |
36 | self.tokenizer = BertTokenizer.from_pretrained("bert/bert-base-uncased")
37 |
38 | bert_config_path = "configs/config_bert.json"
39 | self.model = VL_Transformer_ITM(
40 | text_encoder="bert/bert-base-uncased", config_bert=bert_config_path
41 | )
42 |
43 | checkpoint = torch.load("refcoco.pth", map_location="cpu")
44 | msg = self.model.load_state_dict(checkpoint, strict=False)
45 | self.model.eval()
46 |
47 | self.block_num = 8
48 | self.model.text_encoder.base_model.base_model.encoder.layer[
49 | self.block_num
50 | ].crossattention.self.save_attention = True
51 |
52 | self.model.cuda()
53 |
54 | def predict(
55 | self,
56 | image: Path = Input(description="Input image."),
57 | caption: str = Input(
58 | description="Caption for the image. Grad-CAM visualization will be generated "
59 | "for each word in the cation."
60 | ),
61 | ) -> Path:
62 |
63 | image_pil = Image.open(str(image)).convert("RGB")
64 | img = self.transform(image_pil).unsqueeze(0)
65 |
66 | text = pre_caption(caption)
67 | text_input = self.tokenizer(text, return_tensors="pt")
68 |
69 | img = img.cuda()
70 | text_input = text_input.to(img.device)
71 |
72 | # Compute GradCAM
73 | output = self.model(img, text_input)
74 | loss = output[:, 1].sum()
75 |
76 | self.model.zero_grad()
77 | loss.backward()
78 |
79 | with torch.no_grad():
80 | mask = text_input.attention_mask.view(
81 | text_input.attention_mask.size(0), 1, -1, 1, 1
82 | )
83 |
84 | grads = self.model.text_encoder.base_model.base_model.encoder.layer[
85 | self.block_num
86 | ].crossattention.self.get_attn_gradients()
87 | cams = self.model.text_encoder.base_model.base_model.encoder.layer[
88 | self.block_num
89 | ].crossattention.self.get_attention_map()
90 |
91 | cams = cams[:, :, :, 1:].reshape(img.size(0), 12, -1, 24, 24) * mask
92 | grads = (
93 | grads[:, :, :, 1:].clamp(0).reshape(img.size(0), 12, -1, 24, 24) * mask
94 | )
95 |
96 | gradcam = cams * grads
97 | gradcam = gradcam[0].mean(0).cpu().detach()
98 |
99 | num_image = len(text_input.input_ids[0])
100 | fig, ax = plt.subplots(num_image, 1, figsize=(20, 8 * num_image))
101 |
102 | rgb_image = cv2.imread(str(image))[:, :, ::-1]
103 | rgb_image = np.float32(rgb_image) / 255
104 |
105 | ax[0].imshow(rgb_image)
106 | ax[0].set_yticks([])
107 | ax[0].set_xticks([])
108 | ax[0].set_xlabel("Image")
109 |
110 | for i, token_id in enumerate(text_input.input_ids[0][1:]):
111 | word = self.tokenizer.decode([token_id])
112 | gradcam_image = getAttMap(rgb_image, gradcam[i + 1])
113 | ax[i + 1].imshow(gradcam_image)
114 | ax[i + 1].set_yticks([])
115 | ax[i + 1].set_xticks([])
116 | ax[i + 1].set_xlabel(word)
117 |
118 | out_path = Path(tempfile.mkdtemp()) / "output.png"
119 | fig.savefig(str(out_path))
120 | return out_path
121 |
122 |
123 | class VL_Transformer_ITM(nn.Module):
124 | def __init__(self, text_encoder=None, config_bert=""):
125 | super().__init__()
126 |
127 | bert_config = BertConfig.from_json_file(config_bert)
128 |
129 | self.visual_encoder = VisionTransformer(
130 | img_size=384,
131 | patch_size=16,
132 | embed_dim=768,
133 | depth=12,
134 | num_heads=12,
135 | mlp_ratio=4,
136 | qkv_bias=True,
137 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
138 | )
139 |
140 | self.text_encoder = BertModel.from_pretrained(
141 | text_encoder, config=bert_config, add_pooling_layer=False
142 | )
143 |
144 | self.itm_head = nn.Linear(768, 2)
145 |
146 | def forward(self, image, text):
147 | image_embeds = self.visual_encoder(image)
148 |
149 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
150 | image.device
151 | )
152 |
153 | output = self.text_encoder(
154 | text.input_ids,
155 | attention_mask=text.attention_mask,
156 | encoder_hidden_states=image_embeds,
157 | encoder_attention_mask=image_atts,
158 | return_dict=True,
159 | )
160 |
161 | vl_embeddings = output.last_hidden_state[:, 0, :]
162 | vl_output = self.itm_head(vl_embeddings)
163 | return vl_output
164 |
165 |
166 | def pre_caption(caption, max_words=30):
167 | caption = (
168 | re.sub(
169 | r"([,.'!?\"()*#:;~])",
170 | "",
171 | caption.lower(),
172 | )
173 | .replace("-", " ")
174 | .replace("/", " ")
175 | )
176 |
177 | caption = re.sub(
178 | r"\s{2,}",
179 | " ",
180 | caption,
181 | )
182 | caption = caption.rstrip("\n")
183 | caption = caption.strip(" ")
184 |
185 | # truncate caption
186 | caption_words = caption.split(" ")
187 | if len(caption_words) > max_words:
188 | caption = " ".join(caption_words[:max_words])
189 | return caption
190 |
191 |
192 | def getAttMap(img, attMap, blur=True, overlap=True):
193 | attMap -= attMap.min()
194 | if attMap.max() > 0:
195 | attMap /= attMap.max()
196 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
197 | if blur:
198 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
199 | attMap -= attMap.min()
200 | attMap /= attMap.max()
201 | cmap = plt.get_cmap("jet")
202 | attMapV = cmap(attMap)
203 | attMapV = np.delete(attMapV, 3, 2)
204 | if overlap:
205 | attMap = (
206 | 1 * (1 - attMap ** 0.7).reshape(attMap.shape + (1,)) * img
207 | + (attMap ** 0.7).reshape(attMap.shape + (1,)) * attMapV
208 | )
209 | return attMap
210 |
--------------------------------------------------------------------------------
/refTools/__pycache__/refer_python3.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/__pycache__/refer_python3.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/__pycache__/refer_python3.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/__pycache__/refer_python3.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'licheng'
2 |
3 |
4 |
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/refEvaluation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/__pycache__/refEvaluation.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/refEvaluation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/__pycache__/refEvaluation.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/bleu.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/bleu.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from refTools.evaluation.bleu.bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 |
23 | assert(gts.keys() == res.keys())
24 | imgIds = gts.keys()
25 |
26 | bleu_scorer = BleuScorer(n=self._n)
27 | for id in imgIds:
28 | hypo = res[id]
29 | ref = gts[id]
30 |
31 | # Sanity check.
32 | assert(type(hypo) is list)
33 | assert(len(hypo) == 1)
34 | assert(type(ref) is list)
35 | assert(len(ref) >= 1)
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | #score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
42 |
43 | # return (bleu, bleu_info)
44 | return score, scores
45 |
46 | def method(self):
47 | return "Bleu"
48 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/bleu.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/bleu.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/bleu_scorer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/bleu/bleu_scorer.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/cider.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/cider.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from refTools.evaluation.cider.cider_scorer import CiderScorer
11 | import pdb
12 |
13 | class Cider:
14 | """
15 | Main Class to compute the CIDEr metric
16 |
17 | """
18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
19 | # set cider to sum over 1 to 4-grams
20 | self._n = n
21 | # set the standard deviation parameter for gaussian penalty
22 | self._sigma = sigma
23 |
24 | def compute_score(self, gts, res):
25 | """
26 | Main function to compute CIDEr score
27 | :param hypo_for_image (dict) : dictionary with key and value
28 | ref_for_image (dict) : dictionary with key and value
29 | :return: cider (float) : computed CIDEr score for the corpus
30 | """
31 |
32 | assert(gts.keys() == res.keys())
33 | imgIds = gts.keys()
34 |
35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
36 |
37 | for id in imgIds:
38 | hypo = res[id]
39 | ref = gts[id]
40 |
41 | # Sanity check.
42 | assert(type(hypo) is list)
43 | assert(len(hypo) == 1)
44 | assert(type(ref) is list)
45 | assert(len(ref) > 0)
46 |
47 | cider_scorer += (hypo[0], ref)
48 |
49 | (score, scores) = cider_scorer.compute_score()
50 |
51 | return score, scores
52 |
53 | def method(self):
54 | return "CIDEr"
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/cider.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import pdb
9 | import math
10 |
11 | def precook(s, n=4, out=False):
12 | """
13 | Takes a string as input and returns an object that can be given to
14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
15 | can take string arguments as well.
16 | :param s: string : sentence to be converted into ngrams
17 | :param n: int : number of ngrams for which representation is calculated
18 | :return: term frequency vector for occuring ngrams
19 | """
20 | words = s.split()
21 | counts = defaultdict(int)
22 | for k in xrange(1,n+1):
23 | for i in xrange(len(words)-k+1):
24 | ngram = tuple(words[i:i+k])
25 | counts[ngram] += 1
26 | return counts
27 |
28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29 | '''Takes a list of reference sentences for a single segment
30 | and returns an object that encapsulates everything that BLEU
31 | needs to know about them.
32 | :param refs: list of string : reference sentences for some image
33 | :param n: int : number of ngrams for which (ngram) representation is calculated
34 | :return: result (list of dict)
35 | '''
36 | return [precook(ref, n) for ref in refs]
37 |
38 | def cook_test(test, n=4):
39 | '''Takes a test sentence and returns an object that
40 | encapsulates everything that BLEU needs to know about it.
41 | :param test: list of string : hypothesis sentence for some image
42 | :param n: int : number of ngrams for which (ngram) representation is calculated
43 | :return: result (dict)
44 | '''
45 | return precook(test, n, True)
46 |
47 | class CiderScorer(object):
48 | """CIDEr scorer.
49 | """
50 |
51 | def copy(self):
52 | ''' copy the refs.'''
53 | new = CiderScorer(n=self.n)
54 | new.ctest = copy.copy(self.ctest)
55 | new.crefs = copy.copy(self.crefs)
56 | return new
57 |
58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59 | ''' singular instance '''
60 | self.n = n
61 | self.sigma = sigma
62 | self.crefs = []
63 | self.ctest = []
64 | self.document_frequency = defaultdict(float)
65 | self.cook_append(test, refs)
66 | self.ref_len = None
67 |
68 | def cook_append(self, test, refs):
69 | '''called by constructor and __iadd__ to avoid creating new instances.'''
70 |
71 | if refs is not None:
72 | self.crefs.append(cook_refs(refs))
73 | if test is not None:
74 | self.ctest.append(cook_test(test)) ## N.B.: -1
75 | else:
76 | self.ctest.append(None) # lens of crefs and ctest have to match
77 |
78 | def size(self):
79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80 | return len(self.crefs)
81 |
82 | def __iadd__(self, other):
83 | '''add an instance (e.g., from another sentence).'''
84 |
85 | if type(other) is tuple:
86 | ## avoid creating new CiderScorer instances
87 | self.cook_append(other[0], other[1])
88 | else:
89 | self.ctest.extend(other.ctest)
90 | self.crefs.extend(other.crefs)
91 |
92 | return self
93 | def compute_doc_freq(self):
94 | '''
95 | Compute term frequency for reference data.
96 | This will be used to compute idf (inverse document frequency later)
97 | The term frequency is stored in the object
98 | :return: None
99 | '''
100 | for refs in self.crefs:
101 | # refs, k ref captions of one image
102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
103 | self.document_frequency[ngram] += 1
104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105 |
106 | def compute_cider(self):
107 | def counts2vec(cnts):
108 | """
109 | Function maps counts of ngram to vector of tfidf weights.
110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111 | The n-th entry of array denotes length of n-grams.
112 | :param cnts:
113 | :return: vec (array of dict), norm (array of float), length (int)
114 | """
115 | vec = [defaultdict(float) for _ in range(self.n)]
116 | length = 0
117 | norm = [0.0 for _ in range(self.n)]
118 | for (ngram,term_freq) in cnts.iteritems():
119 | # give word count 1 if it doesn't appear in reference corpus
120 | df = np.log(max(1.0, self.document_frequency[ngram]))
121 | # ngram index
122 | n = len(ngram)-1
123 | # tf (term_freq) * idf (precomputed idf) for n-grams
124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125 | # compute norm for the vector. the norm will be used for computing similarity
126 | norm[n] += pow(vec[n][ngram], 2)
127 |
128 | if n == 1:
129 | length += term_freq
130 | norm = [np.sqrt(n) for n in norm]
131 | return vec, norm, length
132 |
133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134 | '''
135 | Compute the cosine similarity of two vectors.
136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137 | :param vec_ref: array of dictionary for vector corresponding to reference
138 | :param norm_hyp: array of float for vector corresponding to hypothesis
139 | :param norm_ref: array of float for vector corresponding to reference
140 | :param length_hyp: int containing length of hypothesis
141 | :param length_ref: int containing length of reference
142 | :return: array of score for each n-grams cosine similarity
143 | '''
144 | delta = float(length_hyp - length_ref)
145 | # measure consine similarity
146 | val = np.array([0.0 for _ in range(self.n)])
147 | for n in range(self.n):
148 | # ngram
149 | for (ngram,count) in vec_hyp[n].iteritems():
150 | # vrama91 : added clipping
151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152 |
153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154 | val[n] /= (norm_hyp[n]*norm_ref[n])
155 |
156 | assert(not math.isnan(val[n]))
157 | # vrama91: added a length based gaussian penalty
158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159 | return val
160 |
161 | # compute log reference length
162 | self.ref_len = np.log(float(len(self.crefs)))
163 |
164 | scores = []
165 | for test, refs in zip(self.ctest, self.crefs):
166 | # compute vector for test captions
167 | vec, norm, length = counts2vec(test)
168 | # compute vector for ref captions
169 | score = np.array([0.0 for _ in range(self.n)])
170 | for ref in refs:
171 | vec_ref, norm_ref, length_ref = counts2vec(ref)
172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
173 | # change by vrama91 - mean of ngram scores, instead of sum
174 | score_avg = np.mean(score)
175 | # divide by number of references
176 | score_avg /= len(refs)
177 | # multiply score by 10
178 | score_avg *= 10.0
179 | # append score of an image to the score list
180 | scores.append(score_avg)
181 | return scores
182 |
183 | def compute_score(self, option=None, verbose=0):
184 | # compute idf
185 | self.compute_doc_freq()
186 | # assert to check document frequency
187 | assert(len(self.ctest) >= max(self.document_frequency.values()))
188 | # compute cider score
189 | score = self.compute_cider()
190 | # debug
191 | # print score
192 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider_scorer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/cider/cider_scorer.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/meteor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__pycache__/meteor.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/meteor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/__pycache__/meteor.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/meteor-1.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/meteor-1.5.jar
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 | import os
7 | import sys
8 | import subprocess
9 | import threading
10 |
11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
12 | METEOR_JAR = 'meteor-1.5.jar'
13 | # print METEOR_JAR
14 |
15 | class Meteor:
16 |
17 | def __init__(self):
18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
19 | '-', '-', '-stdio', '-l', 'en', '-norm']
20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
21 | cwd=os.path.dirname(os.path.abspath(__file__)), \
22 | stdin=subprocess.PIPE, \
23 | stdout=subprocess.PIPE, \
24 | stderr=subprocess.PIPE)
25 | # Used to guarantee thread safety
26 | self.lock = threading.Lock()
27 |
28 | def compute_score(self, gts, res):
29 | assert(gts.keys() == res.keys())
30 | imgIds = gts.keys()
31 | scores = []
32 |
33 | eval_line = 'EVAL'
34 | self.lock.acquire()
35 | for i in imgIds:
36 | assert(len(res[i]) == 1)
37 | stat = self._stat(res[i][0], gts[i])
38 | eval_line += ' ||| {}'.format(stat)
39 |
40 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
41 | for i in range(0,len(imgIds)):
42 | scores.append(float(self.meteor_p.stdout.readline().strip()))
43 | score = float(self.meteor_p.stdout.readline().strip())
44 | self.lock.release()
45 |
46 | return score, scores
47 |
48 | def method(self):
49 | return "METEOR"
50 |
51 | def _stat(self, hypothesis_str, reference_list):
52 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
53 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
54 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
55 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode())
56 | return self.meteor_p.stdout.readline().decode().strip()
57 |
58 | def _score(self, hypothesis_str, reference_list):
59 | self.lock.acquire()
60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
63 | self.meteor_p.stdin.write('{}\n'.format(score_line))
64 | stats = self.meteor_p.stdout.readline().strip()
65 | eval_line = 'EVAL ||| {}'.format(stats)
66 | # EVAL ||| stats
67 | self.meteor_p.stdin.write('{}\n'.format(eval_line))
68 | score = float(self.meteor_p.stdout.readline().strip())
69 | self.lock.release()
70 | return score
71 |
72 | def __exit__(self):
73 | self.lock.acquire()
74 | self.meteor_p.stdin.close()
75 | self.meteor_p.wait()
76 | self.lock.release()
77 |
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/meteor.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/meteor/meteor.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/readme.txt:
--------------------------------------------------------------------------------
1 | This folder contains modified coco-caption evaluation, which is downloaded from https://github.com/tylin/coco-caption.git
2 | and refEvaluation which is to be called by the refer algorithm.
3 |
4 | More specifically, this folder contains:
5 | 1. bleu/
6 | 2. cider/
7 | 3. meteor/
8 | 4. rouge/
9 | 5. tokenizer/
10 | 6. __init__.py
11 | 7. refEvaluation.py
12 |
--------------------------------------------------------------------------------
/refTools/evaluation/refEvaluation.py:
--------------------------------------------------------------------------------
1 | from refTools.evaluation.tokenizer.ptbtokenizer import PTBTokenizer
2 | from refTools.evaluation.bleu.bleu import Bleu
3 | from refTools.evaluation.meteor.meteor import Meteor
4 | from refTools.evaluation.rouge.rouge import Rouge
5 | from refTools.evaluation.cider.cider import Cider
6 |
7 | """
8 | Input: refer and Res = [{ref_id, sent}]
9 |
10 | Things of interest
11 | evalRefs - list of ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']
12 | eval - dict of {metric: score}
13 | refToEval - dict of {ref_id: ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']}
14 | """
15 |
16 | class RefEvaluation:
17 | def __init__ (self, refer, Res):
18 | """
19 | :param refer: refer class of current dataset
20 | :param Res: [{'ref_id', 'sent'}]
21 | """
22 | self.evalRefs = []
23 | self.eval = {}
24 | self.refToEval = {}
25 | self.refer = refer
26 | self.Res = Res
27 |
28 | def evaluate(self):
29 |
30 | evalRefIds = [ann['ref_id'] for ann in self.Res]
31 |
32 | refToGts = {}
33 | for ref_id in evalRefIds:
34 | ref = self.refer.Refs[ref_id]
35 | gt_sents = [sent['sent'].encode('ascii', 'ignore').decode('ascii') for sent in ref['sentences']] # up to 3 expressions
36 | refToGts[ref_id] = gt_sents
37 | refToRes = {ann['ref_id']: [ann['sent']] for ann in self.Res}
38 |
39 | print('tokenization...')
40 | tokenizer = PTBTokenizer()
41 | self.refToRes = tokenizer.tokenize(refToRes)
42 | self.refToGts = tokenizer.tokenize(refToGts)
43 |
44 | # =================================================
45 | # Set up scorers
46 | # =================================================
47 | print('setting up scorers...')
48 | scorers = [
49 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
50 | (Meteor(),"METEOR"),
51 | (Rouge(), "ROUGE_L"),
52 | (Cider(), "CIDEr")
53 | ]
54 |
55 | # =================================================
56 | # Compute scores
57 | # =================================================
58 | for scorer, method in scorers:
59 | print('computing %s score...'%(scorer.method()))
60 | score, scores = scorer.compute_score(self.refToGts, self.refToRes)
61 | if type(method) == list:
62 | for sc, scs, m in zip(score, scores, method):
63 | self.setEval(sc, m)
64 | self.setRefToEvalRefs(scs, self.refToGts.keys(), m)
65 | print("%s: %0.3f"%(m, sc))
66 | else:
67 | self.setEval(score, method)
68 | self.setRefToEvalRefs(scores, self.refToGts.keys(), method)
69 | print("%s: %0.3f"%(method, score))
70 | self.setEvalRefs()
71 |
72 | def setEval(self, score, method):
73 | self.eval[method] = score
74 |
75 | def setRefToEvalRefs(self, scores, refIds, method):
76 | for refId, score in zip(refIds, scores):
77 | if not refId in self.refToEval:
78 | self.refToEval[refId] = {}
79 | self.refToEval[refId]["ref_id"] = refId
80 | self.refToEval[refId][method] = score
81 |
82 | def setEvalRefs(self):
83 | self.evalRefs = [eval for refId, eval in self.refToEval.items()]
84 |
85 |
86 | if __name__ == '__main__':
87 |
88 | import os.path as osp
89 | import sys
90 | ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))
91 | sys.path.insert(0, osp.join(ROOT_DIR, 'lib', 'datasets'))
92 | from refer import REFER
93 |
94 | # load refer of dataset
95 | dataset = 'refcoco'
96 | refer = REFER(dataset, splitBy = 'google')
97 |
98 | # mimic some Res
99 | val_refIds = refer.getRefIds(split='test')
100 | ref_id = 49767
101 | print("GD: %s" % refer.Refs[ref_id]['sentences'])
102 | Res = [{'ref_id': ref_id, 'sent': 'left bottle'}]
103 |
104 | # evaluate some refer expressions
105 | refEval = RefEvaluation(refer, Res)
106 | refEval.evaluate()
107 |
108 | # print output evaluation scores
109 | for metric, score in refEval.eval.items():
110 | print('%s: %.3f'%(metric, score))
111 |
112 | # demo how to use evalImgs to retrieve low score result
113 | # evals = [eva for eva in refEval.evalRefs if eva['CIDEr']<30]
114 | # print 'ground truth sents'
115 | # refId = evals[0]['ref_id']
116 | # print 'refId: %s' % refId
117 | # print [sent['sent'] for sent in refer.Refs[refId]['sentences']]
118 | #
119 | # print 'generated sent (CIDEr score %0.1f)' % (evals[0]['CIDEr'])
120 |
121 | # print refEval.refToEval[8]
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/refTools/evaluation/refEvaluation.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/refEvaluation.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'vrama91'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/rouge.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__pycache__/rouge.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/rouge.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/__pycache__/rouge.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 |
20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21 | """
22 | if(len(string)< len(sub)):
23 | sub, string = string, sub
24 |
25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
26 |
27 | for j in range(1,len(sub)+1):
28 | for i in range(1,len(string)+1):
29 | if(string[i-1] == sub[j-1]):
30 | lengths[i][j] = lengths[i-1][j-1] + 1
31 | else:
32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
33 |
34 | return lengths[len(string)][len(sub)]
35 |
36 | class Rouge():
37 | '''
38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
39 |
40 | '''
41 | def __init__(self):
42 | # vrama91: updated the value below based on discussion with Hovey
43 | self.beta = 1.2
44 |
45 | def calc_score(self, candidate, refs):
46 | """
47 | Compute ROUGE-L score given one candidate and references for an image
48 | :param candidate: str : candidate sentence to be evaluated
49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
50 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
51 | """
52 | assert(len(candidate)==1)
53 | assert(len(refs)>0)
54 | prec = []
55 | rec = []
56 |
57 | # split into tokens
58 | token_c = candidate[0].split(" ")
59 |
60 | for reference in refs:
61 | # split into tokens
62 | token_r = reference.split(" ")
63 | # compute the longest common subsequence
64 | lcs = my_lcs(token_r, token_c)
65 | prec.append(lcs/float(len(token_c)))
66 | rec.append(lcs/float(len(token_r)))
67 |
68 | prec_max = max(prec)
69 | rec_max = max(rec)
70 |
71 | if(prec_max!=0 and rec_max !=0):
72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
73 | else:
74 | score = 0.0
75 | return score
76 |
77 | def compute_score(self, gts, res):
78 | """
79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
80 | Invoked by evaluate_captions.py
81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
84 | """
85 | assert(gts.keys() == res.keys())
86 | imgIds = gts.keys()
87 |
88 | score = []
89 | for id in imgIds:
90 | hypo = res[id]
91 | ref = gts[id]
92 |
93 | score.append(self.calc_score(hypo, ref))
94 |
95 | # Sanity check.
96 | assert(type(hypo) is list)
97 | assert(len(hypo) == 1)
98 | assert(type(ref) is list)
99 | assert(len(ref) > 0)
100 |
101 | average_score = np.mean(np.array(score))
102 | return average_score, np.array(score)
103 |
104 | def method(self):
105 | return "Rouge"
106 |
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/rouge.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/rouge/rouge.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'hfang'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/ptbtokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import sys
13 | import subprocess
14 | import tempfile
15 | import itertools
16 |
17 | # path to the stanford corenlp jar
18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
19 |
20 | # punctuations to be removed from the sentences
21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
23 |
24 | class PTBTokenizer:
25 | """Python wrapper of Stanford PTBTokenizer"""
26 |
27 | def tokenize(self, captions_for_image):
28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
29 | 'edu.stanford.nlp.process.PTBTokenizer', \
30 | '-preserveLines', '-lowerCase']
31 |
32 | # ======================================================
33 | # prepare data for PTB Tokenizer
34 | # ======================================================
35 | final_tokenized_captions_for_image = {}
36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
38 |
39 | # ======================================================
40 | # save sentences to temporary file
41 | # ======================================================
42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
44 | tmp_file.write(sentences.encode())
45 | tmp_file.close()
46 |
47 | # ======================================================
48 | # tokenize sentence
49 | # ======================================================
50 | cmd.append(os.path.basename(tmp_file.name))
51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
52 | stdout=subprocess.PIPE)
53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
54 | token_lines = token_lines.decode()
55 | lines = token_lines.split('\n')
56 | # remove temp file
57 | os.remove(tmp_file.name)
58 |
59 | # ======================================================
60 | # create dictionary for tokenized captions
61 | # ======================================================
62 | for k, line in zip(image_id, lines):
63 | if not k in final_tokenized_captions_for_image:
64 | final_tokenized_captions_for_image[k] = []
65 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
66 | if w not in PUNCTUATIONS])
67 | final_tokenized_captions_for_image[k].append(tokenized_caption)
68 |
69 | return final_tokenized_captions_for_image
70 |
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/ptbtokenizer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/ptbtokenizer.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/tmp82iqkuu0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/tmp82iqkuu0
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/tmpn19wmqte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/refTools/evaluation/tokenizer/tmpn19wmqte
--------------------------------------------------------------------------------
/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from .cosine_lr import CosineLRScheduler
2 | from .plateau_lr import PlateauLRScheduler
3 | from .step_lr import StepLRScheduler
4 | from .tanh_lr import TanhLRScheduler
5 | from .scheduler_factory import create_scheduler
6 |
--------------------------------------------------------------------------------
/scheduler/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/cosine_lr.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/cosine_lr.cpython-36.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/cosine_lr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/cosine_lr.cpython-38.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/plateau_lr.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/plateau_lr.cpython-36.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/plateau_lr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/plateau_lr.cpython-38.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/scheduler_factory.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/scheduler_factory.cpython-36.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/scheduler_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/scheduler_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/step_lr.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/step_lr.cpython-36.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/step_lr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/step_lr.cpython-38.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/tanh_lr.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/tanh_lr.cpython-36.pyc
--------------------------------------------------------------------------------
/scheduler/__pycache__/tanh_lr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/scheduler/__pycache__/tanh_lr.cpython-38.pyc
--------------------------------------------------------------------------------
/scheduler/cosine_lr.py:
--------------------------------------------------------------------------------
1 | """ Cosine Scheduler
2 |
3 | Cosine LR schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import numpy as np
10 | import torch
11 |
12 | from .scheduler import Scheduler
13 |
14 | from pdb import set_trace as breakpoint
15 |
16 | _logger = logging.getLogger(__name__)
17 |
18 |
19 | class CosineLRScheduler(Scheduler):
20 | """
21 | Cosine decay with restarts.
22 | This is described in the paper https://arxiv.org/abs/1608.03983.
23 |
24 | Inspiration from
25 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
26 | """
27 |
28 | def __init__(self,
29 | optimizer: torch.optim.Optimizer,
30 | t_initial: int,
31 | t_mul: float = 1.,
32 | lr_min: float = 0.,
33 | decay_rate: float = 1.,
34 | warmup_t=0,
35 | warmup_lr_init=0,
36 | warmup_prefix=True,
37 | cycle_limit=0,
38 | t_in_epochs=True,
39 | noise_range_t=None,
40 | noise_pct=0.67,
41 | noise_std=1.0,
42 | noise_seed=42,
43 | initialize=True) -> None:
44 | super().__init__(
45 | optimizer, param_group_field="lr",
46 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
47 | initialize=initialize)
48 |
49 | assert t_initial > 0
50 | assert lr_min >= 0
51 | if t_initial == 1 and t_mul == 1 and decay_rate == 1:
52 | _logger.warning("Cosine annealing scheduler will have no effect on the learning "
53 | "rate since t_initial = t_mul = eta_mul = 1.")
54 | self.t_initial = t_initial
55 | self.t_mul = t_mul
56 | self.lr_min = lr_min
57 | self.decay_rate = decay_rate
58 | self.cycle_limit = cycle_limit
59 | self.warmup_t = warmup_t
60 | self.warmup_lr_init = warmup_lr_init
61 | self.warmup_prefix = warmup_prefix
62 | self.t_in_epochs = t_in_epochs
63 | if self.warmup_t:
64 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
65 | super().update_groups(self.warmup_lr_init)
66 | else:
67 | self.warmup_steps = [1 for _ in self.base_values]
68 |
69 | def _get_lr(self, t):
70 | if t < self.warmup_t:
71 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
72 | else:
73 | if self.warmup_prefix:
74 | t = t - self.warmup_t
75 |
76 | if self.t_mul != 1:
77 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
78 | t_i = self.t_mul ** i * self.t_initial
79 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
80 | else:
81 | i = t // self.t_initial
82 | t_i = self.t_initial
83 | t_curr = t - (self.t_initial * i)
84 |
85 | gamma = self.decay_rate ** i
86 | lr_min = self.lr_min * gamma
87 | lr_max_values = [v * gamma for v in self.base_values]
88 |
89 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
90 | lrs = [
91 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
92 | ]
93 | else:
94 | lrs = [self.lr_min for _ in self.base_values]
95 |
96 | return lrs
97 |
98 | def get_epoch_values(self, epoch: int):
99 | if self.t_in_epochs:
100 | return self._get_lr(epoch)
101 | else:
102 | return None
103 |
104 | def get_update_values(self, num_updates: int):
105 | if not self.t_in_epochs:
106 | return self._get_lr(num_updates)
107 | else:
108 | return None
109 |
110 | def get_cycle_length(self, cycles=0):
111 | if not cycles:
112 | cycles = self.cycle_limit
113 | cycles = max(1, cycles)
114 | if self.t_mul == 1.0:
115 | return self.t_initial * cycles
116 | else:
117 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
118 |
--------------------------------------------------------------------------------
/scheduler/plateau_lr.py:
--------------------------------------------------------------------------------
1 | """ Plateau Scheduler
2 |
3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import torch
8 |
9 | from .scheduler import Scheduler
10 |
11 |
12 | class PlateauLRScheduler(Scheduler):
13 | """Decay the LR by a factor every time the validation loss plateaus."""
14 |
15 | def __init__(self,
16 | optimizer,
17 | decay_rate=0.1,
18 | patience_t=10,
19 | verbose=True,
20 | threshold=1e-4,
21 | cooldown_t=0,
22 | warmup_t=0,
23 | warmup_lr_init=0,
24 | lr_min=0,
25 | mode='max',
26 | noise_range_t=None,
27 | noise_type='normal',
28 | noise_pct=0.67,
29 | noise_std=1.0,
30 | noise_seed=None,
31 | initialize=True,
32 | ):
33 | super().__init__(optimizer, 'lr', initialize=initialize)
34 |
35 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
36 | self.optimizer,
37 | patience=patience_t,
38 | factor=decay_rate,
39 | verbose=verbose,
40 | threshold=threshold,
41 | cooldown=cooldown_t,
42 | mode=mode,
43 | min_lr=lr_min
44 | )
45 |
46 | self.noise_range = noise_range_t
47 | self.noise_pct = noise_pct
48 | self.noise_type = noise_type
49 | self.noise_std = noise_std
50 | self.noise_seed = noise_seed if noise_seed is not None else 42
51 | self.warmup_t = warmup_t
52 | self.warmup_lr_init = warmup_lr_init
53 | if self.warmup_t:
54 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
55 | super().update_groups(self.warmup_lr_init)
56 | else:
57 | self.warmup_steps = [1 for _ in self.base_values]
58 | self.restore_lr = None
59 |
60 | def state_dict(self):
61 | return {
62 | 'best': self.lr_scheduler.best,
63 | 'last_epoch': self.lr_scheduler.last_epoch,
64 | }
65 |
66 | def load_state_dict(self, state_dict):
67 | self.lr_scheduler.best = state_dict['best']
68 | if 'last_epoch' in state_dict:
69 | self.lr_scheduler.last_epoch = state_dict['last_epoch']
70 |
71 | # override the base class step fn completely
72 | def step(self, epoch, metric=None):
73 | if epoch <= self.warmup_t:
74 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
75 | super().update_groups(lrs)
76 | else:
77 | if self.restore_lr is not None:
78 | # restore actual LR from before our last noise perturbation before stepping base
79 | for i, param_group in enumerate(self.optimizer.param_groups):
80 | param_group['lr'] = self.restore_lr[i]
81 | self.restore_lr = None
82 |
83 | self.lr_scheduler.step(metric, epoch) # step the base scheduler
84 |
85 | if self.noise_range is not None:
86 | if isinstance(self.noise_range, (list, tuple)):
87 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
88 | else:
89 | apply_noise = epoch >= self.noise_range
90 | if apply_noise:
91 | self._apply_noise(epoch)
92 |
93 | def _apply_noise(self, epoch):
94 | g = torch.Generator()
95 | g.manual_seed(self.noise_seed + epoch)
96 | if self.noise_type == 'normal':
97 | while True:
98 | # resample if noise out of percent limit, brute force but shouldn't spin much
99 | noise = torch.randn(1, generator=g).item()
100 | if abs(noise) < self.noise_pct:
101 | break
102 | else:
103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104 |
105 | # apply the noise on top of previous LR, cache the old value so we can restore for normal
106 | # stepping of base scheduler
107 | restore_lr = []
108 | for i, param_group in enumerate(self.optimizer.param_groups):
109 | old_lr = float(param_group['lr'])
110 | restore_lr.append(old_lr)
111 | new_lr = old_lr + old_lr * noise
112 | param_group['lr'] = new_lr
113 | self.restore_lr = restore_lr
114 |
--------------------------------------------------------------------------------
/scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | import torch
4 |
5 |
6 | class Scheduler:
7 | """ Parameter Scheduler Base Class
8 | A scheduler base class that can be used to schedule any optimizer parameter groups.
9 |
10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called
11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
13 |
14 | The schedulers built on this should try to remain as stateless as possible (for simplicity).
15 |
16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training
18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call.
19 |
20 | Based on ideas from:
21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
23 | """
24 |
25 | def __init__(self,
26 | optimizer: torch.optim.Optimizer,
27 | param_group_field: str,
28 | noise_range_t=None,
29 | noise_type='normal',
30 | noise_pct=0.67,
31 | noise_std=1.0,
32 | noise_seed=None,
33 | initialize: bool = True) -> None:
34 | self.optimizer = optimizer
35 | self.param_group_field = param_group_field
36 | self._initial_param_group_field = f"initial_{param_group_field}"
37 | if initialize:
38 | for i, group in enumerate(self.optimizer.param_groups):
39 | if param_group_field not in group:
40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
41 | group.setdefault(self._initial_param_group_field, group[param_group_field])
42 | else:
43 | for i, group in enumerate(self.optimizer.param_groups):
44 | if self._initial_param_group_field not in group:
45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
47 | self.metric = None # any point to having this for all?
48 | self.noise_range_t = noise_range_t
49 | self.noise_pct = noise_pct
50 | self.noise_type = noise_type
51 | self.noise_std = noise_std
52 | self.noise_seed = noise_seed if noise_seed is not None else 42
53 | self.update_groups(self.base_values)
54 |
55 | def state_dict(self) -> Dict[str, Any]:
56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
57 |
58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
59 | self.__dict__.update(state_dict)
60 |
61 | def get_epoch_values(self, epoch: int):
62 | return None
63 |
64 | def get_update_values(self, num_updates: int):
65 | return None
66 |
67 | def step(self, epoch: int, metric: float = None) -> None:
68 | self.metric = metric
69 | values = self.get_epoch_values(epoch)
70 | if values is not None:
71 | values = self._add_noise(values, epoch)
72 | self.update_groups(values)
73 |
74 | def step_update(self, num_updates: int, metric: float = None):
75 | self.metric = metric
76 | values = self.get_update_values(num_updates)
77 | if values is not None:
78 | values = self._add_noise(values, num_updates)
79 | self.update_groups(values)
80 |
81 | def update_groups(self, values):
82 | if not isinstance(values, (list, tuple)):
83 | values = [values] * len(self.optimizer.param_groups)
84 | for param_group, value in zip(self.optimizer.param_groups, values):
85 | param_group[self.param_group_field] = value
86 |
87 | def _add_noise(self, lrs, t):
88 | if self.noise_range_t is not None:
89 | if isinstance(self.noise_range_t, (list, tuple)):
90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
91 | else:
92 | apply_noise = t >= self.noise_range_t
93 | if apply_noise:
94 | g = torch.Generator()
95 | g.manual_seed(self.noise_seed + t)
96 | if self.noise_type == 'normal':
97 | while True:
98 | # resample if noise out of percent limit, brute force but shouldn't spin much
99 | noise = torch.randn(1, generator=g).item()
100 | if abs(noise) < self.noise_pct:
101 | break
102 | else:
103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104 | lrs = [v + v * noise for v in lrs]
105 | return lrs
106 |
--------------------------------------------------------------------------------
/scheduler/scheduler_factory.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from .cosine_lr import CosineLRScheduler
5 | from .tanh_lr import TanhLRScheduler
6 | from .step_lr import StepLRScheduler
7 | from .plateau_lr import PlateauLRScheduler
8 |
9 |
10 | def create_scheduler(args, optimizer):
11 | num_epochs = args.epochs
12 |
13 | if getattr(args, 'lr_noise', None) is not None:
14 | lr_noise = getattr(args, 'lr_noise')
15 | if isinstance(lr_noise, (list, tuple)):
16 | noise_range = [n * num_epochs for n in lr_noise]
17 | if len(noise_range) == 1:
18 | noise_range = noise_range[0]
19 | else:
20 | noise_range = lr_noise * num_epochs
21 | else:
22 | noise_range = None
23 |
24 | lr_scheduler = None
25 | if args.sched == 'cosine':
26 | lr_scheduler = CosineLRScheduler(
27 | optimizer,
28 | t_initial=num_epochs,
29 | t_mul=getattr(args, 'lr_cycle_mul', 1.),
30 | lr_min=args.min_lr,
31 | decay_rate=args.decay_rate,
32 | warmup_lr_init=args.warmup_lr,
33 | warmup_t=args.warmup_epochs,
34 | cycle_limit=getattr(args, 'lr_cycle_limit', 1),
35 | t_in_epochs=True,
36 | noise_range_t=noise_range,
37 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
38 | noise_std=getattr(args, 'lr_noise_std', 1.),
39 | noise_seed=getattr(args, 'seed', 42),
40 | )
41 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
42 | elif args.sched == 'tanh':
43 | lr_scheduler = TanhLRScheduler(
44 | optimizer,
45 | t_initial=num_epochs,
46 | t_mul=getattr(args, 'lr_cycle_mul', 1.),
47 | lr_min=args.min_lr,
48 | warmup_lr_init=args.warmup_lr,
49 | warmup_t=args.warmup_epochs,
50 | cycle_limit=getattr(args, 'lr_cycle_limit', 1),
51 | t_in_epochs=True,
52 | noise_range_t=noise_range,
53 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
54 | noise_std=getattr(args, 'lr_noise_std', 1.),
55 | noise_seed=getattr(args, 'seed', 42),
56 | )
57 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
58 | elif args.sched == 'step':
59 | lr_scheduler = StepLRScheduler(
60 | optimizer,
61 | decay_t=args.decay_epochs,
62 | decay_rate=args.decay_rate,
63 | warmup_lr_init=args.warmup_lr,
64 | warmup_t=args.warmup_epochs,
65 | noise_range_t=noise_range,
66 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
67 | noise_std=getattr(args, 'lr_noise_std', 1.),
68 | noise_seed=getattr(args, 'seed', 42),
69 | )
70 | elif args.sched == 'plateau':
71 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
72 | lr_scheduler = PlateauLRScheduler(
73 | optimizer,
74 | decay_rate=args.decay_rate,
75 | patience_t=args.patience_epochs,
76 | lr_min=args.min_lr,
77 | mode=mode,
78 | warmup_lr_init=args.warmup_lr,
79 | warmup_t=args.warmup_epochs,
80 | cooldown_t=0,
81 | noise_range_t=noise_range,
82 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
83 | noise_std=getattr(args, 'lr_noise_std', 1.),
84 | noise_seed=getattr(args, 'seed', 42),
85 | )
86 |
87 | return lr_scheduler, num_epochs
88 |
--------------------------------------------------------------------------------
/scheduler/step_lr.py:
--------------------------------------------------------------------------------
1 | """ Step Scheduler
2 |
3 | Basic step LR schedule with warmup, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import math
8 | import torch
9 |
10 | from .scheduler import Scheduler
11 |
12 |
13 | class StepLRScheduler(Scheduler):
14 | """
15 | """
16 |
17 | def __init__(self,
18 | optimizer: torch.optim.Optimizer,
19 | decay_t: float,
20 | decay_rate: float = 1.,
21 | warmup_t=0,
22 | warmup_lr_init=0,
23 | t_in_epochs=True,
24 | noise_range_t=None,
25 | noise_pct=0.67,
26 | noise_std=1.0,
27 | noise_seed=42,
28 | initialize=True,
29 | ) -> None:
30 | super().__init__(
31 | optimizer, param_group_field="lr",
32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
33 | initialize=initialize)
34 |
35 | self.decay_t = decay_t
36 | self.decay_rate = decay_rate
37 | self.warmup_t = warmup_t
38 | self.warmup_lr_init = warmup_lr_init
39 | self.t_in_epochs = t_in_epochs
40 | if self.warmup_t:
41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
42 | super().update_groups(self.warmup_lr_init)
43 | else:
44 | self.warmup_steps = [1 for _ in self.base_values]
45 |
46 | def _get_lr(self, t):
47 | if t < self.warmup_t:
48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
49 | else:
50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
51 | return lrs
52 |
53 | def get_epoch_values(self, epoch: int):
54 | if self.t_in_epochs:
55 | return self._get_lr(epoch)
56 | else:
57 | return None
58 |
59 | def get_update_values(self, num_updates: int):
60 | if not self.t_in_epochs:
61 | return self._get_lr(num_updates)
62 | else:
63 | return None
64 |
--------------------------------------------------------------------------------
/scheduler/tanh_lr.py:
--------------------------------------------------------------------------------
1 | """ TanH Scheduler
2 |
3 | TanH schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import numpy as np
10 | import torch
11 |
12 | from .scheduler import Scheduler
13 |
14 |
15 | _logger = logging.getLogger(__name__)
16 |
17 |
18 | class TanhLRScheduler(Scheduler):
19 | """
20 | Hyberbolic-Tangent decay with restarts.
21 | This is described in the paper https://arxiv.org/abs/1806.01593
22 | """
23 |
24 | def __init__(self,
25 | optimizer: torch.optim.Optimizer,
26 | t_initial: int,
27 | lb: float = -6.,
28 | ub: float = 4.,
29 | t_mul: float = 1.,
30 | lr_min: float = 0.,
31 | decay_rate: float = 1.,
32 | warmup_t=0,
33 | warmup_lr_init=0,
34 | warmup_prefix=False,
35 | cycle_limit=0,
36 | t_in_epochs=True,
37 | noise_range_t=None,
38 | noise_pct=0.67,
39 | noise_std=1.0,
40 | noise_seed=42,
41 | initialize=True) -> None:
42 | super().__init__(
43 | optimizer, param_group_field="lr",
44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
45 | initialize=initialize)
46 |
47 | assert t_initial > 0
48 | assert lr_min >= 0
49 | assert lb < ub
50 | assert cycle_limit >= 0
51 | assert warmup_t >= 0
52 | assert warmup_lr_init >= 0
53 | self.lb = lb
54 | self.ub = ub
55 | self.t_initial = t_initial
56 | self.t_mul = t_mul
57 | self.lr_min = lr_min
58 | self.decay_rate = decay_rate
59 | self.cycle_limit = cycle_limit
60 | self.warmup_t = warmup_t
61 | self.warmup_lr_init = warmup_lr_init
62 | self.warmup_prefix = warmup_prefix
63 | self.t_in_epochs = t_in_epochs
64 | if self.warmup_t:
65 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
66 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
67 | super().update_groups(self.warmup_lr_init)
68 | else:
69 | self.warmup_steps = [1 for _ in self.base_values]
70 |
71 | def _get_lr(self, t):
72 | if t < self.warmup_t:
73 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
74 | else:
75 | if self.warmup_prefix:
76 | t = t - self.warmup_t
77 |
78 | if self.t_mul != 1:
79 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
80 | t_i = self.t_mul ** i * self.t_initial
81 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
82 | else:
83 | i = t // self.t_initial
84 | t_i = self.t_initial
85 | t_curr = t - (self.t_initial * i)
86 |
87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
88 | gamma = self.decay_rate ** i
89 | lr_min = self.lr_min * gamma
90 | lr_max_values = [v * gamma for v in self.base_values]
91 |
92 | tr = t_curr / t_i
93 | lrs = [
94 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
95 | for lr_max in lr_max_values
96 | ]
97 | else:
98 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values]
99 | return lrs
100 |
101 | def get_epoch_values(self, epoch: int):
102 | if self.t_in_epochs:
103 | return self._get_lr(epoch)
104 | else:
105 | return None
106 |
107 | def get_update_values(self, num_updates: int):
108 | if not self.t_in_epochs:
109 | return self._get_lr(num_updates)
110 | else:
111 | return None
112 |
113 | def get_cycle_length(self, cycles=0):
114 | if not cycles:
115 | cycles = self.cycle_limit
116 | cycles = max(1, cycles)
117 | if self.t_mul == 1.0:
118 | return self.t_initial * cycles
119 | else:
120 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
121 |
--------------------------------------------------------------------------------
/vqaTools/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 |
--------------------------------------------------------------------------------
/vqaTools/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/vqaTools/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/vqaTools/__pycache__/vqa.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/vqa.cpython-36.pyc
--------------------------------------------------------------------------------
/vqaTools/__pycache__/vqa.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/vqa.cpython-38.pyc
--------------------------------------------------------------------------------
/vqaTools/__pycache__/vqaEval.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/vqaEval.cpython-36.pyc
--------------------------------------------------------------------------------
/vqaTools/__pycache__/vqaEval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/ALBEF/b9727e43c3040491774d1b22cc27718aa7772fac/vqaTools/__pycache__/vqaEval.cpython-38.pyc
--------------------------------------------------------------------------------
/vqaTools/vqa.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 | __version__ = '0.9'
3 |
4 | # Interface for accessing the VQA dataset.
5 |
6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
8 |
9 | # The following functions are defined:
10 | # VQA - VQA class that loads VQA annotation file and prepares data structures.
11 | # getQuesIds - Get question ids that satisfy given filter conditions.
12 | # getImgIds - Get image ids that satisfy given filter conditions.
13 | # loadQA - Load questions and answers with the specified question ids.
14 | # showQA - Display the specified questions and answers.
15 | # loadRes - Load result file and create result object.
16 |
17 | # Help on each function can be accessed by: "help(COCO.function)"
18 |
19 | import json
20 | import datetime
21 | import copy
22 |
23 | class VQA:
24 | def __init__(self, annotation_file=None, question_file=None):
25 | """
26 | Constructor of VQA helper class for reading and visualizing questions and answers.
27 | :param annotation_file (str): location of VQA annotation file
28 | :return:
29 | """
30 | # load dataset
31 | self.dataset = {}
32 | self.questions = {}
33 | self.qa = {}
34 | self.qqa = {}
35 | self.imgToQA = {}
36 | if not annotation_file == None and not question_file == None:
37 | print('loading VQA annotations and questions into memory...')
38 | time_t = datetime.datetime.utcnow()
39 | dataset = json.load(open(annotation_file, 'r'))
40 | questions = json.load(open(question_file, 'r'))
41 | self.dataset = dataset
42 | self.questions = questions
43 | self.createIndex()
44 |
45 | def createIndex(self):
46 | # create index
47 | print('creating index...')
48 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
49 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
50 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
51 | for ann in self.dataset['annotations']:
52 | imgToQA[ann['image_id']] += [ann]
53 | qa[ann['question_id']] = ann
54 | for ques in self.questions['questions']:
55 | qqa[ques['question_id']] = ques
56 | print('index created!')
57 |
58 | # create class members
59 | self.qa = qa
60 | self.qqa = qqa
61 | self.imgToQA = imgToQA
62 |
63 | def info(self):
64 | """
65 | Print information about the VQA annotation file.
66 | :return:
67 | """
68 | for key, value in self.datset['info'].items():
69 | print('%s: %s'%(key, value))
70 |
71 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
72 | """
73 | Get question ids that satisfy given filter conditions. default skips that filter
74 | :param imgIds (int array) : get question ids for given imgs
75 | quesTypes (str array) : get question ids for given question types
76 | ansTypes (str array) : get question ids for given answer types
77 | :return: ids (int array) : integer array of question ids
78 | """
79 | imgIds = imgIds if type(imgIds) == list else [imgIds]
80 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
81 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
82 |
83 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
84 | anns = self.dataset['annotations']
85 | else:
86 | if not len(imgIds) == 0:
87 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[])
88 | else:
89 | anns = self.dataset['annotations']
90 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
91 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
92 | ids = [ann['question_id'] for ann in anns]
93 | return ids
94 |
95 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
96 | """
97 | Get image ids that satisfy given filter conditions. default skips that filter
98 | :param quesIds (int array) : get image ids for given question ids
99 | quesTypes (str array) : get image ids for given question types
100 | ansTypes (str array) : get image ids for given answer types
101 | :return: ids (int array) : integer array of image ids
102 | """
103 | quesIds = quesIds if type(quesIds) == list else [quesIds]
104 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
105 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
106 |
107 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
108 | anns = self.dataset['annotations']
109 | else:
110 | if not len(quesIds) == 0:
111 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[])
112 | else:
113 | anns = self.dataset['annotations']
114 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
115 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
116 | ids = [ann['image_id'] for ann in anns]
117 | return ids
118 |
119 | def loadQA(self, ids=[]):
120 | """
121 | Load questions and answers with the specified question ids.
122 | :param ids (int array) : integer ids specifying question ids
123 | :return: qa (object array) : loaded qa objects
124 | """
125 | if type(ids) == list:
126 | return [self.qa[id] for id in ids]
127 | elif type(ids) == int:
128 | return [self.qa[ids]]
129 |
130 | def showQA(self, anns):
131 | """
132 | Display the specified annotations.
133 | :param anns (array of object): annotations to display
134 | :return: None
135 | """
136 | if len(anns) == 0:
137 | return 0
138 | for ann in anns:
139 | quesId = ann['question_id']
140 | print("Question: %s" %(self.qqa[quesId]['question']))
141 | for ans in ann['answers']:
142 | print("Answer %d: %s" %(ans['answer_id'], ans['answer']))
143 |
144 | def loadRes(self, resFile, quesFile):
145 | """
146 | Load result file and return a result object.
147 | :param resFile (str) : file name of result file
148 | :return: res (obj) : result api object
149 | """
150 | res = VQA()
151 | res.questions = json.load(open(quesFile))
152 | res.dataset['info'] = copy.deepcopy(self.questions['info'])
153 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
154 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
155 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
156 | res.dataset['license'] = copy.deepcopy(self.questions['license'])
157 |
158 | print('Loading and preparing results... ')
159 | time_t = datetime.datetime.utcnow()
160 | anns = json.load(open(resFile))
161 | assert type(anns) == list, 'results is not an array of objects'
162 | annsQuesIds = [ann['question_id'] for ann in anns]
163 | assert set(annsQuesIds) == set(self.getQuesIds()), \
164 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
165 | for ann in anns:
166 | quesId = ann['question_id']
167 | if res.dataset['task_type'] == 'Multiple Choice':
168 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices'
169 | qaAnn = self.qa[quesId]
170 | ann['image_id'] = qaAnn['image_id']
171 | ann['question_type'] = qaAnn['question_type']
172 | ann['answer_type'] = qaAnn['answer_type']
173 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))
174 |
175 | res.dataset['annotations'] = anns
176 | res.createIndex()
177 | return res
178 |
--------------------------------------------------------------------------------