├── misc ├── __init__.py ├── resnet_utils.py ├── loss_wrapper.py ├── resnet.py ├── rewards.py └── utils.py ├── vis ├── imgs │ └── dummy └── index.html ├── scripts ├── copy_model.sh ├── make_bu_data.py ├── prepro_reference_json.py ├── prepro_feats.py ├── prepro_ngrams.py ├── dump_to_lmdb.py ├── build_bpe_subword_nmt.py └── prepro_labels.py ├── test-last.sh ├── .gitmodules ├── test-best.sh ├── ADVANCED.md ├── LICENSE ├── train.sh ├── train-wo-refining.sh ├── models ├── __init__.py ├── AttEnsemble.py ├── ShowTellModel.py ├── FCModel.py ├── AoAModel.py ├── OldModel.py ├── CaptionModel.py └── TransformerModel.py ├── eval.py ├── README.md ├── eval_ensemble.py ├── data └── README.md ├── dataloaderraw.py ├── eval_utils.py ├── train.py ├── dataloader.py └── opts.py /misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vis/imgs/dummy: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/copy_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cp -r log_$1 log_$2 4 | cd log_$2 5 | mv infos_$1-best.pkl infos_$2-best.pkl 6 | mv infos_$1.pkl infos_$2.pkl 7 | cd ../ -------------------------------------------------------------------------------- /test-last.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$3 python eval.py --dump_images 0 --dump_json 1 --num_images -1 --model $1/model.pth --infos_path $1/infos_$2.pkl --language_eval 1 --beam_size $4 --batch_size 100 --split test -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cider"] 2 | path = cider 3 | url = https://github.com/ruotianluo/cider.git 4 | [submodule "coco-caption"] 5 | path = coco-caption 6 | url = https://github.com/ruotianluo/coco-caption.git 7 | -------------------------------------------------------------------------------- /test-best.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$3 python eval.py --dump_images 0 --dump_json 1 --num_images -1 --model $1/model-best.pth --infos_path $1/infos_$2-best.pkl --language_eval 1 --image_root /mnt/hl/dataset/coco/train2014/ --beam_size $4 --batch_size 100 --split test -------------------------------------------------------------------------------- /ADVANCED.md: -------------------------------------------------------------------------------- 1 | # Advanced 2 | 3 | ## Ensemble 4 | 5 | Current ensemble only supports models which are subclass of AttModel. Here is example of the script to run ensemble models. The `eval_ensemble.py` assumes the model saving under `log_$id`. 6 | 7 | ``` 8 | python eval_ensemble.py --dump_json 0 --ids model1,model2,model3 --weights 0.3,0.3,0.3 --batch_size 1 --dump_images 0 --num_images 5000 --split test --language_eval 1 --beam_size 5 --temperature 1.0 --sample_method greedy --max_length 30 9 | ``` 10 | 11 | ## Batch normalization 12 | 13 | ## Box feature -------------------------------------------------------------------------------- /misc/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class myResnet(nn.Module): 6 | def __init__(self, resnet): 7 | super(myResnet, self).__init__() 8 | self.resnet = resnet 9 | 10 | def forward(self, img, att_size=14): 11 | x = img.unsqueeze(0) 12 | 13 | x = self.resnet.conv1(x) 14 | x = self.resnet.bn1(x) 15 | x = self.resnet.relu(x) 16 | x = self.resnet.maxpool(x) 17 | 18 | x = self.resnet.layer1(x) 19 | x = self.resnet.layer2(x) 20 | x = self.resnet.layer3(x) 21 | x = self.resnet.layer4(x) 22 | 23 | fc = x.mean(3).mean(2).squeeze() 24 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0) 25 | 26 | return fc, att 27 | 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ruotian(RT) Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /misc/loss_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import misc.utils as utils 3 | from misc.rewards import init_scorer, get_self_critical_reward 4 | 5 | class LossWrapper(torch.nn.Module): 6 | def __init__(self, model, opt): 7 | super(LossWrapper, self).__init__() 8 | self.opt = opt 9 | self.model = model 10 | if opt.label_smoothing > 0: 11 | self.crit = utils.LabelSmoothing(smoothing=opt.label_smoothing) 12 | else: 13 | self.crit = utils.LanguageModelCriterion() 14 | self.rl_crit = utils.RewardCriterion() 15 | 16 | def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, 17 | sc_flag): 18 | out = {} 19 | if not sc_flag: 20 | loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]) 21 | else: 22 | self.model.eval() 23 | with torch.no_grad(): 24 | greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample') 25 | self.model.train() 26 | gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_method':'sample'}, mode='sample') 27 | gts = [gts[_] for _ in gt_indices.tolist()] 28 | reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt) 29 | reward = torch.from_numpy(reward).float().to(gen_result.device) 30 | loss = self.rl_crit(sample_logprobs, gen_result.data, reward) 31 | out['reward'] = reward[:,0].mean() 32 | out['loss'] = loss 33 | return out 34 | -------------------------------------------------------------------------------- /scripts/make_bu_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import base64 7 | import numpy as np 8 | import csv 9 | import sys 10 | import zlib 11 | import time 12 | import mmap 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | # output_dir 18 | parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory') 19 | parser.add_argument('--output_dir', default='data/cocobu', help='output feature files') 20 | 21 | args = parser.parse_args() 22 | 23 | csv.field_size_limit(sys.maxsize) 24 | 25 | 26 | FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] 27 | infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv', 28 | 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\ 29 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \ 30 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1'] 31 | 32 | os.makedirs(args.output_dir+'_att') 33 | os.makedirs(args.output_dir+'_fc') 34 | os.makedirs(args.output_dir+'_box') 35 | 36 | for infile in infiles: 37 | print('Reading ' + infile) 38 | with open(os.path.join(args.downloaded_feats, infile), "r+b") as tsv_in_file: 39 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 40 | for item in reader: 41 | item['image_id'] = int(item['image_id']) 42 | item['num_boxes'] = int(item['num_boxes']) 43 | for field in ['boxes', 'features']: 44 | item[field] = np.frombuffer(base64.decodestring(item[field]), 45 | dtype=np.float32).reshape((item['num_boxes'],-1)) 46 | np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features']) 47 | np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0)) 48 | np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes']) 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /vis/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | neuraltalk2 results visualization 7 | 8 | 42 | 43 | 44 |
45 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | id="aoanet" 2 | if [ ! -f log/log_$id/infos_$id.pkl ]; then 3 | start_from="" 4 | else 5 | start_from="--start_from log/log_$id" 6 | fi 7 | python train.py --id $id \ 8 | --caption_model aoa \ 9 | --refine 1 \ 10 | --refine_aoa 1 \ 11 | --use_ff 0 \ 12 | --decoder_type AoA \ 13 | --use_multi_head 2 \ 14 | --num_heads 8 \ 15 | --multi_head_scale 1 \ 16 | --mean_feats 1 \ 17 | --ctx_drop 1 \ 18 | --dropout_aoa 0.3 \ 19 | --label_smoothing 0.2 \ 20 | --input_json data/cocotalk.json \ 21 | --input_label_h5 data/cocotalk_label.h5 \ 22 | --input_fc_dir data/cocobu_fc \ 23 | --input_att_dir data/cocobu_att \ 24 | --input_box_dir data/cocobu_box \ 25 | --seq_per_img 5 \ 26 | --batch_size 10 \ 27 | --beam_size 1 \ 28 | --learning_rate 2e-4 \ 29 | --num_layers 2 \ 30 | --input_encoding_size 1024 \ 31 | --rnn_size 1024 \ 32 | --learning_rate_decay_start 0 \ 33 | --scheduled_sampling_start 0 \ 34 | --checkpoint_path log/log_$id \ 35 | $start_from \ 36 | --save_checkpoint_every 6000 \ 37 | --language_eval 1 \ 38 | --val_images_use -1 \ 39 | --max_epochs 25 \ 40 | --scheduled_sampling_increase_every 5 \ 41 | --scheduled_sampling_max_prob 0.5 \ 42 | --learning_rate_decay_every 3 43 | 44 | python train.py --id $id \ 45 | --caption_model aoa \ 46 | --refine 1 \ 47 | --refine_aoa 1 \ 48 | --use_ff 0 \ 49 | --decoder_type AoA \ 50 | --use_multi_head 2 \ 51 | --num_heads 8 \ 52 | --multi_head_scale 1 \ 53 | --mean_feats 1 \ 54 | --ctx_drop 1 \ 55 | --dropout_aoa 0.3 \ 56 | --input_json data/cocotalk.json \ 57 | --input_label_h5 data/cocotalk_label.h5 \ 58 | --input_fc_dir data/cocobu_fc \ 59 | --input_att_dir data/cocobu_att \ 60 | --input_box_dir data/cocobu_box \ 61 | --seq_per_img 5 \ 62 | --batch_size 10 \ 63 | --beam_size 1 \ 64 | --num_layers 2 \ 65 | --input_encoding_size 1024 \ 66 | --rnn_size 1024 \ 67 | --language_eval 1 \ 68 | --val_images_use -1 \ 69 | --save_checkpoint_every 3000 \ 70 | --start_from log/log_$id \ 71 | --checkpoint_path log/log_$id"_rl" \ 72 | --learning_rate 2e-5 \ 73 | --max_epochs 40 \ 74 | --self_critical_after 0 \ 75 | --learning_rate_decay_start -1 \ 76 | --scheduled_sampling_start -1 \ 77 | --reduce_on_plateau -------------------------------------------------------------------------------- /train-wo-refining.sh: -------------------------------------------------------------------------------- 1 | id="aoanet-wo-refinng" 2 | if [ ! -f log/log_$id/infos_$id.pkl ]; then 3 | start_from="" 4 | else 5 | start_from="--start_from log/log_$id" 6 | fi 7 | python train.py --id $id \ 8 | --caption_model aoa \ 9 | --refine 0 \ 10 | --refine_aoa 0 \ 11 | --use_ff 0 \ 12 | --decoder_type AoA \ 13 | --use_multi_head 2 \ 14 | --num_heads 8 \ 15 | --multi_head_scale 1 \ 16 | --mean_feats 1 \ 17 | --ctx_drop 0 \ 18 | --dropout_aoa 0.3 \ 19 | --label_smoothing 0 \ 20 | --input_json data/cocotalk.json \ 21 | --input_label_h5 data/cocotalk_label.h5 \ 22 | --input_fc_dir data/cocobu_fc \ 23 | --input_att_dir data/cocobu_att \ 24 | --input_box_dir data/cocobu_box \ 25 | --seq_per_img 5 \ 26 | --batch_size 10 \ 27 | --beam_size 1 \ 28 | --learning_rate 2e-4 \ 29 | --num_layers 2 \ 30 | --input_encoding_size 1024 \ 31 | --rnn_size 1024 \ 32 | --learning_rate_decay_start 0 \ 33 | --scheduled_sampling_start 0 \ 34 | --checkpoint_path log/log_$id \ 35 | $start_from \ 36 | --save_checkpoint_every 6000 \ 37 | --language_eval 1 \ 38 | --val_images_use -1 \ 39 | --max_epochs 35 \ 40 | --scheduled_sampling_increase_every 5 \ 41 | --scheduled_sampling_max_prob 0.5 \ 42 | --learning_rate_decay_every 3 43 | 44 | python train.py --id $id \ 45 | --caption_model aoa \ 46 | --refine 0 \ 47 | --refine_aoa 0 \ 48 | --use_ff 0 \ 49 | --decoder_type AoA \ 50 | --use_multi_head 2 \ 51 | --num_heads 8 \ 52 | --multi_head_scale 1 \ 53 | --mean_feats 1 \ 54 | --ctx_drop 0 \ 55 | --dropout_aoa 0.3 \ 56 | --input_json data/cocotalk.json \ 57 | --input_label_h5 data/cocotalk_label.h5 \ 58 | --input_fc_dir data/cocobu_fc \ 59 | --input_att_dir data/cocobu_att \ 60 | --input_box_dir data/cocobu_box \ 61 | --seq_per_img 5 \ 62 | --batch_size 10 \ 63 | --beam_size 1 \ 64 | --num_layers 2 \ 65 | --input_encoding_size 1024 \ 66 | --rnn_size 1024 \ 67 | --language_eval 1 \ 68 | --val_images_use -1 \ 69 | --save_checkpoint_every 3000 \ 70 | --start_from log/log_$id \ 71 | --checkpoint_path log/log_$id"_rl" \ 72 | --learning_rate 2e-5 \ 73 | --max_epochs 60 \ 74 | --self_critical_after 0 \ 75 | --learning_rate_decay_start -1 \ 76 | --scheduled_sampling_start -1 \ 77 | --reduce_on_plateau 78 | -------------------------------------------------------------------------------- /misc/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models.resnet 4 | from torchvision.models.resnet import BasicBlock, Bottleneck 5 | 6 | class ResNet(torchvision.models.resnet.ResNet): 7 | def __init__(self, block, layers, num_classes=1000): 8 | super(ResNet, self).__init__(block, layers, num_classes) 9 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 10 | for i in range(2, 5): 11 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) 12 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) 13 | 14 | def resnet18(pretrained=False): 15 | """Constructs a ResNet-18 model. 16 | 17 | Args: 18 | pretrained (bool): If True, returns a model pre-trained on ImageNet 19 | """ 20 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 21 | if pretrained: 22 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 23 | return model 24 | 25 | 26 | def resnet34(pretrained=False): 27 | """Constructs a ResNet-34 model. 28 | 29 | Args: 30 | pretrained (bool): If True, returns a model pre-trained on ImageNet 31 | """ 32 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 33 | if pretrained: 34 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 35 | return model 36 | 37 | 38 | def resnet50(pretrained=False): 39 | """Constructs a ResNet-50 model. 40 | 41 | Args: 42 | pretrained (bool): If True, returns a model pre-trained on ImageNet 43 | """ 44 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 45 | if pretrained: 46 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 47 | return model 48 | 49 | 50 | def resnet101(pretrained=False): 51 | """Constructs a ResNet-101 model. 52 | 53 | Args: 54 | pretrained (bool): If True, returns a model pre-trained on ImageNet 55 | """ 56 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 57 | if pretrained: 58 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 59 | return model 60 | 61 | 62 | def resnet152(pretrained=False): 63 | """Constructs a ResNet-152 model. 64 | 65 | Args: 66 | pretrained (bool): If True, returns a model pre-trained on ImageNet 67 | """ 68 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 69 | if pretrained: 70 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 71 | return model -------------------------------------------------------------------------------- /misc/rewards.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import time 7 | import misc.utils as utils 8 | from collections import OrderedDict 9 | import torch 10 | 11 | import sys 12 | sys.path.append("cider") 13 | from pyciderevalcap.ciderD.ciderD import CiderD 14 | sys.path.append("coco-caption") 15 | from pycocoevalcap.bleu.bleu import Bleu 16 | 17 | CiderD_scorer = None 18 | Bleu_scorer = None 19 | #CiderD_scorer = CiderD(df='corpus') 20 | 21 | def init_scorer(cached_tokens): 22 | global CiderD_scorer 23 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) 24 | global Bleu_scorer 25 | Bleu_scorer = Bleu_scorer or Bleu(4) 26 | 27 | def array_to_str(arr): 28 | out = '' 29 | for i in range(len(arr)): 30 | out += str(arr[i]) + ' ' 31 | if arr[i] == 0: 32 | break 33 | return out.strip() 34 | 35 | def get_self_critical_reward(greedy_res, data_gts, gen_result, opt): 36 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img 37 | seq_per_img = batch_size // len(data_gts) 38 | 39 | res = OrderedDict() 40 | 41 | gen_result = gen_result.data.cpu().numpy() 42 | greedy_res = greedy_res.data.cpu().numpy() 43 | for i in range(batch_size): 44 | res[i] = [array_to_str(gen_result[i])] 45 | for i in range(batch_size): 46 | res[batch_size + i] = [array_to_str(greedy_res[i])] 47 | 48 | gts = OrderedDict() 49 | for i in range(len(data_gts)): 50 | gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] 51 | 52 | res_ = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)] 53 | res__ = {i: res[i] for i in range(2 * batch_size)} 54 | gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} 55 | if opt.cider_reward_weight > 0: 56 | _, cider_scores = CiderD_scorer.compute_score(gts, res_) 57 | print('Cider scores:', _) 58 | else: 59 | cider_scores = 0 60 | if opt.bleu_reward_weight > 0: 61 | _, bleu_scores = Bleu_scorer.compute_score(gts, res__) 62 | bleu_scores = np.array(bleu_scores[3]) 63 | print('Bleu scores:', _[3]) 64 | else: 65 | bleu_scores = 0 66 | scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores 67 | 68 | scores = scores[:batch_size] - scores[batch_size:] 69 | 70 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) 71 | 72 | return rewards 73 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | 8 | import numpy as np 9 | import misc.utils as utils 10 | import torch 11 | 12 | from .ShowTellModel import ShowTellModel 13 | from .FCModel import FCModel 14 | from .OldModel import ShowAttendTellModel, AllImgModel 15 | from .AttModel import * 16 | from .TransformerModel import TransformerModel 17 | from .AoAModel import AoAModel 18 | 19 | def setup(opt): 20 | if opt.caption_model == 'fc': 21 | model = FCModel(opt) 22 | elif opt.caption_model == 'language_model': 23 | model = LMModel(opt) 24 | elif opt.caption_model == 'newfc': 25 | model = NewFCModel(opt) 26 | elif opt.caption_model == 'show_tell': 27 | model = ShowTellModel(opt) 28 | # Att2in model in self-critical 29 | elif opt.caption_model == 'att2in': 30 | model = Att2inModel(opt) 31 | # Att2in model with two-layer MLP img embedding and word embedding 32 | elif opt.caption_model == 'att2in2': 33 | model = Att2in2Model(opt) 34 | elif opt.caption_model == 'att2all2': 35 | model = Att2all2Model(opt) 36 | # Adaptive Attention model from Knowing when to look 37 | elif opt.caption_model == 'adaatt': 38 | model = AdaAttModel(opt) 39 | # Adaptive Attention with maxout lstm 40 | elif opt.caption_model == 'adaattmo': 41 | model = AdaAttMOModel(opt) 42 | # Top-down attention model 43 | elif opt.caption_model == 'topdown': 44 | model = TopDownModel(opt) 45 | # StackAtt 46 | elif opt.caption_model == 'stackatt': 47 | model = StackAttModel(opt) 48 | # DenseAtt 49 | elif opt.caption_model == 'denseatt': 50 | model = DenseAttModel(opt) 51 | # Transformer 52 | elif opt.caption_model == 'transformer': 53 | model = TransformerModel(opt) 54 | # AoANet 55 | elif opt.caption_model == 'aoa': 56 | model = AoAModel(opt) 57 | else: 58 | raise Exception("Caption model not supported: {}".format(opt.caption_model)) 59 | 60 | # check compatibility if training is continued from previously saved model 61 | if vars(opt).get('start_from', None) is not None: 62 | # check if all necessary files exist 63 | assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from 64 | assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from 65 | model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth'))) 66 | 67 | return model 68 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import opts 13 | import models 14 | from dataloader import * 15 | from dataloaderraw import * 16 | import eval_utils 17 | import argparse 18 | import misc.utils as utils 19 | import torch 20 | 21 | # Input arguments and options 22 | parser = argparse.ArgumentParser() 23 | # Input paths 24 | parser.add_argument('--model', type=str, default='', 25 | help='path to model to evaluate') 26 | parser.add_argument('--cnn_model', type=str, default='resnet101', 27 | help='resnet101, resnet152') 28 | parser.add_argument('--infos_path', type=str, default='', 29 | help='path to infos to evaluate') 30 | opts.add_eval_options(parser) 31 | 32 | opt = parser.parse_args() 33 | 34 | # Load infos 35 | with open(opt.infos_path, 'rb') as f: 36 | infos = utils.pickle_load(f) 37 | 38 | # override and collect parameters 39 | replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] 40 | ignore = ['start_from'] 41 | 42 | for k in vars(infos['opt']).keys(): 43 | if k in replace: 44 | setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) 45 | elif k not in ignore: 46 | if not k in vars(opt): 47 | vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model 48 | 49 | vocab = infos['vocab'] # ix -> word mapping 50 | 51 | # Setup the model 52 | opt.vocab = vocab 53 | model = models.setup(opt) 54 | del opt.vocab 55 | model.load_state_dict(torch.load(opt.model)) 56 | model.cuda() 57 | model.eval() 58 | crit = utils.LanguageModelCriterion() 59 | 60 | # Create the Data Loader instance 61 | if len(opt.image_folder) == 0: 62 | loader = DataLoader(opt) 63 | else: 64 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 65 | 'coco_json': opt.coco_json, 66 | 'batch_size': opt.batch_size, 67 | 'cnn_model': opt.cnn_model}) 68 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 69 | # So make sure to use the vocab in infos file. 70 | loader.ix_to_word = infos['vocab'] 71 | 72 | 73 | # Set sample options 74 | opt.datset = opt.input_json 75 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 76 | vars(opt)) 77 | 78 | print('loss: ', loss) 79 | if lang_stats: 80 | print(lang_stats) 81 | 82 | if opt.dump_json == 1: 83 | # dump the json 84 | json.dump(split_predictions, open('vis/vis.json', 'w')) 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention on Attention for Image Captioning 2 | 3 | This repository includes the implementation for [Attention on Attention for Image Captioning](https://arxiv.org/abs/1908.06954). 4 | 5 | ## Requirements 6 | 7 | - Python 3.6 8 | - Java 1.8.0 9 | - PyTorch 1.0 10 | - cider (already been added as a submodule) 11 | - coco-caption (already been added as a submodule) 12 | - tensorboardX 13 | 14 | 15 | ## Training AoANet 16 | 17 | ### Prepare data 18 | 19 | See details in `data/README.md`. 20 | 21 | (**notes:** Set `word_count_threshold` in `scripts/prepro_labels.py` to 4 to generate a vocabulary of size 10,369.) 22 | 23 | You should also preprocess the dataset and get the cache for calculating cider score for [SCST](https://arxiv.org/abs/1612.00563): 24 | 25 | ```bash 26 | $ python scripts/prepro_ngrams.py --input_json data/dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train 27 | ``` 28 | ### Start training 29 | 30 | ```bash 31 | $ CUDA_VISIBLE_DEVICES=0 sh train.sh 32 | ``` 33 | 34 | See `opts.py` for the options. (You can download the pretrained models from [here](https://drive.google.com/drive/folders/1ab0iPNyxdVm79ml-oozsIlH7H6t6dIVl?usp=sharing).) 35 | 36 | 37 | ### Evaluation 38 | 39 | ```bash 40 | $ CUDA_VISIBLE_DEVICES=0 python eval.py --model log/log_aoanet_rl/model.pth --infos_path log/log_aoanet_rl/infos_aoanet.pkl --dump_images 0 --dump_json 1 --num_images -1 --language_eval 1 --beam_size 2 --batch_size 100 --split test 41 | ``` 42 | 43 | ### Performance 44 | You will get the scores close to below after training under xe loss for 25 epochs: 45 | ```python 46 | {'Bleu_1': 0.7729384559899702, 'Bleu_2': 0.6163398035383025, 'Bleu_3': 0.4790123137715982, 'Bleu_4': 0.36944349063530374, 'METEOR': 0.2848188431924821, 'ROUGE_L': 0.5729849683867054, 'CIDEr': 1.1842173801790759, 'SPICE': 0.21650786258302354} 47 | ``` 48 | (**notes:** You can enlarge `--max_epochs` in `train.sh` to train the model for more epochs and improve the scores.) 49 | 50 | after training under SCST loss for another 15 epochs, you will get: 51 | ```python 52 | {'Bleu_1': 0.8054903453672397, 'Bleu_2': 0.6523038976984842, 'Bleu_3': 0.5096621263772566, 'Bleu_4': 0.39140307771618477, 'METEOR': 0.29011216375635934, 'ROUGE_L': 0.5890369750273199, 'CIDEr': 1.2892294296245852, 'SPICE': 0.22680092759866174} 53 | ``` 54 | 55 | 56 | ## Reference 57 | 58 | If you find this repo helpful, please consider citing: 59 | 60 | ``` 61 | @inproceedings{huang2019attention, 62 | title={Attention on Attention for Image Captioning}, 63 | author={Huang, Lun and Wang, Wenmin and Chen, Jie and Wei, Xiao-Yong}, 64 | booktitle={International Conference on Computer Vision}, 65 | year={2019} 66 | } 67 | ``` 68 | 69 | ## Acknowledgements 70 | 71 | This repository is based on [self-critical.pytorch](https://github.com/ruotianluo/self-critical.pytorch), and you may refer to it for more details about the code. 72 | -------------------------------------------------------------------------------- /eval_ensemble.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import opts 13 | import models 14 | from dataloader import * 15 | from dataloaderraw import * 16 | import eval_utils 17 | import argparse 18 | import misc.utils as utils 19 | import torch 20 | 21 | # Input arguments and options 22 | parser = argparse.ArgumentParser() 23 | # Input paths 24 | parser.add_argument('--ids', nargs='+', required=True, help='id of the models to ensemble') 25 | parser.add_argument('--weights', nargs='+', required=False, default=None, help='id of the models to ensemble') 26 | # parser.add_argument('--models', nargs='+', required=True 27 | # help='path to model to evaluate') 28 | # parser.add_argument('--infos_paths', nargs='+', required=True, help='path to infos to evaluate') 29 | opts.add_eval_options(parser) 30 | 31 | opt = parser.parse_args() 32 | 33 | model_infos = [] 34 | model_paths = [] 35 | for id in opt.ids: 36 | if '-' in id: 37 | id, app = id.split('-') 38 | app = '-'+app 39 | else: 40 | app = '' 41 | model_infos.append(utils.pickle_load(open('log_%s/infos_%s%s.pkl' %(id, id, app), 'rb'))) 42 | model_paths.append('log_%s/model%s.pth' %(id,app)) 43 | 44 | # Load one infos 45 | infos = model_infos[0] 46 | 47 | # override and collect parameters 48 | replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] 49 | for k in replace: 50 | setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) 51 | 52 | vars(opt).update({k: vars(infos['opt'])[k] for k in vars(infos['opt']).keys() if k not in vars(opt)}) # copy over options from model 53 | 54 | 55 | opt.use_box = max([getattr(infos['opt'], 'use_box', 0) for infos in model_infos]) 56 | assert max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]), 'Not support different norm_att_feat' 57 | assert max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]), 'Not support different norm_box_feat' 58 | 59 | vocab = infos['vocab'] # ix -> word mapping 60 | 61 | # Setup the model 62 | from models.AttEnsemble import AttEnsemble 63 | 64 | _models = [] 65 | for i in range(len(model_infos)): 66 | model_infos[i]['opt'].start_from = None 67 | model_infos[i]['opt'].vocab = vocab 68 | tmp = models.setup(model_infos[i]['opt']) 69 | tmp.load_state_dict(torch.load(model_paths[i])) 70 | _models.append(tmp) 71 | 72 | if opt.weights is not None: 73 | opt.weights = [float(_) for _ in opt.weights] 74 | model = AttEnsemble(_models, weights=opt.weights) 75 | model.seq_length = opt.max_length 76 | model.cuda() 77 | model.eval() 78 | crit = utils.LanguageModelCriterion() 79 | 80 | # Create the Data Loader instance 81 | if len(opt.image_folder) == 0: 82 | loader = DataLoader(opt) 83 | else: 84 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 85 | 'coco_json': opt.coco_json, 86 | 'batch_size': opt.batch_size, 87 | 'cnn_model': opt.cnn_model}) 88 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 89 | # So make sure to use the vocab in infos file. 90 | loader.ix_to_word = infos['vocab'] 91 | 92 | opt.id = '+'.join([_+str(__) for _,__ in zip(opt.ids, opt.weights)]) 93 | # Set sample options 94 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 95 | vars(opt)) 96 | 97 | print('loss: ', loss) 98 | if lang_stats: 99 | print(lang_stats) 100 | 101 | if opt.dump_json == 1: 102 | # dump the json 103 | json.dump(split_predictions, open('vis/vis.json', 'w')) 104 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Prepare data 2 | 3 | ## COCO 4 | 5 | ### Download COCO captions and preprocess them 6 | 7 | Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits. 8 | 9 | Then do: 10 | 11 | ```bash 12 | $ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk 13 | ``` 14 | 15 | `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`. 16 | 17 | ### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature) 18 | 19 | Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`. 20 | 21 | Then: 22 | 23 | ``` 24 | $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT 25 | ``` 26 | 27 | 28 | `prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB. 29 | 30 | (Check the prepro scripts for more options, like other resnet models or other attention sizes.) 31 | 32 | **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset. 33 | 34 | ### Download Bottom-up features (Skip if you are using resnet features) 35 | 36 | Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one. 37 | 38 | For example: 39 | ``` 40 | mkdir data/bu_data; cd data/bu_data 41 | wget https://storage.googleapis.com/bottom-up-attention/trainval.zip 42 | unzip trainval.zip 43 | 44 | ``` 45 | 46 | Then: 47 | 48 | ```bash 49 | python script/make_bu_data.py --output_dir data/cocobu 50 | ``` 51 | 52 | This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu. 53 | 54 | ## Flickr30k. 55 | 56 | It's similar. 57 | 58 | ``` 59 | python scripts/prepro_labels.py --input_json data/dataset_flickr30k.json --output_json data/f30ktalk.json --output_h5 data/f30ktalk 60 | 61 | python scripts/prepro_ngrams.py --input_json data/dataset_flickr30k.json --dict_json data/f30ktalk.json --output_pkl data/f30k-train --split train 62 | ``` 63 | 64 | This is to generate the coco-like annotation file for evaluation using coco-caption. 65 | 66 | ``` 67 | python scripts/prepro_reference_json.py --input_json data/dataset_flickr30k.json --output_json data/f30k_captions4eval.json 68 | ``` 69 | 70 | ### Feature extraction 71 | 72 | For resnet feature, you can do the same thing as COCO. 73 | 74 | For bottom-up feature, you can download from [link](https://github.com/kuanghuei/SCAN) 75 | 76 | `wget https://scanproject.blob.core.windows.net/scan-data/data.zip` 77 | 78 | and then convert to a pth file using the following script: 79 | 80 | ``` 81 | import numpy as np 82 | import os 83 | import torch 84 | from tqdm import tqdm 85 | 86 | out = {} 87 | def transform(id_file, feat_file): 88 | ids = open(id_file, 'r').readlines() 89 | ids = [_.strip('\n') for _ in ids] 90 | feats = np.load(feat_file) 91 | assert feats.shape[0] == len(ids) 92 | for _id, _feat in tqdm(zip(ids, feats)): 93 | out[str(_id)] = _feat 94 | 95 | transform('dev_ids.txt', 'dev_ims.npy') 96 | transform('train_ids.txt', 'train_ims.npy') 97 | transform('test_ids.txt', 'test_ims.npy') 98 | 99 | torch.save(out, 'f30kbu_att.pth') 100 | ``` -------------------------------------------------------------------------------- /scripts/prepro_reference_json.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 4 | 5 | Input: json file that has the form 6 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 7 | example element in this list would look like 8 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 9 | 10 | This script reads this json, does some basic preprocessing on the captions 11 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 12 | 13 | Output: a json file and an hdf5 file 14 | The hdf5 file contains several fields: 15 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 16 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 17 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 18 | first and last indices (in range 1..M) of labels for each image 19 | /label_length stores the length of the sequence for each of the M sequences 20 | 21 | The json file has a dict that contains: 22 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 23 | - an 'images' field that is a list holding auxiliary information for each image, 24 | such as in particular the 'split' it was assigned to. 25 | """ 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | 31 | import os 32 | import json 33 | import argparse 34 | import sys 35 | import hashlib 36 | from random import shuffle, seed 37 | 38 | 39 | def main(params): 40 | 41 | imgs = json.load(open(params['input_json'][0], 'r'))['images'] 42 | # tmp = [] 43 | # for k in imgs.keys(): 44 | # for img in imgs[k]: 45 | # img['filename'] = img['image_id'] # k+'/'+img['image_id'] 46 | # img['image_id'] = int( 47 | # int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint) 48 | # tmp.append(img) 49 | # imgs = tmp 50 | 51 | # create output json file 52 | out = {u'info': {u'description': u'This is stable 1.0 version of the 2014 MS COCO dataset.', u'url': u'http://mscoco.org', u'version': u'1.0', u'year': 2014, u'contributor': u'Microsoft COCO group', u'date_created': u'2015-01-27 09:11:52.357475'}, u'licenses': [{u'url': u'http://creativecommons.org/licenses/by-nc-sa/2.0/', u'id': 1, u'name': u'Attribution-NonCommercial-ShareAlike License'}, {u'url': u'http://creativecommons.org/licenses/by-nc/2.0/', u'id': 2, u'name': u'Attribution-NonCommercial License'}, {u'url': u'http://creativecommons.org/licenses/by-nc-nd/2.0/', u'id': 3, u'name': u'Attribution-NonCommercial-NoDerivs License'}, {u'url': u'http://creativecommons.org/licenses/by/2.0/', u'id': 4, u'name': u'Attribution License'}, {u'url': u'http://creativecommons.org/licenses/by-sa/2.0/', u'id': 5, u'name': u'Attribution-ShareAlike License'}, {u'url': u'http://creativecommons.org/licenses/by-nd/2.0/', u'id': 6, u'name': u'Attribution-NoDerivs License'}, {u'url': u'http://flickr.com/commons/usage/', u'id': 7, u'name': u'No known copyright restrictions'}, {u'url': u'http://www.usa.gov/copyright.shtml', u'id': 8, u'name': u'United States Government Work'}], u'type': u'captions'} 53 | out.update({'images': [], 'annotations': []}) 54 | 55 | cnt = 0 56 | empty_cnt = 0 57 | for i, img in enumerate(imgs): 58 | if img['split'] == 'train': 59 | continue 60 | out['images'].append( 61 | {u'id': img.get('cocoid', img['imgid'])}) 62 | for j, s in enumerate(img['sentences']): 63 | if len(s) == 0: 64 | continue 65 | s = ' '.join(s['tokens']) 66 | out['annotations'].append( 67 | {'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt}) 68 | cnt += 1 69 | 70 | json.dump(out, open(params['output_json'], 'w')) 71 | print('wrote ', params['output_json']) 72 | 73 | 74 | if __name__ == "__main__": 75 | 76 | parser = argparse.ArgumentParser() 77 | 78 | # input json 79 | parser.add_argument('--input_json', nargs='+', required=True, 80 | help='input json file to process into hdf5') 81 | parser.add_argument('--output_json', default='data.json', 82 | help='output json file') 83 | 84 | args = parser.parse_args() 85 | params = vars(args) # convert to ordinary dict 86 | print('parsed input parameters:') 87 | print(json.dumps(params, indent=2)) 88 | main(params) 89 | 90 | -------------------------------------------------------------------------------- /scripts/prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | from six.moves import cPickle 38 | import numpy as np 39 | import torch 40 | import torchvision.models as models 41 | import skimage.io 42 | 43 | from torchvision import transforms as trn 44 | preprocess = trn.Compose([ 45 | #trn.ToTensor(), 46 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 47 | ]) 48 | 49 | from misc.resnet_utils import myResnet 50 | import misc.resnet as resnet 51 | 52 | def main(params): 53 | net = getattr(resnet, params['model'])() 54 | net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) 55 | my_resnet = myResnet(net) 56 | my_resnet.cuda() 57 | my_resnet.eval() 58 | 59 | imgs = json.load(open(params['input_json'], 'r')) 60 | imgs = imgs['images'] 61 | N = len(imgs) 62 | 63 | seed(123) # make reproducible 64 | 65 | dir_fc = params['output_dir']+'_fc' 66 | dir_att = params['output_dir']+'_att' 67 | if not os.path.isdir(dir_fc): 68 | os.mkdir(dir_fc) 69 | if not os.path.isdir(dir_att): 70 | os.mkdir(dir_att) 71 | 72 | for i,img in enumerate(imgs): 73 | # load the image 74 | I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) 75 | # handle grayscale input images 76 | if len(I.shape) == 2: 77 | I = I[:,:,np.newaxis] 78 | I = np.concatenate((I,I,I), axis=2) 79 | 80 | I = I.astype('float32')/255.0 81 | I = torch.from_numpy(I.transpose([2,0,1])).cuda() 82 | I = preprocess(I) 83 | with torch.no_grad(): 84 | tmp_fc, tmp_att = my_resnet(I, params['att_size']) 85 | # write to pkl 86 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 87 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 88 | 89 | if i % 1000 == 0: 90 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 91 | print('wrote ', params['output_dir']) 92 | 93 | if __name__ == "__main__": 94 | 95 | parser = argparse.ArgumentParser() 96 | 97 | # input json 98 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 99 | parser.add_argument('--output_dir', default='data', help='output h5 file') 100 | 101 | # options 102 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 103 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 104 | parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') 105 | parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') 106 | 107 | args = parser.parse_args() 108 | params = vars(args) # convert to ordinary dict 109 | print('parsed input parameters:') 110 | print(json.dumps(params, indent = 2)) 111 | main(params) 112 | -------------------------------------------------------------------------------- /dataloaderraw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import os 8 | import numpy as np 9 | import random 10 | import torch 11 | import skimage 12 | import skimage.io 13 | import scipy.misc 14 | 15 | from torchvision import transforms as trn 16 | preprocess = trn.Compose([ 17 | #trn.ToTensor(), 18 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 19 | ]) 20 | 21 | from misc.resnet_utils import myResnet 22 | import misc.resnet 23 | 24 | class DataLoaderRaw(): 25 | 26 | def __init__(self, opt): 27 | self.opt = opt 28 | self.coco_json = opt.get('coco_json', '') 29 | self.folder_path = opt.get('folder_path', '') 30 | 31 | self.batch_size = opt.get('batch_size', 1) 32 | self.seq_per_img = 1 33 | 34 | # Load resnet 35 | self.cnn_model = opt.get('cnn_model', 'resnet101') 36 | self.my_resnet = getattr(misc.resnet, self.cnn_model)() 37 | self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth')) 38 | self.my_resnet = myResnet(self.my_resnet) 39 | self.my_resnet.cuda() 40 | self.my_resnet.eval() 41 | 42 | 43 | 44 | # load the json file which contains additional information about the dataset 45 | print('DataLoaderRaw loading images from folder: ', self.folder_path) 46 | 47 | self.files = [] 48 | self.ids = [] 49 | 50 | print(len(self.coco_json)) 51 | if len(self.coco_json) > 0: 52 | print('reading from ' + opt.coco_json) 53 | # read in filenames from the coco-style json file 54 | self.coco_annotation = json.load(open(self.coco_json)) 55 | for k,v in enumerate(self.coco_annotation['images']): 56 | fullpath = os.path.join(self.folder_path, v['file_name']) 57 | self.files.append(fullpath) 58 | self.ids.append(v['id']) 59 | else: 60 | # read in all the filenames from the folder 61 | print('listing all images in directory ' + self.folder_path) 62 | def isImage(f): 63 | supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM'] 64 | for ext in supportedExt: 65 | start_idx = f.rfind(ext) 66 | if start_idx >= 0 and start_idx + len(ext) == len(f): 67 | return True 68 | return False 69 | 70 | n = 1 71 | for root, dirs, files in os.walk(self.folder_path, topdown=False): 72 | for file in files: 73 | fullpath = os.path.join(self.folder_path, file) 74 | if isImage(fullpath): 75 | self.files.append(fullpath) 76 | self.ids.append(str(n)) # just order them sequentially 77 | n = n + 1 78 | 79 | self.N = len(self.files) 80 | print('DataLoaderRaw found ', self.N, ' images') 81 | 82 | self.iterator = 0 83 | 84 | def get_batch(self, split, batch_size=None): 85 | batch_size = batch_size or self.batch_size 86 | 87 | # pick an index of the datapoint to load next 88 | fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32') 89 | att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32') 90 | max_index = self.N 91 | wrapped = False 92 | infos = [] 93 | 94 | for i in range(batch_size): 95 | ri = self.iterator 96 | ri_next = ri + 1 97 | if ri_next >= max_index: 98 | ri_next = 0 99 | wrapped = True 100 | # wrap back around 101 | self.iterator = ri_next 102 | 103 | img = skimage.io.imread(self.files[ri]) 104 | 105 | if len(img.shape) == 2: 106 | img = img[:,:,np.newaxis] 107 | img = np.concatenate((img, img, img), axis=2) 108 | 109 | img = img[:,:,:3].astype('float32')/255.0 110 | img = torch.from_numpy(img.transpose([2,0,1])).cuda() 111 | img = preprocess(img) 112 | with torch.no_grad(): 113 | tmp_fc, tmp_att = self.my_resnet(img) 114 | 115 | fc_batch[i] = tmp_fc.data.cpu().float().numpy() 116 | att_batch[i] = tmp_att.data.cpu().float().numpy() 117 | 118 | info_struct = {} 119 | info_struct['id'] = self.ids[ri] 120 | info_struct['file_path'] = self.files[ri] 121 | infos.append(info_struct) 122 | 123 | data = {} 124 | data['fc_feats'] = fc_batch 125 | data['att_feats'] = att_batch.reshape(batch_size, -1, 2048) 126 | data['att_masks'] = None 127 | data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped} 128 | data['infos'] = infos 129 | 130 | return data 131 | 132 | def reset_iterator(self, split): 133 | self.iterator = 0 134 | 135 | def get_vocab_size(self): 136 | return len(self.ix_to_word) 137 | 138 | def get_vocab(self): 139 | return self.ix_to_word 140 | -------------------------------------------------------------------------------- /models/AttEnsemble.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.autograd import * 24 | import misc.utils as utils 25 | 26 | from .CaptionModel import CaptionModel 27 | from .AttModel import pack_wrapper, AttModel 28 | 29 | class AttEnsemble(AttModel): 30 | def __init__(self, models, weights=None): 31 | CaptionModel.__init__(self) 32 | # super(AttEnsemble, self).__init__() 33 | 34 | self.models = nn.ModuleList(models) 35 | self.vocab_size = models[0].vocab_size 36 | self.seq_length = models[0].seq_length 37 | self.bad_endings_ix = models[0].bad_endings_ix 38 | self.ss_prob = 0 39 | weights = weights or [1.0] * len(self.models) 40 | self.register_buffer('weights', torch.tensor(weights)) 41 | 42 | def init_hidden(self, batch_size): 43 | state = [m.init_hidden(batch_size) for m in self.models] 44 | return self.pack_state(state) 45 | 46 | def pack_state(self, state): 47 | self.state_lengths = [len(_) for _ in state] 48 | return sum([list(_) for _ in state], []) 49 | 50 | def unpack_state(self, state): 51 | out = [] 52 | for l in self.state_lengths: 53 | out.append(state[:l]) 54 | state = state[l:] 55 | return out 56 | 57 | def embed(self, it): 58 | return [m.embed(it) for m in self.models] 59 | 60 | def core(self, *args): 61 | return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))]) 62 | 63 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state): 64 | # 'it' contains a word index 65 | xt = self.embed(it) 66 | 67 | state = self.unpack_state(state) 68 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) 69 | logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log() 70 | 71 | return logprobs, self.pack_state(state) 72 | 73 | def _prepare_feature(self, *args): 74 | return tuple(zip(*[m._prepare_feature(*args) for m in self.models])) 75 | 76 | # def _prepare_feature(self, fc_feats, att_feats, att_masks): 77 | 78 | # att_feats, att_masks = self.clip_att(att_feats, att_masks) 79 | 80 | # # embed fc and att feats 81 | # fc_feats = [m.fc_embed(fc_feats) for m in self.models] 82 | # att_feats = [pack_wrapper(m.att_embed, att_feats[...,:m.att_feat_size], att_masks) for m in self.models] 83 | 84 | # # Project the attention feats first to reduce memory and computation comsumptions. 85 | # p_att_feats = [m.ctx2att(att_feats[i]) for i,m in enumerate(self.models)] 86 | 87 | # return fc_feats, att_feats, p_att_feats, [att_masks] * len(self.models) 88 | 89 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 90 | beam_size = opt.get('beam_size', 10) 91 | batch_size = fc_feats.size(0) 92 | 93 | fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 94 | 95 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 96 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 97 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 98 | # lets process every image independently for now, for simplicity 99 | 100 | self.done_beams = [[] for _ in range(batch_size)] 101 | for k in range(batch_size): 102 | state = self.init_hidden(beam_size) 103 | tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)] 104 | tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] 105 | tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] 106 | tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)] 107 | 108 | it = fc_feats[0].data.new(beam_size).long().zero_() 109 | logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) 110 | 111 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) 112 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 113 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 114 | # return the samples and their log likelihoods 115 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 116 | # return the samples and their log likelihoods -------------------------------------------------------------------------------- /scripts/prepro_ngrams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | import os 27 | import json 28 | import argparse 29 | from six.moves import cPickle 30 | import misc.utils as utils 31 | from collections import defaultdict 32 | 33 | def precook(s, n=4, out=False): 34 | """ 35 | Takes a string as input and returns an object that can be given to 36 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 37 | can take string arguments as well. 38 | :param s: string : sentence to be converted into ngrams 39 | :param n: int : number of ngrams for which representation is calculated 40 | :return: term frequency vector for occuring ngrams 41 | """ 42 | words = s.split() 43 | counts = defaultdict(int) 44 | for k in xrange(1,n+1): 45 | for i in xrange(len(words)-k+1): 46 | ngram = tuple(words[i:i+k]) 47 | counts[ngram] += 1 48 | return counts 49 | 50 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 51 | '''Takes a list of reference sentences for a single segment 52 | and returns an object that encapsulates everything that BLEU 53 | needs to know about them. 54 | :param refs: list of string : reference sentences for some image 55 | :param n: int : number of ngrams for which (ngram) representation is calculated 56 | :return: result (list of dict) 57 | ''' 58 | return [precook(ref, n) for ref in refs] 59 | 60 | def create_crefs(refs): 61 | crefs = [] 62 | for ref in refs: 63 | # ref is a list of 5 captions 64 | crefs.append(cook_refs(ref)) 65 | return crefs 66 | 67 | def compute_doc_freq(crefs): 68 | ''' 69 | Compute term frequency for reference data. 70 | This will be used to compute idf (inverse document frequency later) 71 | The term frequency is stored in the object 72 | :return: None 73 | ''' 74 | document_frequency = defaultdict(float) 75 | for refs in crefs: 76 | # refs, k ref captions of one image 77 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]): 78 | document_frequency[ngram] += 1 79 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 80 | return document_frequency 81 | 82 | def build_dict(imgs, wtoi, params): 83 | wtoi[''] = 0 84 | 85 | count_imgs = 0 86 | 87 | refs_words = [] 88 | refs_idxs = [] 89 | for img in imgs: 90 | if (params['split'] == img['split']) or \ 91 | (params['split'] == 'train' and img['split'] == 'restval') or \ 92 | (params['split'] == 'all'): 93 | #(params['split'] == 'val' and img['split'] == 'restval') or \ 94 | ref_words = [] 95 | ref_idxs = [] 96 | for sent in img['sentences']: 97 | if hasattr(params, 'bpe'): 98 | sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ') 99 | tmp_tokens = sent['tokens'] + [''] 100 | tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] 101 | ref_words.append(' '.join(tmp_tokens)) 102 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) 103 | refs_words.append(ref_words) 104 | refs_idxs.append(ref_idxs) 105 | count_imgs += 1 106 | print('total imgs:', count_imgs) 107 | 108 | ngram_words = compute_doc_freq(create_crefs(refs_words)) 109 | ngram_idxs = compute_doc_freq(create_crefs(refs_idxs)) 110 | return ngram_words, ngram_idxs, count_imgs 111 | 112 | def main(params): 113 | 114 | imgs = json.load(open(params['input_json'], 'r')) 115 | dict_json = json.load(open(params['dict_json'], 'r')) 116 | itow = dict_json['ix_to_word'] 117 | wtoi = {w:i for i,w in itow.items()} 118 | 119 | # Load bpe 120 | if 'bpe' in dict_json: 121 | import tempfile 122 | import codecs 123 | codes_f = tempfile.NamedTemporaryFile(delete=False) 124 | codes_f.close() 125 | with open(codes_f.name, 'w') as f: 126 | f.write(dict_json['bpe']) 127 | with codecs.open(codes_f.name, encoding='UTF-8') as codes: 128 | bpe = apply_bpe.BPE(codes) 129 | params.bpe = bpe 130 | 131 | imgs = imgs['images'] 132 | 133 | ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) 134 | 135 | utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','w')) 136 | utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','w')) 137 | 138 | if __name__ == "__main__": 139 | 140 | parser = argparse.ArgumentParser() 141 | 142 | # input json 143 | parser.add_argument('--input_json', default='/home-nfs/rluo/rluo/nips/code/prepro/dataset_coco.json', help='input json file to process into hdf5') 144 | parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') 145 | parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') 146 | parser.add_argument('--split', default='all', help='test, val, train, all') 147 | args = parser.parse_args() 148 | params = vars(args) # convert to ordinary dict 149 | 150 | main(params) 151 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import numpy as np 9 | import json 10 | from json import encoder 11 | import random 12 | import string 13 | import time 14 | import os 15 | import sys 16 | import misc.utils as utils 17 | 18 | bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] 19 | bad_endings += ['the'] 20 | 21 | def count_bad(sen): 22 | sen = sen.split(' ') 23 | if sen[-1] in bad_endings: 24 | return 1 25 | else: 26 | return 0 27 | 28 | def language_eval(dataset, preds, model_id, split): 29 | import sys 30 | sys.path.append("coco-caption") 31 | if 'coco' in dataset: 32 | annFile = 'coco-caption/annotations/captions_val2014.json' 33 | elif 'flickr30k' in dataset or 'f30k' in dataset: 34 | annFile = 'coco-caption/f30k_captions4eval.json' 35 | from pycocotools.coco import COCO 36 | from pycocoevalcap.eval import COCOEvalCap 37 | 38 | # encoder.FLOAT_REPR = lambda o: format(o, '.3f') 39 | 40 | if not os.path.isdir('eval_results'): 41 | os.mkdir('eval_results') 42 | cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json') 43 | 44 | coco = COCO(annFile) 45 | valids = coco.getImgIds() 46 | 47 | # filter results to only those in MSCOCO validation set (will be about a third) 48 | preds_filt = [p for p in preds if p['image_id'] in valids] 49 | print('using %d/%d predictions' % (len(preds_filt), len(preds))) 50 | json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... 51 | 52 | cocoRes = coco.loadRes(cache_path) 53 | cocoEval = COCOEvalCap(coco, cocoRes) 54 | cocoEval.params['image_id'] = cocoRes.getImgIds() 55 | cocoEval.evaluate() 56 | 57 | # create output dictionary 58 | out = {} 59 | for metric, score in cocoEval.eval.items(): 60 | out[metric] = score 61 | 62 | imgToEval = cocoEval.imgToEval 63 | for p in preds_filt: 64 | image_id, caption = p['image_id'], p['caption'] 65 | imgToEval[image_id]['caption'] = caption 66 | 67 | out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt)) 68 | outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json') 69 | with open(outfile_path, 'w') as outfile: 70 | json.dump({'overall': out, 'imgToEval': imgToEval}, outfile) 71 | 72 | return out 73 | 74 | def eval_split(model, crit, loader, eval_kwargs={}): 75 | verbose = eval_kwargs.get('verbose', True) 76 | verbose_beam = eval_kwargs.get('verbose_beam', 1) 77 | verbose_loss = eval_kwargs.get('verbose_loss', 1) 78 | num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) 79 | split = eval_kwargs.get('split', 'val') 80 | lang_eval = eval_kwargs.get('language_eval', 0) 81 | dataset = eval_kwargs.get('dataset', 'coco') 82 | beam_size = eval_kwargs.get('beam_size', 1) 83 | remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) 84 | os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration 85 | 86 | # Make sure in the evaluation mode 87 | model.eval() 88 | 89 | loader.reset_iterator(split) 90 | 91 | n = 0 92 | loss = 0 93 | loss_sum = 0 94 | loss_evals = 1e-8 95 | predictions = [] 96 | while True: 97 | data = loader.get_batch(split) 98 | n = n + loader.batch_size 99 | 100 | if data.get('labels', None) is not None and verbose_loss: 101 | # forward the model to get loss 102 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] 103 | tmp = [_.cuda() if _ is not None else _ for _ in tmp] 104 | fc_feats, att_feats, labels, masks, att_masks = tmp 105 | 106 | with torch.no_grad(): 107 | loss = crit(model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]).item() 108 | loss_sum = loss_sum + loss 109 | loss_evals = loss_evals + 1 110 | 111 | # forward the model to also get generated samples for each image 112 | # Only leave one feature for each image, in case duplicate sample 113 | tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], 114 | data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img], 115 | data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] if data['att_masks'] is not None else None] 116 | tmp = [_.cuda() if _ is not None else _ for _ in tmp] 117 | fc_feats, att_feats, att_masks = tmp 118 | # forward the model to also get generated samples for each image 119 | with torch.no_grad(): 120 | seq = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data 121 | 122 | # Print beam search 123 | if beam_size > 1 and verbose_beam: 124 | for i in range(loader.batch_size): 125 | print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) 126 | print('--' * 10) 127 | sents = utils.decode_sequence(loader.get_vocab(), seq) 128 | 129 | for k, sent in enumerate(sents): 130 | entry = {'image_id': data['infos'][k]['id'], 'caption': sent} 131 | if eval_kwargs.get('dump_path', 0) == 1: 132 | entry['file_name'] = data['infos'][k]['file_path'] 133 | predictions.append(entry) 134 | if eval_kwargs.get('dump_images', 0) == 1: 135 | # dump the raw image to vis/ folder 136 | cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross 137 | print(cmd) 138 | os.system(cmd) 139 | 140 | if verbose: 141 | print('image %s: %s' %(entry['image_id'], entry['caption'])) 142 | 143 | # if we wrapped around the split or used up val imgs budget then bail 144 | ix0 = data['bounds']['it_pos_now'] 145 | ix1 = data['bounds']['it_max'] 146 | if num_images != -1: 147 | ix1 = min(ix1, num_images) 148 | for i in range(n - ix1): 149 | predictions.pop() 150 | 151 | if verbose: 152 | print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss)) 153 | 154 | if data['bounds']['wrapped']: 155 | break 156 | if num_images >= 0 and n >= num_images: 157 | break 158 | 159 | lang_stats = None 160 | if lang_eval == 1: 161 | lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split) 162 | 163 | # Switch back to training mode 164 | model.train() 165 | return loss_sum/loss_evals, predictions, lang_stats 166 | -------------------------------------------------------------------------------- /scripts/dump_to_lmdb.py: -------------------------------------------------------------------------------- 1 | # copy from https://github.com/Lyken17/Efficient-PyTorch/tools 2 | 3 | import os 4 | import os.path as osp 5 | import os, sys 6 | import os.path as osp 7 | from PIL import Image 8 | import six 9 | import string 10 | 11 | import lmdb 12 | import pickle 13 | import tqdm 14 | import numpy as np 15 | import argparse 16 | import json 17 | 18 | import torch 19 | import torch.utils.data as data 20 | from torch.utils.data import DataLoader 21 | 22 | import csv 23 | csv.field_size_limit(sys.maxsize) 24 | FIELDNAMES = ['image_id', 'status'] 25 | 26 | class FolderLMDB(data.Dataset): 27 | def __init__(self, db_path, fn_list=None): 28 | self.db_path = db_path 29 | self.env = lmdb.open(db_path, subdir=osp.isdir(db_path), 30 | readonly=True, lock=False, 31 | readahead=False, meminit=False) 32 | if fn_list is not None: 33 | self.length = len(fn_list) 34 | self.keys = fn_list 35 | else: 36 | raise Error 37 | 38 | def __getitem__(self, index): 39 | env = self.env 40 | with env.begin(write=False) as txn: 41 | byteflow = txn.get(self.keys[index]) 42 | 43 | # load image 44 | imgbuf = byteflow 45 | buf = six.BytesIO() 46 | buf.write(imgbuf) 47 | buf.seek(0) 48 | try: 49 | if args.extension == '.npz': 50 | feat = np.load(buf)['feat'] 51 | else: 52 | feat = np.load(buf) 53 | except Exception as e: 54 | print self.keys[index], e 55 | return None 56 | 57 | return feat 58 | 59 | def __len__(self): 60 | return self.length 61 | 62 | def __repr__(self): 63 | return self.__class__.__name__ + ' (' + self.db_path + ')' 64 | 65 | 66 | def make_dataset(dir, extension): 67 | images = [] 68 | dir = os.path.expanduser(dir) 69 | for root, _, fnames in sorted(os.walk(dir)): 70 | for fname in sorted(fnames): 71 | if has_file_allowed_extension(fname, [extension]): 72 | path = os.path.join(root, fname) 73 | images.append(path) 74 | 75 | return images 76 | 77 | 78 | def raw_reader(path): 79 | with open(path, 'rb') as f: 80 | bin_data = f.read() 81 | return bin_data 82 | 83 | 84 | def raw_npz_reader(path): 85 | with open(path, 'rb') as f: 86 | bin_data = f.read() 87 | try: 88 | npz_data = np.load(six.BytesIO(bin_data))['feat'] 89 | except Exception as e: 90 | print path 91 | npz_data = None 92 | print e 93 | return bin_data, npz_data 94 | 95 | 96 | def raw_npy_reader(path): 97 | with open(path, 'rb') as f: 98 | bin_data = f.read() 99 | try: 100 | npy_data = np.load(six.BytesIO(bin_data)) 101 | except Exception as e: 102 | print path 103 | npy_data = None 104 | print e 105 | return bin_data, npy_data 106 | 107 | 108 | class Folder(data.Dataset): 109 | 110 | def __init__(self, root, loader, extension, fn_list=None): 111 | super(Folder, self).__init__() 112 | self.root = root 113 | if fn_list: 114 | samples = [os.path.join(root, str(_)+extension) for _ in fn_list] 115 | else: 116 | samples = make_dataset(self.root, extention) 117 | 118 | self.loader = loader 119 | self.extension = extension 120 | self.samples = samples 121 | 122 | def __getitem__(self, index): 123 | """ 124 | Args: 125 | index (int): Index 126 | Returns: 127 | tuple: (sample, target) where target is class_index of the target class. 128 | """ 129 | path = self.samples[index] 130 | sample = self.loader(path) 131 | 132 | return (path.split('/')[-1].split('.')[0],) + sample 133 | 134 | def __len__(self): 135 | return len(self.samples) 136 | 137 | 138 | def folder2lmdb(dpath, fn_list, write_frequency=5000): 139 | directory = osp.expanduser(osp.join(dpath)) 140 | print("Loading dataset from %s" % directory) 141 | if args.extension == '.npz': 142 | dataset = Folder(directory, loader=raw_npz_reader, extension='.npz', 143 | fn_list=fn_list) 144 | else: 145 | dataset = Folder(directory, loader=raw_npy_reader, extension='.npy', 146 | fn_list=fn_list) 147 | data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x) 148 | 149 | # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1])) 150 | lmdb_path = osp.join("%s.lmdb" % (directory)) 151 | isdir = os.path.isdir(lmdb_path) 152 | 153 | print("Generate LMDB to %s" % lmdb_path) 154 | db = lmdb.open(lmdb_path, subdir=isdir, 155 | map_size=1099511627776 * 2, readonly=False, 156 | meminit=False, map_async=True) 157 | 158 | txn = db.begin(write=True) 159 | 160 | tsvfile = open(args.output_file, 'ab') 161 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 162 | names = [] 163 | for idx, data in enumerate(tqdm.tqdm(data_loader)): 164 | # print(type(data), data) 165 | name, byte, npz = data[0] 166 | if npz is not None: 167 | txn.put(name, byte) 168 | names.append({'image_id': name, 'status': str(npz is not None)}) 169 | if idx % write_frequency == 0: 170 | print("[%d/%d]" % (idx, len(data_loader))) 171 | print('writing') 172 | txn.commit() 173 | txn = db.begin(write=True) 174 | # write in tsv 175 | for name in names: 176 | writer.writerow(name) 177 | names = [] 178 | tsvfile.flush() 179 | print('writing finished') 180 | 181 | # finish iterating through dataset 182 | txn.commit() 183 | for name in names: 184 | writer.writerow(name) 185 | tsvfile.flush() 186 | tsvfile.close() 187 | 188 | print("Flushing database ...") 189 | db.sync() 190 | db.close() 191 | 192 | def parse_args(): 193 | """ 194 | Parse input arguments 195 | """ 196 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') 197 | # parser.add_argument('--json) 198 | parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str) 199 | parser.add_argument('--output_file', default='.dump_cache.tsv', type=str) 200 | parser.add_argument('--folder', default='./data/cocobu_att', type=str) 201 | parser.add_argument('--extension', default='.npz', type=str) 202 | 203 | args = parser.parse_args() 204 | return args 205 | 206 | if __name__ == "__main__": 207 | global args 208 | args = parse_args() 209 | 210 | args.output_file += args.folder.split('/')[-1] 211 | if args.folder.find('/') > 0: 212 | args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file 213 | print(args.output_file) 214 | 215 | img_list = json.load(open(args.input_json, 'r'))['images'] 216 | fn_list = [str(_['cocoid']) for _ in img_list] 217 | found_ids = set() 218 | try: 219 | with open(args.output_file, 'r') as tsvfile: 220 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 221 | for item in reader: 222 | if item['status'] == 'True': 223 | found_ids.add(item['image_id']) 224 | except: 225 | pass 226 | fn_list = [_ for _ in fn_list if _ not in found_ids] 227 | folder2lmdb(args.folder, fn_list) 228 | 229 | # Test existing. 230 | found_ids = set() 231 | with open(args.output_file, 'r') as tsvfile: 232 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) 233 | for item in reader: 234 | if item['status'] == 'True': 235 | found_ids.add(item['image_id']) 236 | 237 | folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids)) 238 | data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x) 239 | for data in tqdm.tqdm(data_loader): 240 | assert data[0] is not None -------------------------------------------------------------------------------- /models/ShowTellModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | import misc.utils as utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class ShowTellModel(CaptionModel): 14 | def __init__(self, opt): 15 | super(ShowTellModel, self).__init__() 16 | self.vocab_size = opt.vocab_size 17 | self.input_encoding_size = opt.input_encoding_size 18 | self.rnn_type = opt.rnn_type 19 | self.rnn_size = opt.rnn_size 20 | self.num_layers = opt.num_layers 21 | self.drop_prob_lm = opt.drop_prob_lm 22 | self.seq_length = opt.seq_length 23 | self.fc_feat_size = opt.fc_feat_size 24 | 25 | self.ss_prob = 0.0 # Schedule sampling probability 26 | 27 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 28 | self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 29 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 30 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 31 | self.dropout = nn.Dropout(self.drop_prob_lm) 32 | 33 | self.init_weights() 34 | 35 | def init_weights(self): 36 | initrange = 0.1 37 | self.embed.weight.data.uniform_(-initrange, initrange) 38 | self.logit.bias.data.fill_(0) 39 | self.logit.weight.data.uniform_(-initrange, initrange) 40 | 41 | def init_hidden(self, bsz): 42 | weight = next(self.parameters()).data 43 | if self.rnn_type == 'lstm': 44 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), 45 | weight.new_zeros(self.num_layers, bsz, self.rnn_size)) 46 | else: 47 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size) 48 | 49 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 50 | batch_size = fc_feats.size(0) 51 | state = self.init_hidden(batch_size) 52 | outputs = [] 53 | 54 | for i in range(seq.size(1)): 55 | if i == 0: 56 | xt = self.img_embed(fc_feats) 57 | else: 58 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 59 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 60 | sample_mask = sample_prob < self.ss_prob 61 | if sample_mask.sum() == 0: 62 | it = seq[:, i-1].clone() 63 | else: 64 | sample_ind = sample_mask.nonzero().view(-1) 65 | it = seq[:, i-1].data.clone() 66 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 67 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 68 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 69 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 70 | else: 71 | it = seq[:, i-1].clone() 72 | # break if all the sequences end 73 | if i >= 2 and seq[:, i-1].data.sum() == 0: 74 | break 75 | xt = self.embed(it) 76 | 77 | output, state = self.core(xt.unsqueeze(0), state) 78 | output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 79 | outputs.append(output) 80 | 81 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 82 | 83 | def get_logprobs_state(self, it, state): 84 | # 'it' contains a word index 85 | xt = self.embed(it) 86 | 87 | output, state = self.core(xt.unsqueeze(0), state) 88 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 89 | 90 | return logprobs, state 91 | 92 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 93 | beam_size = opt.get('beam_size', 10) 94 | batch_size = fc_feats.size(0) 95 | 96 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 97 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 98 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 99 | # lets process every image independently for now, for simplicity 100 | 101 | self.done_beams = [[] for _ in range(batch_size)] 102 | for k in range(batch_size): 103 | state = self.init_hidden(beam_size) 104 | for t in range(2): 105 | if t == 0: 106 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 107 | elif t == 1: # input 108 | it = fc_feats.data.new(beam_size).long().zero_() 109 | xt = self.embed(it) 110 | 111 | output, state = self.core(xt.unsqueeze(0), state) 112 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 113 | 114 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 115 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 116 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 117 | # return the samples and their log likelihoods 118 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 119 | 120 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): 121 | sample_method = opt.get('sample_method', 'greedy') 122 | beam_size = opt.get('beam_size', 1) 123 | temperature = opt.get('temperature', 1.0) 124 | if beam_size > 1: 125 | return self.sample_beam(fc_feats, att_feats, opt) 126 | 127 | batch_size = fc_feats.size(0) 128 | state = self.init_hidden(batch_size) 129 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) 130 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) 131 | for t in range(self.seq_length + 2): 132 | if t == 0: 133 | xt = self.img_embed(fc_feats) 134 | else: 135 | if t == 1: # input 136 | it = fc_feats.data.new(batch_size).long().zero_() 137 | xt = self.embed(it) 138 | 139 | output, state = self.core(xt.unsqueeze(0), state) 140 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 141 | 142 | # sample the next word 143 | if t == self.seq_length + 1: # skip if we achieve maximum length 144 | break 145 | if sample_method == 'greedy': 146 | sampleLogprobs, it = torch.max(logprobs.data, 1) 147 | it = it.view(-1).long() 148 | else: 149 | if temperature == 1.0: 150 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 151 | else: 152 | # scale logprobs by temperature 153 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 154 | it = torch.multinomial(prob_prev, 1).cuda() 155 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 156 | it = it.view(-1).long() # and flatten indices for downstream processing 157 | 158 | if t >= 1: 159 | # stop when all finished 160 | if t == 1: 161 | unfinished = it > 0 162 | else: 163 | unfinished = unfinished * (it > 0) 164 | it = it * unfinished.type_as(it) 165 | seq[:,t-1] = it #seq[t] the input of t+2 time step 166 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1) 167 | if unfinished.sum() == 0: 168 | break 169 | 170 | return seq, seqLogprobs -------------------------------------------------------------------------------- /scripts/build_bpe_subword_nmt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | import numpy as np 38 | import torch 39 | import torchvision.models as models 40 | import skimage.io 41 | from PIL import Image 42 | 43 | import codecs 44 | import tempfile 45 | from subword_nmt import learn_bpe, apply_bpe 46 | 47 | # python scripts/build_bpe_subword_nmt.py --input_json data/dataset_coco.json --output_json data/cocotalkbpe.json --output_h5 data/cocotalkbpe 48 | 49 | def build_vocab(imgs, params): 50 | # count up the number of words 51 | captions = [] 52 | for img in imgs: 53 | for sent in img['sentences']: 54 | captions.append(' '.join(sent['tokens'])) 55 | captions='\n'.join(captions) 56 | all_captions = tempfile.NamedTemporaryFile(delete=False) 57 | all_captions.close() 58 | with open(all_captions.name, 'w') as txt_file: 59 | txt_file.write(captions) 60 | 61 | # 62 | codecs_output = tempfile.NamedTemporaryFile(delete=False) 63 | codecs_output.close() 64 | with codecs.open(codecs_output.name, 'w', encoding='UTF-8') as output: 65 | learn_bpe.learn_bpe(codecs.open(all_captions.name, encoding='UTF-8'), output, params['symbol_count']) 66 | 67 | with codecs.open(codecs_output.name, encoding='UTF-8') as codes: 68 | bpe = apply_bpe.BPE(codes) 69 | 70 | tmp = tempfile.NamedTemporaryFile(delete=False) 71 | tmp.close() 72 | 73 | tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8') 74 | 75 | for _, img in enumerate(imgs): 76 | img['final_captions'] = [] 77 | for sent in img['sentences']: 78 | txt = ' '.join(sent['tokens']) 79 | txt = bpe.segment(txt).strip() 80 | img['final_captions'].append(txt.split(' ')) 81 | tmpout.write(txt) 82 | tmpout.write('\n') 83 | if _ < 20: 84 | print(txt) 85 | 86 | tmpout.close() 87 | tmpin = codecs.open(tmp.name, encoding='UTF-8') 88 | 89 | vocab = learn_bpe.get_vocabulary(tmpin) 90 | vocab = sorted(vocab.keys(), key=lambda x: vocab[x], reverse=True) 91 | 92 | # Always insert UNK 93 | print('inserting the special UNK token') 94 | vocab.append('UNK') 95 | 96 | print('Vocab size:', len(vocab)) 97 | 98 | os.remove(all_captions.name) 99 | with open(codecs_output.name, 'r') as codes: 100 | bpe = codes.read() 101 | os.remove(codecs_output.name) 102 | os.remove(tmp.name) 103 | 104 | return vocab, bpe 105 | 106 | def encode_captions(imgs, params, wtoi): 107 | """ 108 | encode all captions into one large array, which will be 1-indexed. 109 | also produces label_start_ix and label_end_ix which store 1-indexed 110 | and inclusive (Lua-style) pointers to the first and last caption for 111 | each image in the dataset. 112 | """ 113 | 114 | max_length = params['max_length'] 115 | N = len(imgs) 116 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 117 | 118 | label_arrays = [] 119 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 120 | label_end_ix = np.zeros(N, dtype='uint32') 121 | label_length = np.zeros(M, dtype='uint32') 122 | caption_counter = 0 123 | counter = 1 124 | for i,img in enumerate(imgs): 125 | n = len(img['final_captions']) 126 | assert n > 0, 'error: some image has no captions' 127 | 128 | Li = np.zeros((n, max_length), dtype='uint32') 129 | for j,s in enumerate(img['final_captions']): 130 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 131 | caption_counter += 1 132 | for k,w in enumerate(s): 133 | if k < max_length: 134 | Li[j,k] = wtoi[w] 135 | 136 | # note: word indices are 1-indexed, and captions are padded with zeros 137 | label_arrays.append(Li) 138 | label_start_ix[i] = counter 139 | label_end_ix[i] = counter + n - 1 140 | 141 | counter += n 142 | 143 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 144 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 145 | assert np.all(label_length > 0), 'error: some caption had no words?' 146 | 147 | print('encoded captions to array of size ', L.shape) 148 | return L, label_start_ix, label_end_ix, label_length 149 | 150 | def main(params): 151 | 152 | imgs = json.load(open(params['input_json'], 'r')) 153 | imgs = imgs['images'] 154 | 155 | seed(123) # make reproducible 156 | 157 | # create the vocab 158 | vocab, bpe = build_vocab(imgs, params) 159 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 160 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 161 | 162 | # encode captions in large arrays, ready to ship to hdf5 file 163 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 164 | 165 | # create output h5 file 166 | N = len(imgs) 167 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 168 | f_lb.create_dataset("labels", dtype='uint32', data=L) 169 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 170 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 171 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 172 | f_lb.close() 173 | 174 | # create output json file 175 | out = {} 176 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 177 | out['images'] = [] 178 | out['bpe'] = bpe 179 | for i,img in enumerate(imgs): 180 | 181 | jimg = {} 182 | jimg['split'] = img['split'] 183 | if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need 184 | if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 185 | 186 | if params['images_root'] != '': 187 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: 188 | jimg['width'], jimg['height'] = _img.size 189 | 190 | out['images'].append(jimg) 191 | 192 | json.dump(out, open(params['output_json'], 'w')) 193 | print('wrote ', params['output_json']) 194 | 195 | if __name__ == "__main__": 196 | 197 | parser = argparse.ArgumentParser() 198 | 199 | # input json 200 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 201 | parser.add_argument('--output_json', default='data.json', help='output json file') 202 | parser.add_argument('--output_h5', default='data', help='output h5 file') 203 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 204 | 205 | # options 206 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 207 | parser.add_argument('--symbol_count', default=10000, type=int, help='only words that occur more than this number of times will be put in vocab') 208 | 209 | args = parser.parse_args() 210 | params = vars(args) # convert to ordinary dict 211 | print('parsed input parameters:') 212 | print(json.dumps(params, indent = 2)) 213 | main(params) 214 | 215 | 216 | -------------------------------------------------------------------------------- /scripts/prepro_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | import numpy as np 38 | import torch 39 | import torchvision.models as models 40 | import skimage.io 41 | from PIL import Image 42 | 43 | def build_vocab(imgs, params): 44 | count_thr = params['word_count_threshold'] 45 | 46 | # count up the number of words 47 | counts = {} 48 | for img in imgs: 49 | for sent in img['sentences']: 50 | for w in sent['tokens']: 51 | counts[w] = counts.get(w, 0) + 1 52 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 53 | print('top words and their counts:') 54 | print('\n'.join(map(str,cw[:20]))) 55 | 56 | # print some stats 57 | total_words = sum(counts.values()) 58 | print('total words:', total_words) 59 | bad_words = [w for w,n in counts.items() if n <= count_thr] 60 | vocab = [w for w,n in counts.items() if n > count_thr] 61 | bad_count = sum(counts[w] for w in bad_words) 62 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 63 | print('number of words in vocab would be %d' % (len(vocab), )) 64 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 65 | 66 | # lets look at the distribution of lengths as well 67 | sent_lengths = {} 68 | for img in imgs: 69 | for sent in img['sentences']: 70 | txt = sent['tokens'] 71 | nw = len(txt) 72 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 73 | max_len = max(sent_lengths.keys()) 74 | print('max length sentence in raw data: ', max_len) 75 | print('sentence length distribution (count, number of words):') 76 | sum_len = sum(sent_lengths.values()) 77 | for i in range(max_len+1): 78 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 79 | 80 | # lets now produce the final annotations 81 | if bad_count > 0: 82 | # additional special UNK token we will use below to map infrequent words to 83 | print('inserting the special UNK token') 84 | vocab.append('UNK') 85 | 86 | for img in imgs: 87 | img['final_captions'] = [] 88 | for sent in img['sentences']: 89 | txt = sent['tokens'] 90 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 91 | img['final_captions'].append(caption) 92 | 93 | return vocab 94 | 95 | def encode_captions(imgs, params, wtoi): 96 | """ 97 | encode all captions into one large array, which will be 1-indexed. 98 | also produces label_start_ix and label_end_ix which store 1-indexed 99 | and inclusive (Lua-style) pointers to the first and last caption for 100 | each image in the dataset. 101 | """ 102 | 103 | max_length = params['max_length'] 104 | N = len(imgs) 105 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 106 | 107 | label_arrays = [] 108 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 109 | label_end_ix = np.zeros(N, dtype='uint32') 110 | label_length = np.zeros(M, dtype='uint32') 111 | caption_counter = 0 112 | counter = 1 113 | for i,img in enumerate(imgs): 114 | n = len(img['final_captions']) 115 | assert n > 0, 'error: some image has no captions' 116 | 117 | Li = np.zeros((n, max_length), dtype='uint32') 118 | for j,s in enumerate(img['final_captions']): 119 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 120 | caption_counter += 1 121 | for k,w in enumerate(s): 122 | if k < max_length: 123 | Li[j,k] = wtoi[w] 124 | 125 | # note: word indices are 1-indexed, and captions are padded with zeros 126 | label_arrays.append(Li) 127 | label_start_ix[i] = counter 128 | label_end_ix[i] = counter + n - 1 129 | 130 | counter += n 131 | 132 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 133 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 134 | assert np.all(label_length > 0), 'error: some caption had no words?' 135 | 136 | print('encoded captions to array of size ', L.shape) 137 | return L, label_start_ix, label_end_ix, label_length 138 | 139 | def main(params): 140 | 141 | imgs = json.load(open(params['input_json'], 'r')) 142 | imgs = imgs['images'] 143 | 144 | seed(123) # make reproducible 145 | 146 | # create the vocab 147 | vocab = build_vocab(imgs, params) 148 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 149 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 150 | 151 | # encode captions in large arrays, ready to ship to hdf5 file 152 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 153 | 154 | # create output h5 file 155 | N = len(imgs) 156 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 157 | f_lb.create_dataset("labels", dtype='uint32', data=L) 158 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 159 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 160 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 161 | f_lb.close() 162 | 163 | # create output json file 164 | out = {} 165 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 166 | out['images'] = [] 167 | for i,img in enumerate(imgs): 168 | 169 | jimg = {} 170 | jimg['split'] = img['split'] 171 | if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need 172 | if 'cocoid' in img: 173 | jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 174 | elif 'imgid' in img: 175 | jimg['id'] = img['imgid'] 176 | 177 | if params['images_root'] != '': 178 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: 179 | jimg['width'], jimg['height'] = _img.size 180 | 181 | out['images'].append(jimg) 182 | 183 | json.dump(out, open(params['output_json'], 'w')) 184 | print('wrote ', params['output_json']) 185 | 186 | if __name__ == "__main__": 187 | 188 | parser = argparse.ArgumentParser() 189 | 190 | # input json 191 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 192 | parser.add_argument('--output_json', default='data.json', help='output json file') 193 | parser.add_argument('--output_h5', default='data', help='output h5 file') 194 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 195 | 196 | # options 197 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 198 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 199 | 200 | args = parser.parse_args() 201 | params = vars(args) # convert to ordinary dict 202 | print('parsed input parameters:') 203 | print(json.dumps(params, indent = 2)) 204 | main(params) 205 | -------------------------------------------------------------------------------- /models/FCModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | import misc.utils as utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class LSTMCore(nn.Module): 14 | def __init__(self, opt): 15 | super(LSTMCore, self).__init__() 16 | self.input_encoding_size = opt.input_encoding_size 17 | self.rnn_size = opt.rnn_size 18 | self.drop_prob_lm = opt.drop_prob_lm 19 | 20 | # Build a LSTM 21 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 22 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 23 | self.dropout = nn.Dropout(self.drop_prob_lm) 24 | 25 | def forward(self, xt, state): 26 | 27 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 28 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 29 | sigmoid_chunk = torch.sigmoid(sigmoid_chunk) 30 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 31 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 32 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 33 | 34 | in_transform = torch.max(\ 35 | all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size), 36 | all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size)) 37 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 38 | next_h = out_gate * torch.tanh(next_c) 39 | 40 | output = self.dropout(next_h) 41 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 42 | return output, state 43 | 44 | class FCModel(CaptionModel): 45 | def __init__(self, opt): 46 | super(FCModel, self).__init__() 47 | self.vocab_size = opt.vocab_size 48 | self.input_encoding_size = opt.input_encoding_size 49 | self.rnn_type = opt.rnn_type 50 | self.rnn_size = opt.rnn_size 51 | self.num_layers = opt.num_layers 52 | self.drop_prob_lm = opt.drop_prob_lm 53 | self.seq_length = opt.seq_length 54 | self.fc_feat_size = opt.fc_feat_size 55 | 56 | self.ss_prob = 0.0 # Schedule sampling probability 57 | 58 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 59 | self.core = LSTMCore(opt) 60 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 61 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 62 | 63 | self.init_weights() 64 | 65 | def init_weights(self): 66 | initrange = 0.1 67 | self.embed.weight.data.uniform_(-initrange, initrange) 68 | self.logit.bias.data.fill_(0) 69 | self.logit.weight.data.uniform_(-initrange, initrange) 70 | 71 | def init_hidden(self, bsz): 72 | weight = next(self.parameters()) 73 | if self.rnn_type == 'lstm': 74 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), 75 | weight.new_zeros(self.num_layers, bsz, self.rnn_size)) 76 | else: 77 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size) 78 | 79 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 80 | batch_size = fc_feats.size(0) 81 | state = self.init_hidden(batch_size) 82 | outputs = [] 83 | 84 | for i in range(seq.size(1)): 85 | if i == 0: 86 | xt = self.img_embed(fc_feats) 87 | else: 88 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 89 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 90 | sample_mask = sample_prob < self.ss_prob 91 | if sample_mask.sum() == 0: 92 | it = seq[:, i-1].clone() 93 | else: 94 | sample_ind = sample_mask.nonzero().view(-1) 95 | it = seq[:, i-1].data.clone() 96 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 97 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 98 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 99 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 100 | else: 101 | it = seq[:, i-1].clone() 102 | # break if all the sequences end 103 | if i >= 2 and seq[:, i-1].sum() == 0: 104 | break 105 | xt = self.embed(it) 106 | 107 | output, state = self.core(xt, state) 108 | output = F.log_softmax(self.logit(output), dim=1) 109 | outputs.append(output) 110 | 111 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 112 | 113 | def get_logprobs_state(self, it, state): 114 | # 'it' is contains a word index 115 | xt = self.embed(it) 116 | 117 | output, state = self.core(xt, state) 118 | logprobs = F.log_softmax(self.logit(output), dim=1) 119 | 120 | return logprobs, state 121 | 122 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 123 | beam_size = opt.get('beam_size', 10) 124 | batch_size = fc_feats.size(0) 125 | 126 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 127 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 128 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 129 | # lets process every image independently for now, for simplicity 130 | 131 | self.done_beams = [[] for _ in range(batch_size)] 132 | for k in range(batch_size): 133 | state = self.init_hidden(beam_size) 134 | for t in range(2): 135 | if t == 0: 136 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 137 | elif t == 1: # input 138 | it = fc_feats.data.new(beam_size).long().zero_() 139 | xt = self.embed(it) 140 | 141 | output, state = self.core(xt, state) 142 | logprobs = F.log_softmax(self.logit(output), dim=1) 143 | 144 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 145 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 146 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 147 | # return the samples and their log likelihoods 148 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 149 | 150 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): 151 | sample_method = opt.get('sample_method', 'greedy') 152 | beam_size = opt.get('beam_size', 1) 153 | temperature = opt.get('temperature', 1.0) 154 | if beam_size > 1: 155 | return self._sample_beam(fc_feats, att_feats, opt) 156 | 157 | batch_size = fc_feats.size(0) 158 | state = self.init_hidden(batch_size) 159 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) 160 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) 161 | for t in range(self.seq_length + 2): 162 | if t == 0: 163 | xt = self.img_embed(fc_feats) 164 | else: 165 | if t == 1: # input 166 | it = fc_feats.data.new(batch_size).long().zero_() 167 | xt = self.embed(it) 168 | 169 | output, state = self.core(xt, state) 170 | logprobs = F.log_softmax(self.logit(output), dim=1) 171 | 172 | # sample the next_word 173 | if t == self.seq_length + 1: # skip if we achieve maximum length 174 | break 175 | if sample_method == 'greedy': 176 | sampleLogprobs, it = torch.max(logprobs.data, 1) 177 | it = it.view(-1).long() 178 | else: 179 | if temperature == 1.0: 180 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 181 | else: 182 | # scale logprobs by temperature 183 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 184 | it = torch.multinomial(prob_prev, 1).cuda() 185 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 186 | it = it.view(-1).long() # and flatten indices for downstream processing 187 | 188 | if t >= 1: 189 | # stop when all finished 190 | if t == 1: 191 | unfinished = it > 0 192 | else: 193 | unfinished = unfinished * (it > 0) 194 | it = it * unfinished.type_as(it) 195 | seq[:,t-1] = it #seq[t] the input of t+2 time step 196 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1) 197 | if unfinished.sum() == 0: 198 | break 199 | 200 | return seq, seqLogprobs 201 | -------------------------------------------------------------------------------- /models/AoAModel.py: -------------------------------------------------------------------------------- 1 | # Implementation for paper 'Attention on Attention for Image Captioning' 2 | # https://arxiv.org/abs/1908.06954 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import misc.utils as utils 12 | 13 | from .AttModel import pack_wrapper, AttModel, Attention 14 | from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward 15 | 16 | class MultiHeadedDotAttention(nn.Module): 17 | def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3): 18 | super(MultiHeadedDotAttention, self).__init__() 19 | assert d_model * scale % h == 0 20 | # We assume d_v always equals d_k 21 | self.d_k = d_model * scale // h 22 | self.h = h 23 | 24 | # Do we need to do linear projections on K and V? 25 | self.project_k_v = project_k_v 26 | 27 | # normalize the query? 28 | if norm_q: 29 | self.norm = LayerNorm(d_model) 30 | else: 31 | self.norm = lambda x:x 32 | self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v) 33 | 34 | # output linear layer after the multi-head attention? 35 | self.output_layer = nn.Linear(d_model * scale, d_model) 36 | 37 | # apply aoa after attention? 38 | self.use_aoa = do_aoa 39 | if self.use_aoa: 40 | self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU()) 41 | # dropout to the input of AoA layer 42 | if dropout_aoa > 0: 43 | self.dropout_aoa = nn.Dropout(p=dropout_aoa) 44 | else: 45 | self.dropout_aoa = lambda x:x 46 | 47 | if self.use_aoa or not use_output_layer: 48 | # AoA doesn't need the output linear layer 49 | del self.output_layer 50 | self.output_layer = lambda x:x 51 | 52 | self.attn = None 53 | self.dropout = nn.Dropout(p=dropout) 54 | 55 | def forward(self, query, value, key, mask=None): 56 | if mask is not None: 57 | if len(mask.size()) == 2: 58 | mask = mask.unsqueeze(-2) 59 | # Same mask applied to all h heads. 60 | mask = mask.unsqueeze(1) 61 | 62 | single_query = 0 63 | if len(query.size()) == 2: 64 | single_query = 1 65 | query = query.unsqueeze(1) 66 | 67 | nbatches = query.size(0) 68 | 69 | query = self.norm(query) 70 | 71 | # Do all the linear projections in batch from d_model => h x d_k 72 | if self.project_k_v == 0: 73 | query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 74 | key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 75 | value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 76 | else: 77 | query_, key_, value_ = \ 78 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 79 | for l, x in zip(self.linears, (query, key, value))] 80 | 81 | # Apply attention on all the projected vectors in batch. 82 | x, self.attn = attention(query_, key_, value_, mask=mask, 83 | dropout=self.dropout) 84 | 85 | # "Concat" using a view 86 | x = x.transpose(1, 2).contiguous() \ 87 | .view(nbatches, -1, self.h * self.d_k) 88 | 89 | if self.use_aoa: 90 | # Apply AoA 91 | x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1))) 92 | x = self.output_layer(x) 93 | 94 | if single_query: 95 | query = query.squeeze(1) 96 | x = x.squeeze(1) 97 | return x 98 | 99 | class AoA_Refiner_Layer(nn.Module): 100 | def __init__(self, size, self_attn, feed_forward, dropout): 101 | super(AoA_Refiner_Layer, self).__init__() 102 | self.self_attn = self_attn 103 | self.feed_forward = feed_forward 104 | self.use_ff = 0 105 | if self.feed_forward is not None: 106 | self.use_ff = 1 107 | self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff) 108 | self.size = size 109 | 110 | def forward(self, x, mask): 111 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 112 | return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x 113 | 114 | class AoA_Refiner_Core(nn.Module): 115 | def __init__(self, opt): 116 | super(AoA_Refiner_Core, self).__init__() 117 | attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3)) 118 | layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1) 119 | self.layers = clones(layer, 6) 120 | self.norm = LayerNorm(layer.size) 121 | 122 | def forward(self, x, mask): 123 | for layer in self.layers: 124 | x = layer(x, mask) 125 | return self.norm(x) 126 | 127 | class AoA_Decoder_Core(nn.Module): 128 | def __init__(self, opt): 129 | super(AoA_Decoder_Core, self).__init__() 130 | self.drop_prob_lm = opt.drop_prob_lm 131 | self.d_model = opt.rnn_size 132 | self.use_multi_head = opt.use_multi_head 133 | self.multi_head_scale = opt.multi_head_scale 134 | self.use_ctx_drop = getattr(opt, 'ctx_drop', 0) 135 | self.out_res = getattr(opt, 'out_res', 0) 136 | self.decoder_type = getattr(opt, 'decoder_type', 'AoA') 137 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1 138 | self.out_drop = nn.Dropout(self.drop_prob_lm) 139 | 140 | if self.decoder_type == 'AoA': 141 | # AoA layer 142 | self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU()) 143 | elif self.decoder_type == 'LSTM': 144 | # LSTM layer 145 | self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size) 146 | else: 147 | # Base linear layer 148 | self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU()) 149 | 150 | # if opt.use_multi_head == 1: # TODO, not implemented for now 151 | # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale) 152 | if opt.use_multi_head == 2: 153 | self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1) 154 | else: 155 | self.attention = Attention(opt) 156 | 157 | if self.use_ctx_drop: 158 | self.ctx_drop = nn.Dropout(self.drop_prob_lm) 159 | else: 160 | self.ctx_drop = lambda x :x 161 | 162 | def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None): 163 | # state[0][1] is the context vector at the last step 164 | h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0])) 165 | 166 | if self.use_multi_head == 2: 167 | att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks) 168 | else: 169 | att = self.attention(h_att, att_feats, p_att_feats, att_masks) 170 | 171 | ctx_input = torch.cat([att, h_att], 1) 172 | if self.decoder_type == 'LSTM': 173 | output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1])) 174 | state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic))) 175 | else: 176 | output = self.att2ctx(ctx_input) 177 | # save the context vector to state[0][1] 178 | state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))) 179 | 180 | if self.out_res: 181 | # add residual connection 182 | output = output + h_att 183 | 184 | output = self.out_drop(output) 185 | return output, state 186 | 187 | class AoAModel(AttModel): 188 | def __init__(self, opt): 189 | super(AoAModel, self).__init__(opt) 190 | self.num_layers = 2 191 | # mean pooling 192 | self.use_mean_feats = getattr(opt, 'mean_feats', 1) 193 | if opt.use_multi_head == 2: 194 | del self.ctx2att 195 | self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size) 196 | 197 | if self.use_mean_feats: 198 | del self.fc_embed 199 | if opt.refine: 200 | self.refiner = AoA_Refiner_Core(opt) 201 | else: 202 | self.refiner = lambda x,y : x 203 | self.core = AoA_Decoder_Core(opt) 204 | 205 | 206 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 207 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 208 | 209 | # embed att feats 210 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) 211 | att_feats = self.refiner(att_feats, att_masks) 212 | 213 | if self.use_mean_feats: 214 | # meaning pooling 215 | if att_masks is None: 216 | mean_feats = torch.mean(att_feats, dim=1) 217 | else: 218 | mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1)) 219 | else: 220 | mean_feats = self.fc_embed(fc_feats) 221 | 222 | # Project the attention feats first to reduce memory and computation. 223 | p_att_feats = self.ctx2att(att_feats) 224 | 225 | return mean_feats, att_feats, p_att_feats, att_masks -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import torch.optim as optim 10 | import os 11 | 12 | import six 13 | from six.moves import cPickle 14 | 15 | bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that'] 16 | bad_endings += ['the'] 17 | 18 | def pickle_load(f): 19 | """ Load a pickle. 20 | Parameters 21 | ---------- 22 | f: file-like object 23 | """ 24 | if six.PY3: 25 | return cPickle.load(f, encoding='latin-1') 26 | else: 27 | return cPickle.load(f) 28 | 29 | 30 | def pickle_dump(obj, f): 31 | """ Dump a pickle. 32 | Parameters 33 | ---------- 34 | obj: pickled object 35 | f: file-like object 36 | """ 37 | if six.PY3: 38 | return cPickle.dump(obj, f, protocol=2) 39 | else: 40 | return cPickle.dump(obj, f) 41 | 42 | 43 | def if_use_feat(caption_model): 44 | # Decide if load attention feature according to caption model 45 | if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']: 46 | use_att, use_fc = False, True 47 | elif caption_model == 'language_model': 48 | use_att, use_fc = False, False 49 | elif caption_model in ['topdown', 'aoa']: 50 | use_fc, use_att = True, True 51 | else: 52 | use_att, use_fc = True, False 53 | return use_fc, use_att 54 | 55 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 56 | def decode_sequence(ix_to_word, seq): 57 | N, D = seq.size() 58 | out = [] 59 | for i in range(N): 60 | txt = '' 61 | for j in range(D): 62 | ix = seq[i,j] 63 | if ix > 0 : 64 | if j >= 1: 65 | txt = txt + ' ' 66 | txt = txt + ix_to_word[str(ix.item())] 67 | else: 68 | break 69 | if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): 70 | flag = 0 71 | words = txt.split(' ') 72 | for j in range(len(words)): 73 | if words[-j-1] not in bad_endings: 74 | flag = -j 75 | break 76 | txt = ' '.join(words[0:len(words)+flag]) 77 | out.append(txt.replace('@@ ', '')) 78 | return out 79 | 80 | def to_contiguous(tensor): 81 | if tensor.is_contiguous(): 82 | return tensor 83 | else: 84 | return tensor.contiguous() 85 | 86 | class RewardCriterion(nn.Module): 87 | def __init__(self): 88 | super(RewardCriterion, self).__init__() 89 | 90 | def forward(self, input, seq, reward): 91 | input = to_contiguous(input).view(-1) 92 | reward = to_contiguous(reward).view(-1) 93 | mask = (seq>0).float() 94 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1) 95 | output = - input * reward * mask 96 | output = torch.sum(output) / torch.sum(mask) 97 | 98 | return output 99 | 100 | class LanguageModelCriterion(nn.Module): 101 | def __init__(self): 102 | super(LanguageModelCriterion, self).__init__() 103 | 104 | def forward(self, input, target, mask): 105 | # truncate to the same size 106 | target = target[:, :input.size(1)] 107 | mask = mask[:, :input.size(1)] 108 | 109 | output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask 110 | output = torch.sum(output) / torch.sum(mask) 111 | 112 | return output 113 | 114 | class LabelSmoothing(nn.Module): 115 | "Implement label smoothing." 116 | def __init__(self, size=0, padding_idx=0, smoothing=0.0): 117 | super(LabelSmoothing, self).__init__() 118 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False) 119 | # self.padding_idx = padding_idx 120 | self.confidence = 1.0 - smoothing 121 | self.smoothing = smoothing 122 | # self.size = size 123 | self.true_dist = None 124 | 125 | def forward(self, input, target, mask): 126 | # truncate to the same size 127 | target = target[:, :input.size(1)] 128 | mask = mask[:, :input.size(1)] 129 | 130 | input = to_contiguous(input).view(-1, input.size(-1)) 131 | target = to_contiguous(target).view(-1) 132 | mask = to_contiguous(mask).view(-1) 133 | 134 | # assert x.size(1) == self.size 135 | self.size = input.size(1) 136 | # true_dist = x.data.clone() 137 | true_dist = input.data.clone() 138 | # true_dist.fill_(self.smoothing / (self.size - 2)) 139 | true_dist.fill_(self.smoothing / (self.size - 1)) 140 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 141 | # true_dist[:, self.padding_idx] = 0 142 | # mask = torch.nonzero(target.data == self.padding_idx) 143 | # self.true_dist = true_dist 144 | return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum() 145 | 146 | def set_lr(optimizer, lr): 147 | for group in optimizer.param_groups: 148 | group['lr'] = lr 149 | 150 | def get_lr(optimizer): 151 | for group in optimizer.param_groups: 152 | return group['lr'] 153 | 154 | def clip_gradient(optimizer, grad_clip): 155 | for group in optimizer.param_groups: 156 | for param in group['params']: 157 | param.grad.data.clamp_(-grad_clip, grad_clip) 158 | 159 | def build_optimizer(params, opt): 160 | if opt.optim == 'rmsprop': 161 | return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay) 162 | elif opt.optim == 'adagrad': 163 | return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay) 164 | elif opt.optim == 'sgd': 165 | return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay) 166 | elif opt.optim == 'sgdm': 167 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay) 168 | elif opt.optim == 'sgdmom': 169 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True) 170 | elif opt.optim == 'adam': 171 | return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) 172 | else: 173 | raise Exception("bad option opt.optim: {}".format(opt.optim)) 174 | 175 | 176 | def penalty_builder(penalty_config): 177 | if penalty_config == '': 178 | return lambda x,y: y 179 | pen_type, alpha = penalty_config.split('_') 180 | alpha = float(alpha) 181 | if pen_type == 'wu': 182 | return lambda x,y: length_wu(x,y,alpha) 183 | if pen_type == 'avg': 184 | return lambda x,y: length_average(x,y,alpha) 185 | 186 | def length_wu(length, logprobs, alpha=0.): 187 | """ 188 | NMT length re-ranking score from 189 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 190 | """ 191 | 192 | modifier = (((5 + length) ** alpha) / 193 | ((5 + 1) ** alpha)) 194 | return (logprobs / modifier) 195 | 196 | def length_average(length, logprobs, alpha=0.): 197 | """ 198 | Returns the average probability of tokens in a sequence. 199 | """ 200 | return logprobs / length 201 | 202 | 203 | class NoamOpt(object): 204 | "Optim wrapper that implements rate." 205 | def __init__(self, model_size, factor, warmup, optimizer): 206 | self.optimizer = optimizer 207 | self._step = 0 208 | self.warmup = warmup 209 | self.factor = factor 210 | self.model_size = model_size 211 | self._rate = 0 212 | 213 | def step(self): 214 | "Update parameters and rate" 215 | self._step += 1 216 | rate = self.rate() 217 | for p in self.optimizer.param_groups: 218 | p['lr'] = rate 219 | self._rate = rate 220 | self.optimizer.step() 221 | 222 | def rate(self, step = None): 223 | "Implement `lrate` above" 224 | if step is None: 225 | step = self._step 226 | return self.factor * \ 227 | (self.model_size ** (-0.5) * 228 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 229 | 230 | def __getattr__(self, name): 231 | return getattr(self.optimizer, name) 232 | 233 | class ReduceLROnPlateau(object): 234 | "Optim wrapper that implements rate." 235 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): 236 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) 237 | self.optimizer = optimizer 238 | self.current_lr = get_lr(optimizer) 239 | 240 | def step(self): 241 | "Update parameters and rate" 242 | self.optimizer.step() 243 | 244 | def scheduler_step(self, val): 245 | self.scheduler.step(val) 246 | self.current_lr = get_lr(self.optimizer) 247 | 248 | def state_dict(self): 249 | return {'current_lr':self.current_lr, 250 | 'scheduler_state_dict': self.scheduler.state_dict(), 251 | 'optimizer_state_dict': self.optimizer.state_dict()} 252 | 253 | def load_state_dict(self, state_dict): 254 | if 'current_lr' not in state_dict: 255 | # it's normal optimizer 256 | self.optimizer.load_state_dict(state_dict) 257 | set_lr(self.optimizer, self.current_lr) # use the lr fromt the option 258 | else: 259 | # it's a schduler 260 | self.current_lr = state_dict['current_lr'] 261 | self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) 262 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) 263 | # current_lr is actually useless in this case 264 | 265 | def rate(self, step = None): 266 | "Implement `lrate` above" 267 | if step is None: 268 | step = self._step 269 | return self.factor * \ 270 | (self.model_size ** (-0.5) * 271 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 272 | 273 | def __getattr__(self, name): 274 | return getattr(self.optimizer, name) 275 | 276 | def get_std_opt(model, factor=1, warmup=2000): 277 | # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, 278 | # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 279 | return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup, 280 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 281 | 282 | -------------------------------------------------------------------------------- /models/OldModel.py: -------------------------------------------------------------------------------- 1 | # This file contains ShowAttendTell and AllImg model 2 | 3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention 4 | # https://arxiv.org/abs/1502.03044 5 | 6 | # AllImg is a model where 7 | # img feature is concatenated with word embedding at every time step as the input of lstm 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | import misc.utils as utils 17 | 18 | from .CaptionModel import CaptionModel 19 | 20 | class OldModel(CaptionModel): 21 | def __init__(self, opt): 22 | super(OldModel, self).__init__() 23 | self.vocab_size = opt.vocab_size 24 | self.input_encoding_size = opt.input_encoding_size 25 | self.rnn_type = opt.rnn_type 26 | self.rnn_size = opt.rnn_size 27 | self.num_layers = opt.num_layers 28 | self.drop_prob_lm = opt.drop_prob_lm 29 | self.seq_length = opt.seq_length 30 | self.fc_feat_size = opt.fc_feat_size 31 | self.att_feat_size = opt.att_feat_size 32 | 33 | self.ss_prob = 0.0 # Schedule sampling probability 34 | 35 | self.linear = nn.Linear(self.fc_feat_size, self.num_layers * self.rnn_size) # feature to rnn_size 36 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 37 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 38 | self.dropout = nn.Dropout(self.drop_prob_lm) 39 | 40 | self.init_weights() 41 | 42 | def init_weights(self): 43 | initrange = 0.1 44 | self.embed.weight.data.uniform_(-initrange, initrange) 45 | self.logit.bias.data.fill_(0) 46 | self.logit.weight.data.uniform_(-initrange, initrange) 47 | 48 | def init_hidden(self, fc_feats): 49 | image_map = self.linear(fc_feats).view(-1, self.num_layers, self.rnn_size).transpose(0, 1) 50 | if self.rnn_type == 'lstm': 51 | return (image_map, image_map) 52 | else: 53 | return image_map 54 | 55 | def forward(self, fc_feats, att_feats, seq): 56 | batch_size = fc_feats.size(0) 57 | state = self.init_hidden(fc_feats) 58 | 59 | outputs = [] 60 | 61 | for i in range(seq.size(1) - 1): 62 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 63 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 64 | sample_mask = sample_prob < self.ss_prob 65 | if sample_mask.sum() == 0: 66 | it = seq[:, i].clone() 67 | else: 68 | sample_ind = sample_mask.nonzero().view(-1) 69 | it = seq[:, i].data.clone() 70 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 71 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 72 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 73 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 74 | else: 75 | it = seq[:, i].clone() 76 | # break if all the sequences end 77 | if i >= 1 and seq[:, i].sum() == 0: 78 | break 79 | 80 | xt = self.embed(it) 81 | 82 | output, state = self.core(xt, fc_feats, att_feats, state) 83 | output = F.log_softmax(self.logit(self.dropout(output)), dim=1) 84 | outputs.append(output) 85 | 86 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 87 | 88 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state): 89 | # 'it' contains a word index 90 | xt = self.embed(it) 91 | 92 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state) 93 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) 94 | 95 | return logprobs, state 96 | 97 | def sample_beam(self, fc_feats, att_feats, opt={}): 98 | beam_size = opt.get('beam_size', 10) 99 | batch_size = fc_feats.size(0) 100 | 101 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 102 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 103 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 104 | # lets process every image independently for now, for simplicity 105 | 106 | self.done_beams = [[] for _ in range(batch_size)] 107 | for k in range(batch_size): 108 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, self.fc_feat_size) 109 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 110 | 111 | state = self.init_hidden(tmp_fc_feats) 112 | 113 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 114 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 115 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 116 | done_beams = [] 117 | for t in range(1): 118 | if t == 0: # input 119 | it = fc_feats.data.new(beam_size).long().zero_() 120 | xt = self.embed(it) 121 | 122 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state) 123 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) 124 | 125 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, opt=opt) 126 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 127 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 128 | # return the samples and their log likelihoods 129 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 130 | 131 | def sample(self, fc_feats, att_feats, opt={}): 132 | sample_method = opt.get('sample_method', 'greedy') 133 | beam_size = opt.get('beam_size', 1) 134 | temperature = opt.get('temperature', 1.0) 135 | if beam_size > 1: 136 | return self.sample_beam(fc_feats, att_feats, opt) 137 | 138 | batch_size = fc_feats.size(0) 139 | state = self.init_hidden(fc_feats) 140 | 141 | seq = [] 142 | seqLogprobs = [] 143 | for t in range(self.seq_length + 1): 144 | if t == 0: # input 145 | it = fc_feats.data.new(batch_size).long().zero_() 146 | elif sample_method == 'greedy': 147 | sampleLogprobs, it = torch.max(logprobs.data, 1) 148 | it = it.view(-1).long() 149 | else: 150 | if temperature == 1.0: 151 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 152 | else: 153 | # scale logprobs by temperature 154 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 155 | it = torch.multinomial(prob_prev, 1).cuda() 156 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 157 | it = it.view(-1).long() # and flatten indices for downstream processing 158 | 159 | xt = self.embed(it) 160 | 161 | if t >= 1: 162 | # stop when all finished 163 | if t == 1: 164 | unfinished = it > 0 165 | else: 166 | unfinished = unfinished * (it > 0) 167 | if unfinished.sum() == 0: 168 | break 169 | it = it * unfinished.type_as(it) 170 | seq.append(it) #seq[t] the input of t+2 time step 171 | seqLogprobs.append(sampleLogprobs.view(-1)) 172 | 173 | output, state = self.core(xt, fc_feats, att_feats, state) 174 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) 175 | 176 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 177 | 178 | 179 | class ShowAttendTellCore(nn.Module): 180 | def __init__(self, opt): 181 | super(ShowAttendTellCore, self).__init__() 182 | self.input_encoding_size = opt.input_encoding_size 183 | self.rnn_type = opt.rnn_type 184 | self.rnn_size = opt.rnn_size 185 | self.num_layers = opt.num_layers 186 | self.drop_prob_lm = opt.drop_prob_lm 187 | self.fc_feat_size = opt.fc_feat_size 188 | self.att_feat_size = opt.att_feat_size 189 | self.att_hid_size = opt.att_hid_size 190 | 191 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.att_feat_size, 192 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 193 | 194 | if self.att_hid_size > 0: 195 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size) 196 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 197 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 198 | else: 199 | self.ctx2att = nn.Linear(self.att_feat_size, 1) 200 | self.h2att = nn.Linear(self.rnn_size, 1) 201 | 202 | def forward(self, xt, fc_feats, att_feats, state): 203 | att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size 204 | att = att_feats.view(-1, self.att_feat_size) 205 | if self.att_hid_size > 0: 206 | att = self.ctx2att(att) # (batch * att_size) * att_hid_size 207 | att = att.view(-1, att_size, self.att_hid_size) # batch * att_size * att_hid_size 208 | att_h = self.h2att(state[0][-1]) # batch * att_hid_size 209 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 210 | dot = att + att_h # batch * att_size * att_hid_size 211 | dot = F.tanh(dot) # batch * att_size * att_hid_size 212 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 213 | dot = self.alpha_net(dot) # (batch * att_size) * 1 214 | dot = dot.view(-1, att_size) # batch * att_size 215 | else: 216 | att = self.ctx2att(att)(att) # (batch * att_size) * 1 217 | att = att.view(-1, att_size) # batch * att_size 218 | att_h = self.h2att(state[0][-1]) # batch * 1 219 | att_h = att_h.expand_as(att) # batch * att_size 220 | dot = att_h + att # batch * att_size 221 | 222 | weight = F.softmax(dot, dim=1) 223 | att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size 224 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 225 | 226 | output, state = self.rnn(torch.cat([xt, att_res], 1).unsqueeze(0), state) 227 | return output.squeeze(0), state 228 | 229 | class AllImgCore(nn.Module): 230 | def __init__(self, opt): 231 | super(AllImgCore, self).__init__() 232 | self.input_encoding_size = opt.input_encoding_size 233 | self.rnn_type = opt.rnn_type 234 | self.rnn_size = opt.rnn_size 235 | self.num_layers = opt.num_layers 236 | self.drop_prob_lm = opt.drop_prob_lm 237 | self.fc_feat_size = opt.fc_feat_size 238 | 239 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.fc_feat_size, 240 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 241 | 242 | def forward(self, xt, fc_feats, att_feats, state): 243 | output, state = self.rnn(torch.cat([xt, fc_feats], 1).unsqueeze(0), state) 244 | return output.squeeze(0), state 245 | 246 | class ShowAttendTellModel(OldModel): 247 | def __init__(self, opt): 248 | super(ShowAttendTellModel, self).__init__(opt) 249 | self.core = ShowAttendTellCore(opt) 250 | 251 | class AllImgModel(OldModel): 252 | def __init__(self, opt): 253 | super(AllImgModel, self).__init__(opt) 254 | self.core = AllImgCore(opt) 255 | 256 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | import numpy as np 10 | 11 | import time 12 | import os 13 | from six.moves import cPickle 14 | import traceback 15 | 16 | import opts 17 | import models 18 | from dataloader import * 19 | import skimage.io 20 | import eval_utils 21 | import misc.utils as utils 22 | from misc.rewards import init_scorer, get_self_critical_reward 23 | from misc.loss_wrapper import LossWrapper 24 | 25 | try: 26 | import tensorboardX as tb 27 | except ImportError: 28 | print("tensorboardX is not installed") 29 | tb = None 30 | 31 | def add_summary_value(writer, key, value, iteration): 32 | if writer: 33 | writer.add_scalar(key, value, iteration) 34 | 35 | def train(opt): 36 | # Deal with feature things before anything 37 | opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) 38 | if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 39 | 40 | acc_steps = getattr(opt, 'acc_steps', 1) 41 | 42 | loader = DataLoader(opt) 43 | opt.vocab_size = loader.vocab_size 44 | opt.seq_length = loader.seq_length 45 | 46 | tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) 47 | 48 | infos = {} 49 | histories = {} 50 | if opt.start_from is not None: 51 | # open old infos and check if models are compatible 52 | with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f: 53 | infos = utils.pickle_load(f) 54 | saved_model_opt = infos['opt'] 55 | need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] 56 | for checkme in need_be_same: 57 | assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme 58 | 59 | if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): 60 | with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f: 61 | histories = utils.pickle_load(f) 62 | else: 63 | infos['iter'] = 0 64 | infos['epoch'] = 0 65 | infos['iterators'] = loader.iterators 66 | infos['split_ix'] = loader.split_ix 67 | infos['vocab'] = loader.get_vocab() 68 | infos['opt'] = opt 69 | 70 | iteration = infos.get('iter', 0) 71 | epoch = infos.get('epoch', 0) 72 | 73 | val_result_history = histories.get('val_result_history', {}) 74 | loss_history = histories.get('loss_history', {}) 75 | lr_history = histories.get('lr_history', {}) 76 | ss_prob_history = histories.get('ss_prob_history', {}) 77 | 78 | loader.iterators = infos.get('iterators', loader.iterators) 79 | loader.split_ix = infos.get('split_ix', loader.split_ix) 80 | if opt.load_best_score == 1: 81 | best_val_score = infos.get('best_val_score', None) 82 | 83 | opt.vocab = loader.get_vocab() 84 | model = models.setup(opt).cuda() 85 | del opt.vocab 86 | dp_model = torch.nn.DataParallel(model) 87 | lw_model = LossWrapper(model, opt) 88 | dp_lw_model = torch.nn.DataParallel(lw_model) 89 | 90 | epoch_done = True 91 | # Assure in training mode 92 | dp_lw_model.train() 93 | 94 | if opt.noamopt: 95 | assert opt.caption_model in ['transformer','aoa'], 'noamopt can only work with transformer' 96 | optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) 97 | optimizer._step = iteration 98 | elif opt.reduce_on_plateau: 99 | optimizer = utils.build_optimizer(model.parameters(), opt) 100 | optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) 101 | else: 102 | optimizer = utils.build_optimizer(model.parameters(), opt) 103 | # Load the optimizer 104 | if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): 105 | optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) 106 | 107 | 108 | def save_checkpoint(model, infos, optimizer, histories=None, append=''): 109 | if len(append) > 0: 110 | append = '-' + append 111 | # if checkpoint_path doesn't exist 112 | if not os.path.isdir(opt.checkpoint_path): 113 | os.makedirs(opt.checkpoint_path) 114 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) 115 | torch.save(model.state_dict(), checkpoint_path) 116 | print("model saved to {}".format(checkpoint_path)) 117 | optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) 118 | torch.save(optimizer.state_dict(), optimizer_path) 119 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f: 120 | utils.pickle_dump(infos, f) 121 | if histories: 122 | with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f: 123 | utils.pickle_dump(histories, f) 124 | 125 | try: 126 | while True: 127 | if epoch_done: 128 | if not opt.noamopt and not opt.reduce_on_plateau: 129 | # Assign the learning rate 130 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: 131 | frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every 132 | decay_factor = opt.learning_rate_decay_rate ** frac 133 | opt.current_lr = opt.learning_rate * decay_factor 134 | else: 135 | opt.current_lr = opt.learning_rate 136 | utils.set_lr(optimizer, opt.current_lr) # set the decayed rate 137 | # Assign the scheduled sampling prob 138 | if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: 139 | frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every 140 | opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) 141 | model.ss_prob = opt.ss_prob 142 | 143 | # If start self critical training 144 | if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: 145 | sc_flag = True 146 | init_scorer(opt.cached_tokens) 147 | else: 148 | sc_flag = False 149 | 150 | epoch_done = False 151 | 152 | start = time.time() 153 | if (opt.use_warmup == 1) and (iteration < opt.noamopt_warmup): 154 | opt.current_lr = opt.learning_rate * (iteration+1) / opt.noamopt_warmup 155 | utils.set_lr(optimizer, opt.current_lr) 156 | # Load data from train split (0) 157 | data = loader.get_batch('train') 158 | print('Read data:', time.time() - start) 159 | 160 | if (iteration % acc_steps == 0): 161 | optimizer.zero_grad() 162 | 163 | torch.cuda.synchronize() 164 | start = time.time() 165 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] 166 | tmp = [_ if _ is None else _.cuda() for _ in tmp] 167 | fc_feats, att_feats, labels, masks, att_masks = tmp 168 | 169 | model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) 170 | 171 | loss = model_out['loss'].mean() 172 | loss_sp = loss / acc_steps 173 | 174 | loss_sp.backward() 175 | if ((iteration+1) % acc_steps == 0): 176 | utils.clip_gradient(optimizer, opt.grad_clip) 177 | optimizer.step() 178 | torch.cuda.synchronize() 179 | train_loss = loss.item() 180 | end = time.time() 181 | if not sc_flag: 182 | print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ 183 | .format(iteration, epoch, train_loss, end - start)) 184 | else: 185 | print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ 186 | .format(iteration, epoch, model_out['reward'].mean(), end - start)) 187 | 188 | # Update the iteration and epoch 189 | iteration += 1 190 | if data['bounds']['wrapped']: 191 | epoch += 1 192 | epoch_done = True 193 | 194 | # Write the training loss summary 195 | if (iteration % opt.losses_log_every == 0): 196 | add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) 197 | if opt.noamopt: 198 | opt.current_lr = optimizer.rate() 199 | elif opt.reduce_on_plateau: 200 | opt.current_lr = optimizer.current_lr 201 | add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) 202 | add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) 203 | if sc_flag: 204 | add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) 205 | 206 | loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean() 207 | lr_history[iteration] = opt.current_lr 208 | ss_prob_history[iteration] = model.ss_prob 209 | 210 | # update infos 211 | infos['iter'] = iteration 212 | infos['epoch'] = epoch 213 | infos['iterators'] = loader.iterators 214 | infos['split_ix'] = loader.split_ix 215 | 216 | # make evaluation on validation set, and save model 217 | if (iteration % opt.save_checkpoint_every == 0): 218 | # eval model 219 | eval_kwargs = {'split': 'val', 220 | 'dataset': opt.input_json} 221 | eval_kwargs.update(vars(opt)) 222 | val_loss, predictions, lang_stats = eval_utils.eval_split( 223 | dp_model, lw_model.crit, loader, eval_kwargs) 224 | 225 | if opt.reduce_on_plateau: 226 | if 'CIDEr' in lang_stats: 227 | optimizer.scheduler_step(-lang_stats['CIDEr']) 228 | else: 229 | optimizer.scheduler_step(val_loss) 230 | # Write validation result into summary 231 | add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) 232 | if lang_stats is not None: 233 | for k,v in lang_stats.items(): 234 | add_summary_value(tb_summary_writer, k, v, iteration) 235 | val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} 236 | 237 | # Save model if is improving on validation result 238 | if opt.language_eval == 1: 239 | current_score = lang_stats['CIDEr'] 240 | else: 241 | current_score = - val_loss 242 | 243 | best_flag = False 244 | 245 | if best_val_score is None or current_score > best_val_score: 246 | best_val_score = current_score 247 | best_flag = True 248 | 249 | # Dump miscalleous informations 250 | infos['best_val_score'] = best_val_score 251 | histories['val_result_history'] = val_result_history 252 | histories['loss_history'] = loss_history 253 | histories['lr_history'] = lr_history 254 | histories['ss_prob_history'] = ss_prob_history 255 | 256 | save_checkpoint(model, infos, optimizer, histories) 257 | if opt.save_history_ckpt: 258 | save_checkpoint(model, infos, optimizer, append=str(iteration)) 259 | 260 | if best_flag: 261 | save_checkpoint(model, infos, optimizer, append='best') 262 | 263 | # Stop if reaching max epochs 264 | if epoch >= opt.max_epochs and opt.max_epochs != -1: 265 | break 266 | except (RuntimeError, KeyboardInterrupt): 267 | print('Save ckpt on exception ...') 268 | save_checkpoint(model, infos, optimizer) 269 | print('Save ckpt done.') 270 | stack_trace = traceback.format_exc() 271 | print(stack_trace) 272 | 273 | 274 | opt = opts.parse_opt() 275 | train(opt) 276 | -------------------------------------------------------------------------------- /models/CaptionModel.py: -------------------------------------------------------------------------------- 1 | # This file contains ShowAttendTell and AllImg model 2 | 3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention 4 | # https://arxiv.org/abs/1502.03044 5 | 6 | # AllImg is a model where 7 | # img feature is concatenated with word embedding at every time step as the input of lstm 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.autograd import * 17 | import misc.utils as utils 18 | 19 | from functools import reduce 20 | 21 | 22 | class CaptionModel(nn.Module): 23 | def __init__(self): 24 | super(CaptionModel, self).__init__() 25 | 26 | # implements beam search 27 | # calls beam_step and returns the final set of beams 28 | # augments log-probabilities with diversity terms when number of groups > 1 29 | 30 | def forward(self, *args, **kwargs): 31 | mode = kwargs.get('mode', 'forward') 32 | if 'mode' in kwargs: 33 | del kwargs['mode'] 34 | return getattr(self, '_'+mode)(*args, **kwargs) 35 | 36 | def beam_search(self, init_state, init_logprobs, *args, **kwargs): 37 | 38 | # function computes the similarity score to be augmented 39 | def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): 40 | local_time = t - divm 41 | unaug_logprobsf = logprobsf.clone() 42 | for prev_choice in range(divm): 43 | prev_decisions = beam_seq_table[prev_choice][local_time] 44 | for sub_beam in range(bdash): 45 | for prev_labels in range(bdash): 46 | logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda 47 | return unaug_logprobsf 48 | 49 | # does one step of classical beam search 50 | 51 | def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): 52 | #INPUTS: 53 | #logprobsf: probabilities augmented after diversity 54 | #beam_size: obvious 55 | #t : time instant 56 | #beam_seq : tensor contanining the beams 57 | #beam_seq_logprobs: tensor contanining the beam logprobs 58 | #beam_logprobs_sum: tensor contanining joint logprobs 59 | #OUPUTS: 60 | #beam_seq : tensor containing the word indices of the decoded captions 61 | #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq 62 | #beam_logprobs_sum : joint log-probability of each beam 63 | 64 | ys,ix = torch.sort(logprobsf,1,True) 65 | candidates = [] 66 | cols = min(beam_size, ys.size(1)) 67 | rows = beam_size 68 | if t == 0: 69 | rows = 1 70 | for c in range(cols): # for each column (word, essentially) 71 | for q in range(rows): # for each beam expansion 72 | #compute logprob of expanding beam q with word in (sorted) position c 73 | local_logprob = ys[q,c].item() 74 | candidate_logprob = beam_logprobs_sum[q] + local_logprob 75 | local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] 76 | candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob}) 77 | candidates = sorted(candidates, key=lambda x: -x['p']) 78 | 79 | new_state = [_.clone() for _ in state] 80 | #beam_seq_prev, beam_seq_logprobs_prev 81 | if t >= 1: 82 | #we''ll need these as reference when we fork beams around 83 | beam_seq_prev = beam_seq[:t].clone() 84 | beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() 85 | for vix in range(beam_size): 86 | v = candidates[vix] 87 | #fork beam index q into index vix 88 | if t >= 1: 89 | beam_seq[:t, vix] = beam_seq_prev[:, v['q']] 90 | beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] 91 | #rearrange recurrent states 92 | for state_ix in range(len(new_state)): 93 | # copy over state in previous beam q to new beam at vix 94 | new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step 95 | #append new end terminal at the end of this beam 96 | beam_seq[t, vix] = v['c'] # c'th word is the continuation 97 | beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here 98 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 99 | state = new_state 100 | return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates 101 | 102 | # Start diverse_beam_search 103 | opt = kwargs['opt'] 104 | temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs 105 | beam_size = opt.get('beam_size', 10) 106 | group_size = opt.get('group_size', 1) 107 | diversity_lambda = opt.get('diversity_lambda', 0.5) 108 | decoding_constraint = opt.get('decoding_constraint', 0) 109 | remove_bad_endings = opt.get('remove_bad_endings', 0) 110 | length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) 111 | bdash = beam_size // group_size # beam per group 112 | 113 | # INITIALIZATIONS 114 | beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] 115 | beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] 116 | beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] 117 | 118 | # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) 119 | done_beams_table = [[] for _ in range(group_size)] 120 | # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] 121 | state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) 122 | logprobs_table = list(init_logprobs.chunk(group_size, 0)) 123 | # END INIT 124 | 125 | # Chunk elements in the args 126 | args = list(args) 127 | if self.__class__.__name__ == 'AttEnsemble': 128 | args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name 129 | args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name 130 | else: 131 | args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args] 132 | args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] 133 | 134 | for t in range(self.seq_length + group_size - 1): 135 | for divm in range(group_size): 136 | if t >= divm and t <= self.seq_length + divm - 1: 137 | # add diversity 138 | logprobsf = logprobs_table[divm].data.float() 139 | # suppress previous word 140 | if decoding_constraint and t-divm > 0: 141 | logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf')) 142 | if remove_bad_endings and t-divm > 0: 143 | logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix).astype('uint8')), 0] = float('-inf') 144 | # suppress UNK tokens in the decoding 145 | logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 146 | # diversity is added here 147 | # the function directly modifies the logprobsf values and hence, we need to return 148 | # the unaugmented ones for sorting the candidates in the end. # for historical 149 | # reasons :-) 150 | unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) 151 | 152 | # infer new beams 153 | beam_seq_table[divm],\ 154 | beam_seq_logprobs_table[divm],\ 155 | beam_logprobs_sum_table[divm],\ 156 | state_table[divm],\ 157 | candidates_divm = beam_step(logprobsf, 158 | unaug_logprobsf, 159 | bdash, 160 | t-divm, 161 | beam_seq_table[divm], 162 | beam_seq_logprobs_table[divm], 163 | beam_logprobs_sum_table[divm], 164 | state_table[divm]) 165 | 166 | # if time's up... or if end token is reached then copy beams 167 | for vix in range(bdash): 168 | if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1: 169 | final_beam = { 170 | 'seq': beam_seq_table[divm][:, vix].clone(), 171 | 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), 172 | 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), 173 | 'p': beam_logprobs_sum_table[divm][vix].item() 174 | } 175 | final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) 176 | done_beams_table[divm].append(final_beam) 177 | # don't continue beams from finished sequences 178 | beam_logprobs_sum_table[divm][vix] = -1000 179 | 180 | # move the current group one step forward in time 181 | 182 | it = beam_seq_table[divm][t-divm] 183 | logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]])) 184 | logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) 185 | 186 | # all beams are sorted by their log-probabilities 187 | done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] 188 | done_beams = reduce(lambda a,b:a+b, done_beams_table) 189 | return done_beams 190 | 191 | 192 | def sample_next_word(self, logprobs, sample_method, temperature): 193 | if sample_method == 'greedy': 194 | sampleLogprobs, it = torch.max(logprobs.data, 1) 195 | it = it.view(-1).long() 196 | elif sample_method == 'gumbel': # gumbel softmax 197 | # ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f 198 | def sample_gumbel(shape, eps=1e-20): 199 | U = torch.rand(shape).cuda() 200 | return -torch.log(-torch.log(U + eps) + eps) 201 | def gumbel_softmax_sample(logits, temperature): 202 | y = logits + sample_gumbel(logits.size()) 203 | return F.log_softmax(y / temperature, dim=-1) 204 | _logprobs = gumbel_softmax_sample(logprobs, temperature) 205 | _, it = torch.max(_logprobs.data, 1) 206 | sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions 207 | else: 208 | logprobs = logprobs / temperature 209 | if sample_method.startswith('top'): # topk sampling 210 | top_num = float(sample_method[3:]) 211 | if 0 < top_num < 1: 212 | # nucleus sampling from # The Curious Case of Neural Text Degeneration 213 | probs = F.softmax(logprobs, dim=1) 214 | sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) 215 | _cumsum = sorted_probs.cumsum(1) 216 | mask = _cumsum < top_num 217 | mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) 218 | sorted_probs = sorted_probs * mask.float() 219 | sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) 220 | logprobs.scatter_(1, sorted_indices, sorted_probs.log()) 221 | else: 222 | the_k = int(top_num) 223 | tmp = torch.empty_like(logprobs).fill_(float('-inf')) 224 | topk, indices = torch.topk(logprobs, the_k, dim=1) 225 | tmp = tmp.scatter(1, indices, topk) 226 | logprobs = tmp 227 | it = torch.distributions.Categorical(logits=logprobs.detach()).sample() 228 | sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions 229 | return it, sampleLogprobs -------------------------------------------------------------------------------- /models/TransformerModel.py: -------------------------------------------------------------------------------- 1 | # This file contains Transformer network 2 | # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html 3 | 4 | # The cfg name correspondance: 5 | # N=num_layers 6 | # d_model=input_encoding_size 7 | # d_ff=rnn_size 8 | # h is always 8 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import misc.utils as utils 18 | 19 | import copy 20 | import math 21 | import numpy as np 22 | 23 | from .CaptionModel import CaptionModel 24 | from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel 25 | 26 | class EncoderDecoder(nn.Module): 27 | """ 28 | A standard Encoder-Decoder architecture. Base for this and many 29 | other models. 30 | """ 31 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 32 | super(EncoderDecoder, self).__init__() 33 | self.encoder = encoder 34 | self.decoder = decoder 35 | self.src_embed = src_embed 36 | self.tgt_embed = tgt_embed 37 | self.generator = generator 38 | 39 | def forward(self, src, tgt, src_mask, tgt_mask): 40 | "Take in and process masked src and target sequences." 41 | return self.decode(self.encode(src, src_mask), src_mask, 42 | tgt, tgt_mask) 43 | 44 | def encode(self, src, src_mask): 45 | return self.encoder(self.src_embed(src), src_mask) 46 | 47 | def decode(self, memory, src_mask, tgt, tgt_mask): 48 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 49 | 50 | class Generator(nn.Module): 51 | "Define standard linear + softmax generation step." 52 | def __init__(self, d_model, vocab): 53 | super(Generator, self).__init__() 54 | self.proj = nn.Linear(d_model, vocab) 55 | 56 | def forward(self, x): 57 | return F.log_softmax(self.proj(x), dim=-1) 58 | 59 | def clones(module, N): 60 | "Produce N identical layers." 61 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 62 | 63 | class Encoder(nn.Module): 64 | "Core encoder is a stack of N layers" 65 | def __init__(self, layer, N): 66 | super(Encoder, self).__init__() 67 | self.layers = clones(layer, N) 68 | self.norm = LayerNorm(layer.size) 69 | 70 | def forward(self, x, mask): 71 | "Pass the input (and mask) through each layer in turn." 72 | for layer in self.layers: 73 | x = layer(x, mask) 74 | return self.norm(x) 75 | 76 | class LayerNorm(nn.Module): 77 | "Construct a layernorm module (See citation for details)." 78 | def __init__(self, features, eps=1e-6): 79 | super(LayerNorm, self).__init__() 80 | self.a_2 = nn.Parameter(torch.ones(features)) 81 | self.b_2 = nn.Parameter(torch.zeros(features)) 82 | self.eps = eps 83 | 84 | def forward(self, x): 85 | mean = x.mean(-1, keepdim=True) 86 | std = x.std(-1, keepdim=True) 87 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 88 | 89 | class SublayerConnection(nn.Module): 90 | """ 91 | A residual connection followed by a layer norm. 92 | Note for code simplicity the norm is first as opposed to last. 93 | """ 94 | def __init__(self, size, dropout): 95 | super(SublayerConnection, self).__init__() 96 | self.norm = LayerNorm(size) 97 | self.dropout = nn.Dropout(dropout) 98 | 99 | def forward(self, x, sublayer): 100 | "Apply residual connection to any sublayer with the same size." 101 | return x + self.dropout(sublayer(self.norm(x))) 102 | 103 | class EncoderLayer(nn.Module): 104 | "Encoder is made up of self-attn and feed forward (defined below)" 105 | def __init__(self, size, self_attn, feed_forward, dropout): 106 | super(EncoderLayer, self).__init__() 107 | self.self_attn = self_attn 108 | self.feed_forward = feed_forward 109 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 110 | self.size = size 111 | 112 | def forward(self, x, mask): 113 | "Follow Figure 1 (left) for connections." 114 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 115 | return self.sublayer[1](x, self.feed_forward) 116 | 117 | class Decoder(nn.Module): 118 | "Generic N layer decoder with masking." 119 | def __init__(self, layer, N): 120 | super(Decoder, self).__init__() 121 | self.layers = clones(layer, N) 122 | self.norm = LayerNorm(layer.size) 123 | 124 | def forward(self, x, memory, src_mask, tgt_mask): 125 | for layer in self.layers: 126 | x = layer(x, memory, src_mask, tgt_mask) 127 | return self.norm(x) 128 | 129 | class DecoderLayer(nn.Module): 130 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)" 131 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 132 | super(DecoderLayer, self).__init__() 133 | self.size = size 134 | self.self_attn = self_attn 135 | self.src_attn = src_attn 136 | self.feed_forward = feed_forward 137 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 138 | 139 | def forward(self, x, memory, src_mask, tgt_mask): 140 | "Follow Figure 1 (right) for connections." 141 | m = memory 142 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 143 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 144 | return self.sublayer[2](x, self.feed_forward) 145 | 146 | def subsequent_mask(size): 147 | "Mask out subsequent positions." 148 | attn_shape = (1, size, size) 149 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 150 | return torch.from_numpy(subsequent_mask) == 0 151 | 152 | def attention(query, key, value, mask=None, dropout=None): 153 | "Compute 'Scaled Dot Product Attention'" 154 | d_k = query.size(-1) 155 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 156 | / math.sqrt(d_k) 157 | if mask is not None: 158 | scores = scores.masked_fill(mask == 0, -1e9) 159 | p_attn = F.softmax(scores, dim = -1) 160 | if dropout is not None: 161 | p_attn = dropout(p_attn) 162 | return torch.matmul(p_attn, value), p_attn 163 | 164 | class MultiHeadedAttention(nn.Module): 165 | def __init__(self, h, d_model, dropout=0.1): 166 | "Take in model size and number of heads." 167 | super(MultiHeadedAttention, self).__init__() 168 | assert d_model % h == 0 169 | # We assume d_v always equals d_k 170 | self.d_k = d_model // h 171 | self.h = h 172 | self.linears = clones(nn.Linear(d_model, d_model), 4) 173 | self.attn = None 174 | self.dropout = nn.Dropout(p=dropout) 175 | 176 | def forward(self, query, key, value, mask=None): 177 | "Implements Figure 2" 178 | if mask is not None: 179 | # Same mask applied to all h heads. 180 | mask = mask.unsqueeze(1) 181 | nbatches = query.size(0) 182 | 183 | # 1) Do all the linear projections in batch from d_model => h x d_k 184 | query, key, value = \ 185 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 186 | for l, x in zip(self.linears, (query, key, value))] 187 | 188 | # 2) Apply attention on all the projected vectors in batch. 189 | x, self.attn = attention(query, key, value, mask=mask, 190 | dropout=self.dropout) 191 | 192 | # 3) "Concat" using a view and apply a final linear. 193 | x = x.transpose(1, 2).contiguous() \ 194 | .view(nbatches, -1, self.h * self.d_k) 195 | return self.linears[-1](x) 196 | 197 | class PositionwiseFeedForward(nn.Module): 198 | "Implements FFN equation." 199 | def __init__(self, d_model, d_ff, dropout=0.1): 200 | super(PositionwiseFeedForward, self).__init__() 201 | self.w_1 = nn.Linear(d_model, d_ff) 202 | self.w_2 = nn.Linear(d_ff, d_model) 203 | self.dropout = nn.Dropout(dropout) 204 | 205 | def forward(self, x): 206 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 207 | 208 | class Embeddings(nn.Module): 209 | def __init__(self, d_model, vocab): 210 | super(Embeddings, self).__init__() 211 | self.lut = nn.Embedding(vocab, d_model) 212 | self.d_model = d_model 213 | 214 | def forward(self, x): 215 | return self.lut(x) * math.sqrt(self.d_model) 216 | 217 | class PositionalEncoding(nn.Module): 218 | "Implement the PE function." 219 | def __init__(self, d_model, dropout, max_len=5000): 220 | super(PositionalEncoding, self).__init__() 221 | self.dropout = nn.Dropout(p=dropout) 222 | 223 | # Compute the positional encodings once in log space. 224 | pe = torch.zeros(max_len, d_model) 225 | position = torch.arange(0, max_len).unsqueeze(1).float() 226 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 227 | -(math.log(10000.0) / d_model)) 228 | pe[:, 0::2] = torch.sin(position * div_term) 229 | pe[:, 1::2] = torch.cos(position * div_term) 230 | pe = pe.unsqueeze(0) 231 | self.register_buffer('pe', pe) 232 | 233 | def forward(self, x): 234 | x = x + self.pe[:, :x.size(1)] 235 | return self.dropout(x) 236 | 237 | class TransformerModel(AttModel): 238 | 239 | def make_model(self, src_vocab, tgt_vocab, N=6, 240 | d_model=512, d_ff=2048, h=8, dropout=0.1): 241 | "Helper: Construct a model from hyperparameters." 242 | c = copy.deepcopy 243 | attn = MultiHeadedAttention(h, d_model) 244 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 245 | position = PositionalEncoding(d_model, dropout) 246 | model = EncoderDecoder( 247 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 248 | Decoder(DecoderLayer(d_model, c(attn), c(attn), 249 | c(ff), dropout), N), 250 | lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), 251 | nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), 252 | Generator(d_model, tgt_vocab)) 253 | 254 | # This was important from their code. 255 | # Initialize parameters with Glorot / fan_avg. 256 | for p in model.parameters(): 257 | if p.dim() > 1: 258 | nn.init.xavier_uniform_(p) 259 | return model 260 | 261 | def __init__(self, opt): 262 | super(TransformerModel, self).__init__(opt) 263 | self.opt = opt 264 | # self.config = yaml.load(open(opt.config_file)) 265 | # d_model = self.input_encoding_size # 512 266 | 267 | delattr(self, 'att_embed') 268 | self.att_embed = nn.Sequential(*( 269 | ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ 270 | (nn.Linear(self.att_feat_size, self.input_encoding_size), 271 | nn.ReLU(), 272 | nn.Dropout(self.drop_prob_lm))+ 273 | ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn==2 else ()))) 274 | 275 | delattr(self, 'embed') 276 | self.embed = lambda x : x 277 | delattr(self, 'fc_embed') 278 | self.fc_embed = lambda x : x 279 | delattr(self, 'logit') 280 | del self.ctx2att 281 | 282 | tgt_vocab = self.vocab_size + 1 283 | self.model = self.make_model(0, tgt_vocab, 284 | N=opt.num_layers, 285 | d_model=opt.input_encoding_size, 286 | d_ff=opt.rnn_size) 287 | 288 | def logit(self, x): # unsafe way 289 | return self.model.generator.proj(x) 290 | 291 | def init_hidden(self, bsz): 292 | return [] 293 | 294 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 295 | 296 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) 297 | memory = self.model.encode(att_feats, att_masks) 298 | 299 | return fc_feats[...,:1], att_feats[...,:1], memory, att_masks 300 | 301 | def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): 302 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 303 | 304 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) 305 | 306 | if att_masks is None: 307 | att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) 308 | att_masks = att_masks.unsqueeze(-2) 309 | 310 | if seq is not None: 311 | # crop the last one 312 | seq = seq[:,:-1] 313 | seq_mask = (seq.data > 0) 314 | seq_mask[:,0] += 1 315 | 316 | seq_mask = seq_mask.unsqueeze(-2) 317 | seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) 318 | else: 319 | seq_mask = None 320 | 321 | return att_feats, seq, att_masks, seq_mask 322 | 323 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 324 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) 325 | 326 | out = self.model(att_feats, seq, att_masks, seq_mask) 327 | 328 | outputs = self.model.generator(out) 329 | return outputs 330 | # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 331 | 332 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 333 | """ 334 | state = [ys.unsqueeze(0)] 335 | """ 336 | if len(state) == 0: 337 | ys = it.unsqueeze(1) 338 | else: 339 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 340 | out = self.model.decode(memory, mask, 341 | ys, 342 | subsequent_mask(ys.size(1)) 343 | .to(memory.device)) 344 | return out[:, -1], [ys.unsqueeze(0)] -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import lmdb 8 | import os 9 | import numpy as np 10 | import random 11 | 12 | import torch 13 | import torch.utils.data as data 14 | 15 | import multiprocessing 16 | import six 17 | 18 | class HybridLoader: 19 | """ 20 | If db_path is a director, then use normal file loading 21 | If lmdb, then load from lmdb 22 | The loading method depend on extention. 23 | """ 24 | def __init__(self, db_path, ext): 25 | self.db_path = db_path 26 | self.ext = ext 27 | if self.ext == '.npy': 28 | self.loader = lambda x: np.load(x) 29 | else: 30 | self.loader = lambda x: np.load(x)['feat'] 31 | if db_path.endswith('.lmdb'): 32 | self.db_type = 'lmdb' 33 | self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path), 34 | readonly=True, lock=False, 35 | readahead=False, meminit=False) 36 | elif db_path.endswith('.pth'): # Assume a key,value dictionary 37 | self.db_type = 'pth' 38 | self.feat_file = torch.load(db_path) 39 | self.loader = lambda x: x 40 | print('HybridLoader: ext is ignored') 41 | else: 42 | self.db_type = 'dir' 43 | 44 | def get(self, key): 45 | 46 | if self.db_type == 'lmdb': 47 | env = self.env 48 | with env.begin(write=False) as txn: 49 | byteflow = txn.get(key) 50 | f_input = six.BytesIO(byteflow) 51 | elif self.db_type == 'pth': 52 | f_input = self.feat_file[key] 53 | else: 54 | f_input = os.path.join(self.db_path, key + self.ext) 55 | 56 | # load image 57 | feat = self.loader(f_input) 58 | 59 | return feat 60 | 61 | 62 | class DataLoader(data.Dataset): 63 | 64 | def reset_iterator(self, split): 65 | del self._prefetch_process[split] 66 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train') 67 | self.iterators[split] = 0 68 | 69 | def get_vocab_size(self): 70 | return self.vocab_size 71 | 72 | def get_vocab(self): 73 | return self.ix_to_word 74 | 75 | def get_seq_length(self): 76 | return self.seq_length 77 | 78 | def __init__(self, opt): 79 | self.opt = opt 80 | self.batch_size = self.opt.batch_size 81 | self.seq_per_img = opt.seq_per_img 82 | 83 | # feature related options 84 | self.use_fc = getattr(opt, 'use_fc', True) 85 | self.use_att = getattr(opt, 'use_att', True) 86 | self.use_box = getattr(opt, 'use_box', 0) 87 | self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) 88 | self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) 89 | 90 | # load the json file which contains additional information about the dataset 91 | print('DataLoader loading json file: ', opt.input_json) 92 | self.info = json.load(open(self.opt.input_json)) 93 | if 'ix_to_word' in self.info: 94 | self.ix_to_word = self.info['ix_to_word'] 95 | self.vocab_size = len(self.ix_to_word) 96 | print('vocab size is ', self.vocab_size) 97 | 98 | # open the hdf5 file 99 | print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) 100 | if self.opt.input_label_h5 != 'none': 101 | self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') 102 | # load in the sequence data 103 | seq_size = self.h5_label_file['labels'].shape 104 | self.label = self.h5_label_file['labels'][:] 105 | self.seq_length = seq_size[1] 106 | print('max sequence length in data is', self.seq_length) 107 | # load the pointers in full to RAM (should be small enough) 108 | self.label_start_ix = self.h5_label_file['label_start_ix'][:] 109 | self.label_end_ix = self.h5_label_file['label_end_ix'][:] 110 | else: 111 | self.seq_length = 1 112 | 113 | self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy') 114 | self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz') 115 | self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy') 116 | 117 | self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] 118 | print('read %d image features' %(self.num_images)) 119 | 120 | # separate out indexes for each of the provided splits 121 | self.split_ix = {'train': [], 'val': [], 'test': []} 122 | for ix in range(len(self.info['images'])): 123 | img = self.info['images'][ix] 124 | if not 'split' in img: 125 | self.split_ix['train'].append(ix) 126 | self.split_ix['val'].append(ix) 127 | self.split_ix['test'].append(ix) 128 | elif img['split'] == 'train': 129 | self.split_ix['train'].append(ix) 130 | elif img['split'] == 'val': 131 | self.split_ix['val'].append(ix) 132 | elif img['split'] == 'test': 133 | self.split_ix['test'].append(ix) 134 | elif opt.train_only == 0: # restval 135 | self.split_ix['train'].append(ix) 136 | 137 | print('assigned %d images to split train' %len(self.split_ix['train'])) 138 | print('assigned %d images to split val' %len(self.split_ix['val'])) 139 | print('assigned %d images to split test' %len(self.split_ix['test'])) 140 | 141 | self.iterators = {'train': 0, 'val': 0, 'test': 0} 142 | 143 | self._prefetch_process = {} # The three prefetch process 144 | for split in self.iterators.keys(): 145 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train') 146 | # Terminate the child process when the parent exists 147 | def cleanup(): 148 | print('Terminating BlobFetcher') 149 | for split in self.iterators.keys(): 150 | del self._prefetch_process[split] 151 | import atexit 152 | atexit.register(cleanup) 153 | 154 | def get_captions(self, ix, seq_per_img): 155 | # fetch the sequence labels 156 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 157 | ix2 = self.label_end_ix[ix] - 1 158 | ncap = ix2 - ix1 + 1 # number of captions available for this image 159 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' 160 | 161 | if ncap < seq_per_img: 162 | # we need to subsample (with replacement) 163 | seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') 164 | for q in range(seq_per_img): 165 | ixl = random.randint(ix1,ix2) 166 | seq[q, :] = self.label[ixl, :self.seq_length] 167 | else: 168 | ixl = random.randint(ix1, ix2 - seq_per_img + 1) 169 | seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] 170 | 171 | return seq 172 | 173 | def get_batch(self, split, batch_size=None): 174 | batch_size = batch_size or self.batch_size 175 | seq_per_img = self.seq_per_img 176 | 177 | fc_batch = [] # np.ndarray((batch_size * seq_per_img, self.opt.fc_feat_size), dtype = 'float32') 178 | att_batch = [] # np.ndarray((batch_size * seq_per_img, 14, 14, self.opt.att_feat_size), dtype = 'float32') 179 | label_batch = [] #np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int') 180 | 181 | wrapped = False 182 | 183 | infos = [] 184 | gts = [] 185 | 186 | for i in range(batch_size): 187 | # fetch image 188 | tmp_fc, tmp_att, tmp_seq, \ 189 | ix, tmp_wrapped = self._prefetch_process[split].get() 190 | if tmp_wrapped: 191 | wrapped = True 192 | 193 | fc_batch.append(tmp_fc) 194 | att_batch.append(tmp_att) 195 | 196 | tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') 197 | if hasattr(self, 'h5_label_file'): 198 | tmp_label[:, 1 : self.seq_length + 1] = tmp_seq 199 | label_batch.append(tmp_label) 200 | 201 | # Used for reward evaluation 202 | if hasattr(self, 'h5_label_file'): 203 | gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) 204 | else: 205 | gts.append([]) 206 | 207 | # record associated info as well 208 | info_dict = {} 209 | info_dict['ix'] = ix 210 | info_dict['id'] = self.info['images'][ix]['id'] 211 | info_dict['file_path'] = self.info['images'][ix].get('file_path', '') 212 | infos.append(info_dict) 213 | 214 | # #sort by att_feat length 215 | # fc_batch, att_batch, label_batch, gts, infos = \ 216 | # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) 217 | fc_batch, att_batch, label_batch, gts, infos = \ 218 | zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) 219 | data = {} 220 | data['fc_feats'] = np.stack(sum([[_]*seq_per_img for _ in fc_batch], [])) 221 | # merge att_feats 222 | max_att_len = max([_.shape[0] for _ in att_batch]) 223 | data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, max_att_len, att_batch[0].shape[1]], dtype = 'float32') 224 | for i in range(len(att_batch)): 225 | data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = att_batch[i] 226 | data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') 227 | for i in range(len(att_batch)): 228 | data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = 1 229 | # set att_masks to None if attention features have same length 230 | if data['att_masks'].sum() == data['att_masks'].size: 231 | data['att_masks'] = None 232 | 233 | data['labels'] = np.vstack(label_batch) 234 | # generate mask 235 | nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) 236 | mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') 237 | for ix, row in enumerate(mask_batch): 238 | row[:nonzeros[ix]] = 1 239 | data['masks'] = mask_batch 240 | 241 | data['gts'] = gts # all ground truth captions of each images 242 | data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped} 243 | data['infos'] = infos 244 | 245 | data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor 246 | 247 | return data 248 | 249 | # It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions, 250 | # so that the torch.utils.data.DataLoader can load the data according the index. 251 | # However, it's minimum change to switch to pytorch data loading. 252 | def __getitem__(self, index): 253 | """This function returns a tuple that is further passed to collate_fn 254 | """ 255 | ix = index #self.split_ix[index] 256 | if self.use_att: 257 | att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) 258 | # Reshape to K x C 259 | att_feat = att_feat.reshape(-1, att_feat.shape[-1]) 260 | if self.norm_att_feat: 261 | att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) 262 | if self.use_box: 263 | box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) 264 | # devided by image width and height 265 | x1,y1,x2,y2 = np.hsplit(box_feat, 4) 266 | h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] 267 | box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? 268 | if self.norm_box_feat: 269 | box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) 270 | att_feat = np.hstack([att_feat, box_feat]) 271 | # sort the features by the size of boxes 272 | att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) 273 | else: 274 | att_feat = np.zeros((1,1,1), dtype='float32') 275 | if self.use_fc: 276 | fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) 277 | else: 278 | fc_feat = np.zeros((1), dtype='float32') 279 | if hasattr(self, 'h5_label_file'): 280 | seq = self.get_captions(ix, self.seq_per_img) 281 | else: 282 | seq = None 283 | return (fc_feat, 284 | att_feat, seq, 285 | ix) 286 | 287 | def __len__(self): 288 | return len(self.info['images']) 289 | 290 | class SubsetSampler(torch.utils.data.sampler.Sampler): 291 | r"""Samples elements randomly from a given list of indices, without replacement. 292 | Arguments: 293 | indices (list): a list of indices 294 | """ 295 | 296 | def __init__(self, indices): 297 | self.indices = indices 298 | 299 | def __iter__(self): 300 | return (self.indices[i] for i in range(len(self.indices))) 301 | 302 | def __len__(self): 303 | return len(self.indices) 304 | 305 | class BlobFetcher(): 306 | """Experimental class for prefetching blobs in a separate process.""" 307 | def __init__(self, split, dataloader, if_shuffle=False): 308 | """ 309 | db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname 310 | """ 311 | self.split = split 312 | self.dataloader = dataloader 313 | self.if_shuffle = if_shuffle 314 | 315 | # Add more in the queue 316 | def reset(self): 317 | """ 318 | Two cases for this function to be triggered: 319 | 1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator 320 | 2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already. 321 | """ 322 | # batch_size is 1, the merge is done in DataLoader class 323 | self.split_loader = iter(data.DataLoader(dataset=self.dataloader, 324 | batch_size=1, 325 | sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]), 326 | shuffle=False, 327 | pin_memory=True, 328 | num_workers=4, # 4 is usually enough 329 | collate_fn=lambda x: x[0])) 330 | 331 | def _get_next_minibatch_inds(self): 332 | max_index = len(self.dataloader.split_ix[self.split]) 333 | wrapped = False 334 | 335 | ri = self.dataloader.iterators[self.split] 336 | ix = self.dataloader.split_ix[self.split][ri] 337 | 338 | ri_next = ri + 1 339 | if ri_next >= max_index: 340 | ri_next = 0 341 | if self.if_shuffle: 342 | random.shuffle(self.dataloader.split_ix[self.split]) 343 | wrapped = True 344 | self.dataloader.iterators[self.split] = ri_next 345 | 346 | return ix, wrapped 347 | 348 | def get(self): 349 | if not hasattr(self, 'split_loader'): 350 | self.reset() 351 | 352 | ix, wrapped = self._get_next_minibatch_inds() 353 | tmp = self.split_loader.next() 354 | if wrapped: 355 | self.reset() 356 | 357 | assert tmp[-1] == ix, "ix not equal" 358 | 359 | return tmp + [wrapped] -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_opt(): 4 | parser = argparse.ArgumentParser() 5 | # Data input settings 6 | parser.add_argument('--input_json', type=str, default='data/coco.json', 7 | help='path to the json file containing additional info and vocab') 8 | parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', 9 | help='path to the directory containing the preprocessed fc feats') 10 | parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', 11 | help='path to the directory containing the preprocessed att feats') 12 | parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box', 13 | help='path to the directory containing the boxes of att feats') 14 | parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', 15 | help='path to the h5file containing the preprocessed dataset') 16 | parser.add_argument('--start_from', type=str, default=None, 17 | help="""continue training from saved model at this path. Path must contain files saved by previous training process: 18 | 'infos.pkl' : configuration; 19 | 'checkpoint' : paths to model file(s) (created by tf). 20 | Note: this file contains absolute paths, be careful when moving files around; 21 | 'model.ckpt-*' : file(s) with model definition (created by tf) 22 | """) 23 | parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', 24 | help='Cached token file for calculating cider score during self critical training.') 25 | 26 | # Model settings 27 | parser.add_argument('--caption_model', type=str, default="show_tell", 28 | help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt, transformer') 29 | parser.add_argument('--rnn_size', type=int, default=512, 30 | help='size of the rnn in number of hidden nodes in each layer') 31 | parser.add_argument('--num_layers', type=int, default=1, 32 | help='number of layers in the RNN') 33 | parser.add_argument('--rnn_type', type=str, default='lstm', 34 | help='rnn, gru, or lstm') 35 | parser.add_argument('--input_encoding_size', type=int, default=512, 36 | help='the encoding size of each token in the vocabulary, and the image.') 37 | parser.add_argument('--att_hid_size', type=int, default=512, 38 | help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') 39 | parser.add_argument('--fc_feat_size', type=int, default=2048, 40 | help='2048 for resnet, 4096 for vgg') 41 | parser.add_argument('--att_feat_size', type=int, default=2048, 42 | help='2048 for resnet, 512 for vgg') 43 | parser.add_argument('--logit_layers', type=int, default=1, 44 | help='number of layers in the RNN') 45 | 46 | 47 | parser.add_argument('--use_bn', type=int, default=0, 48 | help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed') 49 | 50 | # AoA settings 51 | parser.add_argument('--mean_feats', type=int, default=1, 52 | help='use mean pooling of feats?') 53 | parser.add_argument('--refine', type=int, default=1, 54 | help='refining feature vectors?') 55 | parser.add_argument('--refine_aoa', type=int, default=1, 56 | help='use aoa in the refining module?') 57 | parser.add_argument('--use_ff', type=int, default=1, 58 | help='keep feed-forward layer in the refining module?') 59 | parser.add_argument('--dropout_aoa', type=float, default=0.3, 60 | help='dropout_aoa in the refining module?') 61 | 62 | parser.add_argument('--ctx_drop', type=int, default=0, 63 | help='apply dropout to the context vector before fed into LSTM?') 64 | parser.add_argument('--decoder_type', type=str, default='AoA', 65 | help='AoA, LSTM, base') 66 | parser.add_argument('--use_multi_head', type=int, default=2, 67 | help='use multi head attention? 0 for addictive single head; 1 for addictive multi head; 2 for productive multi head.') 68 | parser.add_argument('--num_heads', type=int, default=8, 69 | help='number of attention heads?') 70 | parser.add_argument('--multi_head_scale', type=int, default=1, 71 | help='scale q,k,v?') 72 | 73 | parser.add_argument('--use_warmup', type=int, default=0, 74 | help='warm up the learing rate?') 75 | parser.add_argument('--acc_steps', type=int, default=1, 76 | help='accumulation steps') 77 | 78 | 79 | # feature manipulation 80 | parser.add_argument('--norm_att_feat', type=int, default=0, 81 | help='If normalize attention features') 82 | parser.add_argument('--use_box', type=int, default=0, 83 | help='If use box features') 84 | parser.add_argument('--norm_box_feat', type=int, default=0, 85 | help='If use box, do we normalize box feature') 86 | 87 | # Optimization: General 88 | parser.add_argument('--max_epochs', type=int, default=-1, 89 | help='number of epochs') 90 | parser.add_argument('--batch_size', type=int, default=16, 91 | help='minibatch size') 92 | parser.add_argument('--grad_clip', type=float, default=0.1, #5., 93 | help='clip gradients at this value') 94 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, 95 | help='strength of dropout in the Language Model RNN') 96 | parser.add_argument('--self_critical_after', type=int, default=-1, 97 | help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') 98 | parser.add_argument('--seq_per_img', type=int, default=5, 99 | help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') 100 | 101 | # Sample related 102 | parser.add_argument('--beam_size', type=int, default=1, 103 | help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 104 | parser.add_argument('--max_length', type=int, default=20, 105 | help='Maximum length during sampling') 106 | parser.add_argument('--length_penalty', type=str, default='', 107 | help='wu_X or avg_X, X is the alpha') 108 | parser.add_argument('--block_trigrams', type=int, default=0, 109 | help='block repeated trigram.') 110 | parser.add_argument('--remove_bad_endings', type=int, default=0, 111 | help='Remove bad endings') 112 | 113 | #Optimization: for the Language Model 114 | parser.add_argument('--optim', type=str, default='adam', 115 | help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam') 116 | parser.add_argument('--learning_rate', type=float, default=4e-4, 117 | help='learning rate') 118 | parser.add_argument('--learning_rate_decay_start', type=int, default=-1, 119 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') 120 | parser.add_argument('--learning_rate_decay_every', type=int, default=3, 121 | help='every how many iterations thereafter to drop LR?(in epoch)') 122 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, 123 | help='every how many iterations thereafter to drop LR?(in epoch)') 124 | parser.add_argument('--optim_alpha', type=float, default=0.9, 125 | help='alpha for adam') 126 | parser.add_argument('--optim_beta', type=float, default=0.999, 127 | help='beta used for adam') 128 | parser.add_argument('--optim_epsilon', type=float, default=1e-8, 129 | help='epsilon that goes into denominator for smoothing') 130 | parser.add_argument('--weight_decay', type=float, default=0, 131 | help='weight_decay') 132 | # Transformer 133 | parser.add_argument('--label_smoothing', type=float, default=0, 134 | help='') 135 | parser.add_argument('--noamopt', action='store_true', 136 | help='') 137 | parser.add_argument('--noamopt_warmup', type=int, default=2000, 138 | help='') 139 | parser.add_argument('--noamopt_factor', type=float, default=1, 140 | help='') 141 | parser.add_argument('--reduce_on_plateau', action='store_true', 142 | help='') 143 | 144 | parser.add_argument('--scheduled_sampling_start', type=int, default=-1, 145 | help='at what iteration to start decay gt probability') 146 | parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, 147 | help='every how many iterations thereafter to gt probability') 148 | parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, 149 | help='How much to update the prob') 150 | parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, 151 | help='Maximum scheduled sampling prob.') 152 | 153 | 154 | # Evaluation/Checkpointing 155 | parser.add_argument('--val_images_use', type=int, default=3200, 156 | help='how many images to use when periodically evaluating the validation loss? (-1 = all)') 157 | parser.add_argument('--save_checkpoint_every', type=int, default=2500, 158 | help='how often to save a model checkpoint (in iterations)?') 159 | parser.add_argument('--save_history_ckpt', type=int, default=0, 160 | help='If save checkpoints at every save point') 161 | parser.add_argument('--checkpoint_path', type=str, default='save', 162 | help='directory to store checkpointed models') 163 | parser.add_argument('--language_eval', type=int, default=0, 164 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 165 | parser.add_argument('--losses_log_every', type=int, default=25, 166 | help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') 167 | parser.add_argument('--load_best_score', type=int, default=1, 168 | help='Do we load previous best score when resuming training.') 169 | 170 | # misc 171 | parser.add_argument('--id', type=str, default='', 172 | help='an id identifying this run/job. used in cross-val and appended when writing progress files') 173 | parser.add_argument('--train_only', type=int, default=0, 174 | help='if true then use 80k, else use 110k') 175 | 176 | 177 | # Reward 178 | parser.add_argument('--cider_reward_weight', type=float, default=1, 179 | help='The reward weight from cider') 180 | parser.add_argument('--bleu_reward_weight', type=float, default=0, 181 | help='The reward weight from bleu4') 182 | 183 | args = parser.parse_args() 184 | 185 | # Check if args are valid 186 | assert args.rnn_size > 0, "rnn_size should be greater than 0" 187 | assert args.num_layers > 0, "num_layers should be greater than 0" 188 | assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" 189 | assert args.batch_size > 0, "batch_size should be greater than 0" 190 | assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" 191 | assert args.seq_per_img > 0, "seq_per_img should be greater than 0" 192 | assert args.beam_size > 0, "beam_size should be greater than 0" 193 | assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" 194 | assert args.losses_log_every > 0, "losses_log_every should be greater than 0" 195 | assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" 196 | assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" 197 | assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" 198 | 199 | return args 200 | 201 | def add_eval_options(parser): 202 | # Basic options 203 | parser.add_argument('--batch_size', type=int, default=0, 204 | help='if > 0 then overrule, otherwise load from checkpoint.') 205 | parser.add_argument('--num_images', type=int, default=-1, 206 | help='how many images to use when periodically evaluating the loss? (-1 = all)') 207 | parser.add_argument('--language_eval', type=int, default=0, 208 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 209 | parser.add_argument('--dump_images', type=int, default=1, 210 | help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') 211 | parser.add_argument('--dump_json', type=int, default=1, 212 | help='Dump json with predictions into vis folder? (1=yes,0=no)') 213 | parser.add_argument('--dump_path', type=int, default=0, 214 | help='Write image paths along with predictions into vis json? (1=yes,0=no)') 215 | 216 | # Sampling options 217 | parser.add_argument('--sample_method', type=str, default='greedy', 218 | help='greedy; sample; gumbel; top, top<0-1>') 219 | parser.add_argument('--beam_size', type=int, default=2, 220 | help='indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 221 | parser.add_argument('--max_length', type=int, default=20, 222 | help='Maximum length during sampling') 223 | parser.add_argument('--length_penalty', type=str, default='', 224 | help='wu_X or avg_X, X is the alpha') 225 | parser.add_argument('--group_size', type=int, default=1, 226 | help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') 227 | parser.add_argument('--diversity_lambda', type=float, default=0.5, 228 | help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') 229 | parser.add_argument('--temperature', type=float, default=1.0, 230 | help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.') 231 | parser.add_argument('--decoding_constraint', type=int, default=0, 232 | help='If 1, not allowing same word in a row') 233 | parser.add_argument('--block_trigrams', type=int, default=0, 234 | help='block repeated trigram.') 235 | parser.add_argument('--remove_bad_endings', type=int, default=0, 236 | help='Remove bad endings') 237 | # For evaluation on a folder of images: 238 | parser.add_argument('--image_folder', type=str, default='', 239 | help='If this is nonempty then will predict on the images in this folder path') 240 | parser.add_argument('--image_root', type=str, default='', 241 | help='In case the image paths have to be preprended with a root path to an image folder') 242 | # For evaluation on MSCOCO images from some split: 243 | parser.add_argument('--input_fc_dir', type=str, default='', 244 | help='path to the h5file containing the preprocessed dataset') 245 | parser.add_argument('--input_att_dir', type=str, default='', 246 | help='path to the h5file containing the preprocessed dataset') 247 | parser.add_argument('--input_box_dir', type=str, default='', 248 | help='path to the h5file containing the preprocessed dataset') 249 | parser.add_argument('--input_label_h5', type=str, default='', 250 | help='path to the h5file containing the preprocessed dataset') 251 | parser.add_argument('--input_json', type=str, default='', 252 | help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') 253 | parser.add_argument('--split', type=str, default='test', 254 | help='if running on MSCOCO images, which split to use: val|test|train') 255 | parser.add_argument('--coco_json', type=str, default='', 256 | help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') 257 | # misc 258 | parser.add_argument('--id', type=str, default='', 259 | help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') 260 | parser.add_argument('--verbose_beam', type=int, default=1, 261 | help='if we need to print out all beam search beams.') 262 | parser.add_argument('--verbose_loss', type=int, default=0, 263 | help='If calculate loss using ground truth during evaluation') --------------------------------------------------------------------------------