├── NLVR.py
├── Pretrain.py
├── Pretrain_nlvr.py
├── README.md
├── Retrieval.py
├── VQA.py
├── configs
├── NLVR.yaml
├── NLVR_pretrain.yaml
├── Pretrain.yaml
├── Retrieval_coco.yaml
├── Retrieval_flickr.yaml
├── VQA.yaml
└── config_bert.json
├── dataset
├── caption_dataset.py
├── handle_data.py
├── nlvr_dataset.py
├── randaugment.py
├── sampler_for_grit.py
├── utils.py
└── vqa_dataset.py
├── img.png
├── models
├── GRIT_utils.py
├── __init__.py
├── model_nlvr.py
├── model_pretrain_GRIT_VLP.py
├── model_pretrain_nlvr.py
├── model_retrieval.py
├── model_vqa.py
├── tokenization_bert.py
├── vit.py
└── xbert.py
├── optim
├── __init__.py
├── adafactor.py
├── adahessian.py
├── adamp.py
├── adamw.py
├── lookahead.py
├── nadam.py
├── novograd.py
├── nvnovograd.py
├── optim_factory.py
├── radam.py
├── rmsprop_tf.py
└── sgdp.py
├── refTools
├── evaluation
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── refEvaluation.cpython-36.pyc
│ │ ├── refEvaluation.cpython-37.pyc
│ │ └── refEvaluation.cpython-38.pyc
│ ├── bleu
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── bleu.cpython-36.pyc
│ │ │ ├── bleu.cpython-37.pyc
│ │ │ ├── bleu.cpython-38.pyc
│ │ │ ├── bleu_scorer.cpython-36.pyc
│ │ │ ├── bleu_scorer.cpython-37.pyc
│ │ │ └── bleu_scorer.cpython-38.pyc
│ │ ├── bleu.py
│ │ ├── bleu.pyc
│ │ ├── bleu_scorer.py
│ │ └── bleu_scorer.pyc
│ ├── cider
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── cider.cpython-36.pyc
│ │ │ ├── cider.cpython-37.pyc
│ │ │ ├── cider.cpython-38.pyc
│ │ │ ├── cider_scorer.cpython-36.pyc
│ │ │ ├── cider_scorer.cpython-37.pyc
│ │ │ └── cider_scorer.cpython-38.pyc
│ │ ├── cider.py
│ │ ├── cider.pyc
│ │ ├── cider_scorer.py
│ │ └── cider_scorer.pyc
│ ├── meteor
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── meteor.cpython-36.pyc
│ │ │ ├── meteor.cpython-37.pyc
│ │ │ └── meteor.cpython-38.pyc
│ │ ├── meteor-1.5.jar
│ │ ├── meteor.py
│ │ └── meteor.pyc
│ ├── readme.txt
│ ├── refEvaluation.py
│ ├── refEvaluation.pyc
│ ├── rouge
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── rouge.cpython-36.pyc
│ │ │ ├── rouge.cpython-37.pyc
│ │ │ └── rouge.cpython-38.pyc
│ │ ├── rouge.py
│ │ └── rouge.pyc
│ └── tokenizer
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── ptbtokenizer.cpython-36.pyc
│ │ ├── ptbtokenizer.cpython-37.pyc
│ │ └── ptbtokenizer.cpython-38.pyc
│ │ ├── ptbtokenizer.py
│ │ ├── ptbtokenizer.pyc
│ │ ├── stanford-corenlp-3.4.1.jar
│ │ ├── tmp37tp6xj8
│ │ ├── tmp82iqkuu0
│ │ └── tmpn19wmqte
└── refer_python3.py
├── scheduler
├── __init__.py
├── cosine_lr.py
├── plateau_lr.py
├── scheduler.py
├── scheduler_factory.py
├── step_lr.py
└── tanh_lr.py
├── utils.py
└── vqaTools
├── __init__.py
├── vqa.py
└── vqaEval.py
/Pretrain_nlvr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import ruamel.yaml as yaml
4 | import numpy as np
5 | import random
6 | import time
7 | import datetime
8 | import json
9 | from pathlib import Path
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.utils.data import DataLoader
15 | import torch.backends.cudnn as cudnn
16 | import torch.distributed as dist
17 |
18 | from models.model_pretrain_nlvr import ALBEF
19 | from models.vit import interpolate_pos_embed
20 | from models.tokenization_bert import BertTokenizer
21 |
22 | import utils
23 | from dataset.handle_data import create_dataset, create_sampler, create_loader
24 | from scheduler import create_scheduler
25 | from optim import create_optimizer
26 |
27 |
28 | def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config):
29 | # train
30 | model.train()
31 |
32 | metric_logger = utils.MetricLogger(delimiter=" ")
33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
34 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
35 |
36 | header = 'Train Epoch: [{}]'.format(epoch)
37 | print_freq = 50
38 | step_size = 100
39 | warmup_iterations = warmup_steps*step_size
40 |
41 | if args.distributed:
42 | data_loader.sampler.set_epoch(epoch)
43 |
44 | for i, (image, text) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
45 |
46 | optimizer.zero_grad()
47 |
48 | image = image.to(device,non_blocking=True)
49 | text_input = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
50 |
51 | loss = model(image, text_input)
52 | loss.backward()
53 |
54 | optimizer.step()
55 |
56 | metric_logger.update(loss=loss.item())
57 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
58 |
59 | if epoch==0 and i%step_size==0 and i<=warmup_iterations:
60 | scheduler.step(i//step_size)
61 |
62 | # gather the stats from all processes
63 | metric_logger.synchronize_between_processes()
64 | print("Averaged stats:", metric_logger.global_avg())
65 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
66 |
67 |
68 | def main(args, config):
69 | utils.init_distributed_mode(args)
70 |
71 | device = torch.device(args.device)
72 |
73 | # fix the seed for reproducibility
74 | seed = args.seed + utils.get_rank()
75 | torch.manual_seed(seed)
76 | np.random.seed(seed)
77 | random.seed(seed)
78 | cudnn.benchmark = True
79 |
80 | start_epoch = 0
81 | max_epoch = config['schedular']['epochs']
82 | warmup_steps = config['schedular']['warmup_epochs']
83 |
84 | #### Dataset ####
85 | print("Creating dataset")
86 | datasets = [create_dataset('nlvr_pretrain', config)]
87 |
88 | if args.distributed:
89 | num_tasks = utils.get_world_size()
90 | global_rank = utils.get_rank()
91 | samplers = create_sampler(datasets, [True], num_tasks, global_rank)
92 | else:
93 | samplers = [None]
94 |
95 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
96 |
97 | tokenizer = BertTokenizer.from_pretrained(args.text_encoder)
98 |
99 | #### Model ####
100 | print("Creating model")
101 | model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer)
102 |
103 | model = model.to(device)
104 |
105 | arg_opt = utils.AttrDict(config['optimizer'])
106 | optimizer = create_optimizer(arg_opt, model)
107 | arg_sche = utils.AttrDict(config['schedular'])
108 | lr_scheduler, _ = create_scheduler(arg_sche, optimizer)
109 |
110 | if args.checkpoint:
111 | checkpoint = torch.load(args.checkpoint, map_location='cpu')
112 | state_dict = checkpoint['model']
113 | pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
114 | state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
115 |
116 | for key in list(state_dict.keys()):
117 | if 'bert' in key:
118 | new_key = key.replace('bert.','')
119 |
120 | if 'layer' in new_key:
121 | keys = new_key.split('.')
122 | layer_num = int(keys[3])
123 | # replicate the multimodal encoder's blocks for two images
124 | if layer_num>=6:
125 | new_layer_num = (layer_num-6)*2+6
126 | keys[3] = str(new_layer_num)
127 | new_key_0 = '.'.join(keys)
128 | state_dict[new_key_0] = state_dict[key]
129 | keys[3] = str(new_layer_num+1)
130 | new_key_1 = '.'.join(keys)
131 | state_dict[new_key_1] = state_dict[key]
132 | else:
133 | state_dict[new_key] = state_dict[key]
134 | else:
135 | state_dict[new_key] = state_dict[key]
136 | del state_dict[key]
137 |
138 | msg = model.load_state_dict(state_dict,strict=False)
139 | print('load checkpoint from %s'%args.checkpoint)
140 | print(msg)
141 |
142 | model_without_ddp = model
143 | if args.distributed:
144 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
145 | model_without_ddp = model.module
146 |
147 | print("Start training")
148 | start_time = time.time()
149 |
150 | for epoch in range(start_epoch, max_epoch):
151 |
152 | if epoch>0:
153 | lr_scheduler.step(epoch+warmup_steps)
154 |
155 | train_stats = train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config)
156 |
157 | if utils.is_main_process():
158 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
159 | 'epoch': epoch,
160 | }
161 | save_obj = {
162 | 'model': model_without_ddp.state_dict(),
163 | 'optimizer': optimizer.state_dict(),
164 | 'lr_scheduler': lr_scheduler.state_dict(),
165 | 'config': config,
166 | 'epoch': epoch,
167 | }
168 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
169 |
170 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
171 | f.write(json.dumps(log_stats) + "\n")
172 |
173 | dist.barrier()
174 |
175 | total_time = time.time() - start_time
176 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177 | print('Training time {}'.format(total_time_str))
178 |
179 |
180 |
181 | if __name__ == '__main__':
182 | parser = argparse.ArgumentParser()
183 | parser.add_argument('--config', default='./configs/NLVR_pretrain.yaml')
184 | parser.add_argument('--checkpoint', default='')
185 | parser.add_argument('--output_dir', default='output/NLVR_pretrain')
186 | parser.add_argument('--text_encoder', default='bert-base-uncased')
187 | parser.add_argument('--device', default='cuda')
188 | parser.add_argument('--seed', default=0, type=int)
189 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
190 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
191 | parser.add_argument('--distributed', default=True, type=bool)
192 | args = parser.parse_args()
193 |
194 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
195 |
196 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
197 |
198 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
199 |
200 | main(args, config)
201 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## GRIT-VLP: GRouped mIni-baTch sampling for Efficient Vision-Language Pre-training
2 | This is the official PyTorch implementation of "GRIT-VLP: GRouped mIni-baTch sampling for Efficient Vision-Language Pre-training"
3 | (Accepted to ECCV 2022)
4 |
5 | You can find the implementation codes for pre-training and fine-tuning GRIT-VLP.
6 |
7 |
8 |
9 |
10 |
11 | ### Pre-training Dataset Download:
12 | - [MSCOCO (2014)](https://cocodataset.org/#download)
13 | - [Visual Genome (VG)](https://visualgenome.org/api/v0/api_home.html)
14 | - [Conceptual Captions](https://www.kaggle.com/ad271828/conceptual-captions-dataset-train-and-validation)
15 | - [SBU Captions](http://www.cs.virginia.edu/~vicente/sbucaptions/)
16 |
17 |
18 | ### Downstream-task Datasets:
19 | - [NLVR2](https://github.com/lil-lab/nlvr/tree/master/nlvr2#direct-image-download)
20 | - [Flickr30k](https://www.kaggle.com/hsankesara/flickr-image-dataset)
21 | - [VQA v2](https://visualqa.org/download.html)
22 |
23 | ### Json Files:
24 | - Use same json files from [ALBEF](https://github.com/salesforce/ALBEF)
25 | - Change the image path in json files according to your downloaded images (In CC3M and SBU, some images can not be crawled, thus, you should consider about these missing images when creating json files)
26 |
27 |
28 | ### Requirements:
29 | * pytorch 1.8.0
30 | * transformers 4.8.1
31 | * timm 0.4.9
32 |
33 |
34 | ### Pre-training:
35 | 1. Pre-train the model using 4 A100 GPUs:
36 |
python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain/
37 |
38 | ### Downstream tasks:
39 | 1. IRTR (MS-COCO) using 4 A100 GPUs:
40 | python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Retrieval.py --config ./configs/Retrieval_coco.yaml --output_dir output/Retrieval_coco/ --checkpoint [Pretrained checkpoint]
41 |
42 | 2. IRTR (Flickr) using 4 A100 GPUs:
43 | python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Retrieval.py --config ./configs/Retrieval_flickr.yaml --output_dir output/Retrieval_coco/ --checkpoint [Pretrained checkpoint]
44 |
45 | 3. NLVR using 4 A100 GPUs:
46 | python3 -m torch.distributed.launch --nproc_per_node=4 --use_env Pretrain_nlvr.py --config ./configs/NLVR_pretrain.yaml --output_dir output/NLVR_pretrain/ --checkpoint [Pretrained checkpoint]
47 | python3 -m torch.distributed.launch --nproc_per_node=4 --use_env NLVR.py --config ./configs/NLVR.yaml --output_dir output/NLVR/ --checkpoint [NLVR-Pretrained checkpoint]
48 |
49 |
50 | 4. VQA using 4 A100 GPUs:
51 | python3 -m torch.distributed.launch --nproc_per_node=4 --use_env VQA.py --config ./configs/VQA.yaml --output_dir output/vqa/ --checkpoint [Pretrained checkpoint]
52 |
53 | #### If you have any questions or problems to run this code, please mail to wotjr3868@snu.ac.kr or gxq9106@gmail.com. Thank you!
54 |
55 |
56 | ### Acknowledgement:
57 | Our code implementation is largely borrowed from [ALBEF](https://github.com/salesforce/ALBEF#download) since our method is mainly built upon it. We appreciate the original authors for sharing code.
58 |
59 |
--------------------------------------------------------------------------------
/configs/NLVR.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['dataset/downstream_json/nlvr_train.json']
2 | val_file: ['dataset/downstream_json/nlvr_dev.json']
3 | test_file: ['dataset/downstream_json/nlvr_test.json']
4 |
5 | image_root: '/SHARE_ST/mind/dataset/nlvr2/'
6 |
7 | image_res: 384
8 | batch_size: 16
9 |
10 | bert_config: 'configs/config_bert.json'
11 |
12 | alpha: 0
13 | distill: False
14 | warm_up: True
15 | eval_ema: False
16 |
17 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02}
18 | schedular: {sched: cosine, lr: 2e-5, epochs: 10, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/configs/NLVR_pretrain.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['dataset/json_pretrain/coco.json',
2 | 'dataset/json_pretrain/vg.json',
3 | 'dataset/json_pretrain/cc3m_train.json',
4 | 'dataset/json_pretrain/cc3m_val.json',
5 | 'dataset/json_pretrain/sbu.json',
6 | ]
7 |
8 | # each train_file (json) contains a python list where each item is {'image': img_path, 'caption': text or list_of_text }
9 |
10 | bert_config: 'configs/config_bert.json'
11 |
12 | image_res: 256
13 | vision_width: 768
14 | embed_dim: 256
15 | batch_size: 64
16 |
17 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02}
18 | schedular: {sched: cosine, lr: 2e-5, epochs: 1, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/configs/Pretrain.yaml:
--------------------------------------------------------------------------------
1 | train_file: [ 'dataset/json_pretrain/coco.json',
2 | 'dataset/json_pretrain/vg.json',
3 | 'dataset/json_pretrain/cc3m_train.json',
4 | 'dataset/json_pretrain/cc3m_val.json',
5 | 'dataset/json_pretrain/sbu.json'
6 | ]
7 | # each train_file (json) contains a python list where each item is {'image': img_path, 'caption': text or list_of_text }
8 | bert_config: 'configs/config_bert.json'
9 |
10 | image_res: 256
11 | vision_width: 768 #base : 768 #samll : 512
12 | embed_dim: 256
13 | batch_size: 128
14 | temp: 0.07
15 | mlm_probability: 0.5
16 | queue_size: 48000
17 | search_space: 1920
18 | train_epochs: 20
19 |
20 | optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.02}
21 | schedular: {sched: cosine, lr: 1e-4, epochs: 30, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 20, cooldown_epochs: 0}
22 |
--------------------------------------------------------------------------------
/configs/Retrieval_coco.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['dataset/downstream_json/coco_train.json']
2 | val_file: 'dataset/downstream_json/coco_val.json'
3 | test_file: 'dataset/downstream_json/coco_test.json'
4 | image_root: '/dataset/coco/'
5 |
6 | bert_config: 'configs/config_bert.json'
7 |
8 | image_res: 384
9 | batch_size_train: 64
10 | batch_size_test: 96
11 |
12 | queue_size: 65280
13 | momentum: 0.995
14 | use_momentum: True
15 |
16 | vision_width: 768
17 | embed_dim: 256
18 | temp: 0.07
19 | k_test: 256
20 |
21 | alpha: 0
22 | distill: False
23 | warm_up: True
24 |
25 |
26 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02}
27 | schedular: {sched: cosine, lr: 1e-5, epochs: 5, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/configs/Retrieval_flickr.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['dataset/downstream_json/flickr30k_train.json']
2 | val_file: 'dataset/downstream_json/flickr30k_val.json'
3 | test_file: 'dataset/downstream_json/flickr30k_test.json'
4 | image_root: '/dataset/flickr30k_images'
5 |
6 | bert_config: 'configs/config_bert.json'
7 |
8 | image_res: 384
9 | batch_size_train: 64
10 | batch_size_test: 96
11 |
12 | queue_size: 65280
13 | use_momentum: True
14 | momentum: 0.995
15 |
16 | vision_width: 768
17 | embed_dim: 256
18 | temp: 0.07
19 | k_test: 128
20 |
21 | alpha: 0
22 | distill: False
23 | warm_up: True
24 |
25 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02}
26 | schedular: {sched: cosine, lr: 1e-5, epochs: 10, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0}
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/configs/VQA.yaml:
--------------------------------------------------------------------------------
1 | train_file: ['dataset/downstream_json/vqa_train.json',
2 | 'dataset/downstream_json/vqa_val.json',
3 | 'dataset/downstream_json/vg_qa.json']
4 |
5 | test_file: ['dataset/downstream_json/vqa_test.json']
6 | answer_list: 'dataset/downstream_json/answer_list.json'
7 |
8 | vqa_root: 'dataset/coco/' #train2014/
9 | vg_root: 'dataset/vg/' #image/
10 |
11 | image_res: 384
12 | batch_size_train: 32
13 | batch_size_test: 16
14 | k_test: 128
15 |
16 | alpha: 0
17 | distill: False
18 | warm_up: True
19 |
20 | eos: '[SEP]'
21 |
22 | bert_config: 'configs/config_bert.json'
23 |
24 | optimizer: {opt: adamW, lr: 2e-5, weight_decay: 0.02}
25 | schedular: {sched: cosine, lr: 2e-5, epochs: 8, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 4, cooldown_epochs: 0}
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/configs/config_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522,
19 | "fusion_layer": 6,
20 | "concat_layer": 6,
21 | "encoder_width": 768
22 | }
23 |
--------------------------------------------------------------------------------
/dataset/caption_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import json
3 | import os
4 | import random
5 |
6 | from torch.utils.data import Dataset
7 |
8 | from PIL import Image
9 | from PIL import ImageFile
10 | ImageFile.LOAD_TRUNCATED_IMAGES = True
11 | Image.MAX_IMAGE_PIXELS = None
12 |
13 | from dataset.utils import pre_caption
14 |
15 |
16 | class re_train_dataset(Dataset):
17 | def __init__(self, ann_file, transform, image_root, max_words=30):
18 | self.ann = []
19 | for f in ann_file:
20 | self.ann += json.load(open(f,'r'))
21 | self.transform = transform
22 | self.image_root = image_root
23 | self.max_words = max_words
24 | self.img_ids = {}
25 |
26 | n = 0
27 | for ann in self.ann:
28 | img_id = ann['image_id']
29 | if img_id not in self.img_ids.keys():
30 | self.img_ids[img_id] = n
31 | n += 1
32 |
33 | def __len__(self):
34 | return len(self.ann)
35 |
36 | def __getitem__(self, index):
37 |
38 | ann = self.ann[index]
39 |
40 | image_path = os.path.join(self.image_root,ann['image'])
41 | image = Image.open(image_path).convert('RGB')
42 | image = self.transform(image)
43 |
44 | caption = pre_caption(ann['caption'], self.max_words)
45 |
46 | return image, caption, self.img_ids[ann['image_id']]
47 |
48 |
49 |
50 | class re_eval_dataset(Dataset):
51 | def __init__(self, ann_file, transform, image_root, max_words=30):
52 | self.ann = json.load(open(ann_file,'r'))
53 | self.transform = transform
54 | self.image_root = image_root
55 | self.max_words = max_words
56 |
57 | self.text = []
58 | self.image = []
59 | self.txt2img = {}
60 | self.img2txt = {}
61 |
62 | txt_id = 0
63 | for img_id, ann in enumerate(self.ann):
64 | self.image.append(ann['image'])
65 | self.img2txt[img_id] = []
66 | for i, caption in enumerate(ann['caption']):
67 | self.text.append(pre_caption(caption,self.max_words))
68 | self.img2txt[img_id].append(txt_id)
69 | self.txt2img[txt_id] = img_id
70 | txt_id += 1
71 |
72 | def __len__(self):
73 | return len(self.image)
74 |
75 | def __getitem__(self, index):
76 |
77 | image_path = os.path.join(self.image_root, self.ann[index]['image'])
78 | image = Image.open(image_path).convert('RGB')
79 | image = self.transform(image)
80 |
81 | return image, index
82 |
83 |
84 |
85 | class pretrain_dataset(Dataset):
86 | def __init__(self, ann_file, transform, max_words=30):
87 | self.ann = []
88 | for f in ann_file:
89 | self.ann += json.load(open(f,'r'))
90 | self.transform = transform
91 | self.max_words = max_words
92 |
93 | n = 0
94 | self.ex_index_set=[]
95 | for ann in self.ann:
96 | self.ex_index_set.append(n)
97 | n += 1
98 |
99 | def __len__(self):
100 | return len(self.ann)
101 |
102 | def __getitem__(self, index):
103 |
104 | ann = self.ann[index]
105 | idx= self.ex_index_set[index]
106 |
107 | if type(ann['caption']) == list:
108 | caption = pre_caption(random.choice(ann['caption']), self.max_words)
109 | else:
110 | caption = pre_caption(ann['caption'], self.max_words)
111 |
112 | image = Image.open(ann['image']).convert('RGB')
113 | image = self.transform(image)
114 |
115 | return image, caption,idx
116 |
117 |
118 | class nlvr_pretrain_dataset(Dataset):
119 | def __init__(self, ann_file, transform, max_words=30):
120 | self.ann = []
121 | for f in ann_file:
122 | self.ann += json.load(open(f,'r'))
123 | self.transform = transform
124 | self.max_words = max_words
125 |
126 | def __len__(self):
127 | return len(self.ann)
128 |
129 |
130 | def __getitem__(self, index):
131 | ann = self.ann[index]
132 |
133 | if type(ann['caption']) == list:
134 | caption = pre_caption(random.choice(ann['caption']), self.max_words)
135 | else:
136 | caption = pre_caption(ann['caption'], self.max_words)
137 |
138 | image = Image.open(ann['image']).convert('RGB')
139 | image = self.transform(image)
140 |
141 | return image, caption
142 |
143 |
144 |
145 |
146 |
--------------------------------------------------------------------------------
/dataset/handle_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torchvision import transforms
4 | from PIL import Image
5 |
6 | from dataset.caption_dataset import re_train_dataset, re_eval_dataset, pretrain_dataset, nlvr_pretrain_dataset
7 | from dataset.nlvr_dataset import nlvr_dataset
8 | from dataset.vqa_dataset import vqa_dataset
9 | import math
10 | from dataset.randaugment import RandomAugment
11 | from dataset.sampler_for_grit import Sampler_for_GRIT
12 | def create_dataset(dataset, config):
13 |
14 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
15 |
16 | pretrain_transform = transforms.Compose([
17 | transforms.RandomResizedCrop(config['image_res'],scale=(0.2, 1.0), interpolation=Image.BICUBIC),
18 | transforms.RandomHorizontalFlip(),
19 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
20 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
21 | transforms.ToTensor(),
22 | normalize,
23 | ])
24 | train_transform = transforms.Compose([
25 | transforms.RandomResizedCrop(config['image_res'],scale=(0.5, 1.0), interpolation=Image.BICUBIC),
26 | transforms.RandomHorizontalFlip(),
27 | RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
28 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
29 | transforms.ToTensor(),
30 | normalize,
31 | ])
32 | test_transform = transforms.Compose([
33 | transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
34 | transforms.ToTensor(),
35 | normalize,
36 | ])
37 |
38 | if dataset=='pretrain':
39 | dataset = pretrain_dataset(config['train_file'], pretrain_transform)
40 | return dataset
41 | elif dataset=='nlvr_pretrain':
42 | dataset = nlvr_pretrain_dataset(config['train_file'], pretrain_transform)
43 | return dataset
44 |
45 | elif dataset=='re':
46 | train_dataset = re_train_dataset(config['train_file'], train_transform, config['image_root'])
47 | val_dataset = re_eval_dataset(config['val_file'], test_transform, config['image_root'])
48 | test_dataset = re_eval_dataset(config['test_file'], test_transform, config['image_root'])
49 | return train_dataset, val_dataset, test_dataset
50 |
51 | elif dataset=='vqa':
52 | train_dataset = vqa_dataset(config['train_file'], train_transform, config['vqa_root'], config['vg_root'], split='train')
53 | vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], split='test', answer_list=config['answer_list'])
54 | return train_dataset, vqa_test_dataset
55 |
56 | elif dataset=='nlvr':
57 | train_dataset = nlvr_dataset(config['train_file'], train_transform, config['image_root'])
58 | val_dataset = nlvr_dataset(config['val_file'], test_transform, config['image_root'])
59 | test_dataset = nlvr_dataset(config['test_file'], test_transform, config['image_root'])
60 | return train_dataset, val_dataset, test_dataset
61 |
62 |
63 | def vqa_collate_fn(batch):
64 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
65 | for image, question, answer, weights in batch:
66 | image_list.append(image)
67 | question_list.append(question)
68 | weight_list += weights
69 | answer_list += answer
70 | n.append(len(answer))
71 | return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
72 |
73 |
74 |
75 |
76 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
77 | samplers = []
78 | for dataset,shuffle in zip(datasets,shuffles):
79 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
80 | samplers.append(sampler)
81 | return samplers
82 |
83 |
84 | def create_fixed_sampler(num_tasks, global_rank,index_set):
85 | samplers = []
86 | sampler = Sampler_for_GRIT(pre_indices=index_set, num_replicas=num_tasks, rank=global_rank)
87 | samplers.append(sampler)
88 | return samplers
89 |
90 |
91 |
92 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
93 | loaders = []
94 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
95 | if is_train:
96 | shuffle = (sampler is None)
97 | drop_last = True
98 | else:
99 | shuffle = False
100 | drop_last = False
101 | loader = DataLoader(
102 | dataset,
103 | batch_size=bs,
104 | num_workers=n_worker,
105 | pin_memory=False,
106 | sampler=sampler,
107 | shuffle=shuffle,
108 | collate_fn=collate_fn,
109 | drop_last=drop_last,
110 | )
111 | loaders.append(loader)
112 | return loaders
113 |
--------------------------------------------------------------------------------
/dataset/nlvr_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 | from dataset.utils import pre_caption
6 |
7 |
8 | class nlvr_dataset(Dataset):
9 | def __init__(self, ann_file, transform, image_root):
10 | self.ann = []
11 | for f in ann_file:
12 | self.ann += json.load(open(f,'r'))
13 | self.transform = transform
14 | self.image_root = image_root
15 | self.max_words = 30
16 |
17 | def __len__(self):
18 | return len(self.ann)
19 |
20 |
21 | def __getitem__(self, index):
22 |
23 | ann = self.ann[index]
24 |
25 | image0_path = os.path.join(self.image_root,ann['images'][0])
26 | image0 = Image.open(image0_path).convert('RGB')
27 | image0 = self.transform(image0)
28 |
29 | image1_path = os.path.join(self.image_root,ann['images'][1])
30 | image1 = Image.open(image1_path).convert('RGB')
31 | image1 = self.transform(image1)
32 |
33 | sentence = pre_caption(ann['sentence'], self.max_words)
34 |
35 | if ann['label']=='True':
36 | label = 1
37 | else:
38 | label = 0
39 |
40 | return image0, image1, sentence, label
--------------------------------------------------------------------------------
/dataset/sampler_for_grit.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import TypeVar, Optional, Iterator, List
3 |
4 | import torch
5 | from torch.utils.data import Sampler, Dataset
6 | import torch.distributed as dist
7 | import time
8 |
9 | T_co = TypeVar('T_co', covariant=True)
10 |
11 | # Just load the pre-defined index array
12 | class Sampler_for_GRIT (Sampler[T_co]):
13 | def __init__(self, pre_indices: List[int] =None, num_replicas: Optional[int] = None,
14 | rank: Optional[int] = None,) -> None:
15 | if num_replicas is None:
16 | if not dist.is_available():
17 | raise RuntimeError("Requires distributed package to be available")
18 | num_replicas = dist.get_world_size()
19 | if rank is None:
20 | if not dist.is_available():
21 | raise RuntimeError("Requires distributed package to be available")
22 | rank = dist.get_rank()
23 | if rank >= num_replicas or rank < 0:
24 | raise ValueError(
25 | "Invalid rank {}, rank should be in the interval"
26 | " [0, {}]".format(rank, num_replicas - 1))
27 | self.num_replicas = num_replicas
28 | self.rank = rank
29 | self.epoch = 0
30 | self.pre_indices= pre_indices
31 | self.index_set_len= int(len(self.pre_indices)/self.num_replicas)
32 |
33 | def __iter__(self) -> Iterator[T_co]:
34 | # subsample
35 | indices = self.pre_indices[ self.rank*self.index_set_len:(self.rank+1)*self.index_set_len]
36 |
37 | return iter(indices)
38 |
39 | def __len__(self) -> int:
40 | return self.index_set_len
41 |
42 | def set_epoch(self, epoch: int) -> None:
43 | self.epoch = epoch
44 |
--------------------------------------------------------------------------------
/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def pre_question(question,max_ques_words):
4 | question = re.sub(
5 | r"([,.'!?\"()*#:;~])",
6 | '',
7 | question.lower(),
8 | ).replace('-', ' ').replace('/', ' ')
9 | question = question.rstrip(' ')
10 |
11 | #truncate question
12 | question_words = question.split(' ')
13 | if len(question_words)>max_ques_words:
14 | question = ' '.join(question_words[:max_ques_words])
15 |
16 | return question
17 |
18 |
19 | def pre_caption(caption,max_words):
20 | caption = re.sub(
21 | r"([,.'!?\"()*#:;~])",
22 | '',
23 | caption.lower(),
24 | ).replace('-', ' ').replace('/', ' ').replace('', 'person')
25 |
26 | caption = re.sub(
27 | r"\s{2,}",
28 | ' ',
29 | caption,
30 | )
31 | caption = caption.rstrip('\n')
32 | caption = caption.strip(' ')
33 |
34 | #truncate caption
35 | caption_words = caption.split(' ')
36 | if len(caption_words)>max_words:
37 | caption = ' '.join(caption_words[:max_words])
38 |
39 | return caption
40 |
41 |
42 | from vqaTools.vqaEval import VQAEval
43 | from refTools.evaluation.refEvaluation import RefEvaluation
44 |
45 | import json
46 | import os
47 | import numpy as np
48 | import torch
49 | import torch.distributed as dist
50 | import torch.nn.functional as F
51 |
52 | import utils
53 | from tqdm import tqdm
54 |
55 |
56 | def vqa_eval(vqa, result_file, test_ques_path):
57 | vqaRes = vqa.loadRes(result_file, test_ques_path)
58 | # create vqaEval object by taking vqa and vqaRes
59 | vqaEval = VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2
60 | # evaluate results
61 | vqaEval.evaluate()
62 |
63 | # print accuracies
64 | print("\n")
65 | print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall']))
66 | print("Per Answer Type Accuracy is the following:")
67 | for ansType in vqaEval.accuracy['perAnswerType']:
68 | print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType]))
69 | print("\n")
70 |
71 | return vqaEval
72 |
73 |
74 |
75 | def collect_result(result, result_dir, filename, is_json=True, is_list=True):
76 | if is_json:
77 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
78 | final_result_file = os.path.join(result_dir, '%s.json'%filename)
79 | json.dump(result,open(result_file,'w'))
80 | else:
81 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,utils.get_rank()))
82 | final_result_file = os.path.join(result_dir, '%s.pth'%filename)
83 | torch.save(result,result_file)
84 |
85 | dist.barrier()
86 |
87 | result = None
88 | if utils.is_main_process():
89 | # combine results from all processes
90 | if is_list:
91 | result = []
92 | else:
93 | result = {}
94 | for rank in range(utils.get_world_size()):
95 | if is_json:
96 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
97 | res = json.load(open(result_file,'r'))
98 | else:
99 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,rank))
100 | res = torch.load(result_file)
101 | if is_list:
102 | result += res
103 | else:
104 | result.update(res)
105 |
106 | return result
107 |
108 |
109 | def save_result(result, result_dir, filename, is_json=True, is_list=True):
110 | if is_json:
111 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
112 | final_result_file = os.path.join(result_dir, '%s.json'%filename)
113 | json.dump(result,open(result_file,'w'))
114 | else:
115 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,utils.get_rank()))
116 | final_result_file = os.path.join(result_dir, '%s.pth'%filename)
117 | torch.save(result,result_file)
118 |
119 | dist.barrier()
120 |
121 | if utils.is_main_process():
122 | # combine results from all processes
123 | if is_list:
124 | result = []
125 | else:
126 | result = {}
127 | for rank in range(utils.get_world_size()):
128 | if is_json:
129 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
130 | res = json.load(open(result_file,'r'))
131 | else:
132 | result_file = os.path.join(result_dir, '%s_rank%d.pth'%(filename,rank))
133 | res = torch.load(result_file)
134 | if is_list:
135 | result += res
136 | else:
137 | result.update(res)
138 | if is_json:
139 | json.dump(result,open(final_result_file,'w'))
140 | else:
141 | torch.save(result,final_result_file)
142 |
143 | print('result file saved to %s'%final_result_file)
144 | dist.barrier()
145 | return final_result_file
146 |
147 |
148 |
149 | def grounding_eval(results,dets,cocos,refer,alpha,mask_size=24):
150 |
151 | correct_A_d, correct_B_d, correct_val_d = 0, 0, 0
152 | correct_A, correct_B, correct_val = 0, 0, 0
153 | num_A,num_B,num_val = 0,0,0
154 |
155 | for res in tqdm(results):
156 |
157 | ref_id = res['ref_id']
158 | ref = refer.Refs[ref_id]
159 | ref_box = refer.refToAnn[ref_id]['bbox']
160 | image = refer.Imgs[ref['image_id']]
161 |
162 | mask = res['pred'].cuda().view(1,1,mask_size,mask_size)
163 | mask = F.interpolate(mask,size = (image['height'],image['width']), mode='bicubic').squeeze()
164 |
165 | # rank detection boxes
166 | max_score = 0
167 | for det in dets[str(ref['image_id'])]:
168 | score = mask[int(det[1]):int(det[1]+det[3]),int(det[0]):int(det[0]+det[2])]
169 | area = det[2]*det[3]
170 | score = score.sum() / area**alpha
171 | if score>max_score:
172 | pred_box = det[:4]
173 | max_score = score
174 |
175 | IoU_det = computeIoU(ref_box, pred_box)
176 |
177 | if ref['split']=='testA':
178 | num_A += 1
179 | if IoU_det >= 0.5:
180 | correct_A_d += 1
181 | elif ref['split']=='testB':
182 | num_B += 1
183 | if IoU_det >= 0.5:
184 | correct_B_d += 1
185 | elif ref['split']=='val':
186 | num_val += 1
187 | if IoU_det >= 0.5:
188 | correct_val_d += 1
189 |
190 | eval_result = {'val_d':correct_val_d/num_val,'testA_d':correct_A_d/num_A,'testB_d':correct_B_d/num_B}
191 |
192 | for metric, acc in eval_result.items():
193 | print(f'{metric}: {acc:.3f}')
194 |
195 | return eval_result
196 |
197 |
198 |
199 | # IoU function
200 | def computeIoU(box1, box2):
201 | # each box is of [x1, y1, w, h]
202 | inter_x1 = max(box1[0], box2[0])
203 | inter_y1 = max(box1[1], box2[1])
204 | inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1)
205 | inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1)
206 |
207 | if inter_x1 < inter_x2 and inter_y1 < inter_y2:
208 | inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
209 | else:
210 | inter = 0
211 | union = box1[2]*box1[3] + box2[2]*box2[3] - inter
212 | return float(inter)/union
213 |
214 |
215 |
--------------------------------------------------------------------------------
/dataset/vqa_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from dataset.utils import pre_question
7 |
8 |
9 | class vqa_dataset(Dataset):
10 | def __init__(self, ann_file, transform, vqa_root, vg_root, eos='[SEP]', split="train", max_ques_words=30, answer_list=''):
11 | self.split = split
12 | self.ann = []
13 | for f in ann_file:
14 | self.ann += json.load(open(f,'r'))
15 |
16 | self.transform = transform
17 | self.vqa_root = vqa_root
18 | self.vg_root = vg_root
19 | self.max_ques_words = max_ques_words
20 | self.eos = eos
21 |
22 | if split=='test':
23 | self.max_ques_words = 50 # do not limit question length during test
24 | self.answer_list = json.load(open(answer_list,'r'))
25 |
26 |
27 | def __len__(self):
28 | return len(self.ann)
29 |
30 | def __getitem__(self, index):
31 |
32 | ann = self.ann[index]
33 |
34 | if ann['dataset']=='vqa':
35 | image_path = os.path.join(self.vqa_root,ann['image'])
36 | elif ann['dataset']=='vg':
37 | image_path = os.path.join(self.vg_root,ann['image'])
38 |
39 | image = Image.open(image_path).convert('RGB')
40 | image = self.transform(image)
41 |
42 | if self.split == 'test':
43 | question = pre_question(ann['question'],self.max_ques_words)
44 | question_id = ann['question_id']
45 | return image, question, question_id
46 |
47 |
48 | elif self.split=='train':
49 |
50 | question = pre_question(ann['question'],self.max_ques_words)
51 |
52 | if ann['dataset']=='vqa':
53 |
54 | answer_weight = {}
55 | for answer in ann['answer']:
56 | if answer in answer_weight.keys():
57 | answer_weight[answer] += 1/len(ann['answer'])
58 | else:
59 | answer_weight[answer] = 1/len(ann['answer'])
60 |
61 | answers = list(answer_weight.keys())
62 | weights = list(answer_weight.values())
63 |
64 | elif ann['dataset']=='vg':
65 | answers = [ann['answer']]
66 | weights = [0.5]
67 |
68 | answers = [answer+self.eos for answer in answers]
69 |
70 | return image, question, answers, weights
--------------------------------------------------------------------------------
/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/img.png
--------------------------------------------------------------------------------
/models/GRIT_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import random
4 | import numpy as np
5 |
6 | class GRIT():
7 | def __init__(self,
8 | config,device,num_steps):
9 |
10 | batch_size=config['batch_size']
11 | embed_dim = config['embed_dim']
12 |
13 | self.search_space=config['search_space']
14 | total_sample_len= num_steps *batch_size
15 | self.G_index_set= list(range(total_sample_len))
16 | self.G_idx=0
17 | self.queue_size= config['queue_size']
18 | self.num_small_queue = int(config['queue_size']/self.search_space)
19 |
20 | # For simplicity
21 | assert self.queue_size % self.search_space == 0
22 | assert self.queue_size % batch_size == 0
23 | assert self.search_space % batch_size ==0
24 |
25 | self.image_queue= torch.randn(embed_dim, self.queue_size).to(device)
26 | self.text_queue= torch.randn(embed_dim, self.queue_size).to(device)
27 | self.idx_queue= torch.randn(self.queue_size).to(device)
28 | self.queue_ptr= torch.zeros(1, dtype=torch.long)
29 |
30 | @torch.no_grad()
31 | def grit_second_third_phase(self,cur_step,temp,num_steps):
32 | if self.queue_ptr[0] == 0:
33 | self.example_level_shuffle()
34 |
35 | # Divide / Grouping
36 | for q in range(self.num_small_queue):
37 | # Grouping
38 | index_m = self.grouping(self.image_queue[:,q*self.search_space:(q+1)*self.search_space ],self.text_queue[:,q*self.search_space:(q+1)*self.search_space ],self.idx_queue[q*self.search_space:(q+1)*self.search_space ],temp)
39 |
40 | # fill in G
41 | self.G_index_set[(self.G_idx)*self.search_space:(self.G_idx+1)*self.search_space] = index_m
42 | self.G_idx+=1
43 |
44 | elif cur_step== num_steps-1:
45 | # Example level Shuffle
46 | self.remaining_queue()
47 | self.example_level_shuffle()
48 | slice_queue= int (int (self.queue_ptr[0]) // self.search_space )
49 |
50 | if slice_queue >0:
51 | for q in range(slice_queue):
52 | index_m = self.grouping(self.image_queue[:,q*self.search_space:(q+1)*self.search_space],self.text_queue[:,q*self.search_space:(q+1)*self.search_space],self.idx_queue[q*self.search_space:(q+1)*self.search_space],temp)
53 |
54 | self.G_index_set[(self.G_idx)*self.search_space:(self.G_idx+1)*self.search_space] = index_m
55 | self.G_idx+=1
56 |
57 | # Remaining indices
58 | index_m = self.grouping(self.image_queue[:,slice_queue*self.search_space:self.queue_ptr[0]],self.text_queue[:,slice_queue*self.search_space:self.queue_ptr[0]],self.idx_queue[slice_queue*self.search_space:self.queue_ptr[0]],temp)
59 | self.G_index_set[(self.G_idx)*self.search_space:] = index_m
60 | self.G_idx+=1
61 |
62 |
63 |
64 | @torch.no_grad()
65 | def grouping(self,image_sub_queue,text_sub_queue,index_sub_queue,temp):
66 | sim_i2t_sg = F.softmax(image_sub_queue.detach().t() @ text_sub_queue.detach() / temp,dim=1)
67 | sim_i2t_sg.fill_diagonal_(0)
68 | sim_t2i_sg = F.softmax(text_sub_queue.detach().t() @ image_sub_queue.detach() / temp,dim=1)
69 | sim_t2i_sg.fill_diagonal_(0)
70 |
71 | bs= image_sub_queue.size()[1]
72 | I_index_set=[]
73 | start = torch.randint(low=0, high=int(bs-1),size=(1,))[0]
74 | start = start.to(image_sub_queue.device)
75 | next_i_idx=start
76 | I_index_set.append(index_sub_queue[start].to(torch.long).detach().cpu())
77 |
78 | group_iter=int((bs-1)//2)
79 |
80 | for group_idx in range(group_iter):
81 | next_t= torch.topk(sim_i2t_sg[next_i_idx],1)
82 | next_t_idx=next_t.indices[0]
83 | sim_i2t_sg[next_i_idx,:]=0
84 | sim_i2t_sg[:,next_i_idx]=0
85 | sim_t2i_sg[next_i_idx,:]=0
86 | sim_t2i_sg[:,next_i_idx]=0
87 | next_i = torch.topk(sim_t2i_sg[next_t_idx],1)
88 | next_i_idx=next_i.indices[0]
89 |
90 | sim_i2t_sg[next_t_idx,:]=0
91 | sim_i2t_sg[:,next_t_idx]=0
92 | sim_t2i_sg[next_t_idx,:]=0
93 | sim_t2i_sg[:,next_t_idx]=0
94 | I_index_set.append(index_sub_queue[next_t_idx].to(torch.long).detach().cpu())
95 | I_index_set.append(index_sub_queue[next_i_idx].to(torch.long).detach().cpu())
96 |
97 | if int((bs-1)%2) !=0:
98 | next_t_idx= torch.argmax(sim_i2t_sg[next_i_idx,:])
99 | I_index_set.append(index_sub_queue[next_t_idx].to(torch.long).detach().cpu())
100 |
101 | return I_index_set
102 |
103 |
104 | @torch.no_grad()
105 | def example_level_shuffle(self):
106 | shuffle_idx = torch.randperm(self.image_queue.shape[1])
107 | self.image_queue=self.image_queue[:,shuffle_idx].view(self.image_queue.size())
108 | self.text_queue=self.text_queue[:,shuffle_idx].view(self.text_queue.size())
109 | self.idx_queue=self.idx_queue[shuffle_idx].view(self.idx_queue.size())
110 |
111 | @torch.no_grad()
112 | def remaining_queue (self):
113 | self.image_queue = self.image_queue[:,:self.queue_ptr[0]]
114 | self.text_queue = self.text_queue[:,:self.queue_ptr[0]]
115 | self.idx_queue = self.idx_queue[:self.queue_ptr[0]]
116 |
117 |
118 |
119 | @torch.no_grad()
120 | def collecting(self, image_feat, text_feat,idx):
121 | batch_size = image_feat.shape[0]
122 | ptr = int(self.queue_ptr)
123 |
124 | # replace the keys at ptr (dequeue and enqueue)
125 | self.image_queue[:, ptr:ptr + batch_size] = image_feat.T
126 | self.text_queue[:, ptr:ptr + batch_size] = text_feat.T
127 | self.idx_queue[ ptr:ptr + batch_size] = idx.detach()
128 | ptr = (ptr + batch_size) % self.queue_size # move pointer
129 |
130 | self.queue_ptr[0] = ptr
131 |
132 |
133 | @torch.no_grad()
134 | def chunks(lst, n):
135 | """Yield successive n-sized chunks from lst."""
136 | for i in range(0, len(lst), n):
137 | yield lst[i:i + n]
138 |
139 |
140 | @torch.no_grad()
141 | def mini_batch_level_shuffle(index_set, batch_size):
142 | divided_G_index_set = list(chunks(index_set,batch_size))
143 | total_chunk_size = len(divided_G_index_set)
144 | chunk_arr = np.arange(total_chunk_size)
145 | random.shuffle(chunk_arr)
146 | shuffled_G_index_set=[]
147 | for ind in chunk_arr:
148 | shuffled_G_index_set += divided_G_index_set[ind]
149 |
150 | return shuffled_G_index_set
151 |
152 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/models/__init__.py
--------------------------------------------------------------------------------
/models/model_nlvr.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from models.vit import VisionTransformer
3 | from models.xbert import BertConfig, BertModel
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | class ALBEF(nn.Module):
10 | def __init__(self,
11 | text_encoder = None,
12 | tokenizer = None,
13 | config = None,
14 | ):
15 | super().__init__()
16 |
17 | self.tokenizer = tokenizer
18 | self.distill = config['distill']
19 |
20 | self.visual_encoder = VisionTransformer(
21 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
22 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
23 |
24 | bert_config = BertConfig.from_json_file(config['bert_config'])
25 | bert_config.num_hidden_layers = 18
26 |
27 | self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
28 | self.cls_head = nn.Sequential(
29 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
30 | nn.ReLU(),
31 | nn.Linear(self.text_encoder.config.hidden_size, 2)
32 | )
33 |
34 | self.share_cross_attention(self.text_encoder.encoder)
35 |
36 | if self.distill:
37 | self.visual_encoder_m = VisionTransformer(
38 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
39 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
40 | self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
41 | self.share_cross_attention(self.text_encoder_m.encoder)
42 |
43 | self.cls_head_m = nn.Sequential(
44 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
45 | nn.ReLU(),
46 | nn.Linear(self.text_encoder.config.hidden_size, 2)
47 | )
48 |
49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
50 | [self.text_encoder,self.text_encoder_m],
51 | [self.cls_head,self.cls_head_m],
52 | ]
53 | self.copy_params()
54 | self.momentum = 0.995
55 |
56 |
57 | def forward(self, image, text, targets, alpha=0, train=True):
58 |
59 | image_embeds = self.visual_encoder(image)
60 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
61 |
62 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
63 |
64 | output = self.text_encoder(text.input_ids,
65 | attention_mask = text.attention_mask,
66 | encoder_hidden_states = [image0_embeds,image1_embeds],
67 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
68 | image_atts[image0_embeds.size(0):]],
69 | return_dict = True,
70 | )
71 | hidden_state = output.last_hidden_state[:,0,:]
72 | prediction = self.cls_head(hidden_state)
73 |
74 | if train:
75 | if self.distill:
76 | with torch.no_grad():
77 | self._momentum_update()
78 | image_embeds_m = self.visual_encoder_m(image)
79 | image0_embeds_m, image1_embeds_m = torch.split(image_embeds_m,targets.size(0))
80 | output_m = self.text_encoder_m(text.input_ids,
81 | attention_mask = text.attention_mask,
82 | encoder_hidden_states = [image0_embeds_m,image1_embeds_m],
83 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
84 | image_atts[image0_embeds.size(0):]],
85 | return_dict = True,
86 | )
87 | prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:])
88 |
89 | loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum(
90 | F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean()
91 | else:
92 | loss = F.cross_entropy(prediction, targets)
93 | return loss
94 | else:
95 | return prediction
96 |
97 |
98 |
99 | @torch.no_grad()
100 | def copy_params(self):
101 | for model_pair in self.model_pairs:
102 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
103 | param_m.data.copy_(param.data) # initialize
104 | param_m.requires_grad = False # not update by gradient
105 |
106 |
107 | @torch.no_grad()
108 | def _momentum_update(self):
109 | for model_pair in self.model_pairs:
110 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
111 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
112 |
113 |
114 | def share_cross_attention(self, model):
115 |
116 | for i in range(6):
117 | layer_num = 6+i*2
118 | modules_0 = model.layer[layer_num].crossattention.self._modules
119 | modules_1 = model.layer[layer_num+1].crossattention.self._modules
120 |
121 | for name in modules_0.keys():
122 | if 'key' in name or 'value' in name:
123 | module_0 = modules_0[name]
124 | module_1 = modules_1[name]
125 | if hasattr(module_0, "weight"):
126 | module_0.weight = module_1.weight
127 | if hasattr(module_0, "bias"):
128 | module_0.bias = module_1.bias
129 |
--------------------------------------------------------------------------------
/models/model_pretrain_nlvr.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from models.vit import VisionTransformer
3 | from models.xbert import BertConfig, BertModel
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | class ALBEF(nn.Module):
10 | def __init__(self,
11 | text_encoder = None,
12 | tokenizer = None,
13 | config = None,
14 | ):
15 | super().__init__()
16 |
17 | self.tokenizer = tokenizer
18 | vision_width = config['vision_width']
19 | embed_dim = config['embed_dim']
20 |
21 | self.visual_encoder = VisionTransformer(
22 | img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
23 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
24 |
25 | bert_config = BertConfig.from_json_file(config['bert_config'])
26 | bert_config.num_hidden_layers = 18
27 | self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
28 |
29 | #share the cross-attention layers for two images
30 | self.share_cross_attention(self.text_encoder.encoder)
31 |
32 | text_width = self.text_encoder.config.hidden_size
33 | self.vision_proj = nn.Linear(vision_width, embed_dim)
34 | self.text_proj = nn.Linear(text_width, embed_dim)
35 | self.temp = nn.Parameter(torch.ones([]) * 0.07)
36 | self.ta_head = nn.Linear(self.text_encoder.config.hidden_size, 3)
37 |
38 |
39 | def forward(self, image, text):
40 | image_embeds = self.visual_encoder(image)
41 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
42 | with torch.no_grad():
43 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
44 | sim = image_feat @ image_feat.t() / 0.07
45 | weights = F.softmax(sim,dim=1)
46 | weights.fill_diagonal_(0)
47 |
48 | image_inputs = [[],[]]
49 | labels = []
50 | for b in range(image.size(0)):
51 | if torch.rand(1)>1/3:
52 | idx = torch.multinomial(weights[b], 1).item()
53 | if torch.rand(1)>0.5:
54 | image_inputs[0].append(image_embeds[b])
55 | image_inputs[1].append(image_embeds[idx])
56 | labels.append(0)
57 | else:
58 | image_inputs[1].append(image_embeds[b])
59 | image_inputs[0].append(image_embeds[idx])
60 | labels.append(1)
61 | else:
62 | idx = torch.multinomial(weights[b], 2)
63 | image_inputs[0].append(image_embeds[idx[0]])
64 | image_inputs[1].append(image_embeds[idx[1]])
65 | labels.append(2)
66 |
67 | image_inputs[0] = torch.stack(image_inputs[0],dim=0)
68 | image_inputs[1] = torch.stack(image_inputs[1],dim=0)
69 | labels = torch.LongTensor(labels).to(image.device)
70 |
71 | output = self.text_encoder(text.input_ids,
72 | attention_mask = text.attention_mask,
73 | encoder_hidden_states = image_inputs,
74 | encoder_attention_mask = [image_atts,image_atts],
75 | return_dict = True,
76 | )
77 |
78 | pred = self.ta_head(output.last_hidden_state[:,0,:])
79 | loss = F.cross_entropy(pred, labels)
80 |
81 | return loss
82 |
83 |
84 |
85 | def share_cross_attention(self, model):
86 |
87 | for i in range(6):
88 | layer_num = 6+i*2
89 | modules_0 = model.layer[layer_num].crossattention.self._modules
90 | modules_1 = model.layer[layer_num+1].crossattention.self._modules
91 |
92 | for name in modules_0.keys():
93 | if 'key' in name or 'value' in name:
94 | module_0 = modules_0[name]
95 | module_1 = modules_1[name]
96 | if hasattr(module_0, "weight"):
97 | module_0.weight = module_1.weight
98 | if hasattr(module_0, "bias"):
99 | module_0.bias = module_1.bias
--------------------------------------------------------------------------------
/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .adamp import AdamP
2 | from .adamw import AdamW
3 | from .adafactor import Adafactor
4 | from .adahessian import Adahessian
5 | from .lookahead import Lookahead
6 | from .nadam import Nadam
7 | from .novograd import NovoGrad
8 | from .nvnovograd import NvNovoGrad
9 | from .radam import RAdam
10 | from .rmsprop_tf import RMSpropTF
11 | from .sgdp import SGDP
12 |
13 | from .optim_factory import create_optimizer
--------------------------------------------------------------------------------
/optim/adafactor.py:
--------------------------------------------------------------------------------
1 | """ Adafactor Optimizer
2 |
3 | Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
4 |
5 | Original header/copyright below.
6 |
7 | """
8 | # Copyright (c) Facebook, Inc. and its affiliates.
9 | #
10 | # This source code is licensed under the MIT license found in the
11 | # LICENSE file in the root directory of this source tree.
12 | import torch
13 | import math
14 |
15 |
16 | class Adafactor(torch.optim.Optimizer):
17 | """Implements Adafactor algorithm.
18 | This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
19 | (see https://arxiv.org/abs/1804.04235)
20 |
21 | Note that this optimizer internally adjusts the learning rate depending on the
22 | *scale_parameter*, *relative_step* and *warmup_init* options.
23 |
24 | To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
25 | `relative_step=False`.
26 |
27 | Arguments:
28 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups
29 | lr (float, optional): external learning rate (default: None)
30 | eps (tuple[float, float]): regularization constants for square gradient
31 | and parameter scale respectively (default: (1e-30, 1e-3))
32 | clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
33 | decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
34 | beta1 (float): coefficient used for computing running averages of gradient (default: None)
35 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
36 | scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
37 | relative_step (bool): if True, time-dependent learning rate is computed
38 | instead of external learning rate (default: True)
39 | warmup_init (bool): time-dependent learning rate computation depends on
40 | whether warm-up initialization is being used (default: False)
41 | """
42 |
43 | def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
44 | decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
45 | relative_step = lr is None
46 | if warmup_init and not relative_step:
47 | raise ValueError('warmup_init requires relative_step=True')
48 |
49 | beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
50 | defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
51 | beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
52 | relative_step=relative_step, warmup_init=warmup_init)
53 | super(Adafactor, self).__init__(params, defaults)
54 |
55 | @staticmethod
56 | def _get_lr(param_group, param_state):
57 | if param_group['relative_step']:
58 | min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
59 | lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
60 | param_scale = 1.0
61 | if param_group['scale_parameter']:
62 | param_scale = max(param_group['eps_scale'], param_state['RMS'])
63 | param_group['lr'] = lr_t * param_scale
64 | return param_group['lr']
65 |
66 | @staticmethod
67 | def _get_options(param_group, param_shape):
68 | factored = len(param_shape) >= 2
69 | use_first_moment = param_group['beta1'] is not None
70 | return factored, use_first_moment
71 |
72 | @staticmethod
73 | def _rms(tensor):
74 | return tensor.norm(2) / (tensor.numel() ** 0.5)
75 |
76 | def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
77 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
78 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
79 | return torch.mul(r_factor, c_factor)
80 |
81 | def step(self, closure=None):
82 | """Performs a single optimization step.
83 | Arguments:
84 | closure (callable, optional): A closure that reevaluates the model and returns the loss.
85 | """
86 | loss = None
87 | if closure is not None:
88 | loss = closure()
89 |
90 | for group in self.param_groups:
91 | for p in group['params']:
92 | if p.grad is None:
93 | continue
94 | grad = p.grad.data
95 | if grad.dtype in {torch.float16, torch.bfloat16}:
96 | grad = grad.float()
97 | if grad.is_sparse:
98 | raise RuntimeError('Adafactor does not support sparse gradients.')
99 |
100 | state = self.state[p]
101 | grad_shape = grad.shape
102 |
103 | factored, use_first_moment = self._get_options(group, grad_shape)
104 | # State Initialization
105 | if len(state) == 0:
106 | state['step'] = 0
107 |
108 | if use_first_moment:
109 | # Exponential moving average of gradient values
110 | state['exp_avg'] = torch.zeros_like(grad)
111 | if factored:
112 | state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
113 | state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
114 | else:
115 | state['exp_avg_sq'] = torch.zeros_like(grad)
116 |
117 | state['RMS'] = 0
118 | else:
119 | if use_first_moment:
120 | state['exp_avg'] = state['exp_avg'].to(grad)
121 | if factored:
122 | state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
123 | state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
124 | else:
125 | state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
126 |
127 | p_data_fp32 = p.data
128 | if p.data.dtype in {torch.float16, torch.bfloat16}:
129 | p_data_fp32 = p_data_fp32.float()
130 |
131 | state['step'] += 1
132 | state['RMS'] = self._rms(p_data_fp32)
133 | lr_t = self._get_lr(group, state)
134 |
135 | beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
136 | update = grad ** 2 + group['eps']
137 | if factored:
138 | exp_avg_sq_row = state['exp_avg_sq_row']
139 | exp_avg_sq_col = state['exp_avg_sq_col']
140 |
141 | exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
142 | exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
143 | #exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
144 | #exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
145 |
146 | # Approximation of exponential moving average of square of gradient
147 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
148 | update.mul_(grad)
149 | else:
150 | exp_avg_sq = state['exp_avg_sq']
151 |
152 | exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
153 | #exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
154 | update = exp_avg_sq.rsqrt().mul_(grad)
155 |
156 | update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
157 | update.mul_(lr_t)
158 |
159 | if use_first_moment:
160 | exp_avg = state['exp_avg']
161 | exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
162 | #exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
163 | update = exp_avg
164 |
165 | if group['weight_decay'] != 0:
166 | p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
167 | #p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
168 |
169 | p_data_fp32.add_(-update)
170 |
171 | if p.data.dtype in {torch.float16, torch.bfloat16}:
172 | p.data.copy_(p_data_fp32)
173 |
174 | return loss
--------------------------------------------------------------------------------
/optim/adahessian.py:
--------------------------------------------------------------------------------
1 | """ AdaHessian Optimizer
2 |
3 | Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
4 | Originally licensed MIT, Copyright 2020, David Samuel
5 | """
6 | import torch
7 |
8 |
9 | class Adahessian(torch.optim.Optimizer):
10 | """
11 | Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
12 |
13 | Arguments:
14 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups
15 | lr (float, optional): learning rate (default: 0.1)
16 | betas ((float, float), optional): coefficients used for computing running averages of gradient and the
17 | squared hessian trace (default: (0.9, 0.999))
18 | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
20 | hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
21 | update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
22 | (to save time) (default: 1)
23 | n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
24 | """
25 |
26 | def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
27 | hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
28 | if not 0.0 <= lr:
29 | raise ValueError(f"Invalid learning rate: {lr}")
30 | if not 0.0 <= eps:
31 | raise ValueError(f"Invalid epsilon value: {eps}")
32 | if not 0.0 <= betas[0] < 1.0:
33 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
34 | if not 0.0 <= betas[1] < 1.0:
35 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
36 | if not 0.0 <= hessian_power <= 1.0:
37 | raise ValueError(f"Invalid Hessian power value: {hessian_power}")
38 |
39 | self.n_samples = n_samples
40 | self.update_each = update_each
41 | self.avg_conv_kernel = avg_conv_kernel
42 |
43 | # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
44 | self.seed = 2147483647
45 | self.generator = torch.Generator().manual_seed(self.seed)
46 |
47 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
48 | super(Adahessian, self).__init__(params, defaults)
49 |
50 | for p in self.get_params():
51 | p.hess = 0.0
52 | self.state[p]["hessian step"] = 0
53 |
54 | @property
55 | def is_second_order(self):
56 | return True
57 |
58 | def get_params(self):
59 | """
60 | Gets all parameters in all param_groups with gradients
61 | """
62 |
63 | return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
64 |
65 | def zero_hessian(self):
66 | """
67 | Zeros out the accumalated hessian traces.
68 | """
69 |
70 | for p in self.get_params():
71 | if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
72 | p.hess.zero_()
73 |
74 | @torch.no_grad()
75 | def set_hessian(self):
76 | """
77 | Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
78 | """
79 |
80 | params = []
81 | for p in filter(lambda p: p.grad is not None, self.get_params()):
82 | if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
83 | params.append(p)
84 | self.state[p]["hessian step"] += 1
85 |
86 | if len(params) == 0:
87 | return
88 |
89 | if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
90 | self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
91 |
92 | grads = [p.grad for p in params]
93 |
94 | for i in range(self.n_samples):
95 | # Rademacher distribution {-1.0, 1.0}
96 | zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
97 | h_zs = torch.autograd.grad(
98 | grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
99 | for h_z, z, p in zip(h_zs, zs, params):
100 | p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
101 |
102 | @torch.no_grad()
103 | def step(self, closure=None):
104 | """
105 | Performs a single optimization step.
106 | Arguments:
107 | closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
108 | """
109 |
110 | loss = None
111 | if closure is not None:
112 | loss = closure()
113 |
114 | self.zero_hessian()
115 | self.set_hessian()
116 |
117 | for group in self.param_groups:
118 | for p in group['params']:
119 | if p.grad is None or p.hess is None:
120 | continue
121 |
122 | if self.avg_conv_kernel and p.dim() == 4:
123 | p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
124 |
125 | # Perform correct stepweight decay as in AdamW
126 | p.mul_(1 - group['lr'] * group['weight_decay'])
127 |
128 | state = self.state[p]
129 |
130 | # State initialization
131 | if len(state) == 1:
132 | state['step'] = 0
133 | # Exponential moving average of gradient values
134 | state['exp_avg'] = torch.zeros_like(p)
135 | # Exponential moving average of Hessian diagonal square values
136 | state['exp_hessian_diag_sq'] = torch.zeros_like(p)
137 |
138 | exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
139 | beta1, beta2 = group['betas']
140 | state['step'] += 1
141 |
142 | # Decay the first and second moment running average coefficient
143 | exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
144 | exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
145 |
146 | bias_correction1 = 1 - beta1 ** state['step']
147 | bias_correction2 = 1 - beta2 ** state['step']
148 |
149 | k = group['hessian_power']
150 | denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
151 |
152 | # make update
153 | step_size = group['lr'] / bias_correction1
154 | p.addcdiv_(exp_avg, denom, value=-step_size)
155 |
156 | return loss
157 |
--------------------------------------------------------------------------------
/optim/adamp.py:
--------------------------------------------------------------------------------
1 | """
2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
3 |
4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5 | Code: https://github.com/clovaai/AdamP
6 |
7 | Copyright (c) 2020-present NAVER Corp.
8 | MIT license
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.optim.optimizer import Optimizer, required
14 | import math
15 |
16 | class AdamP(Optimizer):
17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
21 | super(AdamP, self).__init__(params, defaults)
22 |
23 | def _channel_view(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | def _layer_view(self, x):
27 | return x.view(1, -1)
28 |
29 | def _cosine_similarity(self, x, y, eps, view_func):
30 | x = view_func(x)
31 | y = view_func(y)
32 |
33 | x_norm = x.norm(dim=1).add_(eps)
34 | y_norm = y.norm(dim=1).add_(eps)
35 | dot = (x * y).sum(dim=1)
36 |
37 | return dot.abs() / x_norm / y_norm
38 |
39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40 | wd = 1
41 | expand_size = [-1] + [1] * (len(p.shape) - 1)
42 | for view_func in [self._channel_view, self._layer_view]:
43 |
44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45 |
46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49 | wd = wd_ratio
50 |
51 | return perturb, wd
52 |
53 | return perturb, wd
54 |
55 | def step(self, closure=None):
56 | loss = None
57 | if closure is not None:
58 | loss = closure()
59 |
60 | for group in self.param_groups:
61 | for p in group['params']:
62 | if p.grad is None:
63 | continue
64 |
65 | grad = p.grad.data
66 | beta1, beta2 = group['betas']
67 | nesterov = group['nesterov']
68 |
69 | state = self.state[p]
70 |
71 | # State initialization
72 | if len(state) == 0:
73 | state['step'] = 0
74 | state['exp_avg'] = torch.zeros_like(p.data)
75 | state['exp_avg_sq'] = torch.zeros_like(p.data)
76 |
77 | # Adam
78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
79 |
80 | state['step'] += 1
81 | bias_correction1 = 1 - beta1 ** state['step']
82 | bias_correction2 = 1 - beta2 ** state['step']
83 |
84 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
86 |
87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
88 | step_size = group['lr'] / bias_correction1
89 |
90 | if nesterov:
91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
92 | else:
93 | perturb = exp_avg / denom
94 |
95 | # Projection
96 | wd_ratio = 1
97 | if len(p.shape) > 1:
98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
99 |
100 | # Weight decay
101 | if group['weight_decay'] > 0:
102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)
103 |
104 | # Step
105 | p.data.add_(-step_size, perturb)
106 |
107 | return loss
108 |
--------------------------------------------------------------------------------
/optim/adamw.py:
--------------------------------------------------------------------------------
1 | """ AdamW Optimizer
2 | Impl copied from PyTorch master
3 | """
4 | import math
5 | import torch
6 | from torch.optim.optimizer import Optimizer
7 |
8 |
9 | class AdamW(Optimizer):
10 | r"""Implements AdamW algorithm.
11 |
12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
14 |
15 | Arguments:
16 | params (iterable): iterable of parameters to optimize or dicts defining
17 | parameter groups
18 | lr (float, optional): learning rate (default: 1e-3)
19 | betas (Tuple[float, float], optional): coefficients used for computing
20 | running averages of gradient and its square (default: (0.9, 0.999))
21 | eps (float, optional): term added to the denominator to improve
22 | numerical stability (default: 1e-8)
23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2)
24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25 | algorithm from the paper `On the Convergence of Adam and Beyond`_
26 | (default: False)
27 |
28 | .. _Adam\: A Method for Stochastic Optimization:
29 | https://arxiv.org/abs/1412.6980
30 | .. _Decoupled Weight Decay Regularization:
31 | https://arxiv.org/abs/1711.05101
32 | .. _On the Convergence of Adam and Beyond:
33 | https://openreview.net/forum?id=ryQu7f-RZ
34 | """
35 |
36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
37 | weight_decay=1e-2, amsgrad=False):
38 | if not 0.0 <= lr:
39 | raise ValueError("Invalid learning rate: {}".format(lr))
40 | if not 0.0 <= eps:
41 | raise ValueError("Invalid epsilon value: {}".format(eps))
42 | if not 0.0 <= betas[0] < 1.0:
43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
44 | if not 0.0 <= betas[1] < 1.0:
45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
46 | defaults = dict(lr=lr, betas=betas, eps=eps,
47 | weight_decay=weight_decay, amsgrad=amsgrad)
48 | super(AdamW, self).__init__(params, defaults)
49 |
50 | def __setstate__(self, state):
51 | super(AdamW, self).__setstate__(state)
52 | for group in self.param_groups:
53 | group.setdefault('amsgrad', False)
54 |
55 | def step(self, closure=None):
56 | """Performs a single optimization step.
57 |
58 | Arguments:
59 | closure (callable, optional): A closure that reevaluates the model
60 | and returns the loss.
61 | """
62 | loss = None
63 | if closure is not None:
64 | loss = closure()
65 |
66 | for group in self.param_groups:
67 | for p in group['params']:
68 | if p.grad is None:
69 | continue
70 |
71 | # Perform stepweight decay
72 | p.data.mul_(1 - group['lr'] * group['weight_decay'])
73 |
74 | # Perform optimization step
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
78 | amsgrad = group['amsgrad']
79 |
80 | state = self.state[p]
81 |
82 | # State initialization
83 | if len(state) == 0:
84 | state['step'] = 0
85 | # Exponential moving average of gradient values
86 | state['exp_avg'] = torch.zeros_like(p.data)
87 | # Exponential moving average of squared gradient values
88 | state['exp_avg_sq'] = torch.zeros_like(p.data)
89 | if amsgrad:
90 | # Maintains max of all exp. moving avg. of sq. grad. values
91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
92 |
93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
94 | if amsgrad:
95 | max_exp_avg_sq = state['max_exp_avg_sq']
96 | beta1, beta2 = group['betas']
97 |
98 | state['step'] += 1
99 | bias_correction1 = 1 - beta1 ** state['step']
100 | bias_correction2 = 1 - beta2 ** state['step']
101 |
102 | # Decay the first and second moment running average coefficient
103 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
105 | if amsgrad:
106 | # Maintains the maximum of all 2nd moment running avg. till now
107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
108 | # Use the max. for normalizing running avg. of gradient
109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
110 | else:
111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
112 |
113 | step_size = group['lr'] / bias_correction1
114 |
115 | p.data.addcdiv_(-step_size, exp_avg, denom)
116 |
117 | return loss
118 |
--------------------------------------------------------------------------------
/optim/lookahead.py:
--------------------------------------------------------------------------------
1 | """ Lookahead Optimizer Wrapper.
2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import torch
8 | from torch.optim.optimizer import Optimizer
9 | from collections import defaultdict
10 |
11 |
12 | class Lookahead(Optimizer):
13 | def __init__(self, base_optimizer, alpha=0.5, k=6):
14 | if not 0.0 <= alpha <= 1.0:
15 | raise ValueError(f'Invalid slow update rate: {alpha}')
16 | if not 1 <= k:
17 | raise ValueError(f'Invalid lookahead steps: {k}')
18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
19 | self.base_optimizer = base_optimizer
20 | self.param_groups = self.base_optimizer.param_groups
21 | self.defaults = base_optimizer.defaults
22 | self.defaults.update(defaults)
23 | self.state = defaultdict(dict)
24 | # manually add our defaults to the param groups
25 | for name, default in defaults.items():
26 | for group in self.param_groups:
27 | group.setdefault(name, default)
28 |
29 | def update_slow(self, group):
30 | for fast_p in group["params"]:
31 | if fast_p.grad is None:
32 | continue
33 | param_state = self.state[fast_p]
34 | if 'slow_buffer' not in param_state:
35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data)
36 | param_state['slow_buffer'].copy_(fast_p.data)
37 | slow = param_state['slow_buffer']
38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow)
39 | fast_p.data.copy_(slow)
40 |
41 | def sync_lookahead(self):
42 | for group in self.param_groups:
43 | self.update_slow(group)
44 |
45 | def step(self, closure=None):
46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups)
47 | loss = self.base_optimizer.step(closure)
48 | for group in self.param_groups:
49 | group['lookahead_step'] += 1
50 | if group['lookahead_step'] % group['lookahead_k'] == 0:
51 | self.update_slow(group)
52 | return loss
53 |
54 | def state_dict(self):
55 | fast_state_dict = self.base_optimizer.state_dict()
56 | slow_state = {
57 | (id(k) if isinstance(k, torch.Tensor) else k): v
58 | for k, v in self.state.items()
59 | }
60 | fast_state = fast_state_dict['state']
61 | param_groups = fast_state_dict['param_groups']
62 | return {
63 | 'state': fast_state,
64 | 'slow_state': slow_state,
65 | 'param_groups': param_groups,
66 | }
67 |
68 | def load_state_dict(self, state_dict):
69 | fast_state_dict = {
70 | 'state': state_dict['state'],
71 | 'param_groups': state_dict['param_groups'],
72 | }
73 | self.base_optimizer.load_state_dict(fast_state_dict)
74 |
75 | # We want to restore the slow state, but share param_groups reference
76 | # with base_optimizer. This is a bit redundant but least code
77 | slow_state_new = False
78 | if 'slow_state' not in state_dict:
79 | print('Loading state_dict from optimizer without Lookahead applied.')
80 | state_dict['slow_state'] = defaultdict(dict)
81 | slow_state_new = True
82 | slow_state_dict = {
83 | 'state': state_dict['slow_state'],
84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code
85 | }
86 | super(Lookahead, self).load_state_dict(slow_state_dict)
87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container
88 | if slow_state_new:
89 | # reapply defaults to catch missing lookahead specific ones
90 | for name, default in self.defaults.items():
91 | for group in self.param_groups:
92 | group.setdefault(name, default)
93 |
--------------------------------------------------------------------------------
/optim/nadam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import Optimizer
3 |
4 |
5 | class Nadam(Optimizer):
6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
7 |
8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
9 |
10 | Arguments:
11 | params (iterable): iterable of parameters to optimize or dicts defining
12 | parameter groups
13 | lr (float, optional): learning rate (default: 2e-3)
14 | betas (Tuple[float, float], optional): coefficients used for computing
15 | running averages of gradient and its square
16 | eps (float, optional): term added to the denominator to improve
17 | numerical stability (default: 1e-8)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
20 |
21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf
22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
23 |
24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408
25 | NOTE: Has potential issues but does work well on some problems.
26 | """
27 |
28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
29 | weight_decay=0, schedule_decay=4e-3):
30 | defaults = dict(lr=lr, betas=betas, eps=eps,
31 | weight_decay=weight_decay, schedule_decay=schedule_decay)
32 | super(Nadam, self).__init__(params, defaults)
33 |
34 | def step(self, closure=None):
35 | """Performs a single optimization step.
36 |
37 | Arguments:
38 | closure (callable, optional): A closure that reevaluates the model
39 | and returns the loss.
40 | """
41 | loss = None
42 | if closure is not None:
43 | loss = closure()
44 |
45 | for group in self.param_groups:
46 | for p in group['params']:
47 | if p.grad is None:
48 | continue
49 | grad = p.grad.data
50 | state = self.state[p]
51 |
52 | # State initialization
53 | if len(state) == 0:
54 | state['step'] = 0
55 | state['m_schedule'] = 1.
56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_()
57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
58 |
59 | # Warming momentum schedule
60 | m_schedule = state['m_schedule']
61 | schedule_decay = group['schedule_decay']
62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
63 | beta1, beta2 = group['betas']
64 | eps = group['eps']
65 | state['step'] += 1
66 | t = state['step']
67 |
68 | if group['weight_decay'] != 0:
69 | grad = grad.add(group['weight_decay'], p.data)
70 |
71 | momentum_cache_t = beta1 * \
72 | (1. - 0.5 * (0.96 ** (t * schedule_decay)))
73 | momentum_cache_t_1 = beta1 * \
74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
75 | m_schedule_new = m_schedule * momentum_cache_t
76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
77 | state['m_schedule'] = m_schedule_new
78 |
79 | # Decay the first and second moment running average coefficient
80 | exp_avg.mul_(beta1).add_(1. - beta1, grad)
81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad)
82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
83 | denom = exp_avg_sq_prime.sqrt_().add_(eps)
84 |
85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom)
86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom)
87 |
88 | return loss
89 |
--------------------------------------------------------------------------------
/optim/novograd.py:
--------------------------------------------------------------------------------
1 | """NovoGrad Optimizer.
2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
4 | - https://arxiv.org/abs/1905.11286
5 | """
6 |
7 | import torch
8 | from torch.optim.optimizer import Optimizer
9 | import math
10 |
11 |
12 | class NovoGrad(Optimizer):
13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
15 | super(NovoGrad, self).__init__(params, defaults)
16 | self._lr = lr
17 | self._beta1 = betas[0]
18 | self._beta2 = betas[1]
19 | self._eps = eps
20 | self._wd = weight_decay
21 | self._grad_averaging = grad_averaging
22 |
23 | self._momentum_initialized = False
24 |
25 | def step(self, closure=None):
26 | loss = None
27 | if closure is not None:
28 | loss = closure()
29 |
30 | if not self._momentum_initialized:
31 | for group in self.param_groups:
32 | for p in group['params']:
33 | if p.grad is None:
34 | continue
35 | state = self.state[p]
36 | grad = p.grad.data
37 | if grad.is_sparse:
38 | raise RuntimeError('NovoGrad does not support sparse gradients')
39 |
40 | v = torch.norm(grad)**2
41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
42 | state['step'] = 0
43 | state['v'] = v
44 | state['m'] = m
45 | state['grad_ema'] = None
46 | self._momentum_initialized = True
47 |
48 | for group in self.param_groups:
49 | for p in group['params']:
50 | if p.grad is None:
51 | continue
52 | state = self.state[p]
53 | state['step'] += 1
54 |
55 | step, v, m = state['step'], state['v'], state['m']
56 | grad_ema = state['grad_ema']
57 |
58 | grad = p.grad.data
59 | g2 = torch.norm(grad)**2
60 | grad_ema = g2 if grad_ema is None else grad_ema * \
61 | self._beta2 + g2 * (1. - self._beta2)
62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
63 |
64 | if self._grad_averaging:
65 | grad *= (1. - self._beta1)
66 |
67 | g2 = torch.norm(grad)**2
68 | v = self._beta2*v + (1. - self._beta2)*g2
69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data)
70 | bias_correction1 = 1 - self._beta1 ** step
71 | bias_correction2 = 1 - self._beta2 ** step
72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
73 |
74 | state['v'], state['m'] = v, m
75 | state['grad_ema'] = grad_ema
76 | p.data.add_(-step_size, m)
77 | return loss
78 |
--------------------------------------------------------------------------------
/optim/nvnovograd.py:
--------------------------------------------------------------------------------
1 | """ Nvidia NovoGrad Optimizer.
2 | Original impl by Nvidia from Jasper example:
3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
5 | - https://arxiv.org/abs/1905.11286
6 | """
7 |
8 | import torch
9 | from torch.optim.optimizer import Optimizer
10 | import math
11 |
12 |
13 | class NvNovoGrad(Optimizer):
14 | """
15 | Implements Novograd algorithm.
16 |
17 | Args:
18 | params (iterable): iterable of parameters to optimize or dicts defining
19 | parameter groups
20 | lr (float, optional): learning rate (default: 1e-3)
21 | betas (Tuple[float, float], optional): coefficients used for computing
22 | running averages of gradient and its square (default: (0.95, 0.98))
23 | eps (float, optional): term added to the denominator to improve
24 | numerical stability (default: 1e-8)
25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
26 | grad_averaging: gradient averaging
27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28 | algorithm from the paper `On the Convergence of Adam and Beyond`_
29 | (default: False)
30 | """
31 |
32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
33 | weight_decay=0, grad_averaging=False, amsgrad=False):
34 | if not 0.0 <= lr:
35 | raise ValueError("Invalid learning rate: {}".format(lr))
36 | if not 0.0 <= eps:
37 | raise ValueError("Invalid epsilon value: {}".format(eps))
38 | if not 0.0 <= betas[0] < 1.0:
39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
40 | if not 0.0 <= betas[1] < 1.0:
41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
42 | defaults = dict(lr=lr, betas=betas, eps=eps,
43 | weight_decay=weight_decay,
44 | grad_averaging=grad_averaging,
45 | amsgrad=amsgrad)
46 |
47 | super(NvNovoGrad, self).__init__(params, defaults)
48 |
49 | def __setstate__(self, state):
50 | super(NvNovoGrad, self).__setstate__(state)
51 | for group in self.param_groups:
52 | group.setdefault('amsgrad', False)
53 |
54 | def step(self, closure=None):
55 | """Performs a single optimization step.
56 |
57 | Arguments:
58 | closure (callable, optional): A closure that reevaluates the model
59 | and returns the loss.
60 | """
61 | loss = None
62 | if closure is not None:
63 | loss = closure()
64 |
65 | for group in self.param_groups:
66 | for p in group['params']:
67 | if p.grad is None:
68 | continue
69 | grad = p.grad.data
70 | if grad.is_sparse:
71 | raise RuntimeError('Sparse gradients are not supported.')
72 | amsgrad = group['amsgrad']
73 |
74 | state = self.state[p]
75 |
76 | # State initialization
77 | if len(state) == 0:
78 | state['step'] = 0
79 | # Exponential moving average of gradient values
80 | state['exp_avg'] = torch.zeros_like(p.data)
81 | # Exponential moving average of squared gradient values
82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
83 | if amsgrad:
84 | # Maintains max of all exp. moving avg. of sq. grad. values
85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
86 |
87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
88 | if amsgrad:
89 | max_exp_avg_sq = state['max_exp_avg_sq']
90 | beta1, beta2 = group['betas']
91 |
92 | state['step'] += 1
93 |
94 | norm = torch.sum(torch.pow(grad, 2))
95 |
96 | if exp_avg_sq == 0:
97 | exp_avg_sq.copy_(norm)
98 | else:
99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
100 |
101 | if amsgrad:
102 | # Maintains the maximum of all 2nd moment running avg. till now
103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
104 | # Use the max. for normalizing running avg. of gradient
105 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
106 | else:
107 | denom = exp_avg_sq.sqrt().add_(group['eps'])
108 |
109 | grad.div_(denom)
110 | if group['weight_decay'] != 0:
111 | grad.add_(group['weight_decay'], p.data)
112 | if group['grad_averaging']:
113 | grad.mul_(1 - beta1)
114 | exp_avg.mul_(beta1).add_(grad)
115 |
116 | p.data.add_(-group['lr'], exp_avg)
117 |
118 | return loss
119 |
--------------------------------------------------------------------------------
/optim/optim_factory.py:
--------------------------------------------------------------------------------
1 | """ Optimizer Factory w/ Custom Weight Decay
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import torch
5 | from torch import optim as optim
6 |
7 | from .adafactor import Adafactor
8 | from .adahessian import Adahessian
9 | from .adamp import AdamP
10 | from .lookahead import Lookahead
11 | from .nadam import Nadam
12 | from .novograd import NovoGrad
13 | from .nvnovograd import NvNovoGrad
14 | from .radam import RAdam
15 | from .rmsprop_tf import RMSpropTF
16 | from .sgdp import SGDP
17 |
18 | try:
19 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
20 | has_apex = True
21 | except ImportError:
22 | has_apex = False
23 |
24 |
25 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
26 | decay = []
27 | no_decay = []
28 | for name, param in model.named_parameters():
29 | if not param.requires_grad:
30 | continue # frozen weights
31 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
32 | no_decay.append(param)
33 | else:
34 | decay.append(param)
35 | return [
36 | {'params': no_decay, 'weight_decay': 0.},
37 | {'params': decay, 'weight_decay': weight_decay}]
38 |
39 |
40 | def create_optimizer(args, model, filter_bias_and_bn=True):
41 | opt_lower = args.opt.lower()
42 | weight_decay = args.weight_decay
43 | if weight_decay and filter_bias_and_bn:
44 | skip = {}
45 | if hasattr(model, 'no_weight_decay'):
46 | skip = model.no_weight_decay()
47 | parameters = add_weight_decay(model, weight_decay, skip)
48 | weight_decay = 0.
49 | else:
50 | parameters = model.parameters()
51 |
52 | if 'fused' in opt_lower:
53 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
54 |
55 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
56 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
57 | opt_args['eps'] = args.opt_eps
58 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
59 | opt_args['betas'] = args.opt_betas
60 | if hasattr(args, 'opt_args') and args.opt_args is not None:
61 | opt_args.update(args.opt_args)
62 |
63 | opt_split = opt_lower.split('_')
64 | opt_lower = opt_split[-1]
65 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
66 | opt_args.pop('eps', None)
67 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
68 | elif opt_lower == 'momentum':
69 | opt_args.pop('eps', None)
70 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
71 | elif opt_lower == 'adam':
72 | optimizer = optim.Adam(parameters, **opt_args)
73 | elif opt_lower == 'adamw':
74 | optimizer = optim.AdamW(parameters, **opt_args)
75 | elif opt_lower == 'nadam':
76 | optimizer = Nadam(parameters, **opt_args)
77 | elif opt_lower == 'radam':
78 | optimizer = RAdam(parameters, **opt_args)
79 | elif opt_lower == 'adamp':
80 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
81 | elif opt_lower == 'sgdp':
82 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
83 | elif opt_lower == 'adadelta':
84 | optimizer = optim.Adadelta(parameters, **opt_args)
85 | elif opt_lower == 'adafactor':
86 | if not args.lr:
87 | opt_args['lr'] = None
88 | optimizer = Adafactor(parameters, **opt_args)
89 | elif opt_lower == 'adahessian':
90 | optimizer = Adahessian(parameters, **opt_args)
91 | elif opt_lower == 'rmsprop':
92 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
93 | elif opt_lower == 'rmsproptf':
94 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
95 | elif opt_lower == 'novograd':
96 | optimizer = NovoGrad(parameters, **opt_args)
97 | elif opt_lower == 'nvnovograd':
98 | optimizer = NvNovoGrad(parameters, **opt_args)
99 | elif opt_lower == 'fusedsgd':
100 | opt_args.pop('eps', None)
101 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
102 | elif opt_lower == 'fusedmomentum':
103 | opt_args.pop('eps', None)
104 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
105 | elif opt_lower == 'fusedadam':
106 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
107 | elif opt_lower == 'fusedadamw':
108 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
109 | elif opt_lower == 'fusedlamb':
110 | optimizer = FusedLAMB(parameters, **opt_args)
111 | elif opt_lower == 'fusednovograd':
112 | opt_args.setdefault('betas', (0.95, 0.98))
113 | optimizer = FusedNovoGrad(parameters, **opt_args)
114 | else:
115 | assert False and "Invalid optimizer"
116 | raise ValueError
117 |
118 | if len(opt_split) > 1:
119 | if opt_split[0] == 'lookahead':
120 | optimizer = Lookahead(optimizer)
121 |
122 | return optimizer
123 |
--------------------------------------------------------------------------------
/optim/radam.py:
--------------------------------------------------------------------------------
1 | """RAdam Optimizer.
2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
4 | """
5 | import math
6 | import torch
7 | from torch.optim.optimizer import Optimizer, required
8 |
9 |
10 | class RAdam(Optimizer):
11 |
12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
14 | self.buffer = [[None, None, None] for ind in range(10)]
15 | super(RAdam, self).__init__(params, defaults)
16 |
17 | def __setstate__(self, state):
18 | super(RAdam, self).__setstate__(state)
19 |
20 | def step(self, closure=None):
21 |
22 | loss = None
23 | if closure is not None:
24 | loss = closure()
25 |
26 | for group in self.param_groups:
27 |
28 | for p in group['params']:
29 | if p.grad is None:
30 | continue
31 | grad = p.grad.data.float()
32 | if grad.is_sparse:
33 | raise RuntimeError('RAdam does not support sparse gradients')
34 |
35 | p_data_fp32 = p.data.float()
36 |
37 | state = self.state[p]
38 |
39 | if len(state) == 0:
40 | state['step'] = 0
41 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
43 | else:
44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
46 |
47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
48 | beta1, beta2 = group['betas']
49 |
50 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
51 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
52 |
53 | state['step'] += 1
54 | buffered = self.buffer[int(state['step'] % 10)]
55 | if state['step'] == buffered[0]:
56 | N_sma, step_size = buffered[1], buffered[2]
57 | else:
58 | buffered[0] = state['step']
59 | beta2_t = beta2 ** state['step']
60 | N_sma_max = 2 / (1 - beta2) - 1
61 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
62 | buffered[1] = N_sma
63 |
64 | # more conservative since it's an approximated value
65 | if N_sma >= 5:
66 | step_size = group['lr'] * math.sqrt(
67 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
68 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
69 | else:
70 | step_size = group['lr'] / (1 - beta1 ** state['step'])
71 | buffered[2] = step_size
72 |
73 | if group['weight_decay'] != 0:
74 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
75 |
76 | # more conservative since it's an approximated value
77 | if N_sma >= 5:
78 | denom = exp_avg_sq.sqrt().add_(group['eps'])
79 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
80 | else:
81 | p_data_fp32.add_(-step_size, exp_avg)
82 |
83 | p.data.copy_(p_data_fp32)
84 |
85 | return loss
86 |
87 |
88 | class PlainRAdam(Optimizer):
89 |
90 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
91 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
92 |
93 | super(PlainRAdam, self).__init__(params, defaults)
94 |
95 | def __setstate__(self, state):
96 | super(PlainRAdam, self).__setstate__(state)
97 |
98 | def step(self, closure=None):
99 |
100 | loss = None
101 | if closure is not None:
102 | loss = closure()
103 |
104 | for group in self.param_groups:
105 |
106 | for p in group['params']:
107 | if p.grad is None:
108 | continue
109 | grad = p.grad.data.float()
110 | if grad.is_sparse:
111 | raise RuntimeError('RAdam does not support sparse gradients')
112 |
113 | p_data_fp32 = p.data.float()
114 |
115 | state = self.state[p]
116 |
117 | if len(state) == 0:
118 | state['step'] = 0
119 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
121 | else:
122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
124 |
125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
126 | beta1, beta2 = group['betas']
127 |
128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
129 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
130 |
131 | state['step'] += 1
132 | beta2_t = beta2 ** state['step']
133 | N_sma_max = 2 / (1 - beta2) - 1
134 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
135 |
136 | if group['weight_decay'] != 0:
137 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
138 |
139 | # more conservative since it's an approximated value
140 | if N_sma >= 5:
141 | step_size = group['lr'] * math.sqrt(
142 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
143 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
144 | denom = exp_avg_sq.sqrt().add_(group['eps'])
145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
146 | else:
147 | step_size = group['lr'] / (1 - beta1 ** state['step'])
148 | p_data_fp32.add_(-step_size, exp_avg)
149 |
150 | p.data.copy_(p_data_fp32)
151 |
152 | return loss
153 |
--------------------------------------------------------------------------------
/optim/rmsprop_tf.py:
--------------------------------------------------------------------------------
1 | """ RMSProp modified to behave like Tensorflow impl
2 |
3 | Originally cut & paste from PyTorch RMSProp
4 | https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
5 | Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
6 |
7 | Modifications Copyright 2020 Ross Wightman
8 | """
9 |
10 | import torch
11 | from torch.optim import Optimizer
12 |
13 |
14 | class RMSpropTF(Optimizer):
15 | """Implements RMSprop algorithm (TensorFlow style epsilon)
16 |
17 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
18 | and a few other modifications to closer match Tensorflow for matching hyper-params.
19 |
20 | Noteworthy changes include:
21 | 1. Epsilon applied inside square-root
22 | 2. square_avg initialized to ones
23 | 3. LR scaling of update accumulated in momentum buffer
24 |
25 | Proposed by G. Hinton in his
26 | `course `_.
27 |
28 | The centered version first appears in `Generating Sequences
29 | With Recurrent Neural Networks `_.
30 |
31 | Arguments:
32 | params (iterable): iterable of parameters to optimize or dicts defining
33 | parameter groups
34 | lr (float, optional): learning rate (default: 1e-2)
35 | momentum (float, optional): momentum factor (default: 0)
36 | alpha (float, optional): smoothing (decay) constant (default: 0.9)
37 | eps (float, optional): term added to the denominator to improve
38 | numerical stability (default: 1e-10)
39 | centered (bool, optional) : if ``True``, compute the centered RMSProp,
40 | the gradient is normalized by an estimation of its variance
41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
43 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
44 | update as per defaults in Tensorflow
45 |
46 | """
47 |
48 | def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
49 | decoupled_decay=False, lr_in_momentum=True):
50 | if not 0.0 <= lr:
51 | raise ValueError("Invalid learning rate: {}".format(lr))
52 | if not 0.0 <= eps:
53 | raise ValueError("Invalid epsilon value: {}".format(eps))
54 | if not 0.0 <= momentum:
55 | raise ValueError("Invalid momentum value: {}".format(momentum))
56 | if not 0.0 <= weight_decay:
57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
58 | if not 0.0 <= alpha:
59 | raise ValueError("Invalid alpha value: {}".format(alpha))
60 |
61 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
62 | decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
63 | super(RMSpropTF, self).__init__(params, defaults)
64 |
65 | def __setstate__(self, state):
66 | super(RMSpropTF, self).__setstate__(state)
67 | for group in self.param_groups:
68 | group.setdefault('momentum', 0)
69 | group.setdefault('centered', False)
70 |
71 | def step(self, closure=None):
72 | """Performs a single optimization step.
73 |
74 | Arguments:
75 | closure (callable, optional): A closure that reevaluates the model
76 | and returns the loss.
77 | """
78 | loss = None
79 | if closure is not None:
80 | loss = closure()
81 |
82 | for group in self.param_groups:
83 | for p in group['params']:
84 | if p.grad is None:
85 | continue
86 | grad = p.grad.data
87 | if grad.is_sparse:
88 | raise RuntimeError('RMSprop does not support sparse gradients')
89 | state = self.state[p]
90 |
91 | # State initialization
92 | if len(state) == 0:
93 | state['step'] = 0
94 | state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero
95 | if group['momentum'] > 0:
96 | state['momentum_buffer'] = torch.zeros_like(p.data)
97 | if group['centered']:
98 | state['grad_avg'] = torch.zeros_like(p.data)
99 |
100 | square_avg = state['square_avg']
101 | one_minus_alpha = 1. - group['alpha']
102 |
103 | state['step'] += 1
104 |
105 | if group['weight_decay'] != 0:
106 | if 'decoupled_decay' in group and group['decoupled_decay']:
107 | p.data.add_(-group['weight_decay'], p.data)
108 | else:
109 | grad = grad.add(group['weight_decay'], p.data)
110 |
111 | # Tensorflow order of ops for updating squared avg
112 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
113 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
114 |
115 | if group['centered']:
116 | grad_avg = state['grad_avg']
117 | grad_avg.add_(one_minus_alpha, grad - grad_avg)
118 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
119 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt
120 | else:
121 | avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
122 |
123 | if group['momentum'] > 0:
124 | buf = state['momentum_buffer']
125 | # Tensorflow accumulates the LR scaling in the momentum buffer
126 | if 'lr_in_momentum' in group and group['lr_in_momentum']:
127 | buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
128 | p.data.add_(-buf)
129 | else:
130 | # PyTorch scales the param update by LR
131 | buf.mul_(group['momentum']).addcdiv_(grad, avg)
132 | p.data.add_(-group['lr'], buf)
133 | else:
134 | p.data.addcdiv_(-group['lr'], grad, avg)
135 |
136 | return loss
137 |
--------------------------------------------------------------------------------
/optim/sgdp.py:
--------------------------------------------------------------------------------
1 | """
2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
3 |
4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5 | Code: https://github.com/clovaai/AdamP
6 |
7 | Copyright (c) 2020-present NAVER Corp.
8 | MIT license
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.optim.optimizer import Optimizer, required
14 | import math
15 |
16 | class SGDP(Optimizer):
17 | def __init__(self, params, lr=required, momentum=0, dampening=0,
18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
21 | super(SGDP, self).__init__(params, defaults)
22 |
23 | def _channel_view(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | def _layer_view(self, x):
27 | return x.view(1, -1)
28 |
29 | def _cosine_similarity(self, x, y, eps, view_func):
30 | x = view_func(x)
31 | y = view_func(y)
32 |
33 | x_norm = x.norm(dim=1).add_(eps)
34 | y_norm = y.norm(dim=1).add_(eps)
35 | dot = (x * y).sum(dim=1)
36 |
37 | return dot.abs() / x_norm / y_norm
38 |
39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40 | wd = 1
41 | expand_size = [-1] + [1] * (len(p.shape) - 1)
42 | for view_func in [self._channel_view, self._layer_view]:
43 |
44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45 |
46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49 | wd = wd_ratio
50 |
51 | return perturb, wd
52 |
53 | return perturb, wd
54 |
55 | def step(self, closure=None):
56 | loss = None
57 | if closure is not None:
58 | loss = closure()
59 |
60 | for group in self.param_groups:
61 | weight_decay = group['weight_decay']
62 | momentum = group['momentum']
63 | dampening = group['dampening']
64 | nesterov = group['nesterov']
65 |
66 | for p in group['params']:
67 | if p.grad is None:
68 | continue
69 | grad = p.grad.data
70 | state = self.state[p]
71 |
72 | # State initialization
73 | if len(state) == 0:
74 | state['momentum'] = torch.zeros_like(p.data)
75 |
76 | # SGD
77 | buf = state['momentum']
78 | buf.mul_(momentum).add_(1 - dampening, grad)
79 | if nesterov:
80 | d_p = grad + momentum * buf
81 | else:
82 | d_p = buf
83 |
84 | # Projection
85 | wd_ratio = 1
86 | if len(p.shape) > 1:
87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
88 |
89 | # Weight decay
90 | if weight_decay != 0:
91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
92 |
93 | # Step
94 | p.data.add_(-group['lr'], d_p)
95 |
96 | return loss
97 |
--------------------------------------------------------------------------------
/refTools/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'licheng'
2 |
3 |
4 |
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/refEvaluation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/__pycache__/refEvaluation.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/refEvaluation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/__pycache__/refEvaluation.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/__pycache__/refEvaluation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/__pycache__/refEvaluation.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/bleu.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/bleu.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/bleu.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/__pycache__/bleu_scorer.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from refTools.evaluation.bleu.bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 |
23 | assert(gts.keys() == res.keys())
24 | imgIds = gts.keys()
25 |
26 | bleu_scorer = BleuScorer(n=self._n)
27 | for id in imgIds:
28 | hypo = res[id]
29 | ref = gts[id]
30 |
31 | # Sanity check.
32 | assert(type(hypo) is list)
33 | assert(len(hypo) == 1)
34 | assert(type(ref) is list)
35 | assert(len(ref) >= 1)
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | #score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
42 |
43 | # return (bleu, bleu_info)
44 | return score, scores
45 |
46 | def method(self):
47 | return "Bleu"
48 |
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/bleu.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/bleu.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/bleu/bleu_scorer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/bleu/bleu_scorer.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/cider.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/cider.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/cider.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/__pycache__/cider_scorer.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from refTools.evaluation.cider.cider_scorer import CiderScorer
11 | import pdb
12 |
13 | class Cider:
14 | """
15 | Main Class to compute the CIDEr metric
16 |
17 | """
18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
19 | # set cider to sum over 1 to 4-grams
20 | self._n = n
21 | # set the standard deviation parameter for gaussian penalty
22 | self._sigma = sigma
23 |
24 | def compute_score(self, gts, res):
25 | """
26 | Main function to compute CIDEr score
27 | :param hypo_for_image (dict) : dictionary with key and value
28 | ref_for_image (dict) : dictionary with key and value
29 | :return: cider (float) : computed CIDEr score for the corpus
30 | """
31 |
32 | assert(gts.keys() == res.keys())
33 | imgIds = gts.keys()
34 |
35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
36 |
37 | for id in imgIds:
38 | hypo = res[id]
39 | ref = gts[id]
40 |
41 | # Sanity check.
42 | assert(type(hypo) is list)
43 | assert(len(hypo) == 1)
44 | assert(type(ref) is list)
45 | assert(len(ref) > 0)
46 |
47 | cider_scorer += (hypo[0], ref)
48 |
49 | (score, scores) = cider_scorer.compute_score()
50 |
51 | return score, scores
52 |
53 | def method(self):
54 | return "CIDEr"
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/cider.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import pdb
9 | import math
10 |
11 | def precook(s, n=4, out=False):
12 | """
13 | Takes a string as input and returns an object that can be given to
14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
15 | can take string arguments as well.
16 | :param s: string : sentence to be converted into ngrams
17 | :param n: int : number of ngrams for which representation is calculated
18 | :return: term frequency vector for occuring ngrams
19 | """
20 | words = s.split()
21 | counts = defaultdict(int)
22 | for k in xrange(1,n+1):
23 | for i in xrange(len(words)-k+1):
24 | ngram = tuple(words[i:i+k])
25 | counts[ngram] += 1
26 | return counts
27 |
28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29 | '''Takes a list of reference sentences for a single segment
30 | and returns an object that encapsulates everything that BLEU
31 | needs to know about them.
32 | :param refs: list of string : reference sentences for some image
33 | :param n: int : number of ngrams for which (ngram) representation is calculated
34 | :return: result (list of dict)
35 | '''
36 | return [precook(ref, n) for ref in refs]
37 |
38 | def cook_test(test, n=4):
39 | '''Takes a test sentence and returns an object that
40 | encapsulates everything that BLEU needs to know about it.
41 | :param test: list of string : hypothesis sentence for some image
42 | :param n: int : number of ngrams for which (ngram) representation is calculated
43 | :return: result (dict)
44 | '''
45 | return precook(test, n, True)
46 |
47 | class CiderScorer(object):
48 | """CIDEr scorer.
49 | """
50 |
51 | def copy(self):
52 | ''' copy the refs.'''
53 | new = CiderScorer(n=self.n)
54 | new.ctest = copy.copy(self.ctest)
55 | new.crefs = copy.copy(self.crefs)
56 | return new
57 |
58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59 | ''' singular instance '''
60 | self.n = n
61 | self.sigma = sigma
62 | self.crefs = []
63 | self.ctest = []
64 | self.document_frequency = defaultdict(float)
65 | self.cook_append(test, refs)
66 | self.ref_len = None
67 |
68 | def cook_append(self, test, refs):
69 | '''called by constructor and __iadd__ to avoid creating new instances.'''
70 |
71 | if refs is not None:
72 | self.crefs.append(cook_refs(refs))
73 | if test is not None:
74 | self.ctest.append(cook_test(test)) ## N.B.: -1
75 | else:
76 | self.ctest.append(None) # lens of crefs and ctest have to match
77 |
78 | def size(self):
79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80 | return len(self.crefs)
81 |
82 | def __iadd__(self, other):
83 | '''add an instance (e.g., from another sentence).'''
84 |
85 | if type(other) is tuple:
86 | ## avoid creating new CiderScorer instances
87 | self.cook_append(other[0], other[1])
88 | else:
89 | self.ctest.extend(other.ctest)
90 | self.crefs.extend(other.crefs)
91 |
92 | return self
93 | def compute_doc_freq(self):
94 | '''
95 | Compute term frequency for reference data.
96 | This will be used to compute idf (inverse document frequency later)
97 | The term frequency is stored in the object
98 | :return: None
99 | '''
100 | for refs in self.crefs:
101 | # refs, k ref captions of one image
102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
103 | self.document_frequency[ngram] += 1
104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105 |
106 | def compute_cider(self):
107 | def counts2vec(cnts):
108 | """
109 | Function maps counts of ngram to vector of tfidf weights.
110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111 | The n-th entry of array denotes length of n-grams.
112 | :param cnts:
113 | :return: vec (array of dict), norm (array of float), length (int)
114 | """
115 | vec = [defaultdict(float) for _ in range(self.n)]
116 | length = 0
117 | norm = [0.0 for _ in range(self.n)]
118 | for (ngram,term_freq) in cnts.iteritems():
119 | # give word count 1 if it doesn't appear in reference corpus
120 | df = np.log(max(1.0, self.document_frequency[ngram]))
121 | # ngram index
122 | n = len(ngram)-1
123 | # tf (term_freq) * idf (precomputed idf) for n-grams
124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125 | # compute norm for the vector. the norm will be used for computing similarity
126 | norm[n] += pow(vec[n][ngram], 2)
127 |
128 | if n == 1:
129 | length += term_freq
130 | norm = [np.sqrt(n) for n in norm]
131 | return vec, norm, length
132 |
133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134 | '''
135 | Compute the cosine similarity of two vectors.
136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137 | :param vec_ref: array of dictionary for vector corresponding to reference
138 | :param norm_hyp: array of float for vector corresponding to hypothesis
139 | :param norm_ref: array of float for vector corresponding to reference
140 | :param length_hyp: int containing length of hypothesis
141 | :param length_ref: int containing length of reference
142 | :return: array of score for each n-grams cosine similarity
143 | '''
144 | delta = float(length_hyp - length_ref)
145 | # measure consine similarity
146 | val = np.array([0.0 for _ in range(self.n)])
147 | for n in range(self.n):
148 | # ngram
149 | for (ngram,count) in vec_hyp[n].iteritems():
150 | # vrama91 : added clipping
151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152 |
153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154 | val[n] /= (norm_hyp[n]*norm_ref[n])
155 |
156 | assert(not math.isnan(val[n]))
157 | # vrama91: added a length based gaussian penalty
158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159 | return val
160 |
161 | # compute log reference length
162 | self.ref_len = np.log(float(len(self.crefs)))
163 |
164 | scores = []
165 | for test, refs in zip(self.ctest, self.crefs):
166 | # compute vector for test captions
167 | vec, norm, length = counts2vec(test)
168 | # compute vector for ref captions
169 | score = np.array([0.0 for _ in range(self.n)])
170 | for ref in refs:
171 | vec_ref, norm_ref, length_ref = counts2vec(ref)
172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
173 | # change by vrama91 - mean of ngram scores, instead of sum
174 | score_avg = np.mean(score)
175 | # divide by number of references
176 | score_avg /= len(refs)
177 | # multiply score by 10
178 | score_avg *= 10.0
179 | # append score of an image to the score list
180 | scores.append(score_avg)
181 | return scores
182 |
183 | def compute_score(self, option=None, verbose=0):
184 | # compute idf
185 | self.compute_doc_freq()
186 | # assert to check document frequency
187 | assert(len(self.ctest) >= max(self.document_frequency.values()))
188 | # compute cider score
189 | score = self.compute_cider()
190 | # debug
191 | # print score
192 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/refTools/evaluation/cider/cider_scorer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/cider/cider_scorer.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/meteor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/__pycache__/meteor.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/meteor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/__pycache__/meteor.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/__pycache__/meteor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/__pycache__/meteor.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/meteor-1.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/meteor-1.5.jar
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 | import os
7 | import sys
8 | import subprocess
9 | import threading
10 |
11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
12 | METEOR_JAR = 'meteor-1.5.jar'
13 | # print METEOR_JAR
14 |
15 | class Meteor:
16 |
17 | def __init__(self):
18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
19 | '-', '-', '-stdio', '-l', 'en', '-norm']
20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
21 | cwd=os.path.dirname(os.path.abspath(__file__)), \
22 | stdin=subprocess.PIPE, \
23 | stdout=subprocess.PIPE, \
24 | stderr=subprocess.PIPE)
25 | # Used to guarantee thread safety
26 | self.lock = threading.Lock()
27 |
28 | def compute_score(self, gts, res):
29 | assert(gts.keys() == res.keys())
30 | imgIds = gts.keys()
31 | scores = []
32 |
33 | eval_line = 'EVAL'
34 | self.lock.acquire()
35 | for i in imgIds:
36 | assert(len(res[i]) == 1)
37 | stat = self._stat(res[i][0], gts[i])
38 | eval_line += ' ||| {}'.format(stat)
39 |
40 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
41 | for i in range(0,len(imgIds)):
42 | scores.append(float(self.meteor_p.stdout.readline().strip()))
43 | score = float(self.meteor_p.stdout.readline().strip())
44 | self.lock.release()
45 |
46 | return score, scores
47 |
48 | def method(self):
49 | return "METEOR"
50 |
51 | def _stat(self, hypothesis_str, reference_list):
52 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
53 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
54 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
55 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode())
56 | return self.meteor_p.stdout.readline().decode().strip()
57 |
58 | def _score(self, hypothesis_str, reference_list):
59 | self.lock.acquire()
60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
63 | self.meteor_p.stdin.write('{}\n'.format(score_line))
64 | stats = self.meteor_p.stdout.readline().strip()
65 | eval_line = 'EVAL ||| {}'.format(stats)
66 | # EVAL ||| stats
67 | self.meteor_p.stdin.write('{}\n'.format(eval_line))
68 | score = float(self.meteor_p.stdout.readline().strip())
69 | self.lock.release()
70 | return score
71 |
72 | def __exit__(self):
73 | self.lock.acquire()
74 | self.meteor_p.stdin.close()
75 | self.meteor_p.wait()
76 | self.lock.release()
77 |
--------------------------------------------------------------------------------
/refTools/evaluation/meteor/meteor.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/meteor/meteor.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/readme.txt:
--------------------------------------------------------------------------------
1 | This folder contains modified coco-caption evaluation, which is downloaded from https://github.com/tylin/coco-caption.git
2 | and refEvaluation which is to be called by the refer algorithm.
3 |
4 | More specifically, this folder contains:
5 | 1. bleu/
6 | 2. cider/
7 | 3. meteor/
8 | 4. rouge/
9 | 5. tokenizer/
10 | 6. __init__.py
11 | 7. refEvaluation.py
12 |
--------------------------------------------------------------------------------
/refTools/evaluation/refEvaluation.py:
--------------------------------------------------------------------------------
1 | from refTools.evaluation.tokenizer.ptbtokenizer import PTBTokenizer
2 | from refTools.evaluation.bleu.bleu import Bleu
3 | from refTools.evaluation.meteor.meteor import Meteor
4 | from refTools.evaluation.rouge.rouge import Rouge
5 | from refTools.evaluation.cider.cider import Cider
6 |
7 | """
8 | Input: refer and Res = [{ref_id, sent}]
9 |
10 | Things of interest
11 | evalRefs - list of ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']
12 | eval - dict of {metric: score}
13 | refToEval - dict of {ref_id: ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']}
14 | """
15 |
16 | class RefEvaluation:
17 | def __init__ (self, refer, Res):
18 | """
19 | :param refer: refer class of current dataset
20 | :param Res: [{'ref_id', 'sent'}]
21 | """
22 | self.evalRefs = []
23 | self.eval = {}
24 | self.refToEval = {}
25 | self.refer = refer
26 | self.Res = Res
27 |
28 | def evaluate(self):
29 |
30 | evalRefIds = [ann['ref_id'] for ann in self.Res]
31 |
32 | refToGts = {}
33 | for ref_id in evalRefIds:
34 | ref = self.refer.Refs[ref_id]
35 | gt_sents = [sent['sent'].encode('ascii', 'ignore').decode('ascii') for sent in ref['sentences']] # up to 3 expressions
36 | refToGts[ref_id] = gt_sents
37 | refToRes = {ann['ref_id']: [ann['sent']] for ann in self.Res}
38 |
39 | print('tokenization...')
40 | tokenizer = PTBTokenizer()
41 | self.refToRes = tokenizer.tokenize(refToRes)
42 | self.refToGts = tokenizer.tokenize(refToGts)
43 |
44 | # =================================================
45 | # Set up scorers
46 | # =================================================
47 | print('setting up scorers...')
48 | scorers = [
49 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
50 | (Meteor(),"METEOR"),
51 | (Rouge(), "ROUGE_L"),
52 | (Cider(), "CIDEr")
53 | ]
54 |
55 | # =================================================
56 | # Compute scores
57 | # =================================================
58 | for scorer, method in scorers:
59 | print('computing %s score...'%(scorer.method()))
60 | score, scores = scorer.compute_score(self.refToGts, self.refToRes)
61 | if type(method) == list:
62 | for sc, scs, m in zip(score, scores, method):
63 | self.setEval(sc, m)
64 | self.setRefToEvalRefs(scs, self.refToGts.keys(), m)
65 | print("%s: %0.3f"%(m, sc))
66 | else:
67 | self.setEval(score, method)
68 | self.setRefToEvalRefs(scores, self.refToGts.keys(), method)
69 | print("%s: %0.3f"%(method, score))
70 | self.setEvalRefs()
71 |
72 | def setEval(self, score, method):
73 | self.eval[method] = score
74 |
75 | def setRefToEvalRefs(self, scores, refIds, method):
76 | for refId, score in zip(refIds, scores):
77 | if not refId in self.refToEval:
78 | self.refToEval[refId] = {}
79 | self.refToEval[refId]["ref_id"] = refId
80 | self.refToEval[refId][method] = score
81 |
82 | def setEvalRefs(self):
83 | self.evalRefs = [eval for refId, eval in self.refToEval.items()]
84 |
85 |
86 | if __name__ == '__main__':
87 |
88 | import os.path as osp
89 | import sys
90 | ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))
91 | sys.path.insert(0, osp.join(ROOT_DIR, 'lib', 'datasets'))
92 | from refer import REFER
93 |
94 | # load refer of dataset
95 | dataset = 'refcoco'
96 | refer = REFER(dataset, splitBy = 'google')
97 |
98 | # mimic some Res
99 | val_refIds = refer.getRefIds(split='test')
100 | ref_id = 49767
101 | print("GD: %s" % refer.Refs[ref_id]['sentences'])
102 | Res = [{'ref_id': ref_id, 'sent': 'left bottle'}]
103 |
104 | # evaluate some refer expressions
105 | refEval = RefEvaluation(refer, Res)
106 | refEval.evaluate()
107 |
108 | # print output evaluation scores
109 | for metric, score in refEval.eval.items():
110 | print('%s: %.3f'%(metric, score))
111 |
112 | # demo how to use evalImgs to retrieve low score result
113 | # evals = [eva for eva in refEval.evalRefs if eva['CIDEr']<30]
114 | # print 'ground truth sents'
115 | # refId = evals[0]['ref_id']
116 | # print 'refId: %s' % refId
117 | # print [sent['sent'] for sent in refer.Refs[refId]['sentences']]
118 | #
119 | # print 'generated sent (CIDEr score %0.1f)' % (evals[0]['CIDEr'])
120 |
121 | # print refEval.refToEval[8]
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/refTools/evaluation/refEvaluation.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/refEvaluation.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'vrama91'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/rouge/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/rouge/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/rouge/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/rouge/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/rouge.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/rouge/__pycache__/rouge.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/rouge.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/rouge/__pycache__/rouge.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/__pycache__/rouge.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/rouge/__pycache__/rouge.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 |
20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21 | """
22 | if(len(string)< len(sub)):
23 | sub, string = string, sub
24 |
25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
26 |
27 | for j in range(1,len(sub)+1):
28 | for i in range(1,len(string)+1):
29 | if(string[i-1] == sub[j-1]):
30 | lengths[i][j] = lengths[i-1][j-1] + 1
31 | else:
32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
33 |
34 | return lengths[len(string)][len(sub)]
35 |
36 | class Rouge():
37 | '''
38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
39 |
40 | '''
41 | def __init__(self):
42 | # vrama91: updated the value below based on discussion with Hovey
43 | self.beta = 1.2
44 |
45 | def calc_score(self, candidate, refs):
46 | """
47 | Compute ROUGE-L score given one candidate and references for an image
48 | :param candidate: str : candidate sentence to be evaluated
49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
50 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
51 | """
52 | assert(len(candidate)==1)
53 | assert(len(refs)>0)
54 | prec = []
55 | rec = []
56 |
57 | # split into tokens
58 | token_c = candidate[0].split(" ")
59 |
60 | for reference in refs:
61 | # split into tokens
62 | token_r = reference.split(" ")
63 | # compute the longest common subsequence
64 | lcs = my_lcs(token_r, token_c)
65 | prec.append(lcs/float(len(token_c)))
66 | rec.append(lcs/float(len(token_r)))
67 |
68 | prec_max = max(prec)
69 | rec_max = max(rec)
70 |
71 | if(prec_max!=0 and rec_max !=0):
72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
73 | else:
74 | score = 0.0
75 | return score
76 |
77 | def compute_score(self, gts, res):
78 | """
79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
80 | Invoked by evaluate_captions.py
81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
84 | """
85 | assert(gts.keys() == res.keys())
86 | imgIds = gts.keys()
87 |
88 | score = []
89 | for id in imgIds:
90 | hypo = res[id]
91 | ref = gts[id]
92 |
93 | score.append(self.calc_score(hypo, ref))
94 |
95 | # Sanity check.
96 | assert(type(hypo) is list)
97 | assert(len(hypo) == 1)
98 | assert(type(ref) is list)
99 | assert(len(ref) > 0)
100 |
101 | average_score = np.mean(np.array(score))
102 | return average_score, np.array(score)
103 |
104 | def method(self):
105 | return "Rouge"
106 |
--------------------------------------------------------------------------------
/refTools/evaluation/rouge/rouge.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/rouge/rouge.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'hfang'
2 |
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/__init__.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-36.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/__pycache__/ptbtokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/ptbtokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import sys
13 | import subprocess
14 | import tempfile
15 | import itertools
16 |
17 | # path to the stanford corenlp jar
18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
19 |
20 | # punctuations to be removed from the sentences
21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
23 |
24 | class PTBTokenizer:
25 | """Python wrapper of Stanford PTBTokenizer"""
26 |
27 | def tokenize(self, captions_for_image):
28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
29 | 'edu.stanford.nlp.process.PTBTokenizer', \
30 | '-preserveLines', '-lowerCase']
31 |
32 | # ======================================================
33 | # prepare data for PTB Tokenizer
34 | # ======================================================
35 | final_tokenized_captions_for_image = {}
36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
38 |
39 | # ======================================================
40 | # save sentences to temporary file
41 | # ======================================================
42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
44 | tmp_file.write(sentences.encode())
45 | tmp_file.close()
46 |
47 | # ======================================================
48 | # tokenize sentence
49 | # ======================================================
50 | cmd.append(os.path.basename(tmp_file.name))
51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
52 | stdout=subprocess.PIPE)
53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
54 | token_lines = token_lines.decode()
55 | lines = token_lines.split('\n')
56 | # remove temp file
57 | os.remove(tmp_file.name)
58 |
59 | # ======================================================
60 | # create dictionary for tokenized captions
61 | # ======================================================
62 | for k, line in zip(image_id, lines):
63 | if not k in final_tokenized_captions_for_image:
64 | final_tokenized_captions_for_image[k] = []
65 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
66 | if w not in PUNCTUATIONS])
67 | final_tokenized_captions_for_image[k].append(tokenized_caption)
68 |
69 | return final_tokenized_captions_for_image
70 |
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/ptbtokenizer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/ptbtokenizer.pyc
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/stanford-corenlp-3.4.1.jar
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/tmp82iqkuu0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/tmp82iqkuu0
--------------------------------------------------------------------------------
/refTools/evaluation/tokenizer/tmpn19wmqte:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaeseokbyun/GRIT-VLP/78b832f422165d4a5a2480b3508aabdd63a4e4dc/refTools/evaluation/tokenizer/tmpn19wmqte
--------------------------------------------------------------------------------
/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from .cosine_lr import CosineLRScheduler
2 | from .plateau_lr import PlateauLRScheduler
3 | from .step_lr import StepLRScheduler
4 | from .tanh_lr import TanhLRScheduler
5 | from .scheduler_factory import create_scheduler
6 |
--------------------------------------------------------------------------------
/scheduler/cosine_lr.py:
--------------------------------------------------------------------------------
1 | """ Cosine Scheduler
2 |
3 | Cosine LR schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import numpy as np
10 | import torch
11 |
12 | from .scheduler import Scheduler
13 |
14 | from pdb import set_trace as breakpoint
15 |
16 | _logger = logging.getLogger(__name__)
17 |
18 |
19 | class CosineLRScheduler(Scheduler):
20 | """
21 | Cosine decay with restarts.
22 | This is described in the paper https://arxiv.org/abs/1608.03983.
23 |
24 | Inspiration from
25 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
26 | """
27 |
28 | def __init__(self,
29 | optimizer: torch.optim.Optimizer,
30 | t_initial: int,
31 | t_mul: float = 1.,
32 | lr_min: float = 0.,
33 | decay_rate: float = 1.,
34 | warmup_t=0,
35 | warmup_lr_init=0,
36 | warmup_prefix=True,
37 | cycle_limit=0,
38 | t_in_epochs=True,
39 | noise_range_t=None,
40 | noise_pct=0.67,
41 | noise_std=1.0,
42 | noise_seed=42,
43 | initialize=True) -> None:
44 | super().__init__(
45 | optimizer, param_group_field="lr",
46 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
47 | initialize=initialize)
48 |
49 | assert t_initial > 0
50 | assert lr_min >= 0
51 | if t_initial == 1 and t_mul == 1 and decay_rate == 1:
52 | _logger.warning("Cosine annealing scheduler will have no effect on the learning "
53 | "rate since t_initial = t_mul = eta_mul = 1.")
54 | self.t_initial = t_initial
55 | self.t_mul = t_mul
56 | self.lr_min = lr_min
57 | self.decay_rate = decay_rate
58 | self.cycle_limit = cycle_limit
59 | self.warmup_t = warmup_t
60 | self.warmup_lr_init = warmup_lr_init
61 | self.warmup_prefix = warmup_prefix
62 | self.t_in_epochs = t_in_epochs
63 | if self.warmup_t:
64 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
65 | super().update_groups(self.warmup_lr_init)
66 | else:
67 | self.warmup_steps = [1 for _ in self.base_values]
68 |
69 | def _get_lr(self, t):
70 | if t < self.warmup_t:
71 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
72 | else:
73 | if self.warmup_prefix:
74 | t = t - self.warmup_t
75 |
76 | if self.t_mul != 1:
77 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
78 | t_i = self.t_mul ** i * self.t_initial
79 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
80 | else:
81 | i = t // self.t_initial
82 | t_i = self.t_initial
83 | t_curr = t - (self.t_initial * i)
84 |
85 | gamma = self.decay_rate ** i
86 | lr_min = self.lr_min * gamma
87 | lr_max_values = [v * gamma for v in self.base_values]
88 |
89 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
90 | lrs = [
91 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
92 | ]
93 | else:
94 | lrs = [self.lr_min for _ in self.base_values]
95 |
96 | return lrs
97 |
98 | def get_epoch_values(self, epoch: int):
99 | if self.t_in_epochs:
100 | return self._get_lr(epoch)
101 | else:
102 | return None
103 |
104 | def get_update_values(self, num_updates: int):
105 | if not self.t_in_epochs:
106 | return self._get_lr(num_updates)
107 | else:
108 | return None
109 |
110 | def get_cycle_length(self, cycles=0):
111 | if not cycles:
112 | cycles = self.cycle_limit
113 | cycles = max(1, cycles)
114 | if self.t_mul == 1.0:
115 | return self.t_initial * cycles
116 | else:
117 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
118 |
--------------------------------------------------------------------------------
/scheduler/plateau_lr.py:
--------------------------------------------------------------------------------
1 | """ Plateau Scheduler
2 |
3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import torch
8 |
9 | from .scheduler import Scheduler
10 |
11 |
12 | class PlateauLRScheduler(Scheduler):
13 | """Decay the LR by a factor every time the validation loss plateaus."""
14 |
15 | def __init__(self,
16 | optimizer,
17 | decay_rate=0.1,
18 | patience_t=10,
19 | verbose=True,
20 | threshold=1e-4,
21 | cooldown_t=0,
22 | warmup_t=0,
23 | warmup_lr_init=0,
24 | lr_min=0,
25 | mode='max',
26 | noise_range_t=None,
27 | noise_type='normal',
28 | noise_pct=0.67,
29 | noise_std=1.0,
30 | noise_seed=None,
31 | initialize=True,
32 | ):
33 | super().__init__(optimizer, 'lr', initialize=initialize)
34 |
35 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
36 | self.optimizer,
37 | patience=patience_t,
38 | factor=decay_rate,
39 | verbose=verbose,
40 | threshold=threshold,
41 | cooldown=cooldown_t,
42 | mode=mode,
43 | min_lr=lr_min
44 | )
45 |
46 | self.noise_range = noise_range_t
47 | self.noise_pct = noise_pct
48 | self.noise_type = noise_type
49 | self.noise_std = noise_std
50 | self.noise_seed = noise_seed if noise_seed is not None else 42
51 | self.warmup_t = warmup_t
52 | self.warmup_lr_init = warmup_lr_init
53 | if self.warmup_t:
54 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
55 | super().update_groups(self.warmup_lr_init)
56 | else:
57 | self.warmup_steps = [1 for _ in self.base_values]
58 | self.restore_lr = None
59 |
60 | def state_dict(self):
61 | return {
62 | 'best': self.lr_scheduler.best,
63 | 'last_epoch': self.lr_scheduler.last_epoch,
64 | }
65 |
66 | def load_state_dict(self, state_dict):
67 | self.lr_scheduler.best = state_dict['best']
68 | if 'last_epoch' in state_dict:
69 | self.lr_scheduler.last_epoch = state_dict['last_epoch']
70 |
71 | # override the base class step fn completely
72 | def step(self, epoch, metric=None):
73 | if epoch <= self.warmup_t:
74 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
75 | super().update_groups(lrs)
76 | else:
77 | if self.restore_lr is not None:
78 | # restore actual LR from before our last noise perturbation before stepping base
79 | for i, param_group in enumerate(self.optimizer.param_groups):
80 | param_group['lr'] = self.restore_lr[i]
81 | self.restore_lr = None
82 |
83 | self.lr_scheduler.step(metric, epoch) # step the base scheduler
84 |
85 | if self.noise_range is not None:
86 | if isinstance(self.noise_range, (list, tuple)):
87 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
88 | else:
89 | apply_noise = epoch >= self.noise_range
90 | if apply_noise:
91 | self._apply_noise(epoch)
92 |
93 | def _apply_noise(self, epoch):
94 | g = torch.Generator()
95 | g.manual_seed(self.noise_seed + epoch)
96 | if self.noise_type == 'normal':
97 | while True:
98 | # resample if noise out of percent limit, brute force but shouldn't spin much
99 | noise = torch.randn(1, generator=g).item()
100 | if abs(noise) < self.noise_pct:
101 | break
102 | else:
103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104 |
105 | # apply the noise on top of previous LR, cache the old value so we can restore for normal
106 | # stepping of base scheduler
107 | restore_lr = []
108 | for i, param_group in enumerate(self.optimizer.param_groups):
109 | old_lr = float(param_group['lr'])
110 | restore_lr.append(old_lr)
111 | new_lr = old_lr + old_lr * noise
112 | param_group['lr'] = new_lr
113 | self.restore_lr = restore_lr
114 |
--------------------------------------------------------------------------------
/scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | import torch
4 |
5 |
6 | class Scheduler:
7 | """ Parameter Scheduler Base Class
8 | A scheduler base class that can be used to schedule any optimizer parameter groups.
9 |
10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called
11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
13 |
14 | The schedulers built on this should try to remain as stateless as possible (for simplicity).
15 |
16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training
18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call.
19 |
20 | Based on ideas from:
21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
23 | """
24 |
25 | def __init__(self,
26 | optimizer: torch.optim.Optimizer,
27 | param_group_field: str,
28 | noise_range_t=None,
29 | noise_type='normal',
30 | noise_pct=0.67,
31 | noise_std=1.0,
32 | noise_seed=None,
33 | initialize: bool = True) -> None:
34 | self.optimizer = optimizer
35 | self.param_group_field = param_group_field
36 | self._initial_param_group_field = f"initial_{param_group_field}"
37 | if initialize:
38 | for i, group in enumerate(self.optimizer.param_groups):
39 | if param_group_field not in group:
40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
41 | group.setdefault(self._initial_param_group_field, group[param_group_field])
42 | else:
43 | for i, group in enumerate(self.optimizer.param_groups):
44 | if self._initial_param_group_field not in group:
45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
47 | self.metric = None # any point to having this for all?
48 | self.noise_range_t = noise_range_t
49 | self.noise_pct = noise_pct
50 | self.noise_type = noise_type
51 | self.noise_std = noise_std
52 | self.noise_seed = noise_seed if noise_seed is not None else 42
53 | self.update_groups(self.base_values)
54 |
55 | def state_dict(self) -> Dict[str, Any]:
56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
57 |
58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
59 | self.__dict__.update(state_dict)
60 |
61 | def get_epoch_values(self, epoch: int):
62 | return None
63 |
64 | def get_update_values(self, num_updates: int):
65 | return None
66 |
67 | def step(self, epoch: int, metric: float = None) -> None:
68 | self.metric = metric
69 | values = self.get_epoch_values(epoch)
70 | if values is not None:
71 | values = self._add_noise(values, epoch)
72 | self.update_groups(values)
73 |
74 | def step_update(self, num_updates: int, metric: float = None):
75 | self.metric = metric
76 | values = self.get_update_values(num_updates)
77 | if values is not None:
78 | values = self._add_noise(values, num_updates)
79 | self.update_groups(values)
80 |
81 | def update_groups(self, values):
82 | if not isinstance(values, (list, tuple)):
83 | values = [values] * len(self.optimizer.param_groups)
84 | for param_group, value in zip(self.optimizer.param_groups, values):
85 | param_group[self.param_group_field] = value
86 |
87 | def _add_noise(self, lrs, t):
88 | if self.noise_range_t is not None:
89 | if isinstance(self.noise_range_t, (list, tuple)):
90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
91 | else:
92 | apply_noise = t >= self.noise_range_t
93 | if apply_noise:
94 | g = torch.Generator()
95 | g.manual_seed(self.noise_seed + t)
96 | if self.noise_type == 'normal':
97 | while True:
98 | # resample if noise out of percent limit, brute force but shouldn't spin much
99 | noise = torch.randn(1, generator=g).item()
100 | if abs(noise) < self.noise_pct:
101 | break
102 | else:
103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104 | lrs = [v + v * noise for v in lrs]
105 | return lrs
106 |
--------------------------------------------------------------------------------
/scheduler/scheduler_factory.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from .cosine_lr import CosineLRScheduler
5 | from .tanh_lr import TanhLRScheduler
6 | from .step_lr import StepLRScheduler
7 | from .plateau_lr import PlateauLRScheduler
8 |
9 |
10 | def create_scheduler(args, optimizer):
11 | num_epochs = args.epochs
12 |
13 | if getattr(args, 'lr_noise', None) is not None:
14 | lr_noise = getattr(args, 'lr_noise')
15 | if isinstance(lr_noise, (list, tuple)):
16 | noise_range = [n * num_epochs for n in lr_noise]
17 | if len(noise_range) == 1:
18 | noise_range = noise_range[0]
19 | else:
20 | noise_range = lr_noise * num_epochs
21 | else:
22 | noise_range = None
23 |
24 | lr_scheduler = None
25 | if args.sched == 'cosine':
26 | lr_scheduler = CosineLRScheduler(
27 | optimizer,
28 | t_initial=num_epochs,
29 | t_mul=getattr(args, 'lr_cycle_mul', 1.),
30 | lr_min=args.min_lr,
31 | decay_rate=args.decay_rate,
32 | warmup_lr_init=args.warmup_lr,
33 | warmup_t=args.warmup_epochs,
34 | cycle_limit=getattr(args, 'lr_cycle_limit', 1),
35 | t_in_epochs=True,
36 | noise_range_t=noise_range,
37 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
38 | noise_std=getattr(args, 'lr_noise_std', 1.),
39 | noise_seed=getattr(args, 'seed', 42),
40 | )
41 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
42 | elif args.sched == 'tanh':
43 | lr_scheduler = TanhLRScheduler(
44 | optimizer,
45 | t_initial=num_epochs,
46 | t_mul=getattr(args, 'lr_cycle_mul', 1.),
47 | lr_min=args.min_lr,
48 | warmup_lr_init=args.warmup_lr,
49 | warmup_t=args.warmup_epochs,
50 | cycle_limit=getattr(args, 'lr_cycle_limit', 1),
51 | t_in_epochs=True,
52 | noise_range_t=noise_range,
53 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
54 | noise_std=getattr(args, 'lr_noise_std', 1.),
55 | noise_seed=getattr(args, 'seed', 42),
56 | )
57 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
58 | elif args.sched == 'step':
59 | lr_scheduler = StepLRScheduler(
60 | optimizer,
61 | decay_t=args.decay_epochs,
62 | decay_rate=args.decay_rate,
63 | warmup_lr_init=args.warmup_lr,
64 | warmup_t=args.warmup_epochs,
65 | noise_range_t=noise_range,
66 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
67 | noise_std=getattr(args, 'lr_noise_std', 1.),
68 | noise_seed=getattr(args, 'seed', 42),
69 | )
70 | elif args.sched == 'plateau':
71 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
72 | lr_scheduler = PlateauLRScheduler(
73 | optimizer,
74 | decay_rate=args.decay_rate,
75 | patience_t=args.patience_epochs,
76 | lr_min=args.min_lr,
77 | mode=mode,
78 | warmup_lr_init=args.warmup_lr,
79 | warmup_t=args.warmup_epochs,
80 | cooldown_t=0,
81 | noise_range_t=noise_range,
82 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
83 | noise_std=getattr(args, 'lr_noise_std', 1.),
84 | noise_seed=getattr(args, 'seed', 42),
85 | )
86 |
87 | return lr_scheduler, num_epochs
88 |
--------------------------------------------------------------------------------
/scheduler/step_lr.py:
--------------------------------------------------------------------------------
1 | """ Step Scheduler
2 |
3 | Basic step LR schedule with warmup, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import math
8 | import torch
9 |
10 | from .scheduler import Scheduler
11 |
12 |
13 | class StepLRScheduler(Scheduler):
14 | """
15 | """
16 |
17 | def __init__(self,
18 | optimizer: torch.optim.Optimizer,
19 | decay_t: float,
20 | decay_rate: float = 1.,
21 | warmup_t=0,
22 | warmup_lr_init=0,
23 | t_in_epochs=True,
24 | noise_range_t=None,
25 | noise_pct=0.67,
26 | noise_std=1.0,
27 | noise_seed=42,
28 | initialize=True,
29 | ) -> None:
30 | super().__init__(
31 | optimizer, param_group_field="lr",
32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
33 | initialize=initialize)
34 |
35 | self.decay_t = decay_t
36 | self.decay_rate = decay_rate
37 | self.warmup_t = warmup_t
38 | self.warmup_lr_init = warmup_lr_init
39 | self.t_in_epochs = t_in_epochs
40 | if self.warmup_t:
41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
42 | super().update_groups(self.warmup_lr_init)
43 | else:
44 | self.warmup_steps = [1 for _ in self.base_values]
45 |
46 | def _get_lr(self, t):
47 | if t < self.warmup_t:
48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
49 | else:
50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
51 | return lrs
52 |
53 | def get_epoch_values(self, epoch: int):
54 | if self.t_in_epochs:
55 | return self._get_lr(epoch)
56 | else:
57 | return None
58 |
59 | def get_update_values(self, num_updates: int):
60 | if not self.t_in_epochs:
61 | return self._get_lr(num_updates)
62 | else:
63 | return None
64 |
--------------------------------------------------------------------------------
/scheduler/tanh_lr.py:
--------------------------------------------------------------------------------
1 | """ TanH Scheduler
2 |
3 | TanH schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import numpy as np
10 | import torch
11 |
12 | from .scheduler import Scheduler
13 |
14 |
15 | _logger = logging.getLogger(__name__)
16 |
17 |
18 | class TanhLRScheduler(Scheduler):
19 | """
20 | Hyberbolic-Tangent decay with restarts.
21 | This is described in the paper https://arxiv.org/abs/1806.01593
22 | """
23 |
24 | def __init__(self,
25 | optimizer: torch.optim.Optimizer,
26 | t_initial: int,
27 | lb: float = -6.,
28 | ub: float = 4.,
29 | t_mul: float = 1.,
30 | lr_min: float = 0.,
31 | decay_rate: float = 1.,
32 | warmup_t=0,
33 | warmup_lr_init=0,
34 | warmup_prefix=False,
35 | cycle_limit=0,
36 | t_in_epochs=True,
37 | noise_range_t=None,
38 | noise_pct=0.67,
39 | noise_std=1.0,
40 | noise_seed=42,
41 | initialize=True) -> None:
42 | super().__init__(
43 | optimizer, param_group_field="lr",
44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
45 | initialize=initialize)
46 |
47 | assert t_initial > 0
48 | assert lr_min >= 0
49 | assert lb < ub
50 | assert cycle_limit >= 0
51 | assert warmup_t >= 0
52 | assert warmup_lr_init >= 0
53 | self.lb = lb
54 | self.ub = ub
55 | self.t_initial = t_initial
56 | self.t_mul = t_mul
57 | self.lr_min = lr_min
58 | self.decay_rate = decay_rate
59 | self.cycle_limit = cycle_limit
60 | self.warmup_t = warmup_t
61 | self.warmup_lr_init = warmup_lr_init
62 | self.warmup_prefix = warmup_prefix
63 | self.t_in_epochs = t_in_epochs
64 | if self.warmup_t:
65 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
66 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
67 | super().update_groups(self.warmup_lr_init)
68 | else:
69 | self.warmup_steps = [1 for _ in self.base_values]
70 |
71 | def _get_lr(self, t):
72 | if t < self.warmup_t:
73 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
74 | else:
75 | if self.warmup_prefix:
76 | t = t - self.warmup_t
77 |
78 | if self.t_mul != 1:
79 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
80 | t_i = self.t_mul ** i * self.t_initial
81 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
82 | else:
83 | i = t // self.t_initial
84 | t_i = self.t_initial
85 | t_curr = t - (self.t_initial * i)
86 |
87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
88 | gamma = self.decay_rate ** i
89 | lr_min = self.lr_min * gamma
90 | lr_max_values = [v * gamma for v in self.base_values]
91 |
92 | tr = t_curr / t_i
93 | lrs = [
94 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
95 | for lr_max in lr_max_values
96 | ]
97 | else:
98 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values]
99 | return lrs
100 |
101 | def get_epoch_values(self, epoch: int):
102 | if self.t_in_epochs:
103 | return self._get_lr(epoch)
104 | else:
105 | return None
106 |
107 | def get_update_values(self, num_updates: int):
108 | if not self.t_in_epochs:
109 | return self._get_lr(num_updates)
110 | else:
111 | return None
112 |
113 | def get_cycle_length(self, cycles=0):
114 | if not cycles:
115 | cycles = self.cycle_limit
116 | cycles = max(1, cycles)
117 | if self.t_mul == 1.0:
118 | return self.t_initial * cycles
119 | else:
120 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
121 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import io
3 | import os
4 | import time
5 | from collections import defaultdict, deque
6 | import datetime
7 |
8 | import torch
9 | import torch.distributed as dist
10 |
11 | class SmoothedValue(object):
12 | """Track a series of values and provide access to smoothed values over a
13 | window or the global series average.
14 | """
15 |
16 | def __init__(self, window_size=20, fmt=None):
17 | if fmt is None:
18 | fmt = "{median:.4f} ({global_avg:.4f})"
19 | self.deque = deque(maxlen=window_size)
20 | self.total = 0.0
21 | self.count = 0
22 | self.fmt = fmt
23 |
24 | def update(self, value, n=1):
25 | self.deque.append(value)
26 | self.count += n
27 | self.total += value * n
28 |
29 | def synchronize_between_processes(self):
30 | """
31 | Warning: does not synchronize the deque!
32 | """
33 | if not is_dist_avail_and_initialized():
34 | return
35 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
36 | dist.barrier()
37 | dist.all_reduce(t)
38 | t = t.tolist()
39 | self.count = int(t[0])
40 | self.total = t[1]
41 |
42 | @property
43 | def median(self):
44 | d = torch.tensor(list(self.deque))
45 | return d.median().item()
46 |
47 | @property
48 | def avg(self):
49 | d = torch.tensor(list(self.deque), dtype=torch.float32)
50 | return d.mean().item()
51 |
52 | @property
53 | def global_avg(self):
54 | return self.total / self.count
55 |
56 | @property
57 | def max(self):
58 | return max(self.deque)
59 |
60 | @property
61 | def value(self):
62 | return self.deque[-1]
63 |
64 | def __str__(self):
65 | return self.fmt.format(
66 | median=self.median,
67 | avg=self.avg,
68 | global_avg=self.global_avg,
69 | max=self.max,
70 | value=self.value)
71 |
72 |
73 | class MetricLogger(object):
74 | def __init__(self, delimiter="\t"):
75 | self.meters = defaultdict(SmoothedValue)
76 | self.delimiter = delimiter
77 |
78 | def update(self, **kwargs):
79 | for k, v in kwargs.items():
80 | if isinstance(v, torch.Tensor):
81 | v = v.item()
82 | assert isinstance(v, (float, int))
83 | self.meters[k].update(v)
84 |
85 | def __getattr__(self, attr):
86 | if attr in self.meters:
87 | return self.meters[attr]
88 | if attr in self.__dict__:
89 | return self.__dict__[attr]
90 | raise AttributeError("'{}' object has no attribute '{}'".format(
91 | type(self).__name__, attr))
92 |
93 | def __str__(self):
94 | loss_str = []
95 | for name, meter in self.meters.items():
96 | loss_str.append(
97 | "{}: {}".format(name, str(meter))
98 | )
99 | return self.delimiter.join(loss_str)
100 |
101 | def global_avg(self):
102 | loss_str = []
103 | for name, meter in self.meters.items():
104 | loss_str.append(
105 | "{}: {:.4f}".format(name, meter.global_avg)
106 | )
107 | return self.delimiter.join(loss_str)
108 |
109 | def synchronize_between_processes(self):
110 | for meter in self.meters.values():
111 | meter.synchronize_between_processes()
112 |
113 | def add_meter(self, name, meter):
114 | self.meters[name] = meter
115 |
116 | def log_every(self, iterable, print_freq, header=None):
117 | i = 0
118 | if not header:
119 | header = ''
120 | start_time = time.time()
121 | end = time.time()
122 | iter_time = SmoothedValue(fmt='{avg:.4f}')
123 | data_time = SmoothedValue(fmt='{avg:.4f}')
124 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
125 | log_msg = [
126 | header,
127 | '[{0' + space_fmt + '}/{1}]',
128 | 'eta: {eta}',
129 | '{meters}',
130 | 'time: {time}',
131 | 'data: {data}'
132 | ]
133 | if torch.cuda.is_available():
134 | log_msg.append('max mem: {memory:.0f}')
135 | log_msg = self.delimiter.join(log_msg)
136 | MB = 1024.0 * 1024.0
137 | for obj in iterable:
138 | data_time.update(time.time() - end)
139 | yield obj
140 | iter_time.update(time.time() - end)
141 | if i % print_freq == 0 or i == len(iterable) - 1:
142 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
143 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
144 | if torch.cuda.is_available():
145 | print(log_msg.format(
146 | i, len(iterable), eta=eta_string,
147 | meters=str(self),
148 | time=str(iter_time), data=str(data_time),
149 | memory=torch.cuda.max_memory_allocated() / MB))
150 | else:
151 | print(log_msg.format(
152 | i, len(iterable), eta=eta_string,
153 | meters=str(self),
154 | time=str(iter_time), data=str(data_time)))
155 | i += 1
156 | end = time.time()
157 | total_time = time.time() - start_time
158 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
159 | print('{} Total time: {} ({:.4f} s / it)'.format(
160 | header, total_time_str, total_time / len(iterable)))
161 |
162 |
163 |
164 | class AttrDict(dict):
165 | def __init__(self, *args, **kwargs):
166 | super(AttrDict, self).__init__(*args, **kwargs)
167 | self.__dict__ = self
168 |
169 |
170 | def compute_acc(logits, label, reduction='mean'):
171 | ret = (torch.argmax(logits, dim=1) == label).float()
172 | if reduction == 'none':
173 | return ret.detach()
174 | elif reduction == 'mean':
175 | return ret.mean().item()
176 |
177 | def compute_n_params(model, return_str=True):
178 | tot = 0
179 | for p in model.parameters():
180 | w = 1
181 | for x in p.shape:
182 | w *= x
183 | tot += w
184 | if return_str:
185 | if tot >= 1e6:
186 | return '{:.1f}M'.format(tot / 1e6)
187 | else:
188 | return '{:.1f}K'.format(tot / 1e3)
189 | else:
190 | return tot
191 |
192 | def setup_for_distributed(is_master):
193 | """
194 | This function disables printing when not in master process
195 | """
196 | import builtins as __builtin__
197 | builtin_print = __builtin__.print
198 |
199 | def print(*args, **kwargs):
200 | force = kwargs.pop('force', False)
201 | if is_master or force:
202 | builtin_print(*args, **kwargs)
203 |
204 | __builtin__.print = print
205 |
206 |
207 | def is_dist_avail_and_initialized():
208 | if not dist.is_available():
209 | return False
210 | if not dist.is_initialized():
211 | return False
212 | return True
213 |
214 |
215 | def get_world_size():
216 | if not is_dist_avail_and_initialized():
217 | return 1
218 | return dist.get_world_size()
219 |
220 |
221 | def get_rank():
222 | if not is_dist_avail_and_initialized():
223 | return 0
224 | return dist.get_rank()
225 |
226 |
227 | def is_main_process():
228 | return get_rank() == 0
229 |
230 |
231 | def save_on_master(*args, **kwargs):
232 | if is_main_process():
233 | torch.save(*args, **kwargs)
234 |
235 |
236 | def init_distributed_mode(args):
237 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
238 | args.rank = int(os.environ["RANK"])
239 | args.world_size = int(os.environ['WORLD_SIZE'])
240 | args.gpu = int(os.environ['LOCAL_RANK'])
241 | elif 'SLURM_PROCID' in os.environ:
242 | args.rank = int(os.environ['SLURM_PROCID'])
243 | args.gpu = args.rank % torch.cuda.device_count()
244 | else:
245 | print('Not using distributed mode')
246 | args.distributed = False
247 | return
248 |
249 | args.distributed = True
250 |
251 | torch.cuda.set_device(args.gpu)
252 | args.dist_backend = 'nccl'
253 | print('| distributed init (rank {}): {}'.format(
254 | args.rank, args.dist_url), flush=True)
255 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
256 | world_size=args.world_size, rank=args.rank)
257 | torch.distributed.barrier()
258 | setup_for_distributed(args.rank == 0)
--------------------------------------------------------------------------------
/vqaTools/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 |
--------------------------------------------------------------------------------
/vqaTools/vqa.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 | __version__ = '0.9'
3 |
4 | # Interface for accessing the VQA dataset.
5 |
6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
8 |
9 | # The following functions are defined:
10 | # VQA - VQA class that loads VQA annotation file and prepares data structures.
11 | # getQuesIds - Get question ids that satisfy given filter conditions.
12 | # getImgIds - Get image ids that satisfy given filter conditions.
13 | # loadQA - Load questions and answers with the specified question ids.
14 | # showQA - Display the specified questions and answers.
15 | # loadRes - Load result file and create result object.
16 |
17 | # Help on each function can be accessed by: "help(COCO.function)"
18 |
19 | import json
20 | import datetime
21 | import copy
22 |
23 | class VQA:
24 | def __init__(self, annotation_file=None, question_file=None):
25 | """
26 | Constructor of VQA helper class for reading and visualizing questions and answers.
27 | :param annotation_file (str): location of VQA annotation file
28 | :return:
29 | """
30 | # load dataset
31 | self.dataset = {}
32 | self.questions = {}
33 | self.qa = {}
34 | self.qqa = {}
35 | self.imgToQA = {}
36 | if not annotation_file == None and not question_file == None:
37 | print('loading VQA annotations and questions into memory...')
38 | time_t = datetime.datetime.utcnow()
39 | dataset = json.load(open(annotation_file, 'r'))
40 | questions = json.load(open(question_file, 'r'))
41 | self.dataset = dataset
42 | self.questions = questions
43 | self.createIndex()
44 |
45 | def createIndex(self):
46 | # create index
47 | print('creating index...')
48 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
49 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
50 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
51 | for ann in self.dataset['annotations']:
52 | imgToQA[ann['image_id']] += [ann]
53 | qa[ann['question_id']] = ann
54 | for ques in self.questions['questions']:
55 | qqa[ques['question_id']] = ques
56 | print('index created!')
57 |
58 | # create class members
59 | self.qa = qa
60 | self.qqa = qqa
61 | self.imgToQA = imgToQA
62 |
63 | def info(self):
64 | """
65 | Print information about the VQA annotation file.
66 | :return:
67 | """
68 | for key, value in self.datset['info'].items():
69 | print('%s: %s'%(key, value))
70 |
71 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
72 | """
73 | Get question ids that satisfy given filter conditions. default skips that filter
74 | :param imgIds (int array) : get question ids for given imgs
75 | quesTypes (str array) : get question ids for given question types
76 | ansTypes (str array) : get question ids for given answer types
77 | :return: ids (int array) : integer array of question ids
78 | """
79 | imgIds = imgIds if type(imgIds) == list else [imgIds]
80 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
81 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
82 |
83 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
84 | anns = self.dataset['annotations']
85 | else:
86 | if not len(imgIds) == 0:
87 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[])
88 | else:
89 | anns = self.dataset['annotations']
90 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
91 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
92 | ids = [ann['question_id'] for ann in anns]
93 | return ids
94 |
95 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
96 | """
97 | Get image ids that satisfy given filter conditions. default skips that filter
98 | :param quesIds (int array) : get image ids for given question ids
99 | quesTypes (str array) : get image ids for given question types
100 | ansTypes (str array) : get image ids for given answer types
101 | :return: ids (int array) : integer array of image ids
102 | """
103 | quesIds = quesIds if type(quesIds) == list else [quesIds]
104 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
105 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
106 |
107 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
108 | anns = self.dataset['annotations']
109 | else:
110 | if not len(quesIds) == 0:
111 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[])
112 | else:
113 | anns = self.dataset['annotations']
114 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
115 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
116 | ids = [ann['image_id'] for ann in anns]
117 | return ids
118 |
119 | def loadQA(self, ids=[]):
120 | """
121 | Load questions and answers with the specified question ids.
122 | :param ids (int array) : integer ids specifying question ids
123 | :return: qa (object array) : loaded qa objects
124 | """
125 | if type(ids) == list:
126 | return [self.qa[id] for id in ids]
127 | elif type(ids) == int:
128 | return [self.qa[ids]]
129 |
130 | def showQA(self, anns):
131 | """
132 | Display the specified annotations.
133 | :param anns (array of object): annotations to display
134 | :return: None
135 | """
136 | if len(anns) == 0:
137 | return 0
138 | for ann in anns:
139 | quesId = ann['question_id']
140 | print("Question: %s" %(self.qqa[quesId]['question']))
141 | for ans in ann['answers']:
142 | print("Answer %d: %s" %(ans['answer_id'], ans['answer']))
143 |
144 | def loadRes(self, resFile, quesFile):
145 | """
146 | Load result file and return a result object.
147 | :param resFile (str) : file name of result file
148 | :return: res (obj) : result api object
149 | """
150 | res = VQA()
151 | res.questions = json.load(open(quesFile))
152 | res.dataset['info'] = copy.deepcopy(self.questions['info'])
153 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
154 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
155 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
156 | res.dataset['license'] = copy.deepcopy(self.questions['license'])
157 |
158 | print('Loading and preparing results... ')
159 | time_t = datetime.datetime.utcnow()
160 | anns = json.load(open(resFile))
161 | assert type(anns) == list, 'results is not an array of objects'
162 | annsQuesIds = [ann['question_id'] for ann in anns]
163 | assert set(annsQuesIds) == set(self.getQuesIds()), \
164 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
165 | for ann in anns:
166 | quesId = ann['question_id']
167 | if res.dataset['task_type'] == 'Multiple Choice':
168 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices'
169 | qaAnn = self.qa[quesId]
170 | ann['image_id'] = qaAnn['image_id']
171 | ann['question_type'] = qaAnn['question_type']
172 | ann['answer_type'] = qaAnn['answer_type']
173 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))
174 |
175 | res.dataset['annotations'] = anns
176 | res.createIndex()
177 | return res
178 |
--------------------------------------------------------------------------------
/vqaTools/vqaEval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | __author__='aagrawal'
4 |
5 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
6 | # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
7 | import sys
8 | import re
9 |
10 | class VQAEval:
11 | def __init__(self, vqa, vqaRes, n=2):
12 | self.n = n
13 | self.accuracy = {}
14 | self.evalQA = {}
15 | self.evalQuesType = {}
16 | self.evalAnsType = {}
17 | self.vqa = vqa
18 | self.vqaRes = vqaRes
19 | self.params = {'question_id': vqa.getQuesIds()}
20 | self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't",
21 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't",
22 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've",
23 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've",
24 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
25 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
26 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
27 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've",
28 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've",
29 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll",
30 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've",
31 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've",
32 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've",
33 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've",
34 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't",
35 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're",
36 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've",
37 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
38 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
39 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
40 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've",
41 | "youll": "you'll", "youre": "you're", "youve": "you've"}
42 | self.manualMap = { 'none': '0',
43 | 'zero': '0',
44 | 'one': '1',
45 | 'two': '2',
46 | 'three': '3',
47 | 'four': '4',
48 | 'five': '5',
49 | 'six': '6',
50 | 'seven': '7',
51 | 'eight': '8',
52 | 'nine': '9',
53 | 'ten': '10'
54 | }
55 | self.articles = ['a',
56 | 'an',
57 | 'the'
58 | ]
59 |
60 |
61 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
62 | self.commaStrip = re.compile("(\d)(,)(\d)")
63 | self.punct = [';', r"/", '[', ']', '"', '{', '}',
64 | '(', ')', '=', '+', '\\', '_', '-',
65 | '>', '<', '@', '`', ',', '?', '!']
66 |
67 |
68 | def evaluate(self, quesIds=None):
69 | if quesIds == None:
70 | quesIds = [quesId for quesId in self.params['question_id']]
71 | gts = {}
72 | res = {}
73 | for quesId in quesIds:
74 | gts[quesId] = self.vqa.qa[quesId]
75 | res[quesId] = self.vqaRes.qa[quesId]
76 |
77 | # =================================================
78 | # Compute accuracy
79 | # =================================================
80 | accQA = []
81 | accQuesType = {}
82 | accAnsType = {}
83 | print ("computing accuracy")
84 | step = 0
85 | for quesId in quesIds:
86 | resAns = res[quesId]['answer']
87 | resAns = resAns.replace('\n', ' ')
88 | resAns = resAns.replace('\t', ' ')
89 | resAns = resAns.strip()
90 | resAns = self.processPunctuation(resAns)
91 | resAns = self.processDigitArticle(resAns)
92 | gtAcc = []
93 | gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
94 | if len(set(gtAnswers)) > 1:
95 | for ansDic in gts[quesId]['answers']:
96 | ansDic['answer'] = self.processPunctuation(ansDic['answer'])
97 | for gtAnsDatum in gts[quesId]['answers']:
98 | otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
99 | matchingAns = [item for item in otherGTAns if item['answer']==resAns]
100 | acc = min(1, float(len(matchingAns))/3)
101 | gtAcc.append(acc)
102 | quesType = gts[quesId]['question_type']
103 | ansType = gts[quesId]['answer_type']
104 | avgGTAcc = float(sum(gtAcc))/len(gtAcc)
105 | accQA.append(avgGTAcc)
106 | if quesType not in accQuesType:
107 | accQuesType[quesType] = []
108 | accQuesType[quesType].append(avgGTAcc)
109 | if ansType not in accAnsType:
110 | accAnsType[ansType] = []
111 | accAnsType[ansType].append(avgGTAcc)
112 | self.setEvalQA(quesId, avgGTAcc)
113 | self.setEvalQuesType(quesId, quesType, avgGTAcc)
114 | self.setEvalAnsType(quesId, ansType, avgGTAcc)
115 | if step%100 == 0:
116 | self.updateProgress(step/float(len(quesIds)))
117 | step = step + 1
118 |
119 | self.setAccuracy(accQA, accQuesType, accAnsType)
120 | print ("Done computing accuracy")
121 |
122 | def processPunctuation(self, inText):
123 | outText = inText
124 | for p in self.punct:
125 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
126 | outText = outText.replace(p, '')
127 | else:
128 | outText = outText.replace(p, ' ')
129 | outText = self.periodStrip.sub("",
130 | outText,
131 | re.UNICODE)
132 | return outText
133 |
134 | def processDigitArticle(self, inText):
135 | outText = []
136 | tempText = inText.lower().split()
137 | for word in tempText:
138 | word = self.manualMap.setdefault(word, word)
139 | if word not in self.articles:
140 | outText.append(word)
141 | else:
142 | pass
143 | for wordId, word in enumerate(outText):
144 | if word in self.contractions:
145 | outText[wordId] = self.contractions[word]
146 | outText = ' '.join(outText)
147 | return outText
148 |
149 | def setAccuracy(self, accQA, accQuesType, accAnsType):
150 | self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
151 | self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
152 | self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
153 |
154 | def setEvalQA(self, quesId, acc):
155 | self.evalQA[quesId] = round(100*acc, self.n)
156 |
157 | def setEvalQuesType(self, quesId, quesType, acc):
158 | if quesType not in self.evalQuesType:
159 | self.evalQuesType[quesType] = {}
160 | self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
161 |
162 | def setEvalAnsType(self, quesId, ansType, acc):
163 | if ansType not in self.evalAnsType:
164 | self.evalAnsType[ansType] = {}
165 | self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
166 |
167 | def updateProgress(self, progress):
168 | barLength = 20
169 | status = ""
170 | if isinstance(progress, int):
171 | progress = float(progress)
172 | if not isinstance(progress, float):
173 | progress = 0
174 | status = "error: progress var must be float\r\n"
175 | if progress < 0:
176 | progress = 0
177 | status = "Halt...\r\n"
178 | if progress >= 1:
179 | progress = 1
180 | status = "Done...\r\n"
181 | block = int(round(barLength*progress))
182 | text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
183 | sys.stdout.write(text)
184 | sys.stdout.flush()
--------------------------------------------------------------------------------